Neural-Path/Notes
35 min

Bridge: Adam, Learning Rate Theory & Neural Loss Landscape Analysis

The optimization theory of Module 05 connects directly to the practical choices made in every deep learning training run: why Adam adapts step sizes per parameter, why warmup is necessary for large learning rates, why flat minima generalize better, and why sharpness-aware minimization (SAM) outperforms SGD on test accuracy despite no improvement in training loss. This lesson closes the loop from theory to practice.

Concepts

Every practical training choice in deep learning — step size, momentum, weight decay schedule, batch size — has a precise optimization-theoretic interpretation grounded in the theory of this module. Adam's adaptive step sizes are a diagonal Fisher approximation. Learning rate warmup avoids the "catapult phase" instability. SAM's second forward pass computes a gradient at the worst nearby point. Understanding why each works — not just that it works — lets you debug training failures systematically rather than by trial and error.

Adam as Approximate Natural Gradient

The Adam optimizer (Kingma & Ba, 2015):

mt=β1mt1+(1β1)Lt(first moment, momentum)vt=β2vt1+(1β2)(Lt)2(second moment, adaptive scaling)m^t=mt/(1β1t),v^t=vt/(1β2t)(bias correction)θt=θt1ηm^t/(v^t+ε)\begin{aligned} m_t &= \beta_1 m_{t-1} + (1-\beta_1)\nabla\mathcal{L}_t \quad \text{(first moment, momentum)} \\ v_t &= \beta_2 v_{t-1} + (1-\beta_2)(\nabla\mathcal{L}_t)^2 \quad \text{(second moment, adaptive scaling)} \\ \hat{m}_t &= m_t / (1-\beta_1^t), \quad \hat{v}_t = v_t/(1-\beta_2^t) \quad \text{(bias correction)} \\ \theta_t &= \theta_{t-1} - \eta \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) \end{aligned}

with defaults β1=0.9\beta_1 = 0.9, β2=0.999\beta_2 = 0.999, ε=108\varepsilon = 10^{-8}.

Connection to natural gradient. As shown in Module 04, the natural gradient step uses I(θ)1L\mathcal{I}(\theta)^{-1}\nabla\mathcal{L}, where I(θ)\mathcal{I}(\theta) is the Fisher information matrix. The diagonal of I(θ)\mathcal{I}(\theta) is:

I(θ)jj=E ⁣[(logp(yx;θ)θj)2]E[(jL)2].\mathcal{I}(\theta)_{jj} = \mathbb{E}\!\left[\left(\frac{\partial \log p(y|x;\theta)}{\partial\theta_j}\right)^2\right] \approx \mathbb{E}[(\nabla_j\mathcal{L})^2].

The Adam denominator v^t\sqrt{\hat{v}_t} approximates I(θ)jj\sqrt{\mathcal{I}(\theta)_{jj}} — so m^t/(v^t+ε)\hat{m}_t/(\sqrt{\hat{v}_t}+\varepsilon) is approximately the natural gradient with diagonal Fisher approximation plus momentum. Each parameter gets a step size adapted to its gradient magnitude, automatically correcting for different scale in different directions.

Interpretation. A parameter with consistently large gradients (high Fisher diagonal) gets a small effective step size (it's already well-informed). A parameter with small gradients (low Fisher) gets a large effective step size (needs more nudging). This is why Adam is dramatically better than plain SGD on ill-conditioned problems. The diagonal Fisher approximation is the cheapest possible natural gradient — using only O(n)O(n) memory versus O(n2)O(n^2) for the full Fisher — and for most architectures it captures the dominant conditioning problem (different layers having different gradient magnitudes) while discarding the off-diagonal correlations that are expensive to maintain.

Convergence. For convex problems with GG-bounded gradients, Adam achieves O(GlogT/T)O(G\log T/\sqrt{T}) regret — same as AdaGrad, but with better constants due to the exponential moving average. For non-convex problems with fixed η\eta: there exist counterexamples where Adam diverges without the ε\varepsilon correction.

AdamW (weight decay decoupled). Standard Adam with L2L_2 regularization adds the regularizer gradient to L\nabla\mathcal{L} before computing moments — this interacts with the adaptive scaling. AdamW decouples weight decay by directly subtracting from parameters:

θtθt1η(m^t/(v^t+ε)+λθt1).\theta_t \leftarrow \theta_{t-1} - \eta\left(\hat{m}_t/(\sqrt{\hat{v}_t}+\varepsilon) + \lambda\theta_{t-1}\right).

This correctly implements L2L_2 regularization without distorting the adaptive step sizes. Modern LLM training uses AdamW exclusively.

Learning Rate Schedules and the Catapult Phase

Why learning rate schedules matter. From convergence theory: for a strongly convex LL-smooth problem, constant step α=1/L\alpha = 1/L is optimal for GD. But in practice, deep networks benefit from:

  1. Starting with a large learning rate (explores the landscape, finds flat basins)
  2. Decaying the learning rate (converges tightly within a basin)

Linear warmup. Large initial learning rates cause instability at the start — the loss can spike or diverge because the gradients are misaligned with the loss landscape. Linearly increasing η\eta from 0 to ηmax\eta_{\max} over TwarmT_\text{warm} steps ("warmup") stabilizes early training.

Theoretical grounding (Lewkowycz et al., 2020 — "catapult phase"). For large learning rates above a threshold ηc2/λmax(2L)\eta_c \approx 2/\lambda_{\max}(\nabla^2\mathcal{L}) (twice the smoothness constant), gradient descent enters the "catapult phase" — the sharpness (max Hessian eigenvalue) initially increases (progressive sharpening), then stabilizes at 2/η2/\eta (edge of stability). The iterate is "launched" into flatter regions of the landscape. This is why large learning rates find flatter minima.

Cosine annealing. The schedule ηt=ηmax1+cos(πt/T)2\eta_t = \eta_{\max}\cdot\frac{1+\cos(\pi t/T)}{2} for t[0,T]t \in [0,T] provides smooth decay from ηmax\eta_{\max} to 0. Used with restarts (SGDR, Loshchilov & Hutter, 2017):

ηt=ηmin+12(ηmaxηmin)(1+cos ⁣(π(tmodTi)Ti)),\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max}-\eta_{\min})\left(1+\cos\!\left(\frac{\pi (t \mod T_i)}{T_i}\right)\right),

where TiT_i doubles after each restart. The periodically-reset learning rate escapes local basins and samples multiple solutions — weight averaging across restarts often improves generalization.

OneCycleLR (super-convergence, Smith, 2018): a single cycle with a large peak learning rate — training can converge in 1/101/10 the usual epochs. The peak learning rate is set at the boundary of instability.

Edge of Stability

Edge of Stability (EoS) (Cohen et al., 2021): in practice, gradient descent with a fixed learning rate η\eta drives the sharpness λmax(2L)\lambda_{\max}(\nabla^2\mathcal{L}) to stabilize around 2/η2/\eta, even when 2/η>L2/\eta > L (formally in the "unstable" regime).

Why? When sharpness exceeds 2/η2/\eta, GD makes oscillatory steps that happen to reduce sharpness — a self-stabilizing feedback loop. The iterate oscillates within a ball while the ball drifts toward lower loss.

Implication. For a given step size η\eta, training converges to a region where λmax(2L)2/η\lambda_{\max}(\nabla^2\mathcal{L}) \approx 2/\eta. Larger η\eta → smaller sharpness → flatter minimum → better generalization. This gives formal support for the empirical observation that SGD with large step sizes generalizes better.

Sharpness-Aware Minimization (SAM)

SAM (Foret et al., 2021) directly minimizes the worst-case loss in a ball around the current parameters:

minθmaxϵρL(θ+ϵ).\min_\theta \max_{\|\epsilon\| \leq \rho} \mathcal{L}(\theta + \epsilon).

Inner maximization (approximate). The maximizer over the ρ\rho-ball:

ϵ^(θ)=argmaxϵρL(θ+ϵ)ρθL(θ)θL(θ).\hat\epsilon(\theta) = \arg\max_{\|\epsilon\|\leq\rho} \mathcal{L}(\theta+\epsilon) \approx \rho\cdot\frac{\nabla_\theta\mathcal{L}(\theta)}{\|\nabla_\theta\mathcal{L}(\theta)\|}.

SAM update:

  1. Compute gradient at perturbed point: g=θL(θ+ϵ^(θ))g = \nabla_\theta\mathcal{L}(\theta + \hat\epsilon(\theta))
  2. Update: θθηg\theta \leftarrow \theta - \eta g

This requires two forward-backward passes per step (one to compute ϵ^\hat\epsilon, one to compute gg) — doubling compute cost. In return, SAM consistently finds flatter minima with lower λmax(2L)\lambda_{\max}(\nabla^2\mathcal{L}) and better test performance.

Connection to PAC-Bayes. The SAM objective upper-bounds the PAC-Bayes generalization bound:

LtestLtrain+tr(2L)ndataLtrain+O ⁣(λmaxndata).\mathcal{L}_{\text{test}} \lesssim \mathcal{L}_{\text{train}} + \sqrt{\frac{\text{tr}(\nabla^2\mathcal{L})}{n_{\text{data}}}} \lesssim \mathcal{L}_{\text{train}} + O\!\left(\frac{\lambda_{\max}}{\sqrt{n_{\text{data}}}}\right).

Minimizing the SAM objective directly attacks the right-hand side's sharpness term.

Optimizer Comparison Table

OptimizerStep-size adaptationConvergence guaranteePractical default?
SGDNone (uniform)O(1/T)O(1/\sqrt{T}) convexFor CV with LR tuning
SGD+MomentumMomentum (Polyak)Same as SGDWide use in CV
AdaGradCumulative gt2\sum g_t^2O(logT/T)O(\log T/\sqrt{T})Sparse data, NLP (old)
RMSPropEMA of gt2g_t^2HeuristicHidden layers in RNNs
AdamEMA of gt,gt2g_t, g_t^2O(logT/T)O(\log T/\sqrt{T})Default for LLMs/NLP
AdamWAdam + decoupled decaySame as AdamLLM pre-training
SAM+AdamAdam + flatness penaltyNone (non-convex)SOTA image classif.

Worked Example

Example 1: Adam Convergence Bound

For a convex problem with GG-bounded gradients (ftG\|\nabla f_t\| \leq G), Adam's regret after TT rounds satisfies:

Regret(T)=t=1Tft(xt)minxt=1Tft(x)dTG2η1β2T1β2(1β1)ε+O ⁣(G2η(1β1)jg1:T,j2).\text{Regret}(T) = \sum_{t=1}^T f_t(x_t) - \min_x \sum_{t=1}^T f_t(x) \leq \frac{d\sqrt{T}G^2\eta\sqrt{1-\beta_2^T}}{\sqrt{1-\beta_2}(1-\beta_1)\varepsilon} + O\!\left(\frac{G^2}{\eta(1-\beta_1)}\sum_j \|g_{1:T,j}\|_2\right).

In the adaptive case where gradients are sparse (many gt,jg_{t,j} are zero), the second term can be much smaller than for uniform SGD — this is why Adam is particularly effective for embeddings and attention layers with sparse gradient patterns.

Example 2: Cosine Annealing Schedule

For a 300-epoch training run with ηmax=0.1\eta_{\max} = 0.1 and 5-epoch warmup, the learning rate at epoch tt:

ηt={0.1t/5t5 (linear warmup)0.11+cos(π(t5)/295)2t>5 (cosine decay)\eta_t = \begin{cases} 0.1 \cdot t/5 & t \leq 5 \text{ (linear warmup)} \\ 0.1 \cdot \frac{1+\cos(\pi(t-5)/295)}{2} & t > 5 \text{ (cosine decay)} \end{cases}

At t=5t = 5: η=0.1\eta = 0.1. At t=150t = 150: η0.1(1+cos(π/2))/2=0.05\eta \approx 0.1\cdot(1+\cos(\pi/2))/2 = 0.05. At t=300t = 300: η0.1(1+cosπ)/2=0\eta \approx 0.1\cdot(1+\cos\pi)/2 = 0.

The gradual decay prevents the oscillatory behavior of a step schedule (which abruptly drops η\eta, causing the model to "jolt" into a sharper basin) while also fully converging to a stationary point as η0\eta \to 0.

Example 3: SAM Two-Step Update

For a batch loss L(θ)\mathcal{L}(\theta) with ρ=0.05\rho = 0.05:

Step 1 (perturbation): Compute g0=θL(θ)g_0 = \nabla_\theta\mathcal{L}(\theta). Set ϵ^=ρg0/g0\hat\epsilon = \rho\cdot g_0/\|g_0\| (gradient direction normalized to ρ\rho-sphere).

Step 2 (gradient at perturbed point): Compute gSAM=θL(θ+ϵ^)g_{\text{SAM}} = \nabla_\theta\mathcal{L}(\theta + \hat\epsilon). This is the gradient at the worst-case nearby point.

Update: θθηgSAM\theta \leftarrow \theta - \eta g_{\text{SAM}}.

The key difference from vanilla GD: gSAMg_{\text{SAM}} points in the direction of steepest ascent at the perturbed point θ+ϵ^\theta + \hat\epsilon, not at θ\theta. Near a sharp minimum, ϵ^\hat\epsilon pushes toward the sharp ridge, and gSAMg_{\text{SAM}} then moves away from it — the net effect is to seek flatter regions.

Connections

Where Your Intuition Breaks

AdamW is so ubiquitous that it's tempting to treat it as theoretically principled for all settings. But AdamW has no convergence guarantee for non-convex objectives with a fixed learning rate — there exist pathological (if contrived) non-convex problems where Adam oscillates without converging. The reason it works in practice for neural networks is not covered by any general theorem; it's empirical robustness on the specific geometry of neural loss surfaces. More practically: AdamW's bias-correction terms (1β1t1-\beta_1^t, 1β2t1-\beta_2^t) are critical during the first few hundred steps. At step t=1t=1 with β2=0.999\beta_2 = 0.999, the uncorrected v1=0.001g12v_1 = 0.001 \cdot g_1^2 — a near-zero denominator that would cause enormous steps without the /(1β2t)/(1-\beta_2^t) correction. This is why learning rate warmup is not optional when using AdamW at large scale: the optimizer is not theoretically stable at full learning rate from step 1.

💡Intuition

Adam's step size adaptation is diagonal natural gradient descent. The update m^t/v^t\hat{m}_t/\sqrt{\hat{v}_t} divides the (smoothed) gradient by the (smoothed) RMS of past gradients. This is exactly the natural gradient with a diagonal Fisher approximation: each parameter's step size is scaled inversely to the information it carries about the loss. Parameters with high gradient variance (high information) take small steps; parameters with low gradient variance take large steps. This is why Adam handles different parameter scales automatically, whereas SGD with a global learning rate struggles on problems with parameters of very different magnitudes.

💡Intuition

The edge of stability connects theory to practice. The finding that GD with step size η\eta drives sharpness to 2/η2/\eta has a beautiful implication: the learning rate sets the sharpness of the minimum found! A larger η\eta gives a flatter minimum (lower sharpness), which empirically corresponds to better generalization. This gives a mechanistic explanation for the folklore that "larger learning rates generalize better" — it's not just exploration, it's that larger LR directly imposes a flatness constraint on the found solution.

⚠️Warning

SAM's double compute cost is not always worth it. SAM requires two backward passes per step. At billion-parameter scale, this is often prohibitive. Efficient SAM variants (ASAM, mSAM, sparse SAM) reduce the cost by applying perturbations only to a subset of parameters or normalizing by parameter magnitude. For models trained at limited compute budgets (most academic research), the SAM overhead may outweigh its generalization benefit relative to simply training longer with standard Adam and better data augmentation. The generalization benefit of SAM is most pronounced for small datasets and without strong augmentation.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.