Skip to content

LogSumExp Trick

The numerical stability trick of subtracting the maximum value before exponentiating: logiexp(xi)=c+logiexp(xic)\log \sum_i \exp(x_i) = c + \log \sum_i \exp(x_i - c) where c=maxixic = \max_i x_i. Prevents overflow/underflow in softmax, cross-entropy, attention, and any computation involving logexp\log \sum \exp. Used everywhere in deep learning — every call to F.cross_entropy, F.softmax, or F.log_softmax uses this internally.

The problem is simple: exp(x)\exp(x) overflows for x>709x > 709 (float64) or x>88x > 88 (float32), and underflows to zero for large negative values. Neural network logits routinely hit these ranges — a confident classifier might output logits of 50 or -50, and e505×1021e^{50} \approx 5 \times 10^{21} while e502×1022e^{-50} \approx 2 \times 10^{-22}.

The fix: factor out ece^{c} where c=max(xi)c = \max(x_i). Since exi=ecexic\sum e^{x_i} = e^c \sum e^{x_i - c}, taking the log gives c+logexicc + \log \sum e^{x_i - c}. After subtraction, the largest exponent is e0=1e^0 = 1, and all others are 1\leq 1. No overflow is possible. Some small terms may underflow to zero, but that’s fine — they contribute negligibly to the sum anyway.

This isn’t an approximation. It’s an algebraically exact rewriting that’s numerically stable. The result is bit-for-bit identical (in exact arithmetic) but avoids the intermediate infinity that would corrupt the computation. This is why PyTorch’s F.cross_entropy takes raw logits — it applies this trick internally to compute log-softmax safely.

The identity (holds for any constant cc, but c=maxixic = \max_i x_i is optimal):

logi=1Kexi=c+logi=1Kexic,c=maxixi\log \sum_{i=1}^{K} e^{x_i} = c + \log \sum_{i=1}^{K} e^{x_i - c}, \quad c = \max_i x_i

Proof — factor out ece^c:

logexi=log(ecexic)=c+logexic\log \sum e^{x_i} = \log \left(e^c \sum e^{x_i - c}\right) = c + \log \sum e^{x_i - c}

Log-softmax (the form used in cross-entropy):

logsoftmax(xj)=xjlogiexi=(xjc)logiexic\log \text{softmax}(x_j) = x_j - \log \sum_i e^{x_i} = (x_j - c) - \log \sum_i e^{x_i - c}

Softmax (via log-softmax for stability):

softmax(xj)=exp(logsoftmax(xj))=exjciexic\text{softmax}(x_j) = \exp\bigl(\log \text{softmax}(x_j)\bigr) = \frac{e^{x_j - c}}{\sum_i e^{x_i - c}}

Note: computing softmax by first computing log-softmax and then exponentiating is the numerically stable path. Computing exj/exie^{x_j} / \sum e^{x_i} directly risks overflow in both numerator and denominator.

import torch
import torch.nn.functional as F
# ── PyTorch's built-in (already uses the trick internally) ──────
logits = model(x) # (B, C)
log_probs = F.log_softmax(logits, dim=-1) # (B, C) — stable
probs = F.softmax(logits, dim=-1) # (B, C) — stable
lse = torch.logsumexp(logits, dim=-1) # (B,) — stable
# ── Cross-entropy uses log-softmax internally ───────────────────
loss = F.cross_entropy(logits, targets) # stable: log_softmax + nll
# NEVER do: loss = -torch.log(F.softmax(logits, -1)) — loses precision
# ── Explicit logsumexp for custom computations ──────────────────
# Useful when you need log(Σ exp(x)) in a custom loss or metric
# Example: log-sum-exp pooling, log-partition functions, CTC loss
# WARNING: never compute torch.log(torch.sum(torch.exp(x))) manually.
# Use torch.logsumexp(x, dim=...) — it handles the shift for you.
# ── Numerical comparison ────────────────────────────────────────
x = torch.tensor([1000.0, 1001.0, 1002.0])
# torch.log(torch.exp(x).sum()) # → inf (overflow)
torch.logsumexp(x, dim=0) # → tensor(1002.4076) ✓
import numpy as np
def logsumexp(x, axis=-1, keepdims=False):
"""
Numerically stable log-sum-exp.
Equivalent to np.log(np.sum(np.exp(x), axis=axis)).
x: array of any shape
"""
c = x.max(axis=axis, keepdims=True) # shift constant
out = c + np.log(np.exp(x - c).sum(axis=axis, keepdims=True))
if not keepdims:
out = out.squeeze(axis=axis)
return out
def log_softmax(logits):
"""
Numerically stable log-softmax.
logits: (B, C) raw scores
Returns: (B, C) log-probabilities
"""
c = logits.max(axis=-1, keepdims=True) # (B, 1)
shifted = logits - c # (B, C) — max is 0
log_sum = np.log(np.exp(shifted).sum(axis=-1, keepdims=True)) # (B, 1)
return shifted - log_sum # (B, C)
def softmax(logits):
"""
Numerically stable softmax via log-softmax.
logits: (B, C) raw scores
Returns: (B, C) probabilities summing to 1
"""
return np.exp(log_softmax(logits)) # (B, C)
def cross_entropy(logits, targets):
"""
Full numerically stable cross-entropy from raw logits.
logits: (B, C) raw scores
targets: (B,) integer class indices
"""
B = logits.shape[0]
lp = log_softmax(logits) # (B, C)
return -lp[np.arange(B), targets].mean() # scalar
  • Cross-entropy loss (every classifier, every LLM): F.cross_entropy is log-softmax + NLL, and log-softmax is built on logsumexp (see nn-training/)
  • Softmax attention (all transformers): attention scores are softmax of scaled dot products — logsumexp ensures stability with large sequence lengths (see transformer/)
  • Contrastive losses (SimCLR, CLIP): InfoNCE is a cross-entropy over similarity scores that can span a huge range (see contrastive-self-supervising/)
  • Log-partition functions (CRFs, energy-based models): computing logZ=logxexp(E(x))\log Z = \log \sum_x \exp(-E(x)) requires logsumexp for stability
  • Mixture density networks: log-likelihood of Gaussian mixtures involves logsumexp over component log-probabilities
AlternativeWhen to useTradeoff
Naive exp-sum-logNever in productionOverflows/underflows for logits outside [-88, 88] (float32). Only acceptable for small, bounded values
Online softmax (Flash Attention)Memory-efficient attention on long sequencesComputes softmax in a single pass without materialising the full attention matrix; uses the same max-subtraction idea in streaming form
Log-space arithmeticChaining multiple log-sum-exp operationsStay in log-space throughout to avoid repeated exp/log — useful for HMMs, CTC, beam search
Mixed precision (fp16/bf16)Training speedNarrower exponent range makes logsumexp even more critical — overflow at e11e^{11} for fp16 vs e88e^{88} for fp32

The logsumexp trick is one of the oldest numerical methods in scientific computing, predating deep learning by decades. It appears in any field that works with probabilities in log-space: statistical mechanics (partition functions), speech recognition (HMMs), and computational biology (sequence alignment scores). The technique is a special case of the more general principle “never exponentiate a large number if you’re going to take the log of the result.”

In deep learning, its importance grew with the scale of models. Early networks had small logit ranges where naive softmax worked fine. Modern LLMs with vocabulary sizes of 32K-128K and attention over sequences of 8K-1M tokens produce logits that routinely overflow float32 without the shift. The trick is now so fundamental that it’s fused into hardware-level implementations — Flash Attention (Dao et al., 2022) uses online softmax (a streaming variant of the same idea) to compute attention without materialising the full N×NN \times N score matrix.