Neural Tangent Kernel (NTK)
Key Points
- •Neural Tangent Kernel (NTK) describes how wide neural networks train like kernel machines, turning gradient descent into kernel regression in the infinite-width limit.
- •For common activations like ReLU, the NTK can be computed using closed-form layer-wise recursions that depend only on input angles and norms.
- •Training a wide network with small learning rates is equivalent to training a linear model in function space with kernel K = NTK.
- •In practice, you can use the NTK as a kernel to perform kernel ridge regression with a closed-form solution.
- •Empirical NTKs from finite, randomly initialized networks concentrate around the infinite-width NTK as width grows.
- •Computing full NTK Gram matrices costs O( d L) time and O() memory; solving kernel regression is typically O().
- •ReLU NTK uses simple trigonometric formulas: sin and arccos of the angle between inputs, making it straightforward to implement in C++.
- •The NTK explains why very wide networks learn features slowly (or not at all) and mostly fit data by adjusting linear coefficients in a fixed feature space.
Prerequisites
- →Multivariable calculus — Gradients and Jacobians define the NTK as inner products of parameter derivatives.
- →Linear algebra — Kernel Gram matrices, positive semi-definiteness, and solving linear systems are essential.
- →Probability and random variables — Initialization distributions and law of large numbers underlie the infinite-width limit.
- →Numerical linear algebra — Stable solution of (K + \lambda I)\alpha = y requires understanding conditioning and regularization.
- →Neural networks basics — Layered architectures, activations, and parameterization are required to interpret the NTK recursion.
- →Kernel methods — Kernel regression and Gram matrices are the functional counterparts of wide-network training.
- →Optimization (gradient descent) — NTK describes gradient flow dynamics and convergence behavior.
- →C++ programming basics — Implementing NTK recursions and solvers requires comfort with arrays, loops, and numerical care.
Detailed Explanation
Tap terms for definitions01Overview
The Neural Tangent Kernel (NTK) is a powerful theory that connects neural networks to kernel methods. When a neural network becomes very wide (many neurons per layer) and is trained with small learning rates, its learning dynamics simplify dramatically: the network’s outputs evolve like a kernel regression model with a specific kernel, the NTK. This means we can predict and analyze training by studying a deterministic kernel function instead of the full, complicated parameter dynamics. For common architectures, especially fully connected networks with ReLU activations, the NTK can be computed by a simple recursion over layers using only input norms and angles. Intuitively, the NTK captures how similar two inputs are according to the network at initialization and how gradient descent will couple their predictions. The Gram matrix built from the NTK across training examples controls convergence speed and interpolation properties. In the infinite-width limit, the NTK remains constant during training (it does not change with parameters), making the learning problem linear in function space. This “linearization” turns nonlinear deep learning into a convex, kernel-based problem, enabling closed-form solutions and precise generalization analyses. Practically, NTK theory helps explain why very wide networks train reliably, why they can perfectly fit data (interpolate), and how their generalization depends on the data’s alignment with the NTK. It also offers a computational tool: instead of training a wide network, one can compute the NTK Gram matrix and perform kernel regression to get identical predictions in the infinite-width regime.
02Intuition & Analogies
Imagine you run a community meeting where each participant (a neuron) raises a sign indicating how much they agree with a statement (the output). If there are only a few participants, who they are and how they change their minds matters a lot. But if you gather an enormous crowd with similar backgrounds (very wide layers with i.i.d. initialization), the law of large numbers kicks in: the crowd’s behavior becomes predictable and smooth. When you tweak the meeting rules slightly (small learning rate), individual opinions barely move; instead, the overall pattern of agreement between different questions (inputs) is determined by a stable similarity function. This similarity function is the NTK. It tells you how a small adjustment aimed at improving the answer for one question will affect the answers to other, similar questions. If two questions are similar under the NTK, improving one will also improve the other, because their gradients with respect to parameters point in similar directions. In contrast, dissimilar questions hardly influence each other. Another analogy: think of a giant orchestra (the network) where each instrument plays so quietly (small parameter change) that you mostly hear the initial harmony (the initialization). As you nudge the volume knobs slightly, the music changes along directions already present in the original harmony. The NTK is the score that encodes which notes are coupled—turning up a violin at one pitch also lifts the violas at related pitches. With infinitely many instruments, this score becomes fixed and predictable. Thus, training is like adjusting volumes in a fixed, high-dimensional space defined by the NTK, rather than composing entirely new melodies (features).
03Formal Definition
04When to Use
- Theoretical analysis: Use NTK to analyze training dynamics, convergence, and generalization of very wide networks, especially to understand why small learning rates and overparameterization yield stable training.
- Practical surrogate for wide nets: If your architecture is a fully connected network (or another architecture with known NTK) and datasets are modest in size (hundreds to a few thousands), compute the NTK Gram matrix and apply kernel ridge regression instead of training the network.
- Sanity checks and diagnostics: Compare empirical NTK from a finite network to the infinite-width NTK to gauge whether the width is large enough for linearization to hold.
- Curriculum and augmentation effects: Study how data transformations change pairwise NTK similarities to predict learning couplings among examples.
- Small-data regimes: Kernel methods often shine with limited data; NTK offers a theoretically grounded kernel tailored to the chosen architecture and activation.
- Rapid prototyping: When experimenting with architectures or activations, computing their NTKs gives insight into what the model “believes” about function smoothness and invariances before you commit to full training.
⚠️Common Mistakes
- Ignoring parameterization: The NTK limit relies on specific scaling (NTK parameterization). Using standard He or Xavier without the correct 1/\sqrt{width} factors or mismatching bias/weight variances breaks the theoretical correspondence.
- Confusing NNGP with NTK: The NNGP kernel describes function values at initialization; the NTK governs training dynamics. They coincide for linear models and in some shallow cases but generally differ.
- Forgetting normalization: Base-layer covariance often uses \Sigma^{(0)}(x, x') = \frac{1}{d} x^{\top} x'. Skipping input scaling can distort angles and cause numerical issues in recursion.
- Numerical instability in angles: Computing \theta = \arccos(c) requires clamping c to [−1, 1]. Without clamping, floating-point errors can produce NaNs.
- Overlooking O(n^{3}) solve costs: Kernel ridge regression needs solving (K + \lambda I)\alpha = y, which is cubic in n. For large n, prefer iterative solvers or low-rank approximations.
- Assuming NTK is fixed at finite width: At practical widths, the empirical NTK changes during training. Linearization holds approximately; monitor deviations if using large learning rates or many epochs.
- Misusing regularization: With near-singular K, you need ridge regularization (\lambda > 0). Setting \lambda = 0 can lead to unstable inverses and overfitting.
- Mixing feature and kernel viewpoints: The NTK is a kernel over inputs, not a feature map you directly apply SGD to. If you want features, use the Jacobian (neural tangent features) explicitly.
Key Formulas
NTK definition
Explanation: The NTK equals the inner product of parameter gradients of network outputs for two inputs. It measures how updates that improve one input’s prediction transfer to the other.
Base covariance
Explanation: The layer-0 covariance is the scaled input dot product. Dividing by d normalizes variance across input dimensions.
ReLU correlation
Explanation: For ReLU, the correlation function depends only on the angle between inputs. This closed form drives the layer-wise covariance recursion.
NNGP recursion (ReLU)
Explanation: Given current-layer variances q, q' and covariance s (through ), the next-layer covariance adds a weighted ReLU correlation plus bias variance.
Derivative correlation (ReLU)
Explanation: This is the expected product of ReLU derivatives across two pre-activations. It scales the propagation of NTK through layers.
NTK recursion
Explanation: Initialization ties NTK to base covariance. Each layer multiplies the previous NTK by the derivative correlation and adds the new covariance.
Kernel ridge regression
Explanation: Given Gram matrix K and regularization , solve for coefficients . Predictions at a test point use its kernel vector k_* to training data.
Gradient flow in function space
Explanation: With fixed NTK K, gradient descent follows a linear ODE toward the labels. Solutions can be written with matrix exponentials for exact dynamics.
NTK convergence
Explanation: As width m grows, the empirical NTK concentrates to a deterministic kernel almost surely, justifying kernel-based training in the limit.
Effective dimension
Explanation: This quantity summarizes complexity of the kernel at scale and appears in learning-curve and generalization analyses.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Clamp value to [-1, 1] for numerical stability before acos 5 static inline double clamp_cos(double c) { 6 if (c > 1.0) return 1.0; 7 if (c < -1.0) return -1.0; 8 return c; 9 } 10 11 // Compute base covariance Sigma^(0)(x, x') = (1/d) x^T x' 12 static vector<vector<double>> base_cov(const vector<vector<double>>& A, 13 const vector<vector<double>>& B) { 14 int nA = (int)A.size(); 15 int nB = (int)B.size(); 16 int d = (int)A[0].size(); 17 vector<vector<double>> S(nA, vector<double>(nB, 0.0)); 18 for (int i = 0; i < nA; ++i) { 19 for (int j = 0; j < nB; ++j) { 20 double dot = 0.0; 21 for (int k = 0; k < d; ++k) dot += A[i][k] * B[j][k]; 22 S[i][j] = dot / (double)d; // normalization by input dim 23 } 24 } 25 return S; 26 } 27 28 // Compute per-input variances (diagonals) at layer 0: q = Sigma^(0)(x, x) 29 static vector<double> base_var(const vector<vector<double>>& X) { 30 int n = (int)X.size(), d = (int)X[0].size(); 31 vector<double> q(n, 0.0); 32 for (int i = 0; i < n; ++i) { 33 double s = 0.0; 34 for (int k = 0; k < d; ++k) s += X[i][k] * X[i][k]; 35 q[i] = s / (double)d; 36 } 37 return q; 38 } 39 40 // One layer update for ReLU NNGP and derivative correlation 41 // Inputs: current cross-cov S (nA x nB), variances qA (nA), qB (nB) 42 // Outputs: next S, next qA, next qB, and dotSigma (nA x nB) 43 static void relu_layer_update(const vector<vector<double>>& S, 44 const vector<double>& qA, 45 const vector<double>& qB, 46 double sigma_w2, double sigma_b2, 47 vector<vector<double>>& S_next, 48 vector<double>& qA_next, 49 vector<double>& qB_next, 50 vector<vector<double>>& dotSig) { 51 const double PI = acos(-1.0); 52 int nA = (int)S.size(); 53 int nB = (int)S[0].size(); 54 55 // Update cross-covariances and derivative correlations 56 S_next.assign(nA, vector<double>(nB, 0.0)); 57 dotSig.assign(nA, vector<double>(nB, 0.0)); 58 for (int i = 0; i < nA; ++i) { 59 for (int j = 0; j < nB; ++j) { 60 double qi = max(qA[i], 1e-12); // avoid divide-by-zero 61 double qj = max(qB[j], 1e-12); 62 double c = S[i][j] / sqrt(qi * qj); 63 c = clamp_cos(c); 64 double theta = acos(c); 65 double C = (sin(theta) + (PI - theta) * cos(theta)) / (2.0 * PI); 66 double Sn = sigma_w2 * sqrt(qi * qj) * C + sigma_b2; 67 S_next[i][j] = Sn; 68 double dotC = (PI - theta) / (2.0 * PI); 69 dotSig[i][j] = sigma_w2 * dotC; 70 } 71 } 72 73 // Update variances (diagonals) for A and B: theta=0 -> C=1/2, dotC=1/2 74 qA_next.assign(nA, 0.0); 75 qB_next.assign(nB, 0.0); 76 for (int i = 0; i < nA; ++i) { 77 qA_next[i] = sigma_w2 * qA[i] * 0.5 + sigma_b2; 78 } 79 for (int j = 0; j < nB; ++j) { 80 qB_next[j] = sigma_w2 * qB[j] * 0.5 + sigma_b2; 81 } 82 } 83 84 // Compute infinite-width ReLU NTK between A (nA x d) and B (nB x d) 85 // with L hidden layers, weight/bias variances sigma_w2 and sigma_b2. 86 static vector<vector<double>> ntk_infinite_relu(const vector<vector<double>>& A, 87 const vector<vector<double>>& B, 88 int L, 89 double sigma_w2, 90 double sigma_b2) { 91 // Base covariances and variances 92 vector<vector<double>> S = base_cov(A, B); // Sigma^(0)(A,B) 93 vector<double> qA = base_var(A); // diag for A at layer 0 94 vector<double> qB = base_var(B); // diag for B at layer 0 95 96 // NTK initialized as Theta^(0) = Sigma^(0) 97 vector<vector<double>> T = S; // Theta 98 99 for (int ell = 0; ell < L; ++ell) { 100 vector<vector<double>> S_next, dotSig; 101 vector<double> qA_next, qB_next; 102 relu_layer_update(S, qA, qB, sigma_w2, sigma_b2, S_next, qA_next, qB_next, dotSig); 103 104 // Theta^(ell+1) = Theta^(ell) .* dotSigma^(ell+1) + Sigma^(ell+1) 105 int nA = (int)A.size(), nB = (int)B.size(); 106 vector<vector<double>> T_next(nA, vector<double>(nB, 0.0)); 107 for (int i = 0; i < nA; ++i) 108 for (int j = 0; j < nB; ++j) 109 T_next[i][j] = T[i][j] * dotSig[i][j] + S_next[i][j]; 110 111 // advance 112 S.swap(S_next); 113 qA.swap(qA_next); 114 qB.swap(qB_next); 115 T.swap(T_next); 116 } 117 return T; 118 } 119 120 // Solve (A)x = b with Gaussian elimination and partial pivoting 121 static vector<double> solve_linear_system(vector<vector<double>> A, vector<double> b) { 122 int n = (int)A.size(); 123 for (int i = 0; i < n; ++i) { 124 // Pivot 125 int piv = i; 126 for (int r = i + 1; r < n; ++r) 127 if (fabs(A[r][i]) > fabs(A[piv][i])) piv = r; 128 if (fabs(A[piv][i]) < 1e-14) throw runtime_error("Singular matrix or ill-conditioned"); 129 if (piv != i) { 130 swap(A[piv], A[i]); 131 swap(b[piv], b[i]); 132 } 133 // Eliminate below 134 for (int r = i + 1; r < n; ++r) { 135 double f = A[r][i] / A[i][i]; 136 if (f == 0.0) continue; 137 for (int c = i; c < n; ++c) A[r][c] -= f * A[i][c]; 138 b[r] -= f * b[i]; 139 } 140 } 141 // Back substitution 142 vector<double> x(n, 0.0); 143 for (int i = n - 1; i >= 0; --i) { 144 double s = b[i]; 145 for (int c = i + 1; c < n; ++c) s -= A[i][c] * x[c]; 146 x[i] = s / A[i][i]; 147 } 148 return x; 149 } 150 151 int main() { 152 // Small demo dataset: y = sin(sum(x)) with 2D inputs 153 vector<vector<double>> X = {{-1.0, -0.5}, {-0.5, 0.3}, {0.2, -0.7}, {0.9, 0.8}}; 154 vector<double> y; 155 for (auto &v : X) y.push_back(sin(v[0] + v[1])); 156 157 // Test points 158 vector<vector<double>> Xtest = {{0.0, 0.0}, {0.5, 0.5}, {-0.8, 0.6}}; 159 160 int L = 2; // number of hidden layers 161 double sigma_w2 = 2.0; // typical for ReLU with NTK parametrization 162 double sigma_b2 = 0.0; 163 double lambda = 1e-3; // ridge regularization 164 165 // Train Gram matrix K and cross-kernel K_* = Theta(Xtest, X) 166 auto K = ntk_infinite_relu(X, X, L, sigma_w2, sigma_b2); 167 auto Kstar = ntk_infinite_relu(Xtest, X, L, sigma_w2, sigma_b2); 168 169 // Add ridge to K 170 for (int i = 0; i < (int)K.size(); ++i) K[i][i] += lambda; 171 172 // Solve (K + lambda I) alpha = y 173 vector<double> alpha = solve_linear_system(K, y); 174 175 // Predict f(Xtest) = K_* alpha 176 vector<double> ypred(Xtest.size(), 0.0); 177 for (int i = 0; i < (int)Xtest.size(); ++i) 178 for (int j = 0; j < (int)X.size(); ++j) 179 ypred[i] += Kstar[i][j] * alpha[j]; 180 181 cout.setf(ios::fixed); cout << setprecision(6); 182 for (int i = 0; i < (int)Xtest.size(); ++i) { 183 cout << "x*: (" << Xtest[i][0] << ", " << Xtest[i][1] << ") -> pred = " << ypred[i] 184 << ", target ~ " << sin(Xtest[i][0] + Xtest[i][1]) << "\n"; 185 } 186 return 0; 187 } 188
This program computes the infinite-width ReLU NTK using the standard layer-wise recursion and uses it as a kernel for kernel ridge regression. It builds the training Gram matrix K, solves the linear system with ridge regularization, and predicts on test inputs via K_* alpha. The implementation carefully maintains per-layer variances and clamps cosine values to avoid NaNs in arccos.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Two-layer network: f(x) = (1/sqrt(m)) * sum_j a_j * ReLU(w_j^T x) 5 // Empirical NTK for scalar output equals grad inner product: 6 // NTK(x,x') = (1/m) * (h(x) . h(x')) + (1/m) * (x . x') * sum_j a_j^2 * 1{z_j(x)>0}1{z_j(x')>0} 7 // where z_j(x) = w_j^T x, h_j(x) = ReLU(z_j(x)). 8 9 struct TwoLayer { 10 int m, d; // width and input dim 11 vector<vector<double>> W; // m x d 12 vector<double> a; // m 13 mt19937 rng; 14 normal_distribution<double> nd; 15 16 TwoLayer(int m_, int d_) : m(m_), d(d_), rng(123), nd(0.0, 1.0) { 17 W.assign(m, vector<double>(d, 0.0)); 18 a.assign(m, 0.0); 19 // NTK parameterization: W_ij ~ N(0, 1/d), a_j ~ N(0, 1) 20 for (int j = 0; j < m; ++j) { 21 for (int k = 0; k < d; ++k) W[j][k] = nd(rng) / sqrt((double)d); 22 a[j] = nd(rng); // variance 1 23 } 24 } 25 26 static inline double relu(double z) { return z > 0 ? z : 0; } 27 28 // Compute z, h for a batch of inputs 29 void forward(const vector<vector<double>>& X, 30 vector<vector<double>>& Z, 31 vector<vector<double>>& H) const { 32 int n = (int)X.size(); 33 Z.assign(n, vector<double>(m, 0.0)); 34 H.assign(n, vector<double>(m, 0.0)); 35 for (int i = 0; i < n; ++i) { 36 for (int j = 0; j < m; ++j) { 37 double s = 0.0; 38 for (int k = 0; k < d; ++k) s += W[j][k] * X[i][k]; 39 Z[i][j] = s; 40 H[i][j] = relu(s); 41 } 42 } 43 } 44 45 // Empirical NTK Gram matrix among X (n x d) 46 vector<vector<double>> empirical_ntk(const vector<vector<double>>& X) const { 47 int n = (int)X.size(); 48 vector<vector<double>> Z, H; 49 forward(X, Z, H); 50 vector<vector<double>> K(n, vector<double>(n, 0.0)); 51 // Precompute x dot products 52 vector<vector<double>> Xdot(n, vector<double>(n, 0.0)); 53 for (int i = 0; i < n; ++i) { 54 for (int j = i; j < n; ++j) { 55 double s = 0.0; 56 for (int k = 0; k < d; ++k) s += X[i][k] * X[j][k]; 57 Xdot[i][j] = Xdot[j][i] = s; 58 } 59 } 60 // For each pair (i,j): sum over neurons 61 for (int i = 0; i < n; ++i) { 62 for (int j = i; j < n; ++j) { 63 double term_a = 0.0; // (1/m) * h_i . h_j 64 double gate_sum = 0.0; // sum_j a_j^2 * 1{z_i>0}1{z_j>0} 65 for (int u = 0; u < m; ++u) { 66 term_a += H[i][u] * H[j][u]; 67 if (Z[i][u] > 0 && Z[j][u] > 0) gate_sum += a[u] * a[u]; 68 } 69 double Kij = (term_a + gate_sum * Xdot[i][j]) / (double)m; 70 K[i][j] = K[j][i] = Kij; 71 } 72 } 73 return K; 74 } 75 }; 76 77 int main() { 78 // Small dataset 79 vector<vector<double>> X = {{-1.0, -0.5}, {-0.5, 0.3}, {0.2, -0.7}, {0.9, 0.8}}; 80 81 int d = 2; 82 for (int m : {64, 256}) { 83 TwoLayer net(m, d); 84 auto K = net.empirical_ntk(X); 85 cout << "Width m=" << m << ", empirical NTK (first row):\n"; 86 cout.setf(ios::fixed); cout << setprecision(6); 87 for (int j = 0; j < (int)X.size(); ++j) cout << K[0][j] << (j+1==(int)X.size()?"\n":" "); 88 } 89 return 0; 90 } 91
This code computes the empirical NTK for a two-layer ReLU network without any autograd by exploiting closed-form gradients. The NTK between two inputs splits into a hidden-activation term and a gate-weight term proportional to x·x'. As width m increases, the empirical NTK concentrates around the infinite-width NTK for L=1 (after matching variances).