Bridge: Autodiff in PyTorch/JAX, Mixed Precision & Numerical Stability in Training
PyTorch and JAX implement reverse-mode automatic differentiation with different design philosophies: PyTorch builds computation graphs eagerly during the forward pass; JAX traces through Python to build XLA computations that are compiled and JIT-accelerated. Mixed-precision training, gradient clipping, and numerical stability practices apply directly from the theory of floating-point arithmetic and condition numbers.
Concepts
PyTorch and JAX both compute exact gradients automatically — but by different strategies. PyTorch builds the computation graph as operations execute (eager mode), making debugging natural but precluding graph-level optimization. JAX traces through the function first to build an XLA computation graph, then compiles and JIT-accelerates it. The same AD theory underlies both; the difference is when the graph is built and how it is executed.
PyTorch Autograd: Eager Tape-Based AD
PyTorch builds a computation graph dynamically: each tensor operation records a grad_fn in the tensor's gradient function attribute. The computation graph is a tape — a record of operations in forward-pass order.
requires_grad: only tensors with requires_grad=True participate in gradient computation. The graph is built only for these tensors.
.backward(): traverses the graph in reverse topological order, calling each grad_fn to compute and accumulate gradients. For a scalar loss, loss.backward() populates .grad for all leaf tensors.
The reverse topological ordering is forced by the chain rule: downstream gradients must be fully accumulated before upstream gradients can be computed. This means the tape built during the forward pass encodes the exact traversal order for the backward pass — there is no search, no interpretation, just replay in reverse.
torch.no_grad(): disables graph construction entirely — for inference, reduces memory and compute by .
Gradient accumulation: optimizer.zero_grad() must be called before each backward pass; otherwise gradients accumulate (useful for gradient accumulation over multiple microbatches to simulate larger batch sizes).
Custom gradients: torch.autograd.Function allows registering custom forward and backward functions. The backward method receives output gradients and must return input gradients. Essential for: integer operations (non-differentiable, use straight-through estimator), external solvers (implicit differentiation), quantized operations.
JAX: Functional Transforms on Pure Functions
JAX applies function transforms to pure Python functions that use JAX arrays:
jax.grad(f): returns — only for scalar-valuedjax.jvp(f, primals, tangents): forward-mode AD; returnsjax.vjp(f, primals): reverse-mode AD; returnsjax.jacobian(f): full Jacobian (uses forward or reverse mode depending on shape)jax.jit(f): just-in-time compile with XLA for GPU/TPUjax.vmap(f): vectorize over a batch dimension (auto-batching)
Composability: jax.grad(jax.jit(f)) works; jax.vmap(jax.grad(f)) computes per-sample gradients. Composing transforms is the key design principle: jax.jit(jax.grad(jax.vmap(loss))) compiles a batched gradient computation.
Side effects: JAX requires pure functions (no in-place mutation, no Python state). jax.lax.while_loop and jax.lax.scan enable loops without Python overhead.
Mixed Precision Training
Standard recipe (Micikevicius et al. 2018, used in all modern LLM training):
-
Forward pass in float16/bfloat16: compute activations using half precision. 2× memory reduction and 2–8× speedup on modern GPUs (Tensor Cores operate on float16/bfloat16).
-
Loss scaling: multiply loss by scale factor before backward. Gradients scaled by move away from the float16 underflow region (below ). After backward, divide gradients by before the optimizer step.
-
Float32 master weights: maintain float32 copies of weights for the optimizer step. Weight updates are — too small to represent in float16. Float32 accumulates updates accurately.
-
Dynamic loss scaling: if gradients contain Inf/NaN (overflow), halve and skip the step. If consecutive steps have no overflow, double . Automatic loss scaling in PyTorch:
torch.cuda.amp.GradScaler.
bfloat16 vs float16:
- float16: 5 exponent bits — max , can overflow for large activations
- bfloat16: 8 exponent bits — same range as float32, rarely overflows; 7 mantissa bits (vs 10 for float16) — less precision but more stable for LLM training
Gradient Clipping and Exploding Gradients
For deep networks and RNNs, the gradient norm can grow exponentially with depth: for a depth- network with weight matrices of spectral norm .
Global gradient clipping: if , scale the gradient:
This preserves the gradient direction while bounding its magnitude. Standard in transformer training: (GPT-2, GPT-3, LLaMA). Gradient clipping is applied after unscaling if using mixed precision.
Gradient norm monitoring: log at each step. A good training run shows stable gradient norms; sudden spikes indicate a numerical instability or a bad batch. Chronic high gradient norms indicate the learning rate is too large or the model is approaching a sharp loss region.
Worked Example
Example 1: Debugging NaN Loss in Mixed Precision Training
Sequence of steps when loss becomes NaN:
-
Check
GradScalerhistory: if loss scale dropped below , gradients are underflowing throughout — use bfloat16 instead. -
Find the first layer with NaN activations: add forward hooks
model.register_forward_hook(lambda m, i, o: torch.isnan(o).any().item()). -
Common culprits:
LayerNormreceiving large inputs (mean/variance computation overflows);softmaxof logits in float16;logorsqrtof near-zero values in loss functions. -
Fix: (a) add numerical epsilon in loss
loss = -torch.log(p + 1e-6); (b) use in-place layer norm with float32 reductions; (c) increase gradient clipping threshold temporarily.
Example 2: Per-Sample Gradients with JAX vmap
Computing per-sample gradients (needed for differential privacy and influence functions): with PyTorch, loop over batch elements — overhead. With JAX:
per_sample_grads = jax.vmap(jax.grad(loss_single))(params, x_batch, y_batch)jax.vmap vectorizes the grad computation over the batch dimension — no Python loop, no extra memory beyond a single forward pass. This is faster than the PyTorch loop approach for batch size 128.
DP-SGD (differential privacy): clip per-sample gradients to radius , then add Gaussian noise to the sum. This requires per-sample gradient norms, which jax.vmap(jax.grad(loss_single)) computes efficiently.
Example 3: Custom Backward for Implicit Differentiation
A neural network layer that solves an optimization problem: where the forward pass runs an iterative solver. The backward pass needs .
By the implicit function theorem: at . In PyTorch:
class ImplicitLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
z_star = solve(x) # run iterative solver
ctx.save_for_backward(x, z_star)
return z_star
@staticmethod
def backward(ctx, grad_output):
x, z_star = ctx.saved_tensors
# Solve J_z g^T v = grad_output (one linear solve)
v = solve_linear(J_z_g_T, grad_output)
# Compute -J_x g^T v
return -J_x_g_T @ vOne linear solve in the backward pass replaces unrolling the solver (which could require hundreds of AD steps). This is used in DEQs, differentiable optimization layers (cvxpylayers), and physics simulations.
Connections
Where Your Intuition Breaks
Mixed precision training is often described as "train in float16 for speed, accumulate in float32 for accuracy." But the core difficulty is the gradient underflow problem: gradient values for early layers in a deep network are often below , which underflows to zero in float16. Loss scaling multiplies the loss by a large constant before the backward pass so that gradients are scaled up, avoiding underflow. But if the scale is too large, gradients overflow to inf. Dynamic loss scaling automatically adjusts the scale up and down based on overflow detection — not because this is elegant, but because no static scale works across all models and learning rate schedules. The float16/float32 split is the easy part; tuning loss scaling is where most mixed precision issues arise in practice.
PyTorch and JAX make different tradeoffs between flexibility and performance. PyTorch's eager execution mode allows Python control flow (if/for) to affect the graph — easy to debug but no graph-level optimization. JAX traces through Python to build a static computation graph, then compiles with XLA — allows aggressive fusion, rematerialization, and cross-device parallelism. For research: PyTorch is more debuggable. For production training at scale: JAX/XLA tends to be faster due to better compiler optimization. Both implement the same AD theory; the difference is execution model.
Loss scaling is not a hack — it is a necessary correction for float16 gradients. Without loss scaling, gradients of magnitude flush to zero in float16 (min float16 ). The loss scale shifts the gradient distribution right by 10 bits, keeping most values in the representable range. After the backward pass, dividing by restores the correct scale. The key insight: the loss scale only needs to prevent underflow during the backward pass; the optimizer step always happens in float32 where full precision is available.
torch.cuda.amp.autocast does not make all operations float16. Autocast promotes operations involving large reductions (layer norm, batch norm statistics, softmax) to float32 automatically, while using float16 for matrix multiplications and convolutions where Tensor Cores are beneficial. The promotion rules are operation-specific and version-dependent. If you cast tensors manually to float16 inside an autocast region, you may bypass these promotions and introduce instability. Let autocast manage dtypes; only override when you have specific knowledge that a custom operation needs a particular dtype.
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.