Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers
Key Summary
- •The paper shows that big language models often get stuck with weight sizes set by training hyperparameters instead of by the data, which quietly hurts performance.
- •This happens because random gradient noise pushes weight matrices to grow, while weight decay pulls them back, creating a fixed balance point for matrix scales.
- •The authors add learnable multipliers (little dials) that sit on top of matrix layers to freely set the best scale from data.
- •They use both scalar multipliers (one dial for a layer) and vector multipliers (a dial per row and per column) to unlock scale at many levels.
- •These multipliers learn stable scales that matrix weights alone could not, leading to richer features across layers and within layers.
- •The method works with different optimizers (Adam and Muon), architectures (attention and SSM), and adds no inference cost because multipliers can be merged into weights.
- •Care is needed to avoid symmetry-related drift and to handle gradient clipping, but a tiny weight decay on multipliers and excluding them from global grad norm fix stability.
- •In long training, learnable multipliers improve average benchmark scores by about 1.2% over Adam and about 1.1% over Muon, with especially strong gains on reasoning tasks.
- •They also reduce the need for expensive μP forward and weight-decay multiplier tuning, since the multipliers learn the right scales by themselves.
Why This Research Matters
This work frees a hidden bottleneck in how LLMs learn: matrix scales were set by training settings, not by the data. Learnable multipliers give each layer (and even each feature) the power to set its own loudness cleanly, which leads to richer internal representations. The method is practical: it adds almost no training overhead, merges away at inference for zero runtime cost, and works with both Adam and Muon. It also reduces expensive, brittle hyperparameter sweeps for forward and weight-decay multipliers in ÎĽP workflows. Most importantly, it translates into better real-world performance, especially on reasoning tasks, which people notice when they ask models to plan, explain, or solve problems.
Detailed Explanation
Tap terms for definitions01Background & Problem Definition
🍞 Hook: You know how a microphone can squeal if its volume knob is set wrong, even if the singer is great? The sound equipment’s settings matter as much as the song.
🥬 The Concept (Weight Decay): Weight decay is like a gentle hand that keeps model numbers from getting too big during training.
- How it works:
- Every update, nudge weights slightly toward zero.
- Balance that nudge with learning signals from data.
- Keep growth under control so training stays stable.
- Why it matters: Without it, some weights can blow up, breaking training. 🍞 Anchor: It’s like slowly letting air out of an overinflated balloon so it doesn’t pop.
🍞 Hook: Imagine learning to ride a bike on a bumpy road. Even if you steer correctly, bumps make you wobble.
🥬 The Concept (Stochastic Gradient Noise): Gradient noise is the random wobble in learning updates caused by training on mini-batches.
- How it works:
- Each mini-batch is a random sample of data.
- Its gradient points slightly differently each time.
- These jitters add up like tiny random pushes.
- Why it matters: Too much wobble can grow weights in accidental directions. 🍞 Anchor: Like waves pushing a boat around while you try to paddle straight.
🍞 Hook: Think of a tug-of-war between a stretchy rubber band and a spring.
🥬 The Concept (Noise–Weight Decay Equilibrium): Matrix weights settle at a balance point where random noise pushes out and weight decay pulls back.
- How it works:
- Noise tends to expand matrices.
- Weight decay shrinks them.
- Over time, a steady typical size emerges, set by training hyperparameters (learning rate and decay), not by the data.
- Why it matters: If this balance point isn’t the best for the task, the model is stuck with a so-so volume knob. 🍞 Anchor: The model ends up singing too softly or too loudly because the room (training settings) picked the volume—not the song (data).
🍞 Hook: Picture a ruler that rescales numbers so they’re easier to compare.
🥬 The Concept (RMSNorm): RMSNorm keeps activations at a steady overall size using a learned per-feature scale.
- How it works:
- Measure the root-mean-square (RMS) of the vector.
- Divide by that RMS to normalize.
- Multiply by learned per-feature weights to allow helpful rescaling.
- Why it matters: It stabilizes deep networks and already gives some features their own dials. 🍞 Anchor: It’s like leveling flour in a measuring cup before adding it to dough.
🍞 Hook: Imagine rules for how big to pour ingredients when you double a recipe.
🥬 The Concept (Maximal Update Parametrization, μP): μP gives scaling rules so models of different widths behave similarly and hyperparameters transfer.
- How it works:
- Decide how to scale learning rates and in-layer multipliers with width.
- Keep activation sizes and relative update sizes stable as width grows.
- Tune small models, then transfer to big ones.
- Why it matters: Saves compute and keeps training predictable when changing model size. 🍞 Anchor: Like using a “times two” recipe card so taste stays the same in a bigger cake.
🍞 Hook: Different GPS apps can guide you to the same place with different shortcuts.
🥬 The Concept (AdamW and Muon Optimizers): AdamW and Muon are common optimizers that both use explicit weight decay to keep training stable.
- How it works:
- Estimate gradients with memory (AdamW) or specialized rules (Muon).
- Apply updates to move weights toward lower loss.
- Use weight decay to control size.
- Why it matters: Even with different update rules, both depend on decay to stay stable—and both inherit the same equilibrium trap for matrices. 🍞 Anchor: Two different driving routes, but both rely on speed limits (decay) to avoid crashes.
The world before: Big language models (LLMs) relied on weight decay to avoid unstable training. Researchers also used μP so that results and hyperparameters could transfer across sizes. RMSNorm’s per-feature weights already gave some flexible scaling. But a quiet problem remained: matrix layers (the big weight tables inside attention, MLPs, and SSMs) didn’t get to pick their own scale. Instead, their size was mostly set by a balance between random gradient noise and weight decay. The result was a one-size-fits-all volume knob, chosen by hyperparameters rather than data.
The problem: Many tasks need different feature scales across layers and even across individual features. If the matrix weights can’t learn those scales, the model can’t fully express the patterns in data. Performance plateaus a bit below what’s possible, especially for tasks needing careful coordination across layers, like reasoning.
Failed attempts: People already tried good initializations, different optimizers, more normalization, and μP tuning of forward multipliers. These helped but still pinned matrix scales to the noise–decay balance. You could hand-tune many multipliers (as in μP workflows), but that’s compute-heavy, brittle across datasets, and still not truly data-adaptive during training.
The gap: Let the model learn the volume knob for matrices directly. Scalars can learn freely (we rarely decay scalar or vector weights like RMSNorm), and their gradients are naturally less noisy because they average over many elements. So, what if we attach learnable dials—scalars, or one per row and per column—on top of matrix layers?
Real stakes: In real life, this means better models without changing inference speed or memory. It reduces costly hyperparameter hunts, works across optimizers, and seems to help reasoning benchmarks the most—useful for coding help, tutoring, planning, and more.
02Core Idea
🍞 Hook: Imagine a giant mixing board with lots of sliders. Until now, many of those sliders were locked by the room’s wiring (training settings), not by the music (data).
🥬 The Concept (Learnable Multipliers): Learnable multipliers are tiny trainable dials that sit on top of matrix layers to set their overall scale—and even per-row and per-column scales—based on the data.
- How it works:
- Wrap each matrix W with multipliers: a single scalar s (one dial), or a vector per row r and per column c (many small dials).
- Train these multipliers along with the weights.
- Because each multiplier pools many gradients, it doesn’t suffer noisy Brownian growth like matrices do.
- The matrix stays stable with decay; the multipliers set the final effective scale.
- Why it matters: This breaks the noise–decay trap so features can find their best loudness, depth by depth and feature by feature. 🍞 Anchor: It’s like putting easy-to-turn knobs on top of a fixed amp; the band can finally set the volume for each instrument during the show.
-
The “Aha!” moment in one sentence: If matrix sizes are stuck at a noise–decay balance, give them clean, learnable volume knobs (multipliers) that aren’t stuck—so the data picks the right scale.
-
Multiple analogies:
- Stereo analogy: Your receiver (matrix) is locked to a factory volume. Add a volume knob (multiplier) that you can actually turn during the song.
- Photography analogy: The camera’s ISO (matrix norm) is auto-locked by lighting (noise vs decay). Add an exposure slider (multiplier) so the photo isn’t too dark or too bright.
- Classroom analogy: Every student (feature) had the same speaking volume set by school rules. Give each student a microphone with its own gain knob so they can be heard appropriately.
- Before vs After:
- Before: Matrix scales are set mainly by learning rate and weight decay, not the data. Layer outputs across depth and width can’t diversify as much as they should.
- After: Multipliers learn per-layer and per-feature scales, creating richer internal representations and better loss. Tuning burdens shrink because forward and decay multipliers no longer need exhaustive ÎĽP sweeps.
- Why it works (the intuition, no equations):
- Matrix elements update with lots of random noise; decay fights that growth and pins their size.
- A scalar multiplier collects signals from an entire matrix; a row/column multiplier collects from a whole row/column. This averaging makes its gradient less noisy and more signal-driven.
- So the matrix can stay safely regularized, while the multipliers learn the right scale guided by data.
- Building blocks:
-
🍞 Hook: Picture a big master volume dial. 🥬 The Concept (Scalar Multiplier): One dial s for an entire matrix layer’s scale.
- How it works: Multiply the matrix by s; train s like any parameter; merge it back for inference.
- Why it matters: Quickly corrects layer-level loudness. 🍞 Anchor: Turning up the whole song’s volume.
-
🍞 Hook: Now imagine separate dials for each instrument section. 🥬 The Concept (Vector Multipliers): Per-row r (output features) and per-column c (input features) dials.
- How it works: Scale each row and column, so different outputs/inputs get their own loudness.
- Why it matters: Unlocks within-layer diversity so some features can whisper and others shout. 🍞 Anchor: Violins softer, drums louder—same orchestra, better balance.
-
🍞 Hook: Have you ever tugged on one rope and another loosened? 🥬 The Concept (Multiplicative and Normalization Symmetries): Some parts can grow while others shrink with no change in output, leading to drift.
- How it works: If two factors only matter as a product, one can rise while the other falls; if outputs are normalized, residuals can grow unbounded.
- Why it matters: Causes instability and NaNs in low-precision training unless lightly controlled. 🍞 Anchor: Two gears spinning in opposite ways but the machine behaves the same—still, the gears can fly off if not bounded.
-
🍞 Hook: Think of trimming only the right branches on a tree. 🥬 The Concept (Light WD on Multipliers): A tiny decay on multipliers curbs symmetry drift without re-trapping scales.
- How it works: Just enough decay to stop runaway growth from symmetries.
- Why it matters: Keeps training stable while preserving freedom to learn scale. 🍞 Anchor: A gentle fence keeps the dog in the yard without putting it on a leash.
-
🍞 Hook: Some models are wider, like choirs with more singers. 🥬 The Concept (Width Scaling with LRMs): As width grows, learned multipliers naturally adjust so key activations stay in a healthy range.
- How it works: With fixed LR and WD, matrix norms stay similar across widths; multipliers adapt to keep outputs well-scaled.
- Why it matters: Reduces the need to handcraft width scaling rules for these layers. 🍞 Anchor: Adding more choir members but the conductor keeps the overall loudness just right.
03Methodology
At a high level: Input tokens → Standard LLM layers (attention, MLP, SSM) reparameterized with multipliers → Train with careful stability tweaks → Merge multipliers into matrices for inference → Output predictions.
🍞 Hook: Imagine slipping adjustable lenses in front of every big camera lens.
🥬 The Concept (Reparameterization with Multipliers): Replace each effective matrix by multipliers times a learnable base matrix.
- How it works:
- Pick places to add multipliers: scalar s for whole-layer scale; vectors r (rows) and c (columns) for fine-grained scale.
- Compute outputs using reparameterized weights (e.g., y = (r·W·c) x for row/column multipliers).
- Train W and the multipliers together.
- For inference, fold multipliers into W to avoid extra cost.
- Why it matters: Keeps runtime fast while letting training discover the right scales. 🍞 Anchor: You use zoom lenses while composing the photo, but you lock in the final image without the extra hardware.
Recipe steps:
- Choose multiplier placement per block
- Attention: Put row/column multipliers around out projections; prefer queries over keys for per-head control; avoid redundant Q/K multipliers that only appear as a product.
- MLP: Multipliers around up/gate/down projections; avoid redundancy with RMSNorm’s column weights.
- SSM (Mamba2): Put multipliers where outputs go through nonlinearities or where no symmetry traps appear; leverage existing conv1d and skip-scale D that already act like vector dials.
- Why this step exists: Redundant placements create symmetries that cause drift and instability.
- Example: If both Q and K get row multipliers, their product stays the same while individual norms explode.
- Stabilize symmetry directions
- Apply a small weight decay just to multipliers (e.g., ~2e-3) to tame multiplicative and normalization symmetries.
- Why this step exists: Low-precision training magnifies drift; tiny decay stops runaways without re-trapping scale.
- Example: With light WD, Q/K multipliers stop see-sawing while their product stays useful.
- Handle gradient clipping smartly
- Exclude multipliers from the global gradient-norm clipping calculation.
- Why this step exists: Early in training, multipliers can have large gradients that, if counted, over-clip everyone else and slow learning.
- Example: Excluding them lowers the clip factor, improving early loss and long-term performance.
- Make width scaling hands-off
- Keep LR and WD fixed across widths in experiments; observe that matrices keep similar norms while multipliers adjust to keep key activations stable.
- Why this step exists: Confirms LRMs learn the needed scale with width, reducing hyperparameter headaches.
- Example: Attention logits and SSM dt signals stay in a healthy range as width grows.
- Train with your favorite optimizer
- Works with AdamW or Muon; both rely on weight decay for matrices and benefit similarly from LRMs.
- Why this step exists: Shows the idea is optimizer-agnostic.
- Example: Gains with Adam and Muon were comparable in both short and long runs.
- Merge for deployment
- After training, multiply multipliers into W once so inference has no extra parameters or latency.
- Why this step exists: Zero runtime overhead is key for serving.
- Example: The serving model looks identical to a normal LLM.
Concrete mini-experiments as examples:
- Projector (LM head) test: Sweep the matrix’s balance point by changing LR/WD while holding effective learning speed constant. Without multipliers, loss worsens at extreme scales; with a scalar or vector multiplier, logits stay well-scaled and loss stays good.
- MLP test: Make MLP matrices too big/small by LR/WD tweaks. Without multipliers (and with RMSNorm weights frozen), losses worsen at large scales; adding three scalar multipliers for MLP layers keeps outputs matched to the rest of the network and restores performance. Same pattern for Adam and Muon.
🍞 Hook: Secret ingredients make recipes shine.
🥬 The Concept (The Secret Sauce): Multipliers average gradients over many elements, so they hear the true signal louder than the noise.
- How it works:
- A scalar multiplier’s gradient sums information across the whole matrix.
- Row/column multipliers pool over a row/column.
- This reduces gradient noise, so multipliers aren’t pushed into the noisy equilibrium.
- Why it matters: They remain data-driven, setting just-right scales that matrices alone couldn’t learn under decay. 🍞 Anchor: A choir director listening to the whole section, not just one singer, makes steadier volume decisions.
04Experiments & Results
🍞 Hook: If you want to test a bike’s gears, you ride uphill, downhill, and on flat roads.
🥬 The Concept (What they measured and why): The authors tested whether multipliers actually free feature scales and improve real-world performance.
- How it works:
- Stress tests: Intentionally push matrix norms too high or too low by changing LR/WD.
- Compare configurations: no multipliers, scalar multipliers, vector multipliers.
- Track norms of matrices, multipliers, and activations across depth and width.
- Validate on long pretraining and downstream benchmarks.
- Why it matters: Shows multipliers fix the scale trap and translate into better scores. 🍞 Anchor: Like proving your new shock absorbers help on bumpy roads and also make daily driving smoother.
The competition (baselines):
- ÎĽP-tuned multipliers (forward, LR, WD) as a strong baseline.
- Optimizers: AdamW vs Muon.
- Architectures: Hybrid attention + SSM (Mamba2) with gated MLP blocks.
Key tests and findings:
-
Projector (LM head) sweep:
- Goal: See if the LM head can keep logits well-scaled when its matrix is forced too big or too small.
- Finding: Without multipliers, loss degrades at extreme scales; with scalar or vector multipliers, logit norms stay just-right and loss stays strong.
- Bonus: Simulated pure-noise Adam Brownian Motion matched matrix norm trajectories, but not multiplier norms—evidence multipliers aren’t noise-trapped.
-
MLP sweep:
- Goal: Change MLP matrices’ balance point while freezing RMSNorm weights so other blocks can’t auto-fix scale.
- Finding: Without multipliers, loss worsens at large scales (mismatch between blocks). Adding three scalar MLP multipliers brings scales back into harmony and restores loss. Same behavior under Adam and Muon.
-
Depth- and width-wise scale diversity:
- With scalar multipliers everywhere, later layers often contribute more (their outputs get larger), suggesting previously underused capacity.
- With vector multipliers, the distribution of per-row norms widens a lot, showing richer within-layer feature scales in attention, SSM, and MLP.
-
Symmetry handling:
- Without any decay on multipliers, multiplicative (e.g., Q/K) and normalization symmetries cause drifting norms and potential instability.
- A tiny decay on multipliers stops drift while keeping scale freedom.
-
Gradient clipping fix:
- If you include multipliers in the global clip norm, they can trigger heavy clipping early on, shrinking all updates.
- Excluding them from the global clip restores healthy gradients and better long-run loss.
Scoreboard (long-run results, ~200 GT):
- Average benchmark gains vs baselines:
- Adam → Adam+LRM: +1.21 percentage points on average.
- Muon → Muon+LRM: +1.10 points on average.
- Highlights:
- Reasoning-heavy tasks (BBH, GSM8K, MATH lvl5) improved notably, sometimes by multiple points.
- Knowledge benchmarks (ARC-C, MMLU) improved modestly.
- Context: A one-point average gain is like moving from a solid B to a high B+ across many exams; it’s especially meaningful at scale.
Surprising findings:
- Multipliers deliver a similar improvement to switching optimizers (Adam→Muon). Stacking both helps further.
- Using μP’s tuned forward and WD multipliers isn’t necessary when LRMs are used; the learned dials find good scales themselves. But tuned learning rate multipliers still matter.
- Under width scaling with fixed LR/WD, key activations remain stable thanks to learned multipliers, suggesting less need for hand-crafted scaling rules in those parts.
05Discussion & Limitations
🍞 Hook: Even the best new bikes have quirks you should know before riding downhill fast.
🥬 The Concept (Limitations and caveats): Learnable multipliers are powerful but need care.
- What they can’t do (yet):
- They don’t replace tuned learning rate multipliers; LR tuning still matters.
- They don’t fully solve all symmetries; a tiny WD or careful placement is still needed.
- The best projector (LM head) setup isn’t trivial—row multipliers on logits can encourage lazy training.
- Resources required:
- Training adds tiny overhead; inference has zero overhead after merging multipliers.
- You need a stable training setup (mixed precision, good clipping settings, etc.).
- When not to use:
- If your pipeline can’t adjust gradient clipping or add small WD to multipliers, you might hit instability.
- If your task is extremely sensitive to exact ÎĽP scaling rules you already rely on, you may want to run a careful A/B first.
- Open questions:
- Can we fully generalize ÎĽP scaling laws to include learned multipliers (e.g., how LR and WD for multipliers should scale with width)?
- Can we define a measurable signal-to-noise test that predicts which tensors can learn scale without decay?
- Why does alignment between inputs and weights seem to shrink mildly with width even when key activations stay stable—what’s the theory story?
- Which “circuits” in reasoning tasks benefit most from freed scales? 🍞 Anchor: It’s like knowing you still need a helmet, a small tune-up, and a test ride before racing.
06Conclusion & Future Work
Three-sentence summary: Matrix layers in LLMs often get stuck at a size chosen by the tug-of-war between gradient noise and weight decay, which isn’t always the best for the data. Learnable multipliers—simple trainable dials placed as scalars and as per-row/per-column vectors—free those scales so features can find the right loudness across depth and width. This brings consistent, optimizer-agnostic gains, reduces hand-tuning, and adds no inference cost.
Main achievement: Showing that a lightweight reparameterization (learnable multipliers) reliably breaks the noise–decay equilibrium trap for matrices, leading to richer internal feature scales and better end-task performance.
Future directions: Build full scaling rules that include multipliers; design symmetry-robust placements with theory-backed WD for multipliers; map which reasoning circuits benefit most; and explore projector designs that boost feature learning without lazy shortcuts.
Why remember this: It’s a clean, practical idea—you add dials where the music was stuck, let the data set the volume, and you get better songs without buying a bigger amp. In LLM terms, you free matrix scales, learn richer representations, and improve benchmarks, all while keeping serving costs unchanged.
Practical Applications
- •Pretraining LLMs with learnable multipliers to improve loss and downstream scores without changing inference costs.
- •Reducing μP tuning workload by skipping forward and weight-decay multiplier sweeps; keep only learning-rate multiplier tuning.
- •Improving reasoning-heavy capabilities (e.g., GSM8K, BBH, MATH) by enabling richer scale diversity across features.
- •Stabilizing training in mixed precision by adding tiny weight decay to multipliers to curb symmetry drift.
- •Avoiding early training slowdowns by excluding multipliers from the global gradient-norm clipping computation.
- •Adopting multiplier placements that avoid redundant symmetries (e.g., prefer query multipliers over key in attention).
- •Applying the same approach to hybrid architectures (attention + SSM) and to MLP blocks for universal benefit.
- •Scaling model width with fewer hyperparameter changes because multipliers naturally keep key activations well-sized.
- •Conducting safe ablations of projector (LM head) multipliers to avoid lazy training while still benefiting from scale freedom.
- •Merging multipliers into matrices after training to deploy models with no added latency or memory footprint.