Generalised Advantage Estimation
Generalised Advantage Estimation
Section titled “Generalised Advantage Estimation”Exponentially weighted average of multi-step TD errors that smoothly interpolates between high-bias (TD) and high-variance (Monte Carlo) advantage estimates. The standard advantage estimator in PPO and A2C: A_t = sum over l of (gamma * lambda)^l * delta_{t+l}. Also known as GAE or GAE(lambda).
Intuition
Section titled “Intuition”Policy gradient methods need to know “how much better was this action than average?” — that’s the advantage. The simplest estimate is the one-step TD error: delta = r + gamma * V(s’) - V(s). This is low variance (just one random step) but biased (V might be wrong). At the other extreme, you could use the full Monte Carlo return minus V(s) — unbiased but noisy because it sums up randomness from every future step.
GAE provides a knob, lambda, that blends these extremes. When lambda = 0, you get pure one-step TD (low variance, high bias). When lambda = 1, you get Monte Carlo (high variance, zero bias from bootstrapping). In practice, lambda = 0.95 works well across most tasks — it uses mostly multi-step information while still damping the variance from distant future rewards.
The mechanism is simple: compute the TD error at each timestep, then form a weighted sum where each successive TD error is discounted by (gamma * lambda). This is computed efficiently with a backward pass through the trajectory, accumulating: A_t = delta_t + (gamma * lambda) * A_{t+1}. This single backward sweep is all you need.
One-step TD error (the building block):
N-step advantage:
GAE (exponentially-weighted average of all n-step advantages):
This is equivalent to an exponentially-weighted average of n-step advantages:
Recursive computation (the form you actually implement):
with at the trajectory boundary. Sweep backward from t = T-1 to 0.
Special cases:
- : (one-step TD, high bias, low variance)
- : (Monte Carlo advantage, zero bias, high variance)
import torch
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95): """ Compute GAE advantages for a batch of trajectories. rewards: (T,) rewards at each timestep values: (T+1,) value estimates — values[T] is V(s_T), the bootstrap dones: (T,) 1.0 if episode ended at this step, else 0.0 Returns: advantages (T,), returns (T,) """ T = len(rewards) advantages = torch.zeros(T) gae = 0.0 for t in reversed(range(T)): # If done, next state value is 0 (no bootstrapping past episode end) next_value = values[t + 1] * (1 - dones[t]) delta = rewards[t] + gamma * next_value - values[t] # TD error gae = delta + gamma * lam * (1 - dones[t]) * gae # accumulate advantages[t] = gae returns = advantages + values[:T] # (T,) — GAE return targets for value fn return advantages, returns
# ── Usage in PPO ────────────────────────────────────────────────# advantages are normalised before computing the policy loss:advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)Manual Implementation
Section titled “Manual Implementation”import numpy as np
def compute_gae_numpy(rewards, values, dones, gamma=0.99, lam=0.95): """ Equivalent to the PyTorch version above, in pure numpy. rewards: (T,) array of rewards values: (T+1,) array of value estimates (last entry is bootstrap value) dones: (T,) array of done flags (1.0 = terminal) Returns: advantages (T,), returns (T,) """ T = len(rewards) advantages = np.zeros(T, dtype=np.float64) # use float64 to avoid accumulation error gae = 0.0
# Backward sweep: A_t = delta_t + gamma * lambda * A_{t+1} for t in range(T - 1, -1, -1): mask = 1.0 - dones[t] # 0 at episode boundary next_val = values[t + 1] * mask # no bootstrap past done delta = rewards[t] + gamma * next_val - values[t] # TD error gae = delta + gamma * lam * mask * gae # recursive GAE advantages[t] = gae
returns = advantages + values[:T] # V(s) + A(s) = target return return advantages, returnsPopular Uses
Section titled “Popular Uses”- PPO (see
policy-gradient/): GAE is the default advantage estimator. PPO collects a trajectory, computes GAE, then runs multiple epochs of clipped policy updates on those advantages - A2C / A3C (see
policy-gradient/): uses GAE (or n-step returns, a special case) to compute advantages for the policy gradient - TRPO: the original GAE paper (Schulman et al., 2016) was co-developed with TRPO and designed specifically for trust-region policy optimisation
- IMPALA, APPO: distributed actor-critic methods that compute GAE on collected trajectories before sending updates to the learner
- RL from Human Feedback (RLHF): PPO applied to language models uses GAE with a KL penalty reward to fine-tune LLMs
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| One-step TD error | Simple environments, fast critic learning | Lowest variance but highest bias; fine when V is accurate |
| N-step returns | Fixed horizon, simple implementation | Discrete choice of n rather than smooth interpolation; no lambda to tune |
| Monte Carlo returns | Short episodes, need unbiased signal | Zero bootstrap bias but variance grows linearly with episode length |
| V-trace (IMPALA) | Off-policy correction needed (distributed RL) | Truncated importance weights prevent high-variance off-policy corrections |
| Retrace(lambda) | Safe off-policy advantage estimation | Provably convergent off-policy; more complex than GAE |
Historical Context
Section titled “Historical Context”GAE was introduced by Schulman, Moritz, Levine, Jordan, and Abbeel in 2016 (“High-Dimensional Continuous Control Using Generalized Advantage Estimation”). It built on the much older idea of eligibility traces and TD(lambda) from Sutton (1988), but recast specifically for policy gradient methods where you need advantage estimates rather than value estimates.
The practical impact was enormous: GAE solved the “what advantage estimator should I use?” question with a principled, tunable answer. The lambda = 0.95, gamma = 0.99 defaults from the original paper remain the standard in virtually every PPO implementation today, including those used for RLHF in ChatGPT and similar systems. The recursive backward computation is trivially parallelisable across independent episodes and adds negligible overhead.