Loss Landscape Sharpness
Loss Landscape Sharpness
Section titled “Loss Landscape Sharpness”Sharp minima in the loss landscape generalise poorly — small perturbations to the parameters cause large changes in loss. Flat minima generalise well because they’re robust to parameter noise. The geometry of where you converge matters as much as the loss value you reach.
Intuition
Section titled “Intuition”Imagine two valleys in a mountain range. Valley A is a narrow, deep canyon — standing at the bottom, one step in any direction sends you up a steep wall. Valley B is a wide, gentle basin — you can wander around the bottom and the altitude barely changes. Both valleys have the same minimum altitude (same training loss), but Valley B is far more forgiving.
Why does flatness help generalisation? The test set is like a slightly shifted version of the training landscape. If you’re in a sharp minimum, that small shift moves the loss surface enough to put you on a steep slope — test loss is much higher than train loss. If you’re in a flat minimum, the same shift barely changes anything — test loss stays close to train loss.
Batch size has a surprising connection: large batches tend to find sharp minima (the gradient is very accurate, so the optimiser follows the narrow path to a sharp valley), while small batches tend to find flat minima (the gradient noise prevents the optimiser from entering narrow valleys — it bounces off the steep walls). This is one reason small-batch SGD generalises better than large-batch training despite being less compute-efficient.
Manifestation
Section titled “Manifestation”- Low training loss but high test loss — the model found a minimum but it doesn’t generalise (this alone could be overfitting; sharpness is one mechanism)
- Sensitivity to weight perturbation — add small Gaussian noise to weights and measure the loss increase; large increase = sharp minimum
- Large-batch training generalises worse than small-batch on the same architecture and data
- The Hessian has large eigenvalues at convergence — this directly measures the curvature (sharpness) of the minimum
- Sharpness-aware methods (SAM) improve test accuracy — if SAM helps significantly, the original optimiser was likely converging to sharp minima
Where It Appears
Section titled “Where It Appears”- NN training (
nn-training/): the central setting — optimiser choice, batch size, learning rate, and weight decay all influence whether the model converges to a sharp or flat minimum - Transformer (
transformer/): large language model training uses carefully tuned learning rates and warmup to avoid sharp minima — cosine annealing with warmup is partly motivated by this - Policy gradient (
policy-gradient/): sharp policy optima can cause fragile policies that collapse when the environment changes slightly — PPO’s trust region implicitly favours flatter optima
Solutions at a Glance
Section titled “Solutions at a Glance”| Solution | Mechanism | Where documented |
|---|---|---|
| Small-batch SGD | Gradient noise prevents convergence to sharp minima | (standard practice) |
| SAM (Sharpness-Aware Minimisation) | Optimises the worst-case loss in a neighbourhood of the current parameters | (Foret et al., 2021) |
| Weight decay | Shrinks weights toward zero, biasing toward simpler (flatter) solutions | atomic-concepts/regularisation/weight-decay.md |
| Learning rate warmup + cosine annealing | Controlled LR trajectory avoids sharp early minima | atomic-concepts/optimisation-primitives/learning-rate-warmup.md, cosine-annealing.md |
| Stochastic Weight Averaging (SWA) | Averages weights over the training trajectory, landing in flatter regions | (Izmailov et al., 2018) |
| Dropout | Adds noise during training that effectively smooths the loss landscape | atomic-concepts/regularisation/dropout.md |
Historical Context
Section titled “Historical Context”Hochreiter & Schmidhuber (1997) first proposed that flat minima generalise better, based on a minimum description length argument. The idea was revived by Keskar et al. (2017), who showed empirically that large-batch training converges to sharper minima and generalises worse. Foret et al. (2021) introduced SAM, which directly optimises for flat minima by performing a gradient ascent step (to find the worst nearby point) followed by a gradient descent step (to minimise that worst case). SAM and its variants have become popular in practice, especially for vision transformers. The theoretical foundations are still debated — Dinh et al. (2017) showed that sharpness is not invariant to reparameterisation, complicating the clean “flat = good” story — but the empirical evidence remains strong.