一个 70B 模型单卡装不下,就像一张表太大单机存不下——你的本能是分片(sharding)。但训练有三个正交的切分维度:切数据(像多个只读副本)、切模型的层(像微服务流水线)、切单层里的大矩阵(像把一个大 JOIN 拆到多机)。理解分布式训练,本质是理解这三个轴怎么组合。
痛点是显存(VRAM),不是算力。一个 70B 模型用 fp16,光参数就 140GB;加上梯度和 Adam 优化器状态,训练时每个参数要占 ~16 字节(下张卡的 ZeRO 会拆解这个数字),总共 1TB 出头——一张 80GB 的 H100 根本放不下。三种并行各解决一块:
# 最简单的数据并行:PyTorch DDP。用 torchrun 启动多进程 import torch, torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP dist.init_process_group("nccl") # NCCL = NVIDIA 的 GPU 集合通信库 rank = dist.get_rank() torch.cuda.set_device(rank) model = MyTransformer().to(rank) model = DDP(model, device_ids=[rank]) # 反向时自动 all-reduce 梯度 for batch in shard_loader(rank): # 每个 rank 拿不同数据切片 loss = model(batch).loss loss.backward() # ← 梯度在这一步跨卡求平均 optimizer.step(); optimizer.zero_grad() # 启动:torchrun --nproc_per_node=8 train.py
纯数据并行有个巨大浪费:每张卡都存一份完整的参数、梯度、优化器状态——这是 N 份冗余副本,等于分布式存储里"全量复制到每个节点"。ZeRO 的洞察是:这些状态大部分时间用不到,干嘛每张卡都存满?于是把它们切片(shard)到各卡,每卡只存 1/N,用到时临时 all-gather 拼回来。本质是用通信换显存——和用 erasure coding 替代全量复制是同一个 trade-off。
先把"每参数 16 字节"拆开看——这是理解 ZeRO 的关键。混合精度训练里,每个参数要存:
ZeRO 分三级,渐进式地切掉冗余,省得越多、通信越重:
这是微软 DeepSpeed 的 ZeRO(Rajbhandari et al. 2019, arXiv 1910.02054)。它最优雅的地方:在数据并行的编程模型下(每卡看到的还是"完整模型"的假象),偷偷把内存冗余消掉了。
# DeepSpeed 用一份 JSON config 开启 ZeRO,业务代码几乎不变 ds_config = { "train_micro_batch_size_per_gpu": 4, "bf16": {"enabled": True}, "zero_optimization": { "stage": 3, # 切 优化器+梯度+参数 "offload_optimizer": {"device": "cpu"}, # 还能下放到 CPU 内存 "stage3_prefetch_bucket_size": 5e7 # 预取,藏通信延迟 } } import deepspeed engine, _, _, _ = deepspeed.initialize( model=model, config=ds_config, model_parameters=model.parameters()) for batch in loader: loss = engine(batch).loss engine.backward(loss) # ZeRO 在这里做 reduce-scatter + all-gather engine.step()
FSDP 就是 PyTorch 官方把 ZeRO-3 写进了标准库。机制几乎等价,区别像第三方库 vs 内置实现:DeepSpeed/ZeRO 是外挂的(类似 requests),FSDP 是和 PyTorch 的 autograd、dispatcher、显存分配器深度同源的官方版本(类似内置的 urllib,但更顺手)。对大多数纯 PyTorch 项目,FSDP 是"不引入新生态就能拿到 ZeRO-3"的默认选择。
解决的痛点和 ZeRO-3 一样——消除数据并行的内存冗余——但要无缝融进原生 PyTorch。机制可以一句话概括:平时分片、用时聚合、用完释放。
权威参考是 PyTorch FSDP 论文(Zhao et al. 2023, arXiv 2304.11277),里面讲清了它怎么和 PyTorch 核心组件协同设计。auto_wrap_policy 决定切分粒度:unit 切太细 → all-gather 次数多、通信碎;切太粗 → 单次聚合的峰值显存高。粒度本身就是显存与通信的连续权衡。
# PyTorch 原生 FSDP,无需第三方库 import torch, functools from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # 按 Transformer block 自动切分单元——粒度决定通信/显存权衡 policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock}) model = FSDP( model, auto_wrap_policy=policy, device_id=torch.cuda.current_device()) # 之后的训练循环和单卡几乎一模一样——分片对业务代码透明 for batch in loader: loss = model(batch).loss loss.backward() # FSDP 在此处理 all-gather / reduce-scatter optimizer.step(); optimizer.zero_grad()
model.state_dict() 存档、或写一段假设"本卡有完整参数"的自定义逻辑,都可能出错——得用 FSDP 提供的 state_dict 接口先把分片聚合回来。分片是透明的,但不是隐形的:碰参数内部时它就会冒出来。accelerate / trl 背后的默认引擎。哪怕你只写一行 accelerate launch,理解"它在每层 all-gather 又 free"能让你看懂为什么峰值显存和你算的参数量对不上——多出来的就是当前层的完整副本。ZeRO 和 FSDP 都假设单层算得动——它们切的是"谁存参数",但算的时候还是把整层拼回单卡。可如果单层的矩阵本身就大到一张卡放不下呢?那就得把单个矩阵乘法本身切开,让多张卡协同算一个 matmul。类比:一行记录大到单机存不下,你只能把这一行的字段也拆到多机(垂直拆单个算子)。这就是张量并行(Tensor Parallel),Megatron-LM 是它的经典实现。
Megatron-LM(Shoeybi et al. 2019, arXiv 1909.08053)的精妙之处:怎么切 Transformer 的两大算子,让切分边界刚好落在非线性函数之外,从而一次前向只需一次同步。看 MLP 层 Y = GeLU(X·A)·B:
为什么是"列切 A、行切 B"而不是反过来?因为这个组合让唯一的非线性(GeLU)落在切分内部,整层前向只在末尾同步一次。换个切法就得在 GeLU 前后各同步一次——通信翻倍。自注意力则更自然:按 attention head 切,每个 head 本就独立计算,天生适合并行。
张量并行通信极重(每层都要 all-reduce),所以只在一台机器内的 NVLink 上做(带宽 ~900GB/s),绝不跨节点。最终的万卡训练是 3D 并行:
# 张量并行的核心思想,用纯 PyTorch 示意「列切 + 行切 + all-reduce」 import torch, torch.distributed as dist import torch.nn.functional as F def tp_mlp(X, A_shard, B_shard): # A 按列切:本卡只持有 A 的一部分列 → 算出部分隐藏维 H = F.gelu(X @ A_shard) # GeLU 逐元素,本地完成,无需同步 # B 按行切:本卡持有对应的行 → 算出部分和 Y_partial = H @ B_shard # 唯一一次同步:把各卡部分和加起来 = 完整输出 dist.all_reduce(Y_partial, op=dist.ReduceOp.SUM) return Y_partial # 整层前向只通信了一次 # 真实生产用 NVIDIA Megatron-LM / Megatron-Core,已封装好这套切分