SonicMoE: Accelerating MoE with IO and Tile-aware Optimizations
Key Summary
- •SonicMoE makes Mixture-of-Experts (MoE) models train faster and use less memory by redesigning how data is moved and computed on GPUs.
- •It keeps activation memory small (not growing with expert granularity) by avoiding caching big tensors and recomputing the right pieces at the right time.
- •Its GPU kernels overlap memory movement (IO) with math, so the GPU is busy instead of waiting around.
- •A new tile-aware 'token rounding' router trims away padding waste in Grouped GEMM, giving up to 16% extra speed in sparse settings without hurting quality.
- •On a 7B MoE, SonicMoE cuts activation memory per layer by up to 45% compared to ScatterMoE.
- •On H100 GPUs, SonicMoE reaches an average of 88% of a strong cuBLAS upper-bound throughput for forward passes.
- •End-to-end training is faster: 213B tokens/day on 64 H100s with SonicMoE vs. 225B tokens/day on 96 H100s with ScatterMoE (similar throughput, fewer GPUs).
- •Kernels are written to use Hopper/Blackwell features like async loads/stores and ping-pong scheduling for high utilization.
- •The method works best for fine-grained, highly sparse MoEs, where older kernels become memory-bound and waste padding FLOPs.
- •SonicMoE is open-sourced, making these speedups and memory savings accessible to others.
Why This Research Matters
SonicMoE makes large language models cheaper and faster to train by cutting memory usage and keeping GPUs busy instead of waiting on data. That means labs can train strong models with fewer GPUs, lowering costs and environmental impact. By removing padding waste and overlapping work, it turns hardware time into useful learning instead of overhead. Because the approach preserves model quality, you don’t trade accuracy for speed. Open-source kernels let the community benefit immediately and build on the ideas. As MoEs become the norm for scaling, these optimizations help everyone reach better models sooner.
Detailed Explanation
Tap terms for definitions01Background & Problem Definition
🍞 Top Bread (Hook) You know how, when a school has many teachers, each one is great at a different subject? If the principal sends each kid to the best teacher for their question, everyone learns faster—and teachers don’t do extra work.
🥬 Filling (The Actual Concept)
- What it is: Mixture of Experts (MoE) models do something similar: a router sends each token to a few specialized mini-networks (experts) instead of using all experts every time.
- How it works: (1) A router scores how good each expert is for a token; (2) it picks top-K experts; (3) those experts do their mini-MLPs; (4) the outputs get combined; (5) training updates the right experts.
- Why it matters: You get the power of many parameters without paying the full compute bill each step—quality per FLOP improves.
🍞 Bottom Bread (Anchor) Imagine asking, “What’s 7×8?” The math expert answers. For “Write a poem,” the language expert helps. You don’t ask every teacher every time.
🍞 Top Bread (Hook) Imagine a cafeteria line: kids (tokens) head to different stations (experts). If each station must reach across the room for ingredients every time, the line slows way down.
🥬 Filling (The Actual Concept)
- What it is: Grouped GEMM is the GPU’s way of batching many small matrix multiplies (one per expert) into a single, efficient operation.
- How it works: (1) Gather the right tokens per expert; (2) run many small GEMMs together; (3) optionally pad tiny leftovers to fit the hardware’s tile sizes.
- Why it matters: Without grouping and tiling, the GPU wastes time on tiny, scattered work. With it, you get big-speed factory lines.
🍞 Bottom Bread (Anchor) Like cooking many mini-pizzas together on a big tray instead of baking them one by one.
🍞 Top Bread (Hook) Have you ever tried to carry every toy you might use for the entire afternoon? It’s heavy and clumsy. Better to grab only what you need.
🥬 Filling (The Actual Concept)
- What it is: Activation memory footprint is how much intermediate data (activations) we store for the backward pass.
- How it works: (1) Forward computes and often saves activations; (2) Backward reuses them to compute gradients; (3) If we cache too much, memory explodes.
- Why it matters: In fine-grained MoEs (many experts with small hidden sizes), naive caching grows with the number of activated experts, quickly hitting memory limits.
🍞 Bottom Bread (Anchor) Storing every worksheet you ever used is wasteful; keep only what the test needs now.
🍞 Top Bread (Hook) If you pack cookies into boxes of 12 but have 13 cookies, you either stuff a half-empty second box or change how you count so you don’t waste space.
🥬 Filling (The Actual Concept)
- What it is: Token rounding makes the number of tokens sent to each expert a neat multiple of the GPU’s tile size (like 128), cutting padding waste in GEMM.
- How it works: (1) Do normal top-K routing; (2) per expert, round its token count up or down to the nearest tile multiple; (3) ensure changes only touch at most one tile; (4) keep total work the same on average.
- Why it matters: Without rounding, GEMM pads small leftovers—doing FLOPs that don’t help learning.
🍞 Bottom Bread (Anchor) It’s like seating kids at tables of 8; if 17 show up, it’s smarter to make 16 (two full tables) and reassign one, rather than set a third table for one kid.
🍞 Top Bread (Hook) When you bake and do dishes at the same time, dinner’s ready sooner. If you do one after the other, you wait a lot.
🥬 Filling (The Actual Concept)
- What it is: Memory IO overlap means moving data and doing math at the same time so the GPU never sits idle.
- How it works: (1) Split work into tiles; (2) while one tile computes on tensor cores, another tile’s data is fetched/stored asynchronously; (3) ping-pong the roles so compute and IO overlap.
- Why it matters: Fine-grained MoEs are often memory-bound; hiding IO keeps throughput high.
🍞 Bottom Bread (Anchor) Boil pasta while frying sauce: dinner finishes faster than doing them one after the other.
The world before: MoEs were already great at squeezing more quality out of each unit of compute by waking only a few experts per token. But as open models trended toward fine-grained (many small experts) and sparser (more total experts, still only a few activated) setups, hardware efficiency got worse. Why? Three main reasons: (1) Activation memory ballooned if you cached big intermediate tensors across many experts; (2) arithmetic intensity dropped—there was proportionally more data movement per unit of math; and (3) padding waste in Grouped GEMM grew as each expert received fewer tokens, leaving half-empty tiles to compute anyway.
The problem: State-of-the-art kernels like ScatterMoE and MoMoE were not optimized for this new regime. They often did gathers in separate kernels (extra IO), computed gradients in ways that forced caching large tensors like Y (extra memory), and didn’t systematically overlap IO with compute, so the GPU waited. In very sparse setups, grouped GEMM’s tile quantization caused substantial wasted FLOPs.
Failed attempts: Prior systems tried fusing some steps, or using block-sparse math, or relying on strong library GEMMs assuming neatly packed inputs. But these approaches either left IO uncovered, required big activation caches, or still suffered from padding waste and stream bubbles. They ran into determinism issues with atomics, register pressure, or lacked fine-grained async control needed on Hopper/Blackwell GPUs.
The gap: A co-designed approach was missing—one that (a) restructures the math to avoid caching large activations for backward, (b) fuses gathers and epilogues smartly, and (c) actively overlaps IO and compute, plus (d) changes routing slightly to align with hardware tiles.
Real stakes: Faster MoE training means fewer GPUs or fewer hours to reach the same quality. That lowers cost, energy, and carbon footprint. It also makes strong models accessible to more teams. And for users, it leads to better AI models delivered sooner and tuned on a reasonable budget.
02Core Idea
🍞 Top Bread (Hook) Imagine turning a busy kitchen into a smooth assembly line: you prep while cooking, batch dishes by size, and avoid storing too many half-finished plates on the counter.
🥬 Filling (The Actual Concept)
- What it is (one sentence): SonicMoE co-designs the MoE math and GPU kernels to avoid large activation caches, overlap memory IO with compute, and round token counts to tile sizes—unlocking high throughput with low memory.
- How it works (recipe-style):
- Redesign backward so gradients (like dS, dH) are computed without needing to cache Y or dY.
- Fuse gathers with GEMM prologue and fuse activation math with the epilogue to cut IO trips.
- Use ping-pong scheduling and async load/store to overlap IO and Tensor Core math.
- Round per-expert token counts to tile multiples to avoid padding waste in grouped GEMM.
- Why it matters: Without these, fine-grained, sparse MoEs get stuck in the memory-bound slow lane—training stalls or overuses GPUs.
🍞 Bottom Bread (Anchor) It’s like packing cupcakes into boxes of exactly 12 while the next batch bakes and frosting is made—no idle ovens, no half-empty boxes, no overflowing counter.
Three analogies for the same idea:
- Factory analogy: Instead of building one gadget at a time (separate kernels), SonicMoE sets up stations (fused stages) so parts keep flowing; warehousing is minimized (no big activation caches), and shipping and assembly happen in parallel (overlap IO and compute). Token rounding makes all boxes the same size (tile multiples), so no space is wasted.
- School analogy: Students (tokens) go to teachers (experts) in neatly sized groups that fit classroom seats (tiles). Hall monitors (kernels) coordinate so as one class learns, the next is already lining up (overlap). The school avoids storing mountains of worksheets (activations) and only keeps what’s needed for the final exam (backward).
- Traffic analogy: Cars (data) hit green lights in sync (overlap), roads are bundled into express lanes (grouped GEMM), and lanes are sized to full car lengths (tiles) to prevent gaps. No one parks on the highway shoulder (no huge caches).
Before vs After:
- Before: Backward needed cached Y/dY, routers used vanilla top-K with jagged sizes, gathers and epilogues weren’t fully fused, and GEMM often waited for IO. Result: high memory, wasted padding FLOPs, and underutilized GPUs.
- After: Backward computes dS via <dA', A> and dH via fused dSwiGLU, so Y/dY caching isn’t needed. Gathers fuse into prologues; epilogues compute more math and write asynchronously. IO and math overlap. Token counts align to tile sizes. Result: 45% less activation memory in some 7B settings; up to 1.86× throughput vs. strong baselines; extra 16% speed from token rounding in sparse regimes.
Why it works (intuition):
- Cutting IO wins in a memory-bound regime: fewer round-trips and better overlap inflate effective arithmetic intensity.
- Computing dS with <dA', A> reduces both memory traffic and reduction complexity (reduce over n, not d).
- Ping-pong scheduling keeps Tensor Cores hot while data moves; async TMA/stores prevent epilogue from stalling the next tile.
- Tile-aware routing removes structural inefficiency (padding), turning wasted FLOPs into useful ones with minimal routing deviation (≤ one tile per expert).
Building blocks (with sandwich explainers):
- MoE granularity and sparsity 🍞 Hook: Picking smaller Lego bricks (experts) gives more detailed shapes but means handling more pieces. 🥬 Concept: Granularity is d/n; sparsity is K/E. Higher granularity and sparsity improve quality per FLOP but lower arithmetic intensity and increase IO. 🍞 Anchor: Many small classes (experts) each get few students; good teaching, more hallway traffic.
- Grouped GEMM tiles 🍞 Hook: Muffin pans have fixed cup sizes. 🥬 Concept: GEMM uses fixed tiles (e.g., 128) and pads leftovers; small leftovers cause waste. 🍞 Anchor: If you bake 10 muffins in a 12-cup pan, two cups heat air.
- Gather/Scatter fusion 🍞 Hook: Put ingredients straight onto the pan instead of into a bowl first. 🥬 Concept: Fuse token gathering with GEMM loads; avoid separate gather kernels; prefer gather+sum over scatter on store to keep stores async. 🍞 Anchor: Fewer trips to the pantry during cooking.
- Epilogue fusion 🍞 Hook: Frost muffins while they’re still in the tray. 🥬 Concept: Do activation functions and dS inside GEMM epilogues; fewer extra kernels and less IO. 🍞 Anchor: Finish more work before removing the tray, saving time.
- IO-compute overlap 🍞 Hook: Do laundry while vacuuming. 🥬 Concept: While one tile computes, the next tile’s data transfers; switch roles in a ping-pong rhythm. 🍞 Anchor: Housework ends sooner when tasks overlap.
- Token rounding 🍞 Hook: Seat kids in tables of 8; don’t leave singles at a new table. 🥬 Concept: Round each expert’s token count to nearest tile multiple; at most one tile changes; preserves total tokens on average. 🍞 Anchor: Two full tables beat one full and one half-empty.
03Methodology
At a high level: Input tokens → Router picks experts → Up-proj (gather+GEMM+act) → Down-proj (GEMM+async store) → Expert aggregation (gather+sum) → Output. Then backward mirrors this with fused steps that avoid caching big tensors, while IO overlaps with compute.
Step-by-step (with reasons and examples):
- Routing (supports standard top-K or SonicMoE Token Rounding)
- What happens: For each token, compute scores over E experts, pick top-K. With token rounding (TR), adjust each expert’s token count to tile-size multiples (e.g., 128) by minimally padding or trimming at most one tile.
- Why this step exists: Experts must know which tokens to process; TR prevents padding waste when experts receive small, uneven counts.
- Example: T=24,576 tokens, E=128 experts, K=8, M_tile=128. If Expert 37 gets 649 tokens (649=5×128+9), we round to 640 (drop 9) or 768 (pad 119), whichever is closer, changing at most one tile’s worth.
- Forward Up-Projection (A kernel): gather + varlen-M Grouped GEMM + activation (SwiGLU) fused in epilogue
- What happens: For each expert, directly gather its tokens from X while loading into shared memory; multiply by W1,e; apply SwiGLU in epilogue; write H and A.
- Why this step exists: Turns d-dimensional embeddings into 2n then n; gather fusion avoids a separate, IO-heavy kernel.
- Example: d=1536, n=256. For Expert e with Te tokens, compute Te×d times d×2n, then act to Te×n.
- Forward Down-Projection (Y kernel): varlen-M Grouped GEMM with async store
- What happens: Multiply A by W2,e to get Y (Te×d) and store to memory asynchronously (no scatter fused in epilogue).
- Why this step exists: Projects back to d. Async store keeps tensor cores busy; avoiding scatter here prevents synchronous stalls.
- Example: For each expert, A_e (Te×n) × (n×d) → Y_e (Te×d), store Y_e contiguously.
- Expert Aggregation (O kernel): gather-and-sum per token
- What happens: For each token, gather its experts’ Y slices, multiply by routing weights, and sum to produce final output O.
- Why this step exists: Equivalent to experts scattering—but gathering lets us keep stores async earlier and batch sums efficiently.
- Example: Token t activates experts {e1,...,eK}, compute O_t = Σ S_{t,ei}·Y_{ei,t}.
- Backward: Down-Proj Activation Gradient (dH kernel): gather dO + varlen-M GEMM + fused dSwiGLU + compute dS = <dA', A>
- What happens: Gather dO for each expert; compute dA' = dO·W2^T; scale by routing weights to get dA; recompute A via activation backward; compute dS via inner product <dA', A>; output dH and A' = broadcast(S)·A for weight gradients.
- Why this step exists: Critical trick—computing dS via <dA', A> means we never need cached Y or dY, slashing activation memory and IO.
- Example: Instead of loading Y (Te×d), we use already-in-register dA' (Te×n) and A (Te×n), reducing memory and doing reductions over n (smaller) not d.
- Backward: Down-Proj Weight Gradient (dW2 kernel): varlen-K Grouped GEMM with gather
- What happens: Use A' and dO to compute dW2 = A'^T·dO with a grouped GEMM that varies over K dimension; gathers are fused.
- Why this step exists: Updates W2 without separate gather kernels; overlapped IO keeps throughput high.
- Example: Per expert, (n×Te)×(Te×d) → (n×d).
- Backward: Up-Proj Activation Gradient (dX~ kernel): varlen-M Grouped GEMM
- What happens: Compute dX~ per expert from dH and W1^T.
- Why this step exists: Gets each expert’s contribution to input gradients.
- Example: dH_e (Te×2n) × (2n×d) → dX~_e (Te×d).
- Backward: Up-Proj Weight Gradient (dW1 kernel): varlen-K Grouped GEMM with gather
- What happens: Gather X per expert; compute dW1 = X_e^T · dH_e.
- Why this step exists: Updates W1 with fused gather and overlapped IO.
- Example: (d×Te)×(Te×2n) → (d×2n).
- Backward: Expert Aggregation (dX kernel): gather-and-sum per token
- What happens: For each token, sum its experts’ dX~ parts to form final dX.
- Why this step exists: Mirror of forward aggregation to route gradients back properly.
- Example: dX_t = Σ_e∈topK dX~_{e,t}.
Secret sauce (what’s truly clever):
- Minimal activation caching: Only X, H, and tiny routing metadata persist. Y and dY aren’t cached; dS comes from <dA', A>. Result: activation memory doesn’t grow with expert granularity.
- IO-savvy fusion: Gather fused into GEMM prologues; activation and dS fused into epilogues; async stores prevent stalls.
- Overlap everywhere: Ping-pong scheduling and async TMA keep tensor cores busy even with heavy epilogues.
- Tile-aware routing: TR aligns workloads to hardware tile sizes, removing padding waste while changing at most one tile per expert.
Concrete walk-through with numbers:
- Setup: T=24,576, d=1536, n=256, E=128, K=8, M_tile=128 (7B-like).
- Routing: Each token activates 8 of 128 experts. Suppose Expert 12 gets 768 tokens (perfect multiple); Expert 37 gets 649; TR rounds Expert 37 to 640 (drop 9) because it’s closer than 768.
- Forward: Up-proj fuses gather→GEMM→SwiGLU, writes A. Down-proj writes Y with async store, no scatter. O kernel gathers K experts/tok and sums.
- Backward: dH kernel gathers dO, does dA' = dO·W2^T, computes dS = <dA', A> in epilogue, and outputs dH and A'. Then dW2 and dW1 via grouped GEMM with fused gathers; dX~ and final dX via aggregation.
- Outcome: Peak activations remain small; IO is overlapped; padding FLOPs largely eliminated; throughput rises.
04Experiments & Results
The test: Measure two things—(1) peak activation memory per layer and (2) compute throughput (TFLOPS) for forward/backward across model sizes and sparsity levels. Also report end-to-end tokens/day training throughput on clusters.
The competition: SonicMoE vs. ScatterMoE, MoMoE, MegaBlocks, Megatron, and strong grouped-GEMM baselines built on DeepGEMM (“DeepGEMM++” where needed). cuBLAS dense BMM is shown as an upper-bound reference for forward GEMM math.
The scoreboard (with context):
- Activation memory: For a 7B MoE (d=1536, n=256, E=128, K=8), SonicMoE cuts peak per-layer activation memory by up to 45% vs. ScatterMoE. At 30B and 120B, the advantage grows further (saves multiple GiB/layer vs. some baselines). Crucially, SonicMoE’s memory stays essentially flat as expert granularity increases, while others grow.
- Forward throughput: SonicMoE reaches on average 88% (max 91%, min 86%) of a strong cuBLAS upper bound for one MoE layer’s forward pass on H100, even though cuBLAS doesn’t include router overhead. Against DeepGEMM++ (a strong, fused baseline), SonicMoE still pulls ahead, especially as granularity rises.
- Backward throughput: SonicMoE’s fused dH (computing dS and dSwiGLU in epilogue) and overlapped IO raise throughput dramatically—up to 83% faster than ScatterMoE and 115% faster than MoMoE on the backward pass in 7B settings.
- End-to-end training: On 64 H100s, SonicMoE trains a 7B MoE at 213B tokens/day—roughly matching ScatterMoE’s 225B tokens/day on 96 H100s. In other words, similar throughput with one-third fewer GPUs.
- Token rounding (TR) gains: In highly sparse regimes (e.g., scaling E while keeping K fixed), TR adds up to 16% extra TFLOPS for the core MoE kernels over vanilla top-K routing by eliminating tile-padding waste. On real model configs (e.g., Qwen3-Next-80B-A3B-Thinking), TR delivers ~20% forward and ~8% backward speedups over top-K.
Surprising findings:
- Quality holds with TR: Training with TR and evaluating with standard top-K shows similar or slightly better perplexity/accuracy in sparse regimes. Despite changing at most one tile per expert, TR retains routing faithfulness well enough to preserve downstream quality.
- Heavy epilogues can still be fast: By overlapping IO and using ping-pong scheduling, even epilogues that do a lot (e.g., dS, activation grads) can keep tensor cores highly utilized.
- Gather-then-sum beats scatter-on-store: Avoiding synchronous scatter stores in GEMM’s epilogue yields higher overall throughput than fusing scatter, even if an extra aggregation kernel is needed later.
Interpretation in plain terms:
- 45% less memory is like turning a 10-shelf storage problem into 5–6 shelves—suddenly your training fits without juggling.
- 1.86× faster kernels on a flagship GPU family means large cost and time savings.
- Matching a bigger cluster’s throughput with fewer GPUs is like winning a relay with fewer runners—you coordinated better.
- TR’s +16% in sparse settings is like removing the last bits of slop from your workflow; the cleaner the batches, the faster the work.
05Discussion & Limitations
Limitations (be specific):
- Hardware focus: The kernels use Hopper/Blackwell features (e.g., cp.async, TMA, ping-pong scheduling, TMEM/UMMA). Other GPUs or accelerators may lack these, reducing gains until analogous features are used.
- Engineering complexity: The fused epilogues, relay warps across CTAs, and persistent schedulers are tricky to implement and tune; portability and maintenance take effort.
- Tile-size dependency: TR depends on tile multiples (often 128). If future hardware changes optimal tile sizes, TR must adapt; very small microbatches where average tokens per expert ≈ tile size (or less) can see sensitivity.
- Distributed overlap not fully covered: SonicMoE focuses on single-node kernel efficiency; overlapping all2all/expert-parallel communication with compute remains an opportunity.
- Inference routing: TR is a training-time router; for autoregressive decoding, standard token-choice top-K is used at eval. While results show good agreement, it’s still a train/eval mismatch to monitor.
Required resources:
- Modern NVIDIA GPUs (H100/Blackwell) to see full benefits; sufficient memory bandwidth; a training stack that can integrate custom kernels (PyTorch interface provided).
- Typical large-scale pretraining infra (FSDP/ZeRO, data pipelines, logging/profiling).
When NOT to use:
- Tiny models or dense MLPs where grouped GEMM and routing overheads dominate potential gains.
- Extremely small batches where average tokens per expert is at or below tile size; TR’s benefits shrink and routing noise can rise.
- Non-NVIDIA or older GPUs without async IO features; simpler kernels may be preferable.
Open questions:
- Can communication (expert-parallel all2all) be overlapped with GEMM as effectively as IO is overlapped here?
- How far can low precision (FP8/MXFP formats) push memory savings without numerical issues in fused epilogues?
- Can we design train-and-infer-consistent routing that is tile-aware yet decoding-friendly?
- How to autotune tile sizes and fusion choices across diverse model shapes for near-peak performance out of the box?
06Conclusion & Future Work
Three-sentence summary: SonicMoE rethinks MoE kernels to minimize activation caches, fuse gathers/epilogues smartly, and overlap memory IO with compute, so fine-grained, sparse MoEs run fast instead of stalling on memory. A tile-aware token rounding router further removes padding waste, adding up to 16% speed in sparse regimes without hurting model quality. Together, these changes deliver up to 45% lower activation memory and up to 1.86× kernel throughput gains on Hopper, translating to fewer GPUs for the same training throughput.
Main achievement: A practical, open-source, hardware-aware redesign that keeps MoE training in the high-utilization lane—activation memory flat vs. granularity, heavy epilogues still fast, and padding waste gone.
Future directions: Extend to low-precision/microscaling formats (FP8/MXFP8/MXFP4), co-design overlap with distributed expert-parallel communication, and evolve tile-aware routing into decoding-compatible strategies. Autotuners that pick optimal fusion/scheduling per shape and hardware could further simplify adoption.
Why remember this: As MoEs grow more granular and sparse for better quality per FLOP, naïve kernels bog down. SonicMoE shows that with careful math choices and IO-smart GPU kernels, we can keep the speed and the savings—making large models cheaper and faster to train without sacrificing results.
Practical Applications
- •Train fine-grained, sparse MoE LLMs on fewer GPUs while keeping the same throughput.
- •Reduce out-of-memory errors by minimizing activation caches in backward passes.
- •Speed up pretraining pipelines by overlapping IO and compute on Hopper/Blackwell GPUs.
- •Use token rounding to eliminate padding waste when scaling the number of experts.
- •Retrofit existing MoE stacks (PyTorch) with SonicMoE kernels to boost TFLOPS.
- •Lower cloud costs for large-scale MoE experiments by improving tokens/day per GPU.
- •Enable bigger MoE models per node (same memory budget) to explore better scaling laws.
- •Benchmark and tune tile sizes and routing strategies for maximum throughput on given hardware.
- •Improve multi-tenant training efficiency by keeping GPU utilization high even with variable workloads.