Skip to content

Unified Variational Inference / VAE

This one has a nice structural symmetry with the other files, but the central tension is different. Where Q-learning variants swap out the target, policy gradient variants swap out the advantage, and contrastive variants swap out the collapse-prevention mechanism, VAE variants all navigate the same single tradeoff: reconstruction vs. regularisation.

The key insights this file tries to make clear:

The reparameterisation trick is the entire reason VAEs work. Without it, you can’t backprop through sampling, and you’re stuck with high-variance REINFORCE estimators. The trick — externalize the randomness as ε ~ N(0,I) and write z = μ + σ·ε — is one line of code but arguably the most important idea in the file.

VQ-VAE is the most radical departure. It replaces the entire Gaussian machinery (reparameterisation, KL divergence) with a codebook lookup + straight-through estimator. This is a fundamentally different kind of bottleneck — discrete rather than continuous — and it’s what enabled treating images as token sequences for autoregressive generation.

KL-AE (Stable Diffusion’s VAE) is barely a VAE at all. The KL weight is ~1e-6, so it’s essentially a perceptual autoencoder with a faint regularisation whisper to keep the latent space smooth enough for diffusion to work. The heavy lifting for generation is done by the diffusion model in latent space; the VAE is just the compressor.

The “HOW VAEs CONNECT TO OTHER ALGORITHMS” section at the bottom ties this file back to the others in the series — VAE → diffusion (Stable Diffusion pipeline), VQ-VAE → transformer (DALL-E 1), VAE → contrastive (generative vs. discriminative representations). These aren’t isolated algorithms; they compose.

Summary: What changes vs. what stays the same

Section titled “Summary: What changes vs. what stays the same”
  • Encode x → latent distribution params
  • Sample z (differentiably) (PLUGGABLE)
  • Decode z → reconstruction
  • Loss = reconstruction + regularisation (both PLUGGABLE)
  • Gradient step
VariantLatent spaceRegularisationReconstruction loss
VAEGaussian (μ, σ²)KL to N(0,I)MSE or BCE
β-VAEGaussian (μ, σ²)β · KLMSE or BCE
VQ-VAEDiscrete codesCommitment lossMSE
CVAEGaussian (μ, σ²)KL to N(0,I)MSE or BCE
KL-AE (Latent Diff.)Gaussian (μ, σ²)tiny KL (≈1e-6)MSE + perceptual + adversarial
VariantProblem SolvedIntuition for Solution
VAEPlain autoencoders learn arbitrary, disconnected latent codes — you can’t sample or interpolate in latent space meaningfullyForce the encoder to output a DISTRIBUTION, not a point. The KL term keeps this distribution close to N(0,I), ensuring the latent space is smooth, connected, and sampleable. The reparameterisation trick makes it trainable
β-VAEVanilla VAE latent dims tend to entangle multiple factors of variation in the same dimensions, making the space hard to interpret or controlIncrease KL weight (β > 1) to force a more factorised posterior — each dim is pushed harder toward independent N(0,1). This encourages disentanglement: one dim for size, one for colour, etc. Costs some reconstruction quality
VQ-VAEContinuous latent spaces have “holes” — regions that decode to garbage — and don’t map naturally to discrete, token-based generation (autoregressive models)Replace the Gaussian with a DISCRETE codebook: the encoder’s output is snapped to the nearest code vector. Every point in the codebook decodes to something valid. The discrete codes can then be modelled by a transformer for generation (DALL-E 1, AudioLM)
CVAEVanilla VAE generates random samples from p(z) — no control over WHAT is generatedGive both encoder and decoder access to a condition c (class label, text, …). The latent z now captures only the variation NOT explained by c. At generation: choose c, sample z, decode
KL-AE (Latent Diff.)Pixel-space diffusion at high resolution (512×512) is prohibitively expensive — attention is O(T²) on pixelsTrain a VAE to compress images to a small spatial latent (64×64). Use near-zero KL weight so the space is almost autoencoder-quality (sharp), with perceptual + adversarial loss for high fidelity. Diffusion then runs in this 64× cheaper latent space

The reconstruction-regularisation tradeoff

Section titled “The reconstruction-regularisation tradeoff”

The two loss terms pull in opposite directions:

  • Reconstruction wants the encoder to pack as much information as possible into z, using every bit of the latent space.
  • KL regularisation wants the encoder to make q(z|x) ≈ N(0,I), which means IGNORING information and producing the same z regardless of input.

This tension is the central design knob of every VAE variant:

  • β-VAE: crank regularisation → disentangled but blurry
  • KL-AE: crank reconstruction → sharp but less structured
  • VQ-VAE: sidestep entirely — discrete codes impose structure without KL, and commitment loss is a gentler leash
  • CVAE: offload information to the condition, so z can be simpler while the decoder still reconstructs well

In the extreme: reg=0 is a plain autoencoder (no generation), reg=∞ ignores the input entirely (prior collapse / posterior collapse — z becomes uninformative).

How VAEs connect to other algorithms in this series

Section titled “How VAEs connect to other algorithms in this series”
  • VAE → Diffusion: Stable Diffusion’s “VAE” is a KL-AE that compresses images to latents. Diffusion (DDPM/DDIM) then runs in that latent space. The two are chained: VAE encodes, diffusion models the latent distribution, VAE decodes.
  • VAE → Contrastive: Both learn representations. VAEs optimise reconstruction (generative). Contrastive methods optimise similarity structure (discriminative). VAE latents capture everything needed to reconstruct; contrastive embeddings capture what’s needed to distinguish.
  • VQ-VAE → Transformer: VQ-VAE turns images into discrete tokens. A transformer then models the token sequence autoregressively — the exact same architecture used for language, now generating images token by token (DALL-E 1).