Skip to content

Unified Q-Learning Algorithm — Implementation

"""
Unified Q-Learning Algorithm
=============================
A single skeleton that covers: Q-Learning, DQN, Double DQN, Dueling DQN,
CQL, IQL, SAC (Q-critic), and more.

The core loop is IDENTICAL across all variants. Only three pluggable
components change:
  1. compute_target()   — how the bootstrap target y is built
  2. compute_loss()     — the objective (MSE, Huber, + regularizers)
  3. data source        — online replay buffer vs. offline dataset
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional
from replay_buffers import Batch, ReplayBuffer, NStepReplayBuffer

# Batch is a NamedTuple of tensors, one row per sample in the mini-batch:
#   s      (B, *state_shape)  — states
#   a      (B,)               — actions taken
#   r      (B,)               — rewards (1-step or n-step discounted return)
#   s_next (B, *state_shape)  — next states (1 or n steps ahead)
#   done   (B,)               — 1.0 if episode ended, 0.0 otherwise
#
# A replay buffer exposes two methods:
#   add(s, a, r, s_next, done)  — store a single transition
#   sample(batch_size) -> Batch — draw a random mini-batch for training
# See replay_buffers.py for ReplayBuffer (1-step) and NStepReplayBuffer.


# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM  (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════

class QAlgorithm(ABC):
    """
    The universal Q-learning training loop.

    Every variant inherits this and only overrides:
      - compute_target(batch) -> Tensor
      - compute_loss(q_values, targets, batch) -> Tensor
    """

    def __init__(self, Q: nn.Module, Q_target: nn.Module,
                 optimizer: torch.optim.Optimizer, gamma: float = 0.99,
                 target_update_freq: int = 1000, tau: Optional[float] = None) -> None:
        self.Q = Q                # online  Q-network
        self.Q_target = Q_target  # target  Q-network
        self.optimizer = optimizer
        self.gamma = gamma
        self.target_update_freq = target_update_freq
        self.tau = tau            # if set, use Polyak averaging instead of hard copy
        self._step: int = 0

    # ── The two pluggable pieces ──────────────────────────────────

    @abstractmethod
    def compute_target(self, batch: Batch) -> torch.Tensor:
        """Return the scalar bootstrap target y for each sample."""
        ...

    def compute_loss(self, q_a: torch.Tensor, targets: torch.Tensor,
                     batch: Batch) -> torch.Tensor:
        """Default: MSE / Huber. Override to add regularisers (e.g. CQL)."""
        return F.mse_loss(q_a, targets)

    # ── Core update step (IDENTICAL for every variant) ────────────

    def update(self, batch: Batch) -> float:
        # 1. Current Q-values for the actions actually taken
        q_all = self.Q(batch.s)                         # (B, |A|)
        q_a = q_all.gather(1, batch.a.unsqueeze(-1)).squeeze(-1)  # (B,)

        # 2. Compute bootstrap target  (this is the part that varies)
        with torch.no_grad():
            targets = self.compute_target(batch)         # (B,)

        # 3. Loss  (varies: plain MSE, Huber, +CQL penalty, expectile, ...)
        loss = self.compute_loss(q_a, targets, batch)

        # 4. Gradient step  (always the same)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # 5. Target network update  (always the same)
        self._step += 1
        self._update_target_network()

        return loss.item()

    def _update_target_network(self) -> None:
        if self.tau is not None:
            # Polyak / soft update  (used by SAC, TD3, etc.)
            for p, pt in zip(self.Q.parameters(), self.Q_target.parameters()):
                pt.data.copy_(self.tau * p.data + (1 - self.tau) * pt.data)
        elif self._step % self.target_update_freq == 0:
            # Hard copy  (used by DQN, CQL, etc.)
            self.Q_target.load_state_dict(self.Q.state_dict())


# ═══════════════════════════════════════════════════════════════════
# ONLINE TRAINING LOOP  (DQN-style: interact → store → sample → learn)
# ═══════════════════════════════════════════════════════════════════

def train_online(algo: QAlgorithm, env: Any,
                 replay_buffer: ReplayBuffer | NStepReplayBuffer,
                 n_steps: int, batch_size: int = 256, warmup: int = 1000,
                 eps_schedule: Optional[Callable[[int], float]] = None) -> None:
    s, _ = env.reset()
    for t in range(n_steps):
        # ε-greedy (or swap in Boltzmann, UCB, etc.)
        eps = eps_schedule(t) if eps_schedule else 0.1
        if torch.rand(1).item() < eps:
            a = env.action_space.sample()
        else:
            with torch.no_grad():
                a = algo.Q(torch.tensor(s).unsqueeze(0)).argmax(-1).item()

        s_next, r, done, trunc, _ = env.step(a)
        replay_buffer.add(s, a, r, s_next, float(done or trunc))
        s = s_next
        if done or trunc:
            s, _ = env.reset()

        # Learn from replay
        if len(replay_buffer) >= warmup:
            batch = replay_buffer.sample(batch_size)
            algo.update(batch)


# ═══════════════════════════════════════════════════════════════════
# OFFLINE TRAINING LOOP  (CQL / IQL-style: just iterate over dataset)
# ═══════════════════════════════════════════════════════════════════

def train_offline(algo: QAlgorithm, dataset: Any, n_steps: int,
                  batch_size: int = 256) -> None:
    for t in range(n_steps):
        batch = dataset.sample(batch_size)   # random mini-batch from static data
        algo.update(batch)


# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS  (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════

# ── 1. Vanilla DQN ───────────────────────────────────────────────

class DQN(QAlgorithm):
    """Target: y = r + γ · max_a' Q_target(s', a')"""

    def compute_target(self, batch: Batch) -> torch.Tensor:
        q_next = self.Q_target(batch.s_next)                # (B, |A|)
        return batch.r + self.gamma * (1 - batch.done) * q_next.max(dim=-1).values


# ── 2. Double DQN ────────────────────────────────────────────────

class DoubleDQN(QAlgorithm):
    """
    Target: y = r + γ · Q_target(s', argmax_a' Q_online(s', a'))
    Online network SELECTS the action; target network EVALUATES it.
    """

    def compute_target(self, batch: Batch) -> torch.Tensor:
        # Online net picks the best action
        a_best = self.Q(batch.s_next).argmax(dim=-1, keepdim=True)    # (B, 1)
        # Target net evaluates that action
        q_next = self.Q_target(batch.s_next).gather(1, a_best).squeeze(-1)
        return batch.r + self.gamma * (1 - batch.done) * q_next


# ── 3. CQL  (Conservative Q-Learning, offline) ──────────────────

class CQL(QAlgorithm):
    """
    Same target as DQN, but adds a regulariser that pushes down
    Q-values on out-of-distribution actions and pushes up Q-values
    on actions seen in the dataset.

    loss = TD_loss + α · [ log Σ_a exp Q(s,a)  −  Q(s, a_data) ]
    """

    def __init__(self, *args: Any, cql_alpha: float = 1.0, **kw: Any) -> None:
        super().__init__(*args, **kw)
        self.cql_alpha = cql_alpha

    def compute_target(self, batch: Batch) -> torch.Tensor:
        # Identical to DQN
        q_next = self.Q_target(batch.s_next)
        return batch.r + self.gamma * (1 - batch.done) * q_next.max(-1).values

    def compute_loss(self, q_a: torch.Tensor, targets: torch.Tensor,
                     batch: Batch) -> torch.Tensor:
        td_loss = F.mse_loss(q_a, targets)

        # CQL regulariser: penalise high Q on ALL actions, reward Q on DATA actions
        q_all = self.Q(batch.s)                              # (B, |A|)
        logsumexp = torch.logsumexp(q_all, dim=-1).mean()    # push down
        data_q    = q_a.mean()                                # push up
        cql_penalty = logsumexp - data_q

        return td_loss + self.cql_alpha * cql_penalty


# ── 4. IQL  (Implicit Q-Learning, offline) ──────────────────────

class IQL(QAlgorithm):
    """
    Avoids querying Q on out-of-distribution actions entirely by
    learning a separate state-value V(s) with expectile regression,
    then using V(s') as the bootstrap target instead of max_a' Q(s',a').

    Two losses:
      • V-loss : expectile regression of V(s) toward Q(s, a_data)
      • Q-loss : standard TD using V(s') as the target
    """

    def __init__(self, *args: Any, V: nn.Module, v_optimizer: torch.optim.Optimizer,
                 expectile: float = 0.7, **kw: Any) -> None:
        super().__init__(*args, **kw)
        self.V = V
        self.v_optimizer = v_optimizer
        self.expectile = expectile

    def compute_target(self, batch: Batch) -> torch.Tensor:
        # Bootstrap off V(s') — never need max over unseen actions
        v_next = self.V(batch.s_next).squeeze(-1)            # (B,)
        return batch.r + self.gamma * (1 - batch.done) * v_next

    def update(self, batch: Batch) -> float:
        # ── Extra step: update V via expectile regression ────
        with torch.no_grad():
            q_all = self.Q_target(batch.s)
            q_a = q_all.gather(1, batch.a.unsqueeze(-1)).squeeze(-1)

        v = self.V(batch.s).squeeze(-1)
        diff = q_a - v
        weight = torch.where(diff > 0, self.expectile, 1 - self.expectile)
        v_loss = (weight * diff.pow(2)).mean()

        self.v_optimizer.zero_grad()
        v_loss.backward()
        self.v_optimizer.step()

        # ── Then the standard Q update (inherited from QAlgorithm) ──
        return super().update(batch)


# ── 5. Soft Q-Learning / SAC critic ─────────────────────────────

class SoftQ(QAlgorithm):
    """
    Target: y = r + γ · (Q_target(s', a') − α log π(a'|s'))
    where a' ~ π(·|s').  Used as the critic in SAC.
    """

    def __init__(self, *args: Any, policy: Any, alpha: float = 0.2, **kw: Any) -> None:
        super().__init__(*args, **kw)
        self.policy = policy
        self.alpha = alpha

    def compute_target(self, batch: Batch) -> torch.Tensor:
        a_next, log_prob = self.policy.sample(batch.s_next)  # sample from π
        q_next = self.Q_target(batch.s_next)
        q_a_next = q_next.gather(1, a_next.unsqueeze(-1)).squeeze(-1)
        return batch.r + self.gamma * (1 - batch.done) * (q_a_next - self.alpha * log_prob)


# ── 6. N-Step DQN ──────────────────────────────────────────────

class NStepDQN(QAlgorithm):
    """
    Target: y = r_nstep + γⁿ · max_a' Q_target(s_n, a')

    The n-step discounted return (r_nstep) and the n-step-ahead next state
    (s_n) are computed by the NStepReplayBuffer, not here.  This class
    only adjusts the discount factor from γ to γⁿ in the bootstrap term.
    """

    def __init__(self, *args: Any, n_step: int = 3, **kw: Any) -> None:
        super().__init__(*args, **kw)
        self.n_step = n_step

    def compute_target(self, batch: Batch) -> torch.Tensor:
        # batch.r is already the n-step discounted return from the buffer
        # batch.s_next is the state n steps ahead
        q_next = self.Q_target(batch.s_next)
        return batch.r + (self.gamma ** self.n_step) * (1 - batch.done) * q_next.max(dim=-1).values