"""
Unified Contrastive / Self-Supervised Learning Algorithm
==========================================================
A single skeleton covering: SimCLR, CLIP, MoCo, BYOL, and
supervised contrastive learning (SupCon).
The core idea shared by ALL contrastive methods:
1. Create two views of the same thing (PLUGGABLE: how?)
2. Encode both views into embeddings (shared or separate encoders)
3. Pull matching pairs together (ALWAYS THE SAME)
4. Push non-matching pairs apart (PLUGGABLE: how, or even whether?)
The loss is always some form of:
"the embedding of view_A should be more similar to its matching
view_B than to any other view_B in the batch."
The pluggable components are:
1. create_views() — how pairs are generated (augmentation, modality, ...)
2. compute_loss() — the contrastive objective (InfoNCE, CLIP, cosine, ...)
3. update_momentum() — whether/how a momentum encoder is maintained
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import NamedTuple
# ─── Shared Data Types ────────────────────────────────────────────
class ViewPair(NamedTuple):
"""A batch of paired views ready for encoding."""
view_a: torch.Tensor # (B, ...) — first view
view_b: torch.Tensor # (B, ...) — second view (positive match for view_a)
labels: torch.Tensor # (B,) — class labels (optional, -1 if unsupervised)
# ═══════════════════════════════════════════════════════════════════
# CORE: SIMILARITY AND PROJECTION (shared building blocks)
# ═══════════════════════════════════════════════════════════════════
#
# Nearly all contrastive methods project encoder output through
# a small "projection head" before computing similarity.
#
# Why? The encoder's representation needs to be general (for
# downstream tasks). The projection head absorbs the contrastive
# objective's biases — it learns to throw away info that doesn't
# help the contrastive task, protecting the encoder.
# At deployment, the projection head is discarded.
class ProjectionHead(nn.Module):
"""MLP projection head. Maps encoder output → contrastive space."""
def __init__(self, d_in, d_hidden, d_out):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_hidden),
nn.BatchNorm1d(d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_out),
)
def forward(self, x):
return self.net(x)
class PredictionHead(nn.Module):
"""
Extra MLP used by asymmetric methods (BYOL, SimSiam).
Applied to only ONE branch, breaking symmetry to prevent collapse.
"""
def __init__(self, d_in, d_hidden, d_out):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_hidden),
nn.BatchNorm1d(d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_out),
)
def forward(self, x):
return self.net(x)
# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
class ContrastiveAlgorithm(ABC):
"""
The universal contrastive learning training step.
Every variant inherits this and only overrides:
- encode(view) → embeddings (if using separate/momentum encoders)
- compute_loss(z_a, z_b, labels) → scalar loss
- post_step() (momentum update, etc.)
"""
def __init__(self, encoder, projection, optimizer):
self.encoder = encoder
self.projection = projection
self.optimizer = optimizer
def encode_and_project(self, x):
"""Encoder → projection head → L2 normalise."""
h = self.encoder(x) # (B, d_enc)
z = self.projection(h) # (B, d_proj)
return F.normalize(z, dim=-1) # unit sphere
# ── The pluggable pieces ──────────────────────────────────────
@abstractmethod
def compute_loss(self, z_a, z_b, labels=None):
"""The contrastive objective. Returns scalar loss."""
...
def post_step(self):
"""Hook for momentum updates, queue maintenance, etc."""
pass
# ── Core training step (IDENTICAL for every variant) ──────────
def train_step(self, views: ViewPair):
# 1. Encode both views
z_a = self.encode_and_project(views.view_a) # (B, d_proj)
z_b = self.encode_and_project(views.view_b) # (B, d_proj)
# 2. Contrastive loss (PLUGGABLE)
loss = self.compute_loss(z_a, z_b, views.labels)
# 3. Gradient step (always the same)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 4. Post-step hooks (momentum update, etc.)
self.post_step()
return loss.item()
# ═══════════════════════════════════════════════════════════════════
# VIEW CREATION (how pairs are constructed — the other axis)
# ═══════════════════════════════════════════════════════════════════
#
# This is where contrastive methods get their signal. No labels
# needed — the pairing IS the supervision.
#
# • Same-image augmentations (SimCLR, MoCo, BYOL):
# Two random crops/flips/colour jitters of the same image.
# "Two views of the same scene should map to the same point."
#
# • Cross-modal pairs (CLIP):
# An image and its text caption.
# "An image and its description should map to the same point."
#
# • Supervised pairs (SupCon):
# Two augmentations of the same image, with class labels.
# All samples of the same class are positives.
class AugmentationPairCreator:
"""SimCLR / MoCo / BYOL style: two random augmentations of each image."""
def __init__(self, transform):
self.transform = transform # a stochastic augmentation pipeline
def __call__(self, batch, labels=None):
view_a = torch.stack([self.transform(x) for x in batch])
view_b = torch.stack([self.transform(x) for x in batch])
labs = labels if labels is not None else torch.full((len(batch),), -1)
return ViewPair(view_a, view_b, labs)
class CrossModalPairCreator:
"""CLIP style: image and text are already paired in the dataset."""
def __call__(self, images, texts):
return ViewPair(images, texts, torch.arange(len(images)))
# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP (standard supervised-style iteration)
# ═══════════════════════════════════════════════════════════════════
def train(algo: ContrastiveAlgorithm, dataloader, view_creator,
n_epochs, device="cpu"):
for epoch in range(n_epochs):
epoch_loss = 0.0
n = 0
for batch, labels in dataloader:
views = view_creator(batch.to(device), labels.to(device))
loss = algo.train_step(views)
epoch_loss += loss * batch.size(0)
n += batch.size(0)
print(f"Epoch {epoch+1:3d}/{n_epochs} │ loss {epoch_loss/n:.4f}")
# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════
# ── 1. SimCLR (in-batch negatives, InfoNCE loss) ────────────────
class SimCLR(ContrastiveAlgorithm):
"""
Contrastive learning with in-batch negatives.
For each sample i, its positive is the other augmentation of the
same image. Every OTHER sample in the batch is a negative.
Loss: InfoNCE = −log[ exp(sim(z_i, z_i+)) / Σ_j exp(sim(z_i, z_j)) ]
Requires large batch sizes (4096+) to get enough negatives.
"""
def __init__(self, *args, temperature=0.07, **kw):
super().__init__(*args, **kw)
self.tau = temperature
def compute_loss(self, z_a, z_b, labels=None):
B = z_a.size(0)
# All pairwise similarities: 2B × 2B matrix
z = torch.cat([z_a, z_b], dim=0) # (2B, d)
sim = (z @ z.T) / self.tau # (2B, 2B)
# Mask out self-similarity (diagonal)
mask_self = torch.eye(2 * B, device=sim.device).bool()
sim.masked_fill_(mask_self, float("-inf"))
# Positive pairs: (i, i+B) and (i+B, i)
pos_idx = torch.cat([torch.arange(B, 2 * B), torch.arange(B)],
dim=0).to(sim.device) # (2B,)
# InfoNCE: −log softmax along each row, evaluated at positive
return F.cross_entropy(sim, pos_idx)
# ── 2. CLIP (cross-modal contrastive, symmetric loss) ───────────
class CLIP(ContrastiveAlgorithm):
"""
Cross-modal contrastive: image encoder + text encoder.
Each (image, text) pair is a positive. All cross-pairs are negatives.
Loss is symmetric: image→text CE + text→image CE, averaged.
The temperature is a LEARNED parameter (log-scaled for stability).
This is how you align two different modalities in a shared space.
"""
def __init__(self, image_encoder, text_encoder,
image_proj, text_proj, optimizer):
# CLIP has two separate encoders — override the base init
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.image_proj = image_proj
self.text_proj = text_proj
self.optimizer = optimizer
self.log_tau = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))
def train_step(self, views: ViewPair):
# Encode each modality with its own encoder
z_img = F.normalize(self.image_proj(self.image_encoder(views.view_a)), dim=-1)
z_txt = F.normalize(self.text_proj(self.text_encoder(views.view_b)), dim=-1)
loss = self.compute_loss(z_img, z_txt)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def compute_loss(self, z_img, z_txt, labels=None):
tau = self.log_tau.exp()
# Similarity matrix: (B, B) — each image against all texts
logits = (z_img @ z_txt.T) * tau # (B, B)
# Targets: diagonal (image_i matches text_i)
targets = torch.arange(z_img.size(0), device=logits.device)
# Symmetric loss: image→text + text→image
loss_i2t = F.cross_entropy(logits, targets)
loss_t2i = F.cross_entropy(logits.T, targets)
return (loss_i2t + loss_t2i) / 2
# ── 3. MoCo (momentum encoder + queue of negatives) ─────────────
class MoCo(ContrastiveAlgorithm):
"""
Momentum Contrast: maintains a large queue of negative embeddings
from previous batches, encoded by a slowly-updating momentum encoder.
Solves SimCLR's problem: you no longer need giant batch sizes for
enough negatives. A queue of 65536 negatives is typical.
The momentum encoder is NOT trained by gradient — it's an EMA of
the online encoder. This provides consistent targets across the
queue (all negatives were encoded by similar-ish networks).
"""
def __init__(self, encoder, projection, optimizer,
momentum_encoder, momentum_projection,
queue_size=65536, momentum=0.999, temperature=0.07):
super().__init__(encoder, projection, optimizer)
self.m_encoder = momentum_encoder
self.m_projection = momentum_projection
self.m = momentum
self.tau = temperature
self.queue_size = queue_size
# Initialise queue with random unit vectors
d = projection.net[-1].out_features
self.queue = F.normalize(torch.randn(queue_size, d), dim=-1)
self.queue_ptr = 0
# Copy initial weights
for p, mp in zip(encoder.parameters(), momentum_encoder.parameters()):
mp.data.copy_(p.data)
mp.requires_grad_(False)
for p, mp in zip(projection.parameters(), momentum_projection.parameters()):
mp.data.copy_(p.data)
mp.requires_grad_(False)
def compute_loss(self, z_a, z_b, labels=None):
# z_a = online encoder output (queries)
# z_b = momentum encoder output (positive keys) — computed in train_step
# queue = negative keys from previous batches
queue = self.queue.to(z_a.device)
# Positive logits: (B, 1)
pos = (z_a * z_b).sum(dim=-1, keepdim=True) / self.tau
# Negative logits: (B, queue_size)
neg = (z_a @ queue.T) / self.tau
# InfoNCE: positive is index 0
logits = torch.cat([pos, neg], dim=1) # (B, 1 + K)
targets = torch.zeros(z_a.size(0), dtype=torch.long, device=z_a.device)
return F.cross_entropy(logits, targets)
def train_step(self, views: ViewPair):
# Online encoder: queries
z_a = self.encode_and_project(views.view_a)
# Momentum encoder: keys (no gradients)
with torch.no_grad():
h_b = self.m_encoder(views.view_b)
z_b = F.normalize(self.m_projection(h_b), dim=-1)
loss = self.compute_loss(z_a, z_b)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Enqueue current keys, dequeue oldest
self._enqueue(z_b.detach())
self.post_step()
return loss.item()
def _enqueue(self, keys):
B = keys.size(0)
ptr = self.queue_ptr
self.queue[ptr:ptr + B] = keys.cpu()
self.queue_ptr = (ptr + B) % self.queue_size
def post_step(self):
# Momentum update: θ_m ← m·θ_m + (1−m)·θ
for p, mp in zip(self.encoder.parameters(), self.m_encoder.parameters()):
mp.data.lerp_(p.data, 1 - self.m)
for p, mp in zip(self.projection.parameters(), self.m_projection.parameters()):
mp.data.lerp_(p.data, 1 - self.m)
# ── 4. BYOL (no negatives at all) ──────────────────────────────
class BYOL(ContrastiveAlgorithm):
"""
Bootstrap Your Own Latent: learns WITHOUT negative pairs.
This seems like it shouldn't work (what prevents collapse to a
constant?). The asymmetry is the key:
• Online branch: encoder → projector → PREDICTOR → output
• Target branch: momentum_encoder → projector → output (no predictor)
The predictor + momentum update + stop-gradient create enough
asymmetry that collapse is avoided. The online branch tries to
PREDICT the target branch's output, not just match it.
"""
def __init__(self, encoder, projection, predictor, optimizer,
momentum_encoder, momentum_projection, momentum=0.996):
super().__init__(encoder, projection, optimizer)
self.predictor = predictor
self.m_encoder = momentum_encoder
self.m_projection = momentum_projection
self.m = momentum
# Copy initial weights
for p, mp in zip(encoder.parameters(), momentum_encoder.parameters()):
mp.data.copy_(p.data)
mp.requires_grad_(False)
for p, mp in zip(projection.parameters(), momentum_projection.parameters()):
mp.data.copy_(p.data)
mp.requires_grad_(False)
def compute_loss(self, p_a, z_b, labels=None):
# Negative cosine similarity (no negatives needed!)
return 2 - 2 * (p_a * z_b).sum(dim=-1).mean()
def train_step(self, views: ViewPair):
# Online: encode → project → predict
z_a = self.encode_and_project(views.view_a)
p_a = F.normalize(self.predictor(z_a), dim=-1)
# Target: momentum encode → project (no predictor, no grad)
with torch.no_grad():
h_b = self.m_encoder(views.view_b)
z_b = F.normalize(self.m_projection(h_b), dim=-1)
# Symmetric loss: predict both directions
z_a2 = self.encode_and_project(views.view_b)
p_a2 = F.normalize(self.predictor(z_a2), dim=-1)
with torch.no_grad():
h_b2 = self.m_encoder(views.view_a)
z_b2 = F.normalize(self.m_projection(h_b2), dim=-1)
loss = (self.compute_loss(p_a, z_b.detach()) +
self.compute_loss(p_a2, z_b2.detach())) / 2
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.post_step()
return loss.item()
def post_step(self):
for p, mp in zip(self.encoder.parameters(), self.m_encoder.parameters()):
mp.data.lerp_(p.data, 1 - self.m)
for p, mp in zip(self.projection.parameters(), self.m_projection.parameters()):
mp.data.lerp_(p.data, 1 - self.m)
# ── 5. SupCon (supervised contrastive) ──────────────────────────
class SupCon(ContrastiveAlgorithm):
"""
Supervised Contrastive Learning: uses labels to define positives.
All samples with the SAME label are positives for each other.
Generalises SimCLR: instead of just the augmented pair being
positive, every same-class sample in the batch is positive.
Outperforms standard cross-entropy on many benchmarks.
"""
def __init__(self, *args, temperature=0.07, **kw):
super().__init__(*args, **kw)
self.tau = temperature
def compute_loss(self, z_a, z_b, labels=None):
z = torch.cat([z_a, z_b], dim=0) # (2B, d)
B = z_a.size(0)
N = 2 * B
sim = (z @ z.T) / self.tau # (N, N)
# Mask: 1 where labels match (positives), 0 elsewhere
labels = labels.repeat(2) # (2B,)
pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
pos_mask.fill_diagonal_(0) # exclude self
# For numerical stability
sim_max, _ = sim.max(dim=1, keepdim=True)
sim = sim - sim_max.detach()
# Log-sum-exp over all non-self entries
self_mask = torch.eye(N, device=sim.device)
exp_sim = torch.exp(sim) * (1 - self_mask)
log_sum_exp = torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
# Mean of log-prob over positive pairs
log_prob = sim - log_sum_exp
mean_log_prob = (pos_mask * log_prob).sum(dim=1) / (pos_mask.sum(dim=1) + 1e-8)
return -mean_log_prob.mean()