"""
Replay Buffers for Q-Learning
==============================
Standard 1-step and N-step replay buffers used by the online training loop.
# sidebar.label: Replay Buffer Implementation
"""
import torch
from typing import Any, NamedTuple, Tuple
# ─── Shared Data Types ────────────────────────────────────────────
class Batch(NamedTuple):
s: torch.Tensor # states (B, *state_shape)
a: torch.Tensor # actions (B,) or (B, action_dim)
r: torch.Tensor # rewards (B,)
s_next: torch.Tensor # next states (B, *state_shape)
done: torch.Tensor # terminal mask (B,) — 1.0 if done
Transition = Tuple[Any, int, float, Any, float] # (s, a, r, s_next, done)
# ─── Replay Buffer ───────────────────────────────────────────────
class ReplayBuffer:
"""Standard 1-step replay buffer backed by a fixed-size ring buffer."""
def __init__(self, capacity: int) -> None:
self.capacity = capacity
self.buf: list[Transition] = []
self.pos: int = 0
def add(self, s: Any, a: int, r: float, s_next: Any, done: float) -> None:
transition = (s, a, r, s_next, done)
if len(self.buf) < self.capacity:
self.buf.append(transition)
else:
self.buf[self.pos] = transition
self.pos = (self.pos + 1) % self.capacity
def sample(self, batch_size: int) -> Batch:
idxs = torch.randint(len(self.buf), (batch_size,))
ss, aa, rr, ss_next, dd = zip(*(self.buf[i] for i in idxs))
return Batch(
s=torch.tensor(ss, dtype=torch.float32),
a=torch.tensor(aa, dtype=torch.long),
r=torch.tensor(rr, dtype=torch.float32),
s_next=torch.tensor(ss_next, dtype=torch.float32),
done=torch.tensor(dd, dtype=torch.float32),
)
def __len__(self) -> int:
return len(self.buf)
# ─── N-Step Replay Buffer ───────────────────────────────────────
class NStepReplayBuffer:
"""
N-step replay buffer. Accumulates n-step discounted returns in a
rolling window before storing the transition.
Stores (s_0, a_0, r_nstep, s_n, done) where:
r_nstep = r_0 + γ r_1 + γ² r_2 + ... + γ^(n-1) r_(n-1)
s_n = state n steps after s_0
Truncates at episode boundaries.
"""
def __init__(self, capacity: int, n_step: int, gamma: float) -> None:
self.capacity = capacity
self.n_step = n_step
self.gamma = gamma
self.buf: list[Transition] = []
self.pos: int = 0
self.pending: list[Transition] = [] # rolling window of recent transitions
def add(self, s: Any, a: int, r: float, s_next: Any, done: float) -> None:
self.pending.append((s, a, r, s_next, done))
# Flush at episode boundary: all pending transitions get truncated n-step returns
if done:
while self.pending:
self._flush_one()
return
# Once we have n transitions queued, the oldest one is ready
if len(self.pending) >= self.n_step:
self._flush_one()
def _flush_one(self) -> None:
"""Pop the oldest pending transition, compute its n-step return, and store it."""
k = min(len(self.pending), self.n_step)
s_0, a_0 = self.pending[0][0], self.pending[0][1]
# Accumulate discounted return, stopping early if a done is encountered
r_nstep = 0.0
for i in range(k):
_, _, r_i, s_i_next, done_i = self.pending[i]
r_nstep += (self.gamma ** i) * r_i
if done_i:
# Episode ended at step i: use this as the terminal transition
self._store(s_0, a_0, r_nstep, s_i_next, done_i)
self.pending.pop(0)
return
# No terminal state within the window: bootstrap from state k steps ahead
s_k = self.pending[k - 1][3] # s_next of the last transition in the window
self._store(s_0, a_0, r_nstep, s_k, 0.0)
self.pending.pop(0)
def _store(self, s: Any, a: int, r: float, s_next: Any, done: float) -> None:
transition = (s, a, r, s_next, done)
if len(self.buf) < self.capacity:
self.buf.append(transition)
else:
self.buf[self.pos] = transition
self.pos = (self.pos + 1) % self.capacity
def sample(self, batch_size: int) -> Batch:
idxs = torch.randint(len(self.buf), (batch_size,))
ss, aa, rr, ss_next, dd = zip(*(self.buf[i] for i in idxs))
return Batch(
s=torch.tensor(ss, dtype=torch.float32),
a=torch.tensor(aa, dtype=torch.long),
r=torch.tensor(rr, dtype=torch.float32),
s_next=torch.tensor(ss_next, dtype=torch.float32),
done=torch.tensor(dd, dtype=torch.float32),
)
def __len__(self) -> int:
return len(self.buf)