AI/ML Explained: Attention Variants

Day 12 · 2026-05-29 · Difficulty ★★★★☆
For: engineers with coding experience, non-AI background

Multi-Head AttentionMHA

baseline archsubspaces
One-line analogy

MHA is like building several secondary indexes on the same table—each index optimized for a different query pattern. One head = one specialized "attention channel": some learn adjacent-word relations, some learn syntactic dependencies, some learn coreference. Run them in parallel, then concatenate—much like a MapReduce where several reducers each compute a different aggregation.

Problem it solves + how it works

Pain point: a single attention's softmax produces only one probability distribution per query—it can only "focus one way" at a time. But a word's relation to its context is multi-dimensional (syntax + semantics + position); one head can't hold all of it.

Mechanism: split the d_model dimension (say 512) into h heads (say 8, each d_head = 512/8 = 64). Each head does its own attention in its own low-dimensional subspace, then concatenate + one output projection WO to mix. The core formula (per head):

Attention(Q,K,V) = softmax( Q·KT / √d_head ) · V

Each symbol: Q (query, "what I'm looking for"), K (key, "what I can offer"), V (value, "the content actually carried"). Q·KT scores every token pair's similarity (larger dot product = more relevant); softmax turns it into weights; then weighted-sum the V's. Why divide by √d_head?—higher dimensions make dot products larger, pushing softmax into its saturated region (one weight near 1, the rest near 0), where gradients vanish. Dividing by √d_head pulls the variance back to ~1 so training stays stable. This is the design from 2017's Attention Is All You Need.

d_model=512 split into 8 parallel heads

input x 512-d
↓ split into 8, each projected to 64-d
head1·64head2·64head8·64
↓ each head runs softmax(QKᵀ/√64)·V independently
attn1attn2attn8
↓ concat back to 512-d → output proj Wᴼ mixes
output 512-d (shape unchanged, info now cross-fused)
Code example
import torch, torch.nn as nn
# d_model=512 split into 8 heads, each of dim 512/8 = 64
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
x = torch.randn(2, 10, 512)   # (batch, seq_len, d_model)

# self-attention: Q=K=V are all x itself
out, attn = mha(x, x, x, need_weights=True, average_attn_weights=False)
print(out.shape)    # (2, 10, 512) shape unchanged, info mixed
print(attn.shape)   # (2, 8, 10, 10) 8 heads, each a 10x10 attention matrix
# each head learns a different pattern: adjacency, syntax, ...
Common misconception + your scenario
"More heads = stronger model"—wrong. Head count is constrained (d_model must be divisible by h), and too many heads makes each d_head too small, hurting expressiveness. Research finds many heads are redundant after training—pruning them barely changes results—so capacity isn't improved by splitting indefinitely.
📌 Cross-disciplinary scenario: MHA is a "parallel multi-perspective" compute paradigm—like a committee reviewing the same sentence simultaneously from syntactic, semantic, and coreference angles, then aggregating. This "many experts in parallel + aggregate at the end" structure is isomorphic to the distributed scatter-gather and multi-expert reviews you already know.
Takeaway + question
💡 "Multi-head" isn't piling on compute—it lets the model learn different relations in parallel across subspaces; total FLOPs is about the same as one big head, the difference is expressiveness.
🤔 If a head is a "specialized perspective," what does the redundancy of heads after training tell you about how neural nets allocate capacity?

Grouped / Multi-Query AttentionGQA / MQA

inference optKV cache
One-line analogy

Think of the KV cache as the "hot data" repeatedly read at inference. MHA gives every query head its own dedicated KV replica (memory explodes); MQA makes all heads share one KV (like all workers sharing a single connection pool); GQA compromises—a few groups, each sharing one. In essence it's "deduplicating / sharing" the KV.

Problem it solves + how it works

Pain point: during autoregressive decoding, generating each new token requires reading the entire KV cache from HBM. The bigger the KV cache, the more memory bandwidth becomes the bottleneck—slow decoding is usually not about FLOPs but about "moving the KV." And KV cache size ∝ number of KV heads.

  • MHA: KV heads = Q heads (e.g. 64). Most expressive, but largest KV cache;
  • MQA (Shazeer 2019): KV heads = 1, shared by all Q heads. KV cache shrinks to 1/64, but quality drops;
  • GQA (Ainslie 2023): KV heads = g (e.g. 8), each group of Q heads shares one. Interpolates between MHA and MQA—the paper shows it reaches "quality close to MHA at speed close to MQA." Llama 2/3 and Mistral default to GQA.
KV head sharing spectrum (8 Q heads in all)

MHA 8 Q ↔ 8 KV: QKV ×8 largest KV cache
GQA 8 Q ↔ 2 KV: Q Q Q QKV | Q Q Q QKV 4× smaller
MQA 8 Q ↔ 1 KV: Q Q Q Q Q Q Q QKV 8× smaller, quality drops
Code example
# KV cache size is proportional to KV head count.
# Estimate the KV cache of one sequence on a 70B-class model:
def kv_cache_gb(n_kv_heads, seq, d_head=128, layers=80, dtype=2):
    # K and V each stored → ×2; dtype=2 means fp16, 2 bytes/number
    return 2 * n_kv_heads * d_head * seq * layers * dtype / 1e9

seq = 8192
print("MHA (64 KV head):", round(kv_cache_gb(64, seq), 1), "GB")  # ~21.5
print("GQA ( 8 KV head):", round(kv_cache_gb(8,  seq), 1), "GB")  # ~2.7
print("MQA ( 1 KV head):", round(kv_cache_gb(1,  seq), 1), "GB")  # ~0.3
# GQA cuts KV cache to 1/8 with almost no quality loss → the practical pick
Common misconception + your scenario
"GQA is for saving training compute"—wrong. GQA/MQA mainly save inference-time KV-cache memory and bandwidth; training gains are small. It's an inference optimization dressed up as architecture—same "what the model knows," just less KV to move.
📌 Model-selection scenario: check num_key_value_heads on the model card. If it equals num_attention_heads, it's MHA (expensive memory, slow decode at long context); if far smaller (e.g. 8 vs 64), it's GQA (cheaper and faster for long-document tasks). This single number predicts your long-context cost curve.
Takeaway + question
💡 GQA/MQA don't change what the model "knows," only how much KV must be moved at inference—it's a memory-access optimization dressed as architecture.
🤔 When the KV cache becomes the "memory wall" for long context, do you cut the number of KV heads, or cut the KV dimension per token? (That's the other road, taken by DeepSeek's MLA.)

Sliding Window AttentionSWA

sparse attnlinear cost
One-line analogy

Like a bounded buffer in stream processing: each token only sees the most recent w tokens, not the whole stream. But information isn't trapped—it propagates layer by layer forward, like a gossip protocol spreading hop by hop, or a CNN stacking depth to grow its receptive field.

Problem it solves + how it works

Pain point: standard attention is O(n²)—every token scores similarity against all n tokens. At n=100K that's ~1010 pairs; long sequences simply can't be computed or stored.

Mechanism: restrict each token to attend to only the previous w (window width), dropping cost to O(n·w)linear in sequence length. So how does distant info arrive? Through depth: each layer pushes information forward by w tokens; after L layers the effective receptive field ≈ L·w. The Mistral 7B paper states exactly this—"after k layers, information can move forward up to k×W tokens." Longformer (2020) adds a few "global tokens" (e.g. [CLS]) as cross-segment bridges.

Window w=3: info spreads hop by hop via depth

token:  t0  t1  t2  t3  t4  t5  t6  t7
layer 1: each token sees only 3 to its left → diagonal band
layer 2: t6 indirectly sees t3 (relayed via t4/t5)
layer L: effective receptive field ≈ L × 3 tokens

how far it can reach → grows linearly with depth
1 layer 3 | 8 layers 24 | 32 layers 96
Code example
import torch
seq, window = 10, 3
# sliding-window mask: token i can only see [i-window+1, i]
i = torch.arange(seq).unsqueeze(1)
j = torch.arange(seq).unsqueeze(0)
mask = (j <= i) & (j > i - window)   # causal + window → a diagonal band
print(mask.int())   # 1=visible 0=masked; visible region is a band, not full triangle

# effective receptive field grows linearly with depth
for L in (1, 8, 32):
    print(f"after {L} layers, receptive field ~ {L*window} tokens")
Common misconception + your scenario
"Sliding window = can only handle short text"—wrong. By stacking layers, the receptive field covers very long sequences. Its real weakness is "single-hop long-range exact retrieval": pulling a fact at token #100K precisely to the current position must be relayed across many layers and is easily distorted—hence the frequent pairing with global tokens or alternating with full-attention layers.
📌 Judgment scenario: when a model claiming 128K context does poorly on "very long-range association" tasks, first ask whether its attention is pure SWA—if so, distant info propagates slowly through depth, and "needle-in-a-haystack + reasoning" is a structural blind spot, not a prompt failure.
Takeaway + question
💡 Local attention + depth = trading "layers" for "reach," the same trick a CNN uses to grow its receptive field with depth.
🤔 Since information propagates hop by hop through layers, what failure mode does that imply for needle-in-a-haystack tasks? How should you trade depth against window width?

FlashAttentionIO-Aware Exact Attention

memory-access optexact
One-line analogy

Like an external merge sort: the data is too big for memory, so you process it in blocks that fit in cache, avoiding repeated trips to slow storage. FlashAttention is the "cache-aware algorithm" for attention—it never materializes the n×n attention matrix in memory, computing block by block inside fast on-chip cache. Same math, different memory-access order.

Problem it solves + how it works

The key counterintuition: attention is memory-bound, not compute-bound. The bottleneck isn't multiplications, it's shuttling that n×n matrix between HBM (big GPU memory, slow) and SRAM (on-chip cache, fast but tiny). The naive version: compute S=QKᵀ, write to HBM → read back for softmax → read back again to multiply by V—dragged down by slow memory throughout.

FlashAttention (Dao 2022) tiles Q, K, V into small blocks, each fitting in SRAM, and computes block by block, never writing the full n×n back to HBM. The catch is that softmax must normalize over an entire row's global max and sum—yet you only see one block at a time. The fix is online softmax: maintain a running max and running sum, and as each new block arrives, rescale the accumulated result by a correction factor exp(old_max − new_max). Numerically stable, and the result is element-for-element identical to the naive version (exact, not approximate). Payoff: memory accesses plummet → 2–4× speedup, memory linear in n rather than quadratic.

Bottleneck is movement, not compute

HBM big·slow ↔ shuttle n×n matrix repeatedly ↔ SRAM small·fast compute

Naive: QKᵀ→write HBM→read→softmax→write→read→×V (materializes full n×n)
Flash: tile Q/K/V into SRAM → in-block compute + online-softmax accumulate → n×n never hits HBM

same output, memory O(n²)→O(n), wall-clock 2–4× faster
Code example
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
# needs a CUDA GPU; q/k/v shape (batch, heads, seq, d_head)
q = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
k = torch.randn_like(q); v = torch.randn_like(q)

# force the FlashAttention kernel: never materialize 4096x4096 in HBM
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(out.shape)   # (2, 8, 4096, 64) element-for-element same as naive
# naive materializes hundreds of MB of intermediate matrix; Flash ~0
Common misconception + your scenario
"FlashAttention is a sparse / approximate attention that loses precision"—wrong. It is exact; its output is element-for-element identical to the naive version. It optimizes how you compute (memory-access order), not what you compute (the math). That's entirely unlike sparse attention such as SWA, which genuinely changes the math.
📌 Selection scenario: when the same model runs much faster on inference frameworks like vLLM / SGLang, IO-aware kernels like FlashAttention are a key source. Understanding that "the speedup comes from memory-access optimization, not a different model" helps you judge whether a performance claim is real or quietly trading away precision.
Takeaway + question
💡 Same math, different memory-access order, and it's 2–4× faster—performance bottlenecks are often in memory, not FLOPs.
🤔 How many times in your past system work did you "think it was CPU-bound when it was actually I/O-bound"? How does that intuition transfer to diagnosing AI systems?

Further Reading

Deep Questions

1. GQA (changes the math/architecture) and FlashAttention (changes memory access, not the math) both make attention faster, yet they belong to two completely different classes of optimization. Why does the distinction matter?
This is the key watershed for understanding the whole field. GQA/MQA/SWA change "what you compute"—they alter the model's mathematical definition: GQA reduces KV heads (trading a bit of expressiveness for memory), SWA narrows each token's view (trading direct long-range links for linear cost). These optimizations change the model's output and must be decided at training time (or via uptraining), a true quality–cost trade-off. FlashAttention changes "how you compute"—the math is untouched, only the GPU memory-access order is reordered, and the output is element-for-element identical. It's a free lunch: no precision cost, any model can adopt it directly. Why it matters: (a) can it be added after the fact—Flash can accelerate an existing model; GQA cannot, it needs retraining or uptraining; (b) different evaluation—Flash only needs numerical equivalence checks and timing; GQA/SWA require downstream quality evals because they lose points; (c) composability—the two are orthogonal and combine (Mistral = GQA + SWA + Flash stacked). For your distributed intuition: Flash is like "swapping in a better I/O scheduler," GQA like "denormalizing your DB schema"—one leaves semantics intact, the other changes them.
2. The KV cache is the "memory wall" of long context. GQA cuts the number of KV heads; DeepSeek's MLA cuts the KV dimension per token—two roads compressing different axes of the same thing. Which do you think has more of a future?
KV cache size ≈ (KV head count) × (per-head dim) × seq × layers, so "compressing KV" has two orthogonal axes. GQA works the "head count" axis: let multiple query heads share fewer KV heads—simple, training-stable, uptrainable from an MHA checkpoint, the current industry default. The cost is expressiveness loss from sharing, and the compression ratio is bounded (KV heads can't go below 1, and at MQA's 1 the quality clearly drops). MLA (Multi-head Latent Attention) works the "dimension" axis: jointly compress K and V into one low-dimensional latent vector in the cache, projecting back when used—essentially "lossy compress + decompress" of the KV, with a compression ratio far beyond GQA, and DeepSeek reports quality holding or even improving. The cost is implementation complexity and special handling for RoPE compatibility. My read: short-term, GQA stays mainstream for simplicity + ecosystem maturity; but the "dimension axis" has a higher theoretical ceiling (redundancy can be shared across heads) and is sexier long-term. Deeper still: both assume the KV holds compressible redundancy—pointing at a fundamental question, how much "memory bandwidth" does attention actually need? That in turn leads to Mamba/SSM-style routes that replace the KV cache with a fixed-size state (Day 34).
3. Why did an optimization like FlashAttention—"reorder memory access and it's 2-4× faster"—take years to appear? What does it reveal about hardware–software co-design?
On the surface it looks "late"; underneath, the abstraction layers hid the hardware reality. Researchers write attention at the PyTorch/math layer—matmul, softmax—where "compute" and "memory" are invisible, so the bottleneck looks like it ought to be FLOPs. But the GPU truth is: compute (TFLOPs) has grown far faster than memory bandwidth this past decade, so more and more kernels are memory-bound—compute sits idle while time is spent waiting on HBM. FlashAttention's insight isn't a new algorithm (online softmax dates to the 1980s) but that someone was willing to drop to the CUDA layer and redesign the kernel around the GPU's memory hierarchy (HBM/SRAM sizes). The lessons: (a) abstraction has a cost—higher-level abstractions make you "blind" to the real bottleneck, and major optimizations often come from piercing the abstraction and rewriting close to the hardware; (b) bottlenecks drift—hardware evolution (the compute-vs-bandwidth scissors) turns yesterday's non-bottleneck into today's, so optimization must chase hardware trends; (c) IO-awareness is a general method—the same idea later spread to many kernels. For you: this matches your distributed experience exactly—a slow DB query is nine times out of ten not CPU but disk/network I/O; the real experts optimize data movement, not computation.
4. Stringing today's four variants together: from MHA in 2017 to now, what's the main thread of attention's evolution? Where does it go next?
The thread in one line: a constant tug-of-war between expressiveness and scalability (long sequences + low cost), gradually pushing cost from quadratic toward linear. (1) MHA (2017) set the paradigm, but O(n²) compute + a KV cache that grows linearly with sequence length doomed it on long sequences. (2) Sparsification (SWA/Longformer, 2020) attacks "what you compute" to cut computational complexity, O(n²)→O(n·w), at the cost of relaying long-range info through depth. (3) KV sharing (MQA 2019 / GQA 2023) attacks the inference bottleneck, cutting KV-cache memory and bandwidth so long-context decoding becomes affordable. (4) FlashAttention (2022) switches dimensions—doesn't touch the math, cuts memory access close to the hardware, making MHA itself faster and leaner. The four are orthogonal and stackable; today's frontier models are basically "GQA + local/global mix + Flash kernel" combined. Two lines ahead: (a) keep compressing KV—MLA (dimension compression), KV quantization, even KV eviction (dropping unimportant history); (b) step outside attention—Mamba/SSM replaces the ever-growing KV cache with a fixed-size state, turning "memory" from O(n) to O(1). An open bet: will the future be dominated by "attention + these optimizations," or by hybrid architectures (some layers SSM, some attention)? Right now hybrids have strong momentum.
5. Attention's "cost" forces architects into lossy trade-offs (GQA loses a little quality, SWA loses long range). This "sacrifice perfection for scale" engineering philosophy—what isomorphic shadows does it have in the distributed systems you know?
This is almost the same worldview projected onto different domains. CAP / eventual consistency in distributed systems: give up strong consistency for availability and partition tolerance—isomorphic to GQA "giving up a little expressiveness for memory," both admitting "perfection doesn't scale, you must give something up selectively." Caching and sampling: a CDN/Redis cache trades a bit of staleness for huge throughput; SWA trades a bit of long-range precision for linear cost—both lossy but worth it. LSM-tree / columnar denormalization: sacrifice storage redundancy and immediate consistency for read/write throughput, the same ledger as "compress KV cache for long context." Approximate algorithms: HyperLogLog, Bloom filters trade bounded error for orders-of-magnitude space savings—in spirit "exact is too expensive, go approximate"—while FlashAttention is precisely the counterexample (insist on exact, optimize only memory access), a reminder that "lossy" isn't the only path; sometimes a different implementation lets you have both. The deeper point: scale itself rewrites what "correct" means—on small data you chase the exact solution, at scale "good enough and scalable" usually beats "perfect but unscalable." As you design "AI super-individual" systems, this judgment—knowing on which axis you can be lossy, how lossy, and how to compensate—is worth more than memorizing any single technique.