Distributed TrainingDistributed Training
systemsparallelism
One-line analogy
A 70B model won't fit on one GPU — just like a table too big for one machine. Your instinct is sharding. But training has three orthogonal axes to split along: split the data (like read replicas), split the model's layers (like a microservice pipeline), or split the big matrix inside a single layer (like splitting one huge JOIN across machines). Understanding distributed training is really about how these three axes combine.
What it solves + how it works
The pain point is VRAM, not compute. A 70B model in fp16 is 140GB of parameters alone; add gradients and Adam optimizer state and each parameter costs ~16 bytes during training (the ZeRO card breaks this number down) — over 1TB total. An 80GB H100 can't hold it. The three parallelisms each solve one piece:
- Data Parallel (DP): every GPU holds a full copy of the model, only the training batch is split. Each GPU runs forward/backward independently, then all-reduce averages the gradients. Analogy: multiple read replicas each handling a batch, reconciling periodically. Requires the model itself to fit on one GPU;
- Pipeline Parallel (PP): split the model by layers into stages across GPUs, like an assembly line / CPU pipeline. The cost is the bubble — stage 1 finishes and waits for stage 2; you fill it by chopping the batch into micro-batches;
- Tensor Parallel (TP): split the big matrix multiply inside a single layer across GPUs (detailed in card 4). Communication-heaviest, done only over a machine's high-speed NVLink.
Three orthogonal splitting axes
Data parallel split the batch, replicate model → all-reduce gradients
Pipeline parallel split layers (vertical) → GPU A: layers 1-8 │ GPU B: 9-16 …
Tensor parallel split a single layer's matrix (horizontal) → one matmul across GPUs
Modern thousand-GPU training = all three combined, called 3D parallelism
Code example
# Simplest data parallel: PyTorch DDP. Launch multi-process with torchrun
import torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group("nccl") # NCCL = NVIDIA's GPU collective comm library
rank = dist.get_rank()
torch.cuda.set_device(rank)
model = MyTransformer().to(rank)
model = DDP(model, device_ids=[rank]) # auto all-reduce of grads on backward
for batch in shard_loader(rank): # each rank gets a different data slice
loss = model(batch).loss
loss.backward() # ← gradients averaged across GPUs here
optimizer.step(); optimizer.zero_grad()
# Launch: torchrun --nproc_per_node=8 train.py
Common pitfall + practical scenario
"More GPUs = linear speedup" — wrong. Adding GPUs adds communication overhead (gradient sync, parameter gathering), and all-reduce cost grows with GPU count. 8 GPUs often hit 7x, but 512 GPUs might only reach ~380x (~74% efficiency). That's why the Megatron paper reports "scaling efficiency" rather than raw multipliers — communication is always fighting you for bandwidth.
📌 Super-individual scenario: you probably won't train a 70B from scratch, but you will fine-tune. The first-principles test for "do I need multiple GPUs" is exactly this card's logic — estimate the combined VRAM of «params + gradients + optimizer state»; over one GPU means you must shard. Treat it as capacity planning, the same mindset as planning database shards.
Takeaway + question
💡 The bottleneck in distributed training is VRAM and communication, not compute. The three parallelisms are three orthogonal axes — choose by "what doesn't fit."
🤔 In distributed storage you know the "replication vs sharding" trade-off — how exactly does it map onto "data parallel vs ZeRO sharding"?
Zero Redundancy OptimizerZeRO
VRAMsharding
One-line analogy
Pure data parallel has a huge waste: every GPU stores a full copy of the parameters, gradients, and optimizer state — N redundant copies, like "replicate the full dataset to every node" in distributed storage. ZeRO's insight: these states are idle most of the time, so why keep them full on every GPU? Shard them across GPUs so each holds only 1/N, and all-gather on demand when needed. It's fundamentally trading communication for VRAM — the same trade-off as using erasure coding instead of full replication.
What it solves + how it works
First, unpack the "16 bytes per parameter" — that's the key to ZeRO. In mixed-precision training each parameter stores:
Mixed-precision training: VRAM bill per parameter
fp16 parameter 2 bytes
fp16 gradient 2 bytes
fp32 param copy 4 bytes ┐
fp32 momentum m 4 bytes ├ Adam optimizer state
fp32 variance v 4 bytes ┘
≈ 16 bytes/param. 70B model ≈ 1.1 TB — the 12 bytes of optimizer state is the fattest part
ZeRO comes in three stages, progressively cutting redundancy — more savings, more communication:
- ZeRO-1: shard the optimizer state (those 12 bytes). Cut the fattest part first, almost no extra communication — best bang for the buck;
- ZeRO-2: also shard the gradients. After backward, reduce-scatter gradients to the owning GPU, keep only your slice locally;
- ZeRO-3: shard the parameters themselves too. When forward/backward reaches a layer, all-gather the full params, and discard immediately after use. Per-GPU VRAM drops to about 16/N bytes/param — in theory, adding GPUs linearly lowers VRAM.
This is Microsoft DeepSpeed's ZeRO (Rajbhandari et al. 2019, arXiv 1910.02054). Its elegance: under the data-parallel programming model (each GPU still sees the illusion of a "full model"), it quietly eliminates the memory redundancy.
Code example
# DeepSpeed enables ZeRO via one JSON config; business code barely changes
ds_config = {
"train_micro_batch_size_per_gpu": 4,
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 3, # shard optimizer + grads + params
"offload_optimizer": {"device": "cpu"}, # can even offload to CPU RAM
"stage3_prefetch_bucket_size": 5e7 # prefetch to hide comm latency
}
}
import deepspeed
engine, _, _, _ = deepspeed.initialize(
model=model, config=ds_config, model_parameters=model.parameters())
for batch in loader:
loss = engine(batch).loss
engine.backward(loss) # ZeRO does reduce-scatter + all-gather here
engine.step()
Common pitfall + practical scenario
"ZeRO-3 saves the most VRAM, so just always turn on stage 3" — wrong. Higher stages mean every forward/backward must all-gather parameters, and communication explodes. Worth it if your bottleneck is VRAM; if the model actually fits and your bottleneck is speed, stage 3 slows you down. Rule: use the lowest stage that fits. Same engineering judgment as "don't compress hot data just to save a little disk."
📌 Super-individual scenario: fine-tuning a 13B model on 4 consumer GPUs (e.g. 24GB ×4), ZeRO-2 + optimizer CPU offload is often the line between "barely runs" and "OOM crash." Understand the 16-byte bill and you can compute beforehand which stage to enable, instead of trial-and-error against the wall.
Takeaway + question
💡 ZeRO = replace data parallel's redundant copies with sharding + on-demand gathering, trading communication bandwidth for VRAM.
🤔 ZeRO-3's "all-gather only when needed, discard after use" — is it the same trick as virtual memory's demand paging? What's the "page-fault cost" for each?
Fully Sharded Data ParallelFSDP
PyTorchnative
One-line analogy
FSDP is PyTorch officially baking ZeRO-3 into the standard library. The mechanism is nearly equivalent; the difference is like third-party library vs built-in: DeepSpeed/ZeRO is bolted on (like requests), FSDP is co-designed with PyTorch's autograd, dispatcher, and memory allocator (like built-in urllib, but smoother). For most pure-PyTorch projects, FSDP is the default way to get ZeRO-3 without adopting a new ecosystem.
What it solves + how it works
It solves the same pain as ZeRO-3 — eliminate data-parallel memory redundancy — but must blend seamlessly into native PyTorch. The mechanism in one line: shard at rest, gather in use, free after use.
- Wrap into units: split the model into FSDP units by layer/module. At rest, each unit's params keep only a 1/N shard on this GPU;
- Forward into a unit: all-gather to reconstruct that layer's full params → compute the layer → free immediately the non-local parts, keeping only 1/N;
- Backward likewise: all-gather once more to compute gradients, then reduce-scatter the gradient shards back to their owning GPUs;
- Comm-compute overlap: while computing layer i, prefetch layer i+1's params in the background, hiding all-gather latency behind compute — the key to running fast.
The authoritative reference is PyTorch FSDP (Zhao et al. 2023, arXiv 2304.11277), which details the co-design with PyTorch internals. auto_wrap_policy sets granularity: units too fine → many all-gathers, fragmented comm; too coarse → high peak VRAM per gather. Granularity itself is a continuous VRAM-vs-communication trade-off.
Lifecycle of one FSDP layer (repeats per layer)
sharded: GPU has only 1/N of params
↓ all-gather (reconstruct full layer)
full: run forward / backward
↓ free non-local shards + reduce-scatter grads
back to sharded, VRAM released
Peak VRAM ≈ whole model sharded (1/N) + the full params of the current layer (not the whole model)
Code example
# PyTorch-native FSDP, no third-party library
import torch, functools
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Auto-split units by Transformer block — granularity sets the comm/VRAM trade-off
policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock})
model = FSDP(
model,
auto_wrap_policy=policy,
device_id=torch.cuda.current_device())
# The training loop afterward looks almost like single-GPU — sharding is transparent
for batch in loader:
loss = model(batch).loss
loss.backward() # FSDP handles all-gather / reduce-scatter here
optimizer.step(); optimizer.zero_grad()
Common pitfall + practical scenario
"FSDP is a drop-in upgrade over DDP, just swap the wrapper" — be careful. FSDP changes the model's state semantics: parameters are sharded, so calling model.state_dict() directly, or any custom logic assuming "this GPU has the full params," can break — you must use FSDP's state_dict APIs to gather shards first. Sharding is transparent, but not invisible: it surfaces the moment you touch parameter internals.
📌 Super-individual scenario: fine-tuning large models in 2026, FSDP is essentially the default engine behind HuggingFace accelerate / trl. Even if you only write one line of accelerate launch, understanding "it all-gathers and frees per layer" lets you see why peak VRAM doesn't match your parameter-count estimate — the extra is the full copy of the current layer.
Takeaway + question
💡 FSDP = native-PyTorch ZeRO-3: shard at rest, gather in use, free after use, with comm hidden behind compute.
🤔 FSDP makes "full parameters" a short-lived, use-and-discard object — is that the same pattern as your caching code's "materialize on demand, invalidate immediately"?
Megatron-LM Tensor ParallelismMegatron-LM / Tensor Parallelism
tensor parallel3D parallel
One-line analogy
ZeRO and FSDP both assume a single layer is computable — they shard "who stores params," but reassemble the full layer on one GPU to compute. What if a single layer's matrix is itself too big for one GPU? Then you must split the single matrix multiply itself, letting multiple GPUs cooperate on one matmul. Analogy: one record too big for a single machine — you have to split that record's fields across machines (vertically splitting a single operator). This is tensor parallelism, and Megatron-LM is its classic implementation.
What it solves + how it works
Megatron-LM (Shoeybi et al. 2019, arXiv 1909.08053) is clever about how to split the Transformer's two big operators so the split boundary lands outside the nonlinearity — so one forward needs only one sync. Take the MLP layer Y = GeLU(X·A)·B:
- Split the first matrix A by columns into [A₁, A₂]: X·A₁ and X·A₂ compute independently on two GPUs. Because GeLU is element-wise, GeLU(X·A₁) doesn't need A₂'s result — the nonlinearity completes locally on each GPU, no intermediate sync;
- Split the second matrix B by rows into [B₁; B₂]: each GPU computes a partial sum, then one all-reduce sums them into the full Y.
Why "column-split A, row-split B" and not the reverse? Because this combo keeps the only nonlinearity (GeLU) inside the split, so the whole layer's forward syncs just once at the end. The other split would require a sync before and after GeLU — doubling communication. Self-attention is even more natural: split by attention head, since each head computes independently — born for parallelism.
Megatron MLP tensor parallel: Y = GeLU(X·A)·B
input X (both GPUs hold full X)
↓ GPU1: X·A₁ │ GPU2: X·A₂ (A split by columns)
GPU1: GeLU(X·A₁) GPU2: GeLU(X·A₂) ← nonlinearity local, no sync
↓ GPU1: (…)·B₁ │ GPU2: (…)·B₂ (B split by rows)
partial sum₁ + partial sum₂ → all-reduce → full Y
The whole layer syncs only once at the end (1 all-reduce per forward)
Tensor parallel communication is heavy (an all-reduce every layer), so it's done only over a single machine's NVLink (~900GB/s bandwidth), never across nodes. The final thousand-GPU training is 3D parallelism:
3D parallelism: three orthogonal axes combined
Tensor parallel (TP) → intra-machine NVLink, split a layer's matrix, heaviest comm
Pipeline parallel (PP) → across nodes, split layers, medium comm
Data parallel (DP) → across nodes/clusters, split data, sparsest comm
Principle: place the heaviest-comm parallelism at the highest-bandwidth tier — a placement optimization by "bandwidth tier"
Code example
# The core idea of tensor parallel: column-split + row-split + all-reduce, in plain PyTorch
import torch, torch.distributed as dist
import torch.nn.functional as F
def tp_mlp(X, A_shard, B_shard):
# A split by columns: this GPU holds only some columns of A → partial hidden dims
H = F.gelu(X @ A_shard) # GeLU is element-wise, local, no sync needed
# B split by rows: this GPU holds the matching rows → a partial sum
Y_partial = H @ B_shard
# The one sync: add the partial sums across GPUs = full output
dist.all_reduce(Y_partial, op=dist.ReduceOp.SUM)
return Y_partial # the whole layer's forward communicated once
# Production uses NVIDIA Megatron-LM / Megatron-Core, which packages this splitting
Common pitfall + practical scenario
"Tensor parallel scales out to hundreds of GPUs like data parallel" — wrong. TP all-reduces every layer, so bandwidth is the ceiling: cross-node InfiniBand (~50GB/s) is an order of magnitude slower than intra-machine NVLink (~900GB/s), and TP collapses once it spans machines. So TP is almost always 8 (within one 8-GPU box); larger scale stacks PP and DP on top. The essence of a parallelism strategy is matching communication volume to bandwidth tiers.
📌 Super-individual scenario: you rarely configure 3D parallelism yourself, but when you read "TP=8, PP=4, DP=16, 512 GPUs total" in a model's tech report, you can immediately back out its cluster topology and where the bottleneck is — key for judging how reproducible / cheap-to-fine-tune an open model is.
Takeaway + question
💡 Tensor parallel splits a single operator; Megatron's trick is making the split boundary avoid the nonlinearity, compressing each layer's sync to once. 3D parallel = placing the three parallelisms by bandwidth tier.
🤔 "Put the heaviest-comm parallelism at the highest-bandwidth tier" — is this the same first principle as your storage/cache tiering (hot data in RAM, cold data on disk)?