AI/ML 详解:训练基础设施

Day 25 · 2026-06-11 · 难度:进阶
面向:有分布式系统经验的非 AI 方向工程师

分布式训练Distributed Training

系统并行
一句话类比

一个 70B 模型单卡装不下,就像一张表太大单机存不下——你的本能是分片(sharding)。但训练有三个正交的切分维度:切数据(像多个只读副本)、切模型的层(像微服务流水线)、切单层里的大矩阵(像把一个大 JOIN 拆到多机)。理解分布式训练,本质是理解这三个轴怎么组合

它解决什么问题 + 工作机制

痛点是显存(VRAM),不是算力。一个 70B 模型用 fp16,光参数就 140GB;加上梯度和 Adam 优化器状态,训练时每个参数要占 ~16 字节(下张卡的 ZeRO 会拆解这个数字),总共 1TB 出头——一张 80GB 的 H100 根本放不下。三种并行各解决一块:

  • 数据并行 Data Parallel(DP):每张卡存一份完整模型,只把训练 batch 切开。各卡独立前向反向,再用 all-reduce 把梯度求平均同步。类比:多个只读副本各处理一批请求,定期对账。前提是模型本身塞得进单卡
  • 流水线并行 Pipeline Parallel(PP):把模型按层切成几段,分到不同卡,像装配线 / CPU 流水线。代价是 bubble(气泡)——第 1 段算完在等第 2 段,靠把 batch 拆成 micro-batch 填满流水线来缓解;
  • 张量并行 Tensor Parallel(TP):把单层内的大矩阵乘法横竖切到多卡(第 4 张卡详解)。通信最密集,只在一台机器的高速 NVLink 内做。
三个正交的切分维度

Data 并行 切 batch,复制模型 → all-reduce 同步梯度
Pipeline 并行 切层(纵向)→ 卡A:层1-8 │ 卡B:层9-16 …
Tensor 并行 切单层矩阵(横向)→ 一个 matmul 拆到多卡

现代万卡训练 = 三者组合,叫 3D 并行
代码示例
# 最简单的数据并行:PyTorch DDP。用 torchrun 启动多进程
import torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group("nccl")        # NCCL = NVIDIA 的 GPU 集合通信库
rank = dist.get_rank()
torch.cuda.set_device(rank)

model = MyTransformer().to(rank)
model = DDP(model, device_ids=[rank])  # 反向时自动 all-reduce 梯度

for batch in shard_loader(rank):     # 每个 rank 拿不同数据切片
    loss = model(batch).loss
    loss.backward()                  # ← 梯度在这一步跨卡求平均
    optimizer.step(); optimizer.zero_grad()
# 启动:torchrun --nproc_per_node=8 train.py
常见误区 + 实践场景
"加卡 = 线性加速"——错。加卡会引入通信开销(梯度同步、参数收集),而 all-reduce 的成本随卡数增长。8 卡常能到 7x,但 512 卡可能只有 ~380x(约 74% 效率)。这就是 Megatron 论文专门报告"scaling efficiency"而不是只说倍数的原因——通信永远在和你抢带宽。
📌 超级个体场景:你大概率不会从头训 70B,但会微调。判断"要不要多卡"的第一性原理就是这张卡的逻辑——先估算「参数+梯度+优化器状态」总显存,超过单卡就必须 sharding。把它当容量规划(capacity planning)问题,和你规划数据库分片是同一套思维。
Takeaway + 思考题
💡 分布式训练的瓶颈是显存与通信,不是算力。三种并行是三个正交轴,按"什么装不下"来选。
🤔 你熟悉的分布式存储里「复制 vs 分片」的权衡,怎么精确映射到「数据并行 vs ZeRO 分片」?

零冗余优化器ZeRO (Zero Redundancy Optimizer)

显存分片
一句话类比

纯数据并行有个巨大浪费:每张卡都存一份完整的参数、梯度、优化器状态——这是 N 份冗余副本,等于分布式存储里"全量复制到每个节点"。ZeRO 的洞察是:这些状态大部分时间用不到,干嘛每张卡都存满?于是把它们切片(shard)到各卡,每卡只存 1/N,用到时临时 all-gather 拼回来。本质是用通信换显存——和用 erasure coding 替代全量复制是同一个 trade-off。

它解决什么问题 + 工作机制

先把"每参数 16 字节"拆开看——这是理解 ZeRO 的关键。混合精度训练里,每个参数要存:

混合精度训练,每个参数的显存账单

fp16 参数 2 字节
fp16 梯度 2 字节
fp32 参数副本 4 字节
fp32 动量 m 4 字节 ├ Adam 优化器状态
fp32 方差 v 4 字节

合计 ≈ 16 字节/参数。70B 模型 ≈ 1.1 TB——其中 12 字节是优化器状态,最肥

ZeRO 分三级,渐进式地切掉冗余,省得越多、通信越重:

  • ZeRO-1:切优化器状态(那 12 字节)。最肥的部分先分片,几乎不增加通信,性价比最高;
  • ZeRO-2:再切梯度。反向算完即把梯度 reduce-scatter 到负责的卡,本地只留自己那片;
  • ZeRO-3:连参数本身也切。前向/反向走到某层时才 all-gather 出完整参数,用完立刻丢弃。每卡显存降到约 16/N 字节/参数——理论上加卡就能线性降显存。

这是微软 DeepSpeed 的 ZeRO(Rajbhandari et al. 2019, arXiv 1910.02054)。它最优雅的地方:在数据并行的编程模型下(每卡看到的还是"完整模型"的假象),偷偷把内存冗余消掉了。

代码示例
# DeepSpeed 用一份 JSON config 开启 ZeRO,业务代码几乎不变
ds_config = {
    "train_micro_batch_size_per_gpu": 4,
    "bf16": {"enabled": True},
    "zero_optimization": {
        "stage": 3,                 # 切 优化器+梯度+参数
        "offload_optimizer": {"device": "cpu"},  # 还能下放到 CPU 内存
        "stage3_prefetch_bucket_size": 5e7        # 预取,藏通信延迟
    }
}
import deepspeed
engine, _, _, _ = deepspeed.initialize(
    model=model, config=ds_config, model_parameters=model.parameters())
for batch in loader:
    loss = engine(batch).loss
    engine.backward(loss)   # ZeRO 在这里做 reduce-scatter + all-gather
    engine.step()
常见误区 + 实践场景
"ZeRO-3 最省显存,那就无脑开 stage 3"——错。stage 越高,每一步前向/反向都要 all-gather 参数,通信量暴涨。如果你的瓶颈本来是显存就值;如果模型其实塞得下、瓶颈在速度,stage 3 反而拖慢。规则:先用最低能装下的 stage。这和"别为了省一点磁盘把热数据也压缩"是同一种工程判断。
📌 超级个体场景:用 4 张消费级 GPU(如 24GB ×4)微调一个 13B 模型时,ZeRO-2 + 优化器 CPU offload 常常是"勉强能跑"和"OOM 崩溃"的分界线。理解 16 字节账单,你就能事前算出开哪个 stage,而不是撞墙试错。
Takeaway + 思考题
💡 ZeRO = 把数据并行里的冗余副本换成分片 + 按需收集,用通信带宽换显存
🤔 ZeRO-3 的"用到才 all-gather、用完就丢",和虚拟内存的按需分页(demand paging)在思想上是不是同一招?它们各自的"缺页代价"是什么?

全分片数据并行FSDP (Fully Sharded Data Parallel)

PyTorch原生
一句话类比

FSDP 就是 PyTorch 官方把 ZeRO-3 写进了标准库。机制几乎等价,区别像第三方库 vs 内置实现:DeepSpeed/ZeRO 是外挂的(类似 requests),FSDP 是和 PyTorch 的 autograd、dispatcher、显存分配器深度同源的官方版本(类似内置的 urllib,但更顺手)。对大多数纯 PyTorch 项目,FSDP 是"不引入新生态就能拿到 ZeRO-3"的默认选择。

它解决什么问题 + 工作机制

解决的痛点和 ZeRO-3 一样——消除数据并行的内存冗余——但要无缝融进原生 PyTorch。机制可以一句话概括:平时分片、用时聚合、用完释放

  • 包裹(wrap)成单元:把模型按层/模块切成若干 FSDP unit。平时每个 unit 的参数只在本卡留 1/N 分片;
  • 前向到某 unit:先 all-gather 拼出这一层的完整参数 → 算这一层 → 立刻 free 掉非本地的那部分,只留回 1/N;
  • 反向同理:再 all-gather 一次算梯度,然后 reduce-scatter 把梯度切片分发回各自负责的卡;
  • 通信-计算重叠:算第 i 层时,后台预取第 i+1 层的参数,把 all-gather 的延迟藏在计算背后——这是它能跑快的关键。

权威参考是 PyTorch FSDP 论文(Zhao et al. 2023, arXiv 2304.11277),里面讲清了它怎么和 PyTorch 核心组件协同设计。auto_wrap_policy 决定切分粒度:unit 切太细 → all-gather 次数多、通信碎;切太粗 → 单次聚合的峰值显存高。粒度本身就是显存与通信的连续权衡。

FSDP 单层的生命周期(每层重复)

分片态: 本卡只有 1/N 参数
↓ all-gather(拼出完整层)
完整态: 算 forward / backward
↓ free 非本地分片 + reduce-scatter 梯度
回到分片态,显存释放

峰值显存 ≈ 全模型分片(1/N) + 当前这一层的完整参数(不是整个模型)
代码示例
# PyTorch 原生 FSDP,无需第三方库
import torch, functools
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# 按 Transformer block 自动切分单元——粒度决定通信/显存权衡
policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerBlock})

model = FSDP(
    model,
    auto_wrap_policy=policy,
    device_id=torch.cuda.current_device())

# 之后的训练循环和单卡几乎一模一样——分片对业务代码透明
for batch in loader:
    loss = model(batch).loss
    loss.backward()      # FSDP 在此处理 all-gather / reduce-scatter
    optimizer.step(); optimizer.zero_grad()
常见误区 + 实践场景
"FSDP 是 DDP 的无脑升级,换个 wrapper 就行"——要小心。FSDP 改变了模型的状态语义:参数是分片的,直接 model.state_dict() 存档、或写一段假设"本卡有完整参数"的自定义逻辑,都可能出错——得用 FSDP 提供的 state_dict 接口先把分片聚合回来。分片是透明的,但不是隐形的:碰参数内部时它就会冒出来。
📌 超级个体场景:2026 年微调大模型,FSDP 几乎是 HuggingFace accelerate / trl 背后的默认引擎。哪怕你只写一行 accelerate launch,理解"它在每层 all-gather 又 free"能让你看懂为什么峰值显存和你算的参数量对不上——多出来的就是当前层的完整副本。
Takeaway + 思考题
💡 FSDP = 原生 PyTorch 版的 ZeRO-3:平时分片、用时聚合、用完释放,通信藏在计算背后。
🤔 FSDP 把"完整参数"做成了一个短暂、即用即弃的对象——这和你写缓存时"按需物化、立即失效"的模式是不是一回事?

Megatron-LM 张量并行Megatron-LM / Tensor Parallelism

张量并行3D 并行
一句话类比

ZeRO 和 FSDP 都假设单层算得动——它们切的是"谁存参数",但算的时候还是把整层拼回单卡。可如果单层的矩阵本身就大到一张卡放不下呢?那就得把单个矩阵乘法本身切开,让多张卡协同算一个 matmul。类比:一行记录大到单机存不下,你只能把这一行的字段也拆到多机(垂直拆单个算子)。这就是张量并行(Tensor Parallel),Megatron-LM 是它的经典实现。

它解决什么问题 + 工作机制

Megatron-LM(Shoeybi et al. 2019, arXiv 1909.08053)的精妙之处:怎么切 Transformer 的两大算子,让切分边界刚好落在非线性函数之外,从而一次前向只需一次同步。看 MLP 层 Y = GeLU(X·A)·B

  • 第一个矩阵 A 按「列」切成 [A₁, A₂]:X·A₁ 和 X·A₂ 分到两卡独立算。因为 GeLU 是逐元素的,GeLU(X·A₁) 不需要看 A₂ 的结果——非线性可以在各卡本地完成,无需中间同步
  • 第二个矩阵 B 按「行」切成 [B₁; B₂]:两卡各算出部分和,最后用一次 all-reduce 求和得到完整 Y。

为什么是"列切 A、行切 B"而不是反过来?因为这个组合让唯一的非线性(GeLU)落在切分内部,整层前向只在末尾同步一次。换个切法就得在 GeLU 前后各同步一次——通信翻倍。自注意力则更自然:按 attention head 切,每个 head 本就独立计算,天生适合并行。

Megatron MLP 张量并行:Y = GeLU(X·A)·B

输入 X(两卡都有完整 X)
↓ 卡1: X·A₁ │ 卡2: X·A₂(A 按列切)
卡1: GeLU(X·A₁) 卡2: GeLU(X·A₂) ← 非线性本地完成,不同步
↓ 卡1: (…)·B₁ │ 卡2: (…)·B₂(B 按行切)
部分和₁ + 部分和₂ → all-reduce → 完整 Y
整层只在最后同步一次(前向 1 次 all-reduce)

张量并行通信极重(每层都要 all-reduce),所以只在一台机器内的 NVLink 上做(带宽 ~900GB/s),绝不跨节点。最终的万卡训练是 3D 并行

3D 并行:三个轴正交组合

Tensor 并行(TP) → 机内 NVLink,切单层矩阵,通信最重
Pipeline 并行(PP) → 跨机,切层,通信中等
Data 并行(DP) → 跨机/跨集群,切数据,通信最稀

原则:通信越重的并行,放在带宽越高的层级。这是按「带宽层级」做的放置优化
代码示例
# 张量并行的核心思想,用纯 PyTorch 示意「列切 + 行切 + all-reduce」
import torch, torch.distributed as dist
import torch.nn.functional as F

def tp_mlp(X, A_shard, B_shard):
    # A 按列切:本卡只持有 A 的一部分列 → 算出部分隐藏维
    H = F.gelu(X @ A_shard)        # GeLU 逐元素,本地完成,无需同步
    # B 按行切:本卡持有对应的行 → 算出部分和
    Y_partial = H @ B_shard
    # 唯一一次同步:把各卡部分和加起来 = 完整输出
    dist.all_reduce(Y_partial, op=dist.ReduceOp.SUM)
    return Y_partial            # 整层前向只通信了一次
# 真实生产用 NVIDIA Megatron-LM / Megatron-Core,已封装好这套切分
常见误区 + 实践场景
"张量并行能像数据并行一样横向扩到几百卡"——错。TP 每层都要 all-reduce,带宽就是天花板:跨节点走 InfiniBand(~50GB/s)比机内 NVLink(~900GB/s)慢一个量级,TP 一旦跨机性能就崩。所以 TP 几乎只用 8(一台 8 卡机内),更大规模靠 PP 和 DP 往外叠。并行策略的本质是按通信量匹配带宽层级
📌 超级个体场景:你基本不会自己配 3D 并行,但读懂大模型技术报告里"TP=8, PP=4, DP=16, 共 512 卡"这行时,就能立刻反推它的集群拓扑和瓶颈在哪——这是判断一个开源模型"好不好复现 / 微调成本多高"的关键信息。
Takeaway + 思考题
💡 张量并行切的是单个算子,Megatron 的巧思是让切分边界避开非线性,把每层同步压到一次;3D 并行 = 按带宽层级摆放三种并行。
🤔 "把通信最重的并行放在带宽最高的层级"——这条原则和你做存储/缓存分层(热数据上内存、冷数据下磁盘)是不是同一个第一性原理?
工程对应 → super-individual D12: Fine-tuning(多卡微调时 FSDP / ZeRO 的实战配置与显存调优)

深入资源Further Reading

深入思考Deep Questions

1. 数据并行的 all-reduce、ZeRO 的 reduce-scatter + all-gather、张量并行的 all-reduce——这些都是「集合通信原语」。它们和你熟悉的分布式系统通信模式有什么对应?
它们就是 MPI / NCCL 的集合通信(collective communication)原语,和分布式系统里的 gossip、quorum、聚合查询是同一族思想。all-reduce = 每个节点都贡献一份数据、最终每个节点都拿到全局聚合结果(如求和/求平均),等价于"分布式求和后广播回所有人"——数据并行同步梯度、张量并行合并部分和都用它。reduce-scatter = 聚合的同时把结果切片分发,每个节点只拿到自己负责的那一片(ZeRO/FSDP 用它把梯度切回各卡)。all-gather = 反向操作,每个节点把自己的分片广播出去,最终人人持有完整数据(FSDP 临时拼回整层参数)。关键洞察:all-reduce ≈ reduce-scatter + all-gather——这正是为什么 ZeRO 能"免费"地把数据并行的同步拆成两半,在中间插入分片。底层算法(ring all-reduce)的通信量与节点数无关、只和数据量有关,这是它能扩展到上千卡的数学基础。你做过分布式聚合查询,就已经懂这套了,只是换了名字。
2. ZeRO-3 / FSDP「用到才 all-gather、用完就丢」和操作系统的虚拟内存按需分页,本质是不是同一招?代价分别是什么?
思想内核确实一致:不把全部数据常驻昂贵存储,而是按需把"页"换入、用完换出。虚拟内存把不活跃的页换到磁盘,进程访问时触发缺页中断再换回;FSDP 把不活跃层的参数"换出"成 1/N 分片,前向走到该层时 all-gather"换入"完整参数,算完再"换出"。但代价的性质不同:(a) 缺页代价的来源不同——虚拟内存的代价是磁盘 I/O 延迟(随机访问、毫秒级),FSDP 的代价是跨卡通信带宽(NVLink/IB,微秒级但量大);(b) 可预测性不同——虚拟内存的缺页是被动、不可预测的(取决于程序访问模式),而 FSDP 的"换入顺序"是完全确定的(就按层的前向/反向顺序),所以它能做预取(prefetch)——算第 i 层时后台拉第 i+1 层,把通信延迟完全藏在计算背后。这是 FSDP 能高效的根本原因:确定性的访问序列让预取变得完美,而 OS 的通用预取只能猜。(c) ZeRO 还能把状态进一步 offload 到 CPU 内存甚至 NVMe——这时就真的变成了和虚拟内存几乎一样的多级换页,延迟代价也随之逼近磁盘级。
3. 为什么张量并行(TP)只能在机内做、数据并行(DP)却能跨集群扩展?这背后是什么定量原则?
核心是「通信量 / 计算量」的比值,决定了一种并行对带宽有多敏感。张量并行:每一层前向都要至少一次 all-reduce,通信发生得极其频繁(每层、每步),且和计算紧耦合——一次同步没完成,下一层算不了。它对带宽和延迟都极度敏感,所以必须放在带宽最高的层级(机内 NVLink ~900GB/s)。一旦跨节点走 InfiniBand(~50GB/s,慢约 20 倍),TP 立刻被通信拖垮。数据并行:只在每步反向结束后同步一次梯度,一步里就一次通信,而且可以和反向计算重叠(算完一层的梯度就立刻开始传,不用等整个反向结束)。通信频率低、可隐藏,所以即使带宽较低也能扛,能扩到跨机房上千卡。流水线并行居中:只在层段边界传激活值,通信量中等。这就推出了 3D 并行的放置定律把通信最密集的并行放在带宽最高的层级——TP 进 NVLink 域(通常 8 卡),PP 跨节点,DP 跨集群。这和你做存储分层(热数据进内存、温数据进 SSD、冷数据进对象存储)是同一个第一性原理:按访问/通信强度匹配资源带宽。理解这条,你就能看一眼"TP=8, PP=4, DP=N"反推出整个集群的物理拓扑。
4. 这四种技术(DP/ZeRO/FSDP/TP)——如果你要训练一个模型,决策树应该怎么走?
按"什么装不下"逐级升级,这是最干净的决策框架:第一步问:单卡能装下完整模型(参数+梯度+优化器)吗?能 → 直接 DDP 数据并行,最简单、通信最少、扩展性最好。这是首选,别过度工程。装不下,第二步问:是优化器状态太肥,还是模型本身太大?如果是前者(参数其实塞得下,是 Adam 的 12 字节/参数压垮了显存)→ 上 ZeRO / FSDP,按需选 stage:先 ZeRO-1(只切优化器状态,几乎不加通信),不够再 ZeRO-2(切梯度),还不够 ZeRO-3 / FSDP(连参数也切)。原则是用能装下的最低 stage第三步:如果连单层都大到一张卡放不下(百亿/千亿级的巨大 FFN 或 embedding)→ 必须 张量并行(Megatron),但只在机内 8 卡做。第四步:模型层数太多、单机放不下整个模型 → 加流水线并行跨节点切层。最终的超大规模就是全叠起来:TP(机内)× PP(跨节点)× DP/ZeRO(跨集群)= 3D 并行。关键心法:每一级都是"上一级装不下"才引入,因为每加一种并行都增加通信和复杂度。这和数据库扩展完全同构——先单机、再读副本(DP)、再分库分表(sharding/TP)、能不分就不分。
5. 这些技术让"训练万亿参数模型"成为可能,但它们都是在「绕过单卡显存墙」。如果显存墙是根本约束,这是否预示着架构层面(而非系统层面)的解法?
这是个深刻的跨层问题。DP/ZeRO/FSDP/TP 全是系统层的"搬运工"——它们不改变"训练这个模型需要多少总显存",只是把它摊薄到更多卡 + 用通信换空间。摊薄有上限:卡越多,通信占比越高,最终通信会吃掉所有收益(这也是 scaling efficiency 永远 <100% 的根因)。所以真正的突破往往来自架构层——直接降低"每个有效参数的训练成本":(a) MoE(专家混合):每个 token 只激活一小部分专家,参数总量巨大但单次计算/激活显存小,用"稀疏激活"绕开稠密矩阵的显存墙(Day 34 主题);(b) 更省的注意力(FlashAttention 不省参数但省激活显存,把 O(n²) 中间矩阵不落地);(c) 状态空间模型 / 线性注意力:从根上把 O(n²) 降到 O(n)。系统层和架构层是互补而非替代:架构决定"理论需要多少资源",系统决定"怎么把它铺到真实硬件上"。BigCat 你可以这样看——分布式训练是把"墙"推远的工程学,而架构创新是重新设计房子让它根本不需要那么大的墙。两条路线的张力,正是 2026 年大模型基础设施最核心的研究主线之一。