🎓How I Study AIHISA
📖Read
📄Papers📰Blogs🎬Courses
💡Learn
🛤️Paths📚Topics💡Concepts🎴Shorts
🎯Practice
⏱️Coach🧩Problems🧠Thinking🎯Prompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
📚TheoryIntermediate

Grokking & Delayed Generalization

Key Points

  • •
    Grokking is when a model suddenly starts to generalize well long after it has already memorized the training set.
  • •
    During grokking, the training loss stays near zero for a long time while the test loss remains high, then abruptly drops.
  • •
    This delayed generalization often happens when implicit or explicit regularization gradually favors a simpler, rule-like solution over a memorized one.
  • •
    Overparameterized models with weight decay and long training are common settings where grokking is observed.
  • •
    You can think of it as a phase transition in learning dynamics driven by a competition between spurious/memorization features and true signal features.
  • •
    Monitoring the generalization gap, weight norms, and training time reveals when the shift from memorization to rules occurs.
  • •
    Toy setups with per-example “memorization” features versus global signal features reproduce grokking-like curves.
  • •
    Early stopping or too-weak regularization can prevent grokking, leaving the model stuck in memorization mode.

Prerequisites

  • →Train/Test Split and Evaluation — To interpret generalization curves and the gap between training and test performance.
  • →Linear Models and Logistic Regression — To understand the objective, gradients, and decision boundaries used in the toy demonstrations.
  • →Gradient Descent and SGD — To follow the optimization dynamics and their long-horizon effects on solutions.
  • →Regularization (L1/L2, Weight Decay) — To see how explicit penalties bias solutions toward simpler hypotheses over time.
  • →Vector Norms and Inner Products — To reason about simplicity (low-norm solutions) and compute gradients efficiently.
  • →Overparameterization and Capacity — To understand how memorization becomes possible and why regularization is needed.
  • →Probability Basics — To interpret expected risk, randomness in data generation, and variability across runs.

Detailed Explanation

Tap terms for definitions

01Overview

Imagine training a model that quickly gets every training example correct, yet keeps failing on the test set for a very long time. Then, surprisingly, after many more epochs of training, its test accuracy suddenly jumps to a high value. This phenomenon is called grokking. It was first observed in small algorithmic tasks (like modular arithmetic) where models memorized the dataset early on and only later discovered the underlying general rule.

Grokking shows up most clearly in overparameterized models—networks with enough capacity to memorize. The model has two competing ways to reduce training loss: (1) memorize each example using many parameters, or (2) discover a compact rule that works for all inputs. Both drive the training loss down, but only the second option generalizes. With weight decay or other regularizers, the learning dynamics slowly push the model toward simpler solutions over time.

Empirically, learning curves display a long plateau of low training loss and poor test performance, followed by an abrupt improvement in test metrics—like a phase transition. Theoretical perspectives link this to implicit regularization via optimization (e.g., gradient descent biasing toward low-norm solutions), minimum description length (MDL), and capacity control. Practically, you can reproduce grokking on toy problems by combining high capacity, strong weight decay, and long training. Understanding grokking helps practitioners design training protocols and interpret surprising learning dynamics in modern deep learning.

02Intuition & Analogies

Hook: Have you ever crammed for an exam by rote-memorizing answers, only to later, after weeks of casually thinking about the material, suddenly realize you actually understand the concept? That "aha" moment mirrors grokking in machine learning.

Concept: Early in training, an overpowered model can memorize the training set—like a student memorizing solutions—because it has enough parameters to store answers. This produces near-perfect training accuracy but poor test accuracy since the memorized specifics do not transfer. Over time, however, the combination of optimization and regularization (like weight decay) slowly penalizes this brittle lookup approach. Meanwhile, a simpler rule that explains many data points consistently earns small but steady improvements. The competition resembles two teams in a long game: the memorization team scores first and easily, but accumulates penalties, while the rule-learning team advances slowly and steadily. Eventually, the second team overtakes the first, and test performance jumps.

Example: Consider classifying points with two kinds of features. The first kind are unique IDs for each training point—perfect for memorization but useless on new data. The second kind are meaningful coordinates that reflect the true boundary. A linear model with weight decay can initially push the ID-related weights high (near-zero training loss, poor test loss). Weight decay gradually suppresses that bloated solution because many large weights are costly. At the same time, the coordinate-based weights keep getting reinforced across examples. After enough steps, the coordinate-based solution dominates; generalization improves sharply. That sudden leap is grokking.

03Formal Definition

Let Dtrain​ and Dtest​ be training and test distributions, and let Wt​ be model parameters after t optimization steps. Define the empirical risk Rtrain​(Wt​) and test risk Rtest​(Wt​) under a loss ℓ. The generalization gap is G(t) = Rtest​(Wt​) - Rtrain​(Wt​). We say delayed generalization (grokking) occurs when there exist times Tmem​ and Tgen​ with Tgen​ ≫ Tmem​ such that: 1) Rtrain​(Wt​) ≈ 0 for t ≥ Tmem​ (near-zero training loss maintained), 2) Rtest​(Wt​) remains high for t in [Tmem​, Tgen​), and then drops sharply near Tgen​, 3) The drop is associated with an increase in simplicity bias (e.g., lower norm, flatter minimum, or lower description length) of the effective solution. A common mechanistic model posits two feature subspaces: a memorization subspace M spanned by per-example indicators (fits Dtrain​ exactly but has no support on Dtest​), and a signal subspace S that captures the true rule (fits both). Let W_t = W_t^{(M)} + W_t(S) be the decomposition. With weight decay (L2), the optimization dynamics favor low-norm solutions. Because W_t(M) requires many coordinates to fit labels, its norm and penalty are large; W_t(S) aggregates signal consistently and can achieve low loss with smaller norm. Over a long horizon, regularization shrinks W_t(M) and lets W_t(S) dominate, leading to a late, abrupt improvement in Rtest​(Wt​).

04When to Use

Grokking is a phenomenon to watch for (and sometimes leverage), not a training recipe to apply everywhere. You should be aware of it when:

  • Training highly overparameterized models on small or algorithmic datasets, where memorization is easy and true structure is compact (e.g., modular arithmetic, string transformations, synthetic tasks).
  • Using strong explicit regularization (e.g., weight decay) or implicit regularization (e.g., specific optimizers or architectures) that prefers simpler hypotheses but may need long training time to overcome early memorization.
  • Investigating surprising learning curves (long plateaus, sudden leaps in test accuracy) and trying to attribute them to optimization dynamics rather than data leakage or bugs.
  • Designing experiments to probe inductive biases (e.g., observing whether a model eventually prefers low-norm or MDL-optimal solutions under sustained training).

Practical use cases include: validating that a model can, in principle, learn an underlying rule if trained long enough; studying how hyperparameters (weight decay, batch size, learning rate schedules) affect the timing of generalization; and constructing toy benchmarks that separate memorization from understanding. On the flip side, if you strictly need reliable early generalization (e.g., production training budgets), you may want to avoid conditions that produce grokking or use early stopping and stronger data augmentation to prevent it.

⚠️Common Mistakes

  • Confusing grokking with ordinary overfitting: In overfitting, test performance usually degrades as training continues. In grokking, test performance can be poor for a long time and then suddenly improve. Always plot both train and test curves over a long horizon.
  • Stopping too early: Early stopping can freeze the model in the memorization regime, missing the later transition. If you suspect grokking, extend training and monitor generalization metrics.
  • Insufficient regularization: Without an explicit or implicit simplicity bias (e.g., weight decay), the memorization solution may persist indefinitely. Tune regularization strength; too weak prevents the shift, too strong can underfit.
  • Misattributing randomness: Different seeds can change when (or if) grokking appears. Run multiple seeds and report variability.
  • Ignoring feature design in toy demos: To reproducibly demonstrate grokking, construct data where memorization features exist only in training, while signal features persist in test. Otherwise, the effect may be weak or invisible.
  • Overinterpreting single runs: A sudden test jump can happen by chance. Use confidence intervals, multiple runs, and diagnostics like weight norms or sparsity to support a grokking claim.

Key Formulas

Generalization Gap

G(t)=Rtest​(Wt​)−Rtrain​(Wt​)

Explanation: This measures how much worse the model performs on test data than on training data at time t. In grokking, G(t) stays large for a long time and then drops sharply.

L2-Regularized Objective

J(W)=n1​i=1∑n​ℓ(yi​,fW​(xi​))+2λ​∥W∥22​

Explanation: The total objective equals average loss plus an L2 penalty on weights. Increasing λ pushes the solution toward smaller norms, often encouraging simpler rules.

Weight Decay Update (Decoupled)

Wt+1​=(1−ηλ)Wt​−η∇W​(m1​i∈Bt​∑​ℓ(yi​,fW​(xi​)))

Explanation: With learning rate η and minibatch Bt​, the update shrinks weights by (1 - ηλ) and then applies a gradient step. Over time this penalizes large memorization weights.

Logistic Loss

ℓlogistic​(y,z)=log(1+e−yz)

Explanation: Common binary classification loss with labels y ∈ \{-1, +1\}. Its gradient pushes predictions to align with labels and is convenient for demonstrating training dynamics.

Max-Margin Implicit Bias

w∗=wargmin​∥w∥2​s.t.yi​(w⊤xi​)≥1,∀i

Explanation: For separable data, gradient descent on certain losses converges in direction to the maximum-margin classifier. This links optimization dynamics to simpler, low-norm solutions.

Grokking Time

Tgrok​=min{t:Atest​(t)≥τ∧Rtrain​(t)≈0}

Explanation: Defines the first time when test accuracy crosses a chosen threshold τ while training loss is already near zero. It formalizes the delayed jump in performance.

PAC-Bayes (Sketch)

Rtest​(f)≤Rtrain​(f)+2(n−1)KL(Q∥P)+lnδ2n​​​​

Explanation: A representative PAC-Bayes bound relates generalization to a complexity term (KL divergence between posterior Q and prior P). Lower complexity (simpler solutions) tightens the bound, connecting regularization to generalization.

Complexity Analysis

In the provided C++ toy demonstrations, we use linear models trained with stochastic gradient descent (SGD) and L2 weight decay over synthetic datasets that mix per-example memorization features with global signal features. Let n be the number of training samples, d the total feature dimension, E the number of epochs, and b the batch size (we use b=1 for pure SGD). Per step, computing the dot product and gradient costs O(d). An epoch processes n examples, costing O(n d). Across E epochs, the total time is O(E n d). For our toy where d=ds​ig+n (due to n one-hot memorization features), d is on the order of n, so the worst-case training time scales like O(E n2). This quadratic-in-n behavior arises from the intentionally constructed memorization features and is acceptable for small n (hundreds to a few thousands) used in demos. Evaluation on the test set adds O(nt​est d) per evaluation; logging periodically keeps overhead small. Space complexity is O(d) for the weight vector plus O(n d) to store the dataset if held densely. However, the memorization features are extremely sparse (one-hot), so a sparse representation can reduce storage to O(n + ds​ig) for data plus O(d) for weights. In our simple, dense implementation we accept O(n d) memory for clarity, but we note that a sparse format would be more efficient. In real deep networks where grokking was first observed, time complexity depends on architecture (e.g., transformers often scale per step as O(L dm​odel^2) or O(L dm​odel^2 + dm​odel^3) depending on implementation) and training steps can be in the millions. The key takeaway is that grokking requires long training horizons; thus, computational cost is dominated by the total number of optimization steps and the evaluation cadence you choose for monitoring generalization.

Code Examples

Toy Grokking via Competing Features (Memorization vs Rule) with L2 Weight Decay
1#include <bits/stdc++.h>
2using namespace std;
3
4// Logistic regression on synthetic data that mixes
5// - signal features (generalize to test)
6// - per-example one-hot memorization features (only exist in train)
7// This setup often shows delayed generalization ("grokking").
8
9struct Dataset {
10 vector<vector<double>> X; // features
11 vector<int> y; // labels in {-1, +1}
12};
13
14// Generate synthetic data
15// n_train: number of training samples
16// n_test: number of test samples
17// d_sig: number of global signal features
18// Returns train and test datasets with total dimension d = d_sig + n_train
19pair<Dataset, Dataset> make_dataset(int n_train, int n_test, int d_sig, unsigned seed=42) {
20 mt19937 rng(seed);
21 normal_distribution<double> gauss(0.0, 1.0);
22
23 // True signal weights (unit-normalized for stability)
24 vector<double> w_sig(d_sig);
25 for (int j = 0; j < d_sig; ++j) w_sig[j] = gauss(rng);
26 double norm = 0.0; for (double v: w_sig) norm += v*v; norm = sqrt(max(1e-12, norm));
27 for (double &v: w_sig) v /= norm;
28
29 int d = d_sig + n_train; // total features: signal + memorization one-hots
30
31 auto gen_xy = [&](int n_samples, bool is_train) {
32 Dataset D; D.X.resize(n_samples, vector<double>(d, 0.0)); D.y.resize(n_samples);
33 for (int i = 0; i < n_samples; ++i) {
34 // Signal features
35 vector<double> xs(d_sig);
36 for (int j = 0; j < d_sig; ++j) xs[j] = gauss(rng);
37 // Label from signal with small noise margin
38 double z = 0.0; for (int j = 0; j < d_sig; ++j) z += w_sig[j]*xs[j];
39 int y = (z >= 0.0) ? +1 : -1;
40
41 // Write features into full vector
42 for (int j = 0; j < d_sig; ++j) D.X[i][j] = xs[j];
43
44 if (is_train) {
45 // Per-example one-hot memorization features
46 // Position: d_sig + i is set to 1 for example i
47 D.X[i][d_sig + i] = 1.0;
48 } else {
49 // Test examples have no memorization coordinates (remain 0)
50 }
51 D.y[i] = y;
52 }
53 return D;
54 };
55
56 Dataset train = gen_xy(n_train, true);
57 Dataset test = gen_xy(n_test, false);
58 return {train, test};
59}
60
61struct Logger {
62 vector<int> epochs; vector<double> train_acc, test_acc, w_norm;
63 void log(int e, double ta, double va, double wn) {
64 epochs.push_back(e); train_acc.push_back(ta); test_acc.push_back(va); w_norm.push_back(wn);
65 }
66 void print_summary() {
67 cout << "epoch,train_acc,test_acc,weight_norm\n";
68 for (size_t i = 0; i < epochs.size(); ++i) {
69 cout << epochs[i] << "," << train_acc[i] << "," << test_acc[i] << "," << w_norm[i] << "\n";
70 }
71 }
72};
73
74// Compute accuracy
75double accuracy(const vector<vector<double>>& X, const vector<int>& y, const vector<double>& w) {
76 int n = (int)X.size(), d = (n? (int)X[0].size(): 0);
77 int correct = 0;
78 for (int i = 0; i < n; ++i) {
79 double z = 0.0; for (int j = 0; j < d; ++j) z += w[j]*X[i][j];
80 int pred = (z >= 0.0) ? +1 : -1;
81 if (pred == y[i]) ++correct;
82 }
83 return n ? (double)correct / n : 0.0;
84}
85
86int main() {
87 ios::sync_with_stdio(false);
88 cin.tie(nullptr);
89
90 // Hyperparameters
91 int n_train = 512;
92 int n_test = 512;
93 int d_sig = 5; // small global signal dimensionality
94 double lr = 0.05; // learning rate
95 double wd = 1e-3; // weight decay (L2)
96 int epochs = 6000; // long training horizon to reveal delayed generalization
97 int log_every = 200;
98
99 auto [train, test] = make_dataset(n_train, n_test, d_sig, 123);
100 int d = (int)train.X[0].size();
101
102 vector<double> w(d, 0.0); // initialize to zeros
103
104 // Indices for SGD
105 vector<int> idx(n_train); iota(idx.begin(), idx.end(), 0);
106 mt19937 rng(1234);
107
108 Logger logger;
109
110 auto step_decay = [&](vector<double>& w, double lr, double wd) {
111 // Decoupled weight decay (AdamW-style): w <- (1 - lr*wd) * w
112 double factor = max(0.0, 1.0 - lr * wd);
113 for (double &wj : w) wj *= factor;
114 };
115
116 auto logistic_grad_coeff = [](double y, double z) {
117 // derivative wrt z of logistic loss log(1+exp(-y z)) is -y / (1 + exp(y z))
118 return -y / (1.0 + exp(y * z));
119 };
120
121 for (int e = 1; e <= epochs; ++e) {
122 shuffle(idx.begin(), idx.end(), rng);
123 for (int it : idx) {
124 // Apply decoupled weight decay each step
125 step_decay(w, lr, wd);
126 // Compute z and gradient on one example
127 double z = 0.0; for (int j = 0; j < d; ++j) z += w[j] * train.X[it][j];
128 double gcoef = logistic_grad_coeff((double)train.y[it], z);
129 // w <- w - lr * gcoef * x
130 for (int j = 0; j < d; ++j) w[j] -= lr * gcoef * train.X[it][j];
131 }
132
133 if (e % log_every == 0 || e == 1 || e == epochs) {
134 double ta = accuracy(train.X, train.y, w);
135 double va = accuracy(test.X, test.y, w);
136 double wn = 0.0; for (double v: w) wn += v*v; wn = sqrt(wn);
137 logger.log(e, ta, va, wn);
138 }
139 }
140
141 // Print CSV for plotting externally
142 logger.print_summary();
143
144 cerr << "Final Train Acc: " << logger.train_acc.back() << "\n";
145 cerr << "Final Test Acc: " << logger.test_acc.back() << "\n";
146 cerr << "Final ||w||2 : " << logger.w_norm.back() << "\n";
147
148 return 0;
149}
150

We construct a dataset with two competing feature sets: (1) a small set of signal features shared by train and test that encodes the true rule; and (2) per-example one-hot features present only for training points, enabling perfect memorization. A logistic-regression model with decoupled L2 weight decay is trained by SGD. Early in training, the per-example weights rapidly memorize the labels, driving training accuracy near 1.0 while test accuracy stays low. Over many epochs, weight decay penalizes the many memorization weights, while gradients on the shared signal features accumulate consistently. Eventually, the signal-based solution dominates, and test accuracy jumps—demonstrating delayed generalization (grokking-like behavior). The program logs epoch, train/test accuracy, and weight norm as CSV for plotting.

Time: O(E * n_train * d) where d = d_sig + n_train (often O(E * n_train^2))Space: O(n_train * d) for dense data storage and O(d) for weights (sparse storage could reduce data memory to O(n_train + d_sig))
Early Stopping vs. Long Training: Grokking Can Be Missed
1#include <bits/stdc++.h>
2using namespace std;
3
4struct Dataset { vector<vector<double>> X; vector<int> y; };
5
6pair<Dataset, Dataset> make_dataset(int n_train, int n_test, int d_sig, unsigned seed=7) {
7 mt19937 rng(seed);
8 normal_distribution<double> gauss(0.0, 1.0);
9
10 vector<double> w_sig(d_sig);
11 for (int j = 0; j < d_sig; ++j) w_sig[j] = gauss(rng);
12 double norm = 0.0; for (double v: w_sig) norm += v*v; norm = sqrt(max(1e-12, norm));
13 for (double &v: w_sig) v /= norm;
14
15 int d = d_sig + n_train;
16 auto gen = [&](int n_samples, bool is_train){
17 Dataset D; D.X.assign(n_samples, vector<double>(d, 0.0)); D.y.resize(n_samples);
18 for (int i = 0; i < n_samples; ++i) {
19 double z = 0.0;
20 for (int j = 0; j < d_sig; ++j) { double xj = gauss(rng); D.X[i][j] = xj; z += w_sig[j]*xj; }
21 D.y[i] = (z >= 0.0) ? +1 : -1;
22 if (is_train) D.X[i][d_sig + i] = 1.0; // memorization one-hot only in train
23 }
24 return D;
25 };
26 return { gen(n_train, true), gen(n_test, false) };
27}
28
29struct RunCfg { int epochs; double lr; double wd; unsigned seed; };
30
31double accuracy(const Dataset& D, const vector<double>& w) {
32 int n = (int)D.X.size(), d = (n? (int)D.X[0].size(): 0);
33 int ok = 0; for (int i = 0; i < n; ++i) { double z=0; for (int j=0;j<d;++j) z+=w[j]*D.X[i][j]; int p=(z>=0)?+1:-1; ok += (p==D.y[i]); }
34 return n? (double)ok/n : 0.0;
35}
36
37vector<double> train_sgd(const Dataset& train, const RunCfg& cfg) {
38 int n = (int)train.X.size(), d = (n? (int)train.X[0].size(): 0);
39 vector<double> w(d, 0.0);
40 vector<int> idx(n); iota(idx.begin(), idx.end(), 0);
41 mt19937 rng(cfg.seed);
42 auto decay = [&](vector<double>& w){ double f = max(0.0, 1.0 - cfg.lr*cfg.wd); for (double &wj: w) wj *= f; };
43 auto gcoef = [](double y, double z){ return -y / (1.0 + exp(y*z)); };
44
45 for (int e = 1; e <= cfg.epochs; ++e) {
46 shuffle(idx.begin(), idx.end(), rng);
47 for (int i : idx) { decay(w); double z=0; for (int j=0;j<d;++j) z+=w[j]*train.X[i][j]; double gc=gcoef((double)train.y[i], z); for (int j=0;j<d;++j) w[j]-=cfg.lr*gc*train.X[i][j]; }
48 }
49 return w;
50}
51
52int main(){
53 ios::sync_with_stdio(false);
54 cin.tie(nullptr);
55
56 int n_train=512, n_test=512, d_sig=5;
57 auto [Dtr, Dte] = make_dataset(n_train, n_test, d_sig, 2024);
58
59 RunCfg early { 500, 0.05, 1e-3, 1u }; // stops early (likely still memorizing)
60 RunCfg longr { 6000, 0.05, 1e-3, 2u }; // long training (allows shift to rules)
61
62 auto w_early = train_sgd(Dtr, early);
63 auto w_long = train_sgd(Dtr, longr);
64
65 double tr_e = accuracy(Dtr, w_early), te_e = accuracy(Dte, w_early);
66 double tr_l = accuracy(Dtr, w_long ), te_l = accuracy(Dte, w_long );
67
68 auto norm2 = [](const vector<double>& w){ double s=0; for(double v:w) s+=v*v; return sqrt(s); };
69
70 cout << fixed << setprecision(4);
71 cout << "Early Stop -> Train Acc: " << tr_e << ", Test Acc: " << te_e << ", ||w||2: " << norm2(w_early) << "\n";
72 cout << "Long Train-> Train Acc: " << tr_l << ", Test Acc: " << te_l << ", ||w||2: " << norm2(w_long ) << "\n";
73
74 // Expectation: Train acc ~1.0 in both. Test acc low for early, higher for long.
75 return 0;
76}
77

This program uses the same competing-features dataset but compares two regimes: early stopping (few epochs) versus long training (many epochs) under identical weight decay. The early-stopped run typically has high training accuracy but poor test accuracy—it is stuck in memorization. The long run allows the implicit/explicit regularization to suppress memorization weights and amplify signal weights, often yielding a much higher test accuracy. Reporting the weight norms hints at movement toward a simpler solution in the long run.

Time: O(E * n_train * d) per run; two runs double the constant factorSpace: O(n_train * d) for data and O(d) for weights
#grokking#delayed generalization#weight decay#implicit regularization#generalization gap#phase transition#overparameterized models#logistic regression#stochastic gradient descent#spurious features#signal features#double descent#minimum description length#pac-bayes#early stopping