"""
Unified Diffusion Models: Core Algorithm
==========================================
A single skeleton covering: DDPM, DDIM, velocity prediction,
v-prediction, classifier-free guidance, and noise schedule variants.
The core idea shared by ALL diffusion models:
TRAINING: x_t = √ᾱ_t · x_0 + √(1−ᾱ_t) · ε (add noise)
loss = ‖ target − model(x_t, t) ‖² (predict something)
SAMPLING: start from pure noise x_T ~ N(0, I)
for t = T, T−1, ..., 1:
x_{t−1} = denoise_step(x_t, t) (iteratively denoise)
That's it. Training is embarrassingly simple — one line of noise addition,
one forward pass, MSE loss. All the complexity is in the sampling loop
and the schedule.
The pluggable components are:
1. noise_schedule() — how fast noise is added (linear, cosine, ...)
2. predict() — what the model outputs (noise ε, signal x_0, velocity v)
3. denoise_step() — how to go from x_t → x_{t−1} (DDPM, DDIM, ...)
4. guidance() — how to steer generation (classifier-free, etc.)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from abc import ABC, abstractmethod
# ═══════════════════════════════════════════════════════════════════
# NOISE SCHEDULES (how much noise at each timestep)
# ═══════════════════════════════════════════════════════════════════
#
# The schedule defines β_t (noise added at step t) and from it:
# α_t = 1 − β_t
# ᾱ_t = ∏ α_i (cumulative — total signal remaining at step t)
#
# These two numbers are all you need:
# √ᾱ_t = how much of the original image survives
# √(1−ᾱ_t) = how much noise has been added
#
# A good schedule adds noise slowly at first (preserve structure)
# and quickly later (destroy details).
class NoiseSchedule:
"""Precompute all schedule quantities once."""
def __init__(self, betas):
self.T = len(betas)
self.betas = betas # (T,)
self.alphas = 1.0 - betas # (T,)
self.alpha_bar = torch.cumprod(self.alphas, dim=0) # ᾱ_t
self.alpha_bar_prev = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0)
# Precompute quantities used in sampling
self.sqrt_alpha_bar = self.alpha_bar.sqrt()
self.sqrt_one_minus_alpha_bar = (1 - self.alpha_bar).sqrt()
self.sqrt_alphas = self.alphas.sqrt()
# Posterior variance: σ²_t = β_t · (1 − ᾱ_{t−1}) / (1 − ᾱ_t)
self.posterior_var = betas * (1 - self.alpha_bar_prev) / (1 - self.alpha_bar)
def to(self, device):
for attr in vars(self):
v = getattr(self, attr)
if isinstance(v, torch.Tensor):
setattr(self, attr, v.to(device))
return self
def linear_schedule(T=1000, beta_start=1e-4, beta_end=0.02):
"""Original DDPM schedule. Simple but sub-optimal — too noisy at high t."""
return NoiseSchedule(torch.linspace(beta_start, beta_end, T))
def cosine_schedule(T=1000, s=0.008):
"""
Improved schedule (Nichol & Dhariwal 2021). Adds noise more
gradually, preserving structure longer. Better image quality.
"""
steps = torch.arange(T + 1, dtype=torch.float32)
f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2
alpha_bar = f / f[0]
betas = 1 - (alpha_bar[1:] / alpha_bar[:-1])
return NoiseSchedule(betas.clamp(max=0.999))
# ═══════════════════════════════════════════════════════════════════
# CORE: FORWARD PROCESS (adding noise — always the same)
# ═══════════════════════════════════════════════════════════════════
#
# The "forward process" is not learned. It's just math:
# q(x_t | x_0) = N(x_t; √ᾱ_t · x_0, (1−ᾱ_t) · I)
#
# Which means: x_t = √ᾱ_t · x_0 + √(1−ᾱ_t) · ε, ε ~ N(0, I)
#
# This lets us jump DIRECTLY to any timestep t without simulating
# the chain step by step — critical for efficient training.
def q_sample(x_0, t, schedule, noise=None):
"""Add noise to x_0 to get x_t. Closed-form, no loop needed."""
if noise is None:
noise = torch.randn_like(x_0)
sqrt_ab = schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1) # (B, 1, 1, 1)
sqrt_omab = schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
return sqrt_ab * x_0 + sqrt_omab * noise # x_t
# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
class DiffusionAlgorithm(ABC):
"""
The universal diffusion training and sampling framework.
Every variant inherits this and only overrides:
- compute_target(x_0, noise, t) → what the model should predict
- recover_x0(model_out, x_t, t) → convert prediction back to x_0
- denoise_step(x_t, t, model_out) → go from x_t to x_{t−1}
"""
def __init__(self, model, schedule: NoiseSchedule, optimizer):
self.model = model
self.schedule = schedule
self.optimizer = optimizer
# ── The three pluggable pieces ────────────────────────────────
@abstractmethod
def compute_target(self, x_0, noise, t):
"""What the model should predict. Returns tensor same shape as x_0."""
...
@abstractmethod
def recover_x0(self, model_out, x_t, t):
"""Convert the model's output back to an x_0 estimate."""
...
@abstractmethod
def denoise_step(self, x_t, t, model_out):
"""Single reverse step: x_t → x_{t−1}."""
...
# ── Training step (IDENTICAL for every variant) ───────────────
def train_step(self, x_0):
"""
THE core training algorithm. Beautifully simple:
1. Sample random timestep
2. Add noise to get x_t
3. Model predicts target from x_t
4. MSE loss
"""
B = x_0.shape[0]
device = x_0.device
# 1. Random timestep per sample
t = torch.randint(0, self.schedule.T, (B,), device=device)
# 2. Add noise (forward process — always the same)
noise = torch.randn_like(x_0)
x_t = q_sample(x_0, t, self.schedule, noise)
# 3. Model predicts (target depends on variant)
model_out = self.model(x_t, t)
target = self.compute_target(x_0, noise, t)
# 4. Simple MSE loss (always the same)
loss = F.mse_loss(model_out, target)
# 5. Gradient step (always the same)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
return loss.item()
# ── Sampling loop (IDENTICAL structure for every variant) ─────
@torch.no_grad()
def sample(self, shape, device="cpu", guidance_fn=None):
"""
Generate samples by iteratively denoising.
shape: (B, C, H, W) — shape of images to generate
"""
x_t = torch.randn(shape, device=device) # start from pure noise
for t_val in reversed(range(self.schedule.T)):
t = torch.full((shape[0],), t_val, device=device, dtype=torch.long)
model_out = self.model(x_t, t)
# Optional: apply guidance (classifier-free, etc.)
if guidance_fn is not None:
model_out = guidance_fn(model_out, x_t, t)
x_t = self.denoise_step(x_t, t, model_out) # x_t → x_{t−1}
return x_t # x_0
# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP (iterate over dataset — standard supervised setup)
# ═══════════════════════════════════════════════════════════════════
#
# Diffusion training is identical to supervised learning: iterate
# over batches from a dataset. No RL, no replay buffer, no rollouts.
# Just (data, noise, MSE). This simplicity is a key reason diffusion
# won over GANs.
def train(algo: DiffusionAlgorithm, dataloader, n_epochs, device="cpu",
ema=None):
for epoch in range(n_epochs):
epoch_loss = 0.0
n = 0
for x_0, *_ in dataloader: # ignore labels if present
x_0 = x_0.to(device)
loss = algo.train_step(x_0)
epoch_loss += loss * x_0.size(0)
n += x_0.size(0)
# EMA update (nearly universal — averages weights for better samples)
if ema is not None:
ema.update(algo.model)
print(f"Epoch {epoch+1:3d}/{n_epochs} │ loss {epoch_loss/n:.4f}")
# ═══════════════════════════════════════════════════════════════════
# EMA (Exponential Moving Average of model weights)
# ═══════════════════════════════════════════════════════════════════
#
# Nearly all diffusion models use an EMA copy for sampling.
# Training weights are noisy; the smoothed EMA weights produce
# noticeably better samples. Decay of 0.9999 is standard.
class EMA:
def __init__(self, model, decay=0.9999):
self.decay = decay
self.shadow = {k: v.clone() for k, v in model.state_dict().items()}
def update(self, model):
for k, v in model.state_dict().items():
self.shadow[k].lerp_(v, 1 - self.decay)
def apply(self, model):
model.load_state_dict(self.shadow)
# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════
# ── 1. DDPM (ε-prediction, stochastic sampling) ────────────────
class DDPM(DiffusionAlgorithm):
"""
Original Denoising Diffusion Probabilistic Model.
Model predicts the NOISE ε that was added.
Sampling is stochastic (adds noise at each reverse step).
"""
def compute_target(self, x_0, noise, t):
return noise # predict ε
def recover_x0(self, model_out, x_t, t):
# x_t = √ᾱ·x_0 + √(1−ᾱ)·ε → x_0 = (x_t − √(1−ᾱ)·ε) / √ᾱ
sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
return (x_t - sqrt_omab * model_out) / sqrt_ab
def denoise_step(self, x_t, t, predicted_noise):
s = self.schedule
b = s.betas[t].view(-1, 1, 1, 1)
sqrt_omab = s.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
sqrt_a = s.sqrt_alphas[t].view(-1, 1, 1, 1)
# Mean of p(x_{t−1} | x_t)
mean = (x_t - b * predicted_noise / sqrt_omab) / sqrt_a
# Add noise (stochastic) — except at t=0
if t[0].item() > 0:
sigma = s.posterior_var[t].sqrt().view(-1, 1, 1, 1)
mean = mean + sigma * torch.randn_like(x_t)
return mean
# ── 2. DDIM (ε-prediction, deterministic sampling) ──────────────
class DDIM(DiffusionAlgorithm):
"""
Denoising Diffusion Implicit Models.
Same training as DDPM (predict ε, same loss), but the reverse
process is DETERMINISTIC — no noise added during sampling.
Key benefits:
• Deterministic: same noise → same image (interpolation, editing)
• Faster: can skip timesteps (e.g. 50 steps instead of 1000)
• η parameter interpolates between DDIM (η=0) and DDPM (η=1)
"""
def __init__(self, *args, eta=0.0, timestep_spacing=None, **kw):
super().__init__(*args, **kw)
self.eta = eta # 0=deterministic, 1=DDPM
# Optional: subsample timesteps for faster sampling
self.sample_steps = timestep_spacing # e.g. [999, 949, 899, ..., 49, 0]
def compute_target(self, x_0, noise, t):
return noise # identical to DDPM
def recover_x0(self, model_out, x_t, t):
sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
return (x_t - sqrt_omab * model_out) / sqrt_ab
def denoise_step(self, x_t, t, predicted_noise):
s = self.schedule
ab = s.alpha_bar[t].view(-1, 1, 1, 1)
ab_prev = s.alpha_bar_prev[t].view(-1, 1, 1, 1)
# Predict x_0 from x_t and ε
x0_pred = self.recover_x0(predicted_noise, x_t, t)
# DDIM formula: deterministic + optional stochastic term
sigma = self.eta * ((1 - ab_prev) / (1 - ab) * (1 - ab / ab_prev)).sqrt()
dir_xt = (1 - ab_prev - sigma ** 2).sqrt() * predicted_noise
x_prev = ab_prev.sqrt() * x0_pred + dir_xt
if self.eta > 0 and t[0].item() > 0:
x_prev = x_prev + sigma * torch.randn_like(x_t)
return x_prev
@torch.no_grad()
def sample(self, shape, device="cpu", guidance_fn=None):
"""Override to support timestep subsampling (fast sampling)."""
x_t = torch.randn(shape, device=device)
steps = self.sample_steps or list(reversed(range(self.schedule.T)))
for t_val in steps:
t = torch.full((shape[0],), t_val, device=device, dtype=torch.long)
model_out = self.model(x_t, t)
if guidance_fn is not None:
model_out = guidance_fn(model_out, x_t, t)
x_t = self.denoise_step(x_t, t, model_out)
return x_t
# ── 3. v-prediction ──────────────────────────────────────────────
class VPrediction(DiffusionAlgorithm):
"""
Instead of predicting noise ε or signal x_0, predict the
"velocity" v = √ᾱ · ε − √(1−ᾱ) · x_0.
Benefits:
• More numerically stable at t≈0 and t≈T
• Better for high-resolution generation
• Used in Stable Diffusion 2.x, Imagen Video
"""
def compute_target(self, x_0, noise, t):
sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
return sqrt_ab * noise - sqrt_omab * x_0 # velocity v
def recover_x0(self, model_out, x_t, t):
sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
return sqrt_ab * x_t - sqrt_omab * model_out
def denoise_step(self, x_t, t, predicted_v):
# Recover ε from v, then use DDIM-style step
sqrt_ab = self.schedule.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
sqrt_omab = self.schedule.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
x0_pred = sqrt_ab * x_t - sqrt_omab * predicted_v
eps_pred = sqrt_omab * x_t + sqrt_ab * predicted_v
ab_prev = self.schedule.alpha_bar_prev[t].view(-1, 1, 1, 1)
return ab_prev.sqrt() * x0_pred + (1 - ab_prev).sqrt() * eps_pred
# ═══════════════════════════════════════════════════════════════════
# CLASSIFIER-FREE GUIDANCE (the standard way to steer generation)
# ═══════════════════════════════════════════════════════════════════
#
# Problem: how do you make the model generate "a photo of a cat"
# rather than just any image?
#
# Solution: train ONE model that can be both conditional and
# unconditional (randomly drop the conditioning during training).
# At sampling time, extrapolate AWAY from the unconditional output:
#
# output = uncond + w · (cond − uncond)
#
# w > 1 amplifies the effect of the condition. w=7.5 is a common default.
# This is the mechanism behind all text-to-image models (DALL-E 2,
# Stable Diffusion, Imagen, etc.)
class ClassifierFreeGuidance:
"""
Wraps a conditional model to apply guidance at sampling time.
The model must accept a conditioning signal (e.g. text embeddings)
and a null/empty condition for the unconditional pass.
"""
def __init__(self, model, null_cond, guidance_scale=7.5):
self.model = model
self.null_cond = null_cond # e.g. empty text embedding
self.w = guidance_scale
def __call__(self, x_t, t, cond):
# Run model twice: once conditional, once unconditional
cond_out = self.model(x_t, t, cond)
uncond_out = self.model(x_t, t, self.null_cond)
# Extrapolate toward the conditional prediction
return uncond_out + self.w * (cond_out - uncond_out)