Skip to content

Gumbel-Softmax

A continuous, differentiable relaxation of categorical sampling. Instead of drawing a hard one-hot sample from a categorical distribution (which blocks gradients), Gumbel-softmax adds Gumbel noise to the logits and applies a temperature-controlled softmax to produce a soft approximation. Also called the concrete distribution. Used in discrete VAEs (DALL-E), learned routing, and architecture search.

Categorical sampling is fundamentally non-differentiable — you pick one category and ignore the rest. The Gumbel-softmax trick works in two stages. First, the Gumbel-max trick: adding Gumbel-distributed noise to logits and taking argmax is mathematically equivalent to sampling from the categorical distribution. This is useful on its own (it converts sampling into an optimisation), but argmax is still non-differentiable.

Second, replace argmax with softmax at low temperature. As temperature τ0\tau \to 0, softmax approaches argmax — the output becomes a one-hot vector. As τ\tau \to \infty, it becomes uniform — every category gets equal weight. At intermediate temperatures, you get a soft approximation that’s differentiable everywhere. Gradients flow through the softmax, through the logits, and back to the parameters that produced them.

The tradeoff is bias vs. variance vs. discreteness. At high temperature, gradients are smooth and low-variance, but the “sample” is far from one-hot — your downstream network sees a blurry mixture instead of a crisp selection. At low temperature, the sample is nearly discrete but gradients become sharp and noisy. In practice, you anneal temperature during training: start warm for stable gradients, cool down for crisper samples.

Gumbel-max trick — exact categorical sampling via argmax:

k=argmaxi[logπi+gi]where giGumbel(0,1)k = \arg\max_i \bigl[\log \pi_i + g_i\bigr] \quad \text{where } g_i \sim \text{Gumbel}(0, 1)

and Gumbel noise is generated as gi=log(log(ui))g_i = -\log(-\log(u_i)), uiUniform(0,1)u_i \sim \text{Uniform}(0, 1).

Gumbel-softmax relaxation — replace argmax with softmax at temperature τ\tau:

yi=exp((logπi+gi)/τ)jexp((logπj+gj)/τ)for i=1,,Ky_i = \frac{\exp\bigl((\log \pi_i + g_i) / \tau\bigr)}{\sum_j \exp\bigl((\log \pi_j + g_j) / \tau\bigr)} \quad \text{for } i = 1, \dots, K

The output yy is a probability vector on the KK-simplex. As τ0\tau \to 0, yy converges to a one-hot vector; as τ\tau \to \infty, yy converges to (1K,,1K)(\frac{1}{K}, \dots, \frac{1}{K}).

Straight-through variant — for when downstream code needs a hard one-hot:

yhard=one_hot(argmax(ysoft))y_{\text{hard}} = \text{one\_hot}(\arg\max(y_{\text{soft}})) yST=yhardysoft.detach()+ysofty_{\text{ST}} = y_{\text{hard}} - y_{\text{soft}}.\text{detach()} + y_{\text{soft}}

Forward: uses the hard one-hot. Backward: gradients flow through ysofty_{\text{soft}}.

import torch
import torch.nn.functional as F
# ── Gumbel-softmax sampling ─────────────────────────────────────
logits = model(x) # (B, K) unnormalised log-probs
# PyTorch has a built-in:
y_soft = F.gumbel_softmax(logits, tau=1.0, hard=False) # (B, K) soft sample
y_hard = F.gumbel_softmax(logits, tau=0.5, hard=True) # (B, K) hard one-hot
# hard=True uses the straight-through trick internally:
# forward is one-hot, backward goes through the soft version
# ── Typical usage: select from an embedding table ───────────────
embeddings = embedding_table.weight # (K, D)
# Soft selection (differentiable weighted average):
z = y_soft @ embeddings # (B, D)
# Hard selection (one-hot, but gradients still flow via ST):
z = y_hard @ embeddings # (B, D)
# ── Temperature annealing (in training loop) ────────────────────
# Start high (τ=1.0) for smooth gradients, anneal to low (τ=0.1)
tau = max(0.1, 1.0 * np.exp(-anneal_rate * step))
y = F.gumbel_softmax(logits, tau=tau, hard=True)
# WARNING: tau=0 is undefined (division by zero). Keep tau >= 0.1.
# WARNING: With hard=False, downstream layers see a soft mixture,
# which may behave differently than a true discrete selection at test time.
import numpy as np
def gumbel_softmax(logits, tau=1.0, hard=False):
"""
Gumbel-softmax sampling.
logits: (B, K) unnormalised log-probabilities
tau: temperature (>0, lower = more discrete)
hard: if True, return one-hot with straight-through gradient
Returns: (B, K) soft sample on the simplex (or hard one-hot)
"""
B, K = logits.shape
# Sample Gumbel(0, 1) noise: g = -log(-log(u)), u ~ Uniform(0,1)
u = np.random.uniform(low=1e-8, high=1.0 - 1e-8, size=(B, K)) # (B, K)
g = -np.log(-np.log(u)) # (B, K)
# Add noise to logits and apply temperature-scaled softmax
noisy_logits = (logits + g) / tau # (B, K)
# Numerically stable softmax
shifted = noisy_logits - noisy_logits.max(axis=1, keepdims=True) # (B, K)
exp_shifted = np.exp(shifted) # (B, K)
y_soft = exp_shifted / exp_shifted.sum(axis=1, keepdims=True) # (B, K)
if hard:
# One-hot from argmax (forward value)
indices = np.argmax(y_soft, axis=1) # (B,)
y_hard = np.zeros_like(y_soft) # (B, K)
y_hard[np.arange(B), indices] = 1.0
# In autograd: y_hard - y_soft.detach() + y_soft (ST trick)
return y_hard # numpy has no autograd, so just return hard
return y_soft
def gumbel_noise(shape):
"""Sample Gumbel(0,1) noise."""
u = np.random.uniform(1e-8, 1.0 - 1e-8, size=shape)
return -np.log(-np.log(u))
  • DALL-E 1 (dVAE): uses Gumbel-softmax to train a discrete VAE over image tokens, enabling autoregressive generation with a transformer
  • Discrete latent variable models: any model with categorical latent variables that needs gradient-based training (discrete VAE, VQ without commitment loss)
  • Neural architecture search (DARTS): soft selection over candidate operations during search, hardened at evaluation time
  • Learned routing / mixture of experts (Switch Transformer): differentiable expert selection (though most MoE systems use top-k with load balancing instead)
  • Hard attention: differentiable approximation to discrete attention selection
AlternativeWhen to useTradeoff
Straight-through estimatorSimple quantisation (VQ-VAE)No temperature tuning needed, simpler to implement, but more biased — no smooth annealing path
REINFORCE / score functionTruly discrete forward pass needed, unbiased gradientsUnbiased but very high variance; needs baselines and many samples to be practical
Reparameterisation trickContinuous latent variablesExact gradients, no bias — but only for continuous distributions, not categorical
Top-k with STESparse selection (mixture of experts)Selects k experts with STE gradients; simpler than Gumbel for routing but no probabilistic interpretation
Categorical sampling + REINFORCEExact discrete samples needed for evaluationTrue samples (no relaxation bias) but gradient estimation is much noisier

The Gumbel-softmax trick was introduced simultaneously by Jang, Gu & Poole (2017, “Categorical Reparameterization with Gumbel-Softmax”) and Maddison, Mnih & Teh (2017, “The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables”). Both papers built on the classical Gumbel-max trick from extreme value theory (Gumbel, 1954) and showed that replacing argmax with softmax created a practical reparameterisation for categorical variables.

The technique saw its highest-profile use in DALL-E 1 (Ramesh et al., 2021), which trained a discrete VAE with Gumbel-softmax relaxation to tokenise images into a grid of discrete codes. Later discrete representation methods (VQ-GAN, SoundStream) often preferred VQ with straight-through estimation for its simplicity, but Gumbel-softmax remains the go-to when you need a principled probabilistic relaxation of categorical choices.