Scaled Dot-Product Attention
Scaled Dot-Product Attention
Section titled “Scaled Dot-Product Attention”The core computation of the transformer: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V. Computes a weighted average of value vectors, where each weight is determined by how well a query matches a key. Scaling by sqrt(d_k) prevents softmax saturation at high dimensions.
Intuition
Section titled “Intuition”Think of attention as a soft database lookup. You have a query (“what am I looking for?”), a set of keys (“what does each position advertise?”), and values (“what does each position actually contain?”). The dot product between the query and each key measures relevance — a high dot product means “this key matches my query well.” Softmax converts these raw scores into weights that sum to 1, and the output is a weighted average of the values.
The scaling by sqrt(d_k) is essential but easy to overlook. When d_k is large (e.g. 128), the dot products between random vectors grow proportionally to sqrt(d_k) — this is because the sum of d_k random products has variance proportional to d_k. Without scaling, the softmax inputs are large, pushing softmax into saturation where it outputs near-one-hot vectors. Saturated softmax has near-zero gradients, killing learning. Dividing by sqrt(d_k) keeps the variance of the dot products at ~1 regardless of dimension, keeping softmax in its useful gradient range.
The key insight is that attention is a dynamic, input-dependent operation. Unlike a linear layer (which applies the same weights to every input), attention recomputes its weights for every new input. This is what gives transformers their power — each token can selectively attend to the most relevant parts of the input.
Scaled dot-product attention (single head):
where , , .
The attention weights matrix:
Each row of A sums to 1 — it’s a distribution over key positions for each query position.
With causal mask (autoregressive / decoder-only models):
Implemented by setting masked positions to before softmax, so they become 0 after exponentiation.
Why sqrt(d_k)? If entries of Q and K are independent with mean 0 and variance 1, then has variance . Dividing by restores unit variance.
import torchimport torch.nn.functional as F
# ── Using PyTorch's built-in (fused, memory-efficient) ──────────# Available since PyTorch 2.0. Uses FlashAttention under the hood.Q = ... # (B, T_q, d_k) — query vectorsK = ... # (B, T_k, d_k) — key vectorsV = ... # (B, T_k, d_v) — value vectors
out = F.scaled_dot_product_attention(Q, K, V) # (B, T_q, d_v)
# With causal mask (autoregressive):out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# With custom attention mask:# mask should be (T_q, T_k) or (B, T_q, T_k), True = attend, False = ignoreout = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)
# With dropout (training only):out = F.scaled_dot_product_attention(Q, K, V, dropout_p=0.1)
# ── Manual version (for understanding, NOT for production) ──────scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, T_q, T_k)scores = scores / (Q.size(-1) ** 0.5) # scale by sqrt(d_k)# NEVER forget the scaling — without it, training will be unstable# for d_k > ~16.
if causal: mask = torch.triu(torch.ones(T_q, T_k, dtype=torch.bool), diagonal=1) scores.masked_fill_(mask, float('-inf'))
weights = F.softmax(scores, dim=-1) # (B, T_q, T_k)out = torch.matmul(weights, V) # (B, T_q, d_v)Manual Implementation
Section titled “Manual Implementation”import numpy as np
def scaled_dot_product_attention(Q, K, V, causal=False): """ Equivalent to F.scaled_dot_product_attention. Q: (B, T_q, d_k) K: (B, T_k, d_k) V: (B, T_k, d_v) Returns: (B, T_q, d_v) """ d_k = Q.shape[-1] scores = Q @ K.transpose(0, 2, 1) # (B, T_q, T_k) scores = scores / np.sqrt(d_k) # prevent softmax saturation
if causal: T_q, T_k = scores.shape[1], scores.shape[2] mask = np.triu(np.ones((T_q, T_k), dtype=bool), k=1) scores[:, mask] = -1e9 # ~negative infinity
# Numerically stable softmax: subtract max before exp scores_max = scores.max(axis=-1, keepdims=True) # (B, T_q, 1) exp_scores = np.exp(scores - scores_max) # (B, T_q, T_k) weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True) # (B, T_q, T_k)
return weights @ V # (B, T_q, d_v)Popular Uses
Section titled “Popular Uses”- Transformer self-attention (see
transformer/): Q, K, V all come from the same sequence — each token attends to all others. This is the mechanism that lets GPT, BERT, and ViT process sequences - Transformer cross-attention (encoder-decoder models, Stable Diffusion): Q from one sequence, K and V from another. Used for conditioning (e.g. text conditions image generation in
diffusion/) - Multi-head attention: run scaled dot-product attention in parallel across multiple heads with different projections, then concatenate. MHA, MQA, and GQA (see
transformer/) all use this as their inner operation - Contrastive learning (see
contrastive-self-supervising/): CLIP’s similarity matrix is essentially attention scores without the value multiplication — cosine similarity is a normalised dot product
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Linear attention (Katharopoulos et al., 2020) | Need O(T) complexity instead of O(T^2) | Approximates softmax attention with kernel feature maps; lower quality for long-range dependencies |
| Local / sliding window attention (Mistral, Longformer) | Very long sequences where full attention is too expensive | Each token only attends to a fixed window; misses global patterns unless combined with global tokens |
| Additive attention (Bahdanau et al., 2015) | Historical; original seq2seq attention | Uses a small MLP instead of dot product; slower but was the first attention mechanism for NMT |
| State-space models (Mamba, S4) | Sequential processing with subquadratic cost | Replace attention entirely with recurrent-style computation; competitive on language but less proven on other modalities |
| Cross-attention with learned queries (Perceiver) | Need to reduce sequence length | Fixed set of learned queries attend to the input; compresses arbitrary-length input to fixed size |
Historical Context
Section titled “Historical Context”Dot-product attention appeared in Luong et al. (2015) as a simpler alternative to Bahdanau’s additive attention for machine translation. The critical “scaled” version was introduced in “Attention Is All You Need” (Vaswani et al., 2017), which showed that attention alone — without any recurrence or convolution — was sufficient for state-of-the-art sequence modelling. The scaling factor was a practical necessity they discovered during development: without it, the softmax gradients vanished for the dimensions they were using (d_k = 64).
The major practical evolution has been in implementation efficiency. FlashAttention (Dao et al., 2022) reorganised the computation to be IO-aware, reducing memory usage from O(T^2) to O(T) and providing 2-4x wall-clock speedups. This is now the default backend in PyTorch’s F.scaled_dot_product_attention. Multi-query attention (MQA) and grouped-query attention (GQA) are complementary optimisations that reduce the KV cache size for inference, but the core scaled dot-product computation remains identical.