Gumbel-Softmax
Gumbel-Softmax
Section titled “Gumbel-Softmax”A continuous, differentiable relaxation of categorical sampling. Instead of drawing a hard one-hot sample from a categorical distribution (which blocks gradients), Gumbel-softmax adds Gumbel noise to the logits and applies a temperature-controlled softmax to produce a soft approximation. Also called the concrete distribution. Used in discrete VAEs (DALL-E), learned routing, and architecture search.
Intuition
Section titled “Intuition”Categorical sampling is fundamentally non-differentiable — you pick one category and ignore the rest. The Gumbel-softmax trick works in two stages. First, the Gumbel-max trick: adding Gumbel-distributed noise to logits and taking argmax is mathematically equivalent to sampling from the categorical distribution. This is useful on its own (it converts sampling into an optimisation), but argmax is still non-differentiable.
Second, replace argmax with softmax at low temperature. As temperature , softmax approaches argmax — the output becomes a one-hot vector. As , it becomes uniform — every category gets equal weight. At intermediate temperatures, you get a soft approximation that’s differentiable everywhere. Gradients flow through the softmax, through the logits, and back to the parameters that produced them.
The tradeoff is bias vs. variance vs. discreteness. At high temperature, gradients are smooth and low-variance, but the “sample” is far from one-hot — your downstream network sees a blurry mixture instead of a crisp selection. At low temperature, the sample is nearly discrete but gradients become sharp and noisy. In practice, you anneal temperature during training: start warm for stable gradients, cool down for crisper samples.
Gumbel-max trick — exact categorical sampling via argmax:
and Gumbel noise is generated as , .
Gumbel-softmax relaxation — replace argmax with softmax at temperature :
The output is a probability vector on the -simplex. As , converges to a one-hot vector; as , converges to .
Straight-through variant — for when downstream code needs a hard one-hot:
Forward: uses the hard one-hot. Backward: gradients flow through .
import torchimport torch.nn.functional as F
# ── Gumbel-softmax sampling ─────────────────────────────────────logits = model(x) # (B, K) unnormalised log-probs
# PyTorch has a built-in:y_soft = F.gumbel_softmax(logits, tau=1.0, hard=False) # (B, K) soft sampley_hard = F.gumbel_softmax(logits, tau=0.5, hard=True) # (B, K) hard one-hot# hard=True uses the straight-through trick internally:# forward is one-hot, backward goes through the soft version
# ── Typical usage: select from an embedding table ───────────────embeddings = embedding_table.weight # (K, D)# Soft selection (differentiable weighted average):z = y_soft @ embeddings # (B, D)# Hard selection (one-hot, but gradients still flow via ST):z = y_hard @ embeddings # (B, D)
# ── Temperature annealing (in training loop) ────────────────────# Start high (τ=1.0) for smooth gradients, anneal to low (τ=0.1)tau = max(0.1, 1.0 * np.exp(-anneal_rate * step))y = F.gumbel_softmax(logits, tau=tau, hard=True)
# WARNING: tau=0 is undefined (division by zero). Keep tau >= 0.1.# WARNING: With hard=False, downstream layers see a soft mixture,# which may behave differently than a true discrete selection at test time.Manual Implementation
Section titled “Manual Implementation”import numpy as np
def gumbel_softmax(logits, tau=1.0, hard=False): """ Gumbel-softmax sampling. logits: (B, K) unnormalised log-probabilities tau: temperature (>0, lower = more discrete) hard: if True, return one-hot with straight-through gradient Returns: (B, K) soft sample on the simplex (or hard one-hot) """ B, K = logits.shape
# Sample Gumbel(0, 1) noise: g = -log(-log(u)), u ~ Uniform(0,1) u = np.random.uniform(low=1e-8, high=1.0 - 1e-8, size=(B, K)) # (B, K) g = -np.log(-np.log(u)) # (B, K)
# Add noise to logits and apply temperature-scaled softmax noisy_logits = (logits + g) / tau # (B, K)
# Numerically stable softmax shifted = noisy_logits - noisy_logits.max(axis=1, keepdims=True) # (B, K) exp_shifted = np.exp(shifted) # (B, K) y_soft = exp_shifted / exp_shifted.sum(axis=1, keepdims=True) # (B, K)
if hard: # One-hot from argmax (forward value) indices = np.argmax(y_soft, axis=1) # (B,) y_hard = np.zeros_like(y_soft) # (B, K) y_hard[np.arange(B), indices] = 1.0 # In autograd: y_hard - y_soft.detach() + y_soft (ST trick) return y_hard # numpy has no autograd, so just return hard return y_soft
def gumbel_noise(shape): """Sample Gumbel(0,1) noise.""" u = np.random.uniform(1e-8, 1.0 - 1e-8, size=shape) return -np.log(-np.log(u))Popular Uses
Section titled “Popular Uses”- DALL-E 1 (dVAE): uses Gumbel-softmax to train a discrete VAE over image tokens, enabling autoregressive generation with a transformer
- Discrete latent variable models: any model with categorical latent variables that needs gradient-based training (discrete VAE, VQ without commitment loss)
- Neural architecture search (DARTS): soft selection over candidate operations during search, hardened at evaluation time
- Learned routing / mixture of experts (Switch Transformer): differentiable expert selection (though most MoE systems use top-k with load balancing instead)
- Hard attention: differentiable approximation to discrete attention selection
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Straight-through estimator | Simple quantisation (VQ-VAE) | No temperature tuning needed, simpler to implement, but more biased — no smooth annealing path |
| REINFORCE / score function | Truly discrete forward pass needed, unbiased gradients | Unbiased but very high variance; needs baselines and many samples to be practical |
| Reparameterisation trick | Continuous latent variables | Exact gradients, no bias — but only for continuous distributions, not categorical |
| Top-k with STE | Sparse selection (mixture of experts) | Selects k experts with STE gradients; simpler than Gumbel for routing but no probabilistic interpretation |
| Categorical sampling + REINFORCE | Exact discrete samples needed for evaluation | True samples (no relaxation bias) but gradient estimation is much noisier |
Historical Context
Section titled “Historical Context”The Gumbel-softmax trick was introduced simultaneously by Jang, Gu & Poole (2017, “Categorical Reparameterization with Gumbel-Softmax”) and Maddison, Mnih & Teh (2017, “The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables”). Both papers built on the classical Gumbel-max trick from extreme value theory (Gumbel, 1954) and showed that replacing argmax with softmax created a practical reparameterisation for categorical variables.
The technique saw its highest-profile use in DALL-E 1 (Ramesh et al., 2021), which trained a discrete VAE with Gumbel-softmax relaxation to tokenise images into a grid of discrete codes. Later discrete representation methods (VQ-GAN, SoundStream) often preferred VQ with straight-through estimation for its simplicity, but Gumbel-softmax remains the go-to when you need a principled probabilistic relaxation of categorical choices.