Training Dynamics
A neural network is only as good as its training run. The same architecture with poor initialization, a mistuned learning rate, or unclipped gradients can fail to converge entirely — while the same architecture with proper setup trains stably to strong results. This lesson covers the four pillars of stable deep learning training: weight initialization, learning rate schedules, gradient clipping, and batch size effects.
Theory
drag slider to adjust warmup length · dashed cyan line = warmup end
Learning rate is the size of each step toward lower loss. Too large: you overshoot the minimum and bounce. Too small: you crawl and may get trapped in local minima. Schedules solve this by starting larger (explore the landscape) and decaying over time (refine toward a minimum). The three curves above show warmup+cosine, step decay, and constant — each encoding a different assumption about when to slow down.
Weight Initialization
Before the first forward pass, every weight must be initialized. The goal: keep activation variance roughly constant through all layers. Too large, and activations explode; too small, and they vanish — in both cases, gradients become useless before reaching early layers.
Xavier / Glorot Initialization (for sigmoid, tanh): derives the scale by requiring that variance is preserved through the linear transformation :
Setting requires . Taking the average of forward and backward pass constraints:
He / Kaiming Initialization (for ReLU): ReLU zeroes out half of its inputs, effectively halving the variance. The correction factor of 2 compensates:
He initialization is the default for any network using ReLU or its variants (Leaky ReLU, GELU, SiLU). Xavier is better for sigmoid/tanh where both positive and negative outputs are preserved.
import torch.nn as nn
# PyTorch defaults
nn.Linear(256, 256) # Kaiming uniform by default for Linear
nn.Conv2d(64, 128, 3) # Kaiming uniform by default for Conv
# Explicit He init (after constructing the module)
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
# Explicit Xavier init
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias) # biases typically initialized to 0Think of initialization as tuning the gain of each layer before any signal has passed through. Too much gain and signals amplify to infinity; too little and they decay to zero. Kaiming and Xavier solve the same equation: what gain factor keeps the signal's amplitude stable from input to output?
Learning Rate Schedules
A fixed learning rate rarely gives best results. Too high early: unstable training. Too low throughout: slow convergence. The solution: vary the learning rate during training.
Warmup: start with a very small learning rate and ramp it up linearly over the first steps. This prevents large random gradients in the first steps (when weights are random) from destabilizing the model before it has learned useful representations:
Warmup is required for Transformer training (see the Attention lesson) and is increasingly used for any deep network.
Warmup is necessary because early in training, the gradient estimates are highly noisy — the network hasn't oriented itself in the loss landscape yet. A large learning rate applied to noisy gradients pushes parameters far in random directions. Starting small lets the model find a coarse direction first, then accelerating once the gradients become meaningful.
Cosine Annealing: after warmup, decay the learning rate following a cosine curve from to :
Cosine annealing is the most commonly used schedule in modern deep learning — it decays smoothly and avoids the abrupt drops of step decay.
Step Decay: multiply the learning rate by a factor every epochs. Simple but creates discontinuities that can cause instability:
import torch.optim as optim
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
# Cosine annealing with warmup (using transformers library scheduler)
from transformers import get_cosine_schedule_with_warmup
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=500,
num_training_steps=10_000,
)
# Or PyTorch native cosine annealing
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=10_000, eta_min=1e-6
)
# In training loop:
optimizer.step()
scheduler.step()| Schedule | Shape | Best for |
|---|---|---|
| Constant | Flat | Baselines, short runs |
| Warmup + cosine | Ramp up then smooth decay | Transformers, LLMs, most modern networks |
| Step decay | Staircase | CNNs (ResNet original paper) |
| OneCycleLR | Triangle + decay | Fast training (super-convergence) |
Gradient Clipping
When gradients become very large — especially in RNNs or deep networks with unstable initialization — a single step can send weights to extreme values, causing NaN loss or divergence. Gradient clipping rescales the gradient if its norm exceeds a threshold:
The key insight: we clip the global gradient norm across all parameters, not per-parameter. This preserves the direction of the gradient while controlling its magnitude:
# After loss.backward(), before optimizer.step()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()A threshold of is standard for Transformers and LLMs; is common for RNNs. You can diagnose whether clipping is needed by logging the gradient norm during training — if it regularly exceeds your threshold by 10×+, your learning rate may be too high or your initialization poor.
Never clip gradients before calling loss.backward() — the gradient doesn't exist yet. The order is always: forward → loss → backward → clip → step. Clipping after optimizer.step() is also wrong — you'd be clipping already-applied updates.
Batch Size and Gradient Noise
The batch size controls how much of the dataset is averaged into each gradient update. Larger batches mean lower-variance (smoother) gradient estimates; smaller batches mean higher-variance (noisier) estimates.
The generalization tradeoff: large-batch training converges faster in wall-clock time but tends to find sharp minima — regions where loss is low but curvature is high, leading to worse generalization. Small-batch training finds flat minima that generalize better, because the noise acts as implicit regularization.
The linear scaling rule (Goyal et al., 2017): when increasing batch size by factor , scale the learning rate by to maintain equivalent training dynamics. With warmup this works well up to batch sizes of ~8K; beyond that, accuracy degrades.
Gradient accumulation simulates a large batch on a small GPU by accumulating gradients over multiple forward passes before updating:
accumulation_steps = 4 # effective batch = batch_size × 4
optimizer.zero_grad()
for i, (X, y) in enumerate(train_loader):
loss = criterion(model(X), y) / accumulation_steps
loss.backward() # accumulate gradients
if (i + 1) % accumulation_steps == 0:
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
scheduler.step()Walkthrough
Diagnosing a Failing Training Run
Suppose you train a 10-layer MLP on CIFAR-10 and observe loss becoming NaN at step 200. Work through the checklist:
Step 1 — Check gradient norm. Add a log line before each optimizer step:
norm = nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
print(f"step {i}: grad_norm={norm:.2f}")You see norms of 300–800 before the NaN. This points to exploding gradients, not bad data.
Step 2 — Identify the cause. High gradient norm with NaN loss has two likely causes: LR too large, or no warmup on a randomly initialized model. Check: did you use lr=1e-3 without warmup? With random weights, early gradient directions are noisy and high-magnitude.
Step 3 — Apply fixes in order:
- Add gradient clipping:
clip_grad_norm_(model.parameters(), 1.0)— this alone may be enough - Add a 500-step warmup before the full learning rate kicks in
- If still unstable, verify init: confirm the final linear layer uses
kaiming_normal_or PyTorch defaults
Step 4 — Verify recovery. After fixes, gradient norm should stay below 2.0 throughout training and loss should decrease monotonically in the first 1000 steps. If it doesn't, the LR itself may be too high — try halving it.
This pattern — log norm, isolate cause, apply the smallest fix — applies to any training instability.
Analysis & Evaluation
Where Your Intuition Breaks
Cosine decay should reach zero at the end of training — you're done, so the learning rate should be zero. In practice, decaying all the way to zero can slightly hurt final performance: a small residual learning rate allows continued adaptation to the tail of the data distribution. Most implementations use a minimum learning rate of rather than zero.
Initialization vs No Initialization
Practical impact of initialization choice on a 10-layer MLP (ReLU activations):
| Init | Mean activation (layer 10) | Gradient norm (layer 1) | Converges? |
|---|---|---|---|
| All zeros | 0 (symmetry broken never) | 0 | No |
| (too large) | ~1e8 (exploded) | NaN | No |
| (too small) | ~1e-9 (vanished) | ~0 | No |
| Kaiming normal | ~1.0 (stable) | ~0.1 | Yes |
The difference between a good and bad initialization is often the difference between a model that trains and one that doesn't — especially for networks deeper than 5–6 layers.
Diagnosing Training Runs
A practical checklist for unstable training:
| Symptom | Likely cause | Fix |
|---|---|---|
| Loss is NaN from step 1 | LR too high or bad init | Reduce LR, check init |
| Loss spikes then recovers | LR too high without warmup | Add warmup |
| Loss plateaus early | LR too low | Increase LR or use schedule |
| Gradient norm > 100 regularly | Exploding gradients | Clip at 1.0, reduce LR |
| Train loss falls, val loss rises immediately | Overfitting | Regularize (dropout, weight decay) |
| Both losses plateau high | Underfitting | Increase capacity, reduce regularization |
Standard training recipe for a new deep learning project:
- Init: Kaiming normal for ReLU networks; Xavier for sigmoid/tanh; default PyTorch init is usually correct
- Optimizer: AdamW with
weight_decay=0.01— decoupled weight decay is more correct than Adam + L2 - LR: start with
1e-3for Adam/AdamW;0.1for SGD with momentum - Schedule: warmup for 5–10% of total steps + cosine annealing to
lr_min=1e-6 - Gradient clipping:
max_norm=1.0— add it by default, it rarely hurts and often prevents catastrophic divergence - Batch size: 32–256 for most tasks; use gradient accumulation if GPU memory is a constraint
- Monitor: log gradient norm, learning rate, and train/val loss every N steps — most training failures are visible in these three signals
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.