DASH: Faster Shampoo via Batched Block Preconditioning and Efficient Inverse-Root Solvers
Key Summary
- •Shampoo is a smart optimizer that can train models better than AdamW, but it used to be slow because it must compute tricky inverse matrix roots.
- •DASH speeds Shampoo up by stacking many small matrices into one 3D pile so the GPU can work on them all at once.
- •DASH adds faster math recipes (Newton-DB and Chebyshev) for those inverse roots, reducing how many expensive steps are needed.
- •A new scaling trick with multi-Power-Iteration makes those recipes converge faster and more reliably than the usual Frobenius-norm scaling.
- •On big language model training (953M parameters), DASH made optimizer steps up to 4.83× faster than the well-optimized Distributed Shampoo.
- •Newton-DB gave the best validation perplexity per iteration among all tested methods, even matching or beating EVD while running much faster.
- •Running the Coupled-Newton method in FP16 was safe and sped things up further, but FP16 Newton-DB was unstable (future work).
- •DASH’s GPU-aware design also improved memory usage and made load balancing across GPUs easier.
- •Overall, DASH makes high-quality second-order optimization practical without paying a huge time cost.
Why This Research Matters
Training large AI models is expensive and slow, so cutting optimizer time by 3–5× without hurting quality saves real money and energy. Better scaling and stability mean fewer crashes and less time tuning tricky settings. Models trained with Shampoo also tend to have fewer activation outliers, which helps with compression and deployment on smaller devices. Faster, steadier optimization lets researchers run more experiments and reach better solutions faster. For industry, this can shorten product cycles and reduce carbon footprints. And because DASH is GPU-friendly, it makes high-quality second-order methods accessible to more teams. In short, DASH turns a great idea (Shampoo) into a practical tool for everyday large-scale training.
Detailed Explanation
Tap terms for definitions01Background & Problem Definition
🍞 Hook: Imagine you’re biking up a hill. You can pedal harder (more power), or you can shift gears (smarter power). In deep learning, AdamW is like pedaling harder, while second-order optimizers like Shampoo are like picking the perfect gear.
🥬 The Concept: Optimizers are rules for how a model takes steps to get better. Shampoo is a second-order optimizer that understands how parameters affect each other, so it can take smarter steps than AdamW.
- How it works:
- Look at the gradient (the direction to improve) from the current batch.
- Build two helper matrices (preconditioners) that summarize how parameters co-vary.
- Take an update step using special matrix roots (like inverse square roots) of those helpers.
- Why it matters: Without this “smart gear,” the model may wobble and take many small, wasteful steps—training slower and sometimes ending with lower-quality solutions.
🍞 Anchor: When training a language model, Shampoo can reach a good loss in fewer, steadier steps than AdamW because it adapts to the landscape more intelligently.
🍞 Hook: You know how a chef preps ingredients into tidy bowls so cooking is fast? Shampoo does “prep” too, but its prep step used to be slow.
🥬 The Concept: The slow part is computing inverse matrix roots—like finding a special “square-root-like” transform of a matrix and then flipping it.
- How it works:
- From gradients, form small square matrices (blocks).
- Compute each block’s inverse square/fourth root.
- Use them to scale (precondition) the gradient for a smarter step.
- Why it matters: If this step is slow or unstable, the whole optimizer drags, making Shampoo feel impractical.
🍞 Anchor: It’s like spending so long sharpening knives that dinner is late; even if the recipe is great, slow prep ruins your timing.
🍞 Hook: Think of a library where one librarian must read every book to sort them; that takes forever.
🥬 The Concept: Older methods (like EVD) to compute those matrix roots are hard to parallelize on GPUs and get slow for big matrices.
- How it works:
- EVD breaks a matrix into eigenvectors/eigenvalues.
- It then transforms those values and rebuilds the matrix.
- This involves steps that don’t map cleanly to fast GPU operations.
- Why it matters: If the core step isn’t GPU-friendly, training wastes your hardware’s potential.
🍞 Anchor: Your GPU is a stadium full of runners; EVD asks many to wait their turn instead of sprinting together.
🍞 Hook: Picture doing homework one page at a time vs. batching lots of similar problems together.
🥬 The Concept: Distributed Shampoo helped by breaking big matrices into B×B blocks, but it still processed many blocks one after another.
- How it works:
- Split big layers into blocks.
- Compute a root for each block.
- Do this block-by-block, often sequentially.
- Why it matters: You save some work, but still lose time by not batching the math to use the GPU fully.
🍞 Anchor: It’s like sorting socks one pair at a time; better than sorting the whole pile blindly, but still not as fast as sorting by colors all at once.
🍞 Hook: You know how learning a new shortcut can turn a long walk into a quick jog?
🥬 The Concept: The gap was that Shampoo’s best math wasn’t arranged to fully match how GPUs like to work, and scaling before iterations wasn’t ideal. That left speed and stability on the table.
- How it works:
- GPUs love big, batched matrix multiplications.
- Using the wrong scaling (like Frobenius norm) can slow convergence of root-finding.
- A better scaling (from Power-Iteration) and batched processing unlock both speed and stability.
- Why it matters: Without aligning math to hardware and using the right scaling, you can’t realize Shampoo’s full potential.
🍞 Anchor: After switching to faster roads (batched ops) and better directions (power-based scaling), the trip gets shorter and smoother.
Real stakes in daily life:
- Faster training means less waiting and lower compute costs for language models, vision systems, and recommendation engines.
- Better-behaved models (fewer activation outliers) are easier to compress, which helps put AI on phones or small servers.
- A quicker, more stable optimizer helps researchers run more experiments and reach better solutions.
- For companies, this can mean faster iteration cycles and reduced energy bills.
02Core Idea
🍞 Hook: Imagine stacking pizza boxes so a delivery person can grab many at once. Doing it box-by-box is slow; grabbing a whole stack is fast.
🥬 The Concept (Aha! in one sentence): DASH makes Shampoo fast by stacking many preconditioner blocks into one 3D batch for GPUs and by using faster, better-scaled inverse-root solvers (like Newton-DB), so you get high-quality steps without the old slowdown.
- How it works:
- Turn lots of B×B blocks into a big 3D tensor and process them together (batched bmm).
- Replace slow EVD with GPU-friendly iterative solvers (Newton-DB, Chebyshev) that use only matrix multiplies.
- Scale each block with a robust spectral estimate from multi-Power-Iteration so iterations converge quickly.
- Where safe, use FP16 for further speedups (e.g., CN in FP16) while keeping accuracy.
- Why it matters: Without batching and better solvers, you waste GPU power and need many more iterations, making Shampoo impractical at scale.
🍞 Anchor: It’s like a bakery: line up trays (3D stacking), bake with a faster oven (Newton-DB), and preheat to the right temperature (multi-Power-Iteration scaling) so every batch finishes quicker and tastier.
Three analogies:
- Assembly line: Instead of one mechanic fixing one bike at a time (sequential blocks), run a conveyor belt (batched tensors) with many bikes and identical tools (bmm) for speed.
- Reading glasses: Newton-DB and Chebyshev are like better lenses that let you focus faster than squinting (EVD) so you don’t waste minutes refocusing.
- Right map scale: Frobenius scaling is like a blurry map; spectral scaling via Power-Iteration is a crisp map that gets you to the destination in fewer turns.
Before vs After:
- Before: Blocks computed one-by-one, EVD or cautious CN, Frobenius scaling pushing eigenvalues too close to zero, unstable or slow convergence.
- After: Blocks stacked and batched, Newton-DB/Chebyshev used where effective, robust spectral scaling from multi-Power-Iteration, FP16 CN where safe—leading to 3–5× faster steps with equal or better perplexity.
Why it works (intuition):
- GPUs are happiest doing big batches of the same operation (matrix multiplies). Stacking turns many small problems into one big, efficient job.
- Iterative solvers like Newton-DB and CN need the matrix spectrum to sit in a friendly range; spectral scaling puts it there, so you need fewer iterations.
- FP16 halves data size and boosts tensor-core throughput; using it where numerically safe gives easy speed wins.
Building blocks (each with a mini-sandwich):
-
🍞 Hook: Sorting lots of index cards is faster if you rubber-band them into stacks. 🥬 The Concept: Batched block preconditioning stacks many B×B blocks into a 3D tensor so the GPU can process them in one go.
- Steps: Gather blocks → stack into (N, B, B) → run batched bmm iterations → unstack results.
- Why it matters: Without stacking, you waste GPU time starting and stopping tiny jobs. 🍞 Anchor: One batched call computes 100 block roots as quickly as many separate calls would process only a few.
-
🍞 Hook: Following two recipes at once can save time if they share steps. 🥬 The Concept: Newton-DB is a matrix method that simultaneously moves toward the square root and its inverse using only matrix multiplies.
- Steps: Initialize Y=A, Z=I; build E=(3I−ZY)/2; update Y←YE, Z←EZ; repeat; chain twice to get A^(−1/4).
- Why it matters: Without Newton-DB, you either rely on slower EVD or need more CN iterations. 🍞 Anchor: In tests, Newton-DB matched or beat EVD quality with less time per step.
-
🍞 Hook: Approximating a curve with a few Lego bricks can be surprisingly accurate. 🥬 The Concept: Chebyshev polynomials approximate inverse roots by evaluating a fixed polynomial with efficient recurrence (Clenshaw’s algorithm).
- Steps: Pre-fit coefficients on an interval; map eigenvalues into that interval; run d lightweight matmuls; get A^(−1/p).
- Why it matters: Without a good polynomial approximation, you need more expensive iterations. 🍞 Anchor: With degree ~60, you get a fast, fixed-cost estimate—handy when iteration budgets are tight.
-
🍞 Hook: Using a ruler that’s too long for a tiny drawing makes measurements awkward. 🥬 The Concept: Frobenius-norm scaling can shrink eigenvalues too much, slowing iterative convergence; spectral scaling (from Power-Iteration) keeps them closer to 1.
- Steps: Estimate top eigenvalue λ_max; divide A by about 2·λ_est to safely satisfy convergence; iterate fewer times.
- Why it matters: Without good scaling, you need more steps or risk instability. 🍞 Anchor: Switching from Frobenius to spectral scaling cut iterations needed for small eigenvalues dramatically.
-
🍞 Hook: If you start several treasure hunts at once, you’re more likely to find the biggest prize quickly. 🥬 The Concept: Multi-Power-Iteration runs many starting vectors in parallel to robustly estimate the largest eigenvalue.
- Steps: Pick 16–32 random starts → multiply by A repeatedly → pick the vector with the biggest Rayleigh quotient.
- Why it matters: A single start can get stuck on a smaller eigenvalue; multiple starts avoid that. 🍞 Anchor: On stacked blocks, multi-PI adds almost no time but greatly improves scaling reliability.
03Methodology
High-level flow: Gradients → build/EMA preconditioners (L, R) in blocks → stack blocks into 3D tensors → scale each block via multi-Power-Iteration → pick a solver (Newton-DB, CN, or EVD) to get inverse roots → graft with Adam’s norm → apply preconditioned update → load-balance across GPUs and sync.
Step-by-step with sandwiches and examples:
- Compute gradients
- 🍞 Hook: You feel which way a marble will roll on a tilted table.
🥬 The Concept: A gradient tells the model which way to change to get less loss.
- Steps: Run a batch → compute loss → backprop to get G for each layer.
- Why it matters: No gradient, no direction. 🍞 Anchor: For a layer with shape (m, n), G is an m×n matrix of “pushes.”
- Build preconditioners (EMA of GG^T and G^T G)
- 🍞 Hook: Keeping a running average of your test scores smooths out lucky days.
🥬 The Concept: Shampoo keeps exponential moving averages (L and R) of GG^T and G^T G.
- Steps: L_t = βLR·L_{t-1} + (1−βLR)·GG^T; R_t = βLR·R_{t-1} + (1−βLR)·G^T G; add εI for stability.
- Why it matters: Without these, you miss parameter correlations and lose second-order benefits. 🍞 Anchor: In a 2048-dim embedding, L and R summarize how features co-vary across the batch.
- Split into blocks and stack into 3D tensors
- 🍞 Hook: Sorting socks by color and size first makes folding faster later.
🥬 The Concept: Break layers into B×B blocks; then stack all blocks with the same B into one (N, B, B) tensor.
- Steps: Tile G into blocks; form matching L and R blocks; stack all same-shaped blocks across layers.
- Why it matters: Without stacking, you do many small calls that underuse the GPU. 🍞 Anchor: For B=1024, a big embedding layer might yield ~62 full-size blocks that are stacked and solved in one batched pass.
- Scale each block using multi-Power-Iteration
- 🍞 Hook: Calibrating a scale before weighing fruit makes results trustworthy.
🥬 The Concept: Estimate the top eigenvalue λ_max per block using multiple simultaneous starts, then divide the block by about 2·λ_est.
- Steps: Initialize 16–32 random vectors; repeatedly multiply by A; pick the one with the largest Rayleigh quotient; scale block.
- Why it matters: Without good scaling, root solvers may converge slowly or misbehave. 🍞 Anchor: Compared to Frobenius scaling (often 10–100× too big), spectral scaling keeps eigenvalues near 1 and cuts iteration counts.
- Compute inverse roots with a chosen solver
-
Options, all batched over (N, B, B): a) Newton-DB (recommended for quality)
- 🍞 Hook: Two hands washing each other get clean faster.
🥬 The Concept: N-DB jointly updates Y≈A^(1/2) and Z≈A^(−1/2) with only matmuls; chain twice for A^(−1/4).
- Steps: Init Y=A, Z=I; E=(3I−ZY)/2; Y←YE; Z←EZ; repeat for a few iterations; run again on Y to get A^(−1/4).
- Why it matters: Fewer, well-behaved iterations at good scaling produce high-quality inverses. 🍞 Anchor: On Llama-953M, N-DB achieved the lowest validation perplexity per iteration among tested methods.
b) Coupled-Newton (CN, fastest in FP16)
- 🍞 Hook: Using a reliable shortcut can be faster if the road is smooth.
🥬 The Concept: CN maintains X_k ≈ A^(−1/p) and a stabilizer M_k; each iteration uses a few matmuls.
- Steps: Compute C_k from M_k; update X_k+1=X_k C_k and M_k+1=C_k^p M_k.
- Why it matters: Without CN, you might use slower EVD; CN in FP16 is very fast and accurate in practice. 🍞 Anchor: With FP16, CN reached down to ~138 ms per optimizer step in tests (B=1024), a 4.83× speedup vs prior baselines.
c) EVD (accuracy reference)
- 🍞 Hook: Rulers and protractors are precise, but slow to use on every line.
🥬 The Concept: Decompose A=QΛQ^T; compute Λ^(−1/p); reassemble.
- Steps: eigendecompose → transform eigenvalues → reconstruct.
- Why it matters: Reliable but not GPU-friendly; kept as a baseline. 🍞 Anchor: EVD often needed update frequency f=10 to be practical; still slower than batched iterative solvers.
d) Chebyshev (fixed-cost approximation)
- 🍞 Hook: Using a good stencil lets you draw curves quickly.
🥬 The Concept: Evaluate a pre-fit polynomial via Clenshaw to approximate A^(−1/p) with d matmuls.
- Steps: Map spectrum into [ε,1+ε]; run recurrence; use degree ~40–100 depending on tolerance.
- Why it matters: Fixed time per step; helpful when iteration budgets must be bounded. 🍞 Anchor: At certain scales, Chebyshev in DASH ran under ~90 ms per step with solid perplexity.
- 🍞 Hook: Two hands washing each other get clean faster.
🥬 The Concept: N-DB jointly updates Y≈A^(1/2) and Z≈A^(−1/2) with only matmuls; chain twice for A^(−1/4).
- Grafting with Adam’s magnitude
- 🍞 Hook: Walk in a smart direction (Shampoo) but match your stride length to a known good walker (Adam).
🥬 The Concept: Use Shampoo’s direction but scale its length to Adam’s, leveraging Adam’s tuned LR schedule.
- Steps: Compute U=preconditioned gradient; P=Adam-like direction; scale s=||P||F/||U||F; update θ←θ−η·s·U.
- Why it matters: Without grafting, Shampoo can be numerically finicky; grafting stabilizes and matches known LR schedules. 🍞 Anchor: It’s like steering with a precise compass (Shampoo) but pacing yourself like a runner you trust (Adam).
- Apply update and load-balance across GPUs
- 🍞 Hook: Sharing chores evenly finishes cleaning faster.
🥬 The Concept: Greedy load balancing assigns heavy layers to less-loaded GPUs; then synchronize parameters.
- Steps: Sort layers by size; assign to GPUs with least load; compute locally; all-reduce/broadcast to sync.
- Why it matters: Without balanced work, some GPUs idle while others lag. 🍞 Anchor: DASH used less memory per GPU than the prior system and kept workers well utilized.
The secret sauce:
- Stacking turns many tiny, launch-bound tasks into a few big, compute-bound kernels that saturate tensor cores.
- Spectral scaling from multi-Power-Iteration aligns the problem to the sweet spot where Newton-like iterations converge fast.
- Selective low precision (FP16 for CN) buys free speedups while keeping accuracy; where unsafe (N-DB in FP16), stay in FP32.
04Experiments & Results
The test and why: The authors trained a 953M-parameter Llama-style model on the C4 dataset with a Chinchilla-optimal token budget (~20 tokens/parameter), measuring two things that matter: validation perplexity (how well the model predicts text) and optimizer step time (how long each update takes). They compared methods that only changed the optimizer step; forward/backward pass times were constant.
Competition/baselines:
- Distributed Shampoo (DIST): The widely used, well-optimized reference implementation.
- Variants of inverse-root solvers: EVD, Coupled-Newton (CN), and Newton-DB (N-DB).
- Precision and scaling variants: FP32 vs FP16 where safe; Frobenius vs Power-Iteration scaling.
Scoreboard with context (selected highlights):
- Overall: DASH matched or improved validation perplexity while making optimizer steps up to 4.83× faster than DIST in one-to-one settings.
- CN (speed champ):
- B=1024, FP32: DIST 666 ms → DASH 149 ms (≈4.47× faster) with the same perplexity (~11.87).
- B=1024, FP16: DIST 471 ms → DASH 138 ms (≈3.41× faster). Versus DIST’s CN-FP32 666 ms, that’s ≈4.83× faster.
- B=2048, FP16: DIST 243 ms → DASH 169 ms (≈1.44× faster).
- Note: FP16 CN was stable; BF16 diverged in tests.
- EVD (accuracy reference):
- B=2048, f=1: DIST 2200 ms → DASH 1747 ms (≈1.26× faster) with essentially identical perplexity (~11.72–11.73).
- B=1024, f=10: DIST 355 ms → DASH 315 ms (≈1.13× faster). The heavy cost still appears at update steps (same as f=1), but skipped steps are cheap (~35 ms).
- N-DB (quality champ):
- B=1024, Frobenius scaling: DIST 558 ms → DASH 188 ms (≈2.97× faster); perplexity improved vs EVD.
- B=1024, Power-Iteration scaling: DIST slowed to 740 ms (sequential PI is costly), but DASH took 194 ms (≈3.81× faster) and achieved even better perplexity (~11.68 vs ~11.76 with Frobenius).
- B=2048 with Power-Iteration: DIST 355 ms → DASH 284 ms (≈1.25× faster), matching EVD’s f=1 quality. Frobenius-scaling runs with N-DB at B=2048 were unstable in DIST but stable in DASH with PI scaling.
Making numbers meaningful:
- A 4.83× speedup is like finishing your homework before the bell while others are only halfway—same answers, less time.
- Perplexity drops like 11.76 → 11.68 are small but real; it’s the difference between a solid B+ and an A− on nuanced language prediction.
- Using Power-Iteration scaling not only sped up convergence of the math steps but also reduced validation perplexity consistently for N-DB.
Surprising findings:
- N-DB, an iterative method, sometimes matched or beat EVD in validation perplexity, despite being far faster.
- FP16 boosted CN safely but made N-DB unstable (FP32 required for N-DB); BF16 for CN diverged due to lower mantissa precision.
- Increasing block size from 1024 to 2048 didn’t consistently improve perplexity but did cost more time—suggesting practitioners shouldn’t assume “bigger block = better.”
Total training time impact:
- Even when the forward/backward pass dominates, shaving optimizer step from, say, 355 ms (DIST EVD f=10) to 138 ms (DASH CN FP16) saved about 5% of total runtime on a 953M model over ~9000 steps—tens of minutes at this scale.
Takeaways:
- If you want the best quality: N-DB with Power-Iteration scaling in DASH reached the lowest perplexity.
- If you want the absolute fastest steps: CN in FP16 with Frobenius scaling gave the top speed while matching baseline quality.
- Stacking plus spectral scaling is a double win: faster and more stable than Frobenius scaling, especially for N-DB.
05Discussion & Limitations
Limitations (specific):
- Newton-DB in FP16 was numerically unstable in these tests; it needed FP32. BF16 caused CN to diverge.
- Chebyshev approximations showed promise but were sensitive to setup; picking degree and precision requires care.
- Very large blocks (e.g., > embedding size) can be rank-limited, adding noise to spectra and not improving quality.
- Power-Iteration adds overhead if done sequentially; DASH’s stacking makes it cheap, but non-stacked pipelines could slow.
Required resources:
- Modern GPUs with strong tensor-core support to realize the full batched bmm speedups.
- Enough GPU memory to store stacked block tensors and cached inverse roots (though DASH improved memory balance vs prior art).
- Distributed training setup (e.g., DDP/ZeRO) to benefit from load balancing and synchronized updates.
When NOT to use:
- Ultra-low-precision-only environments (mandatory BF16 everywhere) where iterative solvers become unstable.
- Tiny models where the optimizer step is already negligible compared to forward/backward (little speedup to harvest).
- Pipelines that cannot batch or stack blocks (e.g., very irregular shapes with no common B) may not see big gains.
Open questions:
- Can we stabilize Newton-DB in FP16 (e.g., mixed-precision tricks, error compensation, or blockwise adaptive damping)?
- Dynamic solver selection: can we cheaply estimate per-block condition numbers and pick CN, N-DB, or Chebyshev adaptively?
- How do these gains scale to multi-billion or trillion-parameter models with tensor and pipeline parallelism?
- Can spectral scaling be further improved (e.g., subspace iteration) without adding too much overhead?
- Can we integrate robust polynomial methods as plug-and-play fallbacks when Newton-like methods stall?
Bottom line: DASH nails the systems–numerics handshake—batched execution plus better scaling—and shows second-order methods can be both high-quality and fast in practice.
06Conclusion & Future Work
Three-sentence summary:
- This paper introduces DASH, a GPU-optimized rethinking of the Shampoo optimizer that stacks block preconditioners into 3D tensors and swaps in faster inverse-root solvers.
- By adding robust spectral scaling via multi-Power-Iteration and using FP16 where safe, DASH cuts optimizer step time by up to 4.83× while matching or improving validation perplexity.
- Newton-DB emerges as a high-quality choice, and the overall design makes strong second-order methods practical at large scale.
Main achievement:
- Turning Shampoo’s main bottleneck—computing inverse matrix roots—into a GPU-friendly, batched, and well-scaled workflow that delivers both speed and quality.
Future directions:
- Stabilize Newton-DB in lower precision; develop dynamic per-block solver selection; extend to even larger, more distributed training regimes; and explore tighter polynomial approximations.
Why remember this:
- DASH shows that aligning math with hardware (batching) and using the right pre-iteration scaling (spectral, not Frobenius) can transform a slow-but-smart optimizer into a fast-and-smart one—bringing second-order benefits to everyday large-scale training.
Practical Applications
- •Speed up pretraining of large language models while maintaining or improving validation perplexity.
- •Reduce training costs for recommendation systems by using faster optimizer steps on GPUs.
- •Improve model deployability by leveraging Shampoo’s tendency to reduce activation outliers (easier compression/quantization).
- •Stabilize training in tricky regimes by using spectral scaling and grafting to avoid divergence.
- •Run more hyperparameter searches within the same compute budget due to faster optimizer steps.
- •Adopt mixed-precision training safely by using FP16 for CN while keeping N-DB in FP32.
- •Enable faster fine-tuning of foundation models by reusing DASH’s batched block preconditioning.
- •Scale multi-GPU training more efficiently with DASH’s load balancing and stacked memory layout.
- •Prototype new inverse-root solvers or damping heuristics easily within DASH’s batched framework.