🎓How I Study AIHISA
📖Read
📄Papers📰Blogs🎬Courses
💡Learn
🛤️Paths📚Topics💡Concepts🎴Shorts
🎯Practice
⏱️Coach🧩Problems🧠Thinking🎯Prompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
⚙️AlgorithmIntermediate

Sharpness-Aware Minimization (SAM)

Key Points

  • •
    Sharpness-Aware Minimization (SAM) trains models to perform well even when their weights are slightly perturbed, seeking flatter minima that generalize better.
  • •
    SAM solves a robust objective: minimize the worst-case loss within a small ball around the current weights, written as minw​ max∣∣ε∣∣≤ρ​ L(w+ε).
  • •
    In practice SAM takes two gradients per step: one to find an adversarial weight perturbation ε, and another to update using the gradient at the perturbed weights.
  • •
    The perturbation points in the direction of the gradient and is normalized by a norm, e.g., ε = ρ · g / |∣g∣|_2 for an L2 ball, or εi​ = ρ · sign(gi​) for an L∞ ball.
  • •
    SAM roughly doubles computation compared to vanilla SGD/Adam, but often improves test accuracy and robustness significantly.
  • •
    The radius ρ is a critical hyperparameter: too small has little effect; too large can destabilize or over-smooth training.
  • •
    SAM can be combined with base optimizers like SGD with momentum or Adam; weight decay and batch normalization need care.
  • •
    Use SAM when you care about generalization, distribution shift resilience, or mild robustness—especially in overparameterized neural networks.

Prerequisites

  • →Gradient Descent and SGD — SAM builds directly on gradient-based updates and mini-batch sampling.
  • →Vector Norms and Dual Norms — Understanding L2 vs. L∞ balls and the dual-norm relationship is key to forming ε correctly.
  • →First-Order Taylor Expansion — SAM uses linearization of the loss to approximate the inner maximization.
  • →Logistic and Linear Regression Losses — Examples compute analytic gradients for MSE and cross-entropy.
  • →Momentum/Adam Optimizers — SAM is often paired with these base optimizers for practical training.
  • →Numerical Stability Basics — Normalization and logarithms in cross-entropy require small epsilons to avoid NaNs.
  • →Regularization (Weight Decay) — To apply or exclude regularization consistently across the two SAM passes.
  • →Mini-batch Training and Data Shuffling — SAM’s two evaluations should use the same batch to maintain correctness.

Detailed Explanation

Tap terms for definitions

01Overview

Hook: Imagine you’re tuning a guitar string. You don’t just want perfect pitch at one microscopic point of tension; you want it to sound right even if the tension shifts a tiny bit. Models are similar: we don’t want them to be good only at one razor-thin setting of parameters; we want them to be robust to tiny changes. Concept: Sharpness-Aware Minimization (SAM) is an optimization method that prefers flat valleys over sharp pits in the loss landscape by explicitly minimizing the worst-case loss in a small neighborhood around the current weights. This is written as a min–max problem where we pick weights that still perform well after a small, adversarial nudge. Example: In standard training, you’d compute a gradient and step downhill. With SAM, you first shift your weights in the direction that would most increase the loss within a tiny radius, compute the gradient there, and then step downhill using that gradient—nudging you toward regions where the loss doesn’t rise quickly in any nearby direction.

02Intuition & Analogies

Think of hiking in fog. If you stop in a narrow, sharp pit, one small step can send you climbing steeply up; it’s unstable. If you settle in a broad, flat meadow, you can wander slightly without gaining much altitude. SAM prefers those meadows. Another analogy: Consider crafting a key. A key that only works when inserted at an exact sub-millimeter position (sharp minimum) is unreliable; a key that works despite small misalignments (flat minimum) is robust. SAM formalizes this by asking: “What if my weights are nudged slightly in the worst possible direction within a small limit? Will I still have low loss?” To answer, SAM first identifies the local “worst nudge” by following the current gradient and scaling it to a fixed norm-bound radius. Then it evaluates the loss gradient at that nudged position, updating parameters to reduce that worst-case loss. This two-step lookahead steers learning away from brittle configurations that depend on precise parameter values and toward regions with gentle curvature—often linked with better test performance and stability under perturbations, noise, or minor distribution shifts. In short, SAM trades a bit more computation per step for solutions that are less sensitive to tiny parameter noise.

03Formal Definition

SAM optimizes a robust objective: \( minw∈Rd​ max∥ϵ∥p​≤ρ​ L(w + ϵ) \), where \(L\) is the empirical training loss, \(ϵ\) is a parameter perturbation constrained by a norm ball of radius \(ρ > 0\), and \(p\) selects the norm (commonly 2 or \(∞\)). Using a first-order Taylor expansion, \( L(w+ϵ) ≈ L(w) + ∇ L(w)^⊤ ϵ \). The inner maximization over the \(p\)-ball then equals \( ρ \∣∇L(w)∥_q \), where \(q\) is the dual norm satisfying \(1/p + 1/q = 1\). The maximizer direction is along the gradient: for \(p=2\), \( ϵ^* = ρ ∇ L(w)/\∣∇L(w)∥_2 \); for \(p=∞\), \( ϵ^*_i = ρ\,sign([∇ L(w)]_i) \). SAM performs updates using the gradient evaluated at the perturbed weights, i.e., \( wt+1​=wt​ - η ∇ L(wt​ + ϵ^*(wt​)) \). In practice, the gradient used to form \(ϵ^*\) is computed on the same mini-batch, and the base step can integrate momentum or Adam. This procedure approximates the gradient of the robust objective and biases the optimizer toward flatter minima where nearby perturbations cause little loss increase.

04When to Use

  • Deep neural networks where overfitting or brittle minima harm test accuracy.
  • Small or noisy datasets, where encouraging flatness improves stability and generalization.
  • Tasks sensitive to distribution shift or mild parameter noise (e.g., deployment with quantization, stochastic layers, or slight data drift).
  • When you already use SGD/Adam and can afford roughly 2× compute per step for improved robustness. Avoid or tune carefully when: compute budget is extremely tight; losses are strictly convex with well-behaved curvature (SAM’s benefits may be marginal); batch normalization or dropout interactions are tricky (ensure consistent mini-batch/BN behavior for both gradients); or when (\rho) is too large relative to the curvature scale, which can oversmooth and slow convergence. Use cases: image classification, NLP fine-tuning, and tabular models (including linear/logistic regression) where a flat solution is preferred. Start with (\rho) in a small range (e.g., 0.01–0.1 for normalized parameter scales) and tune along with learning rate and weight decay.

⚠️Common Mistakes

  • Using the gradient at the original weights for the update instead of at the perturbed weights. The SAM update must use (\nabla L(w+\epsilon^*)).
  • Forgetting to normalize the perturbation. For L2 SAM, (\epsilon = \rho \cdot g/|g|_2); without normalization, the step size depends on gradient magnitude and breaks the min–max rationale.
  • Choosing (\rho) too large. This can cause divergence or excessive smoothing that slows learning and reduces accuracy; tune (\rho) jointly with the learning rate.
  • Inconsistent mini-batches/statistics between the two gradients. For BatchNorm or dropout, ensure the same batch/behavior is used for both evaluations (often done via a closure in frameworks).
  • Applying weight decay or regularization inconsistently during the two passes. Keep the same loss definition across both passes, or apply decay only in the final update.
  • Mixing norms unintentionally. If you intend L∞ SAM, use sign-based perturbation; if L2 SAM, normalize by the L2 norm. Verify numerical stability with a small epsilon (e.g., 1e-12) to avoid division by zero.

Key Formulas

SAM Robust Objective

w∈Rdmin​∥ϵ∥p​≤ρmax​L(w+ϵ)

Explanation: We choose parameters w that minimize the worst possible loss after any small perturbation ε within a p-norm ball of radius ρ. This formalizes the idea of robustness to tiny parameter changes.

First-Order Taylor Approximation

L(w+ϵ)≈L(w)+∇L(w)⊤ϵ

Explanation: Near w, the change in loss is approximately the inner product of the gradient and the perturbation. This justifies finding ε along the gradient direction for the inner maximization.

Dual Norm Relation

∥ϵ∥p​≤1max​∇L(w)⊤ϵ=∥∇L(w)∥q​,with p1​+q1​=1

Explanation: The worst-case linear increase under a unit p-ball equals the q-norm of the gradient, where q is the dual of p. This yields a closed-form objective value for the inner problem.

L2-Ball Maximizer

ϵ∗=ρ⋅∥∇L(w)∥2​∇L(w)​

Explanation: Within an L2 ball, the perturbation that maximizes the linearized loss points exactly in the gradient direction with magnitude ρ. This is the most common SAM variant.

L∞-Ball Maximizer

ϵi∗​=ρ⋅sign([∇L(w)]i​)

Explanation: Within an L∞ ball, the perturbation that maximizes the linearized loss sets each component to ±ρ based on the gradient’s sign. This yields a sign-based perturbation.

SAM Update Rule

wt+1​=wt​−η∇L(wt​+ϵ∗(wt​))

Explanation: After forming the adversarial perturbation using the current gradient, we compute the gradient at the perturbed weights and perform a standard gradient step. Any base optimizer can be used for this step.

Sharpness Proxy

Sρ​(w):=∥ϵ∥≤ρmax​L(w+ϵ)−L(w)≈ρ∥∇L(w)∥q​

Explanation: The worst-case loss increase within a small ball is approximately proportional to the gradient’s q-norm. Minimizing this proxy encourages flatter regions where the loss rises slowly.

Compute Overhead

CSAM​≈2Cbase​

Explanation: SAM requires two forward-backward evaluations per iteration: one to find ε and one to compute the update gradient. This roughly doubles the per-step compute cost.

SAM with Momentum

vt+1​=μvt​+∇L(wt​+ϵ∗),wt+1​=wt​−ηvt+1​

Explanation: SAM can be paired with momentum by accumulating the gradient evaluated at the perturbed weights before applying the update. This smooths noisy gradients.

Complexity Analysis

Let d be the number of parameters and n the mini-batch size. For linear or logistic models with dense features, one forward-backward pass computes gradients in O(nd) time and O(d) additional memory (beyond storing the batch). SAM performs two gradient computations per iteration: (1) at w to form the perturbation ε, and (2) at w+ε to compute the update gradient. Thus, the per-iteration time is about 2·O(nd) = O(nd), with a constant factor ≈2 compared to the base optimizer. Forming ε itself costs O(d) time and O(d) memory, dominated by the gradient evaluations. Memory overhead includes storing the temporary perturbed weights (or ε), adding O(d), and optionally momentum/Adam states (O(d) more). Therefore, the total auxiliary memory for SAM+SGD is O(d), and for SAM+momentum O(d) (velocity), both modest relative to model size. In deep networks, each pass corresponds to a full forward-backward over the architecture; SAM therefore incurs roughly a 2× compute overhead per step and slight activation-memory increases if both passes are held simultaneously (commonly avoided by sequential passes). Convergence-wise, SAM may require fewer epochs to reach good generalization due to its regularizing effect, but wall-clock time often increases unless compensated by larger batch sizes or fewer epochs. Proper tuning of ρ and learning rate can mitigate slowdowns while preserving generalization gains.

Code Examples

L2-SAM for Linear Regression (MSE) with SGD
1#include <bits/stdc++.h>
2using namespace std;
3
4struct Dataset {
5 vector<vector<double>> X; // n x d
6 vector<double> y; // n
7};
8
9// Utility functions
10static double dot(const vector<double>& a, const vector<double>& b) {
11 double s = 0.0; for (size_t i = 0; i < a.size(); ++i) s += a[i] * b[i]; return s;
12}
13static double l2_norm(const vector<double>& v) {
14 double s = 0.0; for (double x : v) s += x * x; return sqrt(max(s, 0.0));
15}
16static void add_scaled(vector<double>& dst, const vector<double>& src, double alpha) {
17 for (size_t i = 0; i < dst.size(); ++i) dst[i] += alpha * src[i];
18}
19
20// Linear regression: prediction = w^T x + b (packed as params [w0..wd-1, b])
21static double predict_one(const vector<double>& params, const vector<double>& x) {
22 size_t d = x.size();
23 double b = params[d];
24 return dot(params, x) + b - 0.0; // dot uses params[0..d-1]; params[d] added as bias
25}
26
27// Compute MSE loss and gradient over a mini-batch
28static double mse_and_grad(const Dataset& data, const vector<int>& batch_idx,
29 const vector<double>& params, vector<double>& grad) {
30 size_t d = params.size() - 1; // last is bias
31 fill(grad.begin(), grad.end(), 0.0);
32 double loss = 0.0;
33 size_t m = batch_idx.size();
34 for (int idx : batch_idx) {
35 const auto& x = data.X[idx];
36 double y = data.y[idx];
37 double yhat = predict_one(params, x);
38 double err = yhat - y; // residual
39 loss += err * err; // squared error
40 // gradient: d/dw = 2/m * err * x; d/db = 2/m * err
41 for (size_t j = 0; j < d; ++j) grad[j] += (2.0 / m) * err * x[j];
42 grad[d] += (2.0 / m) * err; // bias term
43 }
44 return loss / m; // mean squared error
45}
46
47// One SAM step (L2 ball):
48// 1) g = grad L(w)
49// 2) epsilon = rho * g / ||g||_2
50// 3) g_sam = grad L(w + epsilon)
51// 4) w <- w - eta * g_sam
52static void sam_step_l2(const Dataset& data, const vector<int>& batch_idx,
53 vector<double>& params, double lr, double rho) {
54 size_t dtotal = params.size();
55 vector<double> g(dtotal, 0.0), g_sam(dtotal, 0.0);
56
57 // Compute gradient at current params
58 mse_and_grad(data, batch_idx, params, g);
59
60 // Form adversarial perturbation epsilon
61 double normg = l2_norm(g);
62 const double eps = 1e-12; // numerical stability
63 vector<double> epsilon(dtotal, 0.0);
64 if (normg > eps) {
65 for (size_t i = 0; i < dtotal; ++i) epsilon[i] = rho * g[i] / normg;
66 } // else leave epsilon = 0
67
68 // Compute gradient at perturbed params
69 vector<double> params_pert = params;
70 add_scaled(params_pert, epsilon, 1.0);
71 mse_and_grad(data, batch_idx, params_pert, g_sam);
72
73 // Update using gradient at perturbed point
74 for (size_t i = 0; i < dtotal; ++i) params[i] -= lr * g_sam[i];
75}
76
77// Generate a simple synthetic linear dataset: y = w* x + b + noise
78static Dataset make_synthetic_linear(size_t n, size_t d, unsigned seed=42) {
79 mt19937 rng(seed);
80 normal_distribution<double> nx(0.0, 1.0), nnoise(0.0, 0.1);
81 vector<double> w_true(d, 0.0);
82 for (size_t j = 0; j < d; ++j) w_true[j] = (j % 2 == 0 ? 1.0 : -0.5);
83 double b_true = 0.7;
84
85 Dataset data; data.X.resize(n, vector<double>(d)); data.y.resize(n);
86 for (size_t i = 0; i < n; ++i) {
87 for (size_t j = 0; j < d; ++j) data.X[i][j] = nx(rng);
88 double y = dot(w_true, data.X[i]) + b_true + nnoise(rng);
89 data.y[i] = y;
90 }
91 return data;
92}
93
94int main() {
95 ios::sync_with_stdio(false);
96 cin.tie(nullptr);
97
98 size_t n = 512, d = 10;
99 Dataset data = make_synthetic_linear(n, d);
100
101 // Parameters (w, b) initialized to zeros
102 vector<double> params(d + 1, 0.0);
103
104 // Training hyperparameters
105 double lr = 0.05; // learning rate
106 double rho = 0.05; // SAM radius
107 int epochs = 50; // number of passes
108 int batch_size = 64;
109
110 mt19937 rng(123);
111 vector<int> indices(n); iota(indices.begin(), indices.end(), 0);
112
113 for (int ep = 0; ep < epochs; ++ep) {
114 shuffle(indices.begin(), indices.end(), rng);
115 for (size_t start = 0; start < n; start += batch_size) {
116 size_t end = min(n, start + (size_t)batch_size);
117 vector<int> batch(indices.begin() + start, indices.begin() + end);
118 sam_step_l2(data, batch, params, lr, rho);
119 }
120 // Evaluate full training MSE each epoch
121 vector<double> gtmp(params.size(), 0.0);
122 double loss = mse_and_grad(data, indices, params, gtmp);
123 cout << "Epoch " << ep + 1 << ": MSE = " << loss << "\n";
124 }
125 // Print final parameters
126 cout << "Final bias (b): " << params[d] << "\n";
127 cout << "First 5 weights: ";
128 for (size_t j = 0; j < min((size_t)5, d); ++j) cout << params[j] << ' ';
129 cout << "\n";
130 return 0;
131}
132

This program trains a linear regression model using L2-SAM. For each mini-batch, it computes the gradient, forms an L2-normalized perturbation ε with radius ρ, re-computes the gradient at w+ε, and updates parameters with SGD. The synthetic dataset ensures the code runs end-to-end and shows SAM’s training dynamics.

Time: Per epoch: O(n d). Each mini-batch step does two gradient evaluations, so it is ~2× the cost of vanilla SGD, but still Θ(n d) overall.Space: O(d) for parameters and gradients, plus O(d) for the temporary perturbation and perturbed parameters.
L∞-SAM for Logistic Regression with Momentum
1#include <bits/stdc++.h>
2using namespace std;
3
4struct Dataset {
5 vector<vector<double>> X; // n x d
6 vector<int> y; // n (0 or 1)
7};
8
9static double dot(const vector<double>& a, const vector<double>& b) {
10 double s = 0.0; for (size_t i = 0; i < a.size(); ++i) s += a[i] * b[i]; return s;
11}
12static double sigmoid(double z) { return 1.0 / (1.0 + exp(-z)); }
13static void add_scaled(vector<double>& dst, const vector<double>& src, double alpha) {
14 for (size_t i = 0; i < dst.size(); ++i) dst[i] += alpha * src[i];
15}
16static void sign_vector(const vector<double>& v, vector<double>& sgn) {
17 for (size_t i = 0; i < v.size(); ++i) sgn[i] = (v[i] > 0 ? 1.0 : (v[i] < 0 ? -1.0 : 0.0));
18}
19
20// Logistic regression: p = sigmoid(w^T x + b). Loss: average binary cross-entropy.
21static double bce_and_grad(const Dataset& data, const vector<int>& batch_idx,
22 const vector<double>& params, vector<double>& grad) {
23 size_t d = params.size() - 1; // last is bias
24 fill(grad.begin(), grad.end(), 0.0);
25 double loss = 0.0;
26 size_t m = batch_idx.size();
27 for (int idx : batch_idx) {
28 const auto& x = data.X[idx];
29 int y = data.y[idx];
30 double z = dot(params, x) + params[d];
31 double p = sigmoid(z);
32 // BCE: -[ y log p + (1-y) log(1-p) ]
33 const double eps = 1e-12;
34 loss += -( y * log(max(p, eps)) + (1 - y) * log(max(1 - p, eps)) );
35 double err = (p - y); // gradient of BCE wrt z
36 for (size_t j = 0; j < d; ++j) grad[j] += err * x[j] / m;
37 grad[d] += err / m; // bias
38 }
39 return loss / m;
40}
41
42// One L∞-SAM step with momentum base optimizer
43static void sam_step_linf_momentum(const Dataset& data, const vector<int>& batch_idx,
44 vector<double>& params, vector<double>& velocity,
45 double lr, double rho, double momentum) {
46 size_t dtotal = params.size();
47 vector<double> g(dtotal, 0.0), g_sam(dtotal, 0.0), eps_vec(dtotal, 0.0);
48
49 // 1) Gradient at current params
50 bce_and_grad(data, batch_idx, params, g);
51
52 // 2) L∞-ball perturbation: epsilon_i = rho * sign(g_i)
53 sign_vector(g, eps_vec);
54 for (size_t i = 0; i < dtotal; ++i) eps_vec[i] *= rho;
55
56 // 3) Gradient at perturbed params
57 vector<double> params_pert = params; add_scaled(params_pert, eps_vec, 1.0);
58 bce_and_grad(data, batch_idx, params_pert, g_sam);
59
60 // 4) Momentum update using g_sam
61 for (size_t i = 0; i < dtotal; ++i) {
62 velocity[i] = momentum * velocity[i] + g_sam[i];
63 params[i] -= lr * velocity[i];
64 }
65}
66
67// Create a simple 2D, roughly linearly separable dataset
68static Dataset make_classification(size_t n_per_class=200, unsigned seed=7) {
69 mt19937 rng(seed);
70 normal_distribution<double> n1x(-1.0, 0.6), n1y(0.0, 0.6);
71 normal_distribution<double> n2x(1.0, 0.6), n2y(0.2, 0.6);
72 Dataset data; data.X.reserve(2*n_per_class); data.y.reserve(2*n_per_class);
73 for (size_t i = 0; i < n_per_class; ++i) {
74 data.X.push_back({n1x(rng), n1y(rng)}); data.y.push_back(0);
75 }
76 for (size_t i = 0; i < n_per_class; ++i) {
77 data.X.push_back({n2x(rng), n2y(rng)}); data.y.push_back(1);
78 }
79 return data;
80}
81
82int main() {
83 ios::sync_with_stdio(false);
84 cin.tie(nullptr);
85
86 Dataset data = make_classification();
87 size_t n = data.X.size(); size_t d = data.X[0].size();
88
89 vector<double> params(d + 1, 0.0); // w and b
90 vector<double> velocity(d + 1, 0.0); // momentum state
91
92 double lr = 0.1; // learning rate
93 double rho = 0.05; // SAM radius (L∞)
94 double mu = 0.9; // momentum
95 int epochs = 40; // training epochs
96 int batch_size = 64;
97
98 mt19937 rng(1234);
99 vector<int> indices(n); iota(indices.begin(), indices.end(), 0);
100
101 for (int ep = 0; ep < epochs; ++ep) {
102 shuffle(indices.begin(), indices.end(), rng);
103 for (size_t start = 0; start < n; start += batch_size) {
104 size_t end = min(n, start + (size_t)batch_size);
105 vector<int> batch(indices.begin() + start, indices.begin() + end);
106 sam_step_linf_momentum(data, batch, params, velocity, lr, rho, mu);
107 }
108 // Report average BCE on full data
109 vector<double> gtmp(params.size(), 0.0);
110 double loss = bce_and_grad(data, indices, params, gtmp);
111 cout << "Epoch " << ep + 1 << ": BCE = " << loss << "\n";
112 }
113
114 cout << "Final bias: " << params[d] << "\n";
115 cout << "Weights: "; for (size_t j = 0; j < d; ++j) cout << params[j] << ' '; cout << "\n";
116 return 0;
117}
118

This example trains a logistic regression classifier using L∞-SAM with momentum. The inner maximization over an L∞ ball sets ε_i = ρ·sign(g_i). We recompute the gradient at w+ε and then apply a momentum update. It demonstrates combining SAM with a common base optimizer and a different norm choice.

Time: Per epoch: O(n d). Each SAM step needs two gradient computations, maintaining Θ(n d) per epoch with an extra constant factor ≈2.Space: O(d) for parameters, gradients, and momentum; O(d) for the temporary perturbation and perturbed parameters.
#sharpness-aware minimization#sam optimizer#robust optimization#flat minima#dual norm#l2 norm#linf norm#generalization#momentum#stochastic gradient descent#logistic regression#mse#adversarial direction#neural network training