Skip to content

Unified Generative Adversarial Networks (GANs): Core Algorithm — Implementation

"""
Unified Generative Adversarial Networks (GANs): Core Algorithm
================================================================
A single skeleton covering: vanilla GAN, DCGAN, WGAN, WGAN-GP,
StyleGAN (conceptual), conditional GAN (cGAN), and pix2pix-style
paired translation.

The core idea shared by ALL GANs:

  Two networks play a minimax game:

  GENERATOR (G):     z → fake_data          (tries to fool D)
  DISCRIMINATOR (D): data → real_or_fake    (tries to catch G)

  G wants D to output "real" for fakes.
  D wants to output "real" for real data, "fake" for G's output.

  At equilibrium: G produces data indistinguishable from real,
  and D can't tell the difference (outputs 0.5 for everything).

The pluggable components are:
  1. generator_loss()      — how G is penalised (minimax, non-saturating, Wasserstein)
  2. discriminator_loss()  — how D is penalised (BCE, hinge, Wasserstein)
  3. regularisation()      — gradient penalty, spectral norm, R1, ...
  4. conditioning          — whether/how class labels or inputs are injected
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod


# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM  (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
#
# Every GAN variant follows this exact two-step loop:
#
#   Step 1: Update D — make it better at telling real from fake
#   Step 2: Update G — make it better at fooling D
#
# The adversarial dynamic IS the algorithm. Unlike every other
# file in this series, there are TWO networks with OPPOSING losses
# trained in alternation. This is what makes GANs both powerful
# and unstable.

class GANAlgorithm(ABC):
    """
    The universal GAN training step.

    Every variant inherits this and only overrides:
      - discriminator_loss(real_scores, fake_scores) → scalar
      - generator_loss(fake_scores) → scalar
      - discriminator_regularisation(D, real, fake) → scalar  (optional)
    """

    def __init__(self, G, D, g_optimizer, d_optimizer,
                 d_latent=128, n_d_steps=1):
        self.G = G
        self.D = D
        self.g_optimizer = g_optimizer
        self.d_optimizer = d_optimizer
        self.d_latent = d_latent
        self.n_d_steps = n_d_steps     # D updates per G update (often 1, WGAN uses 5)

    # ── The pluggable pieces ──────────────────────────────────────

    @abstractmethod
    def discriminator_loss(self, real_scores, fake_scores):
        """How D is trained. Returns scalar to MINIMISE."""
        ...

    @abstractmethod
    def generator_loss(self, fake_scores):
        """How G is trained. Returns scalar to MINIMISE."""
        ...

    def discriminator_regularisation(self, D, real_data, fake_data):
        """Optional: gradient penalty, R1, etc. Default: none."""
        return torch.tensor(0.0, device=real_data.device)

    def sample_noise(self, batch_size, device):
        """Sample latent z ~ N(0, I). Override for different priors."""
        return torch.randn(batch_size, self.d_latent, device=device)

    # ── Core training step (IDENTICAL for every variant) ──────────

    def train_step(self, real_data):
        device = real_data.device
        B = real_data.size(0)

        # ── Step 1: Update discriminator ──────────────────────────
        # D sees real data and fake data, learns to distinguish them.
        d_loss_total = 0.0
        for _ in range(self.n_d_steps):
            z = self.sample_noise(B, device)
            with torch.no_grad():
                fake_data = self.G(z)                            # don't track G grads

            real_scores = self.D(real_data)                      # D(x)
            fake_scores = self.D(fake_data)                      # D(G(z))

            d_loss = self.discriminator_loss(real_scores, fake_scores)
            d_reg = self.discriminator_regularisation(self.D, real_data, fake_data)
            d_loss_full = d_loss + d_reg

            self.d_optimizer.zero_grad()
            d_loss_full.backward()
            self.d_optimizer.step()

            d_loss_total += d_loss_full.item()

        # ── Step 2: Update generator ──────────────────────────────
        # G generates fakes and wants D to think they're real.
        z = self.sample_noise(B, device)
        fake_data = self.G(z)
        fake_scores = self.D(fake_data)                          # D(G(z)), grads flow to G

        g_loss = self.generator_loss(fake_scores)

        self.g_optimizer.zero_grad()
        g_loss.backward()
        self.g_optimizer.step()

        return {"g_loss": g_loss.item(),
                "d_loss": d_loss_total / self.n_d_steps,
                "d_real": real_scores.mean().item(),
                "d_fake": fake_scores.mean().item()}

    # ── Generation ────────────────────────────────────────────────

    @torch.no_grad()
    def generate(self, n_samples, device="cpu"):
        z = self.sample_noise(n_samples, device)
        return self.G(z)


# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP
# ═══════════════════════════════════════════════════════════════════

def train(algo: GANAlgorithm, dataloader, n_epochs, device="cpu"):
    for epoch in range(n_epochs):
        totals = {"g_loss": 0, "d_loss": 0, "d_real": 0, "d_fake": 0}
        n = 0
        for real, *_ in dataloader:
            real = real.to(device)
            metrics = algo.train_step(real)
            for k in totals:
                totals[k] += metrics[k] * real.size(0)
            n += real.size(0)

        avg = {k: v / n for k, v in totals.items()}
        print(f"Epoch {epoch+1:3d}/{n_epochs} │ "
              f"G {avg['g_loss']:.4f}  "
              f"D {avg['d_loss']:.4f}  "
              f"D(real) {avg['d_real']:.3f}  "
              f"D(fake) {avg['d_fake']:.3f}")


# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS  (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════

# ── 1. Vanilla GAN  (original minimax / non-saturating) ──────────

class VanillaGAN(GANAlgorithm):
    """
    The original GAN (Goodfellow et al. 2014).

    Minimax objective:
      D maximises:  E[log D(x)] + E[log(1 − D(G(z)))]
      G minimises:  E[log(1 − D(G(z)))]

    In practice, G minimises −E[log D(G(z))] instead (non-saturating).
    The minimax form has vanishing gradients when G is bad — log(1−D(G(z)))
    is flat near 0. The non-saturating form provides strong gradients
    early in training when G most needs them.
    """

    def discriminator_loss(self, real_scores, fake_scores):
        real_loss = F.binary_cross_entropy_with_logits(
            real_scores, torch.ones_like(real_scores))
        fake_loss = F.binary_cross_entropy_with_logits(
            fake_scores, torch.zeros_like(fake_scores))
        return real_loss + fake_loss

    def generator_loss(self, fake_scores):
        # Non-saturating: −log D(G(z)) instead of log(1 − D(G(z)))
        return F.binary_cross_entropy_with_logits(
            fake_scores, torch.ones_like(fake_scores))


# ── 2. WGAN  (Wasserstein distance, weight clipping) ─────────────

class WGAN(GANAlgorithm):
    """
    Wasserstein GAN: replaces the JS divergence with the Wasserstein
    (Earth Mover's) distance.

    D becomes a "critic" — it outputs an unbounded score (not a
    probability). The loss is simply the difference in mean scores:

      D maximises:  E[D(x)] − E[D(G(z))]   (real scores > fake scores)
      G minimises: −E[D(G(z))]              (push fake scores up)

    Requires D to be Lipschitz-continuous. Original WGAN enforces this
    by clamping weights to [−c, c] after each update. This works but
    is crude — the capacity of D is artificially limited.
    """

    def __init__(self, *args, clip_value=0.01, **kw):
        kw.setdefault("n_d_steps", 5)                           # WGAN uses 5 critic steps
        super().__init__(*args, **kw)
        self.clip_value = clip_value

    def discriminator_loss(self, real_scores, fake_scores):
        # Wasserstein: maximise E[D(x)] − E[D(G(z))]
        # Minimise the negation:
        return fake_scores.mean() - real_scores.mean()

    def generator_loss(self, fake_scores):
        return -fake_scores.mean()

    def train_step(self, real_data):
        metrics = super().train_step(real_data)

        # Weight clipping to enforce Lipschitz constraint
        with torch.no_grad():
            for p in self.D.parameters():
                p.clamp_(-self.clip_value, self.clip_value)

        return metrics


# ── 3. WGAN-GP  (gradient penalty instead of clipping) ───────────

class WGANGP(GANAlgorithm):
    """
    WGAN with Gradient Penalty: same Wasserstein loss, but enforces
    the Lipschitz constraint via a penalty on D's gradient norm.

    The penalty is computed at random interpolations between real
    and fake data: x̂ = αx + (1−α)G(z), penalise (‖∇D(x̂)‖ − 1)².

    This is mathematically cleaner than weight clipping, allows
    deeper/larger discriminators, and is much more stable.
    """

    def __init__(self, *args, gp_weight=10.0, **kw):
        kw.setdefault("n_d_steps", 5)
        super().__init__(*args, **kw)
        self.gp_weight = gp_weight

    def discriminator_loss(self, real_scores, fake_scores):
        return fake_scores.mean() - real_scores.mean()

    def generator_loss(self, fake_scores):
        return -fake_scores.mean()

    def discriminator_regularisation(self, D, real_data, fake_data):
        B = real_data.size(0)
        alpha = torch.rand(B, *([1] * (real_data.dim() - 1)), device=real_data.device)
        interpolated = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)

        scores = D(interpolated)
        grads = torch.autograd.grad(
            outputs=scores, inputs=interpolated,
            grad_outputs=torch.ones_like(scores),
            create_graph=True, retain_graph=True)[0]

        grad_norm = grads.reshape(B, -1).norm(2, dim=1)
        penalty = ((grad_norm - 1) ** 2).mean()
        return self.gp_weight * penalty


# ── 4. Hinge GAN  (spectral norm + hinge loss) ──────────────────

class HingeGAN(GANAlgorithm):
    """
    Hinge loss GAN: used in SAGAN, BigGAN, StyleGAN-XL.

    D loss:  E[max(0, 1 − D(x))] + E[max(0, 1 + D(G(z)))]
    G loss: −E[D(G(z))]

    The hinge loss saturates once D is "confident enough" (score > 1
    for real, < −1 for fake), preventing D from becoming too strong
    and starving G of gradients.

    Typically paired with spectral normalisation on D's weights
    (applied externally to the network architecture, not shown here).
    """

    def discriminator_loss(self, real_scores, fake_scores):
        return (F.relu(1 - real_scores).mean() +
                F.relu(1 + fake_scores).mean())

    def generator_loss(self, fake_scores):
        return -fake_scores.mean()


# ── 5. Conditional GAN  (cGAN: class-conditional generation) ─────

class ConditionalGAN(GANAlgorithm):
    """
    Both G and D receive a condition (class label, text, etc.).

    G(z, c) → fake:  "generate a cat" vs "generate a dog"
    D(x, c) → score: "is this a real cat?" not just "is this real?"

    The condition can be injected by concatenation, projection,
    or adaptive normalisation (class-conditional BatchNorm, as in
    BigGAN and StyleGAN).

    This class wraps any base loss variant — the conditioning
    mechanism is orthogonal to the loss function choice.
    """

    def __init__(self, G, D, g_optimizer, d_optimizer,
                 base_loss: GANAlgorithm, cond_encoder=None, **kw):
        super().__init__(G, D, g_optimizer, d_optimizer, **kw)
        self.base = base_loss
        self.cond_encoder = cond_encoder

    def discriminator_loss(self, real_scores, fake_scores):
        return self.base.discriminator_loss(real_scores, fake_scores)

    def generator_loss(self, fake_scores):
        return self.base.generator_loss(fake_scores)

    def train_step_conditional(self, real_data, conditions):
        device = real_data.device
        B = real_data.size(0)

        if self.cond_encoder is not None:
            c = self.cond_encoder(conditions)
        else:
            c = conditions

        # ── Update D ──────────────────────────────────────────────
        z = self.sample_noise(B, device)
        with torch.no_grad():
            fake_data = self.G(z, c)

        real_scores = self.D(real_data, c)
        fake_scores = self.D(fake_data, c)

        d_loss = self.discriminator_loss(real_scores, fake_scores)
        self.d_optimizer.zero_grad()
        d_loss.backward()
        self.d_optimizer.step()

        # ── Update G ──────────────────────────────────────────────
        z = self.sample_noise(B, device)
        fake_data = self.G(z, c)
        fake_scores = self.D(fake_data, c)

        g_loss = self.generator_loss(fake_scores)
        self.g_optimizer.zero_grad()
        g_loss.backward()
        self.g_optimizer.step()

        return {"g_loss": g_loss.item(), "d_loss": d_loss.item(),
                "d_real": real_scores.mean().item(),
                "d_fake": fake_scores.mean().item()}


# ── 6. Pix2Pix-style  (paired image translation) ────────────────

class PairedTranslationGAN(GANAlgorithm):
    """
    Image-to-image translation with paired data (Pix2Pix).

    G takes an input IMAGE (not noise) and produces an output image:
      G(x_input) → x_output

    D sees (input, output) pairs and judges whether the output is
    real or generated. Typically a PatchGAN discriminator that
    classifies overlapping patches rather than the whole image.

    Loss adds a reconstruction term (L1) so G doesn't just produce
    realistic-looking images that ignore the input.
    """

    def __init__(self, G, D, g_optimizer, d_optimizer,
                 l1_weight=100.0, **kw):
        super().__init__(G, D, g_optimizer, d_optimizer, **kw)
        self.l1_weight = l1_weight

    def discriminator_loss(self, real_scores, fake_scores):
        real_loss = F.binary_cross_entropy_with_logits(
            real_scores, torch.ones_like(real_scores))
        fake_loss = F.binary_cross_entropy_with_logits(
            fake_scores, torch.zeros_like(fake_scores))
        return real_loss + fake_loss

    def generator_loss(self, fake_scores):
        return F.binary_cross_entropy_with_logits(
            fake_scores, torch.ones_like(fake_scores))

    def train_step_paired(self, input_data, target_data):
        device = input_data.device

        # G takes input, produces output (no noise)
        with torch.no_grad():
            fake_output = self.G(input_data)

        # D sees (input, target) pairs
        real_scores = self.D(torch.cat([input_data, target_data], dim=1))
        fake_scores = self.D(torch.cat([input_data, fake_output], dim=1))

        d_loss = self.discriminator_loss(real_scores, fake_scores)
        self.d_optimizer.zero_grad()
        d_loss.backward()
        self.d_optimizer.step()

        # G: adversarial + L1 reconstruction
        fake_output = self.G(input_data)
        fake_scores = self.D(torch.cat([input_data, fake_output], dim=1))

        g_loss_adv = self.generator_loss(fake_scores)
        g_loss_l1 = F.l1_loss(fake_output, target_data)
        g_loss = g_loss_adv + self.l1_weight * g_loss_l1

        self.g_optimizer.zero_grad()
        g_loss.backward()
        self.g_optimizer.step()

        return {"g_loss": g_loss.item(), "d_loss": d_loss.item(),
                "g_l1": g_loss_l1.item()}