Sharpness-Aware Minimization (SAM)
Key Points
- •Sharpness-Aware Minimization (SAM) trains models to perform well even when their weights are slightly perturbed, seeking flatter minima that generalize better.
- •SAM solves a robust objective: minimize the worst-case loss within a small ball around the current weights, written as mi L(w+
- •In practice SAM takes two gradients per step: one to find an adversarial weight perturbation and another to update using the gradient at the perturbed weights.
- •The perturbation points in the direction of the gradient and is normalized by a norm, e.g., ε = ρ · g / ||_2 for an L2 ball, or = ρ · sign() for an L∞ ball.
- •SAM roughly doubles computation compared to vanilla SGD/Adam, but often improves test accuracy and robustness significantly.
- •The radius ρ is a critical hyperparameter: too small has little effect; too large can destabilize or over-smooth training.
- •SAM can be combined with base optimizers like SGD with momentum or Adam; weight decay and batch normalization need care.
- •Use SAM when you care about generalization, distribution shift resilience, or mild robustness—especially in overparameterized neural networks.
Prerequisites
- →Gradient Descent and SGD — SAM builds directly on gradient-based updates and mini-batch sampling.
- →Vector Norms and Dual Norms — Understanding L2 vs. L∞ balls and the dual-norm relationship is key to forming ε correctly.
- →First-Order Taylor Expansion — SAM uses linearization of the loss to approximate the inner maximization.
- →Logistic and Linear Regression Losses — Examples compute analytic gradients for MSE and cross-entropy.
- →Momentum/Adam Optimizers — SAM is often paired with these base optimizers for practical training.
- →Numerical Stability Basics — Normalization and logarithms in cross-entropy require small epsilons to avoid NaNs.
- →Regularization (Weight Decay) — To apply or exclude regularization consistently across the two SAM passes.
- →Mini-batch Training and Data Shuffling — SAM’s two evaluations should use the same batch to maintain correctness.
Detailed Explanation
Tap terms for definitions01Overview
Hook: Imagine you’re tuning a guitar string. You don’t just want perfect pitch at one microscopic point of tension; you want it to sound right even if the tension shifts a tiny bit. Models are similar: we don’t want them to be good only at one razor-thin setting of parameters; we want them to be robust to tiny changes. Concept: Sharpness-Aware Minimization (SAM) is an optimization method that prefers flat valleys over sharp pits in the loss landscape by explicitly minimizing the worst-case loss in a small neighborhood around the current weights. This is written as a min–max problem where we pick weights that still perform well after a small, adversarial nudge. Example: In standard training, you’d compute a gradient and step downhill. With SAM, you first shift your weights in the direction that would most increase the loss within a tiny radius, compute the gradient there, and then step downhill using that gradient—nudging you toward regions where the loss doesn’t rise quickly in any nearby direction.
02Intuition & Analogies
Think of hiking in fog. If you stop in a narrow, sharp pit, one small step can send you climbing steeply up; it’s unstable. If you settle in a broad, flat meadow, you can wander slightly without gaining much altitude. SAM prefers those meadows. Another analogy: Consider crafting a key. A key that only works when inserted at an exact sub-millimeter position (sharp minimum) is unreliable; a key that works despite small misalignments (flat minimum) is robust. SAM formalizes this by asking: “What if my weights are nudged slightly in the worst possible direction within a small limit? Will I still have low loss?” To answer, SAM first identifies the local “worst nudge” by following the current gradient and scaling it to a fixed norm-bound radius. Then it evaluates the loss gradient at that nudged position, updating parameters to reduce that worst-case loss. This two-step lookahead steers learning away from brittle configurations that depend on precise parameter values and toward regions with gentle curvature—often linked with better test performance and stability under perturbations, noise, or minor distribution shifts. In short, SAM trades a bit more computation per step for solutions that are less sensitive to tiny parameter noise.
03Formal Definition
04When to Use
- Deep neural networks where overfitting or brittle minima harm test accuracy.
- Small or noisy datasets, where encouraging flatness improves stability and generalization.
- Tasks sensitive to distribution shift or mild parameter noise (e.g., deployment with quantization, stochastic layers, or slight data drift).
- When you already use SGD/Adam and can afford roughly 2× compute per step for improved robustness. Avoid or tune carefully when: compute budget is extremely tight; losses are strictly convex with well-behaved curvature (SAM’s benefits may be marginal); batch normalization or dropout interactions are tricky (ensure consistent mini-batch/BN behavior for both gradients); or when (\rho) is too large relative to the curvature scale, which can oversmooth and slow convergence. Use cases: image classification, NLP fine-tuning, and tabular models (including linear/logistic regression) where a flat solution is preferred. Start with (\rho) in a small range (e.g., 0.01–0.1 for normalized parameter scales) and tune along with learning rate and weight decay.
⚠️Common Mistakes
- Using the gradient at the original weights for the update instead of at the perturbed weights. The SAM update must use (\nabla L(w+\epsilon^*)).
- Forgetting to normalize the perturbation. For L2 SAM, (\epsilon = \rho \cdot g/|g|_2); without normalization, the step size depends on gradient magnitude and breaks the min–max rationale.
- Choosing (\rho) too large. This can cause divergence or excessive smoothing that slows learning and reduces accuracy; tune (\rho) jointly with the learning rate.
- Inconsistent mini-batches/statistics between the two gradients. For BatchNorm or dropout, ensure the same batch/behavior is used for both evaluations (often done via a closure in frameworks).
- Applying weight decay or regularization inconsistently during the two passes. Keep the same loss definition across both passes, or apply decay only in the final update.
- Mixing norms unintentionally. If you intend L∞ SAM, use sign-based perturbation; if L2 SAM, normalize by the L2 norm. Verify numerical stability with a small epsilon (e.g., 1e-12) to avoid division by zero.
Key Formulas
SAM Robust Objective
Explanation: We choose parameters w that minimize the worst possible loss after any small perturbation ε within a p-norm ball of radius ρ. This formalizes the idea of robustness to tiny parameter changes.
First-Order Taylor Approximation
Explanation: Near w, the change in loss is approximately the inner product of the gradient and the perturbation. This justifies finding ε along the gradient direction for the inner maximization.
Dual Norm Relation
Explanation: The worst-case linear increase under a unit p-ball equals the q-norm of the gradient, where q is the dual of p. This yields a closed-form objective value for the inner problem.
L2-Ball Maximizer
Explanation: Within an L2 ball, the perturbation that maximizes the linearized loss points exactly in the gradient direction with magnitude ρ. This is the most common SAM variant.
L∞-Ball Maximizer
Explanation: Within an L∞ ball, the perturbation that maximizes the linearized loss sets each component to ±ρ based on the gradient’s sign. This yields a sign-based perturbation.
SAM Update Rule
Explanation: After forming the adversarial perturbation using the current gradient, we compute the gradient at the perturbed weights and perform a standard gradient step. Any base optimizer can be used for this step.
Sharpness Proxy
Explanation: The worst-case loss increase within a small ball is approximately proportional to the gradient’s q-norm. Minimizing this proxy encourages flatter regions where the loss rises slowly.
Compute Overhead
Explanation: SAM requires two forward-backward evaluations per iteration: one to find ε and one to compute the update gradient. This roughly doubles the per-step compute cost.
SAM with Momentum
Explanation: SAM can be paired with momentum by accumulating the gradient evaluated at the perturbed weights before applying the update. This smooths noisy gradients.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct Dataset { 5 vector<vector<double>> X; // n x d 6 vector<double> y; // n 7 }; 8 9 // Utility functions 10 static double dot(const vector<double>& a, const vector<double>& b) { 11 double s = 0.0; for (size_t i = 0; i < a.size(); ++i) s += a[i] * b[i]; return s; 12 } 13 static double l2_norm(const vector<double>& v) { 14 double s = 0.0; for (double x : v) s += x * x; return sqrt(max(s, 0.0)); 15 } 16 static void add_scaled(vector<double>& dst, const vector<double>& src, double alpha) { 17 for (size_t i = 0; i < dst.size(); ++i) dst[i] += alpha * src[i]; 18 } 19 20 // Linear regression: prediction = w^T x + b (packed as params [w0..wd-1, b]) 21 static double predict_one(const vector<double>& params, const vector<double>& x) { 22 size_t d = x.size(); 23 double b = params[d]; 24 return dot(params, x) + b - 0.0; // dot uses params[0..d-1]; params[d] added as bias 25 } 26 27 // Compute MSE loss and gradient over a mini-batch 28 static double mse_and_grad(const Dataset& data, const vector<int>& batch_idx, 29 const vector<double>& params, vector<double>& grad) { 30 size_t d = params.size() - 1; // last is bias 31 fill(grad.begin(), grad.end(), 0.0); 32 double loss = 0.0; 33 size_t m = batch_idx.size(); 34 for (int idx : batch_idx) { 35 const auto& x = data.X[idx]; 36 double y = data.y[idx]; 37 double yhat = predict_one(params, x); 38 double err = yhat - y; // residual 39 loss += err * err; // squared error 40 // gradient: d/dw = 2/m * err * x; d/db = 2/m * err 41 for (size_t j = 0; j < d; ++j) grad[j] += (2.0 / m) * err * x[j]; 42 grad[d] += (2.0 / m) * err; // bias term 43 } 44 return loss / m; // mean squared error 45 } 46 47 // One SAM step (L2 ball): 48 // 1) g = grad L(w) 49 // 2) epsilon = rho * g / ||g||_2 50 // 3) g_sam = grad L(w + epsilon) 51 // 4) w <- w - eta * g_sam 52 static void sam_step_l2(const Dataset& data, const vector<int>& batch_idx, 53 vector<double>& params, double lr, double rho) { 54 size_t dtotal = params.size(); 55 vector<double> g(dtotal, 0.0), g_sam(dtotal, 0.0); 56 57 // Compute gradient at current params 58 mse_and_grad(data, batch_idx, params, g); 59 60 // Form adversarial perturbation epsilon 61 double normg = l2_norm(g); 62 const double eps = 1e-12; // numerical stability 63 vector<double> epsilon(dtotal, 0.0); 64 if (normg > eps) { 65 for (size_t i = 0; i < dtotal; ++i) epsilon[i] = rho * g[i] / normg; 66 } // else leave epsilon = 0 67 68 // Compute gradient at perturbed params 69 vector<double> params_pert = params; 70 add_scaled(params_pert, epsilon, 1.0); 71 mse_and_grad(data, batch_idx, params_pert, g_sam); 72 73 // Update using gradient at perturbed point 74 for (size_t i = 0; i < dtotal; ++i) params[i] -= lr * g_sam[i]; 75 } 76 77 // Generate a simple synthetic linear dataset: y = w* x + b + noise 78 static Dataset make_synthetic_linear(size_t n, size_t d, unsigned seed=42) { 79 mt19937 rng(seed); 80 normal_distribution<double> nx(0.0, 1.0), nnoise(0.0, 0.1); 81 vector<double> w_true(d, 0.0); 82 for (size_t j = 0; j < d; ++j) w_true[j] = (j % 2 == 0 ? 1.0 : -0.5); 83 double b_true = 0.7; 84 85 Dataset data; data.X.resize(n, vector<double>(d)); data.y.resize(n); 86 for (size_t i = 0; i < n; ++i) { 87 for (size_t j = 0; j < d; ++j) data.X[i][j] = nx(rng); 88 double y = dot(w_true, data.X[i]) + b_true + nnoise(rng); 89 data.y[i] = y; 90 } 91 return data; 92 } 93 94 int main() { 95 ios::sync_with_stdio(false); 96 cin.tie(nullptr); 97 98 size_t n = 512, d = 10; 99 Dataset data = make_synthetic_linear(n, d); 100 101 // Parameters (w, b) initialized to zeros 102 vector<double> params(d + 1, 0.0); 103 104 // Training hyperparameters 105 double lr = 0.05; // learning rate 106 double rho = 0.05; // SAM radius 107 int epochs = 50; // number of passes 108 int batch_size = 64; 109 110 mt19937 rng(123); 111 vector<int> indices(n); iota(indices.begin(), indices.end(), 0); 112 113 for (int ep = 0; ep < epochs; ++ep) { 114 shuffle(indices.begin(), indices.end(), rng); 115 for (size_t start = 0; start < n; start += batch_size) { 116 size_t end = min(n, start + (size_t)batch_size); 117 vector<int> batch(indices.begin() + start, indices.begin() + end); 118 sam_step_l2(data, batch, params, lr, rho); 119 } 120 // Evaluate full training MSE each epoch 121 vector<double> gtmp(params.size(), 0.0); 122 double loss = mse_and_grad(data, indices, params, gtmp); 123 cout << "Epoch " << ep + 1 << ": MSE = " << loss << "\n"; 124 } 125 // Print final parameters 126 cout << "Final bias (b): " << params[d] << "\n"; 127 cout << "First 5 weights: "; 128 for (size_t j = 0; j < min((size_t)5, d); ++j) cout << params[j] << ' '; 129 cout << "\n"; 130 return 0; 131 } 132
This program trains a linear regression model using L2-SAM. For each mini-batch, it computes the gradient, forms an L2-normalized perturbation ε with radius ρ, re-computes the gradient at w+ε, and updates parameters with SGD. The synthetic dataset ensures the code runs end-to-end and shows SAM’s training dynamics.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct Dataset { 5 vector<vector<double>> X; // n x d 6 vector<int> y; // n (0 or 1) 7 }; 8 9 static double dot(const vector<double>& a, const vector<double>& b) { 10 double s = 0.0; for (size_t i = 0; i < a.size(); ++i) s += a[i] * b[i]; return s; 11 } 12 static double sigmoid(double z) { return 1.0 / (1.0 + exp(-z)); } 13 static void add_scaled(vector<double>& dst, const vector<double>& src, double alpha) { 14 for (size_t i = 0; i < dst.size(); ++i) dst[i] += alpha * src[i]; 15 } 16 static void sign_vector(const vector<double>& v, vector<double>& sgn) { 17 for (size_t i = 0; i < v.size(); ++i) sgn[i] = (v[i] > 0 ? 1.0 : (v[i] < 0 ? -1.0 : 0.0)); 18 } 19 20 // Logistic regression: p = sigmoid(w^T x + b). Loss: average binary cross-entropy. 21 static double bce_and_grad(const Dataset& data, const vector<int>& batch_idx, 22 const vector<double>& params, vector<double>& grad) { 23 size_t d = params.size() - 1; // last is bias 24 fill(grad.begin(), grad.end(), 0.0); 25 double loss = 0.0; 26 size_t m = batch_idx.size(); 27 for (int idx : batch_idx) { 28 const auto& x = data.X[idx]; 29 int y = data.y[idx]; 30 double z = dot(params, x) + params[d]; 31 double p = sigmoid(z); 32 // BCE: -[ y log p + (1-y) log(1-p) ] 33 const double eps = 1e-12; 34 loss += -( y * log(max(p, eps)) + (1 - y) * log(max(1 - p, eps)) ); 35 double err = (p - y); // gradient of BCE wrt z 36 for (size_t j = 0; j < d; ++j) grad[j] += err * x[j] / m; 37 grad[d] += err / m; // bias 38 } 39 return loss / m; 40 } 41 42 // One L∞-SAM step with momentum base optimizer 43 static void sam_step_linf_momentum(const Dataset& data, const vector<int>& batch_idx, 44 vector<double>& params, vector<double>& velocity, 45 double lr, double rho, double momentum) { 46 size_t dtotal = params.size(); 47 vector<double> g(dtotal, 0.0), g_sam(dtotal, 0.0), eps_vec(dtotal, 0.0); 48 49 // 1) Gradient at current params 50 bce_and_grad(data, batch_idx, params, g); 51 52 // 2) L∞-ball perturbation: epsilon_i = rho * sign(g_i) 53 sign_vector(g, eps_vec); 54 for (size_t i = 0; i < dtotal; ++i) eps_vec[i] *= rho; 55 56 // 3) Gradient at perturbed params 57 vector<double> params_pert = params; add_scaled(params_pert, eps_vec, 1.0); 58 bce_and_grad(data, batch_idx, params_pert, g_sam); 59 60 // 4) Momentum update using g_sam 61 for (size_t i = 0; i < dtotal; ++i) { 62 velocity[i] = momentum * velocity[i] + g_sam[i]; 63 params[i] -= lr * velocity[i]; 64 } 65 } 66 67 // Create a simple 2D, roughly linearly separable dataset 68 static Dataset make_classification(size_t n_per_class=200, unsigned seed=7) { 69 mt19937 rng(seed); 70 normal_distribution<double> n1x(-1.0, 0.6), n1y(0.0, 0.6); 71 normal_distribution<double> n2x(1.0, 0.6), n2y(0.2, 0.6); 72 Dataset data; data.X.reserve(2*n_per_class); data.y.reserve(2*n_per_class); 73 for (size_t i = 0; i < n_per_class; ++i) { 74 data.X.push_back({n1x(rng), n1y(rng)}); data.y.push_back(0); 75 } 76 for (size_t i = 0; i < n_per_class; ++i) { 77 data.X.push_back({n2x(rng), n2y(rng)}); data.y.push_back(1); 78 } 79 return data; 80 } 81 82 int main() { 83 ios::sync_with_stdio(false); 84 cin.tie(nullptr); 85 86 Dataset data = make_classification(); 87 size_t n = data.X.size(); size_t d = data.X[0].size(); 88 89 vector<double> params(d + 1, 0.0); // w and b 90 vector<double> velocity(d + 1, 0.0); // momentum state 91 92 double lr = 0.1; // learning rate 93 double rho = 0.05; // SAM radius (L∞) 94 double mu = 0.9; // momentum 95 int epochs = 40; // training epochs 96 int batch_size = 64; 97 98 mt19937 rng(1234); 99 vector<int> indices(n); iota(indices.begin(), indices.end(), 0); 100 101 for (int ep = 0; ep < epochs; ++ep) { 102 shuffle(indices.begin(), indices.end(), rng); 103 for (size_t start = 0; start < n; start += batch_size) { 104 size_t end = min(n, start + (size_t)batch_size); 105 vector<int> batch(indices.begin() + start, indices.begin() + end); 106 sam_step_linf_momentum(data, batch, params, velocity, lr, rho, mu); 107 } 108 // Report average BCE on full data 109 vector<double> gtmp(params.size(), 0.0); 110 double loss = bce_and_grad(data, indices, params, gtmp); 111 cout << "Epoch " << ep + 1 << ": BCE = " << loss << "\n"; 112 } 113 114 cout << "Final bias: " << params[d] << "\n"; 115 cout << "Weights: "; for (size_t j = 0; j < d; ++j) cout << params[j] << ' '; cout << "\n"; 116 return 0; 117 } 118
This example trains a logistic regression classifier using L∞-SAM with momentum. The inner maximization over an L∞ ball sets ε_i = ρ·sign(g_i). We recompute the gradient at w+ε and then apply a momentum update. It demonstrates combining SAM with a common base optimizer and a different norm choice.