Stage 1: Masked SFT on High-Quality Reasoning Traces
Stage 2: Efficient Policy Gradient Algorithm for dLLMs - diffu-GRPO
Adapting RL algorithms to masked dLLMs poses unique challenges since existing approaches for AR models (PPO and GRPO) rely on computing log-probabilities of generated sequences, which cannot be directly applied to dLLMs. While AR models use sequential factorization, dLLMs lack this natural decomposition due to their iterative, non-sequential generation process.
To address this, we propose an efficient log-probability estimator using Mean-Field Approximation of Sequence Log Probability. This approach decomposes sequence-level log-probability with a simple mean-field decomposition and employs One-Step Per-Token Log Probability Estimation with Prompt Masking.
Using this estimator, we extend GRPO to masked dLLMs with the following objective:
On-policy RL algorithms typically perform multiple gradient updates per batch of samples, requiring a careful balance between outer batch iterations and inner gradient updates. Our log-probability estimator introduces stochastic masking that creates perturbed views of the same (prompt, completion) pairs, serving as regularization for policy optimization. This unique approach allows us to scale the number of inner updates (μ) to higher values while maintaining stable learning dynamics, reducing the number of outer batch iterations and online generations needed—ultimately lowering computational cost significantly.