Unified Variational Inference / VAE
Unified Variational Inference / VAE
Section titled “Unified Variational Inference / VAE”Introduction
Section titled “Introduction”This one has a nice structural symmetry with the other files, but the central tension is different. Where Q-learning variants swap out the target, policy gradient variants swap out the advantage, and contrastive variants swap out the collapse-prevention mechanism, VAE variants all navigate the same single tradeoff: reconstruction vs. regularisation.
The key insights this file tries to make clear:
The reparameterisation trick is the entire reason VAEs work. Without it, you can’t backprop through sampling, and you’re stuck with high-variance REINFORCE estimators. The trick — externalize the randomness as ε ~ N(0,I) and write z = μ + σ·ε — is one line of code but arguably the most important idea in the file.
VQ-VAE is the most radical departure. It replaces the entire Gaussian machinery (reparameterisation, KL divergence) with a codebook lookup + straight-through estimator. This is a fundamentally different kind of bottleneck — discrete rather than continuous — and it’s what enabled treating images as token sequences for autoregressive generation.
KL-AE (Stable Diffusion’s VAE) is barely a VAE at all. The KL weight is ~1e-6, so it’s essentially a perceptual autoencoder with a faint regularisation whisper to keep the latent space smooth enough for diffusion to work. The heavy lifting for generation is done by the diffusion model in latent space; the VAE is just the compressor.
The “HOW VAEs CONNECT TO OTHER ALGORITHMS” section at the bottom ties this file back to the others in the series — VAE → diffusion (Stable Diffusion pipeline), VQ-VAE → transformer (DALL-E 1), VAE → contrastive (generative vs. discriminative representations). These aren’t isolated algorithms; they compose.
Summary: What changes vs. what stays the same
Section titled “Summary: What changes vs. what stays the same”Always the same (core loop)
Section titled “Always the same (core loop)”- Encode x → latent distribution params
- Sample z (differentiably) (PLUGGABLE)
- Decode z → reconstruction
- Loss = reconstruction + regularisation (both PLUGGABLE)
- Gradient step
What varies by variant
Section titled “What varies by variant”| Variant | Latent space | Regularisation | Reconstruction loss |
|---|---|---|---|
| VAE | Gaussian (μ, σ²) | KL to N(0,I) | MSE or BCE |
| β-VAE | Gaussian (μ, σ²) | β · KL | MSE or BCE |
| VQ-VAE | Discrete codes | Commitment loss | MSE |
| CVAE | Gaussian (μ, σ²) | KL to N(0,I) | MSE or BCE |
| KL-AE (Latent Diff.) | Gaussian (μ, σ²) | tiny KL (≈1e-6) | MSE + perceptual + adversarial |
Motives for each variant
Section titled “Motives for each variant”| Variant | Problem Solved | Intuition for Solution |
|---|---|---|
| VAE | Plain autoencoders learn arbitrary, disconnected latent codes — you can’t sample or interpolate in latent space meaningfully | Force the encoder to output a DISTRIBUTION, not a point. The KL term keeps this distribution close to N(0,I), ensuring the latent space is smooth, connected, and sampleable. The reparameterisation trick makes it trainable |
| β-VAE | Vanilla VAE latent dims tend to entangle multiple factors of variation in the same dimensions, making the space hard to interpret or control | Increase KL weight (β > 1) to force a more factorised posterior — each dim is pushed harder toward independent N(0,1). This encourages disentanglement: one dim for size, one for colour, etc. Costs some reconstruction quality |
| VQ-VAE | Continuous latent spaces have “holes” — regions that decode to garbage — and don’t map naturally to discrete, token-based generation (autoregressive models) | Replace the Gaussian with a DISCRETE codebook: the encoder’s output is snapped to the nearest code vector. Every point in the codebook decodes to something valid. The discrete codes can then be modelled by a transformer for generation (DALL-E 1, AudioLM) |
| CVAE | Vanilla VAE generates random samples from p(z) — no control over WHAT is generated | Give both encoder and decoder access to a condition c (class label, text, …). The latent z now captures only the variation NOT explained by c. At generation: choose c, sample z, decode |
| KL-AE (Latent Diff.) | Pixel-space diffusion at high resolution (512×512) is prohibitively expensive — attention is O(T²) on pixels | Train a VAE to compress images to a small spatial latent (64×64). Use near-zero KL weight so the space is almost autoencoder-quality (sharp), with perceptual + adversarial loss for high fidelity. Diffusion then runs in this 64× cheaper latent space |
The reconstruction-regularisation tradeoff
Section titled “The reconstruction-regularisation tradeoff”The two loss terms pull in opposite directions:
- Reconstruction wants the encoder to pack as much information as possible into z, using every bit of the latent space.
- KL regularisation wants the encoder to make q(z|x) ≈ N(0,I), which means IGNORING information and producing the same z regardless of input.
This tension is the central design knob of every VAE variant:
- β-VAE: crank regularisation → disentangled but blurry
- KL-AE: crank reconstruction → sharp but less structured
- VQ-VAE: sidestep entirely — discrete codes impose structure without KL, and commitment loss is a gentler leash
- CVAE: offload information to the condition, so z can be simpler while the decoder still reconstructs well
In the extreme: reg=0 is a plain autoencoder (no generation), reg=∞ ignores the input entirely (prior collapse / posterior collapse — z becomes uninformative).
How VAEs connect to other algorithms in this series
Section titled “How VAEs connect to other algorithms in this series”- VAE → Diffusion: Stable Diffusion’s “VAE” is a KL-AE that compresses images to latents. Diffusion (DDPM/DDIM) then runs in that latent space. The two are chained: VAE encodes, diffusion models the latent distribution, VAE decodes.
- VAE → Contrastive: Both learn representations. VAEs optimise reconstruction (generative). Contrastive methods optimise similarity structure (discriminative). VAE latents capture everything needed to reconstruct; contrastive embeddings capture what’s needed to distinguish.
- VQ-VAE → Transformer: VQ-VAE turns images into discrete tokens. A transformer then models the token sequence autoregressively — the exact same architecture used for language, now generating images token by token (DALL-E 1).