MHA is like building several secondary indexes on the same table—each index optimized for a different query pattern. One head = one specialized "attention channel": some learn adjacent-word relations, some learn syntactic dependencies, some learn coreference. Run them in parallel, then concatenate—much like a MapReduce where several reducers each compute a different aggregation.
Pain point: a single attention's softmax produces only one probability distribution per query—it can only "focus one way" at a time. But a word's relation to its context is multi-dimensional (syntax + semantics + position); one head can't hold all of it.
Mechanism: split the d_model dimension (say 512) into h heads (say 8, each d_head = 512/8 = 64). Each head does its own attention in its own low-dimensional subspace, then concatenate + one output projection WO to mix. The core formula (per head):
Attention(Q,K,V) = softmax( Q·KT / √d_head ) · V
Each symbol: Q (query, "what I'm looking for"), K (key, "what I can offer"), V (value, "the content actually carried"). Q·KT scores every token pair's similarity (larger dot product = more relevant); softmax turns it into weights; then weighted-sum the V's. Why divide by √d_head?—higher dimensions make dot products larger, pushing softmax into its saturated region (one weight near 1, the rest near 0), where gradients vanish. Dividing by √d_head pulls the variance back to ~1 so training stays stable. This is the design from 2017's Attention Is All You Need.
import torch, torch.nn as nn # d_model=512 split into 8 heads, each of dim 512/8 = 64 mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True) x = torch.randn(2, 10, 512) # (batch, seq_len, d_model) # self-attention: Q=K=V are all x itself out, attn = mha(x, x, x, need_weights=True, average_attn_weights=False) print(out.shape) # (2, 10, 512) shape unchanged, info mixed print(attn.shape) # (2, 8, 10, 10) 8 heads, each a 10x10 attention matrix # each head learns a different pattern: adjacency, syntax, ...
Think of the KV cache as the "hot data" repeatedly read at inference. MHA gives every query head its own dedicated KV replica (memory explodes); MQA makes all heads share one KV (like all workers sharing a single connection pool); GQA compromises—a few groups, each sharing one. In essence it's "deduplicating / sharing" the KV.
Pain point: during autoregressive decoding, generating each new token requires reading the entire KV cache from HBM. The bigger the KV cache, the more memory bandwidth becomes the bottleneck—slow decoding is usually not about FLOPs but about "moving the KV." And KV cache size ∝ number of KV heads.
# KV cache size is proportional to KV head count. # Estimate the KV cache of one sequence on a 70B-class model: def kv_cache_gb(n_kv_heads, seq, d_head=128, layers=80, dtype=2): # K and V each stored → ×2; dtype=2 means fp16, 2 bytes/number return 2 * n_kv_heads * d_head * seq * layers * dtype / 1e9 seq = 8192 print("MHA (64 KV head):", round(kv_cache_gb(64, seq), 1), "GB") # ~21.5 print("GQA ( 8 KV head):", round(kv_cache_gb(8, seq), 1), "GB") # ~2.7 print("MQA ( 1 KV head):", round(kv_cache_gb(1, seq), 1), "GB") # ~0.3 # GQA cuts KV cache to 1/8 with almost no quality loss → the practical pick
num_key_value_heads on the model card. If it equals num_attention_heads, it's MHA (expensive memory, slow decode at long context); if far smaller (e.g. 8 vs 64), it's GQA (cheaper and faster for long-document tasks). This single number predicts your long-context cost curve.Like a bounded buffer in stream processing: each token only sees the most recent w tokens, not the whole stream. But information isn't trapped—it propagates layer by layer forward, like a gossip protocol spreading hop by hop, or a CNN stacking depth to grow its receptive field.
Pain point: standard attention is O(n²)—every token scores similarity against all n tokens. At n=100K that's ~1010 pairs; long sequences simply can't be computed or stored.
Mechanism: restrict each token to attend to only the previous w (window width), dropping cost to O(n·w)—linear in sequence length. So how does distant info arrive? Through depth: each layer pushes information forward by w tokens; after L layers the effective receptive field ≈ L·w. The Mistral 7B paper states exactly this—"after k layers, information can move forward up to k×W tokens." Longformer (2020) adds a few "global tokens" (e.g. [CLS]) as cross-segment bridges.
import torch seq, window = 10, 3 # sliding-window mask: token i can only see [i-window+1, i] i = torch.arange(seq).unsqueeze(1) j = torch.arange(seq).unsqueeze(0) mask = (j <= i) & (j > i - window) # causal + window → a diagonal band print(mask.int()) # 1=visible 0=masked; visible region is a band, not full triangle # effective receptive field grows linearly with depth for L in (1, 8, 32): print(f"after {L} layers, receptive field ~ {L*window} tokens")
Like an external merge sort: the data is too big for memory, so you process it in blocks that fit in cache, avoiding repeated trips to slow storage. FlashAttention is the "cache-aware algorithm" for attention—it never materializes the n×n attention matrix in memory, computing block by block inside fast on-chip cache. Same math, different memory-access order.
The key counterintuition: attention is memory-bound, not compute-bound. The bottleneck isn't multiplications, it's shuttling that n×n matrix between HBM (big GPU memory, slow) and SRAM (on-chip cache, fast but tiny). The naive version: compute S=QKᵀ, write to HBM → read back for softmax → read back again to multiply by V—dragged down by slow memory throughout.
FlashAttention (Dao 2022) tiles Q, K, V into small blocks, each fitting in SRAM, and computes block by block, never writing the full n×n back to HBM. The catch is that softmax must normalize over an entire row's global max and sum—yet you only see one block at a time. The fix is online softmax: maintain a running max and running sum, and as each new block arrives, rescale the accumulated result by a correction factor exp(old_max − new_max). Numerically stable, and the result is element-for-element identical to the naive version (exact, not approximate). Payoff: memory accesses plummet → 2–4× speedup, memory linear in n rather than quadratic.
import torch import torch.nn.functional as F from torch.nn.attention import SDPBackend, sdpa_kernel # needs a CUDA GPU; q/k/v shape (batch, heads, seq, d_head) q = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16) k = torch.randn_like(q); v = torch.randn_like(q) # force the FlashAttention kernel: never materialize 4096x4096 in HBM with sdpa_kernel(SDPBackend.FLASH_ATTENTION): out = F.scaled_dot_product_attention(q, k, v, is_causal=True) print(out.shape) # (2, 8, 4096, 64) element-for-element same as naive # naive materializes hundreds of MB of intermediate matrix; Flash ~0