Neural-Path/Notes
35 min

Bridge: Autodiff in PyTorch/JAX, Mixed Precision & Numerical Stability in Training

PyTorch and JAX implement reverse-mode automatic differentiation with different design philosophies: PyTorch builds computation graphs eagerly during the forward pass; JAX traces through Python to build XLA computations that are compiled and JIT-accelerated. Mixed-precision training, gradient clipping, and numerical stability practices apply directly from the theory of floating-point arithmetic and condition numbers.

Concepts

PyTorch and JAX both compute exact gradients automatically — but by different strategies. PyTorch builds the computation graph as operations execute (eager mode), making debugging natural but precluding graph-level optimization. JAX traces through the function first to build an XLA computation graph, then compiles and JIT-accelerates it. The same AD theory underlies both; the difference is when the graph is built and how it is executed.

PyTorch Autograd: Eager Tape-Based AD

PyTorch builds a computation graph dynamically: each tensor operation records a grad_fn in the tensor's gradient function attribute. The computation graph is a tape — a record of operations in forward-pass order.

requires_grad: only tensors with requires_grad=True participate in gradient computation. The graph is built only for these tensors.

.backward(): traverses the graph in reverse topological order, calling each grad_fn to compute and accumulate gradients. For a scalar loss, loss.backward() populates .grad for all leaf tensors.

The reverse topological ordering is forced by the chain rule: downstream gradients must be fully accumulated before upstream gradients can be computed. This means the tape built during the forward pass encodes the exact traversal order for the backward pass — there is no search, no interpretation, just replay in reverse.

torch.no_grad(): disables graph construction entirely — for inference, reduces memory and compute by 30%\sim 30\%.

Gradient accumulation: optimizer.zero_grad() must be called before each backward pass; otherwise gradients accumulate (useful for gradient accumulation over multiple microbatches to simulate larger batch sizes).

Custom gradients: torch.autograd.Function allows registering custom forward and backward functions. The backward method receives output gradients and must return input gradients. Essential for: integer operations (non-differentiable, use straight-through estimator), external solvers (implicit differentiation), quantized operations.

JAX: Functional Transforms on Pure Functions

JAX applies function transforms to pure Python functions that use JAX arrays:

  • jax.grad(f): returns xf(x)\nabla_x f(x) — only for scalar-valued ff
  • jax.jvp(f, primals, tangents): forward-mode AD; returns (f(x),Jf(x)v)(f(x), J_f(x) v)
  • jax.vjp(f, primals): reverse-mode AD; returns (f(x),λλTJf)(f(x), \lambda \mapsto \lambda^T J_f)
  • jax.jacobian(f): full Jacobian (uses forward or reverse mode depending on shape)
  • jax.jit(f): just-in-time compile with XLA for GPU/TPU
  • jax.vmap(f): vectorize ff over a batch dimension (auto-batching)

Composability: jax.grad(jax.jit(f)) works; jax.vmap(jax.grad(f)) computes per-sample gradients. Composing transforms is the key design principle: jax.jit(jax.grad(jax.vmap(loss))) compiles a batched gradient computation.

Side effects: JAX requires pure functions (no in-place mutation, no Python state). jax.lax.while_loop and jax.lax.scan enable loops without Python overhead.

Mixed Precision Training

Standard recipe (Micikevicius et al. 2018, used in all modern LLM training):

  1. Forward pass in float16/bfloat16: compute activations using half precision. 2× memory reduction and 2–8× speedup on modern GPUs (Tensor Cores operate on float16/bfloat16).

  2. Loss scaling: multiply loss by scale factor S210S \sim 2^{10} before backward. Gradients scaled by SS move away from the float16 underflow region (below 2146×1052^{-14} \approx 6 \times 10^{-5}). After backward, divide gradients by SS before the optimizer step.

  3. Float32 master weights: maintain float32 copies of weights for the optimizer step. Weight updates ΔθηL\Delta\theta \sim \eta \nabla L are 106\sim 10^{-6} — too small to represent in float16. Float32 accumulates updates accurately.

  4. Dynamic loss scaling: if gradients contain Inf/NaN (overflow), halve SS and skip the step. If TT consecutive steps have no overflow, double SS. Automatic loss scaling in PyTorch: torch.cuda.amp.GradScaler.

bfloat16 vs float16:

  • float16: 5 exponent bits — max 6550465504, can overflow for large activations
  • bfloat16: 8 exponent bits — same range as float32, rarely overflows; 7 mantissa bits (vs 10 for float16) — less precision but more stable for LLM training

Gradient Clipping and Exploding Gradients

For deep networks and RNNs, the gradient norm can grow exponentially with depth: L=O(WL)\|\nabla L\| = O(\|W\|^L) for a depth-LL network with weight matrices of spectral norm W\|W\|.

Global gradient clipping: if L2>τ\|\nabla L\|_2 > \tau, scale the gradient:

g^=Lmin ⁣(1,τL2).\hat g = \nabla L \cdot \min\!\left(1, \frac{\tau}{\|\nabla L\|_2}\right).

This preserves the gradient direction while bounding its magnitude. Standard in transformer training: τ=1.0\tau = 1.0 (GPT-2, GPT-3, LLaMA). Gradient clipping is applied after unscaling if using mixed precision.

Gradient norm monitoring: log L2\|\nabla L\|_2 at each step. A good training run shows stable gradient norms; sudden spikes indicate a numerical instability or a bad batch. Chronic high gradient norms indicate the learning rate is too large or the model is approaching a sharp loss region.

Worked Example

Example 1: Debugging NaN Loss in Mixed Precision Training

Sequence of steps when loss becomes NaN:

  1. Check GradScaler history: if loss scale dropped below 2102^{-10}, gradients are underflowing throughout — use bfloat16 instead.

  2. Find the first layer with NaN activations: add forward hooks model.register_forward_hook(lambda m, i, o: torch.isnan(o).any().item()).

  3. Common culprits: LayerNorm receiving large inputs (mean/variance computation overflows); softmax of logits >30> 30 in float16; log or sqrt of near-zero values in loss functions.

  4. Fix: (a) add numerical epsilon in loss loss = -torch.log(p + 1e-6); (b) use in-place layer norm with float32 reductions; (c) increase gradient clipping threshold temporarily.

Example 2: Per-Sample Gradients with JAX vmap

Computing per-sample gradients (needed for differential privacy and influence functions): with PyTorch, loop over batch elements — O(n)O(n) overhead. With JAX:

python
per_sample_grads = jax.vmap(jax.grad(loss_single))(params, x_batch, y_batch)

jax.vmap vectorizes the grad computation over the batch dimension — no Python loop, no extra memory beyond a single forward pass. This is 10×10\times faster than the PyTorch loop approach for batch size 128.

DP-SGD (differential privacy): clip per-sample gradients to radius CC, then add Gaussian noise N(0,σ2C2I)\mathcal{N}(0, \sigma^2 C^2 I) to the sum. This requires per-sample gradient norms, which jax.vmap(jax.grad(loss_single)) computes efficiently.

Example 3: Custom Backward for Implicit Differentiation

A neural network layer that solves an optimization problem: z=argminzg(x,z)z^* = \arg\min_z g(x, z) where the forward pass runs an iterative solver. The backward pass needs z/x\partial z^*/\partial x.

By the implicit function theorem: z/x=(Jzg)1Jxg\partial z^*/\partial x = -(J_z g)^{-1} J_x g at z=zz = z^*. In PyTorch:

python
class ImplicitLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        z_star = solve(x)  # run iterative solver
        ctx.save_for_backward(x, z_star)
        return z_star
 
    @staticmethod
    def backward(ctx, grad_output):
        x, z_star = ctx.saved_tensors
        # Solve J_z g^T v = grad_output (one linear solve)
        v = solve_linear(J_z_g_T, grad_output)
        # Compute -J_x g^T v
        return -J_x_g_T @ v

One linear solve in the backward pass replaces unrolling the solver (which could require hundreds of AD steps). This is used in DEQs, differentiable optimization layers (cvxpylayers), and physics simulations.

Connections

Where Your Intuition Breaks

Mixed precision training is often described as "train in float16 for speed, accumulate in float32 for accuracy." But the core difficulty is the gradient underflow problem: gradient values for early layers in a deep network are often below 2146×1052^{-14} \approx 6 \times 10^{-5}, which underflows to zero in float16. Loss scaling multiplies the loss by a large constant before the backward pass so that gradients are scaled up, avoiding underflow. But if the scale is too large, gradients overflow to inf. Dynamic loss scaling automatically adjusts the scale up and down based on overflow detection — not because this is elegant, but because no static scale works across all models and learning rate schedules. The float16/float32 split is the easy part; tuning loss scaling is where most mixed precision issues arise in practice.

💡Intuition

PyTorch and JAX make different tradeoffs between flexibility and performance. PyTorch's eager execution mode allows Python control flow (if/for) to affect the graph — easy to debug but no graph-level optimization. JAX traces through Python to build a static computation graph, then compiles with XLA — allows aggressive fusion, rematerialization, and cross-device parallelism. For research: PyTorch is more debuggable. For production training at scale: JAX/XLA tends to be faster due to better compiler optimization. Both implement the same AD theory; the difference is execution model.

💡Intuition

Loss scaling is not a hack — it is a necessary correction for float16 gradients. Without loss scaling, gradients of magnitude 105\sim 10^{-5} flush to zero in float16 (min float16 6×105\approx 6\times10^{-5}). The loss scale S=210S = 2^{10} shifts the gradient distribution right by 10 bits, keeping most values in the representable range. After the backward pass, dividing by SS restores the correct scale. The key insight: the loss scale only needs to prevent underflow during the backward pass; the optimizer step always happens in float32 where full precision is available.

⚠️Warning

torch.cuda.amp.autocast does not make all operations float16. Autocast promotes operations involving large reductions (layer norm, batch norm statistics, softmax) to float32 automatically, while using float16 for matrix multiplications and convolutions where Tensor Cores are beneficial. The promotion rules are operation-specific and version-dependent. If you cast tensors manually to float16 inside an autocast region, you may bypass these promotions and introduce instability. Let autocast manage dtypes; only override when you have specific knowledge that a custom operation needs a particular dtype.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.