Straight-Through Estimator
Straight-Through Estimator
Section titled “Straight-Through Estimator”Approximates the gradient of a non-differentiable discrete operation (like argmax, rounding, or quantisation) by pretending it was the identity function in the backward pass. The forward pass does the hard discrete operation; the backward pass passes gradients straight through as if nothing happened. Key to VQ-VAE and binary neural networks.
Intuition
Section titled “Intuition”Imagine a staircase function: input goes up smoothly, output jumps between flat steps. The true gradient is zero almost everywhere (on the flat parts) and undefined at the jumps — completely useless for learning. The straight-through estimator says: “for the backward pass, pretend the staircase was a ramp.” The gradient of a ramp is 1 everywhere, so gradients flow through unchanged.
This is clearly wrong — the “gradient” you compute doesn’t match the actual function. But it’s wrong in a useful way. It tells the upstream network: “if you increased your output a little, the loss would change by this much, assuming the discrete operation didn’t interfere.” In practice, the parameters adjust to put the continuous values closer to the discrete values they’ll snap to, and training works surprisingly well.
The key insight: a biased gradient that points roughly in the right direction is far more useful than the true gradient of zero. VQ-VAE (van den Oord et al., 2017) relies entirely on this trick — the encoder outputs a continuous vector, it gets snapped to the nearest codebook entry, and gradients flow back through the snap as if it never happened.
Forward pass — apply a non-differentiable function :
True gradient — zero almost everywhere:
STE approximation — replace the backward pass with identity:
VQ-VAE application — is the encoder output, is the nearest codebook vector:
The codebook vectors are updated separately via an EMA or a commitment loss, not through the STE.
import torch
# ── The core STE pattern in PyTorch ─────────────────────────────# The trick: z_q = z_e + (quantised - z_e).detach()# Forward: z_q = quantised (because z_e - z_e = 0, then + quantised)# Backward: ∂z_q/∂z_e = 1 (because quantised.detach() is a constant)
z_e = encoder(x) # (B, D) continuousdistances = torch.cdist(z_e, codebook.weight) # (B, K)indices = distances.argmin(dim=-1) # (B,) nearest codesz_q = codebook.weight[indices] # (B, D) discrete
# STE: copy gradients from z_q to z_ez_q_st = z_e + (z_q - z_e).detach() # (B, D)# z_q_st has the VALUE of z_q but the GRADIENT path of z_e
reconstruction = decoder(z_q_st) # gradients flow to encoder
# ── For simple rounding / binarisation ──────────────────────────x_hard = torch.round(x_soft) # no gradientx_st = x_soft + (x_hard - x_soft).detach() # STE version
# WARNING: the STE gradient is biased. If the encoder and codebook# drift apart, the approximation degrades. VQ-VAE uses a commitment# loss (β‖z_e - z_q.detach()‖²) to keep them close.Manual Implementation
Section titled “Manual Implementation”import numpy as np
def straight_through_quantise(z_e, codebook): """ VQ-VAE style quantisation with straight-through gradient. z_e: (B, D) continuous encoder outputs codebook: (K, D) codebook vectors Returns: z_q (B, D) quantised values (forward = discrete, backward = identity) indices (B,) which codebook entry was selected """ B, D = z_e.shape K = codebook.shape[0]
# Find nearest codebook entry for each encoder output # ‖z_e - e_k‖² = ‖z_e‖² + ‖e_k‖² - 2·z_e·e_k^T dist = ( np.sum(z_e ** 2, axis=1, keepdims=True) # (B, 1) + np.sum(codebook ** 2, axis=1, keepdims=True).T # (1, K) - 2 * z_e @ codebook.T # (B, K) ) # (B, K) indices = np.argmin(dist, axis=1) # (B,) z_q = codebook[indices] # (B, D)
# STE: in a real backward pass, we'd set ∂z_q/∂z_e = I # In numpy (no autograd), this means: when computing upstream # gradients, treat z_q as if it were z_e. # grad_z_e = grad_z_q (copy gradient unchanged)
return z_q, indices
def ste_backward(grad_output): """ The STE backward pass is literally the identity. grad_output: (B, D) gradient flowing into the quantisation Returns: (B, D) gradient flowing to the encoder — unchanged """ return grad_output # that's it. That's the whole trick.Popular Uses
Section titled “Popular Uses”- VQ-VAE / VQ-VAE-2 (van den Oord et al.): quantise continuous encoder outputs to discrete codebook entries — the STE is the only way gradients reach the encoder (see
variational-inference-vae/) - Binary / ternary neural networks (BinaryConnect, XNOR-Net): weights are binarised to {-1, +1} in the forward pass, STE passes gradients to the full-precision latent weights
- Hard attention mechanisms: argmax selection of attention positions, with STE for gradient flow
- Neural discrete representation learning (dVAE in DALL-E 1): discrete tokens for image generation use STE or Gumbel-softmax
- Learned quantisation (neural compression): differentiable rounding for entropy coding
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Gumbel-softmax | Categorical selection with smoother gradients | Lower bias than STE (approaches true gradient as τ→0) but requires a temperature schedule and is a soft approximation during training |
| Reparameterisation trick | Continuous latent variables | Exact gradients, zero bias — but only works for continuous distributions, not discrete operations |
| REINFORCE / score function | Any discrete operation, unbiased gradients needed | Unbiased but extremely high variance; impractical for high-dimensional discrete spaces like codebooks |
| EMA codebook update | Updating codebook vectors (used alongside STE) | Avoids backprop through codebook entirely; more stable than gradient-based codebook updates |
| Finite differences | Debugging, gradient checking | Unbiased but scales as O(D) per parameter — only useful for verification, never training |
Historical Context
Section titled “Historical Context”The straight-through estimator was introduced by Hinton in his 2012 Coursera lectures as a practical trick for training networks with discrete hidden units. Bengio, Leonard & Courville (2013, “Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation”) formalised and analysed it, showing that despite its bias, it worked well in practice for binary and discrete stochastic neurons.
Its most influential application came with VQ-VAE (van den Oord et al., 2017), which used STE to train a discrete autoencoder that produced high-quality codebook representations. This architecture became foundational — DALL-E 1 used a discrete VAE (dVAE) with Gumbel-softmax relaxation as an alternative, and later work on audio (SoundStream, Encodec) and language-image models relied on VQ codebooks with STE-based training.