Neural-Path/Notes
40 min

Attention & Transformers

The Transformer — introduced in 'Attention Is All You Need' (Vaswani et al., 2017) — discarded recurrence entirely in favor of attention: a mechanism that lets every position in a sequence directly attend to every other position. This parallel computation unlocked training on vastly larger datasets and gave rise to Bidirectional Encoder Representations from Transformers (BERT), Generative Pre-trained Transformer (GPT), T5, and all modern large language models.

Theory

Local syntax — attends to adjacent tokens
Themoviewasnotgood[EOS]Themoviewasnotgood[EOS]0.500.350.300.450.150.250.500.300.480.320.480.320.40
row = query token · col = key token · hover row to highlight

Attention is a soft, differentiable lookup. You have a query — what you're looking for — and a set of keys, each paired with a value. The attention mechanism scores how well each key matches the query, turns those scores into weights that sum to 1, and returns a weighted blend of the values. The heatmap above shows those weights: brighter cells mean "this position paid more attention to that position." Each row is a query; each column is a key.

Scaled Dot-Product Attention

Attention maps a set of queries QQ, keys KK, and values VV to an output:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

where:

  • QRn×dkQ \in \mathbb{R}^{n \times d_k} — queries (one per output position)
  • KRm×dkK \in \mathbb{R}^{m \times d_k} — keys (one per input position)
  • VRm×dvV \in \mathbb{R}^{m \times d_v} — values (one per input position)
  • dkd_k — key/query dimension (controls dot-product scale)

The dk\sqrt{d_k} scaling prevents dot products from growing too large as dkd_k increases, which would push the softmax into near-zero gradient regions.

The dk\sqrt{d_k} scaling is forced by the geometry of high-dimensional dot products. As dkd_k grows, the variance of QKQK^\top grows proportionally. Without scaling, softmax receives very large inputs, producing near-one-hot distributions where almost all weight goes to a single key. Gradients through a near-one-hot softmax are near zero — the network can't learn. Dividing by dk\sqrt{d_k} stabilizes the variance regardless of embedding size.

💡Intuition

Think of attention as a soft dictionary lookup. You have a query (what you're looking for), keys (what each entry is about), and values (what each entry contains). The dot product QK-transpose measures relevance. Softmax converts scores to a probability distribution, and we take a weighted average of values. Unlike a hard lookup, we retrieve a blend of multiple entries. See Attention Pattern Visualization in the Analysis section for an interactive demo of how learned attention weights look in practice.

Why Separate Q, K, and V?

A common question: why not just compute attention directly between the raw token embeddings xix_i and xjx_j? Why project to three separate matrices?

Asymmetry of asking vs. answering. If Q=K=xQ = K = x, every token would attend most strongly to itself (highest cosine similarity), and attention patterns would be dominated by token identity rather than relationships. Separate projections allow a token to ask "what do I need?" (Q) independently of "what do I offer?" (K).

Decoupling retrieval from relevance. The value projection VV is independent of KK. This means what you retrieve can be entirely different from what you used to decide whether to retrieve it. A token might use its syntactic role (key) to be found, but contribute its semantic content (value) once attended to.

Low-rank compression. Each projection maps dmodeldkd_{model} \to d_k (typically dk=dmodel/hd_k = d_{model}/h). This prevents any single feature dimension from dominating attention scores across all heads — each head operates in its own lower-dimensional subspace, which acts as an inductive bias toward specialization.

In practice: WQW^Q, WKW^K, WVW^V are all learned, and each develops a distinct structure. WKW^K and WQW^Q tend to capture relational features (subject-verb agreement, coreference); WVW^V tends to capture content features that get aggregated.

Multi-Head Attention

A single attention head can only attend to one "subspace" at a time. Multi-head attention runs hh attention heads in parallel, each with its own projection matrices:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O headi=Attention(QWiQ,  KWiK,  VWiV)\text{head}_i = \text{Attention}(Q W_i^Q,\; K W_i^K,\; V W_i^V)

where WiQ,WiKRdmodel×dkW_i^Q, W_i^K \in \mathbb{R}^{d_{model} \times d_k}, WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{model} \times d_v}, and WORhdv×dmodelW^O \in \mathbb{R}^{hd_v \times d_{model}}.

Typically dk=dv=dmodel/hd_k = d_v = d_{model} / h, so total computation matches a single large head. With h=8h = 8 and dmodel=512d_{model} = 512, each head operates in 64 dimensions. Different heads learn to attend to different aspects simultaneously: syntax, co-reference, long-range dependencies, etc.

The total parameter count for multi-head attention: 4dmodel24 d_{model}^2 (four weight matrices WQ,WK,WV,WOW^Q, W^K, W^V, W^O).

Modern Variants: MQA and GQA

Standard multi-head attention stores one set of K and V matrices per head. At inference time these must be cached for every generated token (see KV Cache in Production). For large models with long contexts this becomes the primary memory bottleneck.

Multi-Query Attention (MQA) (Shazeer, 2019): all hh heads share a single KK and VV projection, but each head retains its own QQ:

headi=Attention(QWiQ,  KWK,  VWV)\text{head}_i = \text{Attention}(Q W_i^Q,\; K W^K,\; V W^V)

KV cache shrinks by h×h\times. Quality drops slightly because keys and values carry no per-head specialization.

Grouped-Query Attention (GQA) (Ainslie et al., 2023): a middle ground — hh query heads are divided into GG groups, each group sharing one K and V:

headi=Attention(QWiQ,  KWiG/hK,  VWiG/hV)\text{head}_i = \text{Attention}(Q W_i^Q,\; K W_{\lceil iG/h \rceil}^K,\; V W_{\lceil iG/h \rceil}^V)

KV cache shrinks by h/G×h/G\times. With G=hG = h this recovers MHA; with G=1G = 1 this recovers MQA. In practice G=h/4G = h/4 retains most of MHA's quality while cutting cache to 25%.

GQA is now the default in most production LLMs: Llama 2 70B, Llama 3, Mistral 7B, Gemma, and Falcon all use GQA. Understanding this is essential for reading modern model cards.

VariantKV headsCache sizeQuality
MHAhhbaselinebest
GQA (G=h/4G=h/4)h/4h/425%near-MHA
MQA11/h1/hslightly lower

Positional Encoding

Since attention is permutation-invariant, we inject position information by adding sinusoidal encodings to the input embeddings:

PE(pos,2i)=sin ⁣(pos100002i/dmodel),PE(pos,2i+1)=cos ⁣(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\!\left(\frac{pos}{10000^{2i/d_{model}}}\right), \quad PE_{(pos, 2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d_{model}}}\right)

These allow the model to generalize to sequence lengths not seen during training. Modern models (GPT-2 onwards) typically use learned positional embeddings or RoPE (Rotary Position Embedding), which encodes position via rotation matrices applied to Q and K.

Causal Masking

For autoregressive generation (GPT-style), each position must not attend to future tokens. We apply a causal mask before the softmax:

Attention(Q,K,V)=softmax ⁣(QKdk+M)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}} + M\right) V

where Mij=M_{ij} = -\infty if j>ij > i (future position), else 00. After softmax, future positions receive weight 0.

ℹ️Note

The computational cost of full self-attention is O(n2dk)O(n^2 d_k) — quadratic in sequence length. For n=2048n = 2048 this is fine; for n=100,000n = 100{,}000 it becomes prohibitive. FlashAttention (Dao et al., 2022) achieves the same result with O(n)O(n) memory by computing attention in tiles, enabling context lengths of 100K+ tokens.

Transformer Encoder Block

Each encoder layer applies multi-head self-attention followed by a feedforward network, with residual connections and layer normalization:

x1=LayerNorm(x+MultiHead(x,x,x))x_1 = \text{LayerNorm}(x + \text{MultiHead}(x, x, x)) x2=LayerNorm(x1+FFN(x1))x_2 = \text{LayerNorm}(x_1 + \text{FFN}(x_1))

The FFN expands to 4dmodel4 d_{model} in the hidden layer: FFN(x)=max(0,xW1+b1)W2+b2\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2, with W1Rdmodel×4dmodelW_1 \in \mathbb{R}^{d_{model} \times 4d_{model}}.

Walkthrough

Machine Translation: WMT En-De

The Workshop on Machine Translation (WMT) English-German dataset contains approximately 4.5 million sentence pairs.

Tokenization

Modern transformers use subword tokenization rather than word-level tokens, balancing vocabulary size vs. coverage:

python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
 
# Train BPE tokenizer on the training corpus
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(
    vocab_size=32000,        # typical for multilingual models
    special_tokens=["[UNK]", "[PAD]", "[BOS]", "[EOS]"],
    min_frequency=2,
)
files = ["train_en.txt", "train_de.txt"]
tokenizer.train(files, trainer)
tokenizer.save("bpe_tokenizer.json")
 
# Example tokenization:
# "Unbelievable" -> ["Un", "believ", "able"]   (3 tokens)
# "cats"         -> ["cats"]                   (1 token, frequent word)
# "Katze"        -> ["Katz", "e"]              (2 tokens in German BPE)

Building the Mini Transformer

python
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int = 512, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model    = d_model
        self.num_heads  = num_heads
        self.d_k        = d_model // num_heads
 
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
 
    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, d_model) -> (B, h, T, d_k)
        B, T, _ = x.shape
        return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
 
    def forward(self, q, k, v, mask=None):
        B = q.size(0)
        Q = self.split_heads(self.W_q(q))  # (B, h, T_q, d_k)
        K = self.split_heads(self.W_k(k))  # (B, h, T_k, d_k)
        V = self.split_heads(self.W_v(v))  # (B, h, T_k, d_k)
 
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        attn   = self.dropout(F.softmax(scores, dim=-1))
        out    = torch.matmul(attn, V)     # (B, h, T_q, d_k)
 
        # Concatenate heads and project
        out = out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
        return self.W_o(out)               # (B, T_q, d_model)
 
 
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn       = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )
        self.norm1   = nn.LayerNorm(d_model)
        self.norm2   = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x, mask=None):
        # Pre-norm formulation (more stable than original post-norm)
        x = x + self.dropout(self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
        x = x + self.dropout(self.ffn(self.norm2(x)))
        return x
 
 
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=4, num_layers=3,
                 d_ff=1024, max_len=512, dropout=0.1):
        super().__init__()
        self.embedding  = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_embed  = nn.Embedding(max_len, d_model)
        self.layers     = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm       = nn.LayerNorm(d_model)
        self.dropout    = nn.Dropout(dropout)
        self.d_model    = d_model
        # Scale embeddings by sqrt(d_model) as in the original paper
        self._init_weights()
 
    def _init_weights(self):
        nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model ** -0.5)
 
    def forward(self, x, mask=None):
        B, T = x.shape
        positions = torch.arange(T, device=x.device).unsqueeze(0)  # (1, T)
        h = self.dropout(
            self.embedding(x) * math.sqrt(self.d_model) + self.pos_embed(positions)
        )
        for layer in self.layers:
            h = layer(h, mask)
        return self.norm(h)

Code Implementation

train.py
python
"""
Transformer Training Pipeline — Sequence Classification / Translation
=====================================================================
Architecture: Transformer Encoder with classification head (or Seq2Seq)
Optimizer:    Adam with warmup LR schedule (as in original paper)
Target:       Demonstrate attention mechanism on a classification task
"""
import argparse
import math
 
import mlflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
 
 
# --- Model components defined above (MultiHeadAttention, TransformerEncoderLayer) ---
 
 
class TransformerClassifier(nn.Module):
    """Transformer encoder + mean pooling + classification head."""
    def __init__(self, vocab_size: int, num_classes: int, d_model: int = 256,
                 num_heads: int = 4, num_layers: int = 4, d_ff: int = 1024,
                 max_len: int = 512, dropout: float = 0.1):
        super().__init__()
        self.encoder = TransformerEncoder(
            vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, num_classes),
        )
 
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        enc = self.encoder(x, mask)             # (B, T, d_model)
        # Mean pool over non-padding positions
        if mask is not None:
            # mask: (B, T), True where valid
            enc = enc * mask.unsqueeze(-1).float()
            pooled = enc.sum(1) / mask.float().sum(1, keepdim=True)
        else:
            pooled = enc.mean(1)                # (B, d_model)
        return self.classifier(pooled)
 
 
def get_lr(step: int, d_model: int, warmup: int = 4000) -> float:
    """Transformer warmup learning rate schedule."""
    step = max(step, 1)
    return d_model ** -0.5 * min(step ** -0.5, step * warmup ** -1.5)
 
 
def train(cfg: dict) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    hp     = cfg["hyperparams"]
 
    # Data: assume tokenized and batched sequences with labels
    # train_loader, test_loader = get_loaders(hp["batch_size"])
 
    VOCAB_SIZE  = hp.get("vocab_size", 32000)
    NUM_CLASSES = hp.get("num_classes", 2)
    D_MODEL     = hp.get("d_model", 256)
 
    model = TransformerClassifier(
        vocab_size=VOCAB_SIZE, num_classes=NUM_CLASSES,
        d_model=D_MODEL, num_heads=hp["num_heads"],
        num_layers=hp["num_layers"], d_ff=hp["d_ff"],
        dropout=hp["dropout"],
    ).to(device)
 
    # Count parameters
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model parameters: {n_params:,}")
    # d_model=256, 4 layers, 4 heads -> ~4.2M params
 
    optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # label smoothing from paper
 
    mlflow.set_experiment("transformer-classify")
    with mlflow.start_run():
        mlflow.log_params(hp)
        best_acc = 0.0
        step     = 0
 
        for epoch in range(hp["epochs"]):
            model.train()
            total_loss, correct, total = 0.0, 0, 0
 
            for X, y, mask in train_loader:
                X, y, mask = X.to(device), y.to(device), mask.to(device)
                step += 1
 
                # Warmup learning rate schedule
                lr = get_lr(step, D_MODEL, warmup=hp.get("warmup_steps", 4000))
                for pg in optimizer.param_groups:
                    pg["lr"] = lr
 
                optimizer.zero_grad()
                logits = model(X, mask)
                loss   = criterion(logits, y)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
 
                total_loss += loss.item()
                correct    += (logits.argmax(1) == y).sum().item()
                total      += len(y)
 
            # Evaluate
            model.eval()
            tcorrect, ttotal = 0, 0
            with torch.no_grad():
                for X, y, mask in test_loader:
                    X, y, mask = X.to(device), y.to(device), mask.to(device)
                    tcorrect += (model(X, mask).argmax(1) == y).sum().item()
                    ttotal   += len(y)
            test_acc = tcorrect / ttotal
 
            mlflow.log_metrics({
                "train_loss": total_loss / len(train_loader),
                "train_acc":  correct / total,
                "test_acc":   test_acc,
                "lr":         lr,
            }, step=epoch)
            print(f"Ep {epoch+1:3d} | step={step:5d} | lr={lr:.6f} "
                  f"| train_acc={correct/total:.4f} | test_acc={test_acc:.4f}")
 
            if test_acc > best_acc:
                best_acc = test_acc
                torch.save(model.state_dict(), "best_model.pt")
 
        print(f"Best accuracy: {best_acc:.4f}")
 
 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="../../config.yaml")
    import yaml
    train(yaml.safe_load(open(parser.parse_args().config)))
 
 
if __name__ == "__main__":
    main()

Analysis & Evaluation

Where Your Intuition Breaks

More attention heads means more capacity — and more capacity always helps. In practice, heads beyond 8–16 rarely improve performance on standard tasks and often hurt on small datasets. The benefit of multiple heads is diversity of attention patterns (each head specializes in different relationships), not quantity. Past the point where all meaningful relationship types are covered, additional heads add noise and parameter cost without adding signal.

Attention Pattern Visualization

Different heads in a trained model specialize in distinct linguistic relationships. The interactive diagram below shows four representative patterns from a sentiment classifier trained on IMDB — click each head to explore:

  • Head 0 learns local syntax — strong diagonal with adjacency spillover. Tokens attend to their nearest neighbors, capturing phrase structure.
  • Head 1 learns negation — "not" attends with weight 0.62 to "good". The model has discovered that these two tokens form a semantic unit that inverts sentiment.
  • Head 2 learns subject reference — [EOS] attends with weight 0.55 back to "movie", the topic of the sentence. Useful for classification pooling.
  • Head 3 learns global context — flat, spread-out attention. Acts as a mixing layer, propagating information across the full sequence.

This emergent specialization is the key motivation for multi-head attention — a single head can only capture one relationship type at a time.

Transformer vs BiLSTM: IMDB Comparison

ModelParametersTest AccInference (ms/batch)
BiLSTM (256)4.2M90.3%8.2ms
Transformer (256, 4L)4.8M91.7%5.1ms (parallelizable)
BERT-base fine-tuned110M95.1%23.4ms
💡Intuition

The Transformer's key advantage over the BiLSTM is not just accuracy — it's that all positions attend to all others in parallel during training. A BiLSTM must process token 1, then token 2, then token 3 sequentially. The Transformer processes all tokens simultaneously, making it dramatically faster to train on modern hardware.

Warmup Learning Rate Schedule

The original Transformer paper uses a specific warmup schedule critical for stable training:

python
# Learning rate over 50,000 steps with d_model=256, warmup=4000
steps = list(range(1, 50001))
lrs   = [get_lr(s, d_model=256, warmup=4000) for s in steps]
 
# Peak LR at step 4000: 0.0056
# Step 1:    0.0000088  (very small — prevents early instability)
# Step 4000: 0.0056000  (peak)
# Step 50000: 0.0007070 (decayed by inverse sqrt)

Without warmup, the model often diverges in the first few hundred steps because the attention weights are random and produce large, noisy gradients.

⚠️Warning

The warmup schedule is not optional for Transformer training — it is required. Without it, Adam with its default learning rate of 1e-3 causes the model to diverge within the first 100 steps. Alternatively, use a constant small LR (1e-4 or 5e-5) with cosine annealing, which is more commonly used for fine-tuning pre-trained models.

Production-Ready Code

python
"""serve_api/app.py — Transformer classification inference endpoint."""
import math
from pathlib import Path
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from tokenizers import Tokenizer
 
app = FastAPI(title="Transformer Classifier API", version="1.0.0")
 
# Load tokenizer and model at startup
tokenizer = Tokenizer.from_file("bpe_tokenizer.json")
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
tokenizer.enable_truncation(max_length=512)
 
# (Model classes MultiHeadAttention, TransformerEncoderLayer,
#  TransformerEncoder, TransformerClassifier defined above)
 
model = TransformerClassifier(
    vocab_size=32000, num_classes=2, d_model=256,
    num_heads=4, num_layers=4, d_ff=1024,
)
model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
model.eval()
 
LABELS = {0: "negative", 1: "positive"}
 
 
class TextRequest(BaseModel):
    text: str
 
 
@app.post("/predict")
def predict(req: TextRequest):
    enc    = tokenizer.encode(req.text)
    ids    = torch.tensor(enc.ids).unsqueeze(0)
    mask   = torch.tensor(enc.attention_mask, dtype=torch.bool).unsqueeze(0)
 
    with torch.inference_mode():
        logits = model(ids, mask)
        probs  = F.softmax(logits, dim=-1).squeeze()
 
    pred = probs.argmax().item()
    return JSONResponse({
        "prediction":    LABELS[pred],
        "confidence":    round(probs[pred].item(), 4),
        "token_count":   len(enc.ids),
    })

KV Cache

During autoregressive generation, you produce one token at a time. At step tt, the new token must attend to all t1t-1 previous tokens. Without caching, you would recompute KK and VV for every previous token at every step — O(n2)O(n^2) total computation to generate nn tokens.

KV cache stores the computed keys and values for each past token. At step tt, only compute KK and VV for the new token, then concatenate to the cache:

python
# Pseudocode: autoregressive generation with KV cache
past_k, past_v = [], []   # cache per layer
 
for step in range(max_new_tokens):
    q = compute_query(x_new)          # only new token
    k = compute_key(x_new)
    v = compute_value(x_new)
 
    past_k.append(k)
    past_v.append(v)
 
    K_full = torch.cat(past_k, dim=1)  # all tokens so far
    V_full = torch.cat(past_v, dim=1)
 
    out = attention(q, K_full, V_full) # attend over full context
    x_new = sample_next_token(out)

Cache memory cost. For each layer and each head, you store a (seq_len×dk)(\text{seq\_len} \times d_k) tensor for K and another for V. Total:

Cache size=2×nlayers×nheads×seq_len×dk×bytes_per_element\text{Cache size} = 2 \times n_{\text{layers}} \times n_{\text{heads}} \times \text{seq\_len} \times d_k \times \text{bytes\_per\_element}

For Llama 2-7B (32 layers, 32 heads, dk=128d_k = 128, fp16) at 4096 tokens: 4GB\approx 4\,\text{GB} per request. This is why GQA matters — reducing heads from 32 to 8 cuts cache to 1 GB per request, enabling 4× more concurrent users on the same GPU.

🚀Production

Production deployment checklist for Transformer models:

  • KV cache: enabled by default in vLLM, TGI, and TensorRT-LLM. Use GQA to reduce cache size — for many-user deployments, cache memory is the binding constraint, not compute.
  • Paged attention (vLLM): shares cache pages across requests, dramatically improving GPU utilization for variable-length batches.
  • ONNX export: 2–3× inference speedup over PyTorch eager mode. Export with dynamic axes to handle variable sequence lengths.
  • Hardware-specific optimization: TensorRT (NVIDIA) or OpenVINO (Intel) for maximum throughput on dedicated hardware.
  • Batching and versioning: Triton Inference Server handles dynamic batching, model versioning, and GPU memory management automatically.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.