"""
Core Neural Network Training — Minimal Complete Example
========================================================
A single file covering the fundamental requirements for training
a neural network: architecture, initialisation, training loop,
and logging. Uses a simple MLP on synthetic data so the focus
stays on the machinery, not the task.
The five stages every training script goes through:
1. Define the model — architecture, activations, normalisation
2. Initialise weights — set a good starting point
3. Set up optimiser — how gradients become weight updates
4. Training loop — forward → loss → backward → step
5. Logging / monitoring — track what matters, catch problems early
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
# ═══════════════════════════════════════════════════════════════════
# 1. DEFINE THE MODEL
# ═══════════════════════════════════════════════════════════════════
#
# An MLP is just a stack of: Linear → Norm → Activation → repeat
#
# LayerNorm is the simplest normalisation choice:
# • BatchNorm — normalises across the batch. Fast, effective, but
# behaviour changes between train/eval and breaks with batch=1.
# • LayerNorm — normalises across features within each sample.
# Consistent train/eval behaviour. Default choice for transformers.
# • RMSNorm — like LayerNorm but skips the mean-centering step.
# Slightly cheaper, used in LLaMA-style models.
#
# ReLU is the simplest activation choice:
# • ReLU — max(0, x). Simple, sparse, but "dead neuron" risk.
# • GELU — smooth approximation of ReLU, default in transformers.
# • SiLU — x · σ(x), aka Swish. Popular in vision and LLMs.
class MLP(nn.Module):
def __init__(self, dims: list[int], norm=True, dropout=0.0):
"""
Args:
dims: list of layer widths, e.g. [784, 256, 128, 10]
first = input dim, last = output dim
norm: whether to apply LayerNorm between layers
dropout: dropout probability (0 = no dropout)
"""
super().__init__()
layers = []
for i in range(len(dims) - 1):
layers.append(nn.Linear(dims[i], dims[i + 1]))
# No norm/activation/dropout after the LAST layer —
# the final layer produces raw logits or regression values
if i < len(dims) - 2:
if norm:
layers.append(nn.LayerNorm(dims[i + 1]))
layers.append(nn.ReLU())
if dropout > 0:
layers.append(nn.Dropout(dropout))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# ═══════════════════════════════════════════════════════════════════
# 2. INITIALISE WEIGHTS
# ═══════════════════════════════════════════════════════════════════
#
# Why it matters: bad init → vanishing/exploding gradients → dead network.
#
# Rule of thumb: match init to activation function.
# • Kaiming (He) — designed for ReLU. Scales by √(2/fan_in).
# • Xavier (Glorot) — designed for tanh/sigmoid. Scales by √(1/fan_in).
# • If using GELU/SiLU, Kaiming is a reasonable default.
#
# Biases are almost always initialised to zero.
# LayerNorm/BatchNorm params are fine at their defaults (weight=1, bias=0).
def init_weights(model: nn.Module):
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
# ═══════════════════════════════════════════════════════════════════
# 3. SET UP OPTIMISER (and optional LR schedule)
# ═══════════════════════════════════════════════════════════════════
#
# Common choices:
# • SGD + momentum — simple, well-understood, needs LR tuning
# • Adam — adaptive LR per-param, good default
# • AdamW — Adam with decoupled weight decay (preferred)
#
# Learning rate schedule:
# • Constant — fine for quick experiments
# • Cosine annealing — smooth decay, widely used
# • Warmup + decay — standard for transformers
def make_optimizer(model, lr=1e-3, weight_decay=1e-2):
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
def make_scheduler(optimizer, total_steps):
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)
# ═══════════════════════════════════════════════════════════════════
# 4. TRAINING LOOP
# ═══════════════════════════════════════════════════════════════════
#
# The loop is always:
# for each epoch:
# for each batch:
# 1. Forward pass — prediction = model(input)
# 2. Compute loss — how wrong are we?
# 3. Backward pass — compute gradients
# 4. Optimiser step — update weights
# 5. Zero gradients — reset for next batch
# 6. (Optional) LR scheduler step
# 7. Log metrics
#
# Common gotchas:
# • Forgetting zero_grad → gradients accumulate across batches
# • Forgetting model.eval() / torch.no_grad() during validation
# • Stepping the scheduler per-batch vs per-epoch (read the docs)
def train(model, train_loader, val_loader, optimizer, scheduler,
n_epochs, device, log_dir="runs/experiment"):
writer = SummaryWriter(log_dir)
for epoch in range(n_epochs):
# ── Training phase ────────────────────────────────────────
model.train() # enable dropout, etc.
epoch_loss = 0.0
n_correct = 0
n_total = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
# Forward
logits = model(x) # (B, n_classes)
loss = F.cross_entropy(logits, y) # scalar
# Backward
optimizer.zero_grad() # clear old gradients
loss.backward() # compute new gradients
# (Optional) gradient clipping — prevents exploding gradients
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() # update weights
scheduler.step() # update LR (per-batch for cosine)
# Accumulate metrics
epoch_loss += loss.item() * x.size(0)
n_correct += (logits.argmax(-1) == y).sum().item()
n_total += x.size(0)
# ── Per-epoch logging (summary) ───────────────────────────
train_loss = epoch_loss / n_total
train_acc = n_correct / n_total
writer.add_scalar("train/loss_epoch", train_loss, epoch)
writer.add_scalar("train/accuracy", train_acc, epoch)
writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch)
# ── Validation phase ──────────────────────────────────────
val_loss, val_acc = evaluate(model, val_loader, device)
writer.add_scalar("val/loss", val_loss, epoch)
writer.add_scalar("val/accuracy", val_acc, epoch)
# ── Console output ────────────────────────────────────────
print(f"Epoch {epoch+1:3d}/{n_epochs} │ "
f"train loss {train_loss:.4f} acc {train_acc:.3f} │ "
f"val loss {val_loss:.4f} acc {val_acc:.3f} │ "
f"lr {optimizer.param_groups[0]['lr']:.2e}")
# ── Log model graph (once) ────────────────────────────────────
sample = next(iter(train_loader))[0][:1].to(device)
writer.add_graph(model, sample)
writer.close()
print(f"\nTensorBoard logs saved to: {log_dir}")
print(f"Run: tensorboard --logdir {Path(log_dir).parent}")
@torch.no_grad()
def evaluate(model, loader, device):
model.eval() # disable dropout, etc.
total_loss = 0.0
n_correct = 0
n_total = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
total_loss += F.cross_entropy(logits, y, reduction="sum").item()
n_correct += (logits.argmax(-1) == y).sum().item()
n_total += x.size(0)
return total_loss / n_total, n_correct / n_total
# ═══════════════════════════════════════════════════════════════════
# 5. PUTTING IT ALL TOGETHER
# ═══════════════════════════════════════════════════════════════════
def main():
# ── Config ────────────────────────────────────────────────────
input_dim = 64
hidden_dim = 128
n_classes = 10
n_samples = 5000
batch_size = 128
n_epochs = 20
lr = 3e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Synthetic dataset (replace with real data) ────────────────
X = torch.randn(n_samples, input_dim)
Y = torch.randint(0, n_classes, (n_samples,))
split = int(0.8 * n_samples)
train_loader = DataLoader(TensorDataset(X[:split], Y[:split]),
batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X[split:], Y[split:]),
batch_size=batch_size)
# ── Create model ──────────────────────────────────────────────
model = MLP(dims=[input_dim, hidden_dim, hidden_dim, n_classes],
norm=True, dropout=0.1)
init_weights(model)
model = model.to(device)
# ── Optimiser & schedule ──────────────────────────────────────
total_steps = n_epochs * len(train_loader)
optimizer = make_optimizer(model, lr=lr)
scheduler = make_scheduler(optimizer, total_steps)
# ── Train ─────────────────────────────────────────────────────
train(model, train_loader, val_loader, optimizer, scheduler,
n_epochs, device, log_dir="runs/mlp_demo")
if __name__ == "__main__":
main()