SFT: Supervised Fine-Tuning
Pretraining gives a language model broad knowledge but no sense of what it should produce — a raw pretrained model continues text, not instructions. Supervised Fine-Tuning (SFT) is the first alignment step: fine-tuning on curated (prompt, completion) pairs teaches the model what kind of output to generate. Every modern alignment pipeline — RLHF, DPO, GRPO — begins from an SFT checkpoint. Understanding SFT is the prerequisite for every lesson in this module.
Theory
Pretraining teaches a model to complete any text; SFT teaches it which completions to choose. The diagram above shows the key mechanism: the loss mask blocks gradient updates on prompt tokens (the question is given, not learned), and only the response tokens receive training signal. This is how "predict the next token" becomes "follow the instruction."
The SFT Objective
SFT is standard next-token prediction, but with a loss mask that restricts gradient updates to completion tokens only. Given a dataset of (prompt , completion ) pairs:
where is the loss mask: for completion tokens, for prompt tokens.
The loss mask is not optional — it's what distinguishes fine-tuning from pretraining. Without masking, the model is penalized for not predicting the prompt, creating a circular objective: the model would learn to "expect" the prompt token sequence and assign high loss when the prompt changes. Masking concentrates gradient signal entirely on the response, which is the only part the model should be learning to generate.
Think of SFT as studying worked solutions: the problem statement is given (no learning signal there), and the student learns by practicing writing out the answer. Gradient flows only through the answer tokens.
Data Format: Chat Template
Models use a structured chat template to demarcate speaker turns. The loss mask is applied at the template level: all tokens up to and including the assistant start token are masked.
A typical ChatML format:
<|im_start|>system
You are a helpful assistant.
<|im_end|>
<|im_start|>user
What is 7 * 8?
<|im_end|>
<|im_start|>assistant
56
<|im_end|>
Everything before <|im_start|>assistant (inclusive) has . The assistant's response and the closing <|im_end|> have .
Common mistake: training without the chat template, or using the wrong template for the base model. The model fine-tunes but produces garbled output at inference because the token patterns don't match the template it was trained with.
LoRA: Parameter-Efficient Fine-Tuning
Full fine-tuning updates all parameters — impractical at 7B+ scale. Low-Rank Adaptation (LoRA, Hu et al., 2022) freezes the pretrained weights and learns a low-rank residual:
where:
- , , rank
- is a scaling hyperparameter; scales the effective learning rate contribution
- is initialized from ; is initialized to zero so at training start
Trainable parameters per weight matrix: vs. for full fine-tuning. At , : ~65K vs. 16.8M — a 256× reduction.
Gradient Flow in LoRA
During the forward pass, (but is frozen). During backprop, only and receive gradients:
No optimizer states for the frozen matrices — memory drops proportionally.
Walkthrough
Fine-Tuning for Code Generation
Dataset: 10,000 (instruction, code) pairs formatted in ChatML.
Step 1 — Tokenize with chat template:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
chat = [
{"role": "system", "content": "You are a Python expert."},
{"role": "user", "content": "Write a function to reverse a string."},
{"role": "assistant", "content": "def reverse(s):\n return s[::-1]"},
]
tokens = tokenizer.apply_chat_template(
chat, tokenize=True, return_tensors="pt", add_generation_prompt=False
)Step 2 — Train with TRL (handles masking automatically):
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
lora_cfg = LoraConfig(
r=8, lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=lora_cfg,
args=SFTConfig(
output_dir="./sft-output",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
),
)
trainer.train()Step 3 — Merge LoRA weights for deployment:
merged = trainer.model.merge_and_unload()
merged.save_pretrained("./sft-merged")Expected behavior: training loss starts around 2.0–2.5, converges to 0.8–1.2 on a clean code dataset after 2 epochs. If loss stays above 1.8, verify the mask is applied — you may be computing loss on prompt tokens.
Analysis & Evaluation
Where Your Intuition Breaks
More SFT data always produces a better model — quantity drives capability. SFT data quality dominates quantity. A model fine-tuned on 1,000 carefully curated (prompt, response) pairs consistently outperforms one trained on 100,000 scraped completions containing errors, inconsistencies, and off-policy behavior. The model learns the response distribution from the training set — noisy data teaches a noisy distribution. This is why Llama 2's alignment team found that 27,540 high-quality preference examples outperformed millions of automatic annotations on downstream tasks.
SFT vs. Prompting
| Scenario | SFT | Prompting |
|---|---|---|
| Consistent output format | Strong | Fragile with few-shot only |
| Teach behavior not in pretraining data | Required | Won't generalize |
| Dataset under 500 examples | Use LoRA | Prefer few-shot |
| Distill long reasoning chains | Chain-of-thought distillation | Context window limits |
| Rapid prototyping | Requires training run | Instant |
| Reduce inference cost (smaller model) | Fine-tuned small model works | Needs larger model |
LoRA Rank Trade-offs
| Rank | Trainable params (7B model) | Typical use case |
|---|---|---|
| r=4 | ~4M | Simple formatting, style |
| r=8 | ~8M | Instruction following (default) |
| r=16 | ~16M | Domain adaptation |
| r=64 | ~64M | Complex new capabilities |
Starting point: r=8, alpha=16 (alpha = 2r). Increase rank if validation loss plateaus early; reduce rank or add dropout if overfitting.
Common Failure Modes
Prompt bleeding: model echoes user input. Cause: forgot to mask prompt tokens. Fix: verify masking by inspecting labels in the data collator — masked positions should be -100 (PyTorch's ignore index).
Format collapse: model produces the same structural pattern for all inputs. Cause: too many near-identical examples. Fix: diversify dataset format.
Catastrophic forgetting: model loses general capability after fine-tuning. LoRA largely prevents this since base weights are frozen, but full fine-tuning is vulnerable. Fix: use LoRA or add small regularization toward the base checkpoint.
The three SFT settings that matter most in practice:
- Learning rate: 1e-4 to 3e-4 for LoRA; 1e-5 to 5e-5 for full fine-tuning. Too high = loss spikes; too low = no adaptation.
- Epochs: 1–3 is almost always enough. Beyond 3 epochs overfits or degrades general capability.
- Chat template: must match the base model's exact expected format. A mismatched template is a silent failure — training converges but inference outputs are wrong. Always verify with
tokenizer.apply_chat_templatebefore training.
When SFT alone is not enough: if the task requires learning from relative preferences — e.g., "be helpful but not verbose, prefer concise answers" — demonstrations don't encode the ranking. That requires DPO or GRPO, covered in the next two lessons.
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.