Bridge: Adam, Learning Rate Theory & Neural Loss Landscape Analysis
The optimization theory of Module 05 connects directly to the practical choices made in every deep learning training run: why Adam adapts step sizes per parameter, why warmup is necessary for large learning rates, why flat minima generalize better, and why sharpness-aware minimization (SAM) outperforms SGD on test accuracy despite no improvement in training loss. This lesson closes the loop from theory to practice.
Concepts
Every practical training choice in deep learning — step size, momentum, weight decay schedule, batch size — has a precise optimization-theoretic interpretation grounded in the theory of this module. Adam's adaptive step sizes are a diagonal Fisher approximation. Learning rate warmup avoids the "catapult phase" instability. SAM's second forward pass computes a gradient at the worst nearby point. Understanding why each works — not just that it works — lets you debug training failures systematically rather than by trial and error.
Adam as Approximate Natural Gradient
The Adam optimizer (Kingma & Ba, 2015):
with defaults , , .
Connection to natural gradient. As shown in Module 04, the natural gradient step uses , where is the Fisher information matrix. The diagonal of is:
The Adam denominator approximates — so is approximately the natural gradient with diagonal Fisher approximation plus momentum. Each parameter gets a step size adapted to its gradient magnitude, automatically correcting for different scale in different directions.
Interpretation. A parameter with consistently large gradients (high Fisher diagonal) gets a small effective step size (it's already well-informed). A parameter with small gradients (low Fisher) gets a large effective step size (needs more nudging). This is why Adam is dramatically better than plain SGD on ill-conditioned problems. The diagonal Fisher approximation is the cheapest possible natural gradient — using only memory versus for the full Fisher — and for most architectures it captures the dominant conditioning problem (different layers having different gradient magnitudes) while discarding the off-diagonal correlations that are expensive to maintain.
Convergence. For convex problems with -bounded gradients, Adam achieves regret — same as AdaGrad, but with better constants due to the exponential moving average. For non-convex problems with fixed : there exist counterexamples where Adam diverges without the correction.
AdamW (weight decay decoupled). Standard Adam with regularization adds the regularizer gradient to before computing moments — this interacts with the adaptive scaling. AdamW decouples weight decay by directly subtracting from parameters:
This correctly implements regularization without distorting the adaptive step sizes. Modern LLM training uses AdamW exclusively.
Learning Rate Schedules and the Catapult Phase
Why learning rate schedules matter. From convergence theory: for a strongly convex -smooth problem, constant step is optimal for GD. But in practice, deep networks benefit from:
- Starting with a large learning rate (explores the landscape, finds flat basins)
- Decaying the learning rate (converges tightly within a basin)
Linear warmup. Large initial learning rates cause instability at the start — the loss can spike or diverge because the gradients are misaligned with the loss landscape. Linearly increasing from 0 to over steps ("warmup") stabilizes early training.
Theoretical grounding (Lewkowycz et al., 2020 — "catapult phase"). For large learning rates above a threshold (twice the smoothness constant), gradient descent enters the "catapult phase" — the sharpness (max Hessian eigenvalue) initially increases (progressive sharpening), then stabilizes at (edge of stability). The iterate is "launched" into flatter regions of the landscape. This is why large learning rates find flatter minima.
Cosine annealing. The schedule for provides smooth decay from to 0. Used with restarts (SGDR, Loshchilov & Hutter, 2017):
where doubles after each restart. The periodically-reset learning rate escapes local basins and samples multiple solutions — weight averaging across restarts often improves generalization.
OneCycleLR (super-convergence, Smith, 2018): a single cycle with a large peak learning rate — training can converge in the usual epochs. The peak learning rate is set at the boundary of instability.
Edge of Stability
Edge of Stability (EoS) (Cohen et al., 2021): in practice, gradient descent with a fixed learning rate drives the sharpness to stabilize around , even when (formally in the "unstable" regime).
Why? When sharpness exceeds , GD makes oscillatory steps that happen to reduce sharpness — a self-stabilizing feedback loop. The iterate oscillates within a ball while the ball drifts toward lower loss.
Implication. For a given step size , training converges to a region where . Larger → smaller sharpness → flatter minimum → better generalization. This gives formal support for the empirical observation that SGD with large step sizes generalizes better.
Sharpness-Aware Minimization (SAM)
SAM (Foret et al., 2021) directly minimizes the worst-case loss in a ball around the current parameters:
Inner maximization (approximate). The maximizer over the -ball:
SAM update:
- Compute gradient at perturbed point:
- Update:
This requires two forward-backward passes per step (one to compute , one to compute ) — doubling compute cost. In return, SAM consistently finds flatter minima with lower and better test performance.
Connection to PAC-Bayes. The SAM objective upper-bounds the PAC-Bayes generalization bound:
Minimizing the SAM objective directly attacks the right-hand side's sharpness term.
Optimizer Comparison Table
| Optimizer | Step-size adaptation | Convergence guarantee | Practical default? |
|---|---|---|---|
| SGD | None (uniform) | convex | For CV with LR tuning |
| SGD+Momentum | Momentum (Polyak) | Same as SGD | Wide use in CV |
| AdaGrad | Cumulative | Sparse data, NLP (old) | |
| RMSProp | EMA of | Heuristic | Hidden layers in RNNs |
| Adam | EMA of | Default for LLMs/NLP | |
| AdamW | Adam + decoupled decay | Same as Adam | LLM pre-training |
| SAM+Adam | Adam + flatness penalty | None (non-convex) | SOTA image classif. |
Worked Example
Example 1: Adam Convergence Bound
For a convex problem with -bounded gradients (), Adam's regret after rounds satisfies:
In the adaptive case where gradients are sparse (many are zero), the second term can be much smaller than for uniform SGD — this is why Adam is particularly effective for embeddings and attention layers with sparse gradient patterns.
Example 2: Cosine Annealing Schedule
For a 300-epoch training run with and 5-epoch warmup, the learning rate at epoch :
At : . At : . At : .
The gradual decay prevents the oscillatory behavior of a step schedule (which abruptly drops , causing the model to "jolt" into a sharper basin) while also fully converging to a stationary point as .
Example 3: SAM Two-Step Update
For a batch loss with :
Step 1 (perturbation): Compute . Set (gradient direction normalized to -sphere).
Step 2 (gradient at perturbed point): Compute . This is the gradient at the worst-case nearby point.
Update: .
The key difference from vanilla GD: points in the direction of steepest ascent at the perturbed point , not at . Near a sharp minimum, pushes toward the sharp ridge, and then moves away from it — the net effect is to seek flatter regions.
Connections
Where Your Intuition Breaks
AdamW is so ubiquitous that it's tempting to treat it as theoretically principled for all settings. But AdamW has no convergence guarantee for non-convex objectives with a fixed learning rate — there exist pathological (if contrived) non-convex problems where Adam oscillates without converging. The reason it works in practice for neural networks is not covered by any general theorem; it's empirical robustness on the specific geometry of neural loss surfaces. More practically: AdamW's bias-correction terms (, ) are critical during the first few hundred steps. At step with , the uncorrected — a near-zero denominator that would cause enormous steps without the correction. This is why learning rate warmup is not optional when using AdamW at large scale: the optimizer is not theoretically stable at full learning rate from step 1.
Adam's step size adaptation is diagonal natural gradient descent. The update divides the (smoothed) gradient by the (smoothed) RMS of past gradients. This is exactly the natural gradient with a diagonal Fisher approximation: each parameter's step size is scaled inversely to the information it carries about the loss. Parameters with high gradient variance (high information) take small steps; parameters with low gradient variance take large steps. This is why Adam handles different parameter scales automatically, whereas SGD with a global learning rate struggles on problems with parameters of very different magnitudes.
The edge of stability connects theory to practice. The finding that GD with step size drives sharpness to has a beautiful implication: the learning rate sets the sharpness of the minimum found! A larger gives a flatter minimum (lower sharpness), which empirically corresponds to better generalization. This gives a mechanistic explanation for the folklore that "larger learning rates generalize better" — it's not just exploration, it's that larger LR directly imposes a flatness constraint on the found solution.
SAM's double compute cost is not always worth it. SAM requires two backward passes per step. At billion-parameter scale, this is often prohibitive. Efficient SAM variants (ASAM, mSAM, sparse SAM) reduce the cost by applying perturbations only to a subset of parameters or normalizing by parameter magnitude. For models trained at limited compute budgets (most academic research), the SAM overhead may outweigh its generalization benefit relative to simply training longer with standard Adam and better data augmentation. The generalization benefit of SAM is most pronounced for small datasets and without strong augmentation.
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.