标准 Transformer 有两个「先天病」:注意力是 O(n²)(序列翻倍,计算翻 4 倍),以及稠密计算(每个 token 都激活全部参数)。今天 4 个概念是学术界对这两个病的两条攻击路线——MoE 攻「稠密」让参数变大但计算不变;SSM / Mamba 攻「O(n²)」用 RNN 思想换回线性复杂度;长上下文则是「在不换架构的前提下,怎么把 n 撑大」。它们回答的都是同一个问题:当我们想要更大、更长,硬件账单怎么办?
MoE 就是给神经网络做数据库分库分表 + 查询路由。稠密模型像一台单机数据库,每个请求都扫全表;MoE 把一个巨大的前馈层(FFN)拆成 N 个「专家分片」,再放一个路由器(router),按 token 内容只把请求转发给最相关的 1-2 个分片。总容量随分片数线性涨,单次查询的计算量却几乎不变——这正是分库分表的核心收益:用空间换吞吐,不用算力换容量。
痛点:模型变强最直接的办法是堆参数,但稠密模型里参数量和计算量是绑死的——参数翻倍,每个 token 的浮点运算(FLOPs)也翻倍,推理成本线性爆炸。我们想要「容量大但算得便宜」,这二者在稠密架构里不可兼得。
MoE 的机制是条件计算(conditional computation):把 Transformer 每层的 FFN 替换成 N 个并列的小 FFN(专家),加一个轻量路由器。每个 token 进来,路由器算出一个分数分布,只选 top-k 个专家(通常 k=1 或 2)激活,其余专家这一步完全不参与运算。于是:
关键设计难点是负载均衡:路由器若总把 token 发给同几个「明星专家」,剩下的专家训不动、显存还白占。Shazeer 等人 2017 年的奠基论文引入了辅助的均衡损失(auxiliary loss)来惩罚这种倾斜;Switch Transformer(2021)进一步把 k 简化到 1,证明单专家路由也能保质量、还更省路由开销。2024 年的 Mixtral 8×7B 把这套思想做成了开源旗舰:8 个专家、每 token 激活 2 个。
import torch, torch.nn.functional as F # 极简 MoE 层:N 个专家 + top-k 路由(理解机制用,非生产实现) class MoE(torch.nn.Module): def __init__(self, d=512, n_exp=8, k=2): super().__init__() self.k = k self.router = torch.nn.Linear(d, n_exp) # 路由器:token → 每个专家的分数 self.experts = torch.nn.ModuleList([torch.nn.Linear(d, d) for _ in range(n_exp)]) def forward(self, x): # x: [tokens, d] scores = self.router(x) # 每个 token 对所有专家打分 w, idx = scores.topk(self.k, dim=-1) # 只取 top-k 个专家 w = F.softmax(w, dim=-1) # 在被选中的 k 个里归一化权重 out = torch.zeros_like(x) for j in range(self.k): # 加权求和被激活的专家输出 for e in range(len(self.experts)): m = idx[:, j] == e # 路由到专家 e 的 token 掩码 if m.any(): out[m] += w[m, j:j+1] * self.experts[e](x[m]) return out # N=8 容量,但每 token 只算 2 个专家
SSM 处理序列的方式,本质是一个带固定大小内存的流式处理器(streaming processor)。Transformer 像把整个日志文件读进内存再两两比对(O(n²));SSM 像 Kafka 消费者——逐条读入,把历史压缩进一个固定维度的「状态变量」,读完一条就更新状态、丢弃原文。内存不随流长度增长,这正是流式系统对批处理的核心优势。
痛点:Transformer 的注意力要存所有历史 token 的 Key/Value(KV cache 随长度线性涨),计算两两相关性(O(n²))。序列到几万、几十万 token 时,又慢又吃显存。能不能像 RNN 一样「O(n) 线性、内存恒定」,又不丢 Transformer 的并行训练能力?
SSM 借自控制论的经典方程。核心是维护一个隐状态 h,每来一个输入 xt 就更新一次:
直觉:A 是「遗忘/保留矩阵」——决定旧记忆怎么衰减;B 是「写入门」——决定当前输入注入多少;C 是「读出门」——决定从状态里提取什么。这和 LSTM 的门控思想一脉相承,但 SSM 的妙处在于:当 A、B、C 与输入无关(线性时不变)时,整个递推可以数学上展开成一个卷积,于是训练时能像 CNN 一样全序列并行,推理时又能像 RNN 一样逐步 O(1) 更新——鱼和熊掌兼得。
难点是:naive 的 A 矩阵在长序列上数值会爆炸或消失(和 RNN 的梯度问题同源)。Gu 等人 2021 年的 S4 论文用一种特殊的结构化 A 矩阵(基于 HiPPO 理论的初始化 + 低秩修正)解决了这个稳定性问题,让 SSM 第一次在长序列基准(Long Range Arena)上全面超过 Transformer,并把生成速度做到快几十倍。
import torch # SSM 的「循环模式」——展示恒定内存的流式更新(教学版,非 S4 完整实现) def ssm_scan(x, A, B, C): # x: [seq_len, d] A,B,C 是学到的状态矩阵 h = torch.zeros(A.shape[0]) # 隐状态:固定维度,不随 seq_len 变 ys = [] for t in range(x.shape[0]): # 逐 token 流式处理 h = A @ h + B @ x[t] # 旧状态衰减 + 新输入写入 ys.append(C @ h) # 从状态读出当前输出 return torch.stack(ys) # 关键:无论 seq_len 是 1k 还是 1M,h 的大小不变 → 内存 O(1) # 训练时这个递推可数学展开成卷积,从而全序列并行(此处省略)
S4 的状态矩阵是静态配置——像一个写死规则的缓存策略,不管来什么数据都用同一套「保留/丢弃」规则。Mamba 把它升级成内容感知的动态缓存:让「写入门 B、读出门 C、遗忘步长」都变成当前输入的函数。等于缓存策略能根据数据自己判断「这条重要,多留一会儿;那条是废话,赶紧忘」——从固定 TTL 升级成自适应 TTL。
痛点:上一节说 SSM 的硬伤是「无差别压缩历史」。因为 A、B、C 是固定的,模型没法根据内容决定该记住谁、忽略谁——这叫缺乏「内容感知推理(content-based reasoning)」。比如做「跳过所有空格、只记实词」这种任务,固定 SSM 做不到。
Mamba(Gu & Dao, 2023)的核心创新叫选择性机制(selection mechanism):把 SSM 的参数 B、C 以及离散化步长 Δ 改成由输入 xt 现算的函数。这样每个 token 都能动态控制——「我要往状态里写多少、读多少、把旧记忆衰减多快」。直觉上,模型获得了选择性记忆的能力:遇到关键信息就重置/强写状态,遇到噪声就让它流过去。
但这里有个工程悖论:参数一旦依赖输入,就不再是「线性时不变」,那套「展开成卷积来并行训练」的技巧就失效了。Mamba 的第二个贡献是硬件感知的并行扫描(parallel scan)算法——用一种类似前缀和的并行原语,配合精心设计的 GPU 显存读写(思路类似 FlashAttention,把中间状态留在高速 SRAM 里),让这个「输入依赖的递推」依然能在 GPU 上高效并行。论文报告 Mamba 在语言建模上吞吐量约为同规模 Transformer 的 5 倍,并随序列长度线性扩展、可处理到百万级长度。
2024 年起的实践共识是混合架构:把少量注意力层和大量 Mamba 层交替堆叠,用注意力补「精确召回」短板,用 Mamba 拿「长序列效率」——取两者之长。
# 官方实现:pip install mamba-ssm(需 CUDA) import torch from mamba_ssm import Mamba batch, seq_len, dim = 2, 4096, 512 x = torch.randn(batch, seq_len, dim).to("cuda") model = Mamba( d_model=dim, # 输入/输出维度 d_state=16, # 隐状态维度——固定大小,与 seq_len 无关 d_conv=4, # 局部卷积窗口(捕捉短程模式) expand=2, # 内部扩张倍数 ).to("cuda") y = model(x) # 输出与输入同形 [2, 4096, 512] print(y.shape) # B、C、Δ 在内部由 x 现算 → 选择性记忆 # 序列再长 d_state 也是 16,靠并行扫描在 GPU 上高效跑
「把上下文从 4K 撑到 1M」不是简单调一个配置参数,更像给一个为小数据量设计的系统做水平扩展:你不能只改一行 max_length,得同时解决算力(O(n²) 爆炸)、显存(KV cache 线性涨)、和泛化(模型没见过这么远的位置)三件事——就像数据库从单机扩到分布式,索引、缓存、一致性全都要重做。
痛点:原始 Transformer 在固定长度(如 2K、4K)上训练,想直接喂 100K token 会出现三重障碍。长上下文是一组工程技术的合集,分别攻击这三个:
第三点的主流解法围绕 RoPE(旋转位置编码)展开(Day 13 详讲过其机制)。RoPE 把位置信息编码成「旋转角度」,频率越高的维度转得越快。直接外推到训练没见过的长度时,高频维度「转过头」导致模型懵掉。位置插值(Position Interpolation)的巧思是:与其让模型见识没见过的大角度,不如把所有位置等比例「压缩」回训练见过的范围——好比把一把为 30cm 设计的尺子,重新标刻度让它量 3 米,刻度变密但范围都在「认识的区间」内,只需少量微调即可适配。
import torch # 位置插值的核心思想:把超长位置「压缩」回训练长度区间 train_len = 4096 # 模型原始训练长度 target_len = 32768 # 想扩展到的目标长度 scale = train_len / target_len # 缩放因子 = 1/8 pos = torch.arange(target_len).float() # 关键一步:位置 ×scale → 把 [0, 32768) 压回 [0, 4096) 的"已知区间" pos_interpolated = pos * scale # 模型只需"认识"它见过的位置范围 # 这些插值后的位置再喂给 RoPE 计算旋转角度(RoPE 机制见 Day 13) # 实践中:插值后用少量长文本微调,即可稳定适配新长度 print(pos_interpolated.max()) # ≈ 4095,全落在训练见过的范围内