Stop-Gradient (Detach)
Stop-Gradient (Detach)
Section titled “Stop-Gradient (Detach)”Blocks gradient flow through a value while keeping it in the forward pass. The forward computation proceeds normally, but during backpropagation the detached tensor is treated as a constant. This is how you say “use this value, but don’t learn through it.”
Intuition
Section titled “Intuition”Consider a student trying to hit a moving target. If the target is allowed to move in response to the student’s attempts, the whole system can cheat — the target drifts toward where the student already is, and nobody actually improves. Stop-gradient pins the target in place so the student must genuinely learn to reach it.
This is exactly what happens in target networks (Q-learning, BYOL, MoCo). The target network provides a stable reference. Without stop-gradient, gradients would flow back into the target network, causing it to shift toward the online network’s current predictions — the loss drops but nothing meaningful is learned. The system collapses.
The pattern generalises: any time your computation graph has a component you want to hold fixed during a gradient step — a frozen encoder, a looked-up codebook entry, a cached log-probability from a previous policy — you need stop-gradient.
Standard gradient flow:
With stop-gradient on :
The value is still used in the forward pass — only backpropagation is affected.
VQ-VAE straight-through estimator:
Forward pass: (the quantised code). Backward pass: (gradients pass straight through to the encoder, skipping the non-differentiable lookup).
PPO probability ratio:
The denominator is detached — it was computed before the update and must stay fixed.
import torchimport torch.nn.functional as F
# ── Basic detach: freeze a target value ──────────────────────────target = target_network(next_state) # (B, n_actions)target = target.detach() # no gradients flow into target_networkloss = F.mse_loss(online_network(state), target) # only online_network gets gradients
# ── VQ-VAE straight-through ──────────────────────────────────────z_e = encoder(x) # (B, D) — continuousz_q = codebook_lookup(z_e) # (B, D) — quantised (non-differentiable)z = z_e + (z_q - z_e).detach() # forward: z_q, backward: z_e
# ── torch.no_grad() for entire blocks ────────────────────────────with torch.no_grad(): # Nothing inside here builds a computation graph target_features = momentum_encoder(x) # (B, D) — used in MoCo/BYOL
# ── PPO: detach old log-probs ────────────────────────────────────with torch.no_grad(): old_log_probs = policy(states).log_prob(actions) # (B,) — fixed referencenew_log_probs = policy(states).log_prob(actions) # (B,) — has gradientsratio = (new_log_probs - old_log_probs).exp() # (B,)Warning: .detach() and torch.no_grad() are not the same. .detach() removes a single tensor from the graph but still computes forward ops inside the graph. torch.no_grad() disables graph construction entirely for everything inside the block — faster but you can’t selectively detach.
Manual Implementation
Section titled “Manual Implementation”import numpy as np
def stop_gradient(x): """Conceptual stop-gradient: returns a copy that won't propagate gradients.""" return x.copy() # in a real autograd, this node has zero Jacobian
def straight_through_vq(z_e, codebook): """VQ-VAE forward: quantised value; backward: gradients bypass to z_e.""" # Find nearest codebook entry dists = ((z_e[:, None, :] - codebook[None, :, :]) ** 2).sum(axis=2) # (B, K) indices = dists.argmin(axis=1) # (B,) z_q = codebook[indices] # (B, D)
# Straight-through: forward uses z_q, backward uses z_e z_forward = z_q # what the decoder sees z_backward_proxy = z_e # what gets gradients # In autograd: z = z_e + sg(z_q - z_e) = z_q forward, dz/dz_e = 1 backward return z_forward, z_backward_proxy, indices
def target_network_loss(online_out, target_out): """MSE loss where target is detached (treated as constant).""" target_fixed = stop_gradient(target_out) # no grad return ((online_out - target_fixed) ** 2).mean()Popular Uses
Section titled “Popular Uses”- Q-learning target networks (DQN, Double DQN): detach target Q-values so the target network provides a stable bootstrap
- MoCo momentum encoder: the key encoder is never trained by gradients — only updated via EMA of the query encoder
- BYOL target network: same pattern as MoCo — stop-gradient on the target branch prevents representation collapse
- VQ-VAE codebook (straight-through estimator): bypass the non-differentiable quantisation step
- PPO old policy: log-probabilities from the previous policy iteration are detached constants in the ratio
- Knowledge distillation: teacher network outputs are detached — only the student learns
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| EMA update (no gradients at all) | Target/momentum networks (MoCo, BYOL, DQN) | Smoother than periodic copy; target drifts slowly rather than jumping |
Frozen parameters (requires_grad=False) | Permanently frozen layers (fine-tuning) | More explicit but doesn’t work for values that need to be in the graph for other paths |
| torch.no_grad() context | Inference or computing fixed targets in bulk | Faster than .detach() but all-or-nothing — can’t selectively block |
| Custom autograd Function | Need full control over backward pass (straight-through variants) | Most flexible but harder to debug |
| Periodic hard copy | Simple target network updates (original DQN) | Simpler than EMA; causes discontinuities every N steps |
Historical Context
Section titled “Historical Context”Stop-gradient as an explicit tool became prominent with DQN (Mnih et al., 2015), which used a periodically-copied target network to stabilise Q-learning. The straight-through estimator dates to Bengio et al. (2013) for training networks with discrete/binary activations.
The technique became central to self-supervised learning with MoCo (He et al., 2020) and BYOL (Grill et al., 2020), where stop-gradient on the target branch was shown to be critical for preventing representation collapse. SimSiam (Chen & He, 2021) demonstrated that stop-gradient alone (without momentum or negative pairs) can prevent collapse, sparking significant theoretical interest in why it works.