Skip to content

Weight Decay

Penalises large weights by adding λw2\lambda \|\mathbf{w}\|^2 to the loss, or equivalently, shrinking weights by a factor each step. Keeps weights small, reduces overfitting, and improves generalisation. The most universal regulariser in deep learning — used in virtually every modern model.

Think of weight decay as a spring pulling every weight back toward zero. Without it, weights are free to grow as large as they want to fit the training data — potentially memorising noise. The spring provides a restoring force: the further a weight drifts from zero, the harder it gets pulled back. Only weights that significantly reduce the loss can justify being large.

This has a smoothing effect on the learned function. Large weights create sharp, spiky decision boundaries that overfit to individual training examples. Small weights produce smoother functions that generalise better. Weight decay is essentially Occam’s razor in optimisation form: prefer the simplest (smallest-weight) model that explains the data.

A critical subtlety: classical L2 regularisation (add λw2\lambda \|\mathbf{w}\|^2 to the loss) and decoupled weight decay (subtract λw\lambda \mathbf{w} from the weight each step) are identical for SGD but different for Adam. Adam scales gradients by their running variance, which also scales the L2 penalty — effectively applying less regularisation to weights with large gradients. AdamW fixes this by applying weight decay directly to the weights, outside the adaptive learning rate. This is why AdamW is the modern default.

L2 regularisation — add penalty to loss:

Lreg=L+λ2w2\mathcal{L}_{\text{reg}} = \mathcal{L} + \frac{\lambda}{2} \|\mathbf{w}\|^2

The gradient becomes:

wLreg=wL+λw\nabla_{\mathbf{w}} \mathcal{L}_{\text{reg}} = \nabla_{\mathbf{w}} \mathcal{L} + \lambda \mathbf{w}

SGD with L2 — the update step:

wwη(wL+λw)=(1ηλ)wηwL\mathbf{w} \leftarrow \mathbf{w} - \eta(\nabla_{\mathbf{w}} \mathcal{L} + \lambda \mathbf{w}) = (1 - \eta\lambda)\mathbf{w} - \eta \nabla_{\mathbf{w}} \mathcal{L}

The factor (1ηλ)(1 - \eta\lambda) shrinks weights toward zero each step — this is why it’s called “weight decay.”

Decoupled weight decay (AdamW):

w(1λ)wηAdam(wL)\mathbf{w} \leftarrow (1 - \lambda)\mathbf{w} - \eta \cdot \text{Adam}(\nabla_{\mathbf{w}} \mathcal{L})

The decay is applied directly to the weights, not routed through Adam’s adaptive scaling.

Key difference: with Adam + L2, the effective regularisation per weight is λ/vt+ϵ\lambda / \sqrt{v_t + \epsilon} (weakened for high-variance gradients). With AdamW, every weight gets exactly λ\lambda decay regardless of gradient history.

import torch
# ── AdamW (decoupled weight decay) — the modern default ────────
# weight_decay parameter IS the decay factor λ, applied directly.
optimiser = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# ── SGD with weight decay ──────────────────────────────────────
# For SGD, weight_decay and L2 are equivalent.
optimiser = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-4)
# ── Adam with L2 (NOT the same as AdamW) ───────────────────────
# torch.optim.Adam's weight_decay parameter does L2, not decoupled decay.
# WARNING: this is almost never what you want. Use AdamW instead.
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.01)
# ── Exclude bias and LayerNorm from weight decay ───────────────
# Common practice: only regularise weight matrices, not biases or norms.
decay_params = [p for n, p in model.named_parameters() if "weight" in n and p.dim() >= 2]
no_decay_params = [p for n, p in model.named_parameters() if "weight" not in n or p.dim() < 2]
optimiser = torch.optim.AdamW([
{"params": decay_params, "weight_decay": 0.01},
{"params": no_decay_params, "weight_decay": 0.0},
], lr=1e-3)
import numpy as np
def sgd_step_with_weight_decay(params, grads, lr, weight_decay):
"""
SGD + L2 regularisation (equivalent to decoupled weight decay for SGD).
params: list of arrays (model weights)
grads: list of arrays (gradients, same shapes)
"""
for w, g in zip(params, grads):
w -= lr * (g + weight_decay * w) # shrink + step
def adamw_step(params, grads, m_states, v_states, lr, weight_decay, beta1, beta2, eps, t):
"""
Decoupled weight decay (AdamW). One step.
m_states, v_states: running moment estimates, same shapes as params.
t: step number (1-indexed for bias correction).
"""
for w, g, m, v in zip(params, grads, m_states, v_states):
m[:] = beta1 * m + (1 - beta1) * g # first moment
v[:] = beta2 * v + (1 - beta2) * g ** 2 # second moment
m_hat = m / (1 - beta1 ** t) # bias correction
v_hat = v / (1 - beta2 ** t) # bias correction
w -= lr * m_hat / (np.sqrt(v_hat) + eps) # Adam step
w -= weight_decay * w # decoupled decay
  • LLM pretraining (GPT, LLaMA, Mistral): AdamW with weight_decay=0.1, excluding biases and layer norms
  • Vision transformers (ViT, DeiT): weight_decay=0.05 is typical; critical for ViT which overfits easily without it
  • Fine-tuning: often increased weight decay to prevent catastrophic forgetting (keeps weights close to pretrained values)
  • CNN training (ResNet): SGD with weight_decay=1e-4, the classic recipe
  • GAN training: careful tuning needed — too much decay can destabilise the discriminator/generator balance
AlternativeWhen to useTradeoff
L1 regularisationWhen you want sparse weights (feature selection)Drives weights exactly to zero; less common in deep learning, more in linear models
DropoutMLPs, older architecturesRegularises activations not weights; introduces train/test discrepancy
Early stoppingWhen validation loss is monitoredImplicitly limits effective model capacity; no hyperparameter to tune beyond patience
Max-norm constraintRNNs, when weight explosion is a concernClips weight magnitude directly; harder to tune than smooth decay
Spectral normalisationGAN discriminatorsConstrains the spectral norm specifically; more targeted than general weight decay

Weight decay dates back to ridge regression in statistics (Hoerl & Kennard, 1970) and was adopted early in neural network training as a standard regulariser. For decades, “L2 regularisation” and “weight decay” were treated as synonymous because they are equivalent under SGD.

The critical insight came from Loshchilov & Hutter (2019, “Decoupled Weight Decay Regularization”), who showed that L2 regularisation and weight decay diverge under adaptive optimisers like Adam. Their AdamW variant — applying decay directly to weights rather than through the gradient — was a simple fix that meaningfully improved generalisation. AdamW is now the default optimiser for virtually all transformer training, and “weight_decay=0.01 to 0.1” is standard in most recipes.