Neural-Path/Notes
35 min

Bridge: Loss Landscapes, Natural Gradient & Equivariant Networks

This lesson draws the threads of Module 04 together, connecting the mathematics of manifolds, curvature, and differential forms directly to three active frontiers of ML: the geometry of neural network loss landscapes, natural gradient and second-order optimization methods, and equivariant architectures that respect the symmetry group of the data domain.

Concepts

Two minima with equal training loss but different sharpness. Toggle test loss to see the generalization gap, and SAM radius ρ to see why flat minima are preferred.

0.30.60.91.2parameter θlosssharpflat
Both minima have equal training loss. Toggle test loss to see the generalization gap.

Two neural networks can achieve the same training loss but generalize completely differently — the one that found a flat minimum (a broad valley in the loss landscape) tends to win. That flatness is a property of curvature: the Hessian eigenvalues at the solution. The entire machinery of this module — Jacobians, geodesics, Riemannian metrics, group symmetry — comes together here: curvature explains generalization, the Fisher metric explains why Adam works, and symmetry groups explain why some architectures need far less training data than others.

The Loss Landscape as a Riemannian Manifold

A neural network with parameter vector θRn\boldsymbol{\theta} \in \mathbb{R}^n defines a loss function L:RnR\mathcal{L} : \mathbb{R}^n \to \mathbb{R}. The loss landscape is the graph of L\mathcal{L} — a hypersurface in Rn+1\mathbb{R}^{n+1}.

Local geometry at a critical point θ\boldsymbol{\theta}^* (where L=0\nabla\mathcal{L} = 0):

L(θ+δ)L(θ)+12δTHδ,\mathcal{L}(\boldsymbol{\theta}^* + \boldsymbol{\delta}) \approx \mathcal{L}(\boldsymbol{\theta}^*) + \frac{1}{2}\boldsymbol{\delta}^T H \boldsymbol{\delta},

where H=2L(θ)H = \nabla^2\mathcal{L}(\boldsymbol{\theta}^*) is the Hessian. The eigenspectrum of HH determines the local curvature:

  • Minimum: all eigenvalues >0> 0
  • Saddle point: mixed-sign eigenvalues (overwhelmingly dominant in high dimensions)
  • Flat region: many near-zero eigenvalues — "flat minimum," associated with better generalization

The second-order Taylor expansion had to involve the Hessian — it is the unique symmetric matrix capturing all second-order changes. There is no simpler geometric invariant that would distinguish a flat minimum from a sharp one; the Hessian eigenspectrum is the minimal descriptor. This is why every modern theory of generalization — PAC-Bayes, sharpness-aware minimization, flat minima theory — is ultimately phrased in terms of Hessian eigenvalues.

Flat vs sharp minima. Empirically, SGD with small batch sizes finds flatter minima (smaller λmax(H)\lambda_{\max}(H)) than large-batch GD. Flat minima generalize better because small perturbations in θ\boldsymbol{\theta} (from model weight noise, data distribution shift, or quantization) change L\mathcal{L} less. This is the geometric intuition behind sharpness-aware minimization (SAM), which explicitly penalizes the maximal loss in a ball around θ\boldsymbol{\theta}.

PAC-Bayes and curvature. The PAC-Bayes generalization bound gives:

LtestLtrain+tr(H)ndata.\mathcal{L}_{\text{test}} \lesssim \mathcal{L}_{\text{train}} + \sqrt{\frac{\operatorname{tr}(H)}{n_{\text{data}}}}.

The trace of the Hessian (sum of eigenvalues = sum of curvatures) directly bounds generalization. Low curvature → better generalization.

Natural Gradient and Second-Order Methods

The problem with Euclidean gradient descent. The standard gradient step θθηL\boldsymbol{\theta} \leftarrow \boldsymbol{\theta} - \eta\nabla\mathcal{L} treats all parameter directions equally — a step of size η\eta in the direction ei\mathbf{e}_i is indistinguishable from one in the direction ej\mathbf{e}_j. But the effect of a parameter change on the distribution p(;θ)p(\cdot;\boldsymbol{\theta}) varies enormously across directions.

Natural gradient. The Fisher information matrix I(θ)\mathcal{I}(\boldsymbol{\theta}) defines the correct metric on parameter space, measuring how much a parameter perturbation changes the output distribution (in KL divergence):

KL(p(;θ)p(;θ+δ))12δTI(θ)δ.\text{KL}(p(\cdot;\boldsymbol{\theta})\| p(\cdot;\boldsymbol{\theta}+\boldsymbol{\delta})) \approx \frac{1}{2}\boldsymbol{\delta}^T\mathcal{I}(\boldsymbol{\theta})\boldsymbol{\delta}.

Natural gradient descent steepest-descends in this metric:

θθηI(θ)1L.\boldsymbol{\theta} \leftarrow \boldsymbol{\theta} - \eta\mathcal{I}(\boldsymbol{\theta})^{-1}\nabla\mathcal{L}.

Benefits:

  • Invariant to reparameterization of θ\boldsymbol{\theta}
  • Much faster convergence in practice (for well-conditioned Fisher matrices)
  • Reduces to Newton's method when Fisher = Hessian (at a local minimum of negative log-likelihood)

Practical obstacle: I(θ)\mathcal{I}(\boldsymbol{\theta}) is n×nn \times n where nn = number of parameters (possibly billions). Direct inversion is infeasible.

Approximations:

  • Kronecker-Factored Approximate Curvature (K-FAC): IAB\mathcal{I} \approx A \otimes B for each layer, exploiting the block structure of the Fisher for fully-connected layers. Inversion in O(d3/L)O(d^3/L) instead of O(n3)O(n^3).
  • Diagonal approximations (Adagrad, Adam): approximate I\mathcal{I} with its diagonal. Fast to invert, but loses off-diagonal correlations.
  • KFAC-reduce, FOOF, EKFAC: improved Kronecker approximations used in large-scale training.

Connection to Adam. Adam maintains vt=β2vt1+(1β2)(L)2v_t = \beta_2 v_{t-1} + (1-\beta_2)(\nabla\mathcal{L})^2 — an EMA of squared gradients, which approximates the diagonal of the Fisher information matrix. Adam is an implicit natural gradient with diagonal Fisher approximation.

Equivariant Networks and Symmetry

The problem. Many ML tasks have inherent symmetries: a rotated image should give the same classification, a permuted graph should give the same node embeddings, a 3D point cloud should be invariant to rigid transformations.

Standard networks fail symmetry. A generic CNN is translation-equivariant by construction (via weight sharing), but not rotation-equivariant. Achieving rotation equivariance via data augmentation is wasteful and gives only approximate equivariance.

Group-equivariant networks (Cohen & Welling, 2016). Let GG be a group acting on the input space X\mathcal{X} (e.g., G=SO(2)G = SO(2) for 2D rotations). A function f:XYf : \mathcal{X} \to \mathcal{Y} is GG-equivariant if:

f(gx)=ρ(g)f(x)for all gG,xX,f(g \cdot x) = \rho(g) \cdot f(x) \quad \text{for all } g \in G, x \in \mathcal{X},

where ρ\rho is a group representation on Y\mathcal{Y}. For classification (scalar output), equivariance becomes invariance: f(gx)=f(x)f(g \cdot x) = f(x).

Key theorem (Schur's lemma / representation theory). The most general equivariant linear map between two representation spaces decomposes into components along irreducible representations of GG — these are the "symmetry-adapted basis" analogues of eigenvectors.

Examples of equivariant architectures:

Symmetry groupArchitectureApplication
Translation Z2\mathbb{Z}^2CNN (standard)Image processing
Rotation+translation SE(2)SE(2)G-CNNs, steerable CNNsMedical imaging
Rotation+translation SE(3)SE(3)SchNet, DimeNet, SE(3)-TransformersMolecular property prediction
Permutation SnS_nGraph Neural NetworksGraph-structured data
Lorentz groupLorentz Equivariant NetworksParticle physics

Why it matters. Equivariant networks:

  • Require far less data augmentation (symmetry is exact, not approximate)
  • Have fewer effective parameters (weight sharing across symmetry-related configurations)
  • Generalize better when the test distribution includes all symmetry-related views

Differential Geometry and Generative Models

Score-based models (DDPM, SDE). A diffusion process adds noise to data: dx=f(x,t)dt+g(t)dwd\mathbf{x} = \mathbf{f}(\mathbf{x},t)dt + g(t)d\mathbf{w}. The reverse process (generating samples from noise) satisfies:

dx=[f(x,t)g(t)2xlogpt(x)]dt+g(t)dwˉ.d\mathbf{x} = [\mathbf{f}(\mathbf{x},t) - g(t)^2\nabla_\mathbf{x}\log p_t(\mathbf{x})]dt + g(t)d\bar{\mathbf{w}}.

The term xlogpt(x)\nabla_\mathbf{x}\log p_t(\mathbf{x}) is the score function — the gradient of the log-density, a vector field on the data manifold. Score matching trains a neural network sθ(x,t)xlogpt(x)s_\theta(\mathbf{x},t) \approx \nabla_\mathbf{x}\log p_t(\mathbf{x}) without computing the intractable normalization constant.

Riemannian diffusion. When data lives on a Riemannian manifold (e.g., protein backbone angles on a torus, rotation matrices on SO(3)SO(3)), the diffusion process uses the manifold's Laplace-Beltrami operator instead of the flat Laplacian. This produces geometrically correct generative models for structured data.

Worked Example

Example 1: SAM — Sharpness-Aware Minimization

SAM perturbs parameters toward the worst-case loss in a neighborhood:

ϵ^=argmaxϵρL(θ+ϵ)ρθL(θ)θL(θ).\hat\epsilon = \arg\max_{\|\epsilon\| \leq \rho} \mathcal{L}(\boldsymbol{\theta} + \epsilon) \approx \rho \cdot \frac{\nabla_\theta\mathcal{L}(\boldsymbol{\theta})}{\|\nabla_\theta\mathcal{L}(\boldsymbol{\theta})\|}.

Then updates toward the gradient at the perturbed parameters: θθηθL(θ+ϵ^)\boldsymbol{\theta} \leftarrow \boldsymbol{\theta} - \eta\nabla_\theta\mathcal{L}(\boldsymbol{\theta} + \hat\epsilon).

This requires two forward-backward passes per step but consistently finds flatter minima. SAM was state-of-the-art on ImageNet for years and connects directly to the curvature/generalization theory above.

Example 2: K-FAC in Practice

For a layer with input aRdin\mathbf{a} \in \mathbb{R}^{d_{\text{in}}} and output s=WaRdout\mathbf{s} = W\mathbf{a} \in \mathbb{R}^{d_{\text{out}}}, the gradient is WL=gaT\nabla_W \mathcal{L} = \mathbf{g}\mathbf{a}^T where g=L/s\mathbf{g} = \partial\mathcal{L}/\partial\mathbf{s}.

The Fisher block for WW is:

IW=E[(gaT)(gaT)T]=E[ggT]E[aaT]=GA.\mathcal{I}_W = \mathbb{E}[(\mathbf{g}\mathbf{a}^T)(\mathbf{g}\mathbf{a}^T)^T] = \mathbb{E}[\mathbf{g}\mathbf{g}^T] \otimes \mathbb{E}[\mathbf{a}\mathbf{a}^T] = G \otimes A.

K-FAC inverts: IW1G1A1\mathcal{I}_W^{-1} \approx G^{-1} \otimes A^{-1}, and the natural gradient update:

ΔW=ηG1WLA1=ηG1(gaT)A1.\Delta W = -\eta G^{-1} \nabla_W \mathcal{L} \, A^{-1} = -\eta G^{-1} (\mathbf{g}\mathbf{a}^T) A^{-1}.

Cost: O(dout3+din3)O(d_{\text{out}}^3 + d_{\text{in}}^3) per layer instead of O((doutdin)3)O((d_{\text{out}} d_{\text{in}})^3) for the full Fisher.

Example 3: Message Passing is Permutation Equivariant

Graph Neural Networks (GNNs) compute node representations via message passing:

hv(l+1)=UPDATE(l) ⁣(hv(l),AGGREGATE(l) ⁣({hu(l):uN(v)})).\mathbf{h}_v^{(l+1)} = \text{UPDATE}^{(l)}\!\left(\mathbf{h}_v^{(l)},\, \text{AGGREGATE}^{(l)}\!\left(\{\mathbf{h}_u^{(l)} : u \in \mathcal{N}(v)\}\right)\right).

If AGGREGATE is permutation-invariant (e.g., sum, mean, max), then H(l+1)=fθ(H(l),A)\mathbf{H}^{(l+1)} = f_\theta(\mathbf{H}^{(l)}, A) is invariant to relabeling of nodes. This is equivariance to the permutation group SnS_n — a mathematical guarantee that GNNs treat isomorphic graphs identically.

Connections

Where Your Intuition Breaks

Flat minima should generalize better — that is the prevailing geometric intuition, supported by PAC-Bayes theory. But flatness is not reparameterization-invariant. If you rescale some parameters by a large constant λ\lambda (e.g., by inserting a BatchNorm layer that absorbs the scale), the Hessian eigenvalues change dramatically while the function computed by the network is identical. A "flat minimum" in one parameterization is a "sharp minimum" in another. The Fisher-Rao metric (which is reparameterization-invariant) provides a better notion of flatness, but it depends on the data distribution and is expensive to compute. The ongoing debate about what makes a minimum "truly flat" is unresolved — sharpness-aware minimization works empirically but the theoretical foundation for why flatness matters (and which flatness measure) remains an open question.

💡Intuition

The Riemannian perspective on Adam. Adam maintains v^tE[(L)2]\hat{v}_t \approx \mathbb{E}[(\nabla\mathcal{L})^2] — the diagonal of the Fisher information matrix. The update θθηL/v^t\boldsymbol{\theta} \leftarrow \boldsymbol{\theta} - \eta \nabla\mathcal{L}/\sqrt{\hat{v}_t} is natural gradient descent with a diagonal Fisher approximation. This is why Adam adapts step sizes per parameter — it's computing (approximately) the gradient in the geometry where each parameter change is measured by its effect on the model output distribution.

💡Intuition

Equivariance = free data augmentation. If you train a rotation-equivariant network, you never need to include rotated copies of each image in the training set — the equivariance is built into the architecture. Equivalently, the effective training set size is G|G| times larger, where G|G| is the size of the symmetry group. For continuous groups like SO(2)SO(2) (all 2D rotations), this factor is infinite — a genuinely infinite amount of "free" data augmentation baked into the architecture via the differential geometry of the symmetry group.

⚠️Warning

Natural gradient in deep learning requires careful approximation. The exact Fisher matrix is dense and n×nn \times n for nn parameters. Even storing it requires O(n2)O(n^2) memory — impossible for billion-parameter models. K-FAC, Shampoo, and SOAP are the leading practical approximations. They all involve matrix factorizations in each layer, and convergence proofs often rely on the true Fisher being well-approximated by the Kronecker product — an assumption that holds well for large batch sizes but can break for small batches or unusual loss landscapes.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.