LogSumExp Trick
LogSumExp Trick
Section titled “LogSumExp Trick”The numerical stability trick of subtracting the maximum value before exponentiating: where . Prevents overflow/underflow in softmax, cross-entropy, attention, and any computation involving . Used everywhere in deep learning — every call to F.cross_entropy, F.softmax, or F.log_softmax uses this internally.
Intuition
Section titled “Intuition”The problem is simple: overflows for (float64) or (float32), and underflows to zero for large negative values. Neural network logits routinely hit these ranges — a confident classifier might output logits of 50 or -50, and while .
The fix: factor out where . Since , taking the log gives . After subtraction, the largest exponent is , and all others are . No overflow is possible. Some small terms may underflow to zero, but that’s fine — they contribute negligibly to the sum anyway.
This isn’t an approximation. It’s an algebraically exact rewriting that’s numerically stable. The result is bit-for-bit identical (in exact arithmetic) but avoids the intermediate infinity that would corrupt the computation. This is why PyTorch’s F.cross_entropy takes raw logits — it applies this trick internally to compute log-softmax safely.
The identity (holds for any constant , but is optimal):
Proof — factor out :
Log-softmax (the form used in cross-entropy):
Softmax (via log-softmax for stability):
Note: computing softmax by first computing log-softmax and then exponentiating is the numerically stable path. Computing directly risks overflow in both numerator and denominator.
import torchimport torch.nn.functional as F
# ── PyTorch's built-in (already uses the trick internally) ──────logits = model(x) # (B, C)log_probs = F.log_softmax(logits, dim=-1) # (B, C) — stableprobs = F.softmax(logits, dim=-1) # (B, C) — stablelse = torch.logsumexp(logits, dim=-1) # (B,) — stable
# ── Cross-entropy uses log-softmax internally ───────────────────loss = F.cross_entropy(logits, targets) # stable: log_softmax + nll# NEVER do: loss = -torch.log(F.softmax(logits, -1)) — loses precision
# ── Explicit logsumexp for custom computations ──────────────────# Useful when you need log(Σ exp(x)) in a custom loss or metric# Example: log-sum-exp pooling, log-partition functions, CTC loss
# WARNING: never compute torch.log(torch.sum(torch.exp(x))) manually.# Use torch.logsumexp(x, dim=...) — it handles the shift for you.
# ── Numerical comparison ────────────────────────────────────────x = torch.tensor([1000.0, 1001.0, 1002.0])# torch.log(torch.exp(x).sum()) # → inf (overflow)torch.logsumexp(x, dim=0) # → tensor(1002.4076) ✓Manual Implementation
Section titled “Manual Implementation”import numpy as np
def logsumexp(x, axis=-1, keepdims=False): """ Numerically stable log-sum-exp. Equivalent to np.log(np.sum(np.exp(x), axis=axis)). x: array of any shape """ c = x.max(axis=axis, keepdims=True) # shift constant out = c + np.log(np.exp(x - c).sum(axis=axis, keepdims=True)) if not keepdims: out = out.squeeze(axis=axis) return out
def log_softmax(logits): """ Numerically stable log-softmax. logits: (B, C) raw scores Returns: (B, C) log-probabilities """ c = logits.max(axis=-1, keepdims=True) # (B, 1) shifted = logits - c # (B, C) — max is 0 log_sum = np.log(np.exp(shifted).sum(axis=-1, keepdims=True)) # (B, 1) return shifted - log_sum # (B, C)
def softmax(logits): """ Numerically stable softmax via log-softmax. logits: (B, C) raw scores Returns: (B, C) probabilities summing to 1 """ return np.exp(log_softmax(logits)) # (B, C)
def cross_entropy(logits, targets): """ Full numerically stable cross-entropy from raw logits. logits: (B, C) raw scores targets: (B,) integer class indices """ B = logits.shape[0] lp = log_softmax(logits) # (B, C) return -lp[np.arange(B), targets].mean() # scalarPopular Uses
Section titled “Popular Uses”- Cross-entropy loss (every classifier, every LLM):
F.cross_entropyis log-softmax + NLL, and log-softmax is built on logsumexp (seenn-training/) - Softmax attention (all transformers): attention scores are softmax of scaled dot products — logsumexp ensures stability with large sequence lengths (see
transformer/) - Contrastive losses (SimCLR, CLIP): InfoNCE is a cross-entropy over similarity scores that can span a huge range (see
contrastive-self-supervising/) - Log-partition functions (CRFs, energy-based models): computing requires logsumexp for stability
- Mixture density networks: log-likelihood of Gaussian mixtures involves logsumexp over component log-probabilities
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Naive exp-sum-log | Never in production | Overflows/underflows for logits outside [-88, 88] (float32). Only acceptable for small, bounded values |
| Online softmax (Flash Attention) | Memory-efficient attention on long sequences | Computes softmax in a single pass without materialising the full attention matrix; uses the same max-subtraction idea in streaming form |
| Log-space arithmetic | Chaining multiple log-sum-exp operations | Stay in log-space throughout to avoid repeated exp/log — useful for HMMs, CTC, beam search |
| Mixed precision (fp16/bf16) | Training speed | Narrower exponent range makes logsumexp even more critical — overflow at for fp16 vs for fp32 |
Historical Context
Section titled “Historical Context”The logsumexp trick is one of the oldest numerical methods in scientific computing, predating deep learning by decades. It appears in any field that works with probabilities in log-space: statistical mechanics (partition functions), speech recognition (HMMs), and computational biology (sequence alignment scores). The technique is a special case of the more general principle “never exponentiate a large number if you’re going to take the log of the result.”
In deep learning, its importance grew with the scale of models. Early networks had small logit ranges where naive softmax worked fine. Modern LLMs with vocabulary sizes of 32K-128K and attention over sequences of 8K-1M tokens produce logits that routinely overflow float32 without the shift. The trick is now so fundamental that it’s fused into hardware-level implementations — Flash Attention (Dao et al., 2022) uses online softmax (a streaming variant of the same idea) to compute attention without materialising the full score matrix.