Automatic Differentiation: Forward Mode, Reverse Mode & Computation Graphs
Automatic differentiation computes exact gradients of programs by mechanically applying the chain rule to elementary operations. It is neither symbolic differentiation (produces a formula) nor numerical differentiation (finite differences with truncation error). Forward mode propagates derivatives through the computation graph one input at a time; reverse mode backpropagates from the output, computing all partial derivatives in a single pass.
Concepts
Computation graph for f(x,y) = (x+y)·sin(y) at x=1, y=π/2. Forward mode propagates dual numbers up; reverse mode propagates adjoints down.
When you train a neural network, the backward pass reduces to one operation: compute how much each weight contributed to the loss. Automatic differentiation does this exactly — not approximately. It doesn't derive a closed-form symbolic formula (which grows exponentially in complexity) and doesn't use finite differences (which have truncation error). It mechanically applies the chain rule to every elementary operation recorded during the forward pass, in reverse.
Computation Graphs
Any differentiable program can be represented as a directed acyclic graph (DAG) where nodes are intermediate values and edges are data flow. For :
- Nodes: , , , ,
- Edges: , , , ,
The Jacobian of is the matrix . Computing directly costs function evaluations via finite differences. Both AD modes compute exactly at cost per pass over the graph (per column or per row).
The DAG structure is the data structure that makes reverse-mode AD tractable. Each node stores its output value and a reference to the operation that produced it. During the backward pass, each node receives a gradient from its downstream dependents and multiplies by its local derivative before passing upstream. The acyclic structure guarantees this reversal is well-defined: each edge contributes exactly once, and no gradient is needed before all downstream gradients are accumulated.
Forward Mode: Dual Numbers
Dual numbers: augment each value with a derivative component. Represent as a pair where (the derivative with respect to seed direction ).
Arithmetic rules:
Algorithm: set (seed direction ), for . Execute the forward pass — at each operation, compute both the primal value and its derivative. One forward pass computes one column of : for all outputs .
Cost: computing the full Jacobian requires forward passes — times the cost of one forward pass. Ideal when: few inputs, many outputs (wide Jacobian, e.g., forward kinematics).
Jacobian-vector product (JVP): for a tangent vector , the JVP can be computed in one forward pass by seeding .
Reverse Mode: Backpropagation
Adjoint variables (sensitivities): (how much the output changes per unit change in , holding everything else fixed).
Reverse accumulation: starting from , propagate adjoints backward through the graph using the adjoint rules:
For each operation :
(Using : each node accumulates contributions from all downstream nodes using it.)
Algorithm: (1) Forward pass: execute program, storing all intermediate values. (2) Backward pass: traverse graph in reverse topological order, computing and accumulating adjoints.
Cost: one forward pass + one backward pass (typically forward pass cost). Computes the full gradient for a scalar — one backward pass regardless of .
Vector-Jacobian product (VJP): for a cotangent , the VJP can be computed in one backward pass by seeding .
Memory-Compute Tradeoffs
Rematerialization (gradient checkpointing): reverse mode must store all intermediate values from the forward pass to compute adjoints. For a depth- network: memory. Gradient checkpointing stores only values and recomputes the rest during backward, reducing memory at the cost of extra forward passes.
Jacobian-free Newton-Krylov: for implicit functions , the Jacobian is not stored. GMRES applies using one forward + backward pass (JVP), solving without forming .
Second-order derivatives (Hessians): compute by differentiating the backward pass. Hessian-vector product : one forward pass (compute and ) + one backward pass (differentiate with respect to ) — no full Hessian stored. JAX's jax.jvp(jax.grad(f)) implements this efficiently.
Symbolic vs Numerical vs Automatic Differentiation
| Method | Result | Error | Cost | Use case |
|---|---|---|---|---|
| Symbolic | Closed-form expression | Exact | Exponential code growth | CAS (Mathematica) |
| Numerical (FD) | Approximation | evaluations | Gradient checking | |
| AD Forward | Exact (to machine precision) | forward | JVPs, few inputs | |
| AD Reverse | Exact (to machine precision) | forward | Gradients for ML |
Worked Example
Example 1: Reverse Mode for
Forward pass at :
Backward pass with :
- (from )
- (from , at )
Result: ✓; ✓.
Example 2: Memory Scaling with Gradient Checkpointing
A transformer with layers: storing all activations for reverse mode uses batch memory. For batch size 32 with 4096-token sequences in float16: bytes GB — GPU OOM.
Gradient checkpointing ( checkpoints): store every 10th layer, recompute intermediate 10-layer blocks during backward. Memory: forward pass memory GB; extra forward compute: . This is the standard technique in LLM training — GPT-3, LLaMA, and all large models use it.
Example 3: Jacobian-Vector Products for Neural Tangent Kernels
The Neural Tangent Kernel requires Jacobians . For : storing is impossible.
Instead, use JVPs: for any vector costs one forward-mode pass. The NTK matrix element can be computed as using a series of JVPs without ever materializing . JAX's jax.linearize and jax.vjp enable this pattern efficiently.
Connections
Where Your Intuition Breaks
Reverse-mode AD is often described as computing "all gradients in one backward pass." This holds for scalar outputs (), where one backward pass gives the full gradient . But for matrix-valued outputs — full Jacobians — reverse mode computes one row per backward pass (m total), while forward mode computes one column per forward pass (n total). For , reverse mode wins; for , forward mode wins. Nearly all ML training has (scalar loss), making reverse mode dominant — but per-sample gradients, meta-learning Jacobians, or NTK computations require the full , and the cost scales accordingly.
Reverse mode is exactly backpropagation. The backpropagation algorithm discovered independently in the 1960s-70s (Linnainmaa, Werbos, Rumelhart-Hinton-Williams) is reverse-mode automatic differentiation applied to neural networks. There is nothing special about neural networks in the algorithm — it applies to any differentiable program. PyTorch's autograd, JAX's jax.grad, and TensorFlow's tf.GradientTape are all implementations of reverse-mode AD on dynamic computation graphs. Understanding AD is understanding backprop precisely.
Forward vs reverse mode is a cost tradeoff determined by input/output dimensions. For a function : forward mode computes one column of per pass ( passes for full ); reverse mode computes one row per pass ( passes for full ). Neural network training: millions of parameters, (scalar loss). So reverse mode computes the full gradient in one pass; forward mode would require millions of passes. For (e.g., computing the Jacobian of a physics simulator output), forward mode wins. This is why scientific ML (PINNs, neural ODEs for control) sometimes prefers forward mode.
Gradient checking is noisy — use it sparingly and with the right step size. Numerical gradient has truncation error and rounding error . The optimal for float64 gives absolute error . For float32: , error . Gradient checking in float32 with gives entirely rounding-error-dominated results — the check is meaningless. Always run gradient checks in float64, and only use them to debug custom operators, not as a routine test.
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.