Skip to content

Stop-Gradient (Detach)

Blocks gradient flow through a value while keeping it in the forward pass. The forward computation proceeds normally, but during backpropagation the detached tensor is treated as a constant. This is how you say “use this value, but don’t learn through it.”

Consider a student trying to hit a moving target. If the target is allowed to move in response to the student’s attempts, the whole system can cheat — the target drifts toward where the student already is, and nobody actually improves. Stop-gradient pins the target in place so the student must genuinely learn to reach it.

This is exactly what happens in target networks (Q-learning, BYOL, MoCo). The target network provides a stable reference. Without stop-gradient, gradients would flow back into the target network, causing it to shift toward the online network’s current predictions — the loss drops but nothing meaningful is learned. The system collapses.

The pattern generalises: any time your computation graph has a component you want to hold fixed during a gradient step — a frozen encoder, a looked-up codebook entry, a cached log-probability from a previous policy — you need stop-gradient.

Standard gradient flow:

Lθ=Lfθ(x)fθ(x)θ\frac{\partial \mathcal{L}}{\partial \theta} = \frac{\partial \mathcal{L}}{\partial f_\theta(x)} \cdot \frac{\partial f_\theta(x)}{\partial \theta}

With stop-gradient on fθ(x)f_\theta(x):

Lθ=0(gradient blocked)\frac{\partial \mathcal{L}}{\partial \theta} = 0 \quad \text{(gradient blocked)}

The value fθ(x)f_\theta(x) is still used in the forward pass — only backpropagation is affected.

VQ-VAE straight-through estimator:

z=ze+sg(zqze)z = z_e + \text{sg}(z_q - z_e)

Forward pass: z=zqz = z_q (the quantised code). Backward pass: z/ze=1\partial z / \partial z_e = 1 (gradients pass straight through to the encoder, skipping the non-differentiable lookup).

PPO probability ratio:

rt(θ)=πθ(atst)πθold(atst)r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_\text{old}}(a_t | s_t)}

The denominator πθold\pi_{\theta_\text{old}} is detached — it was computed before the update and must stay fixed.

import torch
import torch.nn.functional as F
# ── Basic detach: freeze a target value ──────────────────────────
target = target_network(next_state) # (B, n_actions)
target = target.detach() # no gradients flow into target_network
loss = F.mse_loss(online_network(state), target) # only online_network gets gradients
# ── VQ-VAE straight-through ──────────────────────────────────────
z_e = encoder(x) # (B, D) — continuous
z_q = codebook_lookup(z_e) # (B, D) — quantised (non-differentiable)
z = z_e + (z_q - z_e).detach() # forward: z_q, backward: z_e
# ── torch.no_grad() for entire blocks ────────────────────────────
with torch.no_grad():
# Nothing inside here builds a computation graph
target_features = momentum_encoder(x) # (B, D) — used in MoCo/BYOL
# ── PPO: detach old log-probs ────────────────────────────────────
with torch.no_grad():
old_log_probs = policy(states).log_prob(actions) # (B,) — fixed reference
new_log_probs = policy(states).log_prob(actions) # (B,) — has gradients
ratio = (new_log_probs - old_log_probs).exp() # (B,)

Warning: .detach() and torch.no_grad() are not the same. .detach() removes a single tensor from the graph but still computes forward ops inside the graph. torch.no_grad() disables graph construction entirely for everything inside the block — faster but you can’t selectively detach.

import numpy as np
def stop_gradient(x):
"""Conceptual stop-gradient: returns a copy that won't propagate gradients."""
return x.copy() # in a real autograd, this node has zero Jacobian
def straight_through_vq(z_e, codebook):
"""VQ-VAE forward: quantised value; backward: gradients bypass to z_e."""
# Find nearest codebook entry
dists = ((z_e[:, None, :] - codebook[None, :, :]) ** 2).sum(axis=2) # (B, K)
indices = dists.argmin(axis=1) # (B,)
z_q = codebook[indices] # (B, D)
# Straight-through: forward uses z_q, backward uses z_e
z_forward = z_q # what the decoder sees
z_backward_proxy = z_e # what gets gradients
# In autograd: z = z_e + sg(z_q - z_e) = z_q forward, dz/dz_e = 1 backward
return z_forward, z_backward_proxy, indices
def target_network_loss(online_out, target_out):
"""MSE loss where target is detached (treated as constant)."""
target_fixed = stop_gradient(target_out) # no grad
return ((online_out - target_fixed) ** 2).mean()
  • Q-learning target networks (DQN, Double DQN): detach target Q-values so the target network provides a stable bootstrap
  • MoCo momentum encoder: the key encoder is never trained by gradients — only updated via EMA of the query encoder
  • BYOL target network: same pattern as MoCo — stop-gradient on the target branch prevents representation collapse
  • VQ-VAE codebook (straight-through estimator): bypass the non-differentiable quantisation step
  • PPO old policy: log-probabilities from the previous policy iteration are detached constants in the ratio
  • Knowledge distillation: teacher network outputs are detached — only the student learns
AlternativeWhen to useTradeoff
EMA update (no gradients at all)Target/momentum networks (MoCo, BYOL, DQN)Smoother than periodic copy; target drifts slowly rather than jumping
Frozen parameters (requires_grad=False)Permanently frozen layers (fine-tuning)More explicit but doesn’t work for values that need to be in the graph for other paths
torch.no_grad() contextInference or computing fixed targets in bulkFaster than .detach() but all-or-nothing — can’t selectively block
Custom autograd FunctionNeed full control over backward pass (straight-through variants)Most flexible but harder to debug
Periodic hard copySimple target network updates (original DQN)Simpler than EMA; causes discontinuities every N steps

Stop-gradient as an explicit tool became prominent with DQN (Mnih et al., 2015), which used a periodically-copied target network to stabilise Q-learning. The straight-through estimator dates to Bengio et al. (2013) for training networks with discrete/binary activations.

The technique became central to self-supervised learning with MoCo (He et al., 2020) and BYOL (Grill et al., 2020), where stop-gradient on the target branch was shown to be critical for preventing representation collapse. SimSiam (Chen & He, 2021) demonstrated that stop-gradient alone (without momentum or negative pairs) can prevent collapse, sparking significant theoretical interest in why it works.