Neural-Path/Notes
25 min

Training Dynamics

A neural network is only as good as its training run. The same architecture with poor initialization, a mistuned learning rate, or unclipped gradients can fail to converge entirely — while the same architecture with proper setup trains stably to strong results. This lesson covers the four pillars of stable deep learning training: weight initialization, learning rate schedules, gradient clipping, and batch size effects.

Theory

warmup steps20
0.000.250.500.751.00050100150200stepslr (relative)warmup end
Warmup + Cosine
Step Decay
Constant

drag slider to adjust warmup length · dashed cyan line = warmup end

Learning rate is the size of each step toward lower loss. Too large: you overshoot the minimum and bounce. Too small: you crawl and may get trapped in local minima. Schedules solve this by starting larger (explore the landscape) and decaying over time (refine toward a minimum). The three curves above show warmup+cosine, step decay, and constant — each encoding a different assumption about when to slow down.

Weight Initialization

Before the first forward pass, every weight must be initialized. The goal: keep activation variance roughly constant through all layers. Too large, and activations explode; too small, and they vanish — in both cases, gradients become useless before reaching early layers.

Xavier / Glorot Initialization (for sigmoid, tanh): derives the scale by requiring that variance is preserved through the linear transformation y=Wxy = Wx:

Var(yj)=ninVar(Wij)Var(xi)\text{Var}(y_j) = n_{in} \cdot \text{Var}(W_{ij}) \cdot \text{Var}(x_i)

Setting Var(y)=Var(x)\text{Var}(y) = \text{Var}(x) requires Var(W)=1/nin\text{Var}(W) = 1/n_{in}. Taking the average of forward and backward pass constraints:

WU ⁣(6nin+nout,  6nin+nout)orWN ⁣(0,  2nin+nout)W \sim \mathcal{U}\!\left(-\sqrt{\frac{6}{n_{in} + n_{out}}},\; \sqrt{\frac{6}{n_{in} + n_{out}}}\right) \quad \text{or} \quad W \sim \mathcal{N}\!\left(0,\; \sqrt{\frac{2}{n_{in} + n_{out}}}\right)

He / Kaiming Initialization (for ReLU): ReLU zeroes out half of its inputs, effectively halving the variance. The correction factor of 2 compensates:

WN ⁣(0,  2nin)W \sim \mathcal{N}\!\left(0,\; \sqrt{\frac{2}{n_{in}}}\right)

He initialization is the default for any network using ReLU or its variants (Leaky ReLU, GELU, SiLU). Xavier is better for sigmoid/tanh where both positive and negative outputs are preserved.

python
import torch.nn as nn
 
# PyTorch defaults
nn.Linear(256, 256)          # Kaiming uniform by default for Linear
nn.Conv2d(64, 128, 3)        # Kaiming uniform by default for Conv
 
# Explicit He init (after constructing the module)
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
 
# Explicit Xavier init
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)   # biases typically initialized to 0
💡Intuition

Think of initialization as tuning the gain of each layer before any signal has passed through. Too much gain and signals amplify to infinity; too little and they decay to zero. Kaiming and Xavier solve the same equation: what gain factor keeps the signal's amplitude stable from input to output?

Learning Rate Schedules

A fixed learning rate rarely gives best results. Too high early: unstable training. Too low throughout: slow convergence. The solution: vary the learning rate during training.

Warmup: start with a very small learning rate and ramp it up linearly over the first ww steps. This prevents large random gradients in the first steps (when weights are random) from destabilizing the model before it has learned useful representations:

lr(t)=lrmaxtw,tw\text{lr}(t) = \text{lr}_{\max} \cdot \frac{t}{w}, \quad t \leq w

Warmup is required for Transformer training (see the Attention lesson) and is increasingly used for any deep network.

Warmup is necessary because early in training, the gradient estimates are highly noisy — the network hasn't oriented itself in the loss landscape yet. A large learning rate applied to noisy gradients pushes parameters far in random directions. Starting small lets the model find a coarse direction first, then accelerating once the gradients become meaningful.

Cosine Annealing: after warmup, decay the learning rate following a cosine curve from lrmax\text{lr}_{\max} to lrmin\text{lr}_{\min}:

lr(t)=lrmin+12(lrmaxlrmin)(1+cosπ(tw)Tw)\text{lr}(t) = \text{lr}_{\min} + \frac{1}{2}(\text{lr}_{\max} - \text{lr}_{\min})\left(1 + \cos\frac{\pi (t - w)}{T - w}\right)

Cosine annealing is the most commonly used schedule in modern deep learning — it decays smoothly and avoids the abrupt drops of step decay.

Step Decay: multiply the learning rate by a factor γ<1\gamma < 1 every kk epochs. Simple but creates discontinuities that can cause instability:

lr(t)=lr0γt/k\text{lr}(t) = \text{lr}_0 \cdot \gamma^{\lfloor t/k \rfloor}

python
import torch.optim as optim
 
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
 
# Cosine annealing with warmup (using transformers library scheduler)
from transformers import get_cosine_schedule_with_warmup
 
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=500,
    num_training_steps=10_000,
)
 
# Or PyTorch native cosine annealing
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=10_000, eta_min=1e-6
)
 
# In training loop:
optimizer.step()
scheduler.step()
ScheduleShapeBest for
ConstantFlatBaselines, short runs
Warmup + cosineRamp up then smooth decayTransformers, LLMs, most modern networks
Step decayStaircaseCNNs (ResNet original paper)
OneCycleLRTriangle + decayFast training (super-convergence)

Gradient Clipping

When gradients become very large — especially in RNNs or deep networks with unstable initialization — a single step can send weights to extreme values, causing NaN loss or divergence. Gradient clipping rescales the gradient if its norm exceeds a threshold:

if g2>τ:gτgg2\text{if } \|\mathbf{g}\|_2 > \tau: \quad \mathbf{g} \leftarrow \tau \cdot \frac{\mathbf{g}}{\|\mathbf{g}\|_2}

The key insight: we clip the global gradient norm across all parameters, not per-parameter. This preserves the direction of the gradient while controlling its magnitude:

python
# After loss.backward(), before optimizer.step()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

A threshold of τ=1.0\tau = 1.0 is standard for Transformers and LLMs; τ=5.0\tau = 5.0 is common for RNNs. You can diagnose whether clipping is needed by logging the gradient norm during training — if it regularly exceeds your threshold by 10×+, your learning rate may be too high or your initialization poor.

⚠️Warning

Never clip gradients before calling loss.backward() — the gradient doesn't exist yet. The order is always: forward → loss → backward → clip → step. Clipping after optimizer.step() is also wrong — you'd be clipping already-applied updates.

Batch Size and Gradient Noise

The batch size controls how much of the dataset is averaged into each gradient update. Larger batches mean lower-variance (smoother) gradient estimates; smaller batches mean higher-variance (noisier) estimates.

The generalization tradeoff: large-batch training converges faster in wall-clock time but tends to find sharp minima — regions where loss is low but curvature is high, leading to worse generalization. Small-batch training finds flat minima that generalize better, because the noise acts as implicit regularization.

The linear scaling rule (Goyal et al., 2017): when increasing batch size by factor kk, scale the learning rate by kk to maintain equivalent training dynamics. With warmup this works well up to batch sizes of ~8K; beyond that, accuracy degrades.

lrnew=lrbase×BnewBbase\text{lr}_{\text{new}} = \text{lr}_{\text{base}} \times \frac{B_{\text{new}}}{B_{\text{base}}}

Gradient accumulation simulates a large batch on a small GPU by accumulating gradients over multiple forward passes before updating:

python
accumulation_steps = 4   # effective batch = batch_size × 4
optimizer.zero_grad()
 
for i, (X, y) in enumerate(train_loader):
    loss = criterion(model(X), y) / accumulation_steps
    loss.backward()                    # accumulate gradients
 
    if (i + 1) % accumulation_steps == 0:
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

Walkthrough

Diagnosing a Failing Training Run

Suppose you train a 10-layer MLP on CIFAR-10 and observe loss becoming NaN at step 200. Work through the checklist:

Step 1 — Check gradient norm. Add a log line before each optimizer step:

python
norm = nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
print(f"step {i}: grad_norm={norm:.2f}")

You see norms of 300–800 before the NaN. This points to exploding gradients, not bad data.

Step 2 — Identify the cause. High gradient norm with NaN loss has two likely causes: LR too large, or no warmup on a randomly initialized model. Check: did you use lr=1e-3 without warmup? With random weights, early gradient directions are noisy and high-magnitude.

Step 3 — Apply fixes in order:

  1. Add gradient clipping: clip_grad_norm_(model.parameters(), 1.0) — this alone may be enough
  2. Add a 500-step warmup before the full learning rate kicks in
  3. If still unstable, verify init: confirm the final linear layer uses kaiming_normal_ or PyTorch defaults

Step 4 — Verify recovery. After fixes, gradient norm should stay below 2.0 throughout training and loss should decrease monotonically in the first 1000 steps. If it doesn't, the LR itself may be too high — try halving it.

This pattern — log norm, isolate cause, apply the smallest fix — applies to any training instability.

Analysis & Evaluation

Where Your Intuition Breaks

Cosine decay should reach zero at the end of training — you're done, so the learning rate should be zero. In practice, decaying all the way to zero can slightly hurt final performance: a small residual learning rate allows continued adaptation to the tail of the data distribution. Most implementations use a minimum learning rate of ηmin=0.1×ηmax\eta_{\min} = 0.1 \times \eta_{\max} rather than zero.

Initialization vs No Initialization

Practical impact of initialization choice on a 10-layer MLP (ReLU activations):

InitMean activation (layer 10)Gradient norm (layer 1)Converges?
All zeros0 (symmetry broken never)0No
N(0,1)\mathcal{N}(0, 1) (too large)~1e8 (exploded)NaNNo
N(0,0.01)\mathcal{N}(0, 0.01) (too small)~1e-9 (vanished)~0No
Kaiming normal~1.0 (stable)~0.1Yes

The difference between a good and bad initialization is often the difference between a model that trains and one that doesn't — especially for networks deeper than 5–6 layers.

Diagnosing Training Runs

A practical checklist for unstable training:

SymptomLikely causeFix
Loss is NaN from step 1LR too high or bad initReduce LR, check init
Loss spikes then recoversLR too high without warmupAdd warmup
Loss plateaus earlyLR too lowIncrease LR or use schedule
Gradient norm > 100 regularlyExploding gradientsClip at 1.0, reduce LR
Train loss falls, val loss rises immediatelyOverfittingRegularize (dropout, weight decay)
Both losses plateau highUnderfittingIncrease capacity, reduce regularization
🚀Production

Standard training recipe for a new deep learning project:

  • Init: Kaiming normal for ReLU networks; Xavier for sigmoid/tanh; default PyTorch init is usually correct
  • Optimizer: AdamW with weight_decay=0.01 — decoupled weight decay is more correct than Adam + L2
  • LR: start with 1e-3 for Adam/AdamW; 0.1 for SGD with momentum
  • Schedule: warmup for 5–10% of total steps + cosine annealing to lr_min=1e-6
  • Gradient clipping: max_norm=1.0 — add it by default, it rarely hurts and often prevents catastrophic divergence
  • Batch size: 32–256 for most tasks; use gradient accumulation if GPU memory is a constraint
  • Monitor: log gradient norm, learning rate, and train/val loss every N steps — most training failures are visible in these three signals

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.