Focal Loss
Focal Loss
Section titled “Focal Loss”A modification of cross-entropy that down-weights easy examples and focuses training on hard, misclassified ones. Designed to handle extreme class imbalance in dense object detection (e.g. 100,000 background anchors vs. 10 object anchors per image), but useful in any imbalanced classification setting.
Intuition
Section titled “Intuition”In standard cross-entropy, every correctly classified example still contributes loss and gradient. When 99.9% of your samples are easy negatives (background in object detection), these easy examples dominate the total loss. Each one contributes a small amount, but there are so many of them that they drown out the signal from the rare, hard positives.
Focal loss adds a modulating factor that smoothly scales down the loss for confident predictions. When the model predicts a sample correctly with probability 0.95, the modulating factor is . With (the standard), that’s — the loss contribution is reduced by 400x. But for a hard example where the model assigns only 0.3 probability, the factor is — nearly full strength.
The key insight: this is not the same as hard example mining (which uses a sharp threshold). Focal loss smoothly re-weights all examples, so moderately hard examples still contribute proportionally. The hyperparameter controls how aggressively to down-weight easy examples: recovers standard cross-entropy, is extremely aggressive. In practice, works well across most settings.
Standard cross-entropy (for reference):
where is the model’s predicted probability for the true class.
Focal loss:
- is the modulating factor. When (easy, correct), this approaches 0. When (hard, wrong), this approaches 1
- is a class-balancing weight (e.g. for the rare class). This is independent of the focal mechanism
- is the focusing parameter. Standard value: 2
Binary focal loss (the most common form, used in detection):
where if , else , and is the sigmoid.
Gradient (instructive — shows why it works):
The first term is the focal weighting applied to the gradient itself, ensuring easy examples get small gradients.
import torchimport torch.nn.functional as F
# ── Binary focal loss (from logits) ─────────────────────────────# No built-in PyTorch function. torchvision has one:# from torchvision.ops import sigmoid_focal_loss# But here's the manual version (preferred for clarity):
def focal_loss(logits, targets, alpha=0.25, gamma=2.0): """ logits: (B,) or (B, 1) raw scores (before sigmoid) targets: (B,) float, 0.0 or 1.0 """ bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') # (B,) p_t = torch.exp(-bce) # probability of correct class focal_weight = (1 - p_t) ** gamma # (B,) — down-weight easy examples alpha_t = alpha * targets + (1 - alpha) * (1 - targets) # (B,) loss = alpha_t * focal_weight * bce # (B,) return loss.mean()
# ── Multi-class focal loss ───────────────────────────────────────def focal_loss_multiclass(logits, targets, gamma=2.0): """ logits: (B, C) raw class scores targets: (B,) integer class indices """ ce = F.cross_entropy(logits, targets, reduction='none') # (B,) p_t = torch.exp(-ce) # (B,) loss = ((1 - p_t) ** gamma * ce).mean() return lossManual Implementation
Section titled “Manual Implementation”import numpy as np
def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0): """ Binary focal loss from raw logits. logits: (B,) raw scores before sigmoid targets: (B,) float, 0.0 or 1.0 """ # Numerically stable BCE: max(0, x) - x*t + log(1 + exp(-|x|)) bce = (np.maximum(0, logits) - logits * targets + np.log1p(np.exp(-np.abs(logits)))) # (B,)
# p_t = probability assigned to the TRUE class p_t = np.exp(-bce) # (B,)
# Focal modulating factor: easy examples (high p_t) get near-zero weight focal_weight = (1 - p_t) ** gamma # (B,)
# Class-balancing weight alpha_t = alpha * targets + (1 - alpha) * (1 - targets) # (B,)
return (alpha_t * focal_weight * bce).mean()Popular Uses
Section titled “Popular Uses”- Dense object detection (RetinaNet): the original use case. Made single-stage detectors competitive with two-stage by handling the massive foreground/background imbalance
- Instance segmentation (Mask R-CNN variants): focal loss on the mask classification head
- Medical imaging: lesion detection where positive pixels are rare (e.g. tumor segmentation, retinal disease)
- Fraud/anomaly detection: any binary classification where positives are <1% of the data
- Multi-label classification: when each label has different frequencies, focal loss per label prevents common labels from dominating
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Cross-entropy + class weights | Moderate imbalance (up to ~10:1) | Simple to implement; doesn’t adapt to example difficulty, only class frequency |
| Hard example mining (OHEM) | Very extreme imbalance, want strict selection | Uses only the hardest examples per batch; less stable than focal loss’s smooth re-weighting |
| Hinge loss | GAN discriminators, SVMs | Also ignores easy examples (zero loss beyond margin) but sharp cutoff rather than smooth |
| Dice / IoU loss | Segmentation with imbalanced masks | Directly optimises overlap metric; doesn’t decompose into per-pixel terms |
| Class-balanced sampling | Moderate imbalance, simpler approach | Re-sample batches to equalize class frequency; orthogonal to loss function choice |
Historical Context
Section titled “Historical Context”Focal loss was introduced by Lin et al. (2017) in the RetinaNet paper (“Focal Loss for Dense Object Detection”). The core problem it solved was why single-stage detectors (like SSD and YOLO) lagged behind two-stage detectors (like Faster R-CNN): the two-stage approach uses region proposals to filter out most background before classification, implicitly handling imbalance. Focal loss achieved the same effect purely through the loss function, making RetinaNet the first single-stage detector to match two-stage accuracy.
The idea of re-weighting examples by difficulty has precursors in boosting (AdaBoost, 1997) and curriculum learning, but focal loss’s contribution was showing that a simple, differentiable modulating factor applied to cross-entropy was sufficient. The specific form was chosen empirically — the authors tried several functional forms, and this one worked best. The approach has since been adopted far beyond detection into any domain with severe class imbalance.