Regularization Techniques
A model that perfectly memorizes training data is useless. Regularization is any technique that reduces the gap between training and test performance by constraining what the model can learn. The goal: find the simplest hypothesis that fits the data, not the most complex one that memorizes it.
Theory
Without constraints, a model will fit training data perfectly by memorizing it — every quirk, noise spike, and coincidence. Regularization adds a cost to complexity, forcing the model to find simpler explanations that generalize. The three curves above show the result: too little constraint overfits, too much underfits, the right amount finds the pattern.
L2 Regularization (Ridge / Weight Decay)
The L2-regularized objective adds a penalty proportional to the sum of squared weights:
The gradient update becomes:
The term multiplies every weight — hence the name weight decay. L2 pushes all weights toward zero proportionally to their magnitude, producing smooth solutions where many weights are small but none are exactly zero.
Bayesian interpretation: L2 regularization is equivalent to placing a zero-mean Gaussian prior on the weights and finding the Maximum A Posteriori (MAP) estimate:
L1 Regularization (Lasso)
L1 adds a penalty proportional to the sum of absolute values:
The gradient of is (subgradient), leading to:
L1 subtracts a constant from every weight each step, regardless of magnitude. Small weights get pushed to exactly zero — producing sparse solutions. This is L1's key property: automatic feature selection.
The difference in gradient shape determines the behavior: L2's gradient is — proportional to weight magnitude, so large weights shrink faster but small weights never reach zero. L1's gradient is — constant regardless of magnitude, so every weight is pushed toward zero with equal force, producing exact sparsity. This isn't a design choice; it's a consequence of the geometry of the L1 norm.
L2 whispers "stay small" to all weights proportionally — big weights get pushed harder. L1 shouts "go to zero" equally to all weights — small weights that barely contribute get zeroed out first. L1 is a scalpel (sparse feature selection); L2 is a shrinkage blanket (smooth parameter reduction).
The key geometric difference: L2's constraint region is a sphere (smooth boundary — gradient never exactly hits zero), while L1's constraint region is a diamond/hypercube (sharp corners at the axes — solutions often land exactly at corners where some ).
Elastic Net: Best of Both
Elastic Net combines L1 sparsity with L2's grouping effect (correlated features get similar non-zero weights). Used in scikit-learn's ElasticNet and as weight decay + L1 penalty in neural networks.
Dropout: Approximate Model Averaging
Dropout (Srivastava et al., 2014) randomly zeroes each neuron's activation with probability during training:
At inference, no masking is applied but activations are scaled by to maintain expected magnitude (or equivalently, training activations are scaled by — the "inverted dropout" used in PyTorch).
Why does this work? Each training step uses a different random subnetwork (a "thinned" network). With dropout rate and neurons, there are possible subnetworks. Training with dropout is approximately equivalent to training an ensemble of all these subnetworks with shared weights and averaging their predictions at test time.
The geometric mean of many models is almost always better than any single model.
Dropout rates by layer type (rules of thumb from practice):
- Dense layers: –
- Conv layers: – (spatial dropout: drop entire feature maps)
- Input layer: – (if using)
- Last hidden layer before output: –
Normalization Layers
Normalization layers stabilize training by controlling the scale of activations between layers. The right choice depends on architecture — this is one of the most practically consequential decisions in modern deep learning.
Batch Normalization (CNNs)
Batch normalization normalizes over the mini-batch dimension, computing statistics across all samples in the batch for each feature:
where and are the batch mean and variance, and are learnable scale and shift parameters.
Training vs inference: During training, statistics come from the current mini-batch. During inference, running statistics (exponential moving averages) are used:
Always call model.eval() before inference — this switches BatchNorm from batch statistics to running statistics. Forgetting model.eval() is one of the most common PyTorch bugs: results become non-deterministic and usually worse because single-sample batch statistics are noisy.
Layer Normalization (Transformers)
Layer normalization normalizes over the feature dimension of each individual sample, making it independent of batch size:
where and are computed across all features of a single token. Unlike BatchNorm, LayerNorm behaves identically at train and inference time — no running statistics, no model.eval() concern.
LayerNorm is the standard for all Transformer-based models (BERT, GPT, T5, Llama) because sequences have variable lengths and batch sizes can be as small as 1 during generation.
RMSNorm (Modern LLMs)
Root Mean Square Layer Normalization (RMSNorm) (Zhang & Sennrich, 2019) drops the mean-centering step from LayerNorm, only normalizing by the RMS of activations:
This removes the mean subtraction and the shift parameter, reducing computation by ~40% with negligible quality loss. RMSNorm is now preferred over LayerNorm in most modern LLMs: Llama 2/3, Mistral, Gemma, and Falcon all use RMSNorm.
| BatchNorm | LayerNorm | RMSNorm | |
|---|---|---|---|
| Normalizes over | Batch | Features | Features (no mean) |
| Batch size dependency | Yes | No | No |
| Train/eval difference | Yes | No | No |
| Learnable params | only | ||
| Primary use | CNNs | Transformers | Modern LLMs |
| Examples | ResNet, EfficientNet | BERT, GPT-2 | Llama 2/3, Mistral |
BatchNorm asks: "How does this neuron compare to the same neuron across the batch?" LayerNorm asks: "How does this value compare to all features for this sample?" RMSNorm asks the same but skips the mean-centering, betting that the scale (RMS) is what matters for training stability, not the shift. For most modern LLM training this bet pays off.
Walkthrough
Comparing Dropout Rates on CIFAR-10
Using the SimpleCNN from the CNN lesson (3 ConvBlocks), we compare dropout rates applied in the classification head:
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self, dropout_rate: float = 0.0):
super().__init__()
# ... convolutional blocks (same as CNN lesson) ...
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(dropout_rate), # <- varies between experiments
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(dropout_rate), # <- varies between experiments
nn.Linear(128, 10),
)
# Experiment: same training setup, only dropout varies
results = {}
for p in [0.0, 0.1, 0.3, 0.5, 0.7]:
model = SimpleCNN(dropout_rate=p)
# Train for 30 epochs on CIFAR-10...
results[p] = {"train_acc": ..., "test_acc": ..., "gap": ...}
# Results (30 epochs, Adam lr=1e-3, no other augmentation):
# p=0.0: train=95.2% test=74.1% gap=21.1% (severe overfitting)
# p=0.1: train=93.8% test=77.3% gap=16.5%
# p=0.3: train=91.2% test=80.1% gap=11.1% (best generalization)
# p=0.5: train=87.4% test=79.3% gap=8.1%
# p=0.7: train=78.2% test=72.1% gap=6.1% (underfitting)The optimal dropout rate (0.3 here) balances regularization against underfitting. Notice that at , the model can no longer fit the training data — 70% of neurons are zeroed at each step, leaving too little capacity.
Effect of Weight Decay on CIFAR-10
Weight decay ( in L2 regularization) controls how strongly weights are pushed toward zero:
# Using torch.optim.Adam with weight_decay parameter
results_wd = {}
for wd in [0.0, 1e-5, 1e-4, 1e-3, 1e-2]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=wd)
# Train for 30 epochs...
# Results:
# wd=0.0: train=95.1% test=74.3% gap=20.8%
# wd=1e-5: train=94.6% test=75.9% gap=18.7%
# wd=1e-4: train=92.3% test=79.8% gap=12.5% (good)
# wd=1e-3: train=89.1% test=80.2% gap=8.9% (best test acc)
# wd=1e-2: train=82.4% test=76.1% gap=6.3% (too much — underfitting)Weight decay of to is typical for Adam. Note that AdamW (Adam with decoupled weight decay) should be used instead of Adam with weight_decay for more mathematically correct behavior.
Combining Regularization Techniques
In practice, multiple regularization techniques are combined:
from torchvision import transforms
# 1. Data augmentation (reduces overfitting by expanding effective dataset)
train_tf = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2),
transforms.RandomErasing(p=0.1), # randomly mask patches (Cutout)
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
# 2. Batch normalization (in the model architecture)
# 3. Dropout (in the classifier head)
# 4. Weight decay (in the optimizer)
# 5. Label smoothing (in the loss function)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-3,
weight_decay=1e-3, # L2 regularization
)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # soft targets
# Combined result on CIFAR-10 (30 epochs):
# train=90.1% test=84.2% gap=5.9%
# vs no regularization: train=95.2% test=74.1% gap=21.1%Analysis & Evaluation
Where Your Intuition Breaks
More regularization seems safer — it should always reduce overfitting. But over-regularized models underfit: the penalty term dominates the loss, and the model learns to predict near-zero everywhere rather than learning the actual signal. Regularization strength is a dial, not a switch. The right value depends on dataset size — larger datasets need less regularization because they provide enough signal to constrain the model naturally.
Sparsity Patterns: L1 vs L2
One key diagnostic: what fraction of weights are near-zero after training?
import torch
import numpy as np
# After training with L1 vs L2 regularization
def sparsity(model, threshold=1e-3):
total, near_zero = 0, 0
for p in model.parameters():
total += p.numel()
near_zero += (p.abs() < threshold).sum().item()
return near_zero / total
# PyTorch does not natively support L1, but we can add it manually:
def l1_regularization(model, lambda_l1=1e-4):
l1_loss = sum(p.abs().sum() for p in model.parameters())
return lambda_l1 * l1_loss
# Training with L1:
for X, y in train_loader:
optimizer.zero_grad()
loss = criterion(model(X), y) + l1_regularization(model, 1e-4)
loss.backward()
optimizer.step()
# Results after training (MLP on MNIST, hidden=256):
# No regularization: sparsity = 1.3% (very few near-zero weights)
# L2 (wd=1e-3): sparsity = 3.7% (small improvement)
# L1 (lambda=1e-4): sparsity = 47.8% (nearly half the weights zeroed!)
# L1 (lambda=1e-3): sparsity = 89.2% (extremely sparse, underfitting)This sparsity from L1 is the foundation of pruning: identifying and removing weights that don't contribute. A 47.8% sparse MLP can be compressed to roughly half the size with minimal accuracy loss.
Batch Norm: Training vs Inference Behavior
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 2),
)
# Generate a consistent test input
x_test = torch.randn(1, 10, generator=torch.Generator().manual_seed(0))
# TRAINING MODE: uses batch statistics (different each call due to random inputs)
model.train()
batch1 = torch.randn(32, 10)
batch2 = torch.randn(32, 10)
out_b1 = model(batch1) # normalizes using batch1 stats
out_b2 = model(batch2) # normalizes using batch2 stats
# BN running stats are updated after each forward pass
# INFERENCE MODE: uses accumulated running statistics
model.eval()
out_eval_1 = model(x_test)
out_eval_2 = model(x_test)
assert torch.allclose(out_eval_1, out_eval_2) # deterministic!
# Common bug: forgetting model.eval()
model.train() # accidentally left in train mode
out_wrong = model(x_test) # non-deterministic, usually worseThe difference between train and eval mode is largest when:
- The model was trained on a different data distribution than the test input
- The mini-batch size at test time is very small (batch statistics are noisy)
- There is a domain shift between training and deployment
Regularization Strategy by Dataset Size
A practical guide for choosing regularization strength:
| Dataset Size | Primary Risk | Recommended Strategy |
|---|---|---|
| < 1K samples | Severe overfitting | Strong dropout (0.5), high weight decay (1e-3), heavy augmentation |
| 1K-10K | Overfitting | Dropout (0.3), moderate weight decay (1e-4), augmentation |
| 10K-100K | Moderate | Dropout (0.2), light weight decay (1e-5), augmentation |
| 100K-1M | Usually fine | Batch norm, light dropout (0.1), weight decay only |
| > 1M | Underfitting risk | Minimal regularization; focus on model capacity |
In production, regularization must be re-tuned when: (1) you collect more data (less regularization needed), (2) you change model architecture (deeper models need more dropout), or (3) you observe distribution shift (the model may be overfitting to the old distribution). Monitor training vs validation loss ratio in production — if it widens over time as you collect new data, retrain with updated regularization. Tools like Weights & Biases (wandb) make tracking these metrics across model versions straightforward.
Practical Checklist
Use this sequence when a model is overfitting:
- Add data augmentation — free regularization, usually the most effective first step
- Add weight decay — start with , increase if still overfitting
- Add dropout — start with in dense layers, in conv layers
- Add normalization — BatchNorm for CNNs; LayerNorm or RMSNorm for Transformers/LLMs; stabilizes training and provides implicit regularization
- Reduce model capacity — fewer layers or narrower layers if all else fails
- Collect more data — the ground truth solution to overfitting, when possible
When underfitting (training accuracy is also low):
- Remove or reduce regularization
- Increase model capacity
- Train longer
- Reduce learning rate
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.