Skip to content

Unified Contrastive / Self-Supervised Learning Algorithm — Implementation

"""
Unified Contrastive / Self-Supervised Learning Algorithm
==========================================================
A single skeleton covering: SimCLR, CLIP, MoCo, BYOL, and
supervised contrastive learning (SupCon).

The core idea shared by ALL contrastive methods:

  1. Create two views of the same thing          (PLUGGABLE: how?)
  2. Encode both views into embeddings           (shared or separate encoders)
  3. Pull matching pairs together                (ALWAYS THE SAME)
  4. Push non-matching pairs apart               (PLUGGABLE: how, or even whether?)

The loss is always some form of:
  "the embedding of view_A should be more similar to its matching
   view_B than to any other view_B in the batch."

The pluggable components are:
  1. create_views()     — how pairs are generated (augmentation, modality, ...)
  2. compute_loss()     — the contrastive objective (InfoNCE, CLIP, cosine, ...)
  3. update_momentum()  — whether/how a momentum encoder is maintained
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import NamedTuple


# ─── Shared Data Types ────────────────────────────────────────────

class ViewPair(NamedTuple):
    """A batch of paired views ready for encoding."""
    view_a: torch.Tensor    # (B, ...) — first view
    view_b: torch.Tensor    # (B, ...) — second view (positive match for view_a)
    labels: torch.Tensor    # (B,)     — class labels (optional, -1 if unsupervised)


# ═══════════════════════════════════════════════════════════════════
# CORE: SIMILARITY AND PROJECTION  (shared building blocks)
# ═══════════════════════════════════════════════════════════════════
#
# Nearly all contrastive methods project encoder output through
# a small "projection head" before computing similarity.
#
# Why? The encoder's representation needs to be general (for
# downstream tasks). The projection head absorbs the contrastive
# objective's biases — it learns to throw away info that doesn't
# help the contrastive task, protecting the encoder.
# At deployment, the projection head is discarded.

class ProjectionHead(nn.Module):
    """MLP projection head. Maps encoder output → contrastive space."""

    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.BatchNorm1d(d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out),
        )

    def forward(self, x):
        return self.net(x)


class PredictionHead(nn.Module):
    """
    Extra MLP used by asymmetric methods (BYOL, SimSiam).
    Applied to only ONE branch, breaking symmetry to prevent collapse.
    """

    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.BatchNorm1d(d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out),
        )

    def forward(self, x):
        return self.net(x)


# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM  (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════

class ContrastiveAlgorithm(ABC):
    """
    The universal contrastive learning training step.

    Every variant inherits this and only overrides:
      - encode(view) → embeddings  (if using separate/momentum encoders)
      - compute_loss(z_a, z_b, labels) → scalar loss
      - post_step()  (momentum update, etc.)
    """

    def __init__(self, encoder, projection, optimizer):
        self.encoder = encoder
        self.projection = projection
        self.optimizer = optimizer

    def encode_and_project(self, x):
        """Encoder → projection head → L2 normalise."""
        h = self.encoder(x)                         # (B, d_enc)
        z = self.projection(h)                      # (B, d_proj)
        return F.normalize(z, dim=-1)               # unit sphere

    # ── The pluggable pieces ──────────────────────────────────────

    @abstractmethod
    def compute_loss(self, z_a, z_b, labels=None):
        """The contrastive objective. Returns scalar loss."""
        ...

    def post_step(self):
        """Hook for momentum updates, queue maintenance, etc."""
        pass

    # ── Core training step (IDENTICAL for every variant) ──────────

    def train_step(self, views: ViewPair):
        # 1. Encode both views
        z_a = self.encode_and_project(views.view_a)  # (B, d_proj)
        z_b = self.encode_and_project(views.view_b)  # (B, d_proj)

        # 2. Contrastive loss  (PLUGGABLE)
        loss = self.compute_loss(z_a, z_b, views.labels)

        # 3. Gradient step  (always the same)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # 4. Post-step hooks (momentum update, etc.)
        self.post_step()

        return loss.item()


# ═══════════════════════════════════════════════════════════════════
# VIEW CREATION  (how pairs are constructed — the other axis)
# ═══════════════════════════════════════════════════════════════════
#
# This is where contrastive methods get their signal. No labels
# needed — the pairing IS the supervision.
#
#   • Same-image augmentations  (SimCLR, MoCo, BYOL):
#       Two random crops/flips/colour jitters of the same image.
#       "Two views of the same scene should map to the same point."
#
#   • Cross-modal pairs  (CLIP):
#       An image and its text caption.
#       "An image and its description should map to the same point."
#
#   • Supervised pairs  (SupCon):
#       Two augmentations of the same image, with class labels.
#       All samples of the same class are positives.

class AugmentationPairCreator:
    """SimCLR / MoCo / BYOL style: two random augmentations of each image."""

    def __init__(self, transform):
        self.transform = transform    # a stochastic augmentation pipeline

    def __call__(self, batch, labels=None):
        view_a = torch.stack([self.transform(x) for x in batch])
        view_b = torch.stack([self.transform(x) for x in batch])
        labs = labels if labels is not None else torch.full((len(batch),), -1)
        return ViewPair(view_a, view_b, labs)


class CrossModalPairCreator:
    """CLIP style: image and text are already paired in the dataset."""

    def __call__(self, images, texts):
        return ViewPair(images, texts, torch.arange(len(images)))


# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP  (standard supervised-style iteration)
# ═══════════════════════════════════════════════════════════════════

def train(algo: ContrastiveAlgorithm, dataloader, view_creator,
          n_epochs, device="cpu"):
    for epoch in range(n_epochs):
        epoch_loss = 0.0
        n = 0
        for batch, labels in dataloader:
            views = view_creator(batch.to(device), labels.to(device))
            loss = algo.train_step(views)
            epoch_loss += loss * batch.size(0)
            n += batch.size(0)

        print(f"Epoch {epoch+1:3d}/{n_epochs} │ loss {epoch_loss/n:.4f}")


# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS  (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════

# ── 1. SimCLR  (in-batch negatives, InfoNCE loss) ────────────────

class SimCLR(ContrastiveAlgorithm):
    """
    Contrastive learning with in-batch negatives.

    For each sample i, its positive is the other augmentation of the
    same image. Every OTHER sample in the batch is a negative.
    Loss: InfoNCE = −log[ exp(sim(z_i, z_i+)) / Σ_j exp(sim(z_i, z_j)) ]

    Requires large batch sizes (4096+) to get enough negatives.
    """

    def __init__(self, *args, temperature=0.07, **kw):
        super().__init__(*args, **kw)
        self.tau = temperature

    def compute_loss(self, z_a, z_b, labels=None):
        B = z_a.size(0)

        # All pairwise similarities: 2B × 2B matrix
        z = torch.cat([z_a, z_b], dim=0)                        # (2B, d)
        sim = (z @ z.T) / self.tau                               # (2B, 2B)

        # Mask out self-similarity (diagonal)
        mask_self = torch.eye(2 * B, device=sim.device).bool()
        sim.masked_fill_(mask_self, float("-inf"))

        # Positive pairs: (i, i+B) and (i+B, i)
        pos_idx = torch.cat([torch.arange(B, 2 * B), torch.arange(B)],
                            dim=0).to(sim.device)                # (2B,)

        # InfoNCE: −log softmax along each row, evaluated at positive
        return F.cross_entropy(sim, pos_idx)


# ── 2. CLIP  (cross-modal contrastive, symmetric loss) ───────────

class CLIP(ContrastiveAlgorithm):
    """
    Cross-modal contrastive: image encoder + text encoder.
    Each (image, text) pair is a positive. All cross-pairs are negatives.

    Loss is symmetric: image→text CE + text→image CE, averaged.
    The temperature is a LEARNED parameter (log-scaled for stability).

    This is how you align two different modalities in a shared space.
    """

    def __init__(self, image_encoder, text_encoder,
                 image_proj, text_proj, optimizer):
        # CLIP has two separate encoders — override the base init
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.image_proj = image_proj
        self.text_proj = text_proj
        self.optimizer = optimizer
        self.log_tau = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))

    def train_step(self, views: ViewPair):
        # Encode each modality with its own encoder
        z_img = F.normalize(self.image_proj(self.image_encoder(views.view_a)), dim=-1)
        z_txt = F.normalize(self.text_proj(self.text_encoder(views.view_b)), dim=-1)

        loss = self.compute_loss(z_img, z_txt)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def compute_loss(self, z_img, z_txt, labels=None):
        tau = self.log_tau.exp()
        # Similarity matrix: (B, B) — each image against all texts
        logits = (z_img @ z_txt.T) * tau                         # (B, B)

        # Targets: diagonal (image_i matches text_i)
        targets = torch.arange(z_img.size(0), device=logits.device)

        # Symmetric loss: image→text + text→image
        loss_i2t = F.cross_entropy(logits, targets)
        loss_t2i = F.cross_entropy(logits.T, targets)
        return (loss_i2t + loss_t2i) / 2


# ── 3. MoCo  (momentum encoder + queue of negatives) ─────────────

class MoCo(ContrastiveAlgorithm):
    """
    Momentum Contrast: maintains a large queue of negative embeddings
    from previous batches, encoded by a slowly-updating momentum encoder.

    Solves SimCLR's problem: you no longer need giant batch sizes for
    enough negatives. A queue of 65536 negatives is typical.

    The momentum encoder is NOT trained by gradient — it's an EMA of
    the online encoder. This provides consistent targets across the
    queue (all negatives were encoded by similar-ish networks).
    """

    def __init__(self, encoder, projection, optimizer,
                 momentum_encoder, momentum_projection,
                 queue_size=65536, momentum=0.999, temperature=0.07):
        super().__init__(encoder, projection, optimizer)
        self.m_encoder = momentum_encoder
        self.m_projection = momentum_projection
        self.m = momentum
        self.tau = temperature
        self.queue_size = queue_size

        # Initialise queue with random unit vectors
        d = projection.net[-1].out_features
        self.queue = F.normalize(torch.randn(queue_size, d), dim=-1)
        self.queue_ptr = 0

        # Copy initial weights
        for p, mp in zip(encoder.parameters(), momentum_encoder.parameters()):
            mp.data.copy_(p.data)
            mp.requires_grad_(False)
        for p, mp in zip(projection.parameters(), momentum_projection.parameters()):
            mp.data.copy_(p.data)
            mp.requires_grad_(False)

    def compute_loss(self, z_a, z_b, labels=None):
        # z_a = online encoder output (queries)
        # z_b = momentum encoder output (positive keys) — computed in train_step
        # queue = negative keys from previous batches

        queue = self.queue.to(z_a.device)

        # Positive logits: (B, 1)
        pos = (z_a * z_b).sum(dim=-1, keepdim=True) / self.tau

        # Negative logits: (B, queue_size)
        neg = (z_a @ queue.T) / self.tau

        # InfoNCE: positive is index 0
        logits = torch.cat([pos, neg], dim=1)                    # (B, 1 + K)
        targets = torch.zeros(z_a.size(0), dtype=torch.long, device=z_a.device)
        return F.cross_entropy(logits, targets)

    def train_step(self, views: ViewPair):
        # Online encoder: queries
        z_a = self.encode_and_project(views.view_a)

        # Momentum encoder: keys (no gradients)
        with torch.no_grad():
            h_b = self.m_encoder(views.view_b)
            z_b = F.normalize(self.m_projection(h_b), dim=-1)

        loss = self.compute_loss(z_a, z_b)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Enqueue current keys, dequeue oldest
        self._enqueue(z_b.detach())
        self.post_step()

        return loss.item()

    def _enqueue(self, keys):
        B = keys.size(0)
        ptr = self.queue_ptr
        self.queue[ptr:ptr + B] = keys.cpu()
        self.queue_ptr = (ptr + B) % self.queue_size

    def post_step(self):
        # Momentum update: θ_m ← m·θ_m + (1−m)·θ
        for p, mp in zip(self.encoder.parameters(), self.m_encoder.parameters()):
            mp.data.lerp_(p.data, 1 - self.m)
        for p, mp in zip(self.projection.parameters(), self.m_projection.parameters()):
            mp.data.lerp_(p.data, 1 - self.m)


# ── 4. BYOL  (no negatives at all) ──────────────────────────────

class BYOL(ContrastiveAlgorithm):
    """
    Bootstrap Your Own Latent: learns WITHOUT negative pairs.

    This seems like it shouldn't work (what prevents collapse to a
    constant?). The asymmetry is the key:
      • Online branch:  encoder → projector → PREDICTOR → output
      • Target branch:  momentum_encoder → projector → output  (no predictor)

    The predictor + momentum update + stop-gradient create enough
    asymmetry that collapse is avoided. The online branch tries to
    PREDICT the target branch's output, not just match it.
    """

    def __init__(self, encoder, projection, predictor, optimizer,
                 momentum_encoder, momentum_projection, momentum=0.996):
        super().__init__(encoder, projection, optimizer)
        self.predictor = predictor
        self.m_encoder = momentum_encoder
        self.m_projection = momentum_projection
        self.m = momentum

        # Copy initial weights
        for p, mp in zip(encoder.parameters(), momentum_encoder.parameters()):
            mp.data.copy_(p.data)
            mp.requires_grad_(False)
        for p, mp in zip(projection.parameters(), momentum_projection.parameters()):
            mp.data.copy_(p.data)
            mp.requires_grad_(False)

    def compute_loss(self, p_a, z_b, labels=None):
        # Negative cosine similarity (no negatives needed!)
        return 2 - 2 * (p_a * z_b).sum(dim=-1).mean()

    def train_step(self, views: ViewPair):
        # Online: encode → project → predict
        z_a = self.encode_and_project(views.view_a)
        p_a = F.normalize(self.predictor(z_a), dim=-1)

        # Target: momentum encode → project (no predictor, no grad)
        with torch.no_grad():
            h_b = self.m_encoder(views.view_b)
            z_b = F.normalize(self.m_projection(h_b), dim=-1)

        # Symmetric loss: predict both directions
        z_a2 = self.encode_and_project(views.view_b)
        p_a2 = F.normalize(self.predictor(z_a2), dim=-1)
        with torch.no_grad():
            h_b2 = self.m_encoder(views.view_a)
            z_b2 = F.normalize(self.m_projection(h_b2), dim=-1)

        loss = (self.compute_loss(p_a, z_b.detach()) +
                self.compute_loss(p_a2, z_b2.detach())) / 2

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.post_step()

        return loss.item()

    def post_step(self):
        for p, mp in zip(self.encoder.parameters(), self.m_encoder.parameters()):
            mp.data.lerp_(p.data, 1 - self.m)
        for p, mp in zip(self.projection.parameters(), self.m_projection.parameters()):
            mp.data.lerp_(p.data, 1 - self.m)


# ── 5. SupCon  (supervised contrastive) ──────────────────────────

class SupCon(ContrastiveAlgorithm):
    """
    Supervised Contrastive Learning: uses labels to define positives.
    All samples with the SAME label are positives for each other.

    Generalises SimCLR: instead of just the augmented pair being
    positive, every same-class sample in the batch is positive.
    Outperforms standard cross-entropy on many benchmarks.
    """

    def __init__(self, *args, temperature=0.07, **kw):
        super().__init__(*args, **kw)
        self.tau = temperature

    def compute_loss(self, z_a, z_b, labels=None):
        z = torch.cat([z_a, z_b], dim=0)                        # (2B, d)
        B = z_a.size(0)
        N = 2 * B

        sim = (z @ z.T) / self.tau                               # (N, N)

        # Mask: 1 where labels match (positives), 0 elsewhere
        labels = labels.repeat(2)                                # (2B,)
        pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        pos_mask.fill_diagonal_(0)                               # exclude self

        # For numerical stability
        sim_max, _ = sim.max(dim=1, keepdim=True)
        sim = sim - sim_max.detach()

        # Log-sum-exp over all non-self entries
        self_mask = torch.eye(N, device=sim.device)
        exp_sim = torch.exp(sim) * (1 - self_mask)
        log_sum_exp = torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)

        # Mean of log-prob over positive pairs
        log_prob = sim - log_sum_exp
        mean_log_prob = (pos_mask * log_prob).sum(dim=1) / (pos_mask.sum(dim=1) + 1e-8)
        return -mean_log_prob.mean()