Jet-RL: Enabling On-Policy FP8 Reinforcement Learning with Unified Training and Rollout Precision Flow
Key Summary
- •Reinforcement learning (RL) for large language models is slow because the rollout (text generation) stage can take more than 70% of training time, especially for long, step-by-step answers.
- •Many teams tried speeding this up by using FP8 just for rollout and BF16 for training, but this caused a mismatch that made learning unstable on long sequences and hard tasks.
- •Jet-RL fixes this by using the same FP8 precision flow in both training and rollout, keeping the learning truly on-policy and consistent.
- •The key is a unified precision graph and fine-grained FP8 quantization (per-block and per-group) for all linear-layer GEMMs, while keeping a BF16 master copy of weights.
- •Across models like Llama3.1-8B, Qwen2.5-7B, and Qwen3-8B-Base, Jet-RL converges reliably and usually stays within about 1% of full BF16 accuracy.
- •Jet-RL speeds up rollout by up to 1.33×, speeds up the training phase by up to 1.41×, and delivers up to 1.16× end-to-end speedup on 8B models.
- •It avoids expensive calibration between steps, reducing engineering complexity and training stalls.
- •BF16-train + FP8-rollout can collapse at 16K-token generations; Jet-RL remains stable even for long, hard reasoning.
- •The method uses real FP8 kernels (DeepGEMM) and integrates with vLLM and VeRL to work in practical RL pipelines.
Why This Research Matters
Jet-RL makes long reasoning RL both faster and more reliable by aligning practice and game-time math. That means training costs drop and energy use falls, so more teams can afford to build strong reasoning models. Stable learning at 8K–16K tokens enables step-by-step solutions in education, science, code generation, and planning without collapse. Because it avoids repeated calibration stalls and integrates with real systems (vLLM, VeRL, DeepGEMM), it’s practical for production pipelines. In short, Jet-RL helps the AI community move from risky speed hacks to dependable, efficient low-precision training for the era of long CoT.
Detailed Explanation
Tap terms for definitions01Background & Problem Definition
🍞 Hook: Imagine you’re solving a big puzzle by writing out your thinking step by step. The longer your explanation, the more time it takes, right?
🥬 The Concept: Chain-of-Thought (CoT) is when a model writes its reasoning step by step before answering. How it works:
- The model reads the question.
- It generates many reasoning tokens in order.
- It uses those steps to form the final answer. Why it matters: Without CoT, models may guess; with CoT, they can solve tougher problems. But long step-by-step thinking takes time and compute. 🍞 Anchor: For a hard math word problem, a model might write 6,000+ tokens of reasoning before giving the final number.
🍞 Hook: You know how a coach wants the team to practice the same way they play in the real game?
🥬 The Concept: Reinforcement Learning (RL) is how models learn by trying, getting rewards, and improving their strategy (policy). How it works:
- Try to answer by generating text (an action).
- Get a reward (was it correct/helpful?).
- Update the policy to do better next time. Why it matters: RL helps models get good at reasoning, not just memorizing. 🍞 Anchor: A model tries several solutions to a math problem, gets a reward for the right one, and learns to do more of what worked.
🍞 Hook: If you practice your free throws with a heavy ball, but play with a normal ball in the game, your shots might be off!
🥬 The Concept: On-Policy Training means you learn using the same strategy and conditions you’ll use when you act. How it works:
- Use the current model to generate data (rollouts).
- Update the same model based on that data.
- Keep practice and game consistent. Why it matters: If practice (training) and the game (rollout) don’t match, learning can go off-track. 🍞 Anchor: A basketball player practicing with the same ball and hoop as the real game learns more reliably.
🍞 Hook: You know how taking a test takes longer than studying flashcards at home?
🥬 The Concept: Rollout Phase is when the model actually generates long answers token by token to see how it performs. How it works:
- Read the prompt.
- Autoregressively generate many tokens.
- Measure results with reward/critics. Why it matters: Rollout can consume over 70% of RL training time for long responses, becoming the main bottleneck. 🍞 Anchor: For 8K–16K-token answers, rollout time dominates the clock like a final exam that lasts all afternoon.
🍞 Hook: Think of numbers as jars of marbles—smaller jars are easier to carry quickly.
🥬 The Concept: FP8 Quantization stores numbers in 8-bit floating point to save memory and speed up compute. How it works:
- Find a scale for a tensor so values fit FP8.
- Convert (quantize) values to FP8.
- Use fast FP8 hardware to multiply/add. Why it matters: Without FP8, computations (especially rollouts) are slower and more costly. 🍞 Anchor: Using FP8 lets the model generate tokens faster, like switching to lighter running shoes.
🍞 Hook: Shorthand can be faster than full sentences, but still clear enough to study from.
🥬 The Concept: BF16 Precision is a 16-bit format that keeps training more stable than 8-bit in many cases. How it works:
- Store numbers with fewer bits than FP32 but enough range for gradients.
- Run training math in BF16 or mixed precision. Why it matters: Pure low-precision can be unstable; BF16 is a popular training default. 🍞 Anchor: Many LLMs train in BF16 to balance speed and accuracy.
🍞 Hook: If you whisper a story through many people, tiny mistakes can snowball into nonsense.
🥬 The Concept: Catastrophic Collapse is when training suddenly falls apart and accuracy tanks. How it works:
- Small mismatches accumulate over long generations.
- The learned policy drifts from good behavior.
- Performance collapses on hard or long tasks. Why it matters: Collapse wastes compute and ruins results. 🍞 Anchor: Prior methods using BF16 for training but FP8 for rollout often collapsed at 16K tokens.
🍞 Hook: Imagine grading a soccer game not only by the final score but by how helpful each play was.
🥬 The Concept: Generalized Advantage Estimation (GAE) is a smarter way to estimate how good actions were. How it works:
- Compare actual outcomes to expected ones.
- Smooth across time to reduce noise.
- Produce advantages that guide updates. Why it matters: Without GAE, updates can be jittery and less effective. 🍞 Anchor: In PPO-style training, GAE helps judge which generated steps improved the final answer.
The world before: LLMs got better at reasoning by writing long Chain-of-Thoughts and using RL to improve them. But rollout time soared with longer answers, eating more than 70% of the total training time. Teams tried a quick fix: keep training in BF16 (stable) and speed up rollout by casting to FP8. This sped up generation but secretly broke the on-policy rule: the model learned from BF16 math in training but acted with FP8 math at rollout. Tiny numeric differences seemed harmless for short answers, but for 8K–16K tokens those errors piled up like compounding whispers.
The problem: BF16-train + FP8-rollout became unstable on long sequences and hard tasks, sometimes failing to converge. Inter-step FP8 calibration was too slow to do frequently, so many pipelines skipped it, making the mismatch worse. The gap: no framework ensured the exact same precision flow in training and rollout to preserve on-policy learning while still getting FP8 speed.
Real stakes: Faster, stable RL means cheaper training, greener energy use, and more accessible research. It also means models that can reason longer without breaking—useful for science, tutoring, coding, and planning where multi-step thinking is essential.
02Core Idea
🍞 Hook: If you practice piano on a lightly detuned keyboard but perform on a perfectly tuned one, your fingers learn the wrong feel.
🥬 The Concept: The key idea is to make training and rollout use the exact same FP8 precision flow so learning stays on-policy and stable while still being fast. How it works:
- Build one unified precision graph for forward passes in both training and rollout.
- Quantize the same tensors with the same granularity in both.
- Keep a BF16 master copy of weights to stabilize updates.
- Use FP8 GEMMs everywhere they fit (forward and backward) and save activations in FP8. Why it matters: Without identical precision flows, small numeric mismatches grow during long generations, causing drift and instability. 🍞 Anchor: With Jet-RL, the math done during learning matches the math during generation, like practicing and performing on the same instrument.
Three analogies:
- Glasses prescription: If you practice reading with weak glasses but take the test with strong ones, you’ll misjudge line spacing. Jet-RL keeps the same prescription for practice and test.
- Shoe traction: Training in cleats but playing in sneakers messes up your steps. Jet-RL wears the same shoes on both fields.
- Recipe oven: Testing cookies in a toaster but selling ones baked in a convection oven yields surprises. Jet-RL bakes all batches in the same oven.
Before vs After:
- Before (BF16-train + FP8-rollout): Faster rollouts, but off-policy drift, instability on long sequences and harder datasets, sometimes collapse.
- After (Jet-RL unified FP8): Similar speedups but with on-policy consistency, stable convergence across 8K–16K tokens, and accuracy close to BF16 (often within ~1%).
🍞 Hook: Remember the telephone game where tiny mis-hearings add up to a different message?
🥬 The Concept: Precision-flow consistency is the secret. Why it works (intuition):
- The policy you update must be the policy that generated your data; otherwise, PPO-style methods chase a moving, mismatched target.
- Long rollouts amplify tiny numeric differences at every token step; keeping precision identical clamps down drift.
- Fine-grained quantization (per-block weights, per-group activations/gradients) tames outliers so FP8 stays accurate.
- Storing activations in FP8 aligns forward/backward paths and reduces memory traffic. Why it matters: This alignment guards against divergence and keeps rewards meaningful to the actual acting policy. 🍞 Anchor: Like tracking footprints in wet cement—the first mismatch skews the path; Jet-RL lays matching footprints with the same shoes.
Building blocks (each with a quick sandwich):
-
🍞 Hook: Sorting LEGO bricks into small tubs is easier than one giant bin. 🥬 The Concept: Quantization Granularity means how finely you scale values before turning them into FP8. How it works: Use 128×128 per-block for weights; 1×128 per-group for activations/gradients. Why it matters: Coarse scales lose detail; fine-grained scales keep accuracy. 🍞 Anchor: Grouping bricks by color-and-size helps you rebuild the exact model.
-
🍞 Hook: A math class uses three types of problems: compute, track weight mistakes, and fix inputs. 🥬 The Concept: GEMMs in training: FProp (forward compute), WGrad (weight gradients), DGrad (activation gradients). How it works: All three use FP8 inputs with fine-grained scaling, then output BF16. Why it matters: If one GEMM stays high-precision but others are FP8, the flow mismatches again. 🍞 Anchor: Using the same calculator rules for homework and tests yields consistent grades.
-
🍞 Hook: You keep a neat original drawing in a folder but work on copies. 🥬 The Concept: BF16 Master Weights are the safe original, while FP8 copies are used for forward/backward math. How it works: Update the BF16 master; re-quantize to FP8 for compute. Why it matters: Prevents weight drift and preserves long-term learning quality. 🍞 Anchor: Artists trace from the master sketch to avoid deforming the original line art.
-
🍞 Hook: Replay the same song on the same speaker to check clarity. 🥬 The Concept: Unified Precision Graph means the rollout engine is literally a subgraph of the training forward pass. How it works: Identical edges (precision/granularity) in both; inference just omits backward. Why it matters: Eliminates policy mismatch. 🍞 Anchor: Practice playlist and performance playlist use the same audio chain, no surprises.
03Methodology
High-level recipe: Prompt → (Unified FP8 Rollout) → Rewards/Critic/Ref (prefill) → (Unified FP8 Training Step: forward + backward with BF16 master weights) → Updated Actor (sync) → Repeat.
Step 1: Unified FP8 precision flow for the forward pass (training and rollout)
- What happens: The exact same quantization decisions (which tensors are FP8, scaling granularity) are applied in the training forward pass and the rollout engine. We treat the rollout graph as a subgraph of the training forward graph.
- Why this step exists: If forward math differs between training and rollout, the policy is off-policy, and tiny numeric deviations snowball during long Chain-of-Thought sequences.
- Example: For a Qwen-8B model generating 8K tokens, each token depends on the last. A 0.1% per-step mismatch can stack into a large trajectory divergence; matching precision keeps the trajectory aligned.
🍞 Hook: Using the same ruler for class and exams avoids measuring mistakes. 🥬 The Concept: Precision Flow Graph How it works:
- Make a node for each op/weight and edges for tensors.
- Mark edges with data precision and granularity.
- Ensure the rollout graph is a subgraph of the training forward graph. Why it matters: It’s a blueprint ensuring identical math. 🍞 Anchor: Both practice and game plans come from the same playbook.
Step 2: FP8 quantization scheme for linear layers
- What happens: All three core GEMMs (FProp, WGrad, DGrad) consume FP8 inputs using fine-grained quantization; outputs are BF16. We use 128×128 per-block FP8 for weights and 1×128 per-group FP8 for activations/gradients. Forward activation quantization is fused with preceding ops when possible.
- Why this step exists: Per-tensor FP8 can be unstable. Fine granularity captures local ranges and outliers, preserving accuracy while enabling FP8 speed.
- Example: In attention’s output projection, activations are quantized in 1×128 groups; weights in 128×128 blocks. The FP8 GEMM runs fast on Tensor Cores, then outputs BF16 for the next layers to consume stably.
🍞 Hook: Different tools for different jobs: scissors for paper, clippers for hedges. 🥬 The Concept: GEMM Roles (FProp/WGrad/DGrad) How it works:
- FProp: activation (1×128) × weight (128×128) in FP8.
- WGrad: needs gradients in formats (1×128 and 128×1); we fuse quantization/requantization.
- DGrad: mirrors FProp’s shapes for gradients. Why it matters: Each path must be consistently FP8 to keep precision flow unified. 🍞 Anchor: It’s like matching plug shapes so all appliances work in the same outlet.
Step 3: Backward pass and saved activations
- What happens: Activations saved for backward are stored in FP8 to match the forward’s quantized outputs. Gradients flowing between ops remain BF16 to avoid underflow/noise, but the GEMMs (WGrad/DGrad) still use FP8 inputs with our fine-grained scheme.
- Why this step exists: Saving activations in FP8 aligns forward/backward math and keeps memory bandwidth lower, while BF16 gradients preserve convergence stability.
- Example: RMSNorm outputs get quantized to FP8; those FP8 activations are what the backward GEMMs consume (with fused quant steps where needed).
🍞 Hook: Keep the master recipe safe, and cook from copies. 🥬 The Concept: BF16 Master Weights How it works:
- Maintain BF16 master parameters.
- On each update, apply optimizer steps to BF16.
- Quantize to FP8 for forward/backward compute. Why it matters: Prevents long-term drift; FP8 math stays fast, BF16 keeps learning sturdy. 🍞 Anchor: A chef writes edits on the master card, then hands copies to line cooks.
Step 4: Inference and system integration
- What happens: vLLM serves rollout with the exact same quantization and kernels as training forward. VeRL runs PPO/GRPO-style updates with Jet-RL’s precision rules. DeepGEMM provides efficient FP8 kernels; Triton implements fused quant/norm/transpose.
- Why this step exists: System parity ensures the rollout path truly matches training’s forward path, eliminating inter-step calibration and downtime.
- Example: We quantize weights at parameter-update time (cheap) so the inference engine always has fresh FP8 weights matching the training precision flow.
🍞 Hook: Borrowing a friend’s game console won’t help if your controller mapping is different. 🥬 The Concept: No Inter-Step Calibration How it works:
- Because the precision flow is identical, we don’t need data-dependent calibration after each update.
- We avoid tens-of-minutes stalls for 8B models. Why it matters: More time training, less time waiting. 🍞 Anchor: Skipping the re-tuning step after every song keeps the concert moving.
Step 5: Handling long rollouts and hard tasks
- What happens: The unified FP8 flow keeps rollout and training tightly matched even at 16K tokens or on DeepMATH. Small numeric differences don’t accumulate into off-policy drift.
- Why this step exists: Prior BF16-train + FP8-rollout approaches collapsed on long sequences; Jet-RL was designed to specifically stop that.
- Example: Qwen3-8B-Base at 16K length: the baseline FP8 rollout failed to converge, while Jet-RL stayed within ~2.7% of BF16 average scores.
The secret sauce:
- Identical precision flow eliminates the root cause of mismatch.
- Fine-grained FP8 (1×128 groups and 128×128 blocks) stabilizes low-precision math.
- BF16 master weights and BF16 gradient transport preserve learning quality.
- Practical kernels (DeepGEMM) plus fused ops (Triton) turn the theory into real speed.
What breaks without each step:
- If forward precision mismatches training vs rollout: off-policy drift and instability.
- If granularity is too coarse: FP8 errors spike, harming accuracy.
- If saved activations aren’t FP8: forward/back misalign; memory traffic rises.
- If gradients are quantized too aggressively: underflow/noise harms convergence.
- If inter-step calibration is required: training stalls kill end-to-end gains.
04Experiments & Results
The test: Measure stability, accuracy, and speed under real RL settings.
- Why: Rollout dominates time; we need to know if FP8 is both faster and stable on long CoT.
- What: Compare Jet-RL vs BF16 training and BF16-train + FP8-rollout across datasets (GSM8K, MATH, DeepMATH) and models (Llama3.1-8B, Qwen2.5-7B, Qwen3-8B-Base) at 8K and 16K rollouts.
- Metrics: Benchmark scores (GSM8K, MATH500, AMC, GPQA, SuperGPQA), convergence behavior, and throughput (tokens/s, step-time speedup).
The competition:
- BF16 full training: accuracy-strong but slower.
- BF16-train + FP8-rollout: popular in practice but often unstable on long/hard tasks.
- Jet-RL: unified FP8 flow in training and rollout.
Scoreboard with context:
-
8K rollout length:
- Llama3.1-8B on GSM8K+MATH: BF16-train + FP8-rollout averaged about a 9.8% drop vs BF16, while Jet-RL slightly beat BF16 (+2.0%) on average. That’s like Jet-RL getting an A when the FP8 baseline got a C+.
- Qwen2.5-7B: BF16-train + FP8-rollout did not converge; Jet-RL finished stably, only ~1% behind BF16 on average. That’s the difference between finishing a marathon vs dropping out halfway.
- Qwen3-8B-Base: BF16-train + FP8-rollout lost ~2.9% vs BF16; Jet-RL cut the gap to ~1.1%.
-
16K rollout length and DeepMATH:
- Qwen3-8B-Base at 16K: BF16-train + FP8-rollout didn’t converge. Jet-RL converged, about 2.7% behind BF16 on average—like going from a likely fail to a solid B.
- DeepMATH (harder): BF16-train + FP8-rollout suffered a 10.3% average drop vs BF16; Jet-RL shrank that to ~0.9%. That’s the difference between a big red mark and nearly matching the teacher’s key.
- Qwen2.5-7B at 16K: BF16-train + FP8-rollout dropped ~5.0%; Jet-RL cut that to ~3.0%.
Speed results (tokens/s and step-time):
- Rollout throughput: FP8 achieved 1.07×–1.33× speedups depending on model size and tensor parallel degree. Bigger models (e.g., 32B) saw up to 1.33×; heavy parallelism reduced the gain since communication overhead grows.
- Training phase: On Qwen3-8B, actor updates sped up ~1.54× and reference model prefill ~1.80×, giving a ~1.41× training-phase speedup.
- End-to-end: Combining rollout and training speedups, Jet-RL reached up to ~1.16× step-time speedup on 8B models; larger models are expected to benefit even more.
Surprising findings:
- Jet-RL sometimes slightly outperformed BF16 on averages (e.g., Llama3.1-8B), suggesting fine-grained FP8 may add mild regularization or smoother optimization in some regimes.
- FP8-train + FP8-rollout consistency mattered more than FP8 alone: mismatched flows (BF16-train + FP8-rollout) failed even when FP8 hardware was fast.
- The benefit of FP8 increased with model size, but too much tensor parallelism ate into gains due to communication costs.
Takeaway in kid-friendly terms: When practice and the game used the same rules (same FP8 math), the team played better and faster, even in long matches. When the rules differed between practice and game, the team stumbled—especially in overtime.
05Discussion & Limitations
Limitations:
- Ultra-large models (>32B) weren’t fully explored in end-to-end RL due to resource limits; speedups should grow with size, but communication overhead and memory patterns need careful tuning.
- Gradients between ops are kept in BF16 for stability; further quantizing them could save bandwidth but risks underflow and convergence issues.
- Extremely high tensor parallel degrees reduce FP8’s relative speedup because communication starts to dominate.
- The approach depends on strong FP8 kernels (e.g., DeepGEMM) and integration with engines like vLLM; older hardware or missing kernels will reduce benefits.
Required resources:
- NVIDIA GPUs with FP8 Tensor Core support (e.g., H100-class), plus DeepGEMM kernels and a framework stack (VeRL + vLLM) that can mirror the precision graph.
- Engineering to fuse quantization with preceding ops, manage BF16 master weights, and ensure rollout engines exactly mirror training precision.
When not to use:
- Very short rollouts (e.g., <1–2K tokens) or trivial tasks where rollout isn’t a bottleneck; the extra quantization plumbing may not justify itself.
- Pipelines lacking FP8-capable hardware/kernels; you won’t see speedups and might add complexity.
- If your system is heavily communication-bound (very high tensor parallelism), first fix parallelism bottlenecks.
Open questions:
- Can we safely quantize gradient transport for even more bandwidth savings without hurting convergence?
- How does unified FP8 interact with alternative formats (e.g., NVFP4) or hybrid precision schedules?
- Can asynchronous RL pipelines (e.g., AReaL-like) plus unified FP8 further boost utilization while staying on-policy enough in practice?
- Are there theoretical guarantees on stability with matched precision flows over very long horizons (e.g., >32K tokens)?
- What auto-tuning tools can co-design quantization granularity, kernel selection, and parallelism for each model size?
06Conclusion & Future Work
Three-sentence summary:
- Jet-RL shows that the main reason FP8 rollouts destabilized RL was a training–inference precision mismatch that violated on-policy learning.
- By enforcing the exact same FP8 precision flow in training and rollout, and keeping BF16 master weights, Jet-RL delivers stable convergence close to BF16 accuracy.
- It also speeds up both rollout and training, improving end-to-end step time without costly inter-step calibration.
Main achievement:
- A practical, unified FP8 RL framework that preserves on-policy consistency at scale, stabilizing long Chain-of-Thought training while unlocking up to 1.33× rollout and 1.41× training-phase speedups.
Future directions:
- Extend to larger models (>32B) with careful communication/parallelism co-design.
- Explore safe gradient quantization and hybrid precision schedules (e.g., NVFP4+FP8).
- Combine with asynchronous RL systems and improved sampling strategies to further raise utilization.
Why remember this:
- Jet-RL turns FP8 from a risky speed trick into a reliable training tool: same math in practice and in the game. That simple alignment unlocks faster, greener, and more stable reasoning RL—just when long, careful thinking is becoming the standard for powerful AI.
Practical Applications
- •Train RLHF/GRPO reasoning models with unified FP8 to cut costs while preserving stability on long rollouts.
- •Adopt the same FP8 quantization scheme (1×128 activations/gradients, 128×128 weights) across training and rollout.
- •Maintain BF16 master weights and quantize to FP8 for compute each step to stabilize updates.
- •Save activations in FP8 and keep inter-op gradient transport in BF16 to balance stability and speed.
- •Integrate DeepGEMM FP8 kernels and fuse activation quantization with preceding ops via Triton.
- •Mirror the training forward precision graph in vLLM (rollout) exactly; avoid inter-step calibration.
- •Tune tensor parallel degree to avoid communication-dominated regimes that blunt FP8 gains.
- •Use Jet-RL for hard tasks (e.g., DeepMATH, 16K+ tokens) where BF16-train + FP8-rollout often collapses.
- •Continuously monitor on-policy consistency by checking that logits match between training forward and rollout.
- •Run ablations to validate granularity choices and confirm end-to-end speedups on your specific hardware.