Approximation of Log-Partition Function in Policy Mirror Descent Induces Implicit Regularization for LLM Post-Training
Key Summary
- •The paper studies a simple way to train giant language models with reinforcement learning by replacing a hard-to-compute term (the log-partition function) with something easy: the mean reward.
- •This simple replacement (called PMD-MEAN) secretly adds an extra safety belt (a chi-squared regularizer) on top of the usual KL regularizer, making updates more cautious when rewards are low.
- •Mathematically, the ideal PMD-MEAN update uses the Lambert-W function, which gently squashes big probability jumps.
- •Because updates are more conservative early on, PMD-MEAN stays stable even when using stale, off-policy rollouts and small regularization.
- •Compared to directly fitting the partition-normalized target (PMD-PART), PMD-MEAN is less sensitive to noise from few rollouts, so it avoids overfitting bad estimates.
- •On math reasoning benchmarks (AIME 2024/2025), PMD-MEAN beats strong baselines like GRPO and matches advanced variants like GSPO, while training faster via larger rollout batches.
- •The paper provides theory (closed-form solution, convergence insights) that explains PMD-MEAN’s stability and why it works well in practice.
- •This helps practitioners train LLMs more simply and robustly, especially in asynchronous or large-batch settings where off-policy data is common.
Why This Research Matters
Training LLMs with RL often uses off-policy data and few rollouts, which can make learning unstable. PMD-MEAN offers a simple swap—use mean reward instead of a hard normalizer—that secretly adds an adaptive safety belt via χ² regularization. This keeps probability updates reasonable, especially when most answers are wrong early on, avoiding collapse. It also reduces engineering complexity since you don’t need heavy importance sampling tricks to stay stable. In practice, this means faster, steadier training of reasoning-focused LLMs that solve more problems with fewer headaches. The theory explains why it works, giving teams confidence to adopt it in real systems.
Detailed Explanation
Tap terms for definitions01Background & Problem Definition
🍞 Hook: Imagine you’re teaching a huge class of robot students to solve math problems. You give them practice questions (prompts), they try answers (responses), and you cheer or boo (rewards). You want them to improve fast without suddenly forgetting how to write sentences or going wild with risky guesses.
🥬 The Concept: Before this paper, large language models (LLMs) were often improved with reinforcement learning (RL) using a method called policy mirror descent (PMD). PMD says: make the model prefer answers with higher rewards, but don’t let it drift too far from what it currently does well. This “don’t drift too far” is measured with KL-divergence, like a leash that prevents giant jumps. PMD has a beautiful closed-form update that reweights the current policy by how good each answer is, and then it normalizes everything using a special number called the partition function.
Why it matters: Without that normalization, probabilities wouldn’t add up to 1. But computing that partition function exactly is hard for LLMs because their action space (all possible long responses) is huge, and we only have a small set of sampled rollouts per prompt. If we try to estimate it from just a few samples, the estimate can be noisy and unstable, especially when the regularization is small (which engineers like because it learns faster).
🍞 Anchor: Think of making fruit punch for the class. You know how much each kid likes strawberry vs. lemon (rewards), so you add more strawberry if it’s liked more. But you must still make sure the total drink fits the pitcher (normalize). For a giant class, measuring the exact total across all kids’ tastes (partition function) is really hard from a few taste tests.
🍞 Hook: You also know how: sometimes the practice answers are generated by an older version of the model (stale), but you’re updating a newer model version. That’s called off-policy training, and it can cause mismatches like grading last week’s homework to teach today’s class.
🥬 The Concept: Popular methods like PPO or GRPO try to fix off-policy mismatch using importance sampling and clipping—like adjusting grades by how likely that answer would be under the current model. It helps, but brings many dials to tune and can still be unstable with long responses and asynchronous rollouts.
Why it matters: If the mismatch is large and your fixes are imperfect, training can wobble or collapse.
🍞 Anchor: It’s like trying to steer a bike by looking in the rearview mirror. You can adjust, but it’s tricky and wobbly.
🍞 Hook: A surprisingly simple trick used in Kimi K1.5/K2 said, “What if we don’t try to estimate the exact normalizer? What if we just use the mean reward as a stand-in?”
🥬 The Concept: This leads to PMD-MEAN. Instead of fitting the exact ideal target that needs the log-partition function, PMD-MEAN sets the target in log-policy space to ‘advantage divided by temperature’ (advantage = reward minus mean reward), and just does regression. No explicit partition function!
Why it matters: This makes training much simpler and more stable in practice, especially when rollouts are limited or off-policy.
🍞 Anchor: It’s like saying, “I won’t measure everyone’s exact thirst to the last drop. I’ll just balance around the average thirst and keep going.”
🍞 Hook: But is this approximation accurate? And what does it actually optimize, mathematically?
🥬 The Concept: The paper answers this by fully characterizing PMD-MEAN’s solution. It shows that PMD-MEAN is equivalent to solving a mirror descent problem with a mixed regularizer: usual KL plus an extra chi-squared (χ²) term whose weight adapts based on the mean reward. This extra term is like a stronger seatbelt against big probability spikes, especially when the average reward is low (common early in training).
Why it matters: This reveals a principled reason why PMD-MEAN is stable and robust with few samples—it resists overreacting to noisy rewards.
🍞 Anchor: Early in training, when most answers are wrong, PMD-MEAN doesn’t let the model suddenly jump to a few seemingly lucky answers. It tiptoes instead of leaping, so it doesn’t trip.
— New concept sandwich blocks used here:
-
Policy Mirror Descent (PMD) 🍞 You know how a coach helps an athlete improve by giving feedback without changing everything at once? 🥬 PMD is a way to update a model to like better actions more, while staying close to its current habits using KL-divergence as a leash. It works by reweighting probabilities by reward and then normalizing. Without PMD, updates could be wild and break language fluency. 🍞 Example: When teaching math solutions, PMD nudges the model to prefer correct paths but keeps its grammar and style steady.
-
KL-Divergence 🍞 Imagine comparing two smoothie recipes: how different are they? 🥬 KL-divergence measures how one probability distribution differs from another. PMD uses it to prevent the new policy from drifting too far from the old policy. Without it, the model might overfit a few lucky samples. 🍞 Example: If yesterday’s policy loved apples and today’s suddenly hates them, KL says, “Too big a jump!”
-
Log-Partition Function 🍞 Think of resizing all recipe amounts so the total fits one pitcher. 🥬 The partition function is the normalizer that ensures reweighted probabilities still sum to 1. Without it, your probability “recipe” overflows or doesn’t fill the pitcher. 🍞 Example: After boosting probabilities for good answers, the partition function scales the whole distribution back to sum to 1.
02Core Idea
🍞 Hook: Imagine you have a dimmer switch for how boldly your model should change: sometimes you want quick learning, sometimes gentle nudges. What if a simple average—mean reward—could set that dimmer smartly and automatically?
🥬 The Concept (Aha! in one sentence): Approximating the tough log-partition term with the mean reward makes the PMD update behave like mirror descent with an adaptive mix of KL and chi-squared (χ²) regularization, which naturally adds extra caution when rewards are low.
How it works (intuitively):
- Start from PMD: reweight by reward while staying close via KL.
- Replace the hard log-partition function with mean reward, and fit the target in log-policy space by regression.
- The exact population solution of this regression uses the Lambert-W function, which shrinks big probability changes more than the usual exponential would.
- This is equivalent to solving a mirror descent step with KL + χ² penalties, where the χ² weight grows when the mean reward is small.
- Result: Early on, when the model is mostly wrong, updates are conservative and stable; as the model improves, updates can become bolder.
Why it matters: Without this adaptive extra regularization, small-sample noise and off-policy mismatch can make updates twitchy or explode. The χ² term acts like a smart shock absorber.
🍞 Anchor: Training on math problems where most attempts are wrong at first, PMD-MEAN avoids over-boosting the few lucky attempts. It improves steadily instead of roller-coastering.
Three analogies for the same idea:
- Thermostat analogy: The mean reward is like room temperature. When it’s cold (low reward), the system moves slowly to avoid overshoot (extra χ² regularization). As it warms (higher reward), it allows faster changes.
- Car suspension analogy: KL is the standard shock absorber; PMD-MEAN adds an adaptive second shock (χ²) that stiffens on bumpy roads (noisy, low-reward phases), preventing the car from bouncing.
- Class grading analogy: Raising grades based on recent quizzes (rewards) is fine, but if the class average is low, you adjust more cautiously so one lucky quiz doesn’t decide the semester.
Before vs. After:
- Before: People either estimated the partition function (risky with few samples), or used complex off-policy corrections (many knobs, tricky stability), or had to crank regularization high (slow learning).
- After: Replace the partition with the mean reward, do simple regression, and get built-in extra regularization that adapts to how well the model is doing. More stability, less tinkering.
Why it works (intuition without equations):
- Reweighting by reward alone can make some actions explode in probability. The missing normalizer usually reins this in, but it’s hard to estimate.
- Mean reward shifts the target around an average, which centers updates and reduces variance.
- The exact solution shows a Lambert-W shrinkage, which grows slower than plain exponentials for big arguments, stopping runaway ratios.
- That shrinkage is identical to adding a χ² penalty that’s stronger when average reward is low—exactly when we most need caution.
Building blocks (each as a sandwich):
-
Mean Reward Approximation 🍞 You know how a report card often compares each test score to the class average? 🥬 We replace the complicated normalizer with the mean reward per prompt, and aim to match ‘reward minus mean’ (advantage) in log-prob space. Without this, we’d need a noisy estimate of a huge normalizer. 🍞 Example: For a prompt with mostly wrong answers, we center around that poor average to avoid overreacting to a rare correct guess.
-
χ²-Regularization 🍞 Picture a suitcase limit: you can’t pack too much more than before. 🥬 χ²-regularization penalizes large probability ratio spikes even more strongly than KL does. Without it, a few actions can balloon in probability from noisy wins. 🍞 Example: If one solution looks awesome by luck, χ² says, “Not so fast—prove it consistently.”
-
Adaptive Mixed Regularization 🍞 Think of dimming lights: brighter if it’s daytime, dimmer at night. 🥬 The method mixes KL and χ², and the χ² weight auto-adjusts using the mean reward. Low mean reward → stronger χ² → safer steps. High mean reward → looser χ² → faster learning. Without adaptivity, you’d need manual tuning. 🍞 Example: Early in math training (low pass rate), the method treads lightly; later, it speeds up as the student improves.
-
Lambert-W Function 🍞 Imagine a special undo button for equations where the thing you’re solving for appears both as a number and inside an exponent. 🥬 Lambert-W solves those twisty equations and, here, produces a probability update that naturally squashes huge jumps. Without Lambert-W’s shape, the update would act too much like a rocket booster. 🍞 Example: When a candidate answer looks super good, Lambert-W keeps the boost reasonable instead of letting it skyrocket.
03Methodology
At a high level: Prompts and old-policy rollouts → compute rewards and mean rewards → build a simple regression target in log-policy space → train the policy to match this target → get a stable, adaptively regularized update.
Step-by-step recipe (with why each step exists and an example):
- Collect off-policy rollouts.
- What happens: For each prompt x, sample multiple responses y from the current/older policy (asynchronous batches are okay). Score each with a reward r(x, y). Rewards can be binary (right/wrong) or other bounded scores.
- Why it exists: We need data to learn which actions are better. Without rollouts, no learning signal.
- Example: For a math problem, we generate 16 solutions and mark each as correct (+1) or incorrect (0 or −1).
- Compute mean reward per prompt and the advantage.
- What happens: For each prompt, compute the average reward across its sampled responses. Then compute advantage Δ(x,y) = r(x,y) − mean_reward(x).
- Why it exists: Centering around the mean reduces variance—updates depend on how much better/worse than typical an answer is, not the raw score. Without centering, noisy highs could dominate.
- Example: If for a prompt the 16 responses have 2 correct and 14 wrong (mean ≈ 0.125 with 0/1 rewards), then a correct answer has advantage +0.875 and a wrong answer has −0.125.
- Build a target in log-policy space.
- What happens: Define the target s*(x,y) = Δ(x,y) / τ, where τ is the temperature (regularization strength). We want the log ratio log π_new(y|x) − log π_old(y|x) to be close to s*.
- Why it exists: Working in log-space turns multiplicative changes in probabilities into additive targets and simplifies optimization. Without log-space, training would be more brittle and scale poorly.
- Example: With τ = 0.05 and advantage +0.875, s* ≈ 17.5—strongly increase that response (but see Step 6 for how the method keeps this safe).
- Regress the policy to the target.
- What happens: Minimize the squared error between the model’s log-prob ratio and s*, across sampled (x,y). In practice, we sum losses over tokens in the sequence (length-normalized) or over sequences (sequence-level), depending on the implementation.
- Why it exists: Directly tracking the target is simpler than estimating the partition function. Without this regression view, we would struggle with normalizing a massive action space.
- Example: The model adjusts its logits so that answers better than average get higher log-prob, and worse ones get lower, in proportion to advantage/τ.
- Enforce normalization implicitly.
- What happens: Although we don’t compute the partition function, the learned policy must still be a valid distribution (probabilities sum to 1). The regression with a fixed old-policy baseline implicitly balances increases and decreases across actions, and the population optimum satisfies a normalization via a special Lambert-W solution.
- Why it exists: Without normalization pressure, the model could inflate probabilities everywhere. The optimization structure plus the softmax ensures legal distributions.
- Example: If a few answers go up, others must come down so total probability stays 1.
- The secret sauce: implicit χ² regularization via Lambert-W.
- What happens: The exact population solution for PMD-MEAN takes the form of a Lambert-W shrinkage on probability ratios. Equivalently, PMD-MEAN solves a mirror descent subproblem with KL + χ² penalties, where the χ² weight depends on the mean reward.
- Why it exists: χ² adds a stronger penalty for large changes than KL alone, which is exactly what you want when the data are noisy or rewards are low. Without it, noise could trigger overshooting and instability.
- Example: Early in training, when pass rates are small, negative (wrong) answers shrink by about exp(−p/τ), which is more gentle than the partition-normalized update would force, avoiding collapse.
- Pick τ and batch sizes.
- What happens: Choose τ (smaller τ pushes stronger updates; larger τ is gentler) and use larger global rollout batches to amortize generation cost. Off-policy staleness is okay because PMD-MEAN is robust to it.
- Why it exists: τ is your learning boldness knob; bigger rollout batches reduce time per token and improve efficiency. Without thoughtful τ and batching, you may train slowly or unstably.
- Example: The paper used τ in {0.005, 0.01, 0.02} for 7B and up to 0.1 for 30B MoE, and global batches of 512 prompts to get 4.6× speedups over strict on-policy training.
- Practical loss forms.
- What happens: For each sampled response, use a leave-one-out mean reward as a stable baseline. Train with length-normalized sequence-level objectives, similar to RLOO/GRPO styles but without explicit partition estimation.
- Why it exists: Leave-one-out avoids bias from including the target in its own baseline and stabilizes variance. Without it, targets can be slightly biased and noisier.
- Example: For 16 responses, each one’s baseline is the average of the other 15 rewards.
- Monitor policy ratios and stability.
- What happens: Track min/max log π_new/π_old during training steps. PMD-MEAN shows more conservative decreases on bad answers than partition-based updates, especially early.
- Why it exists: This confirms the theory: χ²-like behavior limits spikes and prevents collapse. Without monitoring, you might miss subtle instabilities.
- Example: Plots show PMD-MEAN’s negative-action shrinkage is weaker at first, then strengthens as accuracy improves.
What breaks without each step:
- No mean baseline → noisier targets → overreaction to rare good samples.
- No regression in log-space → hard normalization, scaling issues.
- No implicit χ² → large ratio spikes → instability with few rollouts.
- No large rollout batches → slow training; less amortized generation.
Concrete data example:
- Prompt: A math puzzle. 16 rollouts: 3 correct (1.0), 13 wrong (0.0). Mean reward = 0.1875.
- Advantages: correct = +0.8125; wrong = −0.1875.
- Targets: s* = advantage/τ (suppose τ = 0.05 → s_correct ≈ 16.25; s_wrong ≈ −3.75).
- Update: The model increases prob. of those 3 correct paths and decreases the others, but Lambert-W/χ² behavior curbs extreme jumps, especially for the majority-wrong set.
Secret sauce summary:
- Replace the fragile partition estimate with the mean reward.
- Use regression in log-policy space to keep it simple and stable.
- Gain an automatic, adaptive χ² safety belt (via Lambert-W), most active when rewards are low.
— New concept sandwich blocks used here:
-
Off-Policy Training (staleness) 🍞 You know how reading last week’s essays to grade today’s class can be mismatched? 🥬 Off-policy means your data were made by an older policy than the one you’re updating now. PMD-MEAN handles this well without heavy corrections. Without such robustness, training can wobble. 🍞 Example: Asynchronous rollouts where the sampler lags behind the learner.
-
Importance Sampling & Clipping (context) 🍞 Think of weighting past homework by how relevant it is to today’s lesson. 🥬 Importance sampling reweights old samples to mimic the current policy; clipping limits extreme weights. Without clipping, variance can explode; with too much clipping, bias grows. 🍞 Example: PPO/GRPO use token-level clipping to tame off-policy drift.
-
Boltzmann Reweighting (context) 🍞 Like giving extra dessert to students with higher scores. 🥬 The ideal PMD update multiplies old probabilities by exp(reward/τ) and then normalizes (partition function). Without normalization, totals don’t sum to 1. 🍞 Example: Better answers get exponentially more probability, but everyone gets scaled to fit the probability ‘plate.’
04Experiments & Results
The test: The authors trained LLMs on math reasoning (DAPO-Math-17k) and evaluated on AIME 2024 and 2025. They compared PMD-MEAN to GRPO (a strong baseline) and also to GSPO (an advanced variant that improves stability for MoE models). They also measured training efficiency (time per token) and stability (no collapse, smooth reward curves).
The competition:
- GRPO: widely used, critic-free policy gradient with group baselines and clipping.
- On-policy gradient: a simpler baseline that avoids staleness but is slower because of small batches.
- GSPO: a stronger GRPO variant using sequence-level importance sampling and clipping for MoE stability.
Scoreboard with context:
- Qwen2.5-7B (dense):
- GRPO: Average AIME24/25 ≈ 13.8.
- On-policy: ≈ 18.49 but slower (small batches).
- PMD-MEAN (τ = 0.005): ≈ 19.58 average, which is a clear jump above GRPO (think moving from a B− to a solid A−), and similar to or better than on-policy but much faster.
- PMD-MEAN (τ = 0.02): ≈ 19.58 average as well, showing robustness across τ choices.
- Qwen3-30B-A3B-Base (MoE):
- GRPO: ≈ 32.24 average.
- PMD-MEAN (τ = 0.1): ≈ 44.01 average, which is like boosting class rank from middle to top quartile.
Efficiency:
- PMD-MEAN with large global batches (off-policy) achieves around 4.6× speedup over strict on-policy training while delivering equal or better accuracy.
- Generation cost dominates; PMD-MEAN amortizes it by using larger rollout batches and asynchronous pipelines.
Stability findings:
- PMD-MEAN shows smooth, steadily rising training rewards and evaluation accuracy.
- PMD-PART (fitting the exact partition-normalized target) can become highly unstable and even collapse unless τ is much larger (which slows learning). This matches the theory: PMD-PART is more sensitive to finite-sample errors early on when the pass rate is low.
Surprising or notable results:
- The more conservative shrinkage of bad actions in PMD-MEAN, especially early, leads to better stability and still strong final performance—confirming the predicted mixed KL–χ² behavior.
- Against GSPO, which adds sophisticated importance sampling and clipping, PMD-MEAN is competitive or better with a simpler training loop—suggesting that the implicit χ² regularization can replace some engineering complexity.
What the numbers mean in plain terms:
- AIME gains of +8% to +15% absolute on some settings are big in math reasoning: that’s the difference between often missing problems and solving many more, without adding complicated tricks.
- Time-per-token drops significantly, making it much more practical to scale training.
Takeaway: PMD-MEAN is a simple method that wins on stability, speed, and accuracy in these math reasoning tasks, precisely where off-policy staleness and limited rollouts used to make training fragile.
05Discussion & Limitations
Limitations (be specific):
- Approximation gap: Replacing the log-partition with the mean reward is not exact, especially at very small τ. While the χ² regularization helps, there’s still a mismatch that can make positive-action targets slightly too aggressive in some cases.
- Early-phase speed vs. caution: PMD-MEAN is intentionally conservative when mean rewards are low. This stabilizes learning but can be slower than the ideal (partition-based) update in a perfect, infinite-sample world.
- Bandit framing: The analysis is done in a contextual bandit view (single-step decisions). Extending every guarantee cleanly to long-horizon settings may require extra care.
- Reward design: Binary rewards are simple but coarse. Different or noisy reward designs may change the best τ or the practical behavior of the adaptive regularization.
- Implementation assumptions: Full-support policies (no zero probabilities) are assumed in the analysis. Aggressive top-k/top-p truncation might break some guarantees if probabilities hit zero.
Required resources:
- Access to old and current policies’ log-probabilities over generated sequences.
- Moderate to large batch generation capability (to amortize cost and stabilize statistics).
- GPUs or accelerators for sequence-level training; memory for storing rollouts.
- Basic RLHF/RLVR setup: prompts, rollout sampling, and reward computation.
When not to use PMD-MEAN:
- If you can reliably estimate the partition function with many samples and want the fastest theoretical contraction, PMD-PART may converge in fewer ideal steps (though it risks instability in practice).
- If your training is strictly on-policy with tiny sequence lengths and no staleness, simpler on-policy methods might suffice and be easier to reason about.
- If your reward signal requires very fine-grained credit assignment that benefits from a value critic network, a policy-gradient-plus-critic approach could be preferable.
Open questions:
- Adaptive τ scheduling: Theory hints that scaling τ with per-prompt pass rate p could further improve robustness and speed. What’s the best schedule in practice?
- Beyond bandits: How do the mixed KL–χ² benefits carry over to multi-step RL settings with credit assignment and bootstrapping?
- Reward shapes: How does PMD-MEAN behave with dense rewards, shaped rewards, or learned reward models? Any new stability patterns?
- Composability: Can we combine PMD-MEAN with light-weight importance sampling, oversampling hard prompts, or curriculum strategies for extra gains without losing theoretical clarity?
- Model architecture: Does the χ² effect interact with MoE routing and long-context models in special ways (e.g., better expert balance, less collapse)?
06Conclusion & Future Work
Three-sentence summary: The paper shows that a simple mean-reward approximation to the log-partition function (PMD-MEAN) leads to a closed-form update using the Lambert-W function and is exactly equivalent to mirror descent with an adaptive KL–χ² regularizer. This extra χ² term automatically tightens the leash on probability changes when rewards are low, yielding robust, conservative updates that are less sensitive to small-sample noise and off-policy staleness. Experiments on math reasoning confirm higher accuracy, smoother training, and faster throughput compared to strong baselines.
Main achievement: Revealing the precise mathematics and practical mechanism behind PMD-MEAN—its Lambert-W update and its equivalence to adaptive mixed KL–χ² regularization—explains why it is stable and effective for LLM post-training without extra engineering complexity.
Future directions:
- Design adaptive τ schedules tied to pass rates or uncertainty estimates.
- Extend theory and practice beyond bandits to full RL (multi-step) with value estimates.
- Combine PMD-MEAN with gentle importance sampling or curriculum sampling to accelerate learning while keeping stability.
- Study interactions with MoE routing, long context, and reasoning-specific rewards.
Why remember this: A tiny change—swap a hard normalizer for the mean reward—quietly builds in a powerful, adaptive safety belt (χ²), making off-policy LLM RL simpler, sturdier, and faster. It’s a rare win where less machinery yields both better theory and better practice.
Practical Applications
- •Stabilize off-policy RL training for LLMs in asynchronous or large-batch pipelines.
- •Speed up training by using larger rollout batches without heavy importance sampling machinery.
- •Improve math and code reasoning models where rewards are sparse or binary (correct/incorrect).
- •Reduce training collapses in mixture-of-experts models by curbing extreme probability jumps.
- •Use a simple loss (log-policy regression to advantage/τ) instead of estimating partition functions.
- •Tune fewer knobs: rely on adaptive χ² behavior to keep early training cautious and robust.
- •Monitor policy ratio stats (log π_new/π_old) to confirm conservative early updates and healthy progression.
- •Combine with modest curriculum or oversampling of hard prompts for further gains without destabilizing.
- •Apply to RL from verifiable rewards (e.g., unit tests, auto-graders) where reward noise can mislead.
- •Prototype safer RLHF/RLAIF post-training recipes with simpler code paths and clearer theory.