MHA 就像在同一张表上建多个二级索引——每个索引为一种查询模式优化。一个 head = 一个专用"关注通道":有的学相邻词关系,有的学句法依赖,有的学指代。并行跑完再把结果拼起来,类似一次 MapReduce 里多个 reducer 各算一种聚合。
痛点:单个 attention 的 softmax 对每个 query 只能产生一个概率分布——也就是只能"用一种方式聚焦"一次。但一个词和上下文的关系是多维的(语法 + 语义 + 位置),单头表达不下。
机制:把 d_model 维(比如 512)切成 h 个 head(比如 8 个,每个 d_head = 512/8 = 64),每个 head 在自己的低维子空间里独立做一次注意力,最后拼接 + 一个输出投影 WO 混合。核心公式(每个 head):
Attention(Q,K,V) = softmax( Q·KT / √d_head ) · V
每个符号:Q(query,"我在找什么")、K(key,"我能提供什么")、V(value,"实际要传的内容")。Q·KT 算每对 token 的相似度(点积越大越相关);softmax 把它变成权重;再加权求和 V。为什么除 √d_head?——维度越高点积数值越大,会把 softmax 推向饱和区(一个权重接近 1、其余接近 0),梯度几乎消失。除以 √d_head 把方差拉回 ~1,训练才稳定。这是 2017 年 Attention Is All You Need 的设计。
import torch, torch.nn as nn # d_model=512 拆成 8 个 head,每个 head 维度 = 512/8 = 64 mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True) x = torch.randn(2, 10, 512) # (batch, seq_len, d_model) # self-attention: Q=K=V 都是 x 自己 out, attn = mha(x, x, x, need_weights=True, average_attn_weights=False) print(out.shape) # (2, 10, 512) 形状不变,信息已混合 print(attn.shape) # (2, 8, 10, 10) 8 个 head 各一张 10×10 注意力矩阵 # 每个 head 学到不同关注模式:有的盯相邻词,有的盯句法依赖
把 KV cache 想成推理时被反复读取的"热数据"。MHA 给每个 query head 配一份专属 KV 副本(显存爆);MQA 让所有 head 共享一份 KV(像所有 worker 共用一个连接池);GQA 折中——分几组、每组共享一份。本质是给 KV 做"去重 / 共享"。
痛点:自回归解码时,每生成一个新 token,都要把整个 KV cache 从 HBM 读一遍。KV cache 越大,内存带宽越成为瓶颈——解码慢的主因往往不是算力,是"搬 KV"。而 KV cache 大小 ∝ KV head 数量。
# KV cache 大小 ∝ KV head 数。估算 70B 级模型单条序列的 KV cache: def kv_cache_gb(n_kv_heads, seq, d_head=128, layers=80, dtype=2): # K 和 V 各存一份 → ×2; dtype=2 表示 fp16 每个数 2 字节 return 2 * n_kv_heads * d_head * seq * layers * dtype / 1e9 seq = 8192 print("MHA (64 KV head):", round(kv_cache_gb(64, seq), 1), "GB") # ~21.5 print("GQA ( 8 KV head):", round(kv_cache_gb(8, seq), 1), "GB") # ~2.7 print("MQA ( 1 KV head):", round(kv_cache_gb(1, seq), 1), "GB") # ~0.3 # GQA 把 KV cache 砍到 1/8,几乎不掉质量 → 长上下文的现实选择
num_key_value_heads。若它 = num_attention_heads,是 MHA(长上下文显存贵、解码慢);若远小于(如 8 vs 64),是 GQA(长文档任务更省、更快)。这一个数字直接预示你跑长上下文的成本曲线。像流处理里的有界缓冲(bounded buffer):每个 token 只看最近 w 个 token,不再看全程。但信息不会被困死——靠层层向前传递,类似 gossip 协议逐跳扩散,或 CNN 用深度堆出大感受野。
痛点:标准注意力是 O(n²)——每个 token 和全部 n 个 token 算相似度。n=100K 时就是约 1010 对,长序列直接算不动、存不下。
机制:限制每个 token 只注意前 w 个(窗口宽度),复杂度降到 O(n·w)——对序列长度线性。那远处的信息怎么传过来?靠深度:每过一层,信息向前推进 w 个 token;叠 L 层后,有效感受野 ≈ L·w。Mistral 7B 论文正是这么说的——"k 层之后信息可前移最多 k×W 个 token"。Longformer (2020) 进一步加少量"全局 token"(如 [CLS])当跨段桥梁。
import torch seq, window = 10, 3 # 滑动窗口掩码:token i 只能看 [i-window+1, i] 这 window 个位置 i = torch.arange(seq).unsqueeze(1) j = torch.arange(seq).unsqueeze(0) mask = (j <= i) & (j > i - window) # 因果 + 窗口 → 一条对角带 print(mask.int()) # 1=可见 0=屏蔽,可见区是对角带不是整个三角 # 有效感受野随层数线性增长 for L in (1, 8, 32): print(f"{L} 层后感受野 ≈ {L*window} tokens")
像外部归并排序:数据大到装不进内存,就分块、每块装进缓存处理,避免把全量反复搬进慢速存储。FlashAttention 是 attention 的"cache-aware 算法"——从不在显存里物化那张 n×n 的注意力矩阵,分块在片上高速缓存里算完。同样的数学,换访存顺序。
关键反直觉:attention 是 memory-bound(访存受限),不是 compute-bound(算力受限)。瓶颈不在乘法次数,而在那张 n×n 矩阵在 HBM(GPU 大显存,慢) 和 SRAM(片上缓存,快但极小) 之间反复搬运。朴素实现:算 S=QKᵀ 写回 HBM → 读回做 softmax → 再读回乘 V,全程被慢显存拖死。
FlashAttention (Dao 2022) 把 Q、K、V 切成小块(tile),每块装进 SRAM,一块一块算,永不把完整 n×n 写回 HBM。难点是 softmax 要按整行的全局最大值和总和归一化——而你一次只看到一块。解法是 online softmax(在线 softmax):维护"running max"和"running sum",每来一个新块就用修正因子 exp(旧max − 新max) 重新缩放已累积的结果。数值稳定、且结果与朴素实现逐元素完全相同(exact,非近似)。收益:访存暴减 → 2–4× 提速,显存随 n 线性而非平方。
import torch import torch.nn.functional as F from torch.nn.attention import SDPBackend, sdpa_kernel # 需要支持 CUDA 的 GPU; q/k/v 形状 (batch, heads, seq, d_head) q = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16) k = torch.randn_like(q); v = torch.randn_like(q) # 强制走 FlashAttention 内核:从不在 HBM 物化 4096×4096 注意力矩阵 with sdpa_kernel(SDPBackend.FLASH_ATTENTION): out = F.scaled_dot_product_attention(q, k, v, is_causal=True) print(out.shape) # (2, 8, 4096, 64) 结果与朴素实现逐元素相同 # 朴素实现要物化数百 MB 的中间矩阵;Flash 近乎为 0