Neural-Path/Notes
40 min

KL Divergence, f-Divergences & Total Variation Distance

KL divergence quantifies how much one probability distribution differs from another and is the central object in variational inference, policy optimization, and training objectives. Its asymmetry gives rise to a family of divergence measures — f-divergences — each with distinct geometric and statistical properties.

Concepts

KL divergence is asymmetric: KL(P‖Q) diverges when Q has zero probability where P doesn't. JS divergence is symmetric and bounded by log 2. Adjust the distributions to see the asymmetry.

-4-2024
P = N(μ₁,σ₁²)Q = N(μ₂,σ₂²)overlap
KL(P‖Q)
0.628
KL(Q‖P)
1.325
TV(P,Q)
0.472
JS(P,Q)
0.169

Set σ₂ small while shifting μ₂ away from μ₁: KL(P‖Q) explodes but KL(Q‖P) stays bounded. JS is always ≤ log 2 ≈ 0.693.

Training a language model by minimizing cross-entropy loss is exactly minimizing the KL divergence from the model's output distribution to the true data distribution. KL divergence is the workhorse quantity of machine learning: it appears in variational inference (ELBO), policy gradients (trust region), knowledge distillation, and RLHF — always measuring how much one probability distribution deviates from another, and in which direction.

KL Divergence

For distributions PP and QQ on the same alphabet X\mathcal{X}, the KL divergence (Kullback-Leibler divergence) is:

KL(PQ)=xXp(x)logp(x)q(x)=EP ⁣[logp(X)q(X)].\text{KL}(P \| Q) = \sum_{x \in \mathcal{X}} p(x) \log \frac{p(x)}{q(x)} = \mathbb{E}_P\!\left[\log\frac{p(X)}{q(X)}\right].

For continuous distributions with densities: KL(PQ)=p(x)logp(x)q(x)dx\text{KL}(P \| Q) = \int p(x)\log\frac{p(x)}{q(x)}\,dx.

Convention: plog(p/0)=+p\log(p/0) = +\infty when q(x)=0q(x)=0 but p(x)>0p(x)>0; 0log(0/q)=00\log(0/q) = 0.

Non-negativity (Gibbs' inequality): KL(PQ)0\text{KL}(P \| Q) \geq 0 with equality iff P=QP = Q almost everywhere.

Proof: KL(PQ)=EP[log(q/p)]logEP[q/p]=logxp(x)q(x)p(x)=log1=0-\text{KL}(P\|Q) = \mathbb{E}_P[\log(q/p)] \leq \log\mathbb{E}_P[q/p] = \log\sum_x p(x)\frac{q(x)}{p(x)} = \log 1 = 0, by Jensen's inequality applied to the concave function log\log.

Non-negativity follows from Jensen's inequality applied to the concave log — this is not a special property of KL but a consequence of convexity. The asymmetry is not a defect to be fixed: forward and reverse KL encode genuinely different optimization criteria with different practical consequences, and the specific direction used in VAEs, RLHF, and knowledge distillation shapes the learned solution in qualitatively different ways.

Asymmetry: KL(PQ)KL(QP)\text{KL}(P \| Q) \neq \text{KL}(Q \| P) in general. The two directions encode different penalties:

  • KL(PQ)\text{KL}(P \| Q) (forward KL, "I-projection"): QQ must cover all regions where PP has mass — otherwise KL diverges. Minimizing this yields mean-seeking approximations (cover all modes).
  • KL(QP)\text{KL}(Q \| P) (reverse KL, "M-projection"): QQ is penalized for having mass where PP has none, but not for missing mass of PP. Minimizing this yields mode-seeking approximations (concentrate on one mode).

KL divergence for exponential families: for p(x;η)=h(x)exp(ηTT(x)A(η))p(x;\eta) = h(x)\exp(\eta^T T(x) - A(\eta)) and p(x;η)p(x;\eta'):

KL(PηPη)=A(η)A(η)(ηη)TA(η).\text{KL}(P_\eta \| P_{\eta'}) = A(\eta') - A(\eta) - (\eta' - \eta)^T \nabla A(\eta).

This is the Bregman divergence generated by the log-partition function A(η)A(\eta).

Gaussian-Gaussian KL: for P=N(μ1,Σ1)P = \mathcal{N}(\mu_1, \Sigma_1) and Q=N(μ2,Σ2)Q = \mathcal{N}(\mu_2, \Sigma_2):

KL(PQ)=12[logdetΣ2detΣ1+tr(Σ21Σ1)+(μ1μ2)TΣ21(μ1μ2)d].\text{KL}(P \| Q) = \frac{1}{2}\left[\log\frac{\det\Sigma_2}{\det\Sigma_1} + \text{tr}(\Sigma_2^{-1}\Sigma_1) + (\mu_1-\mu_2)^T\Sigma_2^{-1}(\mu_1-\mu_2) - d\right].

For scalar Gaussians: KL(N(μ1,σ12)N(μ2,σ22))=log(σ2/σ1)+(σ12+(μ1μ2)2)/(2σ22)1/2\text{KL}(\mathcal{N}(\mu_1,\sigma_1^2) \| \mathcal{N}(\mu_2,\sigma_2^2)) = \log(\sigma_2/\sigma_1) + (\sigma_1^2 + (\mu_1-\mu_2)^2)/(2\sigma_2^2) - 1/2.

Connection to Entropy and Cross-Entropy

KL(PQ)=H(P,Q)H(P),\text{KL}(P \| Q) = H(P, Q) - H(P),

where H(P)=EP[logp]H(P) = -\mathbb{E}_P[\log p] is the entropy of PP and H(P,Q)=EP[logq]H(P,Q) = -\mathbb{E}_P[\log q] is the cross-entropy of QQ relative to PP.

Cross-entropy loss in classification: if PP is the true label distribution (one-hot) and Q=p^Q = \hat p is the model's predicted distribution, then H(P,Q)=logp^(y)H(P, Q) = -\log\hat p(y) for the true class yy. The KL divergence KL(PQ)=H(P,Q)H(P)\text{KL}(P \| Q) = H(P,Q) - H(P) measures how much worse the model is than the best possible predictor. Since H(P)H(P) is a constant (the label is deterministic), minimizing cross-entropy is equivalent to minimizing KL divergence.

f-Divergences

A general f-divergence is defined for a convex function ff with f(1)=0f(1) = 0:

Df(PQ)=EQ ⁣[f ⁣(p(X)q(X))]=q(x)f ⁣(p(x)q(x))dx.D_f(P \| Q) = \mathbb{E}_Q\!\left[f\!\left(\frac{p(X)}{q(X)}\right)\right] = \int q(x)\, f\!\left(\frac{p(x)}{q(x)}\right)\,dx.

By Jensen's inequality (ff convex): Df(PQ)f(EQ[p/q])=f(1)=0D_f(P\|Q) \geq f(\mathbb{E}_Q[p/q]) = f(1) = 0.

Special cases:

Generator f(t)f(t)f-divergence
tlogtt\log tKL(PQ)\text{KL}(P \| Q)
logt-\log tKL(QP)\text{KL}(Q \| P) (reverse KL)
(t1)2(t-1)^2χ2\chi^2 divergence
(t1)2(\sqrt{t}-1)^2Squared Hellinger distance
$\frac12t-1
tlogt(t+1)logt+12t\log t - (t+1)\log\frac{t+1}{2}Jensen-Shannon divergence

Total Variation Distance

TV(P,Q)=12xp(x)q(x)=12PQ1.\text{TV}(P, Q) = \frac{1}{2}\sum_x |p(x) - q(x)| = \frac{1}{2}\|P - Q\|_1.

TV is the maximum probability that any test can distinguish PP from QQ from a single sample:

TV(P,Q)=maxAXP(A)Q(A).\text{TV}(P, Q) = \max_{A \subseteq \mathcal{X}} |P(A) - Q(A)|.

Pinsker's inequality: TV(P,Q)12KL(PQ)\text{TV}(P,Q) \leq \sqrt{\frac{1}{2}\text{KL}(P\|Q)}.

Hellinger distance: H2(P,Q)=12x(p(x)q(x))2H^2(P,Q) = \frac{1}{2}\sum_x (\sqrt{p(x)} - \sqrt{q(x)})^2. Sandwiched:

H2(P,Q)TV(P,Q)2H(P,Q).H^2(P,Q) \leq \text{TV}(P,Q) \leq \sqrt{2}\, H(P,Q).

Jensen-Shannon Divergence

JS(P,Q)=12KL(PM)+12KL(QM),M=P+Q2.\text{JS}(P,Q) = \frac{1}{2}\text{KL}(P \| M) + \frac{1}{2}\text{KL}(Q \| M), \quad M = \frac{P+Q}{2}.

Properties: symmetric, bounded by log2\log 2 (or 1 bit), zero iff P=QP = Q. The JS distance JS(P,Q)\sqrt{\text{JS}(P,Q)} is a metric.

GAN connection: the Jensen-Shannon divergence arises naturally in the original GAN objective. The optimal discriminator D(x)=p(x)/(p(x)+q(x))D^*(x) = p(x)/(p(x)+q(x)) gives the GAN value function equal to 2JS(PdataPG)log42\,\text{JS}(P_\text{data} \| P_G) - \log 4. Minimizing the GAN objective is equivalent to minimizing the JS divergence between data and generator distributions.

Worked Example

Example 1: Mode-Seeking vs Mean-Seeking

Approximate a bimodal PP (mixture of two Gaussians) with a unimodal Q=N(μ,σ2)Q = \mathcal{N}(\mu, \sigma^2).

Minimizing KL(PQ)\text{KL}(P\|Q) (forward, as in maximum likelihood): QQ must cover all of PP's mass. The optimal QQ has mean near the average of the two modes and large variance to cover both — a broad distribution between the modes.

Minimizing KL(QP)\text{KL}(Q\|P) (reverse, as in variational inference with mean-field): QQ is penalized only for having mass where PP is zero. The optimal QQ collapses to one mode (mode-seeking). Standard mean-field variational inference minimizes reverse KL, which explains why it tends to underestimate uncertainty (collapses to a single mode).

Example 2: Cross-Entropy in Language Models

For a language model predicting next-token probabilities over vocabulary VV:

The cross-entropy per token is H(P,Q)=ExP[logQ(x)]H(P, Q) = -\mathbb{E}_{x \sim P}[\log Q(x)] where PP is the true data distribution and QQ is the model. Since KL(P‖Q) = H(P,Q) - H(P) and H(P)H(P) is fixed, minimizing cross-entropy = minimizing KL. A model achieving cross-entropy of 2 bits/token on English text has KL(PQ)=21.2=0.8\text{KL}(P\|Q) = 2 - 1.2 = 0.8 bits/token above the true entropy.

Perplexity = 2H=2cross-entropy2^H = 2^{\text{cross-entropy}}. Perplexity 8 means the model is as uncertain as a uniform distribution over 8 equally likely tokens — a useful interpretable metric.

Example 3: f-Divergence in Generative Models

Different GAN variants correspond to different f-divergences:

  • Standard GAN: Jensen-Shannon
  • ff-GAN (Nowozin et al.): any f-divergence via variational dual form Df(PQ)=supTEP[T]EQ[f(T)]D_f(P\|Q) = \sup_{T} \mathbb{E}_P[T] - \mathbb{E}_Q[f^*(T)] where ff^* is the Fenchel conjugate of ff
  • Wasserstein GAN: not an f-divergence, but an optimal transport distance — more stable training for distributions with disjoint support (where KL = ∞ and JS = log 2 are useless)

Connections

Where Your Intuition Breaks

The "direction" of KL divergence matters more than the magnitude. Minimizing KL(qp)\text{KL}(q \| p) (reverse KL) and minimizing KL(pq)\text{KL}(p \| q) (forward KL) solve different problems and can give radically different solutions. Reverse KL forces qq to be zero wherever pp is near-zero, producing mode-seeking approximations that can collapse to a single mode even when pp is multimodal. Forward KL forces qq to spread mass wherever pp has mass, producing mean-seeking approximations that average over modes. In variational inference, reverse KL is standard because it is tractable — but the mode-seeking consequence means VI systematically underestimates posterior variance in complex posteriors. Choosing the KL direction is a modeling decision, not a mathematical convention.

💡Intuition

KL divergence is the "cost" of using the wrong code. If the true distribution is PP but you design an optimal code for QQ, you use H(P,Q)H(P,Q) bits per symbol instead of the optimal H(P)H(P). The KL divergence KL(PQ)\text{KL}(P\|Q) is exactly this overhead. Shannon's source coding theorem says that no code for QQ can compress PP below H(P,Q)H(P,Q) bits/symbol. This makes cross-entropy loss the correct training objective for any classification problem where you are trying to learn the true conditional distribution — you are minimizing the coding overhead from using the model distribution instead of the true distribution.

💡Intuition

Reverse KL is the basis of variational inference. Variational inference minimizes KL(qπ(x))\text{KL}(q\|\pi(\cdot|x)), which is tractable because qq is a simple family (Gaussian, mean-field). The reverse direction means qq avoids regions where π\pi is small but ignores regions of π\pi that qq doesn't cover. This systematically underestimates posterior uncertainty in multimodal posteriors. Forward KL (as in expectation propagation) covers all modes but is harder to optimize. The choice between forward and reverse KL is a fundamental tradeoff in approximate inference.

⚠️Warning

KL divergence is infinite when the supports don't match. KL(PQ)=+\text{KL}(P\|Q) = +\infty whenever PP assigns positive probability to an event that QQ assigns zero probability. This makes KL impractical for comparing distributions with different supports — common in generative modeling when real and generated distributions occupy different manifolds. This is the original motivation for Wasserstein GANs: the Wasserstein distance is always finite even for distributions with disjoint support, because it uses the geometry of the underlying space rather than pointwise ratios.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.