Skip to content

Unified Variational Autoencoder (VAE) Algorithm — Implementation

"""
Unified Variational Autoencoder (VAE) Algorithm
=================================================
A single skeleton covering: vanilla VAE, β-VAE, VQ-VAE,
conditional VAE (CVAE), and the VAE as used in latent diffusion
(Stable Diffusion's image compressor).

The core idea shared by ALL VAE variants:

  1. ENCODE:    x → q(z|x)              (compress input to a distribution)
  2. SAMPLE:    z ~ q(z|x)              (draw a latent code)
  3. DECODE:    z → p(x|z)              (reconstruct from the code)
  4. LOSS:      reconstruction + regularisation

  loss = E_q[ −log p(x|z) ]  +  KL[ q(z|x) ‖ p(z) ]
         ├── reconstruction ──┘    └── regularisation ──┘
         "how good is the           "how close is the
          reconstruction?"           learned posterior to
                                     the prior?"

The pluggable components are:
  1. encode()              — how the posterior q(z|x) is parameterised
  2. sample_latent()       — how z is drawn (reparameterisation, codebook, ...)
  3. decode()              — how p(x|z) is parameterised
  4. reconstruction_loss() — MSE, BCE, perceptual, ...
  5. regularisation_loss() — KL divergence, commitment loss, ...
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from abc import ABC, abstractmethod


# ═══════════════════════════════════════════════════════════════════
# CORE: THE REPARAMETERISATION TRICK
# ═══════════════════════════════════════════════════════════════════
#
# The fundamental problem: we need to SAMPLE z ~ q(z|x), but
# sampling is not differentiable — you can't backprop through
# a random draw.
#
# The trick: instead of sampling z ~ N(μ, σ²), compute
#   z = μ + σ · ε,    where ε ~ N(0, I)
#
# Now the randomness (ε) is external to the computation graph,
# and gradients flow through μ and σ normally.
#
# This single trick is what makes VAEs trainable. Without it,
# you'd need REINFORCE-style gradient estimators (high variance).

def reparameterise(mu, log_var):
    """Sample z = μ + σ·ε, where ε ~ N(0, I). Differentiable in μ, σ."""
    std = (0.5 * log_var).exp()
    eps = torch.randn_like(std)
    return mu + std * eps


# ═══════════════════════════════════════════════════════════════════
# CORE: KL DIVERGENCE  (analytic for two Gaussians)
# ═══════════════════════════════════════════════════════════════════
#
# When both q(z|x) = N(μ, σ²) and p(z) = N(0, I) are Gaussian,
# the KL has a closed form — no sampling needed:
#
#   KL = -½ Σ (1 + log σ² − μ² − σ²)
#
# This is per-sample; we average over the batch.

def kl_divergence_gaussian(mu, log_var):
    """KL[ N(μ, σ²) ‖ N(0, I) ], averaged over batch."""
    return -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(dim=-1).mean()


# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM  (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════

class VAEAlgorithm(ABC):
    """
    The universal VAE training step.

    Every variant inherits this and only overrides:
      - encode(x) → latent params
      - sample_latent(latent_params) → z
      - decode(z) → reconstruction
      - reconstruction_loss(x_recon, x) → scalar
      - regularisation_loss(latent_params) → scalar
    """

    def __init__(self, encoder, decoder, optimizer):
        self.encoder = encoder
        self.decoder = decoder
        self.optimizer = optimizer

    # ── The pluggable pieces ──────────────────────────────────────

    @abstractmethod
    def encode(self, x):
        """Return latent parameters (e.g. mu, log_var)."""
        ...

    @abstractmethod
    def sample_latent(self, latent_params):
        """Sample z from the encoded distribution. Must be differentiable."""
        ...

    @abstractmethod
    def decode(self, z):
        """Reconstruct x from z."""
        ...

    def reconstruction_loss(self, x_recon, x):
        """Default: MSE. Override for BCE, perceptual loss, etc."""
        return F.mse_loss(x_recon, x, reduction="mean")

    @abstractmethod
    def regularisation_loss(self, latent_params):
        """Return the regularisation term (KL, commitment, ...)."""
        ...

    # ── Core training step (IDENTICAL for every variant) ──────────

    def train_step(self, x, condition=None):
        """
        THE core VAE training loop:
          1. Encode → distribution over z
          2. Sample z (reparameterised — differentiable)
          3. Decode → reconstruction
          4. Loss = reconstruction + regularisation
        """
        # 1. Encode
        latent_params = self.encode(x)

        # 2. Sample (must be differentiable)
        z = self.sample_latent(latent_params)

        # 3. Decode
        x_recon = self.decode(z) if condition is None else self.decode(z, condition)

        # 4. Loss  (both terms are PLUGGABLE)
        loss_recon = self.reconstruction_loss(x_recon, x)
        loss_reg = self.regularisation_loss(latent_params)
        loss = loss_recon + loss_reg

        # 5. Gradient step  (always the same)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {"loss": loss.item(),
                "recon": loss_recon.item(),
                "reg": loss_reg.item()}

    # ── Generation (decode from prior) ────────────────────────────

    @torch.no_grad()
    def generate(self, n_samples, device="cpu"):
        """Sample z ~ p(z), then decode."""
        z = self.sample_prior(n_samples, device)
        return self.decode(z)

    @abstractmethod
    def sample_prior(self, n_samples, device):
        """Sample from the prior p(z)."""
        ...


# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP
# ═══════════════════════════════════════════════════════════════════

def train(algo: VAEAlgorithm, dataloader, n_epochs, device="cpu"):
    for epoch in range(n_epochs):
        totals = {"loss": 0, "recon": 0, "reg": 0}
        n = 0
        for x, *rest in dataloader:
            x = x.to(device)
            metrics = algo.train_step(x)
            for k in totals:
                totals[k] += metrics[k] * x.size(0)
            n += x.size(0)

        avg = {k: v / n for k, v in totals.items()}
        print(f"Epoch {epoch+1:3d}/{n_epochs} │ "
              f"loss {avg['loss']:.4f}  "
              f"recon {avg['recon']:.4f}  "
              f"reg {avg['reg']:.4f}")


# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS  (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════

# ── 1. Vanilla VAE  (Gaussian posterior, Gaussian prior) ─────────

class VanillaVAE(VAEAlgorithm):
    """
    The original VAE (Kingma & Welling 2013).

    Encoder outputs (μ, log σ²), sample via reparameterisation,
    decode, and balance reconstruction vs KL divergence.

    The ELBO (Evidence Lower Bound) is:
      log p(x) ≥ E_q[log p(x|z)] − KL[q(z|x) ‖ p(z)]
               = −reconstruction    − KL
    Maximising the ELBO = minimising our loss.
    """

    def __init__(self, encoder, decoder, optimizer, d_latent):
        super().__init__(encoder, decoder, optimizer)
        self.d_latent = d_latent

    def encode(self, x):
        h = self.encoder(x)                                     # (B, d_enc)
        # Split into μ and log σ² (encoder's last layer outputs 2×d_latent)
        mu, log_var = h.chunk(2, dim=-1)
        return (mu, log_var)

    def sample_latent(self, latent_params):
        mu, log_var = latent_params
        return reparameterise(mu, log_var)

    def decode(self, z, condition=None):
        return self.decoder(z)

    def regularisation_loss(self, latent_params):
        mu, log_var = latent_params
        return kl_divergence_gaussian(mu, log_var)

    def sample_prior(self, n_samples, device):
        return torch.randn(n_samples, self.d_latent, device=device)


# ── 2. β-VAE  (disentangled representations) ────────────────────

class BetaVAE(VanillaVAE):
    """
    Identical to vanilla VAE but scales the KL term by β.

    β > 1: stronger pressure to use a simple, factorised posterior.
    Each latent dimension is pushed to be independent, encouraging
    "disentangled" representations where individual dimensions
    correspond to meaningful factors of variation (e.g. size,
    rotation, colour).

    β < 1: weaker regularisation, better reconstruction,
    but less structured latent space.

    The tradeoff: β controls reconstruction quality vs latent
    structure. β=1 is the original VAE (statistically principled).
    """

    def __init__(self, *args, beta=4.0, **kw):
        super().__init__(*args, **kw)
        self.beta = beta

    def regularisation_loss(self, latent_params):
        mu, log_var = latent_params
        return self.beta * kl_divergence_gaussian(mu, log_var)


# ── 3. VQ-VAE  (discrete latent codes via vector quantisation) ───

class VectorQuantiser(nn.Module):
    """
    Replaces continuous z with the nearest vector from a learned codebook.

    Given encoder output z_e, find the closest codebook entry:
      z_q = codebook[argmin ‖z_e − e_k‖]

    Not differentiable (argmin), so we use the straight-through estimator:
    forward pass uses z_q, backward pass pretends z_q = z_e.
    """

    def __init__(self, n_codes, d_code):
        super().__init__()
        self.n_codes = n_codes
        self.d_code = d_code
        self.codebook = nn.Embedding(n_codes, d_code)
        # Uniform init
        self.codebook.weight.data.uniform_(-1 / n_codes, 1 / n_codes)

    def forward(self, z_e):
        """
        Args:
            z_e: (B, ..., d_code) — encoder output (any spatial shape)
        Returns:
            z_q: (B, ..., d_code) — quantised (straight-through)
            indices: (B, ...) — codebook indices
            commit_loss: scalar — commitment loss
        """
        flat = z_e.reshape(-1, self.d_code)                      # (N, d)

        # Distances to all codebook entries
        dist = (flat.pow(2).sum(1, keepdim=True)
                - 2 * flat @ self.codebook.weight.T
                + self.codebook.weight.pow(2).sum(1, keepdim=True).T)

        indices = dist.argmin(dim=-1)                            # (N,)
        z_q = self.codebook(indices).view_as(z_e)                # (B, ..., d)

        # Losses (see VQ-VAE variant for how these are used)
        codebook_loss = F.mse_loss(z_q.detach(), z_e)            # move codes → encoder
        commit_loss = F.mse_loss(z_q, z_e.detach())              # move encoder → codes

        # Straight-through: forward uses z_q, backward uses z_e
        z_q_st = z_e + (z_q - z_e).detach()

        return z_q_st, indices, codebook_loss + 0.25 * commit_loss


class VQVAE(VAEAlgorithm):
    """
    VQ-VAE: replaces the Gaussian posterior with discrete codes.

    Instead of z ~ N(μ, σ²), the encoder output is snapped to the
    nearest entry in a learned codebook. No KL divergence — the
    regularisation comes from the information bottleneck of
    discrete codes.

    This is the architecture behind:
      • DALL-E 1 (image tokens for autoregressive generation)
      • Stable Diffusion's VAE (continuous variant with KL, but same idea)
      • AudioLM, SoundStream (audio tokenisation)
    """

    def __init__(self, encoder, decoder, optimizer, quantiser: VectorQuantiser):
        super().__init__(encoder, decoder, optimizer)
        self.quantiser = quantiser

    def encode(self, x):
        z_e = self.encoder(x)                                   # (B, d_code) or (B, H, W, d)
        return (z_e,)                                            # tuple for consistency

    def sample_latent(self, latent_params):
        z_e, = latent_params
        z_q, indices, vq_loss = self.quantiser(z_e)
        # Stash the VQ loss for use in regularisation_loss
        self._vq_loss = vq_loss
        self._indices = indices
        return z_q

    def decode(self, z, condition=None):
        return self.decoder(z)

    def regularisation_loss(self, latent_params):
        # No KL — regularisation is the VQ commitment loss
        return self._vq_loss

    def sample_prior(self, n_samples, device):
        # Sample random codebook indices, look up embeddings
        indices = torch.randint(0, self.quantiser.n_codes,
                                (n_samples,), device=device)
        return self.quantiser.codebook(indices)


# ── 4. Conditional VAE  (generation conditioned on a label/text) ─

class CVAE(VanillaVAE):
    """
    Conditional VAE: both encoder and decoder receive a condition c
    (e.g. class label, text embedding).

    Encoder:  q(z | x, c)   — "what latent code explains x given c?"
    Decoder:  p(x | z, c)   — "generate x from z and c"
    Prior:    p(z) = N(0, I) — same as vanilla (or can be conditioned too)

    At generation time: choose c, sample z ~ p(z), decode(z, c).
    """

    def __init__(self, encoder, decoder, optimizer, d_latent,
                 cond_encoder=None):
        super().__init__(encoder, decoder, optimizer, d_latent)
        self.cond_encoder = cond_encoder    # optional: embed raw labels → vectors

    def train_step(self, x, condition=None):
        if self.cond_encoder is not None and condition is not None:
            condition = self.cond_encoder(condition)

        # Encode with condition concatenated
        latent_params = self.encode(torch.cat([x, condition], dim=-1)
                                    if condition is not None else x)
        z = self.sample_latent(latent_params)

        # Decode with condition
        z_cond = torch.cat([z, condition], dim=-1) if condition is not None else z
        x_recon = self.decoder(z_cond)

        loss_recon = self.reconstruction_loss(x_recon, x)
        loss_reg = self.regularisation_loss(latent_params)
        loss = loss_recon + loss_reg

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {"loss": loss.item(),
                "recon": loss_recon.item(),
                "reg": loss_reg.item()}


# ── 5. KL-regularised autoencoder  (Stable Diffusion's VAE) ─────

class KLAE(VanillaVAE):
    """
    The image compressor used in Latent Diffusion / Stable Diffusion.

    Structurally a VAE, but with two important differences:
      • Very small KL weight (≈1e-6) — almost a plain autoencoder,
        just enough regularisation to keep the latent space smooth
      • Perceptual + adversarial reconstruction loss instead of MSE,
        producing much sharper reconstructions

    The purpose is not generation (diffusion handles that) but
    COMPRESSION: map 512×512×3 images → 64×64×4 latents, making
    the diffusion model 64× cheaper to train and run.
    """

    def __init__(self, encoder, decoder, optimizer, d_latent,
                 kl_weight=1e-6, perceptual_loss_fn=None, disc=None,
                 disc_optimizer=None):
        super().__init__(encoder, decoder, optimizer, d_latent)
        self.kl_weight = kl_weight
        self.perceptual_fn = perceptual_loss_fn    # e.g. LPIPS
        self.disc = disc                           # patch discriminator
        self.disc_optimizer = disc_optimizer

    def regularisation_loss(self, latent_params):
        mu, log_var = latent_params
        return self.kl_weight * kl_divergence_gaussian(mu, log_var)

    def reconstruction_loss(self, x_recon, x):
        loss = F.mse_loss(x_recon, x)

        if self.perceptual_fn is not None:
            loss = loss + self.perceptual_fn(x_recon, x)

        if self.disc is not None:
            # Generator wants discriminator to think reconstruction is real
            loss = loss + (-self.disc(x_recon)).mean()

        return loss

    def train_step(self, x, condition=None):
        metrics = super().train_step(x)

        # Optional: update discriminator
        if self.disc is not None and self.disc_optimizer is not None:
            with torch.no_grad():
                latent_params = self.encode(x)
                z = self.sample_latent(latent_params)
                x_recon = self.decode(z)

            real_score = self.disc(x)
            fake_score = self.disc(x_recon)
            disc_loss = F.relu(1 - real_score).mean() + F.relu(1 + fake_score).mean()

            self.disc_optimizer.zero_grad()
            disc_loss.backward()
            self.disc_optimizer.step()
            metrics["disc_loss"] = disc_loss.item()

        return metrics