Skip to content

Learning Rate Warmup

Linearly increasing the learning rate from ~0 to the target value over the first N steps. Prevents early instability when Adam’s second-moment estimates are not yet calibrated. Standard practice in transformer training (GPT, BERT, ViT).

At step 0, Adam’s running estimate of the squared gradient (vv) is initialised to zero. Because Adam divides by v+ϵ\sqrt{v} + \epsilon, those first updates are divided by something tiny, producing wildly large parameter changes. The model can diverge in the first few hundred steps before the optimizer has seen enough gradients to form reliable estimates.

Warmup is the fix: start with a near-zero learning rate and ramp it up linearly over, say, 1000 steps. During that ramp, even though the per-parameter scaling in Adam is unreliable, the small global learning rate keeps the actual updates small. By the time the learning rate reaches its full value, the second-moment estimates have accumulated enough history to be trustworthy.

This is why warmup is most critical for Adam-family optimizers. Plain SGD with momentum doesn’t have the second-moment issue, so warmup helps less there (though it can still smooth the initial transient). The deeper or more attention-heavy the model, the more warmup matters — large transformers routinely use 1-5% of total steps for warmup.

Linear warmup schedule (step tt, warmup steps TwT_w, target learning rate ηmax\eta_{\max}):

η(t)=ηmaxtTw,tTw\eta(t) = \eta_{\max} \cdot \frac{t}{T_w}, \quad t \le T_w

After warmup, the learning rate is typically handed off to a decay schedule (cosine annealing, linear decay, etc.):

η(t)={ηmaxtTwtTwdecay(ηmax,tTw)t>Tw\eta(t) = \begin{cases} \eta_{\max} \cdot \frac{t}{T_w} & t \le T_w \\ \text{decay}(\eta_{\max}, t - T_w) & t > T_w \end{cases}

Common default: TwT_w = 1-5% of total training steps. GPT-3 used 375 warmup steps out of ~300k total. ViT used 10k warmup steps.

import torch
# ── Using PyTorch's built-in schedulers ─────────────────────────
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# Linear warmup for 1000 steps, then constant
warmup = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1e-8/3e-4, total_iters=1000 # ramp from ~0 to 3e-4
)
# Warmup + cosine decay (the standard transformer recipe)
warmup = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1e-8/3e-4, total_iters=1000
)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=99000 # remaining steps after warmup
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup, cosine], milestones=[1000]
)
# ── In the training loop ────────────────────────────────────────
for step, batch in enumerate(dataloader):
loss = model(batch).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step() # call AFTER optimizer.step()
import numpy as np
def warmup_lr(step, lr_max, warmup_steps):
"""
Returns the learning rate at a given step during linear warmup.
step: current training step (0-indexed)
lr_max: target learning rate after warmup
warmup_steps: number of steps to ramp over
"""
if step < warmup_steps:
return lr_max * (step / warmup_steps) # linear ramp
return lr_max # constant after
def warmup_cosine_lr(step, lr_max, warmup_steps, total_steps, lr_min=0.0):
"""
Linear warmup followed by cosine decay — the standard transformer schedule.
"""
if step < warmup_steps:
return lr_max * (step / warmup_steps) # linear ramp
# Cosine decay phase
progress = (step - warmup_steps) / (total_steps - warmup_steps) # 0 → 1
return lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(np.pi * progress))
# Example: 100k steps, 2k warmup, peak lr 3e-4
lrs = [warmup_cosine_lr(t, 3e-4, 2000, 100000, 1e-5) for t in range(100000)]
# lrs[0] ≈ 0, lrs[2000] = 3e-4, lrs[100000-1] ≈ 1e-5
  • LLM pre-training (GPT, LLaMA, Chinchilla): linear warmup + cosine decay is the de facto standard schedule
  • Vision transformers (ViT, DeiT, Swin): warmup is critical because self-attention layers amplify the early-step instability
  • BERT / masked language modelling: original BERT paper used 10k warmup steps out of 1M total
  • Fine-tuning large models: shorter warmup (100-500 steps) helps stabilise the first few gradient updates on a new task
  • Diffusion models (DDPM, Stable Diffusion): warmup used in UNet training to prevent early divergence
AlternativeWhen to useTradeoff
No warmup (constant LR)SGD with momentum on CNNsWorks for simpler optimizers; Adam without warmup risks early divergence
Exponential warmupWhen linear ramp is too slowFaster ramp but harder to tune; less common in practice
RAdamDrop-in Adam replacementAutomatically corrects the variance bias in Adam’s early steps, removing the need for explicit warmup. Slightly higher compute per step
Gradual unfreezingTransfer learning / fine-tuningWarms up capacity rather than learning rate — unfreeze layers one at a time. Complementary to LR warmup
Learning rate probingUnknown good LR rangeLR range test (Smith 2017) sweeps LR to find the right max before setting the warmup target

The need for warmup was first identified empirically in the transformer paper (Vaswani et al., 2017, “Attention Is All You Need”), which used a specific schedule: warmup over 4000 steps followed by inverse-square-root decay. The authors didn’t explain it theoretically — it was a practical fix that made training converge.

The theoretical justification came from Liu et al. (2020, “On the Variance of the Adaptive Learning Rate and Beyond”), who showed that Adam’s adaptive learning rate has excessively high variance in early training because the second-moment estimate is biased toward zero. Their RAdam optimizer corrects this analytically, but linear warmup remains the dominant practical solution because it’s simpler, well-understood, and composes cleanly with any decay schedule.