Skip to content

Unified Policy Gradient Algorithm — Implementation

"""
Unified Policy Gradient Algorithm
===================================
A single skeleton that covers: REINFORCE, REINFORCE with baseline,
Vanilla Policy Gradient (VPG/A2C), PPO (clip), and PPO with
entropy bonus.

The core idea shared by ALL policy gradient methods:
  ∇J(θ) ≈ E[ Ψ · ∇log π(a|s) ]

  where Ψ is some measure of "how good was this action."
  Every variant just changes what Ψ is and how the gradient
  is used.

The pluggable components are:
  1. compute_advantages()  — what Ψ is (returns, advantages, GAE, ...)
  2. policy_loss()         — how the gradient signal is shaped
                             (vanilla, clipped surrogate, ...)
  3. value_loss()          — how the value baseline is trained (if any)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import NamedTuple


# ─── Shared Data Types ────────────────────────────────────────────

class Rollout(NamedTuple):
    """A batch of complete trajectories collected by the current policy."""
    states:     torch.Tensor   # (T, *state_shape)
    actions:    torch.Tensor   # (T,) or (T, action_dim)
    rewards:    torch.Tensor   # (T,)
    dones:      torch.Tensor   # (T,)  — 1.0 at episode boundaries
    log_probs:  torch.Tensor   # (T,)  — log π_old(a|s) at collection time
    values:     torch.Tensor   # (T,)  — V(s) estimates (zeros if no baseline)


# ─── Utility: Generalised Advantage Estimation (GAE) ──────────────
#
# GAE is the standard way to compute advantages in modern policy
# gradient methods. It interpolates between:
#   λ=0 → TD(0) advantage:  A = r + γV(s') - V(s)     (low variance, high bias)
#   λ=1 → Monte Carlo:      A = R_t - V(s)             (high variance, low bias)
#
# Nearly everyone uses λ=0.95, γ=0.99.

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    T = len(rewards)
    advantages = torch.zeros(T)
    gae = 0.0
    for t in reversed(range(T)):
        next_val = values[t + 1] if t + 1 < T else 0.0
        delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages[t] = gae
    returns = advantages + values                       # A = R - V, so R = A + V
    return advantages, returns


# ─── Utility: Simple discounted returns (no baseline) ─────────────

def compute_discounted_returns(rewards, dones, gamma=0.99):
    T = len(rewards)
    returns = torch.zeros(T)
    R = 0.0
    for t in reversed(range(T)):
        R = rewards[t] + gamma * R * (1 - dones[t])
        returns[t] = R
    return returns


# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM  (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════

class PolicyGradient(ABC):
    """
    The universal policy gradient training step.

    Every variant inherits this and only overrides:
      - compute_advantages(rollout) -> (advantages, returns)
      - policy_loss(log_probs, old_log_probs, advantages) -> Tensor
      - value_loss(values, returns) -> Tensor  (optional)
    """

    def __init__(self, policy, optimizer, entropy_coeff=0.01,
                 value_net=None, value_optimizer=None):
        self.policy = policy
        self.optimizer = optimizer
        self.entropy_coeff = entropy_coeff
        self.value_net = value_net
        self.value_optimizer = value_optimizer

    # ── The three pluggable pieces ────────────────────────────────

    @abstractmethod
    def compute_advantages(self, rollout: Rollout):
        """Return (advantages, returns) tensors of shape (T,)."""
        ...

    @abstractmethod
    def policy_loss(self, log_probs, old_log_probs, advantages):
        """Return the scalar policy loss to minimise."""
        ...

    def value_loss(self, values, returns):
        """Default: MSE. Override if you want clipped value loss, etc."""
        return F.mse_loss(values, returns)

    # ── Core update step (IDENTICAL for every variant) ────────────

    def update(self, rollout: Rollout):
        advantages, returns = self.compute_advantages(rollout)

        # Normalise advantages (nearly universal, stabilises training)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # ── Policy update ─────────────────────────────────────────
        dist = self.policy(rollout.states)                  # action distribution
        log_probs = dist.log_prob(rollout.actions)          # (T,)
        entropy = dist.entropy().mean()                     # scalar

        loss_pi = self.policy_loss(log_probs, rollout.log_probs, advantages)
        loss_pi = loss_pi - self.entropy_coeff * entropy    # encourage exploration

        self.optimizer.zero_grad()
        loss_pi.backward()
        nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=0.5)
        self.optimizer.step()

        # ── Value update (if we have a value network) ─────────────
        loss_v = torch.tensor(0.0)
        if self.value_net is not None and self.value_optimizer is not None:
            v = self.value_net(rollout.states).squeeze(-1)  # (T,)
            loss_v = self.value_loss(v, returns.detach())

            self.value_optimizer.zero_grad()
            loss_v.backward()
            nn.utils.clip_grad_norm_(self.value_net.parameters(), max_norm=0.5)
            self.value_optimizer.step()

        return {"policy_loss": loss_pi.item(),
                "value_loss": loss_v.item(),
                "entropy": entropy.item()}


# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP  (collect rollout → update → repeat)
# ═══════════════════════════════════════════════════════════════════
#
# Key difference from Q-learning: policy gradient methods are
# ON-POLICY — you must collect fresh data with the CURRENT policy,
# use it for one (or a few) updates, then throw it away.
# (This is why they're less sample-efficient than Q-learning.)

def collect_rollout(policy, value_net, env, n_steps, device="cpu"):
    """Roll out the current policy in the environment for n_steps."""
    states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []

    s, _ = env.reset()
    for _ in range(n_steps):
        s_t = torch.tensor(s, dtype=torch.float32, device=device)
        with torch.no_grad():
            dist = policy(s_t.unsqueeze(0))
            a = dist.sample()
            lp = dist.log_prob(a)
            v = value_net(s_t.unsqueeze(0)).squeeze() if value_net else torch.tensor(0.0)

        s_next, r, done, trunc, _ = env.step(a.squeeze(0).cpu().numpy())

        states.append(s_t)
        actions.append(a.squeeze(0))
        rewards.append(r)
        dones.append(float(done or trunc))
        log_probs.append(lp.squeeze(0))
        values.append(v)

        s = s_next
        if done or trunc:
            s, _ = env.reset()

    return Rollout(
        states=torch.stack(states),
        actions=torch.stack(actions),
        rewards=torch.tensor(rewards),
        dones=torch.tensor(dones),
        log_probs=torch.stack(log_probs),
        values=torch.stack(values),
    )


def train(algo: PolicyGradient, policy, value_net, env,
          n_iters=500, rollout_len=2048, device="cpu"):
    for i in range(n_iters):
        rollout = collect_rollout(policy, value_net, env, rollout_len, device)
        metrics = algo.update(rollout)

        if (i + 1) % 10 == 0:
            print(f"Iter {i+1:4d} │ "
                  f"π loss {metrics['policy_loss']:+.4f}  "
                  f"V loss {metrics['value_loss']:.4f}  "
                  f"entropy {metrics['entropy']:.3f}")


# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS  (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════

# ── 1. REINFORCE (vanilla, no baseline) ──────────────────────────

class REINFORCE(PolicyGradient):
    """
    Ψ = G_t (discounted return from time t)

    The simplest possible policy gradient. High variance because the
    raw return includes reward from the entire episode.
    """

    def __init__(self, policy, optimizer, gamma=0.99, **kw):
        super().__init__(policy, optimizer, **kw)
        self.gamma = gamma

    def compute_advantages(self, rollout):
        returns = compute_discounted_returns(rollout.rewards, rollout.dones, self.gamma)
        # No baseline, so advantages = returns
        return returns, returns

    def policy_loss(self, log_probs, old_log_probs, advantages):
        # Classic REINFORCE: −E[G_t · log π(a|s)]
        return -(log_probs * advantages).mean()


# ── 2. REINFORCE with learned baseline ───────────────────────────

class REINFORCEBaseline(PolicyGradient):
    """
    Ψ = G_t − V(s_t)

    Subtracting a baseline (the value function) doesn't change the
    expected gradient but dramatically reduces variance, because now
    the signal is "how much BETTER was this action than average"
    rather than "how good was the entire trajectory."
    """

    def __init__(self, policy, optimizer, gamma=0.99, **kw):
        super().__init__(policy, optimizer, **kw)
        self.gamma = gamma

    def compute_advantages(self, rollout):
        returns = compute_discounted_returns(rollout.rewards, rollout.dones, self.gamma)
        advantages = returns - rollout.values
        return advantages, returns

    def policy_loss(self, log_probs, old_log_probs, advantages):
        return -(log_probs * advantages.detach()).mean()


# ── 3. A2C  (Advantage Actor-Critic) ─────────────────────────────

class A2C(PolicyGradient):
    """
    Ψ = GAE(δ_t)  where δ_t = r + γV(s') − V(s)

    Uses GAE instead of full Monte Carlo returns for the advantage.
    This reduces variance further at the cost of some bias (controlled
    by λ). Otherwise identical to REINFORCE with baseline.
    """

    def __init__(self, policy, optimizer, gamma=0.99, lam=0.95, **kw):
        super().__init__(policy, optimizer, **kw)
        self.gamma = gamma
        self.lam = lam

    def compute_advantages(self, rollout):
        return compute_gae(rollout.rewards, rollout.values,
                           rollout.dones, self.gamma, self.lam)

    def policy_loss(self, log_probs, old_log_probs, advantages):
        return -(log_probs * advantages.detach()).mean()


# ── 4. PPO  (Proximal Policy Optimisation, clipped) ──────────────

class PPO(PolicyGradient):
    """
    Same advantages as A2C (GAE), but changes how the gradient is used.

    Problem: vanilla policy gradient takes one big step, which can
    destroy the policy (performance collapses and never recovers).

    Solution: instead of −log π · A, use a clipped surrogate objective
    that prevents the policy ratio π/π_old from moving too far from 1.

        L = min( ratio · A,  clip(ratio, 1−ε, 1+ε) · A )

    This is the key insight of PPO: constrain the update size without
    the complexity of TRPO's KL constraint.
    """

    def __init__(self, policy, optimizer, gamma=0.99, lam=0.95,
                 clip_eps=0.2, n_policy_epochs=4, minibatch_size=64, **kw):
        super().__init__(policy, optimizer, **kw)
        self.gamma = gamma
        self.lam = lam
        self.clip_eps = clip_eps
        self.n_policy_epochs = n_policy_epochs
        self.minibatch_size = minibatch_size

    def compute_advantages(self, rollout):
        return compute_gae(rollout.rewards, rollout.values,
                           rollout.dones, self.gamma, self.lam)

    def policy_loss(self, log_probs, old_log_probs, advantages):
        ratio = (log_probs - old_log_probs).exp()           # π_new / π_old
        clipped = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps)
        return -torch.min(ratio * advantages, clipped * advantages).mean()

    # ── PPO overrides update() to do multiple epochs on the same data ──

    def update(self, rollout: Rollout):
        advantages, returns = self.compute_advantages(rollout)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO reuses the same rollout for several epochs of minibatch updates.
        # This is what makes it more sample-efficient than vanilla PG, while
        # the clipping keeps each step safe.

        T = len(rollout.states)
        metrics = {"policy_loss": 0, "value_loss": 0, "entropy": 0}
        n_updates = 0

        for _ in range(self.n_policy_epochs):
            indices = torch.randperm(T)
            for start in range(0, T, self.minibatch_size):
                idx = indices[start:start + self.minibatch_size]

                dist = self.policy(rollout.states[idx])
                log_probs = dist.log_prob(rollout.actions[idx])
                entropy = dist.entropy().mean()

                # Policy loss (clipped surrogate)
                loss_pi = self.policy_loss(
                    log_probs, rollout.log_probs[idx], advantages[idx])
                loss_pi = loss_pi - self.entropy_coeff * entropy

                self.optimizer.zero_grad()
                loss_pi.backward()
                nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=0.5)
                self.optimizer.step()

                # Value loss
                loss_v = torch.tensor(0.0)
                if self.value_net is not None and self.value_optimizer is not None:
                    v = self.value_net(rollout.states[idx]).squeeze(-1)
                    loss_v = self.value_loss(v, returns[idx].detach())

                    self.value_optimizer.zero_grad()
                    loss_v.backward()
                    nn.utils.clip_grad_norm_(self.value_net.parameters(), max_norm=0.5)
                    self.value_optimizer.step()

                metrics["policy_loss"] += loss_pi.item()
                metrics["value_loss"] += loss_v.item()
                metrics["entropy"] += entropy.item()
                n_updates += 1

        return {k: v / n_updates for k, v in metrics.items()}