Neural-Path/Notes
45 min

Automatic Differentiation: Forward Mode, Reverse Mode & Computation Graphs

Automatic differentiation computes exact gradients of programs by mechanically applying the chain rule to elementary operations. It is neither symbolic differentiation (produces a formula) nor numerical differentiation (finite differences with truncation error). Forward mode propagates derivatives through the computation graph one input at a time; reverse mode backpropagates from the output, computing all partial derivatives in a single pass.

Concepts

Computation graph for f(x,y) = (x+y)·sin(y) at x=1, y=π/2. Forward mode propagates dual numbers up; reverse mode propagates adjoints down.

xx=1ẋ=1yy=π/2≈1.571ẏ=0w₁=x+y2.571ẇ₁=1w₂=sin y1ẇ₂=0f=w₁·w₂2.571f̈=1
Forward mode: one pass → one column of Jacobian. Seed ẋ=1; all other seeds=0. Result: ∂f/∂x = 1. Cost per column: O(1) × forward pass.

When you train a neural network, the backward pass reduces to one operation: compute how much each weight contributed to the loss. Automatic differentiation does this exactly — not approximately. It doesn't derive a closed-form symbolic formula (which grows exponentially in complexity) and doesn't use finite differences (which have truncation error). It mechanically applies the chain rule to every elementary operation recorded during the forward pass, in reverse.

Computation Graphs

Any differentiable program can be represented as a directed acyclic graph (DAG) where nodes are intermediate values and edges are data flow. For f(x,y)=(x+y)sin(y)f(x, y) = (x+y)\sin(y):

  • Nodes: xx, yy, w1=x+yw_1 = x+y, w2=sin(y)w_2 = \sin(y), f=w1w2f = w_1 w_2
  • Edges: xw1x \to w_1, yw1y \to w_1, yw2y \to w_2, w1fw_1 \to f, w2fw_2 \to f

The Jacobian of f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m is the m×nm \times n matrix Jij=fi/xjJ_{ij} = \partial f_i / \partial x_j. Computing JJ directly costs O(mn)O(mn) function evaluations via finite differences. Both AD modes compute JJ exactly at cost O(1)O(1) per pass over the graph (per column or per row).

The DAG structure is the data structure that makes reverse-mode AD tractable. Each node stores its output value and a reference to the operation that produced it. During the backward pass, each node receives a gradient from its downstream dependents and multiplies by its local derivative before passing upstream. The acyclic structure guarantees this reversal is well-defined: each edge contributes exactly once, and no gradient is needed before all downstream gradients are accumulated.

Forward Mode: Dual Numbers

Dual numbers: augment each value with a derivative component. Represent xx as a pair x,x˙\langle x, \dot x \rangle where x˙=x/xj\dot x = \partial x / \partial x_j (the derivative with respect to seed direction xjx_j).

Arithmetic rules:

a,a˙+b,b˙=a+b,a˙+b˙\langle a, \dot a \rangle + \langle b, \dot b \rangle = \langle a+b, \dot a + \dot b \rangle a,a˙×b,b˙=ab,ab˙+ba˙\langle a, \dot a \rangle \times \langle b, \dot b \rangle = \langle ab, a\dot b + b\dot a \rangle sina,a˙=sina,a˙cosa\sin\langle a, \dot a \rangle = \langle \sin a, \dot a \cos a \rangle

Algorithm: set x˙j=1\dot x_j = 1 (seed direction jj), x˙k=0\dot x_k = 0 for kjk \neq j. Execute the forward pass — at each operation, compute both the primal value and its derivative. One forward pass computes one column of JJ: f/xj\partial f / \partial x_j for all outputs ff.

Cost: computing the full n×mn \times m Jacobian requires nn forward passes — O(n)O(n) times the cost of one forward pass. Ideal when: few inputs, many outputs (wide Jacobian, e.g., forward kinematics).

Jacobian-vector product (JVP): for a tangent vector vRnv \in \mathbb{R}^n, the JVP JvJv can be computed in one forward pass by seeding x˙j=vj\dot x_j = v_j.

Reverse Mode: Backpropagation

Adjoint variables (sensitivities): wˉi=f/wi\bar w_i = \partial f / \partial w_i (how much the output changes per unit change in wiw_i, holding everything else fixed).

Reverse accumulation: starting from fˉ=1\bar f = 1, propagate adjoints backward through the graph using the adjoint rules:

For each operation wk=g(wi,wj,)w_k = g(w_i, w_j, \ldots):

wˉi+=wˉkgwi,wˉj+=wˉkgwj.\bar w_i \mathrel{+}= \bar w_k \cdot \frac{\partial g}{\partial w_i}, \quad \bar w_j \mathrel{+}= \bar w_k \cdot \frac{\partial g}{\partial w_j}.

(Using +=+=: each node accumulates contributions from all downstream nodes using it.)

Algorithm: (1) Forward pass: execute program, storing all intermediate values. (2) Backward pass: traverse graph in reverse topological order, computing and accumulating adjoints.

Cost: one forward pass + one backward pass (typically 3×\leq 3\times forward pass cost). Computes the full gradient xfRn\nabla_x f \in \mathbb{R}^n for a scalar ff — one backward pass regardless of nn.

Vector-Jacobian product (VJP): for a cotangent uRmu \in \mathbb{R}^m, the VJP uTJu^T J can be computed in one backward pass by seeding fˉ=u\bar f = u.

Memory-Compute Tradeoffs

Rematerialization (gradient checkpointing): reverse mode must store all intermediate values from the forward pass to compute adjoints. For a depth-LL network: O(L)O(L) memory. Gradient checkpointing stores only O(L)O(\sqrt{L}) values and recomputes the rest during backward, reducing memory at the cost of O(L)O(\sqrt{L}) extra forward passes.

Jacobian-free Newton-Krylov: for implicit functions F(θ)=0F(\theta) = 0, the Jacobian JFJ_F is not stored. GMRES applies JFvJ_F v using one forward + backward pass (JVP), solving JFδ=FJ_F \delta = -F without forming JFJ_F.

Second-order derivatives (Hessians): compute 2f\nabla^2 f by differentiating the backward pass. Hessian-vector product HvHv: one forward pass (compute ff and JvJ v) + one backward pass (differentiate JvJv with respect to θ\theta) — no full Hessian stored. JAX's jax.jvp(jax.grad(f)) implements this efficiently.

Symbolic vs Numerical vs Automatic Differentiation

MethodResultErrorCostUse case
SymbolicClosed-form expressionExactExponential code growthCAS (Mathematica)
Numerical (FD)Approximation [f(x+h)f(x)]/h[f(x+h)-f(x)]/hO(h)+O(ε/h)O(h) + O(\varepsilon/h)O(n)O(n) evaluationsGradient checking
AD ForwardExact (to machine precision)O(εmach)O(\varepsilon_{\text{mach}})O(n)×O(n) \times forwardJVPs, few inputs
AD ReverseExact (to machine precision)O(εmach)O(\varepsilon_{\text{mach}})O(1)×O(1) \times forwardGradients for ML

Worked Example

Example 1: Reverse Mode for f(x,y)=(x+y)sin(y)f(x,y) = (x+y)\sin(y)

Forward pass at (x,y)=(1,π/2)(x, y) = (1, \pi/2):

  • w1=x+y=1+π/22.571w_1 = x + y = 1 + \pi/2 \approx 2.571
  • w2=sin(y)=sin(π/2)=1w_2 = \sin(y) = \sin(\pi/2) = 1
  • f=w1w22.571f = w_1 w_2 \approx 2.571

Backward pass with fˉ=1\bar f = 1:

  • wˉ1=fˉw2=11=1\bar w_1 = \bar f \cdot w_2 = 1 \cdot 1 = 1
  • wˉ2=fˉw1=12.571=2.571\bar w_2 = \bar f \cdot w_1 = 1 \cdot 2.571 = 2.571
  • yˉ+=wˉ1w1/y=11=1\bar y \mathrel{+}= \bar w_1 \cdot \partial w_1/\partial y = 1 \cdot 1 = 1 (from w1=x+yw_1 = x+y)
  • yˉ+=wˉ2cos(y)=2.5710=0\bar y \mathrel{+}= \bar w_2 \cdot \cos(y) = 2.571 \cdot 0 = 0 (from w2=sinyw_2 = \sin y, at y=π/2y=\pi/2)
  • xˉ=wˉ11=1\bar x = \bar w_1 \cdot 1 = 1

Result: f/x=1=sin(π/2)\partial f/\partial x = 1 = \sin(\pi/2) ✓; f/y=1+0=1=sin(π/2)+(1+π/2)cos(π/2)\partial f/\partial y = 1 + 0 = 1 = \sin(\pi/2) + (1+\pi/2)\cos(\pi/2) ✓.

Example 2: Memory Scaling with Gradient Checkpointing

A transformer with L=96L = 96 layers: storing all activations for reverse mode uses O(L)×O(L) \times batch memory. For batch size 32 with 4096-token sequences in float16: 96×4096×4096×32×2\sim 96 \times 4096 \times 4096 \times 32 \times 2 bytes 100\approx 100 GB — GPU OOM.

Gradient checkpointing (L=10\sqrt{L} = 10 checkpoints): store every 10th layer, recompute intermediate 10-layer blocks during backward. Memory: 10×10 \times forward pass memory 10\approx 10 GB; extra forward compute: +10%+10\%. This is the standard technique in LLM training — GPT-3, LLaMA, and all large models use it.

Example 3: Jacobian-Vector Products for Neural Tangent Kernels

The Neural Tangent Kernel K(x,x)=Jθf(x)Jθf(x)TK(x, x') = J_\theta f(x) J_\theta f(x')^T requires Jacobians JθfRnout×nparamsJ_\theta f \in \mathbb{R}^{n_{\text{out}} \times n_{\text{params}}}. For nparams=109n_{\text{params}} = 10^9: storing JJ is impossible.

Instead, use JVPs: Jθf(x)vJ_\theta f(x) v for any vector vv costs one forward-mode pass. The NTK matrix element K(x,x)=k(Jθf(x))kT(Jθf(x))kK(x, x') = \sum_k (J_\theta f(x))_k^T (J_\theta f(x'))_k can be computed as tr(JTJ)\text{tr}(J^T J') using a series of JVPs without ever materializing JJ. JAX's jax.linearize and jax.vjp enable this pattern efficiently.

Connections

Where Your Intuition Breaks

Reverse-mode AD is often described as computing "all gradients in one backward pass." This holds for scalar outputs (m=1m = 1), where one backward pass gives the full gradient fRn\nabla f \in \mathbb{R}^n. But for matrix-valued outputs — full Jacobians JRm×nJ \in \mathbb{R}^{m \times n} — reverse mode computes one row per backward pass (m total), while forward mode computes one column per forward pass (n total). For mnm \ll n, reverse mode wins; for nmn \ll m, forward mode wins. Nearly all ML training has m=1m = 1 (scalar loss), making reverse mode dominant — but per-sample gradients, meta-learning Jacobians, or NTK computations require the full JJ, and the cost scales accordingly.

💡Intuition

Reverse mode is exactly backpropagation. The backpropagation algorithm discovered independently in the 1960s-70s (Linnainmaa, Werbos, Rumelhart-Hinton-Williams) is reverse-mode automatic differentiation applied to neural networks. There is nothing special about neural networks in the algorithm — it applies to any differentiable program. PyTorch's autograd, JAX's jax.grad, and TensorFlow's tf.GradientTape are all implementations of reverse-mode AD on dynamic computation graphs. Understanding AD is understanding backprop precisely.

💡Intuition

Forward vs reverse mode is a cost tradeoff determined by input/output dimensions. For a function f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m: forward mode computes one column of JJ per pass (nn passes for full JJ); reverse mode computes one row per pass (mm passes for full JJ). Neural network training: n=n = millions of parameters, m=1m = 1 (scalar loss). So reverse mode computes the full gradient in one pass; forward mode would require millions of passes. For mnm \gg n (e.g., computing the Jacobian of a physics simulator output), forward mode wins. This is why scientific ML (PINNs, neural ODEs for control) sometimes prefers forward mode.

⚠️Warning

Gradient checking is noisy — use it sparingly and with the right step size. Numerical gradient [f(x+h)f(xh)]/(2h)[f(x+h) - f(x-h)]/(2h) has truncation error O(h2)O(h^2) and rounding error O(εmach/h)O(\varepsilon_{\text{mach}}/h). The optimal h=εmach1/3105h = \varepsilon_{\text{mach}}^{1/3} \approx 10^{-5} for float64 gives absolute error 1010\sim 10^{-10}. For float32: h103h \approx 10^{-3}, error 106\sim 10^{-6}. Gradient checking in float32 with h=105h = 10^{-5} gives entirely rounding-error-dominated results — the check is meaningless. Always run gradient checks in float64, and only use them to debug custom operators, not as a routine test.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.