Neural-Path/Notes
25 min

SFT: Supervised Fine-Tuning

Pretraining gives a language model broad knowledge but no sense of what it should produce — a raw pretrained model continues text, not instructions. Supervised Fine-Tuning (SFT) is the first alignment step: fine-tuning on curated (prompt, completion) pairs teaches the model what kind of output to generate. Every modern alignment pipeline — RLHF, DPO, GRPO — begins from an SFT checkpoint. Understanding SFT is the prerequisite for every lesson in this module.

Theory

SFT token masking — loss only on assistant tokens
<|system|>
You
are
helpful
<|user|>
What
is
7*8?
<|assistant|>
56
∇θ
<|end|>
∇θ
<mask = 0 (prompt)mask = 1 (completion) — gradient flows here
■ system — M_t = 0■ user — M_t = 0■ assistant — M_t = 1, loss computed

Pretraining teaches a model to complete any text; SFT teaches it which completions to choose. The diagram above shows the key mechanism: the loss mask blocks gradient updates on prompt tokens (the question is given, not learned), and only the response tokens receive training signal. This is how "predict the next token" becomes "follow the instruction."

The SFT Objective

SFT is standard next-token prediction, but with a loss mask that restricts gradient updates to completion tokens only. Given a dataset of (prompt xx, completion yy) pairs:

LSFT=E(x,y)D[t=1TMtlogπθ(ytx,y<t)]\mathcal{L}_{\text{SFT}} = -\mathbb{E}_{(x,y)\sim\mathcal{D}} \left[ \sum_{t=1}^{T} M_t \cdot \log \pi_\theta(y_t \mid x, y_{<t}) \right]

where Mt{0,1}M_t \in \{0, 1\} is the loss mask: Mt=1M_t = 1 for completion tokens, Mt=0M_t = 0 for prompt tokens.

The loss mask is not optional — it's what distinguishes fine-tuning from pretraining. Without masking, the model is penalized for not predicting the prompt, creating a circular objective: the model would learn to "expect" the prompt token sequence and assign high loss when the prompt changes. Masking concentrates gradient signal entirely on the response, which is the only part the model should be learning to generate.

💡Intuition

Think of SFT as studying worked solutions: the problem statement is given (no learning signal there), and the student learns by practicing writing out the answer. Gradient flows only through the answer tokens.

Data Format: Chat Template

Models use a structured chat template to demarcate speaker turns. The loss mask is applied at the template level: all tokens up to and including the assistant start token are masked.

A typical ChatML format:

<|im_start|>system
You are a helpful assistant.
<|im_end|>
<|im_start|>user
What is 7 * 8?
<|im_end|>
<|im_start|>assistant
56
<|im_end|>

Everything before <|im_start|>assistant (inclusive) has Mt=0M_t = 0. The assistant's response and the closing <|im_end|> have Mt=1M_t = 1.

Common mistake: training without the chat template, or using the wrong template for the base model. The model fine-tunes but produces garbled output at inference because the token patterns don't match the template it was trained with.

LoRA: Parameter-Efficient Fine-Tuning

Full fine-tuning updates all parameters — impractical at 7B+ scale. Low-Rank Adaptation (LoRA, Hu et al., 2022) freezes the pretrained weights WRd×kW \in \mathbb{R}^{d \times k} and learns a low-rank residual:

W=W+ΔW=W+αrABW' = W + \Delta W = W + \frac{\alpha}{r}\, A B^\top

where:

  • ARd×rA \in \mathbb{R}^{d \times r}, BRk×rB \in \mathbb{R}^{k \times r}, rank rmin(d,k)r \ll \min(d, k)
  • α\alpha is a scaling hyperparameter; α/r\alpha/r scales the effective learning rate contribution
  • AA is initialized from N(0,σ2)\mathcal{N}(0, \sigma^2); BB is initialized to zero so ΔW=0\Delta W = 0 at training start

Trainable parameters per weight matrix: r(d+k)r(d+k) vs. dkdk for full fine-tuning. At r=8r=8, d=k=4096d=k=4096: ~65K vs. 16.8M — a 256× reduction.

Gradient Flow in LoRA

During the forward pass, Wh=Wh+αrABhW' h = Wh + \frac{\alpha}{r} AB^\top h (but WW is frozen). During backprop, only AA and BB receive gradients:

LA=LWBαr,LB=(LW)Aαr\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial W'} B \frac{\alpha}{r}, \qquad \frac{\partial \mathcal{L}}{\partial B} = \left(\frac{\partial \mathcal{L}}{\partial W'}\right)^\top A \frac{\alpha}{r}

No optimizer states for the frozen WW matrices — memory drops proportionally.

Walkthrough

Fine-Tuning for Code Generation

Dataset: 10,000 (instruction, code) pairs formatted in ChatML.

Step 1 — Tokenize with chat template:

python
from transformers import AutoTokenizer
 
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
chat = [
    {"role": "system", "content": "You are a Python expert."},
    {"role": "user", "content": "Write a function to reverse a string."},
    {"role": "assistant", "content": "def reverse(s):\n    return s[::-1]"},
]
tokens = tokenizer.apply_chat_template(
    chat, tokenize=True, return_tensors="pt", add_generation_prompt=False
)

Step 2 — Train with TRL (handles masking automatically):

python
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
 
lora_cfg = LoraConfig(
    r=8, lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
)
 
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=lora_cfg,
    args=SFTConfig(
        output_dir="./sft-output",
        num_train_epochs=2,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        warmup_ratio=0.05,
        lr_scheduler_type="cosine",
    ),
)
trainer.train()

Step 3 — Merge LoRA weights for deployment:

python
merged = trainer.model.merge_and_unload()
merged.save_pretrained("./sft-merged")

Expected behavior: training loss starts around 2.0–2.5, converges to 0.8–1.2 on a clean code dataset after 2 epochs. If loss stays above 1.8, verify the mask is applied — you may be computing loss on prompt tokens.

Analysis & Evaluation

Where Your Intuition Breaks

More SFT data always produces a better model — quantity drives capability. SFT data quality dominates quantity. A model fine-tuned on 1,000 carefully curated (prompt, response) pairs consistently outperforms one trained on 100,000 scraped completions containing errors, inconsistencies, and off-policy behavior. The model learns the response distribution from the training set — noisy data teaches a noisy distribution. This is why Llama 2's alignment team found that 27,540 high-quality preference examples outperformed millions of automatic annotations on downstream tasks.

SFT vs. Prompting

ScenarioSFTPrompting
Consistent output formatStrongFragile with few-shot only
Teach behavior not in pretraining dataRequiredWon't generalize
Dataset under 500 examplesUse LoRAPrefer few-shot
Distill long reasoning chainsChain-of-thought distillationContext window limits
Rapid prototypingRequires training runInstant
Reduce inference cost (smaller model)Fine-tuned small model worksNeeds larger model

LoRA Rank Trade-offs

RankTrainable params (7B model)Typical use case
r=4~4MSimple formatting, style
r=8~8MInstruction following (default)
r=16~16MDomain adaptation
r=64~64MComplex new capabilities

Starting point: r=8, alpha=16 (alpha = 2r). Increase rank if validation loss plateaus early; reduce rank or add dropout if overfitting.

Common Failure Modes

Prompt bleeding: model echoes user input. Cause: forgot to mask prompt tokens. Fix: verify masking by inspecting labels in the data collator — masked positions should be -100 (PyTorch's ignore index).

Format collapse: model produces the same structural pattern for all inputs. Cause: too many near-identical examples. Fix: diversify dataset format.

Catastrophic forgetting: model loses general capability after fine-tuning. LoRA largely prevents this since base weights are frozen, but full fine-tuning is vulnerable. Fix: use LoRA or add small regularization toward the base checkpoint.

🚀Production

The three SFT settings that matter most in practice:

  • Learning rate: 1e-4 to 3e-4 for LoRA; 1e-5 to 5e-5 for full fine-tuning. Too high = loss spikes; too low = no adaptation.
  • Epochs: 1–3 is almost always enough. Beyond 3 epochs overfits or degrades general capability.
  • Chat template: must match the base model's exact expected format. A mismatched template is a silent failure — training converges but inference outputs are wrong. Always verify with tokenizer.apply_chat_template before training.

When SFT alone is not enough: if the task requires learning from relative preferences — e.g., "be helpful but not verbose, prefer concise answers" — demonstrations don't encode the ranking. That requires DPO or GRPO, covered in the next two lessons.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.