Skip to content

Scaled Dot-Product Attention

The core computation of the transformer: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V. Computes a weighted average of value vectors, where each weight is determined by how well a query matches a key. Scaling by sqrt(d_k) prevents softmax saturation at high dimensions.

Think of attention as a soft database lookup. You have a query (“what am I looking for?”), a set of keys (“what does each position advertise?”), and values (“what does each position actually contain?”). The dot product between the query and each key measures relevance — a high dot product means “this key matches my query well.” Softmax converts these raw scores into weights that sum to 1, and the output is a weighted average of the values.

The scaling by sqrt(d_k) is essential but easy to overlook. When d_k is large (e.g. 128), the dot products between random vectors grow proportionally to sqrt(d_k) — this is because the sum of d_k random products has variance proportional to d_k. Without scaling, the softmax inputs are large, pushing softmax into saturation where it outputs near-one-hot vectors. Saturated softmax has near-zero gradients, killing learning. Dividing by sqrt(d_k) keeps the variance of the dot products at ~1 regardless of dimension, keeping softmax in its useful gradient range.

The key insight is that attention is a dynamic, input-dependent operation. Unlike a linear layer (which applies the same weights to every input), attention recomputes its weights for every new input. This is what gives transformers their power — each token can selectively attend to the most relevant parts of the input.

Scaled dot-product attention (single head):

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

where QRTq×dkQ \in \mathbb{R}^{T_q \times d_k}, KRTk×dkK \in \mathbb{R}^{T_k \times d_k}, VRTk×dvV \in \mathbb{R}^{T_k \times d_v}.

The attention weights matrix:

A=softmax ⁣(QKTdk)RTq×TkA = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) \in \mathbb{R}^{T_q \times T_k}

Each row of A sums to 1 — it’s a distribution over key positions for each query position.

With causal mask (autoregressive / decoder-only models):

Aij={softmaxj ⁣(qikjdk)if ji0if j>iA_{ij} = \begin{cases} \text{softmax}_j\!\left(\frac{q_i \cdot k_j}{\sqrt{d_k}}\right) & \text{if } j \leq i \\ 0 & \text{if } j > i \end{cases}

Implemented by setting masked positions to -\infty before softmax, so they become 0 after exponentiation.

Why sqrt(d_k)? If entries of Q and K are independent with mean 0 and variance 1, then qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i has variance dkd_k. Dividing by dk\sqrt{d_k} restores unit variance.

import torch
import torch.nn.functional as F
# ── Using PyTorch's built-in (fused, memory-efficient) ──────────
# Available since PyTorch 2.0. Uses FlashAttention under the hood.
Q = ... # (B, T_q, d_k) — query vectors
K = ... # (B, T_k, d_k) — key vectors
V = ... # (B, T_k, d_v) — value vectors
out = F.scaled_dot_product_attention(Q, K, V) # (B, T_q, d_v)
# With causal mask (autoregressive):
out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# With custom attention mask:
# mask should be (T_q, T_k) or (B, T_q, T_k), True = attend, False = ignore
out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)
# With dropout (training only):
out = F.scaled_dot_product_attention(Q, K, V, dropout_p=0.1)
# ── Manual version (for understanding, NOT for production) ──────
scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, T_q, T_k)
scores = scores / (Q.size(-1) ** 0.5) # scale by sqrt(d_k)
# NEVER forget the scaling — without it, training will be unstable
# for d_k > ~16.
if causal:
mask = torch.triu(torch.ones(T_q, T_k, dtype=torch.bool), diagonal=1)
scores.masked_fill_(mask, float('-inf'))
weights = F.softmax(scores, dim=-1) # (B, T_q, T_k)
out = torch.matmul(weights, V) # (B, T_q, d_v)
import numpy as np
def scaled_dot_product_attention(Q, K, V, causal=False):
"""
Equivalent to F.scaled_dot_product_attention.
Q: (B, T_q, d_k) K: (B, T_k, d_k) V: (B, T_k, d_v)
Returns: (B, T_q, d_v)
"""
d_k = Q.shape[-1]
scores = Q @ K.transpose(0, 2, 1) # (B, T_q, T_k)
scores = scores / np.sqrt(d_k) # prevent softmax saturation
if causal:
T_q, T_k = scores.shape[1], scores.shape[2]
mask = np.triu(np.ones((T_q, T_k), dtype=bool), k=1)
scores[:, mask] = -1e9 # ~negative infinity
# Numerically stable softmax: subtract max before exp
scores_max = scores.max(axis=-1, keepdims=True) # (B, T_q, 1)
exp_scores = np.exp(scores - scores_max) # (B, T_q, T_k)
weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True) # (B, T_q, T_k)
return weights @ V # (B, T_q, d_v)
  • Transformer self-attention (see transformer/): Q, K, V all come from the same sequence — each token attends to all others. This is the mechanism that lets GPT, BERT, and ViT process sequences
  • Transformer cross-attention (encoder-decoder models, Stable Diffusion): Q from one sequence, K and V from another. Used for conditioning (e.g. text conditions image generation in diffusion/)
  • Multi-head attention: run scaled dot-product attention in parallel across multiple heads with different projections, then concatenate. MHA, MQA, and GQA (see transformer/) all use this as their inner operation
  • Contrastive learning (see contrastive-self-supervising/): CLIP’s similarity matrix is essentially attention scores without the value multiplication — cosine similarity is a normalised dot product
AlternativeWhen to useTradeoff
Linear attention (Katharopoulos et al., 2020)Need O(T) complexity instead of O(T^2)Approximates softmax attention with kernel feature maps; lower quality for long-range dependencies
Local / sliding window attention (Mistral, Longformer)Very long sequences where full attention is too expensiveEach token only attends to a fixed window; misses global patterns unless combined with global tokens
Additive attention (Bahdanau et al., 2015)Historical; original seq2seq attentionUses a small MLP instead of dot product; slower but was the first attention mechanism for NMT
State-space models (Mamba, S4)Sequential processing with subquadratic costReplace attention entirely with recurrent-style computation; competitive on language but less proven on other modalities
Cross-attention with learned queries (Perceiver)Need to reduce sequence lengthFixed set of learned queries attend to the input; compresses arbitrary-length input to fixed size

Dot-product attention appeared in Luong et al. (2015) as a simpler alternative to Bahdanau’s additive attention for machine translation. The critical “scaled” version was introduced in “Attention Is All You Need” (Vaswani et al., 2017), which showed that attention alone — without any recurrence or convolution — was sufficient for state-of-the-art sequence modelling. The scaling factor was a practical necessity they discovered during development: without it, the softmax gradients vanished for the dimensions they were using (d_k = 64).

The major practical evolution has been in implementation efficiency. FlashAttention (Dao et al., 2022) reorganised the computation to be IO-aware, reducing memory usage from O(T^2) to O(T) and providing 2-4x wall-clock speedups. This is now the default backend in PyTorch’s F.scaled_dot_product_attention. Multi-query attention (MQA) and grouped-query attention (GQA) are complementary optimisations that reduce the KV cache size for inference, but the core scaled dot-product computation remains identical.