Neural-Path/Notes
30 min

Vision Transformers (ViT)

Convolutional neural networks dominated computer vision for a decade by exploiting translation invariance and local structure through sliding kernels. Vision Transformers discard both inductive biases entirely: the image is split into a grid of flat patches, each patch becomes a token, and a standard transformer encoder processes the token sequence. With sufficient data and scale, ViT matches or exceeds CNNs on every major benchmark — and unlike CNNs, it scales predictably with compute. Understanding how patch tokenization works, why position embeddings matter, and how ViT compares to CNNs is foundational for everything in modern computer vision.

Theory

Vision Transformer — image to tokens

The image is divided into P×P non-overlapping patches. Each patch becomes one "token" — the 2D analogue of a word token in a language model.

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
patch size
P×P px
num tokens
(H/P)·(W/P)
+ class token
[CLS] prepended
pos embed
learned 1D

ViT-B/16: 224×224 image, 16×16 patches → 196 tokens + 1 CLS · D=768 · 12 layers · 86M params

A Vision Transformer treats an image as a sequence of patches, exactly like a language model treats a sentence as a sequence of words. Divide the image into a grid of non-overlapping squares, flatten each square into a vector, and pass them through a standard transformer encoder. The key insight is that no architectural changes to the transformer are needed: images and text become the same kind of input.

Patch tokenization

An image xRH×W×Cx \in \mathbb{R}^{H \times W \times C} is divided into NN non-overlapping patches of size P×PP \times P pixels: N=HWP2N = \frac{H \cdot W}{P^2}

The patch count formula is forced by the transformer's requirement for a fixed-length sequence of tokens: N=HW/P2N = H \cdot W / P^2 is the only way to partition the image into non-overlapping equal-sized patches that tile it exactly. Every choice of PP is a trade-off: smaller patches produce more tokens (richer spatial representation at quadratic attention cost); larger patches produce fewer tokens (cheaper attention but coarser spatial resolution). ViT-B/16 (16×16 patches, 196 tokens) is the standard trade-off point on ImageNet-scale images.

Each patch xpiRP2Cx_p^i \in \mathbb{R}^{P^2 \cdot C} is flattened and linearly projected to dimension DD: z0i=Expi+eposi,ERD×(P2C)z_0^i = E \, x_p^i + e_{\text{pos}}^i, \quad E \in \mathbb{R}^{D \times (P^2 \cdot C)}

where eposie_{\text{pos}}^i is a learnable 1D position embedding. This creates a sequence of NN patch tokens, identical in structure to word tokens in BERT.

A learnable class token z00=xclassz_0^0 = x_{\text{class}} is prepended to the sequence. After LL transformer layers, the class token representation is used for classification: y=MLP(zL0)y = \text{MLP}(z_L^0)

Multi-head self-attention over patches

The transformer encoder is identical to the original (Vaswani et al., 2017). For each layer ll:

Layer normalization + MSA: zl=MSA(LN(zl1))+zl1z'_l = \text{MSA}(\text{LN}(z_{l-1})) + z_{l-1}

Layer normalization + MLP: zl=MLP(LN(zl))+zlz_l = \text{MLP}(\text{LN}(z'_l)) + z'_l

The multi-head self-attention computes attention over all N+1N+1 token pairs (patches + class token). This gives global receptive field from layer 1 — unlike CNNs which require many layers for distant pixels to interact.

Query, key, value projections per head hh: Qh=zWQh,Kh=zWKh,Vh=zWVhQ_h = z W_Q^h, \quad K_h = z W_K^h, \quad V_h = z W_V^h

Attnh=softmax ⁣(QhKhdk)Vh\text{Attn}_h = \text{softmax}\!\left(\frac{Q_h K_h^\top}{\sqrt{d_k}}\right) V_h

where dk=D/nheadsd_k = D / n_\text{heads}.

Position embeddings

Unlike CNNs, self-attention is permutation-invariant — it treats the input as a set with no inherent order. Position embeddings inject spatial structure.

1D learnable (standard in ViT): a single learned vector eiRDe_i \in \mathbb{R}^D per position index i=0,,Ni = 0, \ldots, N. Simple and surprisingly effective.

2D position embeddings: separate row and column embeddings, concatenated or added. Better for high-resolution inputs or tasks requiring precise spatial reasoning.

Relative position bias (Swin Transformer): instead of absolute positions, learn a bias BijB_{ij} for each pair of relative positions in the attention logits: Aij=qikjdk+Br(i,j)A_{ij} = \frac{q_i^\top k_j}{\sqrt{d_k}} + B_{r(i,j)}

where r(i,j)r(i,j) is the relative 2D offset between patches ii and jj.

ViT scaling behavior and the data requirement

ViT is a data-hungry model. CNN inductive biases (locality, translation equivariance) help on small datasets; ViT lacks them and needs to learn spatial structure from scratch.

Empirical finding (Dosovitskiy et al., 2020):

  • ViT-Huge trained on ImageNet-21k (14M images) is competitive with ResNets
  • ViT-Huge trained on JFT-300M (300M images) surpasses ResNets
  • On ImageNet-1k alone (1.2M images), ViT underperforms equivalent CNNs

Scaling laws: ViT follows near-power-law scaling with model size and data, similar to language models. This predictable scaling is a key reason ViT became the architecture of choice in large-scale vision systems.

Swin Transformer: hierarchical ViT

Pure ViT scales quadratically with image resolution (O(N2)O(N^2) for self-attention). Swin Transformer (Liu et al., 2021) addresses this with two modifications:

Windowed attention: compute self-attention within local non-overlapping windows of M×MM \times M patches, reducing complexity to O(NM2)O(N \cdot M^2).

Shifted windows: alternate between regular and shifted window partitions to enable cross-window connections without full global attention.

Hierarchical feature maps: like CNNs, Swin merges patches (patch merging = spatial downsampling) to build multi-scale feature maps — enabling dense prediction tasks (detection, segmentation) that require spatial resolution.

Walkthrough

ViT forward pass

python
import torch
import torch.nn as nn
 
class PatchEmbedding(nn.Module):
    """Divide image into patches and project to D dimensions."""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        # Equivalent to: flatten patch + linear projection
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
 
    def forward(self, x):
        # x: (B, C, H, W) -> (B, D, H/P, W/P) -> (B, N, D)
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x
 
 
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim), nn.Dropout(dropout),
        )
 
    def forward(self, x):
        normed = self.norm1(x)
        attn_out, _ = self.attn(normed, normed, normed)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x
 
 
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        n_patches = (img_size // patch_size) ** 2
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
 
    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])   # CLS token -> classification
 
 
# ViT-B/16
model = VisionTransformer(img_size=224, patch_size=16, embed_dim=768,
                          depth=12, num_heads=12, num_classes=1000)
img = torch.randn(2, 3, 224, 224)
logits = model(img)   # (2, 1000)

Fine-tuning with timm

python
import timm, torch
 
# Load pretrained ViT-B/16
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
 
# Stage 1: head only
for p in model.parameters():
    p.requires_grad = False
for p in model.head.parameters():
    p.requires_grad = True
 
optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-3, weight_decay=0.01)
 
# Stage 2: full fine-tune with cosine decay + warmup
for p in model.parameters():
    p.requires_grad = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.05)

Analysis & Evaluation

Where Your Intuition Breaks

ViT requires far more training data than CNNs and can't match CNN performance without ImageNet-21k pretraining. The original ViT paper did require larger datasets than CNNs for competitive performance on ImageNet-1k — but DeiT (2021) closed this gap by training ViT-S/16 from scratch on ImageNet-1k with data augmentation and knowledge distillation, matching ResNet-50 performance. The data efficiency gap came from inductive biases: CNNs have translation equivariance and local connectivity built in, while ViT must learn these from data. With sufficient regularization and augmentation, ViT learns equivalent inductive biases and matches or exceeds CNNs at the same parameter count.

ViT vs CNN comparison

CNN (ResNet, ConvNeXt)ViT
Inductive biasLocality, translation equiv.None
Receptive fieldGrows with depthGlobal from layer 1
Data efficiencyGood on small datasetsNeeds large-scale pretraining
ScalingSub-linear with paramsNear-linear with data and compute
Dense predictionNative multi-scaleNeeds adaptation (Swin, FPN)
Training stabilityRobustSensitive to LR, warmup

ViT model variants

ModelPatchLayersDHeadsParamsImageNet Top-1
ViT-S/161612384622M81.4%
ViT-B/1616127681286M85.2%
ViT-L/161624102416307M87.1%
ViT-H/141432128016632M88.5%

Key failure modes

Low-data regimes: without CNNs' inductive biases, ViT needs 10–100x more data to match equivalent CNN performance. Hybrid architectures (CNN stem + transformer) bridge this gap.

High-resolution inputs: self-attention is O(N2)O(N^2). A 512×512 image with 16×16 patches gives N=1024N = 1024 tokens, making attention 28x more expensive than 224×224. Use Swin or windowed attention for high-resolution tasks.

Position embedding interpolation: when fine-tuning at different resolutions, interpolate the pretrained position embeddings bilinearly — naive transfer degrades performance significantly.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.