Vision Transformers (ViT)
Convolutional neural networks dominated computer vision for a decade by exploiting translation invariance and local structure through sliding kernels. Vision Transformers discard both inductive biases entirely: the image is split into a grid of flat patches, each patch becomes a token, and a standard transformer encoder processes the token sequence. With sufficient data and scale, ViT matches or exceeds CNNs on every major benchmark — and unlike CNNs, it scales predictably with compute. Understanding how patch tokenization works, why position embeddings matter, and how ViT compares to CNNs is foundational for everything in modern computer vision.
Theory
The image is divided into P×P non-overlapping patches. Each patch becomes one "token" — the 2D analogue of a word token in a language model.
ViT-B/16: 224×224 image, 16×16 patches → 196 tokens + 1 CLS · D=768 · 12 layers · 86M params
A Vision Transformer treats an image as a sequence of patches, exactly like a language model treats a sentence as a sequence of words. Divide the image into a grid of non-overlapping squares, flatten each square into a vector, and pass them through a standard transformer encoder. The key insight is that no architectural changes to the transformer are needed: images and text become the same kind of input.
Patch tokenization
An image is divided into non-overlapping patches of size pixels:
The patch count formula is forced by the transformer's requirement for a fixed-length sequence of tokens: is the only way to partition the image into non-overlapping equal-sized patches that tile it exactly. Every choice of is a trade-off: smaller patches produce more tokens (richer spatial representation at quadratic attention cost); larger patches produce fewer tokens (cheaper attention but coarser spatial resolution). ViT-B/16 (16×16 patches, 196 tokens) is the standard trade-off point on ImageNet-scale images.
Each patch is flattened and linearly projected to dimension :
where is a learnable 1D position embedding. This creates a sequence of patch tokens, identical in structure to word tokens in BERT.
A learnable class token is prepended to the sequence. After transformer layers, the class token representation is used for classification:
Multi-head self-attention over patches
The transformer encoder is identical to the original (Vaswani et al., 2017). For each layer :
Layer normalization + MSA:
Layer normalization + MLP:
The multi-head self-attention computes attention over all token pairs (patches + class token). This gives global receptive field from layer 1 — unlike CNNs which require many layers for distant pixels to interact.
Query, key, value projections per head :
where .
Position embeddings
Unlike CNNs, self-attention is permutation-invariant — it treats the input as a set with no inherent order. Position embeddings inject spatial structure.
1D learnable (standard in ViT): a single learned vector per position index . Simple and surprisingly effective.
2D position embeddings: separate row and column embeddings, concatenated or added. Better for high-resolution inputs or tasks requiring precise spatial reasoning.
Relative position bias (Swin Transformer): instead of absolute positions, learn a bias for each pair of relative positions in the attention logits:
where is the relative 2D offset between patches and .
ViT scaling behavior and the data requirement
ViT is a data-hungry model. CNN inductive biases (locality, translation equivariance) help on small datasets; ViT lacks them and needs to learn spatial structure from scratch.
Empirical finding (Dosovitskiy et al., 2020):
- ViT-Huge trained on ImageNet-21k (14M images) is competitive with ResNets
- ViT-Huge trained on JFT-300M (300M images) surpasses ResNets
- On ImageNet-1k alone (1.2M images), ViT underperforms equivalent CNNs
Scaling laws: ViT follows near-power-law scaling with model size and data, similar to language models. This predictable scaling is a key reason ViT became the architecture of choice in large-scale vision systems.
Swin Transformer: hierarchical ViT
Pure ViT scales quadratically with image resolution ( for self-attention). Swin Transformer (Liu et al., 2021) addresses this with two modifications:
Windowed attention: compute self-attention within local non-overlapping windows of patches, reducing complexity to .
Shifted windows: alternate between regular and shifted window partitions to enable cross-window connections without full global attention.
Hierarchical feature maps: like CNNs, Swin merges patches (patch merging = spatial downsampling) to build multi-scale feature maps — enabling dense prediction tasks (detection, segmentation) that require spatial resolution.
Walkthrough
ViT forward pass
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""Divide image into patches and project to D dimensions."""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.n_patches = (img_size // patch_size) ** 2
# Equivalent to: flatten patch + linear projection
self.proj = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: (B, C, H, W) -> (B, D, H/P, W/P) -> (B, N, D)
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(mlp_dim, dim), nn.Dropout(dropout),
)
def forward(self, x):
normed = self.norm1(x)
attn_out, _ = self.attn(normed, normed, normed)
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
n_patches = (img_size // patch_size) ** 2
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
self.dropout = nn.Dropout(dropout)
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
B = x.size(0)
x = self.patch_embed(x)
cls = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1)
x = x + self.pos_embed
x = self.dropout(x)
x = self.blocks(x)
x = self.norm(x)
return self.head(x[:, 0]) # CLS token -> classification
# ViT-B/16
model = VisionTransformer(img_size=224, patch_size=16, embed_dim=768,
depth=12, num_heads=12, num_classes=1000)
img = torch.randn(2, 3, 224, 224)
logits = model(img) # (2, 1000)Fine-tuning with timm
import timm, torch
# Load pretrained ViT-B/16
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
# Stage 1: head only
for p in model.parameters():
p.requires_grad = False
for p in model.head.parameters():
p.requires_grad = True
optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-3, weight_decay=0.01)
# Stage 2: full fine-tune with cosine decay + warmup
for p in model.parameters():
p.requires_grad = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.05)Analysis & Evaluation
Where Your Intuition Breaks
ViT requires far more training data than CNNs and can't match CNN performance without ImageNet-21k pretraining. The original ViT paper did require larger datasets than CNNs for competitive performance on ImageNet-1k — but DeiT (2021) closed this gap by training ViT-S/16 from scratch on ImageNet-1k with data augmentation and knowledge distillation, matching ResNet-50 performance. The data efficiency gap came from inductive biases: CNNs have translation equivariance and local connectivity built in, while ViT must learn these from data. With sufficient regularization and augmentation, ViT learns equivalent inductive biases and matches or exceeds CNNs at the same parameter count.
ViT vs CNN comparison
| CNN (ResNet, ConvNeXt) | ViT | |
|---|---|---|
| Inductive bias | Locality, translation equiv. | None |
| Receptive field | Grows with depth | Global from layer 1 |
| Data efficiency | Good on small datasets | Needs large-scale pretraining |
| Scaling | Sub-linear with params | Near-linear with data and compute |
| Dense prediction | Native multi-scale | Needs adaptation (Swin, FPN) |
| Training stability | Robust | Sensitive to LR, warmup |
ViT model variants
| Model | Patch | Layers | D | Heads | Params | ImageNet Top-1 |
|---|---|---|---|---|---|---|
| ViT-S/16 | 16 | 12 | 384 | 6 | 22M | 81.4% |
| ViT-B/16 | 16 | 12 | 768 | 12 | 86M | 85.2% |
| ViT-L/16 | 16 | 24 | 1024 | 16 | 307M | 87.1% |
| ViT-H/14 | 14 | 32 | 1280 | 16 | 632M | 88.5% |
Key failure modes
Low-data regimes: without CNNs' inductive biases, ViT needs 10–100x more data to match equivalent CNN performance. Hybrid architectures (CNN stem + transformer) bridge this gap.
High-resolution inputs: self-attention is . A 512×512 image with 16×16 patches gives tokens, making attention 28x more expensive than 224×224. Use Swin or windowed attention for high-resolution tasks.
Position embedding interpolation: when fine-tuning at different resolutions, interpolate the pretrained position embeddings bilinearly — naive transfer degrades performance significantly.
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.