AI/ML 详解:Attention 变种

Day 12 · 2026-05-29 · 难度 ★★★★☆
面向:有编程经验的非 AI 方向工程师

多头注意力Multi-Head Attention · MHA

基线架构子空间
一句话类比

MHA 就像在同一张表上建多个二级索引——每个索引为一种查询模式优化。一个 head = 一个专用"关注通道":有的学相邻词关系,有的学句法依赖,有的学指代。并行跑完再把结果拼起来,类似一次 MapReduce 里多个 reducer 各算一种聚合。

它解决什么问题 + 工作机制

痛点:单个 attention 的 softmax 对每个 query 只能产生一个概率分布——也就是只能"用一种方式聚焦"一次。但一个词和上下文的关系是多维的(语法 + 语义 + 位置),单头表达不下。

机制:把 d_model 维(比如 512)切成 h 个 head(比如 8 个,每个 d_head = 512/8 = 64),每个 head 在自己的低维子空间里独立做一次注意力,最后拼接 + 一个输出投影 WO 混合。核心公式(每个 head):

Attention(Q,K,V) = softmax( Q·KT / √d_head ) · V

每个符号:Q(query,"我在找什么")、K(key,"我能提供什么")、V(value,"实际要传的内容")。Q·KT 算每对 token 的相似度(点积越大越相关);softmax 把它变成权重;再加权求和 V。为什么除 √d_head?——维度越高点积数值越大,会把 softmax 推向饱和区(一个权重接近 1、其余接近 0),梯度几乎消失。除以 √d_head 把方差拉回 ~1,训练才稳定。这是 2017 年 Attention Is All You Need 的设计。

d_model=512 拆成 8 个 head 并行

输入 x 512 维
↓ 切成 8 份,各自投影到 64 维
head1·64head2·64head8·64
↓ 每个 head 独立算 softmax(QKᵀ/√64)·V
注意力1注意力2注意力8
↓ 拼接回 512 维 → 输出投影 Wᴼ 混合
输出 512 维 (形状不变,信息已交叉融合)
代码示例
import torch, torch.nn as nn
# d_model=512 拆成 8 个 head,每个 head 维度 = 512/8 = 64
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
x = torch.randn(2, 10, 512)   # (batch, seq_len, d_model)

# self-attention: Q=K=V 都是 x 自己
out, attn = mha(x, x, x, need_weights=True, average_attn_weights=False)
print(out.shape)    # (2, 10, 512) 形状不变,信息已混合
print(attn.shape)   # (2, 8, 10, 10) 8 个 head 各一张 10×10 注意力矩阵
# 每个 head 学到不同关注模式:有的盯相邻词,有的盯句法依赖
常见误区 + 实践场景
"head 越多模型越强"——错。head 数受约束(d_model 必须能被 h 整除),且 head 太多每个 d_head 太小、表达力反而下降。研究发现训练后很多 head 是冗余的,剪掉对效果几乎无影响——说明容量不是越分越好。
📌 跨学科场景:MHA 是一种"并行多视角"计算范式——像一个委员会同时从语法、语义、指代多个角度审同一句话再汇总。这个"多专家并行 + 末端聚合"的结构,和你熟悉的分布式 scatter-gather、组织决策中的多专家会审是同构的。
Takeaway + 思考题
💡 "多头"不是堆算力,是让模型在多个子空间里并行学不同关系——总算量和单个大头差不多,差别在表达力。
🤔 如果一个 head 是一个"专用视角",训练后冗余的 head 说明了神经网络在容量分配上的什么习性?

分组 / 多查询注意力GQA / MQA

推理优化KV cache
一句话类比

把 KV cache 想成推理时被反复读取的"热数据"。MHA 给每个 query head 配一份专属 KV 副本(显存爆);MQA 让所有 head 共享一份 KV(像所有 worker 共用一个连接池);GQA 折中——分几组、每组共享一份。本质是给 KV 做"去重 / 共享"。

它解决什么问题 + 工作机制

痛点:自回归解码时,每生成一个新 token,都要把整个 KV cache 从 HBM 读一遍。KV cache 越大,内存带宽越成为瓶颈——解码慢的主因往往不是算力,是"搬 KV"。而 KV cache 大小 ∝ KV head 数量

  • MHA:KV head 数 = Q head 数(如 64)。最强表达,但 KV cache 最大;
  • MQA(Shazeer 2019):KV head 数 = 1,所有 Q head 共享。KV cache 缩小到 1/64,但质量会掉;
  • GQA(Ainslie 2023):KV head 数 = g(如 8),每组 Q head 共享一份。在 MHA 和 MQA 之间插值——论文证明它能"质量接近 MHA、速度接近 MQA"。Llama 2/3、Mistral 都默认用 GQA。
KV head 共享谱(Q head 数都是 8)

MHA 8 Q ↔ 8 KV:QKV ×8 KV cache 最大
GQA 8 Q ↔ 2 KV:Q Q Q QKVQ Q Q QKV 缩 4×
MQA 8 Q ↔ 1 KV:Q Q Q Q Q Q Q QKV 缩 8×,质量掉
代码示例
# KV cache 大小 ∝ KV head 数。估算 70B 级模型单条序列的 KV cache:
def kv_cache_gb(n_kv_heads, seq, d_head=128, layers=80, dtype=2):
    # K 和 V 各存一份 → ×2; dtype=2 表示 fp16 每个数 2 字节
    return 2 * n_kv_heads * d_head * seq * layers * dtype / 1e9

seq = 8192
print("MHA (64 KV head):", round(kv_cache_gb(64, seq), 1), "GB")  # ~21.5
print("GQA ( 8 KV head):", round(kv_cache_gb(8,  seq), 1), "GB")  # ~2.7
print("MQA ( 1 KV head):", round(kv_cache_gb(1,  seq), 1), "GB")  # ~0.3
# GQA 把 KV cache 砍到 1/8,几乎不掉质量 → 长上下文的现实选择
常见误区 + 实践场景
"GQA 是为了省训练算力"——错。GQA/MQA 省的主要是推理时的 KV cache 显存和内存带宽,训练时收益很小。它是一种推理优化,披着架构的外衣——同样的"模型会什么",只是改了 KV 要搬多少。
📌 选型决策场景:看模型卡里的 num_key_value_heads。若它 = num_attention_heads,是 MHA(长上下文显存贵、解码慢);若远小于(如 8 vs 64),是 GQA(长文档任务更省、更快)。这一个数字直接预示你跑长上下文的成本曲线。
Takeaway + 思考题
💡 GQA/MQA 改的不是模型"会什么",而是推理时 KV 要搬多少——是访存优化披着架构外衣。
🤔 当 KV cache 成为长上下文的"内存墙",下一步该压 KV head 数,还是压每个 token 的 KV 维度?(这正是 DeepSeek 的 MLA 走的另一条路)

滑动窗口注意力Sliding Window Attention · SWA

稀疏注意力线性复杂度
一句话类比

像流处理里的有界缓冲(bounded buffer):每个 token 只看最近 w 个 token,不再看全程。但信息不会被困死——靠层层向前传递,类似 gossip 协议逐跳扩散,或 CNN 用深度堆出大感受野。

它解决什么问题 + 工作机制

痛点:标准注意力是 O(n²)——每个 token 和全部 n 个 token 算相似度。n=100K 时就是约 1010 对,长序列直接算不动、存不下。

机制:限制每个 token 只注意前 w 个(窗口宽度),复杂度降到 O(n·w)——对序列长度线性。那远处的信息怎么传过来?靠深度:每过一层,信息向前推进 w 个 token;叠 L 层后,有效感受野 ≈ L·w。Mistral 7B 论文正是这么说的——"k 层之后信息可前移最多 k×W 个 token"。Longformer (2020) 进一步加少量"全局 token"(如 [CLS])当跨段桥梁。

窗口 w=3:信息靠层数逐跳扩散

token:  t0  t1  t2  t3  t4  t5  t6  t7
第1层:  每个 token 只看左边 3 个 → 对角带
第2层:  t6 间接看到 t3(经 t4/t5 传递)
第L层:  有效感受野 ≈ L × 3 个 token

L 层后能覆盖多远 → 随深度线性增长
1 层 3 | 8 层 24 | 32 层 96
代码示例
import torch
seq, window = 10, 3
# 滑动窗口掩码:token i 只能看 [i-window+1, i] 这 window 个位置
i = torch.arange(seq).unsqueeze(1)
j = torch.arange(seq).unsqueeze(0)
mask = (j <= i) & (j > i - window)   # 因果 + 窗口 → 一条对角带
print(mask.int())   # 1=可见 0=屏蔽,可见区是对角带不是整个三角

# 有效感受野随层数线性增长
for L in (1, 8, 32):
    print(f"{L} 层后感受野 ≈ {L*window} tokens")
常见误区 + 实践场景
"滑动窗口 = 只能处理短文本"——错。靠堆层数,感受野能覆盖很长序列。但它的真实弱点是"单跳远距离精确检索":要把第 100K 个 token 的某个事实精确取到当前位置,得靠多层接力传递,容易在传递中失真——所以常配全局 token 或与 full attention 层交替使用。
📌 判断力场景:当一个号称 128K 上下文的模型在"跨越很远的关联"任务上表现差,先想想它的注意力是不是纯 SWA——若是,远距离信息靠层数慢慢传,"大海捞针 + 推理"本就是它的结构性短板,不是 prompt 没写好。
Takeaway + 思考题
💡 局部注意力 + 深度 = 用"层数"换"广度",和 CNN 用深度堆感受野是同一招。
🤔 信息靠层数逐跳传播,这对"大海捞针"任务意味着什么样的失败模式?深度和窗口宽度该怎么权衡?

FlashAttentionIO-Aware Exact Attention

访存优化精确
一句话类比

外部归并排序:数据大到装不进内存,就分块、每块装进缓存处理,避免把全量反复搬进慢速存储。FlashAttention 是 attention 的"cache-aware 算法"——从不在显存里物化那张 n×n 的注意力矩阵,分块在片上高速缓存里算完。同样的数学,换访存顺序。

它解决什么问题 + 工作机制

关键反直觉:attention 是 memory-bound(访存受限),不是 compute-bound(算力受限)。瓶颈不在乘法次数,而在那张 n×n 矩阵在 HBM(GPU 大显存,慢)SRAM(片上缓存,快但极小) 之间反复搬运。朴素实现:算 S=QKᵀ 写回 HBM → 读回做 softmax → 再读回乘 V,全程被慢显存拖死。

FlashAttention (Dao 2022) 把 Q、K、V 切成小块(tile),每块装进 SRAM,一块一块算,永不把完整 n×n 写回 HBM。难点是 softmax 要按整行的全局最大值和总和归一化——而你一次只看到一块。解法是 online softmax(在线 softmax):维护"running max"和"running sum",每来一个新块就用修正因子 exp(旧max − 新max) 重新缩放已累积的结果。数值稳定、且结果与朴素实现逐元素完全相同(exact,非近似)。收益:访存暴减 → 2–4× 提速,显存随 n 线性而非平方。

瓶颈在搬运,不在算力

HBM 大·慢 ↔ 反复搬 n×n 矩阵 ↔ SRAM 小·快 算力

朴素: QKᵀ→写 HBM→读→softmax→写→读→×V (物化整张 n×n)
Flash: Q/K/V 切块进 SRAM → 块内算 + online softmax 累加 → n×n 永不落 HBM

同样的输出,显存 O(n²)→O(n),墙钟 2–4× 快
代码示例
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
# 需要支持 CUDA 的 GPU; q/k/v 形状 (batch, heads, seq, d_head)
q = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
k = torch.randn_like(q); v = torch.randn_like(q)

# 强制走 FlashAttention 内核:从不在 HBM 物化 4096×4096 注意力矩阵
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(out.shape)   # (2, 8, 4096, 64) 结果与朴素实现逐元素相同
# 朴素实现要物化数百 MB 的中间矩阵;Flash 近乎为 0
常见误区 + 实践场景
"FlashAttention 是一种稀疏 / 近似注意力,会损失精度"——错。它是精确(exact)的,输出和朴素实现逐元素相同。它优化的是怎么算(访存顺序),不是算什么(数学)。这点和 SWA 这种真改了数学的稀疏注意力完全不同。
📌 选型场景:当你发现同一个模型换到 vLLM / SGLang 这类推理框架后快很多,FlashAttention 这类 IO-aware 内核是关键来源之一。理解"速度差来自访存优化、不是模型变了",能帮你判断一个性能宣传是真本事还是偷换了精度。
Takeaway + 思考题
💡 同样的数学,换个访存顺序就快 2–4 倍——性能瓶颈常在 memory,不在 FLOPs
🤔 你过去优化系统时,有多少次"以为是 CPU 瓶颈、其实是 I/O 瓶颈"?这个直觉怎么迁移到判断 AI 系统的性能问题?

深入资源Further Reading

深入思考Deep Questions

1. GQA(改数学/架构)和 FlashAttention(改访存/不改数学)都让注意力更快,但属于完全不同的两类优化。这个区分为什么重要?
这是理解整个领域的关键分水岭。GQA/MQA/SWA 改的是"算什么"——它们改变了模型的数学定义:GQA 让 KV head 变少(牺牲一点表达力换显存),SWA 让每个 token 看的范围变窄(牺牲远距离直连换线性复杂度)。这类优化会改变模型输出,必须在训练时就决定(或用 uptraining 微调),是一种 quality–cost trade-offFlashAttention 改的是"怎么算"——数学定义一字不动,只重排了 GPU 上的访存顺序,输出逐元素完全相同。它是免费午餐:没有任何精度代价,任何模型都能直接套用。为什么重要:(a) 能不能事后加——Flash 可以给现成模型直接加速;GQA 不行,得重训或 uptrain;(b) 评估方式不同——Flash 只需对拍数值和测速;GQA/SWA 必须跑下游质量评测,因为它们会掉点;(c) 叠加——两类正交,可同时用(Mistral = GQA + SWA + Flash 三者叠加)。BigCat 你的分布式直觉:Flash 像"换个更好的 I/O 调度器",GQA 像"改数据库 schema 做反范式化"——一个不动语义、一个动语义。
2. KV cache 是长上下文的"内存墙"。GQA 砍 KV head 数,DeepSeek 的 MLA 砍每个 token 的 KV 维度——这两条路线本质上在压缩同一个东西的不同轴,你怎么看哪条更有前途?
KV cache 大小 ≈ (KV head 数) × (每 head 维度) × seq × layers,所以"压 KV"有两个正交的轴。GQA 走"head 数"轴:让多个 query head 共享少数 KV head,简单、训练稳定、可从 MHA checkpoint uptrain,是当下工业默认。代价是共享带来的表达力损失,且压缩比受限(KV head 不能少于 1,少到 MQA 就明显掉点)。MLA(Multi-head Latent Attention)走"维度"轴:把 K、V 联合压缩成一个低维 latent 向量存进 cache,用时再投影回来——相当于对 KV 做了"有损压缩 + 解压",压缩比可以远超 GQA,且 DeepSeek 报告质量不降反升。代价是实现复杂、和 RoPE 位置编码的兼容需要特殊处理。怎么看:短期 GQA 因为简单 + 生态成熟仍是主流;但"维度轴"理论上限更高(信息可以跨 head 共享冗余),长期更性感。更深的一层:两者都假设 KV 里有冗余可压——这本身指向一个根本问题,注意力到底需要多少"记忆带宽"?这又把话题引向 Mamba/SSM 这种用固定大小状态替代 KV cache 的路线(Day 34)。
3. 为什么 FlashAttention 这种"重排访存就快 2-4 倍"的优化,过了好几年才出现?它揭示了软硬件协同设计的什么规律?
表面看是"晚了",深层是抽象层把硬件现实藏起来了。研究者写 attention 时用的是 PyTorch/数学层的抽象——矩阵乘、softmax,这层抽象里"算"和"存"是不可见的,看起来瓶颈就该是 FLOPs。但 GPU 的真相是:算力(TFLOPs)这十年涨得远比内存带宽快,导致越来越多 kernel 是 memory-bound——算力闲着,时间全花在等 HBM。FlashAttention 的洞察不是新算法(online softmax 1980 年代就有),而是有人愿意下到 CUDA 层,按 GPU 的内存层级(HBM/SRAM 大小)重新设计 kernel。揭示的规律:(a) 抽象有代价——越高层的抽象越容易让你对真实瓶颈"失明",重大优化常来自捅破抽象、贴着硬件重写;(b) 瓶颈会漂移——硬件演进(算力 vs 带宽剪刀差)会让昨天的非瓶颈变成今天的瓶颈,优化要追着硬件趋势走;(c) IO-awareness 是通法——同样的思路后来推广到很多 kernel。对 BigCat:这和你在分布式系统里的经验完全一致——数据库查询慢,十有八九不是 CPU 而是磁盘/网络 I/O;真正的高手优化的是数据移动,不是计算。
4. 把今天四个变种串起来:从 2017 的 MHA 到现在,attention 的演化主线是什么?这条主线接下来会走向哪里?
主线是一句话:在"表达力"和"扩展性(长序列 + 低成本)"之间反复拉扯,逐步把成本从平方降向线性。(1) MHA(2017)奠定范式,但 O(n²) 计算 + 随序列线性增长的 KV cache,注定长序列吃不消。(2) 稀疏化(SWA/Longformer,2020)从"算什么"下手砍计算复杂度,O(n²)→O(n·w),代价是远距离要靠深度接力。(3) KV 共享(MQA 2019 / GQA 2023)从推理瓶颈下手砍 KV cache 显存与带宽,让长上下文解码可负担。(4) FlashAttention(2022)换个维度——不改数学,贴着硬件砍访存,把 MHA 本身变快变省。四者正交、可叠加,今天的前沿模型基本是"GQA + 局部/全局混合 + Flash 内核"的合体。接下来两条线:(a) 继续压 KV——MLA(维度压缩)、KV 量化、甚至 KV 驱逐(evict 不重要的历史);(b) 跳出注意力——Mamba/SSM 用固定大小的状态彻底取代随序列增长的 KV cache,把"记忆"从 O(n) 变成 O(1)。一个开放赌注:未来主流是"注意力 + 这些优化"继续主导,还是混合架构(部分层用 SSM、部分层用注意力)胜出?目前看混合架构势头很猛。
5. 注意力的"成本"逼着架构师做有损权衡(GQA 掉一点质量、SWA 丢远距离)。这种"为了规模牺牲完美"的工程哲学,在你熟悉的分布式系统里有哪些同构的影子?
这几乎是同一套世界观在不同领域的投影。分布式系统的 CAP / 最终一致性:为了可用性和分区容忍,放弃强一致——和 GQA"为了显存放弃一点表达力"同构,都是承认"完美不可扩展,必须选择性放弃"。缓存与采样:CDN/Redis 缓存是"用一点陈旧换巨大吞吐",SWA 是"用一点远距离精度换线性复杂度"——都是有损但值得。LSM-tree / 列存的反范式化:为读写吞吐牺牲存储冗余和即时一致,和"压 KV cache 换长上下文"是一类账。近似算法:HyperLogLog、Bloom filter 用可控误差换数量级的空间——精神上正是"exact 太贵就转 approximate",而 FlashAttention 偏偏是反例(坚持 exact、只优化访存),提醒你"有损"不是唯一出路,有时换个实现就能既要又要。更深的意涵:规模本身会重写什么叫"对"——小数据上你追求精确解,大规模下"够好且能扩展"往往打败"完美但不可扩展"。BigCat 你做"AI 超级个体"系统设计时,这个判断力——知道在哪条轴上可以有损、有损多少、怎么补偿——比记住任何单个技巧都值钱。