logo

Reinforcing Few-step Generators via Reward-Tilted Distribution Matching

Reward-Tilted DMD  ·  Ambient-Consistent Distillation  ·  Hybrid Policy Gradient

Paper Github Hugging Face Collection

License: Apache 2.0 Python

Yushi Huang1, 2,*, Xiangxin Zhou1,*, Ruoyu Wang1, 3,*, Chi Zhang3, Jun Zhang2, Tianyu Pang1,

1Tencent Hunyuan    2The Hong Kong University of Science and Technology    3Westlake University

* Equal contribution  ·  † Work done during internship at Tencent Hunyuan  ·  ‡ Corresponding author


📖 Abstract

We propose Reward-Tilted Distribution Matching Distillation (RTDMD), a two-stage framework that unifies distribution-matching distillation with reward-guided RL for few-step flow generators. Minimizing the KL divergence to a reward-tilted teacher distribution decomposes naturally into a distribution-matching term and a reward-maximization term — instantiated as Ambient-Consistent DMD (AC-DMD) for the cold start and a hybrid policy gradient (SubGRPO + final-step reward back-propagation) for the RL stage. With 4 NFE RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most rewards.

RTDMD teaser
4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).
RTDMD comparison
Qualitative comparison for few-step diffusion models (4 NFE).

🍭 Method Overview

RTDMD method overview
RTDMD overview. Det. = deterministic final step, Stoc. = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).

For the generator $G_\theta$, the reward-tilted KL objective decomposes as

θDKL(pθp~ψ)=θDKL(pθpψ)distribution matchingβθEx^0pθ[r(x^0)]reward maximization. \nabla_\theta D_{\text{KL}}(p_\theta \| \tilde{p}_\psi) = \underbrace{\nabla_\theta D_{\text{KL}}(p_\theta \| p_\psi)}_{\text{distribution matching}} - \beta\underbrace{\nabla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}.

The two terms map directly to the two trainers exposed by the CLI:

Stage Trainer Key knobs
1. AC-DMD cold start ACDMDTrainer (--trainer ac_dmd) sub-interval renoising, consistency weight γ, CPS sampler η = 0.9
2. RTDMD RL fine-tune RTDMDTrainer (--trainer rtdmd) SubGRPO + final-step BP + AC-DMD

📦 Contents

This repository hosts the 4-NFE LoRA checkpoints distilled from FLUX.2-klein 4B with RTDMD.

.
├── cold_start/
│   └── generator_ema.pt    # Stage-1 AC-DMD LoRA (4 NFE base)
└── rtdmd/
    └── generator_ema.pt    # Stage-2 RTDMD LoRA (stacked on top of cold_start)

Each generator_ema.pt is a torch.save-d state_dict containing only LoRA adapter keys (lora_A / lora_B, rank 32, alpha 64). The two adapters are designed to be stacked: the cold-start LoRA distills FLUX.2-klein 4B down to 4 NFE, and the RTDMD LoRA further fine-tunes that distilled model with reward-tilted RL.


🚀 Usage

Option 1 — RTDMD inference CLI (recommended)

The simplest path is to clone the RTDMD repo and let it stack both LoRAs and run the CPS sampler for you:

git clone https://github.com/Harahan/RTDMD.git && cd RTDMD
pip install -r requirements.txt && pip install -e .

# Download this repo
huggingface-cli download Harahan/FLUX2-4B-RTDMD --local-dir ./ckpts/flux2_4b

# Run 4-NFE inference (single GPU)
python inference.py configs/inference/flux2_4b.yaml \
    --override lora_paths='["./ckpts/flux2_4b/cold_start/generator_ema.pt","./ckpts/flux2_4b/rtdmd/generator_ema.pt"]' \
    --override eval_reward=false \
    --prompt "a cute cat sitting on a windowsill"

Option 2 — Plain diffusers

import torch
from diffusers import Flux2KleinPipeline, Flux2Transformer2DModel
from huggingface_hub import hf_hub_download

base = "black-forest-labs/FLUX.2-klein-4B"
pipe = Flux2KleinPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda")

# Inject LoRA adapters with the rank/alpha used during training
TARGETS = [
    "to_q", "to_k", "to_v", "to_out.0",
    "add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
    "to_qkv_mlp_proj",
] + [f"single_transformer_blocks.{i}.attn.to_out" for i in range(20)]
pipe.transformer.add_adapter(
    LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian")
)

# Sequentially load cold-start then RTDMD weights into the same adapter
for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]:
    path = hf_hub_download("Harahan/FLUX2-4B-RTDMD", ckpt)
    state = torch.load(path, map_location="cpu", weights_only=False)
    pipe.transformer.load_state_dict(state, strict=False)

# 4-step CPS sampling
pipe(prompt="a cute cat sitting on a windowsill",
     num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png")

Note: RTDMD is trained on the CPS (Coefficients-Preserving Sampling) scheduler with η = 0.9. Using the default Flow-Matching Euler scheduler will still produce reasonable samples at 4 NFE, but the RTDMD inference CLI is the only entry point that reproduces the paper numbers exactly.


📄 Citation

@misc{huang2026reinforcingfewstepgeneratorsrewardtilted,
      title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching}, 
      author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang},
      year={2026},
      eprint={2605.26108},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2605.26108}, 
}

⚖️ License

Apache 2.0 — same as the upstream RTDMD repo. The base model black-forest-labs/FLUX.2-klein-4B is governed by its own license; please review and comply with it separately.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Harahan/FLUX2-4B-RTDMD

Finetuned
(32)
this model

Collection including Harahan/FLUX2-4B-RTDMD

Paper for Harahan/FLUX2-4B-RTDMD

Free AI Image Generator No sign-up. Instant results. Open Now