Skip to content

Weight Initialisation (Xavier / Kaiming)

Controls the scale of initial weights so that signals neither explode nor vanish as they propagate through layers. Without proper initialisation, a 50-layer network’s activations can overflow to infinity or collapse to zero before training even begins.

Imagine passing a message through a chain of people, where each person multiplies the message by a random number. If that number is typically greater than 1, the message grows exponentially. If less than 1, it shrinks to nothing. The fix: choose multipliers that keep the message roughly the same size on average.

That is exactly what happens with layer activations. Each layer multiplies its input by a weight matrix. If the variance of the weights is too high, activations explode; too low, they vanish. The solution is to set the initial weight variance so that the variance of the output equals the variance of the input — “variance preservation.”

Xavier init achieves this for linear/tanh activations by setting Var(w)=2/(fan_in+fan_out)\text{Var}(w) = 2/(\text{fan\_in} + \text{fan\_out}). Kaiming init adjusts for ReLU, which zeros out half the activations: to compensate, it doubles the variance to 2/fan_in2/\text{fan\_in}. This single factor-of-2 correction is what allows deep ReLU networks to train stably.

The problem: for layer y=Wxy = Wx with WRnout×ninW \in \mathbb{R}^{n_\text{out} \times n_\text{in}}:

Var(yj)=ninVar(w)Var(x)\text{Var}(y_j) = n_\text{in} \cdot \text{Var}(w) \cdot \text{Var}(x)

To preserve variance (Var(y)=Var(x)\text{Var}(y) = \text{Var}(x)), we need Var(w)=1/nin\text{Var}(w) = 1/n_\text{in}.

Xavier / Glorot init (linear, tanh, sigmoid — symmetric activations):

Var(w)=2nin+nout\text{Var}(w) = \frac{2}{n_\text{in} + n_\text{out}}

Compromises between preserving variance in both forward and backward passes.

Kaiming / He init (ReLU — kills half the activations):

Var(w)=2nin\text{Var}(w) = \frac{2}{n_\text{in}}

The factor of 2 compensates for ReLU zeroing out negative values, which halves the variance.

Uniform vs normal sampling:

  • Uniform [a,a][-a, a]: Var=a2/3\text{Var} = a^2/3, so a=3Var(w)a = \sqrt{3 \cdot \text{Var}(w)}
  • Normal N(0,σ2)\mathcal{N}(0, \sigma^2): σ=Var(w)\sigma = \sqrt{\text{Var}(w)}

Transformer scaled init (common practice):

σ=1dmodelorσ=0.022nlayers\sigma = \frac{1}{\sqrt{d_\text{model}}} \quad \text{or} \quad \sigma = \frac{0.02}{\sqrt{2 \cdot n_\text{layers}}}

GPT-2 scales residual projections by 1/nlayers1/\sqrt{n_\text{layers}} to prevent the residual sum from growing.

import torch
import torch.nn as nn
# ── Xavier (Glorot) — for tanh/sigmoid layers ────────────────────
linear = nn.Linear(256, 128)
nn.init.xavier_uniform_(linear.weight) # U[-a, a], a = sqrt(6/(fan_in+fan_out))
nn.init.xavier_normal_(linear.weight) # N(0, 2/(fan_in+fan_out))
# ── Kaiming (He) — for ReLU layers ───────────────────────────────
conv = nn.Conv2d(64, 128, 3)
nn.init.kaiming_normal_(conv.weight, mode='fan_in', nonlinearity='relu')
nn.init.kaiming_uniform_(conv.weight, mode='fan_in', nonlinearity='relu')
# mode='fan_out' preserves variance in backward pass instead
# ── Transformer-style scaled init ────────────────────────────────
d_model = 768
layer = nn.Linear(d_model, d_model)
nn.init.normal_(layer.weight, mean=0.0, std=1 / (d_model ** 0.5))
nn.init.zeros_(layer.bias)
# ── Check what PyTorch does by default ───────────────────────────
default_linear = nn.Linear(256, 128) # Kaiming uniform by default
# Verify: weight variance should be ≈ 1/fan_in = 1/256
print(default_linear.weight.var().item()) # ≈ 0.0039

Warning: PyTorch’s nn.Linear default is Kaiming uniform — this is fine for ReLU but suboptimal for other activations. If using GELU or SiLU, consider explicit initialisation.

import numpy as np
def xavier_uniform(fan_in, fan_out):
"""Xavier/Glorot uniform init. Best for tanh/sigmoid."""
a = np.sqrt(6.0 / (fan_in + fan_out))
return np.random.uniform(-a, a, size=(fan_out, fan_in)) # (fan_out, fan_in)
def xavier_normal(fan_in, fan_out):
"""Xavier/Glorot normal init."""
std = np.sqrt(2.0 / (fan_in + fan_out))
return np.random.randn(fan_out, fan_in) * std # (fan_out, fan_in)
def kaiming_normal(fan_in, fan_out):
"""Kaiming/He normal init. Best for ReLU."""
std = np.sqrt(2.0 / fan_in)
return np.random.randn(fan_out, fan_in) * std # (fan_out, fan_in)
def kaiming_uniform(fan_in, fan_out):
"""Kaiming/He uniform init. PyTorch nn.Linear default."""
a = np.sqrt(6.0 / fan_in) # sqrt(3 * 2/fan_in)
return np.random.uniform(-a, a, size=(fan_out, fan_in)) # (fan_out, fan_in)
# Verify variance preservation through a 50-layer ReLU network
x = np.random.randn(32, 256) # (B, D)
for _ in range(50):
W = kaiming_normal(256, 256) # (256, 256)
x = x @ W.T # (B, 256)
x = np.maximum(0, x) # ReLU
print(f"Activation std after 50 layers: {x.std():.4f}") # should be ≈ O(1)
  • ResNets (He et al.): Kaiming init enabled training of 100+ layer networks with ReLU; without it, these networks don’t converge
  • Transformers (GPT, BERT, LLaMA): scaled normal init with σ=0.02\sigma = 0.02 or 1/d1/\sqrt{d}; residual projections scaled by 1/nlayers1/\sqrt{n_\text{layers}}
  • GANs: proper init is critical — DCGAN specifies N(0,0.02)\mathcal{N}(0, 0.02) for all weights
  • LSTM / GRU: orthogonal init for recurrent weights preserves gradient norms across time steps
  • nn-training entry: the init_weights variant axis demonstrates Xavier vs Kaiming vs scaled init
AlternativeWhen to useTradeoff
Orthogonal initRNNs, very deep networksPreserves gradient norms exactly; slightly more expensive to compute
LSUV (Layer-Sequential Unit Variance)Networks with unusual architecturesData-driven: passes a batch and rescales each layer; more robust but slower
Fixup initResNets without BatchNormScales residual branches by 1/nlayers1/\sqrt{n_\text{layers}}; avoids need for normalisation layers
Zero init (for residual branches)Transformer residual projections, ReZeroIdentity at init — each layer starts as a no-op; stable training but slower early progress
Pretrained weightsTransfer learning, fine-tuningBypasses init entirely; best when sufficient pretraining data exists

The variance preservation idea was formalised by Glorot & Bengio (2010) as “Xavier initialisation,” derived for linear and tanh activations. He et al. (2015) extended this to ReLU networks as “Kaiming initialisation” — the factor-of-2 correction was the key insight that enabled training of very deep residual networks.

Before these principled approaches, practitioners used heuristics like N(0,0.01)\mathcal{N}(0, 0.01) or N(0,0.02)\mathcal{N}(0, 0.02), which happened to work for shallow networks but failed catastrophically for deep ones. Modern normalisation layers (BatchNorm, LayerNorm) reduce sensitivity to initialisation but don’t eliminate it — scaled init in Transformers remains important, especially for training stability at large model sizes.