Attention & Transformers
The Transformer — introduced in 'Attention Is All You Need' (Vaswani et al., 2017) — discarded recurrence entirely in favor of attention: a mechanism that lets every position in a sequence directly attend to every other position. This parallel computation unlocked training on vastly larger datasets and gave rise to Bidirectional Encoder Representations from Transformers (BERT), Generative Pre-trained Transformer (GPT), T5, and all modern large language models.
Theory
Attention is a soft, differentiable lookup. You have a query — what you're looking for — and a set of keys, each paired with a value. The attention mechanism scores how well each key matches the query, turns those scores into weights that sum to 1, and returns a weighted blend of the values. The heatmap above shows those weights: brighter cells mean "this position paid more attention to that position." Each row is a query; each column is a key.
Scaled Dot-Product Attention
Attention maps a set of queries , keys , and values to an output:
where:
- — queries (one per output position)
- — keys (one per input position)
- — values (one per input position)
- — key/query dimension (controls dot-product scale)
The scaling prevents dot products from growing too large as increases, which would push the softmax into near-zero gradient regions.
The scaling is forced by the geometry of high-dimensional dot products. As grows, the variance of grows proportionally. Without scaling, softmax receives very large inputs, producing near-one-hot distributions where almost all weight goes to a single key. Gradients through a near-one-hot softmax are near zero — the network can't learn. Dividing by stabilizes the variance regardless of embedding size.
Think of attention as a soft dictionary lookup. You have a query (what you're looking for), keys (what each entry is about), and values (what each entry contains). The dot product QK-transpose measures relevance. Softmax converts scores to a probability distribution, and we take a weighted average of values. Unlike a hard lookup, we retrieve a blend of multiple entries. See Attention Pattern Visualization in the Analysis section for an interactive demo of how learned attention weights look in practice.
Why Separate Q, K, and V?
A common question: why not just compute attention directly between the raw token embeddings and ? Why project to three separate matrices?
Asymmetry of asking vs. answering. If , every token would attend most strongly to itself (highest cosine similarity), and attention patterns would be dominated by token identity rather than relationships. Separate projections allow a token to ask "what do I need?" (Q) independently of "what do I offer?" (K).
Decoupling retrieval from relevance. The value projection is independent of . This means what you retrieve can be entirely different from what you used to decide whether to retrieve it. A token might use its syntactic role (key) to be found, but contribute its semantic content (value) once attended to.
Low-rank compression. Each projection maps (typically ). This prevents any single feature dimension from dominating attention scores across all heads — each head operates in its own lower-dimensional subspace, which acts as an inductive bias toward specialization.
In practice: , , are all learned, and each develops a distinct structure. and tend to capture relational features (subject-verb agreement, coreference); tends to capture content features that get aggregated.
Multi-Head Attention
A single attention head can only attend to one "subspace" at a time. Multi-head attention runs attention heads in parallel, each with its own projection matrices:
where , , and .
Typically , so total computation matches a single large head. With and , each head operates in 64 dimensions. Different heads learn to attend to different aspects simultaneously: syntax, co-reference, long-range dependencies, etc.
The total parameter count for multi-head attention: (four weight matrices ).
Modern Variants: MQA and GQA
Standard multi-head attention stores one set of K and V matrices per head. At inference time these must be cached for every generated token (see KV Cache in Production). For large models with long contexts this becomes the primary memory bottleneck.
Multi-Query Attention (MQA) (Shazeer, 2019): all heads share a single and projection, but each head retains its own :
KV cache shrinks by . Quality drops slightly because keys and values carry no per-head specialization.
Grouped-Query Attention (GQA) (Ainslie et al., 2023): a middle ground — query heads are divided into groups, each group sharing one K and V:
KV cache shrinks by . With this recovers MHA; with this recovers MQA. In practice retains most of MHA's quality while cutting cache to 25%.
GQA is now the default in most production LLMs: Llama 2 70B, Llama 3, Mistral 7B, Gemma, and Falcon all use GQA. Understanding this is essential for reading modern model cards.
| Variant | KV heads | Cache size | Quality |
|---|---|---|---|
| MHA | baseline | best | |
| GQA () | 25% | near-MHA | |
| MQA | 1 | slightly lower |
Positional Encoding
Since attention is permutation-invariant, we inject position information by adding sinusoidal encodings to the input embeddings:
These allow the model to generalize to sequence lengths not seen during training. Modern models (GPT-2 onwards) typically use learned positional embeddings or RoPE (Rotary Position Embedding), which encodes position via rotation matrices applied to Q and K.
Causal Masking
For autoregressive generation (GPT-style), each position must not attend to future tokens. We apply a causal mask before the softmax:
where if (future position), else . After softmax, future positions receive weight 0.
The computational cost of full self-attention is — quadratic in sequence length. For this is fine; for it becomes prohibitive. FlashAttention (Dao et al., 2022) achieves the same result with memory by computing attention in tiles, enabling context lengths of 100K+ tokens.
Transformer Encoder Block
Each encoder layer applies multi-head self-attention followed by a feedforward network, with residual connections and layer normalization:
The FFN expands to in the hidden layer: , with .
Walkthrough
Machine Translation: WMT En-De
The Workshop on Machine Translation (WMT) English-German dataset contains approximately 4.5 million sentence pairs.
- Training: 4.5M sentence pairs
- Validation: newstest2013 (3,000 sentences)
- Test: newstest2014 (3,003 sentences)
- Baseline (phrase-based Statistical Machine Translation (SMT)): ~20 Bilingual Evaluation Understudy (BLEU)
- Original Transformer (base): 27.3 BLEU
- Our mini Transformer: ~22 BLEU (trained on 100K pairs)
Tokenization
Modern transformers use subword tokenization rather than word-level tokens, balancing vocabulary size vs. coverage:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Train BPE tokenizer on the training corpus
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(
vocab_size=32000, # typical for multilingual models
special_tokens=["[UNK]", "[PAD]", "[BOS]", "[EOS]"],
min_frequency=2,
)
files = ["train_en.txt", "train_de.txt"]
tokenizer.train(files, trainer)
tokenizer.save("bpe_tokenizer.json")
# Example tokenization:
# "Unbelievable" -> ["Un", "believ", "able"] (3 tokens)
# "cats" -> ["cats"] (1 token, frequent word)
# "Katze" -> ["Katz", "e"] (2 tokens in German BPE)Building the Mini Transformer
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int = 512, num_heads: int = 8, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, d_model) -> (B, h, T, d_k)
B, T, _ = x.shape
return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, q, k, v, mask=None):
B = q.size(0)
Q = self.split_heads(self.W_q(q)) # (B, h, T_q, d_k)
K = self.split_heads(self.W_k(k)) # (B, h, T_k, d_k)
V = self.split_heads(self.W_v(v)) # (B, h, T_k, d_k)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = self.dropout(F.softmax(scores, dim=-1))
out = torch.matmul(attn, V) # (B, h, T_q, d_k)
# Concatenate heads and project
out = out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
return self.W_o(out) # (B, T_q, d_model)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Pre-norm formulation (more stable than original post-norm)
x = x + self.dropout(self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model=256, num_heads=4, num_layers=3,
d_ff=1024, max_len=512, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_embed = nn.Embedding(max_len, d_model)
self.layers = nn.ModuleList([
TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model
# Scale embeddings by sqrt(d_model) as in the original paper
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model ** -0.5)
def forward(self, x, mask=None):
B, T = x.shape
positions = torch.arange(T, device=x.device).unsqueeze(0) # (1, T)
h = self.dropout(
self.embedding(x) * math.sqrt(self.d_model) + self.pos_embed(positions)
)
for layer in self.layers:
h = layer(h, mask)
return self.norm(h)Code Implementation
train.py"""
Transformer Training Pipeline — Sequence Classification / Translation
=====================================================================
Architecture: Transformer Encoder with classification head (or Seq2Seq)
Optimizer: Adam with warmup LR schedule (as in original paper)
Target: Demonstrate attention mechanism on a classification task
"""
import argparse
import math
import mlflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
# --- Model components defined above (MultiHeadAttention, TransformerEncoderLayer) ---
class TransformerClassifier(nn.Module):
"""Transformer encoder + mean pooling + classification head."""
def __init__(self, vocab_size: int, num_classes: int, d_model: int = 256,
num_heads: int = 4, num_layers: int = 4, d_ff: int = 1024,
max_len: int = 512, dropout: float = 0.1):
super().__init__()
self.encoder = TransformerEncoder(
vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout)
self.classifier = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_model // 2, num_classes),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
enc = self.encoder(x, mask) # (B, T, d_model)
# Mean pool over non-padding positions
if mask is not None:
# mask: (B, T), True where valid
enc = enc * mask.unsqueeze(-1).float()
pooled = enc.sum(1) / mask.float().sum(1, keepdim=True)
else:
pooled = enc.mean(1) # (B, d_model)
return self.classifier(pooled)
def get_lr(step: int, d_model: int, warmup: int = 4000) -> float:
"""Transformer warmup learning rate schedule."""
step = max(step, 1)
return d_model ** -0.5 * min(step ** -0.5, step * warmup ** -1.5)
def train(cfg: dict) -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hp = cfg["hyperparams"]
# Data: assume tokenized and batched sequences with labels
# train_loader, test_loader = get_loaders(hp["batch_size"])
VOCAB_SIZE = hp.get("vocab_size", 32000)
NUM_CLASSES = hp.get("num_classes", 2)
D_MODEL = hp.get("d_model", 256)
model = TransformerClassifier(
vocab_size=VOCAB_SIZE, num_classes=NUM_CLASSES,
d_model=D_MODEL, num_heads=hp["num_heads"],
num_layers=hp["num_layers"], d_ff=hp["d_ff"],
dropout=hp["dropout"],
).to(device)
# Count parameters
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {n_params:,}")
# d_model=256, 4 layers, 4 heads -> ~4.2M params
optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # label smoothing from paper
mlflow.set_experiment("transformer-classify")
with mlflow.start_run():
mlflow.log_params(hp)
best_acc = 0.0
step = 0
for epoch in range(hp["epochs"]):
model.train()
total_loss, correct, total = 0.0, 0, 0
for X, y, mask in train_loader:
X, y, mask = X.to(device), y.to(device), mask.to(device)
step += 1
# Warmup learning rate schedule
lr = get_lr(step, D_MODEL, warmup=hp.get("warmup_steps", 4000))
for pg in optimizer.param_groups:
pg["lr"] = lr
optimizer.zero_grad()
logits = model(X, mask)
loss = criterion(logits, y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(1) == y).sum().item()
total += len(y)
# Evaluate
model.eval()
tcorrect, ttotal = 0, 0
with torch.no_grad():
for X, y, mask in test_loader:
X, y, mask = X.to(device), y.to(device), mask.to(device)
tcorrect += (model(X, mask).argmax(1) == y).sum().item()
ttotal += len(y)
test_acc = tcorrect / ttotal
mlflow.log_metrics({
"train_loss": total_loss / len(train_loader),
"train_acc": correct / total,
"test_acc": test_acc,
"lr": lr,
}, step=epoch)
print(f"Ep {epoch+1:3d} | step={step:5d} | lr={lr:.6f} "
f"| train_acc={correct/total:.4f} | test_acc={test_acc:.4f}")
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), "best_model.pt")
print(f"Best accuracy: {best_acc:.4f}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="../../config.yaml")
import yaml
train(yaml.safe_load(open(parser.parse_args().config)))
if __name__ == "__main__":
main()Analysis & Evaluation
Where Your Intuition Breaks
More attention heads means more capacity — and more capacity always helps. In practice, heads beyond 8–16 rarely improve performance on standard tasks and often hurt on small datasets. The benefit of multiple heads is diversity of attention patterns (each head specializes in different relationships), not quantity. Past the point where all meaningful relationship types are covered, additional heads add noise and parameter cost without adding signal.
Attention Pattern Visualization
Different heads in a trained model specialize in distinct linguistic relationships. The interactive diagram below shows four representative patterns from a sentiment classifier trained on IMDB — click each head to explore:
- Head 0 learns local syntax — strong diagonal with adjacency spillover. Tokens attend to their nearest neighbors, capturing phrase structure.
- Head 1 learns negation — "not" attends with weight 0.62 to "good". The model has discovered that these two tokens form a semantic unit that inverts sentiment.
- Head 2 learns subject reference — [EOS] attends with weight 0.55 back to "movie", the topic of the sentence. Useful for classification pooling.
- Head 3 learns global context — flat, spread-out attention. Acts as a mixing layer, propagating information across the full sequence.
This emergent specialization is the key motivation for multi-head attention — a single head can only capture one relationship type at a time.
Transformer vs BiLSTM: IMDB Comparison
| Model | Parameters | Test Acc | Inference (ms/batch) |
|---|---|---|---|
| BiLSTM (256) | 4.2M | 90.3% | 8.2ms |
| Transformer (256, 4L) | 4.8M | 91.7% | 5.1ms (parallelizable) |
| BERT-base fine-tuned | 110M | 95.1% | 23.4ms |
The Transformer's key advantage over the BiLSTM is not just accuracy — it's that all positions attend to all others in parallel during training. A BiLSTM must process token 1, then token 2, then token 3 sequentially. The Transformer processes all tokens simultaneously, making it dramatically faster to train on modern hardware.
Warmup Learning Rate Schedule
The original Transformer paper uses a specific warmup schedule critical for stable training:
# Learning rate over 50,000 steps with d_model=256, warmup=4000
steps = list(range(1, 50001))
lrs = [get_lr(s, d_model=256, warmup=4000) for s in steps]
# Peak LR at step 4000: 0.0056
# Step 1: 0.0000088 (very small — prevents early instability)
# Step 4000: 0.0056000 (peak)
# Step 50000: 0.0007070 (decayed by inverse sqrt)Without warmup, the model often diverges in the first few hundred steps because the attention weights are random and produce large, noisy gradients.
The warmup schedule is not optional for Transformer training — it is required. Without it, Adam with its default learning rate of 1e-3 causes the model to diverge within the first 100 steps. Alternatively, use a constant small LR (1e-4 or 5e-5) with cosine annealing, which is more commonly used for fine-tuning pre-trained models.
Production-Ready Code
"""serve_api/app.py — Transformer classification inference endpoint."""
import math
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from tokenizers import Tokenizer
app = FastAPI(title="Transformer Classifier API", version="1.0.0")
# Load tokenizer and model at startup
tokenizer = Tokenizer.from_file("bpe_tokenizer.json")
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
tokenizer.enable_truncation(max_length=512)
# (Model classes MultiHeadAttention, TransformerEncoderLayer,
# TransformerEncoder, TransformerClassifier defined above)
model = TransformerClassifier(
vocab_size=32000, num_classes=2, d_model=256,
num_heads=4, num_layers=4, d_ff=1024,
)
model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
model.eval()
LABELS = {0: "negative", 1: "positive"}
class TextRequest(BaseModel):
text: str
@app.post("/predict")
def predict(req: TextRequest):
enc = tokenizer.encode(req.text)
ids = torch.tensor(enc.ids).unsqueeze(0)
mask = torch.tensor(enc.attention_mask, dtype=torch.bool).unsqueeze(0)
with torch.inference_mode():
logits = model(ids, mask)
probs = F.softmax(logits, dim=-1).squeeze()
pred = probs.argmax().item()
return JSONResponse({
"prediction": LABELS[pred],
"confidence": round(probs[pred].item(), 4),
"token_count": len(enc.ids),
})KV Cache
During autoregressive generation, you produce one token at a time. At step , the new token must attend to all previous tokens. Without caching, you would recompute and for every previous token at every step — total computation to generate tokens.
KV cache stores the computed keys and values for each past token. At step , only compute and for the new token, then concatenate to the cache:
# Pseudocode: autoregressive generation with KV cache
past_k, past_v = [], [] # cache per layer
for step in range(max_new_tokens):
q = compute_query(x_new) # only new token
k = compute_key(x_new)
v = compute_value(x_new)
past_k.append(k)
past_v.append(v)
K_full = torch.cat(past_k, dim=1) # all tokens so far
V_full = torch.cat(past_v, dim=1)
out = attention(q, K_full, V_full) # attend over full context
x_new = sample_next_token(out)Cache memory cost. For each layer and each head, you store a tensor for K and another for V. Total:
For Llama 2-7B (32 layers, 32 heads, , fp16) at 4096 tokens: per request. This is why GQA matters — reducing heads from 32 to 8 cuts cache to 1 GB per request, enabling 4× more concurrent users on the same GPU.
Production deployment checklist for Transformer models:
- KV cache: enabled by default in vLLM, TGI, and TensorRT-LLM. Use GQA to reduce cache size — for many-user deployments, cache memory is the binding constraint, not compute.
- Paged attention (vLLM): shares cache pages across requests, dramatically improving GPU utilization for variable-length batches.
- ONNX export: 2–3× inference speedup over PyTorch eager mode. Export with dynamic axes to handle variable sequence lengths.
- Hardware-specific optimization: TensorRT (NVIDIA) or OpenVINO (Intel) for maximum throughput on dedicated hardware.
- Batching and versioning: Triton Inference Server handles dynamic batching, model versioning, and GPU memory management automatically.
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.