Cross-Entropy Loss
Cross-Entropy Loss
Section titled “Cross-Entropy Loss”Measures how well a predicted probability distribution matches a target distribution. The standard loss function for classification — both discrete (image classification, next-token prediction) and binary (spam detection, real/fake discrimination).
Intuition
Section titled “Intuition”Imagine you’re placing bets. Cross-entropy asks: “if the true answer follows distribution P, how many bits do I waste by using my predicted distribution Q to encode it?” A perfect prediction wastes zero extra bits. The worse Q matches P, the more bits wasted.
For classification: the true distribution P is a one-hot vector (all probability on the correct class). Cross-entropy then simplifies to “how much probability did you put on the right answer?” If you put probability 0.9 on the correct class, your loss is −log(0.9) ≈ 0.105. If you put 0.01, your loss is −log(0.01) ≈ 4.6. The log means the penalty grows explosively as confidence in the wrong answer increases — this is the key property that makes it work well for training.
Note: cross-entropy on one-hot targets is mathematically identical to negative log-likelihood (NLL). The terms are used interchangeably in practice.
General form (discrete distributions):
Classification (one-hot target, class is correct):
where are the raw logits (unnormalised scores) from the network.
Expanding the log-softmax:
This is the form actually computed — it’s numerically stable and avoids computing softmax explicitly.
Binary cross-entropy (single probability , target ):
With label smoothing (soften one-hot targets by mixing with uniform):
where is the number of classes and is typically 0.1. Prevents the model from becoming overconfident.
import torchimport torch.nn.functional as F
# ── Standard classification (logits → loss) ──────────────────────# F.cross_entropy takes RAW LOGITS, not probabilities.# It does log-softmax + NLL internally in a numerically stable way.# NEVER apply softmax before this — you'll get wrong gradients and# numerical issues.
logits = model(x) # (B, n_classes) — raw scorestargets = labels # (B,) — integer class indicesloss = F.cross_entropy(logits, targets) # scalar
# ── With label smoothing ─────────────────────────────────────────loss = F.cross_entropy(logits, targets, label_smoothing=0.1)
# ── Binary classification (single logit per sample) ──────────────logit = model(x) # (B, 1) or (B,) — single scoretarget = labels.float() # (B,) — 0.0 or 1.0loss = F.binary_cross_entropy_with_logits(logit, target)# Again: takes raw logits, applies sigmoid internally.Manual Implementation
Section titled “Manual Implementation”import numpy as np
def cross_entropy_manual(logits, targets): """ Equivalent to F.cross_entropy. logits: (B, C) raw scores — NOT probabilities targets: (B,) integer class indices """ B, C = logits.shape
# Numerically stable log-softmax: subtract max to prevent overflow in exp() shifted = logits - logits.max(axis=1, keepdims=True) # (B, C) log_sum_exp = np.log(np.exp(shifted).sum(axis=1, keepdims=True)) # (B, 1) log_probs = shifted - log_sum_exp # (B, C)
# Pick the log-prob of the correct class for each sample loss_per_sample = -log_probs[np.arange(B), targets] # (B,) return loss_per_sample.mean()
def binary_cross_entropy_manual(logits, targets): """ Equivalent to F.binary_cross_entropy_with_logits. logits: (B,) raw scores targets: (B,) float 0.0 or 1.0 """ # Numerically stable form: max(0, logit) - logit*target + log(1 + exp(-|logit|)) return (np.maximum(0, logits) - logits * targets + np.log1p(np.exp(-np.abs(logits)))).mean()Popular Uses
Section titled “Popular Uses”- Image classification (ResNet, ViT): predict one class from K options
- Language modelling / next-token prediction (GPT, LLaMA): cross-entropy over the full vocabulary at every position — this is THE training objective for LLMs
- GAN discriminators (vanilla GAN): binary cross-entropy for real/fake classification
- Knowledge distillation: cross-entropy between student and teacher softened distributions (with temperature)
- Contrastive learning (SimCLR, CLIP): InfoNCE loss is cross-entropy where the “classes” are the positive pair vs. all negatives in the batch
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| MSE loss | Regression, diffusion noise prediction | Doesn’t penalise confident wrong answers as harshly; not suitable for classification |
| Focal loss | Class-imbalanced classification (object detection) | Down-weights easy examples, focuses on hard ones. Adds a modulating factor |
| Hinge loss | SVMs, hinge GAN | Margin-based — only penalises if the correct class score isn’t above the margin. No probability interpretation |
| CTC loss | Sequence-to-sequence without alignment (speech recognition) | Marginalises over all valid alignments between input and output sequences |
| KL divergence | Soft target distributions (distillation, VAE regularisation) | Cross-entropy minus the target’s own entropy. Identical gradient when the target is fixed |
Historical Context
Section titled “Historical Context”Cross-entropy comes from information theory (Shannon, 1948), where it measures the expected message length when using code Q to encode messages from distribution P. It entered machine learning through logistic regression in statistics and was formalised as the standard classification loss through maximum likelihood estimation — minimising cross-entropy is equivalent to maximising the likelihood of the data under the model.
The key practical innovation was the “logits” formulation: computing log-softmax + NLL in a single fused operation (the logsumexp trick) rather than computing softmax first and then taking the log. This avoids numerical underflow/overflow and is the reason every modern framework has cross_entropy take raw logits rather than probabilities.
Label smoothing (Szegedy et al., 2016, “Rethinking the Inception Architecture”) was a simple but effective addition that prevents the model from learning to output extreme logits, improving generalisation and calibration.