Natural Gradient Method
Key Points
- •Natural gradient scales the ordinary gradient by the inverse Fisher information matrix to account for the geometry of probability distributions.
- •It is the steepest descent direction measured in KL-divergence, making it invariant to reparameterization of the model.
- •Computing the full Fisher matrix is expensive, so practical implementations use approximations (diagonal, block-diagonal like K-FAC) or solve Fisher-vector systems with conjugate gradient.
- •For logistic regression, the Fisher matrix equals W X where W has entries (1-), letting us compute Fisher-vector products efficiently without forming the matrix.
- •Natural gradient steps often converge faster on ill-conditioned problems than vanilla SGD because they precondition by curvature.
- •You should add damping (F + I) to stabilize updates and avoid numerical issues when Fisher is singular or poorly estimated.
- •Empirical Fisher (based on sample gradients) and true Fisher (expectation over the model) are different; know which one you are using.
- •In C++, you can implement natural gradient efficiently using Fisher-vector products plus a conjugate gradient solver, avoiding explicit matrix inversion.
Prerequisites
- →Gradient Descent — Natural gradient modifies the basic gradient update using a geometry-aware preconditioner.
- →Probability and Log-Likelihood — The Fisher information and natural gradient are defined via log-likelihoods and expectations over the model.
- →Linear Algebra (Vectors, Matrices) — Understanding matrix–vector products, positive-definiteness, and solving linear systems is essential.
- →Logistic Regression — Provides a concrete model where the Fisher has a simple, efficient structure.
- →Conjugate Gradient Method — Efficiently solves (F + λI)s = g without forming F.
- →KL Divergence — Natural gradient is the steepest descent direction under the KL-induced metric.
Detailed Explanation
Tap terms for definitions01Overview
The natural gradient method is an optimization technique designed specifically for probabilistic models. Instead of moving parameters straight along the negative gradient, it rescales that step using the inverse of the Fisher information matrix, a measure of local curvature that captures how sensitive the model’s probability distribution is to parameter changes. The canonical update is (\theta_{t+1} = \theta_t - \eta , F(\theta_t)^{-1} , \nabla L(\theta_t)), where (L) is typically the average negative log-likelihood (or a loss aligned with it), and (F) is the Fisher information matrix.
Why do this? Ordinary gradients depend on how you choose to parameterize your model; change coordinates and the step can behave quite differently. The natural gradient uses a geometry-aware metric so that steps correspond to the most efficient change in the model’s predictive distribution, not just in parameter space. In practice, this often leads to faster and more stable convergence, especially when parameters have different scales or the problem is ill-conditioned.
The major challenge is computational: (F) can be huge. Instead of forming and inverting it, we use approximations (diagonal, block-diagonal) or iterative solvers that need only Fisher-vector products. This keeps the method practical for models like logistic/softmax regression and neural networks, where such products can be computed in time similar to a gradient evaluation.
02Intuition & Analogies
Imagine you’re hiking on a landscape while wearing shoes that stretch or shrink distances depending on direction—some directions feel steep, others flat, even if visually the terrain looks the same. In ordinary gradient descent, you measure “steepest” using a fixed ruler (Euclidean geometry). But probabilistic models live on a curved surface defined by how outputs (probabilities) change with parameters. Two equal-length parameter steps can cause very different changes in the predicted distribution.
Natural gradient replaces the fixed ruler with one that respects what we care about: changes in the model’s probability distribution. The Fisher information matrix acts like a smart tape measure. It stretches directions where small parameter tweaks drastically change predictions and compresses directions where parameters barely matter. Then, when you take a “steepest descent” step using this ruler, the method automatically shortens steps in sensitive directions and lengthens them in insensitive ones.
A helpful analogy: navigating a city with non-uniform blocks. If you only count intersections (Euclidean steps), you might pick a slow route. If instead you consider actual walking time accounting for traffic and block length (the Fisher metric), you’ll choose a more efficient path. Natural gradient is the “time-aware” choice—it moves in the direction that most quickly reduces loss measured in how the model’s probabilities improve.
In linear logistic regression, for instance, different features may have wildly different scales or frequencies. Natural gradient’s rescaling is like auto-adjusting your step size per direction according to how predictive and uncertain those features are, yielding more balanced progress.
03Formal Definition
04When to Use
- Ill-conditioned problems: If gradients oscillate or progress is slow due to poorly scaled parameters, natural gradient preconditions steps with curvature, often accelerating convergence.
- Probabilistic models: Logistic/softmax regression, variational inference, and probabilistic neural networks, where the Fisher metric matches how predictions change.
- Reparameterization concerns: When different parameterizations (e.g., changing units or activations) alter the behavior of vanilla gradients, natural gradient provides invariance and stability.
- Large models with structure: When you can compute efficient Fisher–vector products (e.g., GLMs, networks with layerwise structure) or use approximations like diagonal or K-FAC.
- Policy optimization in reinforcement learning: Natural policy gradient uses the Fisher of the policy to achieve stable updates under a KL trust region.
- Trust region methods: If you want to control the change in distribution via KL constraints, natural gradient (possibly with damping) naturally fits as a step direction or as part of a TRPO-like procedure.
- Small to medium dimensions: For small d, you can form and factorize F directly; for large d, prefer iterative solvers and approximations.
⚠️Common Mistakes
- Confusing empirical Fisher with true Fisher: The empirical Fisher uses observed labels; the true Fisher is an expectation under the model. They can differ substantially, especially far from optimum. Be explicit about which you use.
- Explicitly inverting F: Computing and inverting (F) is numerically unstable and expensive. Instead, solve (F s = g) with conjugate gradient or use approximations.
- Missing damping: Fisher can be singular or poorly conditioned. Add (\lambda I) (Levenberg–Marquardt style) to stabilize solves and tune (\lambda).
- Wrong gradients: Natural gradient uses gradients of the loss compatible with log-likelihood. If you mix signs or use MSE for a probabilistic model without adjustment, you can get inconsistent behavior.
- Stale Fisher: Using an outdated Fisher for many steps can hurt. Recompute or refresh approximations regularly (e.g., per epoch or with EMA).
- Ignoring batch effects: Small minibatches yield noisy Fisher estimates. Use larger batches or momentum/EMA to smooth.
- Parameterization traps: Constraints (e.g., variances must be positive) require appropriate parameterization (e.g., log-variance) so the Fisher is well-defined and steps remain valid.
Key Formulas
Natural Gradient Update
Explanation: Update parameters by scaling the gradient with the inverse Fisher information. This rescales steps according to the model’s geometry and typically improves conditioning.
Fisher Information Matrix
Explanation: The Fisher is the expected outer product of score functions. It defines a Riemannian metric that measures how parameter changes affect the distribution.
Steepest Descent in KL Geometry
Explanation: The natural gradient is the solution to minimizing the linearized loss subject to a KL-based trust region. Solving gives \( = - L\).
KL Second-Order Approximation
Explanation: For small parameter changes, the KL divergence is approximately quadratic with curvature F. This ties KL geometry to the Fisher matrix.
Empirical Fisher
Explanation: A sample-based estimate of the Fisher using observed data. It’s easy to compute but can differ from the true Fisher away from the model distribution.
Damped Natural Gradient Solve
Explanation: Instead of inverting F, solve a linear system possibly with damping Iterative solvers like conjugate gradient only need Fisher–vector products.
Fisher–Vector Product (General)
Explanation: This identity lets us compute the product of F with a vector v using per-sample gradients, avoiding explicit construction of F.
Fisher for Logistic Regression
Explanation: For binary logistic regression, the Fisher equals a weighted feature covariance. This structure enables efficient Fv computation: ))/N.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Sigmoid function 5 static inline double sigmoid(double z) { return 1.0 / (1.0 + exp(-z)); } 6 7 // Conjugate Gradient solver for (A + lambda I) s = b where A is accessed via a callback Av(v) 8 struct CGSolver { 9 int max_iters; 10 double tol; 11 double lambda; 12 function<void(const vector<double>&, vector<double>&)> Av; // computes A v 13 14 CGSolver(int max_iters_=20, double tol_=1e-8, double lambda_=1e-3) 15 : max_iters(max_iters_), tol(tol_), lambda(lambda_) {} 16 17 // Solve (A + lambda I) s = b. We pass size via b.size(). 18 vector<double> solve(const vector<double>& b) { 19 int d = (int)b.size(); 20 vector<double> s(d, 0.0), r = b, p = r, Ap(d, 0.0); 21 // r = b - (A + lambda I) * s, but s=0 initially so r=b 22 double rsold = inner_product(r.begin(), r.end(), r.begin(), 0.0); 23 for (int it = 0; it < max_iters; ++it) { 24 // Ap = (A + lambda I) p = A p + lambda p 25 Av(p, Ap); 26 for (int i = 0; i < d; ++i) Ap[i] += lambda * p[i]; 27 double pAp = inner_product(p.begin(), p.end(), Ap.begin(), 0.0); 28 double alpha = rsold / max(pAp, 1e-30); 29 for (int i = 0; i < d; ++i) s[i] += alpha * p[i]; 30 for (int i = 0; i < d; ++i) r[i] -= alpha * Ap[i]; 31 double rsnew = inner_product(r.begin(), r.end(), r.begin(), 0.0); 32 if (sqrt(rsnew) < tol) break; 33 double beta = rsnew / max(rsold, 1e-30); 34 for (int i = 0; i < d; ++i) p[i] = r[i] + beta * p[i]; 35 rsold = rsnew; 36 } 37 return s; 38 } 39 }; 40 41 // Helper to compute X * v and X^T * w for dense data 42 struct DenseData { 43 int N, d; 44 vector<vector<double>> X; // N x d 45 vector<int> y; // labels in {0,1} 46 47 DenseData(int N_, int d_) : N(N_), d(d_), X(N_, vector<double>(d_, 0.0)), y(N_, 0) {} 48 49 vector<double> Xv(const vector<double>& v) const { 50 vector<double> out(N, 0.0); 51 for (int i = 0; i < N; ++i) { 52 double s = 0.0; 53 for (int j = 0; j < d; ++j) s += X[i][j] * v[j]; 54 out[i] = s; 55 } 56 return out; 57 } 58 59 vector<double> XTw(const vector<double>& w) const { 60 vector<double> out(d, 0.0); 61 for (int j = 0; j < d; ++j) { 62 double s = 0.0; 63 for (int i = 0; i < N; ++i) s += X[i][j] * w[i]; 64 out[j] = s; 65 } 66 return out; 67 } 68 }; 69 70 // Build synthetic linearly separable dataset with noise 71 DenseData make_dataset(int N, int d, uint64_t seed=42) { 72 DenseData data(N, d); 73 mt19937_64 rng(seed); 74 normal_distribution<double> gauss(0.0, 1.0); 75 vector<double> w_true(d); 76 for (int j = 0; j < d; ++j) w_true[j] = gauss(rng); 77 for (int i = 0; i < N; ++i) { 78 for (int j = 0; j < d; ++j) data.X[i][j] = gauss(rng); 79 double z = inner_product(data.X[i].begin(), data.X[i].end(), w_true.begin(), 0.0); 80 double p = sigmoid(z); 81 bernoulli_distribution bern(p * 0.9 + 0.05); // add some label noise 82 data.y[i] = (int)bern(rng); 83 } 84 return data; 85 } 86 87 int main() { 88 ios::sync_with_stdio(false); 89 cin.tie(nullptr); 90 91 int N = 2000, d = 20; 92 DenseData data = make_dataset(N, d); 93 94 // Parameters 95 vector<double> theta(d, 0.0); 96 double lr = 1.0; // learning rate on the natural direction 97 double lambda = 1e-2; // damping for stability (F + lambda I) 98 99 // Precompute probabilities and weights each iteration 100 auto compute_prob = [&](const vector<double>& th) { 101 vector<double> p(N, 0.0); 102 for (int i = 0; i < N; ++i) { 103 double z = inner_product(data.X[i].begin(), data.X[i].end(), th.begin(), 0.0); 104 p[i] = sigmoid(z); 105 } 106 return p; 107 }; 108 109 auto compute_grad = [&](const vector<double>& p) { 110 // Gradient of average NLL: g = X^T (p - y) / N 111 vector<double> diff(N, 0.0); 112 for (int i = 0; i < N; ++i) diff[i] = p[i] - (double)data.y[i]; 113 vector<double> g = data.XTw(diff); 114 for (double &gi : g) gi /= (double)N; 115 return g; 116 }; 117 118 // Fisher–vector product for logistic regression: F v = X^T (W (X v)) / N, W_ii = p_i (1 - p_i) 119 vector<double> p = compute_prob(theta); 120 121 auto Av = [&](const vector<double>& v, vector<double>& out) { 122 vector<double> Xv = data.Xv(v); 123 vector<double> W_Xv(N, 0.0); 124 for (int i = 0; i < N; ++i) W_Xv[i] = p[i] * (1.0 - p[i]) * Xv[i]; 125 out = data.XTw(W_Xv); 126 for (double &oi : out) oi /= (double)N; 127 }; 128 129 CGSolver cg(15, 1e-8, lambda); 130 cg.Av = Av; 131 132 for (int iter = 0; iter < 30; ++iter) { 133 // Refresh probabilities and Fisher-vector product callback each iter 134 p = compute_prob(theta); 135 cg.Av = [&](const vector<double>& v, vector<double>& out) { 136 vector<double> Xv = data.Xv(v); 137 vector<double> W_Xv(N, 0.0); 138 for (int i = 0; i < N; ++i) W_Xv[i] = p[i] * (1.0 - p[i]) * Xv[i]; 139 out = data.XTw(W_Xv); 140 for (double &oi : out) oi /= (double)N; 141 }; 142 143 vector<double> g = compute_grad(p); // average gradient of NLL 144 vector<double> s = cg.solve(g); // solve (F + lambda I) s = g 145 146 // Natural gradient step 147 for (int j = 0; j < d; ++j) theta[j] -= lr * s[j]; 148 149 // Report average NLL 150 double nll = 0.0; 151 for (int i = 0; i < N; ++i) { 152 double z = inner_product(data.X[i].begin(), data.X[i].end(), theta.begin(), 0.0); 153 double pi = sigmoid(z); 154 // Clip for numerical stability 155 pi = min(max(pi, 1e-12), 1.0 - 1e-12); 156 nll += -(data.y[i] ? log(pi) : log(1.0 - pi)); 157 } 158 nll /= (double)N; 159 if (iter % 5 == 0) cerr << "iter=" << iter << "\tNLL=" << nll << "\n"; 160 } 161 162 // Print first few parameters 163 cout << fixed << setprecision(4); 164 for (int j = 0; j < min(d, 5); ++j) cout << theta[j] << (j+1<min(d,5)?' ':'\n'); 165 return 0; 166 } 167
This program fits binary logistic regression using natural gradient steps. It avoids forming the Fisher matrix by computing Fisher–vector products Fv = X^T(W(Xv))/N, where W_ii = p_i(1−p_i). A conjugate gradient (CG) solver computes s ≈ (F + λI)^{-1} g, with g the average gradient of negative log-likelihood. The parameters are updated by θ ← θ − η s. Damping λ improves numerical stability. The complexity per CG iteration is O(Nd), similar to a pair of matrix–vector multiplies over the dataset.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Data: x_i ~ Normal(mu, sigma^2). We estimate mu with known sigma. 5 // For this model, per-sample score: d/dmu log p(x_i|mu) = (x_i - mu) / sigma^2. 6 // Fisher per sample: 1 / sigma^2. For N samples: F = N / sigma^2. 7 // Natural gradient step equals ordinary gradient scaled by sigma^2 / N. 8 9 int main() { 10 ios::sync_with_stdio(false); 11 cin.tie(nullptr); 12 13 int N = 500; 14 double sigma = 2.0; // known standard deviation 15 double mu_true = 3.5; // ground-truth mean 16 uint64_t seed = 123; 17 mt19937_64 rng(seed); 18 normal_distribution<double> gauss(mu_true, sigma); 19 20 vector<double> x(N); 21 for (int i = 0; i < N; ++i) x[i] = gauss(rng); 22 23 // Start far from true mean 24 double mu_sgd = -5.0; 25 double mu_nat = -5.0; 26 27 // Learning rates 28 double eta_sgd = 0.05; // for ordinary gradient descent on average NLL 29 double eta_nat = 1.0; // for natural gradient (will be scaled by F^{-1}) 30 31 // Precompute Fisher (true Fisher for this model) 32 // Average NLL gradient: g = (1/N) * sum_i (mu - x_i) / sigma^2 = (mu - x_bar) / sigma^2 33 // Fisher: F = 1 / sigma^2 per sample -> for average NLL, F_avg = 1 / sigma^2 34 // Using average loss, natural step uses F_avg^{-1} = sigma^2. 35 double Fisher_avg = 1.0 / (sigma * sigma); 36 double Finv_avg = 1.0 / Fisher_avg; // = sigma^2 37 38 auto mean_of = [&](const vector<double>& v){ return accumulate(v.begin(), v.end(), 0.0) / (double)v.size(); }; 39 double xbar = mean_of(x); 40 41 for (int t = 0; t < 40; ++t) { 42 // Ordinary gradient of average NLL: g = (mu - xbar) / sigma^2 43 double g_sgd = (mu_sgd - xbar) / (sigma * sigma); 44 mu_sgd -= eta_sgd * g_sgd; 45 46 // Natural gradient: g = (mu - xbar) / sigma^2, s = F^{-1} g = sigma^2 * g = (mu - xbar) 47 double g_nat = (mu_nat - xbar) / (sigma * sigma); 48 double s = Finv_avg * g_nat; // equals (mu - xbar) 49 mu_nat -= eta_nat * s; // one step moves directly toward xbar 50 51 if (t % 5 == 0) { 52 cout << "t=" << setw(2) << t 53 << "\tmu_sgd=" << setw(8) << fixed << setprecision(4) << mu_sgd 54 << "\tmu_nat=" << setw(8) << mu_nat << '\n'; 55 } 56 } 57 58 cout << "sample mean xbar = " << fixed << setprecision(4) << xbar << '\n'; 59 return 0; 60 } 61
For Gaussian mean estimation with known variance, the Fisher for the mean parameter under the average NLL is F = 1/σ^2, so F^{-1} = σ^2. The natural gradient direction equals the ordinary gradient multiplied by σ^2, which cancels the σ^2 in the gradient and yields s = (μ − x̄). With η_nat = 1, a single natural-gradient step jumps directly to the sample mean. This simple case illustrates how natural gradient removes parameter scaling from the update.