Skip to content

Core Neural Network Training — Minimal Complete Example — Implementation

"""
Core Neural Network Training — Minimal Complete Example
========================================================
A single file covering the fundamental requirements for training
a neural network: architecture, initialisation, training loop,
and logging. Uses a simple MLP on synthetic data so the focus
stays on the machinery, not the task.

The five stages every training script goes through:
  1. Define the model       — architecture, activations, normalisation
  2. Initialise weights     — set a good starting point
  3. Set up optimiser       — how gradients become weight updates
  4. Training loop          — forward → loss → backward → step
  5. Logging / monitoring   — track what matters, catch problems early
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path


# ═══════════════════════════════════════════════════════════════════
# 1. DEFINE THE MODEL
# ═══════════════════════════════════════════════════════════════════
#
# An MLP is just a stack of:   Linear → Norm → Activation → repeat
#
# LayerNorm is the simplest normalisation choice:
#   • BatchNorm — normalises across the batch. Fast, effective, but
#     behaviour changes between train/eval and breaks with batch=1.
#   • LayerNorm — normalises across features within each sample.
#     Consistent train/eval behaviour. Default choice for transformers.
#   • RMSNorm — like LayerNorm but skips the mean-centering step.
#     Slightly cheaper, used in LLaMA-style models.
#
# ReLU is the simplest activation choice:
#   • ReLU     — max(0, x). Simple, sparse, but "dead neuron" risk.
#   • GELU     — smooth approximation of ReLU, default in transformers.
#   • SiLU     — x · σ(x), aka Swish. Popular in vision and LLMs.

class MLP(nn.Module):
    def __init__(self, dims: list[int], norm=True, dropout=0.0):
        """
        Args:
            dims:    list of layer widths, e.g. [784, 256, 128, 10]
                     first = input dim, last = output dim
            norm:    whether to apply LayerNorm between layers
            dropout: dropout probability (0 = no dropout)
        """
        super().__init__()
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))

            # No norm/activation/dropout after the LAST layer —
            # the final layer produces raw logits or regression values
            if i < len(dims) - 2:
                if norm:
                    layers.append(nn.LayerNorm(dims[i + 1]))
                layers.append(nn.ReLU())
                if dropout > 0:
                    layers.append(nn.Dropout(dropout))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# ═══════════════════════════════════════════════════════════════════
# 2. INITIALISE WEIGHTS
# ═══════════════════════════════════════════════════════════════════
#
# Why it matters: bad init → vanishing/exploding gradients → dead network.
#
# Rule of thumb: match init to activation function.
#   • Kaiming (He)  — designed for ReLU. Scales by √(2/fan_in).
#   • Xavier (Glorot) — designed for tanh/sigmoid. Scales by √(1/fan_in).
#   • If using GELU/SiLU, Kaiming is a reasonable default.
#
# Biases are almost always initialised to zero.
# LayerNorm/BatchNorm params are fine at their defaults (weight=1, bias=0).

def init_weights(model: nn.Module):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            if m.bias is not None:
                nn.init.zeros_(m.bias)


# ═══════════════════════════════════════════════════════════════════
# 3. SET UP OPTIMISER (and optional LR schedule)
# ═══════════════════════════════════════════════════════════════════
#
# Common choices:
#   • SGD + momentum — simple, well-understood, needs LR tuning
#   • Adam           — adaptive LR per-param, good default
#   • AdamW          — Adam with decoupled weight decay (preferred)
#
# Learning rate schedule:
#   • Constant        — fine for quick experiments
#   • Cosine annealing — smooth decay, widely used
#   • Warmup + decay   — standard for transformers

def make_optimizer(model, lr=1e-3, weight_decay=1e-2):
    return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

def make_scheduler(optimizer, total_steps):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)


# ═══════════════════════════════════════════════════════════════════
# 4. TRAINING LOOP
# ═══════════════════════════════════════════════════════════════════
#
# The loop is always:
#   for each epoch:
#     for each batch:
#       1. Forward pass    — prediction = model(input)
#       2. Compute loss    — how wrong are we?
#       3. Backward pass   — compute gradients
#       4. Optimiser step  — update weights
#       5. Zero gradients  — reset for next batch
#       6. (Optional) LR scheduler step
#       7. Log metrics
#
# Common gotchas:
#   • Forgetting zero_grad → gradients accumulate across batches
#   • Forgetting model.eval() / torch.no_grad() during validation
#   • Stepping the scheduler per-batch vs per-epoch (read the docs)

def train(model, train_loader, val_loader, optimizer, scheduler,
          n_epochs, device, log_dir="runs/experiment"):

    writer = SummaryWriter(log_dir)

    for epoch in range(n_epochs):

        # ── Training phase ────────────────────────────────────────
        model.train()                              # enable dropout, etc.
        epoch_loss = 0.0
        n_correct = 0
        n_total = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # Forward
            logits = model(x)                      # (B, n_classes)
            loss = F.cross_entropy(logits, y)      # scalar

            # Backward
            optimizer.zero_grad()                  # clear old gradients
            loss.backward()                        # compute new gradients

            # (Optional) gradient clipping — prevents exploding gradients
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()                       # update weights
            scheduler.step()                       # update LR (per-batch for cosine)

            # Accumulate metrics
            epoch_loss += loss.item() * x.size(0)
            n_correct += (logits.argmax(-1) == y).sum().item()
            n_total += x.size(0)

        # ── Per-epoch logging (summary) ───────────────────────────
        train_loss = epoch_loss / n_total
        train_acc = n_correct / n_total
        writer.add_scalar("train/loss_epoch", train_loss, epoch)
        writer.add_scalar("train/accuracy", train_acc, epoch)
        writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch)

        # ── Validation phase ──────────────────────────────────────
        val_loss, val_acc = evaluate(model, val_loader, device)
        writer.add_scalar("val/loss", val_loss, epoch)
        writer.add_scalar("val/accuracy", val_acc, epoch)

        # ── Console output ────────────────────────────────────────
        print(f"Epoch {epoch+1:3d}/{n_epochs} │ "
              f"train loss {train_loss:.4f}  acc {train_acc:.3f} │ "
              f"val loss {val_loss:.4f}  acc {val_acc:.3f} │ "
              f"lr {optimizer.param_groups[0]['lr']:.2e}")

    # ── Log model graph (once) ────────────────────────────────────
    sample = next(iter(train_loader))[0][:1].to(device)
    writer.add_graph(model, sample)
    writer.close()
    print(f"\nTensorBoard logs saved to: {log_dir}")
    print(f"Run: tensorboard --logdir {Path(log_dir).parent}")


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()                                   # disable dropout, etc.
    total_loss = 0.0
    n_correct = 0
    n_total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        total_loss += F.cross_entropy(logits, y, reduction="sum").item()
        n_correct += (logits.argmax(-1) == y).sum().item()
        n_total += x.size(0)
    return total_loss / n_total, n_correct / n_total


# ═══════════════════════════════════════════════════════════════════
# 5. PUTTING IT ALL TOGETHER
# ═══════════════════════════════════════════════════════════════════

def main():
    # ── Config ────────────────────────────────────────────────────
    input_dim    = 64
    hidden_dim   = 128
    n_classes    = 10
    n_samples    = 5000
    batch_size   = 128
    n_epochs     = 20
    lr           = 3e-4
    device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ── Synthetic dataset (replace with real data) ────────────────
    X = torch.randn(n_samples, input_dim)
    Y = torch.randint(0, n_classes, (n_samples,))
    split = int(0.8 * n_samples)
    train_loader = DataLoader(TensorDataset(X[:split], Y[:split]),
                              batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(TensorDataset(X[split:], Y[split:]),
                              batch_size=batch_size)

    # ── Create model ──────────────────────────────────────────────
    model = MLP(dims=[input_dim, hidden_dim, hidden_dim, n_classes],
                norm=True, dropout=0.1)
    init_weights(model)
    model = model.to(device)

    # ── Optimiser & schedule ──────────────────────────────────────
    total_steps = n_epochs * len(train_loader)
    optimizer   = make_optimizer(model, lr=lr)
    scheduler   = make_scheduler(optimizer, total_steps)

    # ── Train ─────────────────────────────────────────────────────
    train(model, train_loader, val_loader, optimizer, scheduler,
          n_epochs, device, log_dir="runs/mlp_demo")


if __name__ == "__main__":
    main()