Skip to content

Unified Diffusion Models: Core Algorithm — Implementation

"""
Unified Diffusion Models: Core Algorithm
==========================================
A single skeleton covering: DDPM, DDIM, velocity prediction,
v-prediction, classifier-free guidance, and noise schedule variants.

The core idea shared by ALL diffusion models:

  TRAINING:   x_t = √ᾱ_t · x_0 + √(1−ᾱ_t) · ε      (add noise)
              loss = ‖ target − model(x_t, t) ‖²       (predict something)

  SAMPLING:   start from pure noise x_T ~ N(0, I)
              for t = T, T−1, ..., 1:
                  x_{t−1} = denoise_step(x_t, t)       (iteratively denoise)

That's it. Training is embarrassingly simple — one line of noise addition,
one forward pass, MSE loss. All the complexity is in the sampling loop
and the schedule.

The pluggable components are:
  1. noise_schedule()    — how fast noise is added (linear, cosine, ...)
  2. predict()           — what the model outputs (noise ε, signal x_0, velocity v)
  3. denoise_step()      — how to go from x_t → x_{t−1} (DDPM, DDIM, ...)
  4. guidance()          — how to steer generation (classifier-free, etc.)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from abc import ABC, abstractmethod


# ═══════════════════════════════════════════════════════════════════
# NOISE SCHEDULES  (how much noise at each timestep)
# ═══════════════════════════════════════════════════════════════════
#
# The schedule defines β_t (noise added at step t) and from it:
#   α_t = 1 − β_t
#   ᾱ_t = ∏ α_i  (cumulative — total signal remaining at step t)
#
# These two numbers are all you need:
#   √ᾱ_t  = how much of the original image survives
#   √(1−ᾱ_t) = how much noise has been added
#
# A good schedule adds noise slowly at first (preserve structure)
# and quickly later (destroy details).

class NoiseSchedule:
    """Precompute all schedule quantities once."""

    def __init__(self, betas):
        self.T = len(betas)
        self.betas = betas                                      # (T,)
        self.alphas = 1.0 - betas                               # (T,)
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)      # ᾱ_t
        self.alpha_bar_prev = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0)

        # Precompute quantities used in sampling
        self.sqrt_alpha_bar = self.alpha_bar.sqrt()
        self.sqrt_one_minus_alpha_bar = (1 - self.alpha_bar).sqrt()
        self.sqrt_alphas = self.alphas.sqrt()

        # Posterior variance: σ²_t = β_t · (1 − ᾱ_{t−1}) / (1 − ᾱ_t)
        self.posterior_var = betas * (1 - self.alpha_bar_prev) / (1 - self.alpha_bar)

    def to(self, device):
        for attr in vars(self):
            v = getattr(self, attr)
            if isinstance(v, torch.Tensor):
                setattr(self, attr, v.to(device))
        return self


def linear_schedule(T=1000, beta_start=1e-4, beta_end=0.02):
    """Original DDPM schedule. Simple but sub-optimal — too noisy at high t."""
    return NoiseSchedule(torch.linspace(beta_start, beta_end, T))


def cosine_schedule(T=1000, s=0.008):
    """
    Improved schedule (Nichol & Dhariwal 2021). Adds noise more
    gradually, preserving structure longer. Better image quality.
    """
    steps = torch.arange(T + 1, dtype=torch.float32)
    f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2
    alpha_bar = f / f[0]
    betas = 1 - (alpha_bar[1:] / alpha_bar[:-1])
    return NoiseSchedule(betas.clamp(max=0.999))


# ═══════════════════════════════════════════════════════════════════
# CORE: FORWARD PROCESS  (adding noise — always the same)
# ═══════════════════════════════════════════════════════════════════
#
# The "forward process" is not learned. It's just math:
#   q(x_t | x_0) = N(x_t; √ᾱ_t · x_0, (1−ᾱ_t) · I)
#
# Which means: x_t = √ᾱ_t · x_0 + √(1−ᾱ_t) · ε,   ε ~ N(0, I)
#
# This lets us jump DIRECTLY to any timestep t without simulating
# the chain step by step — critical for efficient training.

def q_sample(x_0, t, schedule, noise=None):
    """Add noise to x_0 to get x_t. Closed-form, no loop needed."""
    if noise is None:
        noise = torch.randn_like(x_0)

    sqrt_ab = schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)         # (B, 1, 1, 1)
    sqrt_omab = schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)

    return sqrt_ab * x_0 + sqrt_omab * noise                       # x_t


# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM  (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════

class DiffusionAlgorithm(ABC):
    """
    The universal diffusion training and sampling framework.

    Every variant inherits this and only overrides:
      - compute_target(x_0, noise, t)  → what the model should predict
      - recover_x0(model_out, x_t, t)  → convert prediction back to x_0
      - denoise_step(x_t, t, model_out) → go from x_t to x_{t−1}
    """

    def __init__(self, model, schedule: NoiseSchedule, optimizer):
        self.model = model
        self.schedule = schedule
        self.optimizer = optimizer

    # ── The three pluggable pieces ────────────────────────────────

    @abstractmethod
    def compute_target(self, x_0, noise, t):
        """What the model should predict. Returns tensor same shape as x_0."""
        ...

    @abstractmethod
    def recover_x0(self, model_out, x_t, t):
        """Convert the model's output back to an x_0 estimate."""
        ...

    @abstractmethod
    def denoise_step(self, x_t, t, model_out):
        """Single reverse step: x_t → x_{t−1}."""
        ...

    # ── Training step (IDENTICAL for every variant) ───────────────

    def train_step(self, x_0):
        """
        THE core training algorithm. Beautifully simple:
          1. Sample random timestep
          2. Add noise to get x_t
          3. Model predicts target from x_t
          4. MSE loss
        """
        B = x_0.shape[0]
        device = x_0.device

        # 1. Random timestep per sample
        t = torch.randint(0, self.schedule.T, (B,), device=device)

        # 2. Add noise (forward process — always the same)
        noise = torch.randn_like(x_0)
        x_t = q_sample(x_0, t, self.schedule, noise)

        # 3. Model predicts (target depends on variant)
        model_out = self.model(x_t, t)
        target = self.compute_target(x_0, noise, t)

        # 4. Simple MSE loss (always the same)
        loss = F.mse_loss(model_out, target)

        # 5. Gradient step (always the same)
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()

        return loss.item()

    # ── Sampling loop (IDENTICAL structure for every variant) ─────

    @torch.no_grad()
    def sample(self, shape, device="cpu", guidance_fn=None):
        """
        Generate samples by iteratively denoising.
          shape: (B, C, H, W) — shape of images to generate
        """
        x_t = torch.randn(shape, device=device)                 # start from pure noise

        for t_val in reversed(range(self.schedule.T)):
            t = torch.full((shape[0],), t_val, device=device, dtype=torch.long)

            model_out = self.model(x_t, t)

            # Optional: apply guidance (classifier-free, etc.)
            if guidance_fn is not None:
                model_out = guidance_fn(model_out, x_t, t)

            x_t = self.denoise_step(x_t, t, model_out)          # x_t → x_{t−1}

        return x_t                                               # x_0


# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP  (iterate over dataset — standard supervised setup)
# ═══════════════════════════════════════════════════════════════════
#
# Diffusion training is identical to supervised learning: iterate
# over batches from a dataset. No RL, no replay buffer, no rollouts.
# Just (data, noise, MSE). This simplicity is a key reason diffusion
# won over GANs.

def train(algo: DiffusionAlgorithm, dataloader, n_epochs, device="cpu",
          ema=None):
    for epoch in range(n_epochs):
        epoch_loss = 0.0
        n = 0
        for x_0, *_ in dataloader:                              # ignore labels if present
            x_0 = x_0.to(device)
            loss = algo.train_step(x_0)
            epoch_loss += loss * x_0.size(0)
            n += x_0.size(0)

            # EMA update (nearly universal — averages weights for better samples)
            if ema is not None:
                ema.update(algo.model)

        print(f"Epoch {epoch+1:3d}/{n_epochs} │ loss {epoch_loss/n:.4f}")


# ═══════════════════════════════════════════════════════════════════
# EMA  (Exponential Moving Average of model weights)
# ═══════════════════════════════════════════════════════════════════
#
# Nearly all diffusion models use an EMA copy for sampling.
# Training weights are noisy; the smoothed EMA weights produce
# noticeably better samples. Decay of 0.9999 is standard.

class EMA:
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {k: v.clone() for k, v in model.state_dict().items()}

    def update(self, model):
        for k, v in model.state_dict().items():
            self.shadow[k].lerp_(v, 1 - self.decay)

    def apply(self, model):
        model.load_state_dict(self.shadow)


# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS  (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════

# ── 1. DDPM  (ε-prediction, stochastic sampling) ────────────────

class DDPM(DiffusionAlgorithm):
    """
    Original Denoising Diffusion Probabilistic Model.
    Model predicts the NOISE ε that was added.
    Sampling is stochastic (adds noise at each reverse step).
    """

    def compute_target(self, x_0, noise, t):
        return noise                                             # predict ε

    def recover_x0(self, model_out, x_t, t):
        # x_t = √ᾱ·x_0 + √(1−ᾱ)·ε  →  x_0 = (x_t − √(1−ᾱ)·ε) / √ᾱ
        sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
        sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
        return (x_t - sqrt_omab * model_out) / sqrt_ab

    def denoise_step(self, x_t, t, predicted_noise):
        s = self.schedule
        b = s.betas[t].view(-1, 1, 1, 1)
        sqrt_omab = s.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
        sqrt_a = s.sqrt_alphas[t].view(-1, 1, 1, 1)

        # Mean of p(x_{t−1} | x_t)
        mean = (x_t - b * predicted_noise / sqrt_omab) / sqrt_a

        # Add noise (stochastic) — except at t=0
        if t[0].item() > 0:
            sigma = s.posterior_var[t].sqrt().view(-1, 1, 1, 1)
            mean = mean + sigma * torch.randn_like(x_t)

        return mean


# ── 2. DDIM  (ε-prediction, deterministic sampling) ──────────────

class DDIM(DiffusionAlgorithm):
    """
    Denoising Diffusion Implicit Models.

    Same training as DDPM (predict ε, same loss), but the reverse
    process is DETERMINISTIC — no noise added during sampling.

    Key benefits:
      • Deterministic: same noise → same image (interpolation, editing)
      • Faster: can skip timesteps (e.g. 50 steps instead of 1000)
      • η parameter interpolates between DDIM (η=0) and DDPM (η=1)
    """

    def __init__(self, *args, eta=0.0, timestep_spacing=None, **kw):
        super().__init__(*args, **kw)
        self.eta = eta                                            # 0=deterministic, 1=DDPM
        # Optional: subsample timesteps for faster sampling
        self.sample_steps = timestep_spacing  # e.g. [999, 949, 899, ..., 49, 0]

    def compute_target(self, x_0, noise, t):
        return noise                                             # identical to DDPM

    def recover_x0(self, model_out, x_t, t):
        sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
        sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
        return (x_t - sqrt_omab * model_out) / sqrt_ab

    def denoise_step(self, x_t, t, predicted_noise):
        s = self.schedule
        ab = s.alpha_bar[t].view(-1, 1, 1, 1)
        ab_prev = s.alpha_bar_prev[t].view(-1, 1, 1, 1)

        # Predict x_0 from x_t and ε
        x0_pred = self.recover_x0(predicted_noise, x_t, t)

        # DDIM formula: deterministic + optional stochastic term
        sigma = self.eta * ((1 - ab_prev) / (1 - ab) * (1 - ab / ab_prev)).sqrt()
        dir_xt = (1 - ab_prev - sigma ** 2).sqrt() * predicted_noise
        x_prev = ab_prev.sqrt() * x0_pred + dir_xt

        if self.eta > 0 and t[0].item() > 0:
            x_prev = x_prev + sigma * torch.randn_like(x_t)

        return x_prev

    @torch.no_grad()
    def sample(self, shape, device="cpu", guidance_fn=None):
        """Override to support timestep subsampling (fast sampling)."""
        x_t = torch.randn(shape, device=device)

        steps = self.sample_steps or list(reversed(range(self.schedule.T)))
        for t_val in steps:
            t = torch.full((shape[0],), t_val, device=device, dtype=torch.long)
            model_out = self.model(x_t, t)
            if guidance_fn is not None:
                model_out = guidance_fn(model_out, x_t, t)
            x_t = self.denoise_step(x_t, t, model_out)

        return x_t


# ── 3. v-prediction ──────────────────────────────────────────────

class VPrediction(DiffusionAlgorithm):
    """
    Instead of predicting noise ε or signal x_0, predict the
    "velocity" v = √ᾱ · ε − √(1−ᾱ) · x_0.

    Benefits:
      • More numerically stable at t≈0 and t≈T
      • Better for high-resolution generation
      • Used in Stable Diffusion 2.x, Imagen Video
    """

    def compute_target(self, x_0, noise, t):
        sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
        sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
        return sqrt_ab * noise - sqrt_omab * x_0                 # velocity v

    def recover_x0(self, model_out, x_t, t):
        sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
        sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
        return sqrt_ab * x_t - sqrt_omab * model_out

    def denoise_step(self, x_t, t, predicted_v):
        # Recover ε from v, then use DDIM-style step
        sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
        sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)

        x0_pred = sqrt_ab * x_t - sqrt_omab * predicted_v
        eps_pred = sqrt_omab * x_t + sqrt_ab * predicted_v

        ab_prev = self.schedule.alpha_bar_prev[t].view(-1, 1, 1, 1)
        return ab_prev.sqrt() * x0_pred + (1 - ab_prev).sqrt() * eps_pred


# ═══════════════════════════════════════════════════════════════════
# CLASSIFIER-FREE GUIDANCE  (the standard way to steer generation)
# ═══════════════════════════════════════════════════════════════════
#
# Problem: how do you make the model generate "a photo of a cat"
# rather than just any image?
#
# Solution: train ONE model that can be both conditional and
# unconditional (randomly drop the conditioning during training).
# At sampling time, extrapolate AWAY from the unconditional output:
#
#   output = uncond + w · (cond − uncond)
#
# w > 1 amplifies the effect of the condition. w=7.5 is a common default.
# This is the mechanism behind all text-to-image models (DALL-E 2,
# Stable Diffusion, Imagen, etc.)

class ClassifierFreeGuidance:
    """
    Wraps a conditional model to apply guidance at sampling time.

    The model must accept a conditioning signal (e.g. text embeddings)
    and a null/empty condition for the unconditional pass.
    """

    def __init__(self, model, null_cond, guidance_scale=7.5):
        self.model = model
        self.null_cond = null_cond                               # e.g. empty text embedding
        self.w = guidance_scale

    def __call__(self, x_t, t, cond):
        # Run model twice: once conditional, once unconditional
        cond_out = self.model(x_t, t, cond)
        uncond_out = self.model(x_t, t, self.null_cond)

        # Extrapolate toward the conditional prediction
        return uncond_out + self.w * (cond_out - uncond_out)