Grokking & Delayed Generalization
Key Points
- •Grokking is when a model suddenly starts to generalize well long after it has already memorized the training set.
- •During grokking, the training loss stays near zero for a long time while the test loss remains high, then abruptly drops.
- •This delayed generalization often happens when implicit or explicit regularization gradually favors a simpler, rule-like solution over a memorized one.
- •Overparameterized models with weight decay and long training are common settings where grokking is observed.
- •You can think of it as a phase transition in learning dynamics driven by a competition between spurious/memorization features and true signal features.
- •Monitoring the generalization gap, weight norms, and training time reveals when the shift from memorization to rules occurs.
- •Toy setups with per-example “memorization” features versus global signal features reproduce grokking-like curves.
- •Early stopping or too-weak regularization can prevent grokking, leaving the model stuck in memorization mode.
Prerequisites
- →Train/Test Split and Evaluation — To interpret generalization curves and the gap between training and test performance.
- →Linear Models and Logistic Regression — To understand the objective, gradients, and decision boundaries used in the toy demonstrations.
- →Gradient Descent and SGD — To follow the optimization dynamics and their long-horizon effects on solutions.
- →Regularization (L1/L2, Weight Decay) — To see how explicit penalties bias solutions toward simpler hypotheses over time.
- →Vector Norms and Inner Products — To reason about simplicity (low-norm solutions) and compute gradients efficiently.
- →Overparameterization and Capacity — To understand how memorization becomes possible and why regularization is needed.
- →Probability Basics — To interpret expected risk, randomness in data generation, and variability across runs.
Detailed Explanation
Tap terms for definitions01Overview
Imagine training a model that quickly gets every training example correct, yet keeps failing on the test set for a very long time. Then, surprisingly, after many more epochs of training, its test accuracy suddenly jumps to a high value. This phenomenon is called grokking. It was first observed in small algorithmic tasks (like modular arithmetic) where models memorized the dataset early on and only later discovered the underlying general rule.
Grokking shows up most clearly in overparameterized models—networks with enough capacity to memorize. The model has two competing ways to reduce training loss: (1) memorize each example using many parameters, or (2) discover a compact rule that works for all inputs. Both drive the training loss down, but only the second option generalizes. With weight decay or other regularizers, the learning dynamics slowly push the model toward simpler solutions over time.
Empirically, learning curves display a long plateau of low training loss and poor test performance, followed by an abrupt improvement in test metrics—like a phase transition. Theoretical perspectives link this to implicit regularization via optimization (e.g., gradient descent biasing toward low-norm solutions), minimum description length (MDL), and capacity control. Practically, you can reproduce grokking on toy problems by combining high capacity, strong weight decay, and long training. Understanding grokking helps practitioners design training protocols and interpret surprising learning dynamics in modern deep learning.
02Intuition & Analogies
Hook: Have you ever crammed for an exam by rote-memorizing answers, only to later, after weeks of casually thinking about the material, suddenly realize you actually understand the concept? That "aha" moment mirrors grokking in machine learning.
Concept: Early in training, an overpowered model can memorize the training set—like a student memorizing solutions—because it has enough parameters to store answers. This produces near-perfect training accuracy but poor test accuracy since the memorized specifics do not transfer. Over time, however, the combination of optimization and regularization (like weight decay) slowly penalizes this brittle lookup approach. Meanwhile, a simpler rule that explains many data points consistently earns small but steady improvements. The competition resembles two teams in a long game: the memorization team scores first and easily, but accumulates penalties, while the rule-learning team advances slowly and steadily. Eventually, the second team overtakes the first, and test performance jumps.
Example: Consider classifying points with two kinds of features. The first kind are unique IDs for each training point—perfect for memorization but useless on new data. The second kind are meaningful coordinates that reflect the true boundary. A linear model with weight decay can initially push the ID-related weights high (near-zero training loss, poor test loss). Weight decay gradually suppresses that bloated solution because many large weights are costly. At the same time, the coordinate-based weights keep getting reinforced across examples. After enough steps, the coordinate-based solution dominates; generalization improves sharply. That sudden leap is grokking.
03Formal Definition
04When to Use
Grokking is a phenomenon to watch for (and sometimes leverage), not a training recipe to apply everywhere. You should be aware of it when:
- Training highly overparameterized models on small or algorithmic datasets, where memorization is easy and true structure is compact (e.g., modular arithmetic, string transformations, synthetic tasks).
- Using strong explicit regularization (e.g., weight decay) or implicit regularization (e.g., specific optimizers or architectures) that prefers simpler hypotheses but may need long training time to overcome early memorization.
- Investigating surprising learning curves (long plateaus, sudden leaps in test accuracy) and trying to attribute them to optimization dynamics rather than data leakage or bugs.
- Designing experiments to probe inductive biases (e.g., observing whether a model eventually prefers low-norm or MDL-optimal solutions under sustained training).
Practical use cases include: validating that a model can, in principle, learn an underlying rule if trained long enough; studying how hyperparameters (weight decay, batch size, learning rate schedules) affect the timing of generalization; and constructing toy benchmarks that separate memorization from understanding. On the flip side, if you strictly need reliable early generalization (e.g., production training budgets), you may want to avoid conditions that produce grokking or use early stopping and stronger data augmentation to prevent it.
⚠️Common Mistakes
- Confusing grokking with ordinary overfitting: In overfitting, test performance usually degrades as training continues. In grokking, test performance can be poor for a long time and then suddenly improve. Always plot both train and test curves over a long horizon.
- Stopping too early: Early stopping can freeze the model in the memorization regime, missing the later transition. If you suspect grokking, extend training and monitor generalization metrics.
- Insufficient regularization: Without an explicit or implicit simplicity bias (e.g., weight decay), the memorization solution may persist indefinitely. Tune regularization strength; too weak prevents the shift, too strong can underfit.
- Misattributing randomness: Different seeds can change when (or if) grokking appears. Run multiple seeds and report variability.
- Ignoring feature design in toy demos: To reproducibly demonstrate grokking, construct data where memorization features exist only in training, while signal features persist in test. Otherwise, the effect may be weak or invisible.
- Overinterpreting single runs: A sudden test jump can happen by chance. Use confidence intervals, multiple runs, and diagnostics like weight norms or sparsity to support a grokking claim.
Key Formulas
Generalization Gap
Explanation: This measures how much worse the model performs on test data than on training data at time t. In grokking, G(t) stays large for a long time and then drops sharply.
L2-Regularized Objective
Explanation: The total objective equals average loss plus an L2 penalty on weights. Increasing pushes the solution toward smaller norms, often encouraging simpler rules.
Weight Decay Update (Decoupled)
Explanation: With learning rate and minibatch , the update shrinks weights by (1 - ) and then applies a gradient step. Over time this penalizes large memorization weights.
Logistic Loss
Explanation: Common binary classification loss with labels y \{-1, +1\}. Its gradient pushes predictions to align with labels and is convenient for demonstrating training dynamics.
Max-Margin Implicit Bias
Explanation: For separable data, gradient descent on certain losses converges in direction to the maximum-margin classifier. This links optimization dynamics to simpler, low-norm solutions.
Grokking Time
Explanation: Defines the first time when test accuracy crosses a chosen threshold while training loss is already near zero. It formalizes the delayed jump in performance.
PAC-Bayes (Sketch)
Explanation: A representative PAC-Bayes bound relates generalization to a complexity term (KL divergence between posterior Q and prior P). Lower complexity (simpler solutions) tightens the bound, connecting regularization to generalization.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Logistic regression on synthetic data that mixes 5 // - signal features (generalize to test) 6 // - per-example one-hot memorization features (only exist in train) 7 // This setup often shows delayed generalization ("grokking"). 8 9 struct Dataset { 10 vector<vector<double>> X; // features 11 vector<int> y; // labels in {-1, +1} 12 }; 13 14 // Generate synthetic data 15 // n_train: number of training samples 16 // n_test: number of test samples 17 // d_sig: number of global signal features 18 // Returns train and test datasets with total dimension d = d_sig + n_train 19 pair<Dataset, Dataset> make_dataset(int n_train, int n_test, int d_sig, unsigned seed=42) { 20 mt19937 rng(seed); 21 normal_distribution<double> gauss(0.0, 1.0); 22 23 // True signal weights (unit-normalized for stability) 24 vector<double> w_sig(d_sig); 25 for (int j = 0; j < d_sig; ++j) w_sig[j] = gauss(rng); 26 double norm = 0.0; for (double v: w_sig) norm += v*v; norm = sqrt(max(1e-12, norm)); 27 for (double &v: w_sig) v /= norm; 28 29 int d = d_sig + n_train; // total features: signal + memorization one-hots 30 31 auto gen_xy = [&](int n_samples, bool is_train) { 32 Dataset D; D.X.resize(n_samples, vector<double>(d, 0.0)); D.y.resize(n_samples); 33 for (int i = 0; i < n_samples; ++i) { 34 // Signal features 35 vector<double> xs(d_sig); 36 for (int j = 0; j < d_sig; ++j) xs[j] = gauss(rng); 37 // Label from signal with small noise margin 38 double z = 0.0; for (int j = 0; j < d_sig; ++j) z += w_sig[j]*xs[j]; 39 int y = (z >= 0.0) ? +1 : -1; 40 41 // Write features into full vector 42 for (int j = 0; j < d_sig; ++j) D.X[i][j] = xs[j]; 43 44 if (is_train) { 45 // Per-example one-hot memorization features 46 // Position: d_sig + i is set to 1 for example i 47 D.X[i][d_sig + i] = 1.0; 48 } else { 49 // Test examples have no memorization coordinates (remain 0) 50 } 51 D.y[i] = y; 52 } 53 return D; 54 }; 55 56 Dataset train = gen_xy(n_train, true); 57 Dataset test = gen_xy(n_test, false); 58 return {train, test}; 59 } 60 61 struct Logger { 62 vector<int> epochs; vector<double> train_acc, test_acc, w_norm; 63 void log(int e, double ta, double va, double wn) { 64 epochs.push_back(e); train_acc.push_back(ta); test_acc.push_back(va); w_norm.push_back(wn); 65 } 66 void print_summary() { 67 cout << "epoch,train_acc,test_acc,weight_norm\n"; 68 for (size_t i = 0; i < epochs.size(); ++i) { 69 cout << epochs[i] << "," << train_acc[i] << "," << test_acc[i] << "," << w_norm[i] << "\n"; 70 } 71 } 72 }; 73 74 // Compute accuracy 75 double accuracy(const vector<vector<double>>& X, const vector<int>& y, const vector<double>& w) { 76 int n = (int)X.size(), d = (n? (int)X[0].size(): 0); 77 int correct = 0; 78 for (int i = 0; i < n; ++i) { 79 double z = 0.0; for (int j = 0; j < d; ++j) z += w[j]*X[i][j]; 80 int pred = (z >= 0.0) ? +1 : -1; 81 if (pred == y[i]) ++correct; 82 } 83 return n ? (double)correct / n : 0.0; 84 } 85 86 int main() { 87 ios::sync_with_stdio(false); 88 cin.tie(nullptr); 89 90 // Hyperparameters 91 int n_train = 512; 92 int n_test = 512; 93 int d_sig = 5; // small global signal dimensionality 94 double lr = 0.05; // learning rate 95 double wd = 1e-3; // weight decay (L2) 96 int epochs = 6000; // long training horizon to reveal delayed generalization 97 int log_every = 200; 98 99 auto [train, test] = make_dataset(n_train, n_test, d_sig, 123); 100 int d = (int)train.X[0].size(); 101 102 vector<double> w(d, 0.0); // initialize to zeros 103 104 // Indices for SGD 105 vector<int> idx(n_train); iota(idx.begin(), idx.end(), 0); 106 mt19937 rng(1234); 107 108 Logger logger; 109 110 auto step_decay = [&](vector<double>& w, double lr, double wd) { 111 // Decoupled weight decay (AdamW-style): w <- (1 - lr*wd) * w 112 double factor = max(0.0, 1.0 - lr * wd); 113 for (double &wj : w) wj *= factor; 114 }; 115 116 auto logistic_grad_coeff = [](double y, double z) { 117 // derivative wrt z of logistic loss log(1+exp(-y z)) is -y / (1 + exp(y z)) 118 return -y / (1.0 + exp(y * z)); 119 }; 120 121 for (int e = 1; e <= epochs; ++e) { 122 shuffle(idx.begin(), idx.end(), rng); 123 for (int it : idx) { 124 // Apply decoupled weight decay each step 125 step_decay(w, lr, wd); 126 // Compute z and gradient on one example 127 double z = 0.0; for (int j = 0; j < d; ++j) z += w[j] * train.X[it][j]; 128 double gcoef = logistic_grad_coeff((double)train.y[it], z); 129 // w <- w - lr * gcoef * x 130 for (int j = 0; j < d; ++j) w[j] -= lr * gcoef * train.X[it][j]; 131 } 132 133 if (e % log_every == 0 || e == 1 || e == epochs) { 134 double ta = accuracy(train.X, train.y, w); 135 double va = accuracy(test.X, test.y, w); 136 double wn = 0.0; for (double v: w) wn += v*v; wn = sqrt(wn); 137 logger.log(e, ta, va, wn); 138 } 139 } 140 141 // Print CSV for plotting externally 142 logger.print_summary(); 143 144 cerr << "Final Train Acc: " << logger.train_acc.back() << "\n"; 145 cerr << "Final Test Acc: " << logger.test_acc.back() << "\n"; 146 cerr << "Final ||w||2 : " << logger.w_norm.back() << "\n"; 147 148 return 0; 149 } 150
We construct a dataset with two competing feature sets: (1) a small set of signal features shared by train and test that encodes the true rule; and (2) per-example one-hot features present only for training points, enabling perfect memorization. A logistic-regression model with decoupled L2 weight decay is trained by SGD. Early in training, the per-example weights rapidly memorize the labels, driving training accuracy near 1.0 while test accuracy stays low. Over many epochs, weight decay penalizes the many memorization weights, while gradients on the shared signal features accumulate consistently. Eventually, the signal-based solution dominates, and test accuracy jumps—demonstrating delayed generalization (grokking-like behavior). The program logs epoch, train/test accuracy, and weight norm as CSV for plotting.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct Dataset { vector<vector<double>> X; vector<int> y; }; 5 6 pair<Dataset, Dataset> make_dataset(int n_train, int n_test, int d_sig, unsigned seed=7) { 7 mt19937 rng(seed); 8 normal_distribution<double> gauss(0.0, 1.0); 9 10 vector<double> w_sig(d_sig); 11 for (int j = 0; j < d_sig; ++j) w_sig[j] = gauss(rng); 12 double norm = 0.0; for (double v: w_sig) norm += v*v; norm = sqrt(max(1e-12, norm)); 13 for (double &v: w_sig) v /= norm; 14 15 int d = d_sig + n_train; 16 auto gen = [&](int n_samples, bool is_train){ 17 Dataset D; D.X.assign(n_samples, vector<double>(d, 0.0)); D.y.resize(n_samples); 18 for (int i = 0; i < n_samples; ++i) { 19 double z = 0.0; 20 for (int j = 0; j < d_sig; ++j) { double xj = gauss(rng); D.X[i][j] = xj; z += w_sig[j]*xj; } 21 D.y[i] = (z >= 0.0) ? +1 : -1; 22 if (is_train) D.X[i][d_sig + i] = 1.0; // memorization one-hot only in train 23 } 24 return D; 25 }; 26 return { gen(n_train, true), gen(n_test, false) }; 27 } 28 29 struct RunCfg { int epochs; double lr; double wd; unsigned seed; }; 30 31 double accuracy(const Dataset& D, const vector<double>& w) { 32 int n = (int)D.X.size(), d = (n? (int)D.X[0].size(): 0); 33 int ok = 0; for (int i = 0; i < n; ++i) { double z=0; for (int j=0;j<d;++j) z+=w[j]*D.X[i][j]; int p=(z>=0)?+1:-1; ok += (p==D.y[i]); } 34 return n? (double)ok/n : 0.0; 35 } 36 37 vector<double> train_sgd(const Dataset& train, const RunCfg& cfg) { 38 int n = (int)train.X.size(), d = (n? (int)train.X[0].size(): 0); 39 vector<double> w(d, 0.0); 40 vector<int> idx(n); iota(idx.begin(), idx.end(), 0); 41 mt19937 rng(cfg.seed); 42 auto decay = [&](vector<double>& w){ double f = max(0.0, 1.0 - cfg.lr*cfg.wd); for (double &wj: w) wj *= f; }; 43 auto gcoef = [](double y, double z){ return -y / (1.0 + exp(y*z)); }; 44 45 for (int e = 1; e <= cfg.epochs; ++e) { 46 shuffle(idx.begin(), idx.end(), rng); 47 for (int i : idx) { decay(w); double z=0; for (int j=0;j<d;++j) z+=w[j]*train.X[i][j]; double gc=gcoef((double)train.y[i], z); for (int j=0;j<d;++j) w[j]-=cfg.lr*gc*train.X[i][j]; } 48 } 49 return w; 50 } 51 52 int main(){ 53 ios::sync_with_stdio(false); 54 cin.tie(nullptr); 55 56 int n_train=512, n_test=512, d_sig=5; 57 auto [Dtr, Dte] = make_dataset(n_train, n_test, d_sig, 2024); 58 59 RunCfg early { 500, 0.05, 1e-3, 1u }; // stops early (likely still memorizing) 60 RunCfg longr { 6000, 0.05, 1e-3, 2u }; // long training (allows shift to rules) 61 62 auto w_early = train_sgd(Dtr, early); 63 auto w_long = train_sgd(Dtr, longr); 64 65 double tr_e = accuracy(Dtr, w_early), te_e = accuracy(Dte, w_early); 66 double tr_l = accuracy(Dtr, w_long ), te_l = accuracy(Dte, w_long ); 67 68 auto norm2 = [](const vector<double>& w){ double s=0; for(double v:w) s+=v*v; return sqrt(s); }; 69 70 cout << fixed << setprecision(4); 71 cout << "Early Stop -> Train Acc: " << tr_e << ", Test Acc: " << te_e << ", ||w||2: " << norm2(w_early) << "\n"; 72 cout << "Long Train-> Train Acc: " << tr_l << ", Test Acc: " << te_l << ", ||w||2: " << norm2(w_long ) << "\n"; 73 74 // Expectation: Train acc ~1.0 in both. Test acc low for early, higher for long. 75 return 0; 76 } 77
This program uses the same competing-features dataset but compares two regimes: early stopping (few epochs) versus long training (many epochs) under identical weight decay. The early-stopped run typically has high training accuracy but poor test accuracy—it is stuck in memorization. The long run allows the implicit/explicit regularization to suppress memorization weights and amplify signal weights, often yielding a much higher test accuracy. Reporting the weight norms hints at movement toward a simpler solution in the long run.