An activation function is the network's non-linear switch. Without it, stacking 100 layers is exactly equivalent to one layer — like chaining 100 middlewares that only do linear transforms (pure forwarding, pure scaling); the compiler collapses them all into a single matrix. ReLU = max(0, x), the same as SQL's GREATEST(0, x): zero out negatives, pass positives — the most primitive gate there is.
The core pain point is linear collapsibility: a layer computes y = Wx + b (matmul + bias). Two stacked linear layers W₂(W₁x) = (W₂W₁)x are still a single linear map — no matter how deep, the expressive power equals one layer, fitting only lines/hyperplanes. Activations inject non-linearity between layers so the network can approximate any complex function (the intuition behind the "universal approximation theorem").
Early nets used Sigmoid / Tanh (S-shaped curves), but they "saturate" at the extremes — when input is very large or small the gradient approaches 0, and in backprop gradients multiplied layer by layer decay exponentially, so deep nets can't learn (vanishing gradients). ReLU's revolution: in the positive region the gradient is constant 1, never decaying, making deep nets trainable. The cost is Dying ReLU: if a neuron's input stays negative for long, its output is stuck at 0, its gradient is 0, and it's permanently dead — its most famous failure mode.
import torch, torch.nn as nn # Intuition: two stacked "no-activation" linears == one linear x = torch.randn(4, 8) lin = nn.Sequential(nn.Linear(8, 16), nn.Linear(16, 8)) # Two Linears in series is still a linear map; an equivalent matrix replaces it relu_net = nn.Sequential( nn.Linear(8, 16), nn.ReLU(), # ← inject non-linearity; now depth "counts" nn.Linear(16, 8), ) print(relu_net(x).shape) # torch.Size([4, 8]) # ReLU itself is this simple: def relu(x): return torch.clamp(x, min=0) # max(0, x)
ReLU is a hard if-else router: negatives are cut to zero. GELU is a probabilistic soft router: it smoothly weights through inputs by "how likely they're worth keeping." SwiGLU goes further — a learnable dynamic gate, like a smart load balancer that decides how much traffic to let through based on request content, with the gate opening set jointly by data and parameters.
ReLU's hard cutoff has two flaws: ① it's non-differentiable at 0 (a kink); ② Dying ReLU. GELU (Hendrycks & Gimpel, 2016) replaces the kink with a smooth curve. It's defined as GELU(x) = x · Φ(x), where Φ(x) is the standard normal cumulative distribution function (CDF) — intuitively, Φ(x) is "the probability that a standard normal random number is less than x." The larger x is, the closer this probability is to 1 (pass almost everything); the more negative x is, the closer to 0 (block almost everything), with a smooth transition in between. So GELU weights probabilistically by the input's relative magnitude, rather than hard-gating by sign like ReLU. The negative region keeps a sliver of signal (never fully dead), and gradients are smoother.
SwiGLU (from Shazeer 2020, GLU Variants Improve Transformer) isn't a single activation but a redesign of the Transformer's feed-forward layer (FFN). A plain FFN: W₂ · act(W₁x). GLU-style structures add gating: two linear projections — one as "content," one passed through an activation as the "gate" — multiplied element-wise: (W_v·x) ⊙ Swish(W_g·x). Here ⊙ is element-wise multiplication, and the gate opening is decided dynamically by the input — one extra layer of data-dependent modulation beyond a fixed activation. This is the standard FFN in LLaMA, PaLM, and other modern LLMs.
import torch, torch.nn as nn, torch.nn.functional as F # GELU: built into PyTorch, one line x = torch.randn(2, 8) y = F.gelu(x) # x * Φ(x), a smooth ReLU # A SwiGLU-style FFN (same idea as LLaMA) class SwiGLU_FFN(nn.Module): def __init__(self, dim, hidden): super().__init__() self.w_gate = nn.Linear(dim, hidden, bias=False) # gate self.w_val = nn.Linear(dim, hidden, bias=False) # content self.w_out = nn.Linear(hidden, dim, bias=False) def forward(self, x): # Swish(gate) decides how much value passes per dimension return self.w_out(F.silu(self.w_gate(x)) * self.w_val(x)) # F.silu == Swish == x*sigmoid(x); * is the element-wise gating multiply
Normalization = a standardization (z-score) applied before data enters the next layer, pulling activations back to a uniform "mean 0, variance 1" scale — like normalizing features before a query so a feature with a huge magnitude doesn't drown out the rest. The difference is which dimension you compute statistics over: BatchNorm relies on global statistics across the whole batch (like a rate limiter keyed off a global counter — jittery at small batch); LayerNorm computes per sample (like per-request local normalization, ignoring neighbors).
Pain point: in deep nets, each layer's input distribution drifts sharply as earlier layers update, forcing later layers to "chase a moving target," making training slow and unstable. Normalization forces each layer's activations back to a stable distribution. The core formula is y = γ · (x − μ) / √(σ² + ε) + β, term by term: μ is the mean, σ² the variance (first standardize to 0 mean, 1 variance); ε is a tiny constant to avoid division by zero; γ, β are learnable scale and shift — this step is crucial: normalize first, then let the model learn back the scale it needs, so normalization doesn't sacrifice expressive power.
Why do Transformers / LLMs use LayerNorm instead of BatchNorm? Three hard reasons: ① sequence lengths vary and samples in a batch aren't aligned, so cross-batch statistics are ill-defined; ② BatchNorm is statistically noisy at small batch and behaves inconsistently between train and inference (inference uses running global statistics); ③ at inference the batch may be 1, and BatchNorm's batch statistics simply break. LayerNorm computes per sample, so training and inference are identical, fully decoupled from batch size and sequence length — a natural fit for NLP. BatchNorm remains the workhorse in CNN vision.
import torch, torch.nn as nn x = torch.randn(32, 512) # [batch=32, features=512] bn = nn.BatchNorm1d(512) # across batch: per-feature statistics ln = nn.LayerNorm(512) # across features: per-sample statistics # Verify the difference in stat dimension # BatchNorm: each "column" (feature) pulled to 0 mean → mean(dim=0)≈0 print(bn(x).mean(dim=0).abs().max()) # ≈ 0 # LayerNorm: each "row" (sample) pulled to 0 mean → mean(dim=1)≈0 print(ln(x).mean(dim=1).abs().max()) # ≈ 0 # At inference with batch=1: BatchNorm needs eval() for global stats, LayerNorm doesn't care bn.eval(); print(bn(torch.randn(1, 512)).shape)
RMSNorm is a stripped-down LayerNorm: the authors found that LayerNorm's "subtract the mean (re-centering)" step is actually dispensable, so they cut it, keeping only "divide by magnitude (re-scaling)." It's like removing a field from an RPC protocol you found nobody actually uses — saving serialization overhead with no accuracy loss. The entire LLaMA family and most modern LLMs switched to it.
LayerNorm does two things: re-centering (subtract mean μ to center the data) and re-scaling (divide by std to unify magnitude). Zhang & Sennrich (2019) hypothesize that only re-scaling matters; re-centering is redundant. RMSNorm therefore divides only by the Root Mean Square (RMS): y = γ · x / RMS(x), where RMS(x) = √( (1/D)·Σxᵢ² ). Intuitively RMS is "how large this vector is overall" (average of squared components, then square root), and it needs no mean computation, no subtraction, and no β bias.
Payoff: less compute, faster (the paper reports a notable speedup), with quality on par with — or better than — LayerNorm at large scale. In LLMs with hundreds of billions of parameters, where every layer normalizes countless times, this per-step saving is amplified into a meaningful drop in total training/inference cost — the practical reason for its wide adoption.
A related mechanism: where normalization goes. Early Transformers used Post-LN (norm after the residual add), which trained unstably when deep and needed warmup. Modern designs generally moved to Pre-LN (norm before the sub-layer), giving steadier gradients and removing the fiddly warmup — one of the key changes behind "why today's Transformers are easier to train."
import torch, torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) # only γ, no β self.eps = eps def forward(self, x): # RMS = sqrt(mean(x^2)); note: no mean subtraction rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() return x * rms * self.weight # divide by magnitude, then scale x = torch.randn(2, 512) print(RMSNorm(512)(x).shape) # torch.Size([2, 512]) # torch 2.4+ also has built-in nn.RMSNorm(512); use the official one in prod
Norm → Linear → activation repeated. Another angle: the activation decides "how the signal transforms," normalization decides "at what scale the signal enters the next transform" — one governs shape, the other governs magnitude. Grasp this and you see why swapping the activation often forces a normalization re-tune: they're coupled on the same "gradient health" objective.