Convolutional Neural Networks
Convolutional Neural Networks (CNNs) exploit a fundamental property of visual data: translation invariance. A cat is a cat whether it appears in the top-left or bottom-right of an image. CNNs bake this prior into the architecture through weight sharing — the same filter scans every position — slashing parameter counts and dramatically improving generalization.
Theory
A convolution is a sliding pattern detector. The kernel is a small template — in the animation above, an edge detector — and sliding it across the input asks "does this pattern appear here?" at every position. The output is a map of where the pattern was detected and how strongly. Stack many kernels and you get many pattern maps; stack many layers and early layers detect edges, later layers detect textures, and deep layers detect objects.
The Convolution Operation
The discrete 2D convolution (technically cross-correlation in deep learning) between input and kernel :
Shared weights across all positions — the same kernel values everywhere — is what makes CNNs efficient. A kernel has 9 parameters regardless of image size. A fully-connected layer processing a image would need parameters per output neuron. Weight sharing also encodes the assumption that a pattern useful in one region is useful everywhere — translational equivariance built into the architecture.
For a kernel applied to an input with padding and stride , the output spatial dimensions are:
With , , : (same padding — output matches input size).
Parameter Sharing Advantage
This is the key insight. Compare a dense layer to a conv layer for a image producing feature maps:
Dense (fully-connected): Every output neuron connects to every input pixel.
For : 196,608 parameters just for one layer.
Conv (3×3 kernel, same padding): Every output position uses the same filter.
That is a 114× reduction in parameters. For a ImageNet image producing 64 feature maps, the ratio exceeds 30,000×.
Weight sharing means the filter for detecting a horizontal edge is the same everywhere. It must be a good horizontal-edge detector regardless of position — which forces it to generalize. A dense layer could memorize which pixel positions correlate with the label; a conv layer cannot.
Receptive Fields
The receptive field of a neuron is the region of the input it can "see." Stacking conv layers grows the receptive field:
- After conv1 (3×3, stride 1): RF = 3×3
- After conv2 (3×3, stride 1): RF = 5×5
- After conv3 (3×3, stride 1): RF = 7×7
With stride-2 or max-pooling, growth is faster. After 3 layers of 3×3 convolutions with 2×2 max-pooling between them, the RF covers a large fraction of the input.
Pooling
Max pooling takes the maximum activation in a spatial window (typically 2×2, stride 2):
This halves spatial dimensions, doubles the effective receptive field, and provides local translation invariance. The gradient during backprop is 1 at the max location and 0 elsewhere (a switch variable).
Global average pooling reduces each feature map to a single scalar by averaging over all spatial positions, replacing large dense layers in the classification head.
Batch Normalization
After each conv layer, batch norm normalizes activations across the batch:
where and are computed over the current mini-batch, and are learnable scale/shift parameters. At inference, running exponential moving averages replace batch statistics.
Batch norm's regularization effect is strong enough that many architectures train without dropout when using batch norm. The noise from estimating batch statistics prevents co-adaptation of features and acts like a form of data augmentation.
CNN Architecture
Walkthrough
Dataset: CIFAR-10
CIFAR-10 contains 60,000 color images () across 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck.
- Training: 50,000 images (5,000 per class)
- Test: 10,000 images
- Random baseline: 10%
- Human performance: ~94%
- Our 3-block CNN: ~82% test accuracy
- ResNet-18: ~93%
Data Augmentation
Augmentation is critical for CIFAR-10. Without it, a CNN typically overfits: 95%+ training accuracy but only ~70% test accuracy.
from torchvision import transforms
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # crop from 40x40 padded image
transforms.RandomHorizontalFlip(p=0.5), # mirror left-right
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465), # CIFAR-10 per-channel means
std =(0.2023, 0.1994, 0.2010), # CIFAR-10 per-channel stds
),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
std =(0.2023, 0.1994, 0.2010)),
])Code Implementation
train.py"""
CNN Training Pipeline — CIFAR-10 Image Classification
======================================================
Architecture: 3-block CNN with batch norm and residual skip connections
Optimizer: Adam with cosine annealing LR schedule
Target: ~82% test accuracy in 30 epochs
"""
import argparse
import json
from pathlib import Path
import mlflow
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class ConvBlock(nn.Module):
"""Conv -> BN -> ReLU -> Conv -> BN + residual skip connection."""
def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
# 1x1 conv skip if dimensions change
self.skip = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch),
) if (in_ch != out_ch or stride != 1) else nn.Identity()
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return self.relu(out + self.skip(x))
class SimpleCNN(nn.Module):
def __init__(self, num_classes: int = 10):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
)
self.layer1 = ConvBlock(32, 64, stride=2) # 32x32 -> 16x16
self.layer2 = ConvBlock(64, 128, stride=2) # 16x16 -> 8x8
self.layer3 = ConvBlock(128, 256, stride=2) # 8x8 -> 4x4
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1), # global average pooling -> (B, 256, 1, 1)
nn.Flatten(), # -> (B, 256)
nn.Linear(256, num_classes),
)
def forward(self, x):
return self.head(self.layer3(self.layer2(self.layer1(self.stem(x)))))
def get_loaders(batch_size: int = 128):
train_tf = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=train_tf)
test_ds = datasets.CIFAR10("./data", train=False, transform=test_tf)
return (
DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True),
DataLoader(test_ds, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True),
)
def train(cfg: dict) -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hp = cfg["hyperparams"]
train_loader, test_loader = get_loaders(hp["batch_size"])
model = SimpleCNN(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=hp["learning_rate"],
weight_decay=hp["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=hp["epochs"])
criterion = nn.CrossEntropyLoss()
mlflow.set_experiment("cnn-cifar10")
with mlflow.start_run():
mlflow.log_params(hp)
best_acc = 0.0
for epoch in range(hp["epochs"]):
model.train()
total_loss, correct, total = 0.0, 0, 0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
logits = model(X)
loss = criterion(logits, y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(1) == y).sum().item()
total += len(y)
scheduler.step()
# Evaluate
model.eval()
tcorrect, ttotal = 0, 0
with torch.no_grad():
for X, y in test_loader:
X, y = X.to(device), y.to(device)
tcorrect += (model(X).argmax(1) == y).sum().item()
ttotal += len(y)
test_acc = tcorrect / ttotal
mlflow.log_metrics({
"train_loss": total_loss / len(train_loader),
"train_acc": correct / total,
"test_acc": test_acc,
}, step=epoch)
print(f"Ep {epoch+1:3d} | loss={total_loss/len(train_loader):.4f} "
f"| train={correct/total:.4f} | test={test_acc:.4f}")
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), "best_model.pt")
mlflow.log_artifact("best_model.pt")
print(f"Best test accuracy: {best_acc:.4f}")
# Typical: 0.8237 after 30 epochs
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="../../config.yaml")
import yaml
train(yaml.safe_load(open(parser.parse_args().config)))
if __name__ == "__main__":
main()Analysis & Evaluation
Where Your Intuition Breaks
Deeper CNNs detect more global features — true in general, but max pooling loses spatial precision at each layer. After 4–5 pooling operations, the network knows what is in the image but not where. This is why object detection architectures (YOLO, Faster R-CNN) use feature pyramids or skip connections to recover spatial information — pure depth without spatial preservation is insufficient for localization tasks.
Accuracy Progression
| Epoch | Train Loss | Train Acc | Test Acc |
|---|---|---|---|
| 1 | 1.6234 | 41.2% | 43.7% |
| 5 | 1.0821 | 62.8% | 61.3% |
| 10 | 0.7234 | 74.6% | 71.8% |
| 20 | 0.4521 | 83.9% | 79.4% |
| 30 | 0.2893 | 90.1% | 82.4% |
The train-test gap after 30 epochs (90.1% vs 82.4%) indicates mild overfitting. Adding CutMix augmentation or label smoothing would close this gap by 1-2%.
Batch Normalization Effect
Removing batch norm from the same architecture:
| Config | Test Acc @ Ep10 | Test Acc @ Ep30 | Notes |
|---|---|---|---|
| With BN | 71.8% | 82.4% | Stable, lr=1e-3 |
| Without BN | 58.3% | 71.2% | High variance, slow |
| Without BN, lr/10 | 61.1% | 68.7% | Safer LR needed |
Batch norm lets you use 10× larger learning rates because it keeps activations in a well-scaled range throughout training. Without it, a high LR causes exploding gradients in early layers. This is why BN enabled training significantly deeper networks after its introduction in 2015.
Receptive Field Growth
For our 3-block architecture with stride-2 convolutions:
| Layer | Output Size | Receptive Field |
|---|---|---|
| Input | 32×32 | 1×1 |
| Stem | 32×32 | 3×3 |
| Layer1 s=2 | 16×16 | 7×7 |
| Layer2 s=2 | 8×8 | 15×15 |
| Layer3 s=2 | 4×4 | 31×31 |
| Global Pool | 1×1 | 32×32 (full) |
By the final layer, each output neuron has "seen" the entire input image — essential for classifying whole objects rather than local textures.
Filter Visualization
import matplotlib.pyplot as plt
model.load_state_dict(torch.load("best_model.pt"))
filters = model.stem[0].weight.data.cpu() # (32, 3, 3, 3)
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
for i, ax in enumerate(axes.flat):
f = filters[i].permute(1, 2, 0).numpy()
f = (f - f.min()) / (f.max() - f.min() + 1e-8)
ax.imshow(f); ax.axis("off")
plt.suptitle("Learned Conv1 Filters — edge and color detectors")
plt.savefig("filters.png", dpi=150)Typical patterns: Gabor-like edge detectors at various orientations, color opponent filters (red-green, blue-yellow), and blob detectors. These closely resemble V1 simple cells in the mammalian visual cortex.
If your filters look like random noise after training, the network failed to learn meaningful features. Common causes: LR too high (destroyed features early), insufficient data without augmentation, or a normalization bug that prevents convergence. Always check that your Normalize() mean/std match the actual dataset statistics.
Production-Ready Code
"""serve_api/app.py — CIFAR-10 CNN inference endpoint."""
import io
from pathlib import Path
import torch
import torch.nn as nn
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
from torchvision import transforms
CLASSES = ["airplane","automobile","bird","cat","deer",
"dog","frog","horse","ship","truck"]
app = FastAPI(title="CIFAR-10 CNN API", version="1.0.0")
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
self.skip = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch),
) if (in_ch != out_ch or stride != 1) else nn.Identity()
def forward(self, x):
return self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))) + self.skip(x))
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, 32, 3, padding=1, bias=False),
nn.BatchNorm2d(32), nn.ReLU())
self.layer1 = ConvBlock(32, 64, stride=2)
self.layer2 = ConvBlock(64, 128, stride=2)
self.layer3 = ConvBlock(128, 256, stride=2)
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(),
nn.Linear(256, num_classes))
def forward(self, x):
return self.head(self.layer3(self.layer2(self.layer1(self.stem(x)))))
model = SimpleCNN()
model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
model.eval()
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
tensor = transform(img).unsqueeze(0)
with torch.inference_mode():
logits = model(tensor)
probs = torch.softmax(logits, 1).squeeze()
top3 = probs.topk(3)
return JSONResponse({
"prediction": CLASSES[probs.argmax()],
"confidence": round(probs.max().item(), 4),
"top3": [
{"class": CLASSES[i], "probability": round(p.item(), 4)}
for i, p in zip(top3.indices.tolist(), top3.values.tolist())
],
})For batch inference, use torch.compile() (PyTorch 2.0+) for a 20-40% speedup. For deployment-agnostic serving, export to Open Neural Network Exchange (ONNX): torch.onnx.export(model, dummy_input, "model.onnx", opset_version=17). ONNX Runtime typically achieves 2-4× throughput vs PyTorch on CPU inference, and runs on edge devices without a Python runtime.
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.