Skip to content

Gradient Clipping

Capping gradient norms to prevent exploding gradients: if the total gradient norm exceeds a threshold, scale all gradients down proportionally. Essential for RNNs, transformers, and reinforcement learning, where gradient magnitudes can spike by orders of magnitude on a single batch.

Neural network training occasionally produces catastrophically large gradients. In an RNN unrolled over 100 timesteps, a gradient can multiply through the same weight matrix 100 times — if the largest singular value is even slightly above 1, the gradient grows exponentially. One bad batch can produce a gradient 1000x larger than normal, launching the parameters into a region from which the model never recovers.

Gradient clipping is a safety valve: before applying the gradient, measure its total norm. If it’s below the threshold, do nothing. If it exceeds the threshold, scale the entire gradient vector down so its norm equals the threshold. This preserves the direction (which is still useful information) while capping the step size.

The key subtlety: clipping operates on the global norm across all parameters, not per-parameter. Per-parameter clipping would distort the gradient direction — some parameters would be clipped and others not, producing a direction the loss function never suggested. Global norm clipping scales everything uniformly, so the direction is preserved exactly.

Global norm clipping (gradient vector gg, max norm cc):

g={gif gcgcgif g>cg' = \begin{cases} g & \text{if } \|g\| \le c \\ g \cdot \frac{c}{\|g\|} & \text{if } \|g\| > c \end{cases}

where g=igi2\|g\| = \sqrt{\sum_i g_i^2} is the L2 norm over all parameters concatenated.

Equivalently: g=gmin ⁣(1,cg)g' = g \cdot \min\!\left(1, \frac{c}{\|g\|}\right)

Value clipping (less common, per-element): gi=clamp(gi,c,c)g'_i = \text{clamp}(g_i, -c, c). This DOES distort direction and is rarely used in modern practice.

Common defaults: c=1.0c = 1.0 for transformers (GPT, BERT, T5). c=0.5c = 0.5 for some RL algorithms (PPO). c=5.0c = 5.0 was common for LSTMs.

import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss = model(batch).loss
loss.backward()
# ── Global norm clipping (the standard approach) ────────────────
# Clips all parameter gradients so their combined L2 norm ≤ max_norm.
# Returns the total norm BEFORE clipping — useful for logging.
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0 # 1.0 is the transformer default
)
# WARNING: call clip_grad_norm_ AFTER backward(), BEFORE optimizer.step()
optimizer.step()
optimizer.zero_grad()
# ── Log the gradient norm to detect spikes ──────────────────────
# If total_norm is frequently >> max_norm, you may need a lower LR
# or there's a data issue. If it never triggers, the threshold is too high.
print(f"grad norm: {total_norm:.2f}") # or log to TensorBoard/W&B
# ── Value clipping (rarely used, shown for completeness) ────────
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
import numpy as np
def clip_grad_norm(grads, max_norm):
"""
Global norm gradient clipping.
grads: list of numpy arrays (one per parameter), already computed
max_norm: maximum allowed L2 norm
Returns: clipped gradients (same structure), total norm before clipping
"""
# Compute global L2 norm across all parameters
total_norm = np.sqrt(sum(np.sum(g ** 2) for g in grads))
# Scale factor: 1.0 if within budget, else shrink proportionally
clip_coef = min(1.0, max_norm / (total_norm + 1e-6))
clipped = [g * clip_coef for g in grads] # preserve direction
return clipped, total_norm
# Example
g1 = np.random.randn(768, 768) * 10 # large gradient on a weight matrix
g2 = np.random.randn(768) * 10 # large gradient on a bias
clipped, norm = clip_grad_norm([g1, g2], max_norm=1.0)
# norm ≈ 270, each gradient scaled by ~1/270 → clipped norm = 1.0
  • Transformer training (GPT, BERT, T5, LLaMA): max_norm=1.0 is near-universal; without clipping, attention gradients spike on long sequences
  • RNN/LSTM training: the original motivation for gradient clipping — recurrent gradient paths multiply through the same matrix, causing exponential blowup
  • Reinforcement learning (PPO, SAC, DQN): reward spikes and rare transitions cause high-variance gradients; clipping at 0.5-1.0 stabilises training
  • GAN training: discriminator gradients can explode when the generator produces near-perfect or degenerate outputs
  • Multi-task / multi-loss training: different loss terms can have vastly different gradient scales; clipping prevents one term from dominating
AlternativeWhen to useTradeoff
Gradient penalty (WGAN-GP)GAN trainingPenalises large gradients in the loss rather than post-hoc clipping; smoother but adds compute
Adaptive gradient methods (Adam)General trainingAdam’s per-parameter scaling implicitly normalises gradient magnitude, but doesn’t prevent global spikes
Weight decayPreventing parameter explosionControls weight magnitude, not gradient magnitude — complementary to clipping
Skip / reduce stepWhen any clipping is too aggressiveIf grad norm > threshold, skip the update entirely or reduce the LR. Used in some LLM training (PaLM)
Layer-wise clippingHeterogeneous architecturesClip each layer independently — useful when layers have very different gradient scales, but distorts global direction

Gradient clipping was introduced by Pascanu et al. (2013, “On the difficulty of training recurrent neural networks”), who provided the theoretical analysis of exploding gradients in RNNs and proposed clipping as the solution. The insight was simple: gradient explosion is a multiplicative phenomenon (eigenvalues > 1 compounding over time steps), and the fix is a hard ceiling on the result.

The technique became universal with the rise of transformers. Although self-attention doesn’t have the same recurrent multiplication problem, the softmax and residual connections can still produce gradient spikes, especially on long sequences or early in training. The GPT-2 and GPT-3 training recipes hardcoded max_norm=1.0, and virtually every LLM training codebase has followed suit. It costs essentially nothing (one norm computation per step) and prevents the rare catastrophic update that can derail a multi-week training run.