Pooling (Mean / Max / CLS)
Pooling (Mean / Max / CLS)
Section titled “Pooling (Mean / Max / CLS)”Reducing a sequence of vectors to a single vector. When a model produces one vector per token but you need one vector for the whole input — for classification, retrieval, or sentence embeddings — you need a pooling strategy. Mean pooling averages all positions, max pooling takes the element-wise maximum, and CLS pooling uses a dedicated special token.
Intuition
Section titled “Intuition”A transformer outputs a matrix of shape (T, d) — one vector per token. But many tasks need a single vector: “is this review positive?”, “how similar are these two sentences?”, “which document matches this query?” Pooling compresses the sequence dimension away.
Mean pooling treats every token as equally important and averages them. This works surprisingly well because the average captures the “centre of mass” of the sequence’s meaning in embedding space. The main pitfall is padding: if your batch has sequences of different lengths padded to the same length, you must mask out the padding tokens before averaging, or the zero-padding will dilute the representation.
Max pooling takes the element-wise maximum across positions. For each dimension of the output vector, it picks the highest activation from any token. This captures the most salient feature for each dimension, regardless of where it appears. It’s less popular than mean pooling for language but dominant in CNN architectures (e.g. max-pool over filter outputs).
CLS pooling uses a special [CLS] token prepended to the input. The model is trained to aggregate sequence-level information into this token through attention. BERT popularised this approach, but it has a subtle problem: the CLS token’s representation depends entirely on what the model has learned to put there, and without explicit training (e.g. next-sentence prediction), it may not contain a useful summary. Modern sentence embedding models (E5, GTE) generally prefer mean pooling over CLS.
Given a sequence of hidden states and a mask indicating real (non-padding) tokens:
Mean pooling:
Max pooling (element-wise across the sequence):
Masked positions should be set to before max, not 0.
CLS pooling:
No mask needed since CLS is always present.
import torchimport torch.nn as nn
hidden = model(input_ids) # (B, T, d) — transformer outputmask = attention_mask # (B, T) — 1 for real tokens, 0 for padding
# ── Mean pooling (with proper masking) ──────────────────────────mask_expanded = mask.unsqueeze(-1).float() # (B, T, 1)sum_hidden = (hidden * mask_expanded).sum(dim=1) # (B, d)count = mask_expanded.sum(dim=1).clamp(min=1e-9) # (B, 1) — avoid div by 0mean_pooled = sum_hidden / count # (B, d)# NEVER just do hidden.mean(dim=1) — it includes padding tokens.
# ── Max pooling (with proper masking) ───────────────────────────hidden_masked = hidden.masked_fill(~mask.unsqueeze(-1).bool(), -1e9)max_pooled = hidden_masked.max(dim=1).values # (B, d)
# ── CLS pooling ─────────────────────────────────────────────────cls_pooled = hidden[:, 0, :] # (B, d)
# ── Adaptive pooling for CNNs (spatial dimensions) ──────────────# Reduces any spatial size to target sizepool = nn.AdaptiveAvgPool2d((1, 1))features = pool(conv_output) # (B, C, H, W) → (B, C, 1, 1)features = features.flatten(1) # (B, C)Manual Implementation
Section titled “Manual Implementation”import numpy as np
def mean_pool(hidden, mask): """ Mean pooling with mask. hidden: (B, T, d) mask: (B, T) with 1=real, 0=padding Returns: (B, d) """ mask_exp = mask[:, :, None] # (B, T, 1) summed = (hidden * mask_exp).sum(axis=1) # (B, d) counts = mask_exp.sum(axis=1).clip(min=1e-9) # (B, 1) return summed / counts # (B, d)
def max_pool(hidden, mask): """ Max pooling with mask. hidden: (B, T, d) mask: (B, T) with 1=real, 0=padding Returns: (B, d) """ masked = np.where(mask[:, :, None], hidden, -1e9) # (B, T, d) return masked.max(axis=1) # (B, d)
def cls_pool(hidden): """CLS token pooling. hidden: (B, T, d). Returns: (B, d).""" return hidden[:, 0, :] # (B, d)Popular Uses
Section titled “Popular Uses”- Sentence embeddings (Sentence-BERT, E5, GTE): mean pooling over transformer outputs to produce fixed-size sentence vectors for retrieval and similarity
- Text classification (BERT): CLS token fed to a classification head for sentiment analysis, NLI, etc.
- CNN image classification (ResNet, EfficientNet): global average pooling over spatial dimensions before the final linear classifier — replaced the fully connected layers from AlexNet/VGG
- Contrastive learning (see
contrastive-self-supervising/): SimCLR and CLIP pool image/text representations to fixed-size vectors before computing similarity - Reinforcement learning (see
policy-gradient/): pooling over observation features (e.g. multiple entities in the environment) before feeding to the policy network
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Weighted mean (attention pooling) | Need learned, input-dependent aggregation | Small MLP computes per-token weights; more expressive but adds parameters |
| Last token pooling | Autoregressive models (GPT-style) | Uses the final non-padding token; natural for causal models where only the last position has seen everything |
| Multi-head attention pooling (Perceiver) | Need multiple summary vectors | Learned query tokens attend to the sequence; powerful but heavier |
| No pooling (per-token output) | Token-level tasks (NER, translation) | Keep all T outputs; pooling would destroy the per-token information needed |
| Concatenation | Very short sequences (2-4 items) | Preserves all information but output size scales with T; only viable for fixed, short lengths |
Historical Context
Section titled “Historical Context”Pooling in neural networks dates to Fukushima’s Neocognitron (1980) and became standard through LeCun’s convolutional networks in the 1990s, where max pooling over local spatial regions provided translation invariance. Global average pooling as a replacement for fully connected layers was proposed by Lin et al. (2014, “Network in Network”) and adopted by GoogLeNet/Inception, dramatically reducing parameter counts.
For transformers, CLS token pooling was introduced by BERT (Devlin et al., 2019), borrowing the concept from earlier classification architectures. However, Reimers & Gurevych (2019, “Sentence-BERT”) showed that mean pooling over all tokens significantly outperformed CLS pooling for sentence similarity tasks, likely because CLS is only optimised for the pre-training objective, not for producing a general sentence representation. This finding made mean pooling the default for embedding models, and modern retrieval systems (E5, GTE, Nomic) all use mean pooling.