The previous 9 issues covered model architecture + inference — how a trained LLM works. Today is about how it actually gets trained, in four gears: Cross-Entropy (the scoring function — how badly you're wrong), AdamW (the step engine — per-parameter adaptive step size), LR Schedule (the scheduler — when to go fast, when to go slow), and Gradient Clipping (the circuit breaker — so a single step doesn't blow up the model). These four together are why GPT/Llama/Claude can train stably across trillions of tokens without diverging.
Cross-entropy is the model's "SLA penalty fee" during training — how far the predicted probability distribution is from the truth, that's how much you owe. Backend analogy: an error-rate monitor, but not a coarse "right/wrong" binary — it's billed at fine granularity by how much confidence you assigned to the correct answer. 99% confidence on the right token = almost free; 0.1% confidence on the right token = heavy fine — and the gradient tells the model "next time, give this token more probability."
The core LLM training task is next-token prediction: given the previous N tokens, the model outputs a probability distribution over the vocabulary (~50K–150K entries); the truth is one specific token. You need a loss that turns "predicted distribution vs. single ground-truth point" into a differentiable scalar.
Formula: H(p, q) = -Σx p(x) log q(x). Each symbol: p is the true distribution (one-hot for LLMs — 1 at the correct token, 0 elsewhere); q is the model's predicted distribution (softmax output). Since p is one-hot, the sum collapses to one term: loss = -log q(correct token).
Intuition: model assigns probability 0.99 to the correct token → loss = -log(0.99) ≈ 0.01 (nearly free); probability 0.001 → loss = -log(0.001) ≈ 6.9 (heavy fine); probability 0 → loss = ∞ (theoretically punished to the sky — which is why softmax never outputs a true zero). The beauty of negative log: "more confident and correct" is rewarded gently, "more confident and wrong" is punished sharply — naturally aligned with the desired calibration property.
Why not MSE? On classification, MSE ((predicted - true)²) has tiny gradients — softmax outputs are already in [0,1], squaring then differentiating shrinks gradients toward zero, training is slow and prone to bad local minima. Cross-entropy paired with softmax has a beautiful gradient form: ∂loss/∂logit = (softmax_output - one_hot) — literally just "prediction − truth", no decay term. This mathematical coincidence is why every LLM trains with cross-entropy.
import torch import torch.nn.functional as F # Vocab size 50000, batch=2, model outputs a logit vector per position logits = torch.randn(2, 50000) # raw model output (pre-softmax) targets = torch.tensor([42, 7]) # ids of the true next tokens # PyTorch's cross_entropy fuses log_softmax + NLL — numerically stable loss = F.cross_entropy(logits, targets) print(loss.item()) # Untrained ≈ ln(50000) ≈ 10.8 (uniform-guess baseline) # Manual equivalent, to make the formula concrete: log_probs = F.log_softmax(logits, dim=-1) manual_loss = -log_probs[range(2), targets].mean() assert torch.isclose(loss, manual_loss) # In LLM training reduce='mean' averages loss across all token positions. # A well-trained LLM on web text reaches loss ≈ 2.0–2.5 (perplexity = exp(loss) ≈ 8–12)
AdamW is "a load balancer with per-key adaptive rate limits" — same idea as per-key adaptive rate limiting in distributed systems. If a parameter has had consistently large gradients lately ("hot key" — big impact on loss), AdamW automatically gives it a smaller step; if gradients have been small ("cold key"), it amplifies the step. Every parameter's learning rate is dynamically calibrated by its own gradient history.
Naive SGD uses one global lr. But across Transformer layers and parameters, gradient magnitudes vary by orders of magnitude — one lr either blows up large-gradient parameters or starves small-gradient ones. Adam (Kingma & Ba 2014) insight: use each parameter's own gradient statistics to normalize its step. AdamW (Loshchilov & Hutter 2017, ICLR 2019) fixed a long-overlooked Adam bug — decoupling weight decay from the gradient so regularization still works under adaptive lr. Today every modern LLM trains with AdamW.
Update rule (each step):
mt = β₁·mt-1 + (1-β₁)·gt ← 1st moment: EMA of recent gradient direction (momentum)vt = β₂·vt-1 + (1-β₂)·gt² ← 2nd moment: EMA of recent squared-gradient magnitudeθt = θt-1 - lr · m̂t / (√v̂t + ε) - lr · wd · θt-1 ← update (m̂, v̂ are bias-corrected)Intuition: numerator m̂ is the smoothed recent gradient direction, providing momentum (cutting through noise); denominator √v̂ is the typical magnitude of recent gradients, normalizing each parameter to its own scale (large-grad params divided by a large number = small step; small-grad params divided by a small number = larger step). The last term -lr·wd·θ is decoupled weight decay: each step pulls parameters slightly toward zero to prevent weight blow-up — placing this term outside the gradient is the entire change AdamW makes over Adam, and it dramatically improves generalization.
Memory cost: each parameter stores m and v, so optimizer state = 2× model parameters. A 70B model in FP32 = 280 GB parameters, optimizer state = 560 GB, total 840 GB — this is why training large models requires ZeRO/FSDP to shard optimizer state across many GPUs. Typical hyperparams: lr=1e-4 to 3e-4 (pretraining), 1e-5 to 1e-4 (fine-tuning), β₁=0.9, β₂=0.95 (LLMs prefer 0.95 over the default 0.999 for stability), wd=0.1, ε=1e-8.
import torch from torch.optim import AdamW model = ... # your nn.Module # Standard convention: don't apply weight decay to biases or LayerNorm scales decay, no_decay = [], [] for n, p in model.named_parameters(): if p.dim() >= 2: decay.append(p) # matrix weights: apply wd else: no_decay.append(p) # biases / norm scales: don't optimizer = AdamW( [{"params": decay, "weight_decay": 0.1}, {"params": no_decay, "weight_decay": 0.0}], lr=3e-4, betas=(0.9, 0.95), # β₂=0.95 is standard for GPT/Llama-style training eps=1e-8, ) # Training loop: for batch in data: optimizer.zero_grad() loss = model(**batch).loss loss.backward() # compute gradients optimizer.step() # apply AdamW update rule
Adam(weight_decay=...) implementation is the broken version — always use torch.optim.AdamW.Adam(..., weight_decay=0.1) from old PyTorch tutorials — looks the same, gives meaningfully worse models.Almost a one-to-one mapping to TCP slow-start + congestion avoidance: probe the network gently first (warmup), accelerate to steady-state throughput, and back off gracefully on congestion signals. Also like the gradual ramp-up of a database connection pool, or a new hire's probation → main contribution → handover. In one line: "when to go fast, when to go slow" is not a constant; it's a function of time.
Early in training, weights are randomly initialized and the loss landscape is extremely steep — starting at peak lr means the first gradient is huge → weights blow up → training diverges (loss → NaN). Late in training, you want delicate fine-tuning toward a local minimum, and a large lr makes updates oscillate past the optimum. Using one lr for the whole run loses on both ends, hence the need for a time-varying curve.
Mainstream recipe = Warmup + Cosine Decay (GPT-3, Chinchilla, Llama all use it):
lr(t) = lr_peak · t / T_warmup. Lets Adam's 2nd moment v accumulate enough samples before being used as a normalizer, avoiding early-step explosions caused by an under-sampled √v.lr(t) = lr_min + ½(lr_peak − lr_min)(1 + cos(π · progress)). The cosine curve is slow at both ends, fast in the middle — lingering at peak to search broadly, then slowing in the final phase to avoid overshooting the minimum.Why cosine? The SGDR paper (Loshchilov & Hutter 2016) found empirically that cosine beats linear and step-decay. Later guess: cosine has zero derivative at both ends, giving a "gentle start" and "gentle finish." The community also uses linear decay and trapezoidal schedules — gaps are typically <1%, but cosine became the default by first-mover momentum. The point is "must have warmup + must have decay," not "cosine vs linear" — missing either end breaks LLM training.
from transformers import get_cosine_schedule_with_warmup from torch.optim import AdamW optimizer = AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95)) total_steps = 100_000 warmup_steps = int(0.02 * total_steps) # 2% warmup, GPT-3 style scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, # defaults to decay to 0; to keep 10% floor, customize or use num_cycles=0.5 ) for step, batch in enumerate(loader): loss = model(**batch).loss loss.backward() optimizer.step() scheduler.step() # ← call every step; lr follows the curve optimizer.zero_grad() if step % 100 == 0: print(f"step="{step} lr={scheduler.get_last_lr()[0]:.2e} loss={loss.item():.3f}") # You'll see lr ramp from ~0 to 3e-4, then decay; loss drops surprisingly fast during warmup
Gradient clipping = the training process's circuit breaker / API rate limiter. After computing gradients each step, check if the total length exceeds a threshold — if so, scale everything proportionally back to the threshold; direction unchanged. Same structure as HA systems' "intercept individual requests that exceed timeout to prevent overloading downstream." Pascanu et al. 2013 proposed it for RNNs to prevent exploding gradients; today every LLM training run defaults to clip_norm = 1.0.
LLM training spans hundreds of thousands of steps. The loss curve descends smoothly almost the whole time, but occasionally there's a loss spike — some batch happens to contain a rare combination (special characters, corrupted text, an out-of-distribution pattern), causing activations in some layer to balloon → backprop produces gradients dozens or hundreds of times the usual magnitude → one update shoves weights far away → subsequent training never recovers → training diverges, days of compute wasted. Clipping is the circuit breaker for this disaster.
Algorithm (most common: global norm clipping):
g₁, g₂, ..., gn‖g‖ = √(Σ‖gi‖²) (concatenate all parameter gradients into a giant vector, take its length)c (typically 1.0). If ‖g‖ > c, scale every gi by c / ‖g‖; otherwise leave aloneoptimizer.step()Key invariant: g' = g · min(1, c/‖g‖). Direction is completely preserved (all components scaled by the same factor) — only the over-long step is reined back to a safe length.
Diagnostic signal: log grad_norm during training. A healthy distribution is a narrow peak (0.1-0.5) with occasional spikes (10-100) showing clip saved you. If grad_norm consistently sits at the threshold, lr is too high or data has issues — investigate; if grad_norm explodes to 1e6+ and loss is still NaN after clipping, it's likely fp16 numerical overflow — switch to bf16 or enable loss scaling. GPT-3, Chinchilla, Llama all use clip_norm = 1.0 — this has become a no-discussion default.
import torch from torch.nn.utils import clip_grad_norm_ CLIP_VALUE = 1.0 # GPT-3/Llama default; rarely needs tuning for step, batch in enumerate(loader): loss = model(**batch).loss loss.backward() # compute gradients # Key line: in-place clip of all parameter gradients, returns pre-clip norm grad_norm = clip_grad_norm_( model.parameters(), max_norm=CLIP_VALUE, ) optimizer.step() # update with clipped gradients scheduler.step() optimizer.zero_grad() # Always monitor grad_norm distribution — most important training health signal if step % 10 == 0: print(f"step="{step} grad_norm={grad_norm:.3f} loss={loss.item():.3f}") # healthy: 0.1-0.5 with occasional spikes > 10 saved by clip # alarm: persistently at 1.0 (lr too high); sudden NaN/Inf (numerical overflow)
H(p, q) = -Σ p(x) log q(x) has a precise physical meaning in information theory: average bits per symbol if you use arithmetic coding with model distribution q on data drawn from true distribution p. Minimizing cross-entropy = making the model the optimal compressor of real data. GPT-2/3/4 training reaches per-token loss of ~1.5-2.5 nats (≈ 2.2-3.6 bits/token); raw UTF-8 text is ~5 bits/byte, so a well-trained LLM compresses human text to 30-50% of original size — substantially better than gzip (~70%). Deeper: to predict the next token down to this low a loss, the model must genuinely "understand" context — syntax, world knowledge, reasoning chains, character intent. So "understanding is compression" is not metaphor: compress → must predict → must understand structure. Marcus Hutter (inventor of AIXI; the same Hutter from the AdamW paper) has long argued "AI evaluation = compression ratio," and his Hutter Prize still rewards the best text-compression algorithms. BigCat, your backend intuition for Parquet/dictionary encoding/protobuf varint solves the same problem: find structure in data, encode in fewer bits. LLMs just made "find structure" general-purpose — that's their essential difference from traditional compressors.