Skip to content

Focal Loss

A modification of cross-entropy that down-weights easy examples and focuses training on hard, misclassified ones. Designed to handle extreme class imbalance in dense object detection (e.g. 100,000 background anchors vs. 10 object anchors per image), but useful in any imbalanced classification setting.

In standard cross-entropy, every correctly classified example still contributes loss and gradient. When 99.9% of your samples are easy negatives (background in object detection), these easy examples dominate the total loss. Each one contributes a small amount, but there are so many of them that they drown out the signal from the rare, hard positives.

Focal loss adds a modulating factor (1pt)γ(1 - p_t)^\gamma that smoothly scales down the loss for confident predictions. When the model predicts a sample correctly with probability 0.95, the modulating factor is 0.05γ0.05^\gamma. With γ=2\gamma = 2 (the standard), that’s 0.00250.0025 — the loss contribution is reduced by 400x. But for a hard example where the model assigns only 0.3 probability, the factor is 0.72=0.490.7^2 = 0.49 — nearly full strength.

The key insight: this is not the same as hard example mining (which uses a sharp threshold). Focal loss smoothly re-weights all examples, so moderately hard examples still contribute proportionally. The hyperparameter γ\gamma controls how aggressively to down-weight easy examples: γ=0\gamma = 0 recovers standard cross-entropy, γ=5\gamma = 5 is extremely aggressive. In practice, γ=2\gamma = 2 works well across most settings.

Standard cross-entropy (for reference):

CE(pt)=log(pt)\text{CE}(p_t) = -\log(p_t)

where ptp_t is the model’s predicted probability for the true class.

Focal loss:

FL(pt)=αt(1pt)γlog(pt)\text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)

  • (1pt)γ(1 - p_t)^\gamma is the modulating factor. When pt1p_t \to 1 (easy, correct), this approaches 0. When pt0p_t \to 0 (hard, wrong), this approaches 1
  • αt\alpha_t is a class-balancing weight (e.g. α=0.25\alpha = 0.25 for the rare class). This is independent of the focal mechanism
  • γ\gamma is the focusing parameter. Standard value: 2

Binary focal loss (the most common form, used in detection):

FL=αt(1pt)γlog(pt)\text{FL} = -\alpha_t (1 - p_t)^\gamma \log(p_t)

where pt=σ(x)p_t = \sigma(x) if y=1y = 1, else pt=1σ(x)p_t = 1 - \sigma(x), and σ\sigma is the sigmoid.

Gradient (instructive — shows why it works):

FLx=αt[γ(1pt)γ1log(pt)+(1pt)γ1pt]pt(1pt)\frac{\partial \text{FL}}{\partial x} = -\alpha_t \Bigl[\gamma (1 - p_t)^{\gamma-1} \log(p_t) + (1 - p_t)^\gamma \cdot \frac{-1}{p_t}\Bigr] \cdot p_t(1 - p_t)

The first term is the focal weighting applied to the gradient itself, ensuring easy examples get small gradients.

import torch
import torch.nn.functional as F
# ── Binary focal loss (from logits) ─────────────────────────────
# No built-in PyTorch function. torchvision has one:
# from torchvision.ops import sigmoid_focal_loss
# But here's the manual version (preferred for clarity):
def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
"""
logits: (B,) or (B, 1) raw scores (before sigmoid)
targets: (B,) float, 0.0 or 1.0
"""
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') # (B,)
p_t = torch.exp(-bce) # probability of correct class
focal_weight = (1 - p_t) ** gamma # (B,) — down-weight easy examples
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) # (B,)
loss = alpha_t * focal_weight * bce # (B,)
return loss.mean()
# ── Multi-class focal loss ───────────────────────────────────────
def focal_loss_multiclass(logits, targets, gamma=2.0):
"""
logits: (B, C) raw class scores
targets: (B,) integer class indices
"""
ce = F.cross_entropy(logits, targets, reduction='none') # (B,)
p_t = torch.exp(-ce) # (B,)
loss = ((1 - p_t) ** gamma * ce).mean()
return loss
import numpy as np
def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0):
"""
Binary focal loss from raw logits.
logits: (B,) raw scores before sigmoid
targets: (B,) float, 0.0 or 1.0
"""
# Numerically stable BCE: max(0, x) - x*t + log(1 + exp(-|x|))
bce = (np.maximum(0, logits) - logits * targets
+ np.log1p(np.exp(-np.abs(logits)))) # (B,)
# p_t = probability assigned to the TRUE class
p_t = np.exp(-bce) # (B,)
# Focal modulating factor: easy examples (high p_t) get near-zero weight
focal_weight = (1 - p_t) ** gamma # (B,)
# Class-balancing weight
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) # (B,)
return (alpha_t * focal_weight * bce).mean()
  • Dense object detection (RetinaNet): the original use case. Made single-stage detectors competitive with two-stage by handling the massive foreground/background imbalance
  • Instance segmentation (Mask R-CNN variants): focal loss on the mask classification head
  • Medical imaging: lesion detection where positive pixels are rare (e.g. tumor segmentation, retinal disease)
  • Fraud/anomaly detection: any binary classification where positives are <1% of the data
  • Multi-label classification: when each label has different frequencies, focal loss per label prevents common labels from dominating
AlternativeWhen to useTradeoff
Cross-entropy + class weightsModerate imbalance (up to ~10:1)Simple to implement; doesn’t adapt to example difficulty, only class frequency
Hard example mining (OHEM)Very extreme imbalance, want strict selectionUses only the hardest examples per batch; less stable than focal loss’s smooth re-weighting
Hinge lossGAN discriminators, SVMsAlso ignores easy examples (zero loss beyond margin) but sharp cutoff rather than smooth
Dice / IoU lossSegmentation with imbalanced masksDirectly optimises overlap metric; doesn’t decompose into per-pixel terms
Class-balanced samplingModerate imbalance, simpler approachRe-sample batches to equalize class frequency; orthogonal to loss function choice

Focal loss was introduced by Lin et al. (2017) in the RetinaNet paper (“Focal Loss for Dense Object Detection”). The core problem it solved was why single-stage detectors (like SSD and YOLO) lagged behind two-stage detectors (like Faster R-CNN): the two-stage approach uses region proposals to filter out most background before classification, implicitly handling imbalance. Focal loss achieved the same effect purely through the loss function, making RetinaNet the first single-stage detector to match two-stage accuracy.

The idea of re-weighting examples by difficulty has precursors in boosting (AdaBoost, 1997) and curriculum learning, but focal loss’s contribution was showing that a simple, differentiable modulating factor applied to cross-entropy was sufficient. The specific form (1pt)γ(1 - p_t)^\gamma was chosen empirically — the authors tried several functional forms, and this one worked best. The approach has since been adopted far beyond detection into any domain with severe class imbalance.