The standard Transformer has two congenital ailments: attention is O(n²) (double the sequence, quadruple the compute) and dense computation (every token activates every parameter). Today's four concepts are academia's two lines of attack on these ailments — MoE attacks "dense," growing parameters while holding compute fixed; SSM / Mamba attacks "O(n²)" by borrowing RNN ideas to recover linear complexity; long context asks "how do we stretch n without changing the architecture?" They all answer the same question: when we want bigger and longer, what about the hardware bill?
MoE is database sharding + query routing for a neural net. A dense model is like a single-node database that scans the whole table on every request; MoE splits one giant feed-forward layer (FFN) into N "expert shards" and adds a router that forwards each token only to its 1–2 most relevant shards. Total capacity grows linearly with shard count, while per-query compute stays roughly fixed — exactly sharding's core payoff: trade space for throughput, not compute for capacity.
Pain point: the most direct way to a stronger model is more parameters, but in a dense model parameter count and compute are welded together — double the params, double the FLOPs per token, and inference cost explodes linearly. We want "huge capacity but cheap to run," and a dense architecture can't have both.
MoE's mechanism is conditional computation: replace each Transformer layer's FFN with N parallel small FFNs (experts) plus a lightweight router. For each token, the router computes a score distribution and activates only the top-k experts (usually k=1 or 2); the rest do zero work this step. So:
The key design difficulty is load balancing: if the router keeps sending tokens to a few "star experts," the rest never train yet still occupy memory. Shazeer et al.'s foundational 2017 paper introduced an auxiliary balancing loss to penalize this skew; Switch Transformer (2021) further simplified k to 1, showing single-expert routing preserves quality and cuts routing overhead. The 2024 Mixtral 8×7B turned this into an open-source flagship: 8 experts, 2 activated per token.
import torch, torch.nn.functional as F # Minimal MoE layer: N experts + top-k routing (for intuition, not production) class MoE(torch.nn.Module): def __init__(self, d=512, n_exp=8, k=2): super().__init__() self.k = k self.router = torch.nn.Linear(d, n_exp) # router: token → score per expert self.experts = torch.nn.ModuleList([torch.nn.Linear(d, d) for _ in range(n_exp)]) def forward(self, x): # x: [tokens, d] scores = self.router(x) # score every token against all experts w, idx = scores.topk(self.k, dim=-1) # keep only top-k experts w = F.softmax(w, dim=-1) # normalize weights over the chosen k out = torch.zeros_like(x) for j in range(self.k): # weighted sum of activated experts for e in range(len(self.experts)): m = idx[:, j] == e # mask of tokens routed to expert e if m.any(): out[m] += w[m, j:j+1] * self.experts[e](x[m]) return out # N=8 capacity, but each token computes only 2 experts
An SSM processes a sequence like a streaming processor with a fixed-size memory. A Transformer reads the whole log into memory and compares everything pairwise (O(n²)); an SSM is like a Kafka consumer — read one record at a time, compress history into a fixed-dimension "state variable," then update the state and drop the raw record. Memory does not grow with stream length — exactly streaming's core edge over batch.
Pain point: attention must store every past token's Key/Value (KV cache grows linearly with length) and compute pairwise correlations (O(n²)). At tens or hundreds of thousands of tokens it's both slow and memory-hungry. Can we be "O(n) linear, constant memory" like an RNN, yet keep the Transformer's parallel training?
SSMs borrow classic equations from control theory. The core keeps a hidden state h, updated once per input xt:
Intuition: A is the "forget/retain matrix" — how old memory decays; B is the "write gate" — how much of the current input is injected; C is the "read gate" — what to extract from the state. This echoes LSTM gating, but the beauty of SSMs is: when A, B, C are input-independent (linear time-invariant), the whole recurrence can be unrolled mathematically into a convolution, so training runs fully parallel like a CNN, while inference updates step-by-step O(1) like an RNN — best of both worlds.
The difficulty: a naive A matrix numerically explodes or vanishes over long sequences (the same root as RNN gradient issues). Gu et al.'s 2021 S4 paper used a specially structured A matrix (HiPPO-theory initialization + low-rank correction) to fix this stability, letting SSMs for the first time beat Transformers across long-sequence benchmarks (Long Range Arena) and generate dozens of times faster.
import torch # The SSM "recurrent mode" — showing constant-memory streaming (teaching version, not full S4) def ssm_scan(x, A, B, C): # x: [seq_len, d] A,B,C are learned state matrices h = torch.zeros(A.shape[0]) # hidden state: fixed dim, does not grow with seq_len ys = [] for t in range(x.shape[0]): # streaming, token by token h = A @ h + B @ x[t] # decay old state + write new input ys.append(C @ h) # read current output from the state return torch.stack(ys) # Key: whether seq_len is 1k or 1M, h's size is unchanged → O(1) memory # In training this recurrence unrolls into a convolution → fully parallel (omitted)
S4's state matrices are a static config — like a hard-coded cache policy that applies the same "keep/drop" rule regardless of the data. Mamba upgrades it to a content-aware dynamic cache: it makes the "write gate B, read gate C, and forget step" all functions of the current input. The cache policy can now judge for itself "this one matters, keep it longer; that one's noise, forget it fast" — upgrading from a fixed TTL to an adaptive TTL.
Pain point: the last section noted SSMs' Achilles' heel — "indiscriminate compression of history." Because A, B, C are fixed, the model cannot decide by content whom to remember and whom to ignore — this is the lack of "content-based reasoning." A task like "skip all whitespace, keep only content words" is impossible for a fixed SSM.
Mamba's (Gu & Dao, 2023) core innovation is the selection mechanism: it makes the SSM's B, C, and discretization step Δ functions computed from input xt. So every token can dynamically control "how much I write into the state, how much I read, how fast I decay old memory." Intuitively, the model gains selective memory: reset/strongly-write the state on key info, let noise flow past.
But here's the engineering paradox: once parameters depend on input, the system is no longer "linear time-invariant," so the "unroll into a convolution for parallel training" trick stops working. Mamba's second contribution is a hardware-aware parallel scan algorithm — a prefix-sum-like parallel primitive plus carefully designed GPU memory I/O (in the spirit of FlashAttention, keeping intermediate state in fast SRAM) — so this "input-dependent recurrence" still runs efficiently in parallel on a GPU. The paper reports Mamba reaching ~5× the throughput of a same-size Transformer in language modeling, scaling linearly with sequence length up to million-length sequences.
Since 2024 the practical consensus is hybrid architectures: interleave a few attention layers with many Mamba layers — use attention to patch the "precise recall" weakness and Mamba for "long-sequence efficiency," taking the best of both.
# Official implementation: pip install mamba-ssm (needs CUDA) import torch from mamba_ssm import Mamba batch, seq_len, dim = 2, 4096, 512 x = torch.randn(batch, seq_len, dim).to("cuda") model = Mamba( d_model=dim, # input/output dimension d_state=16, # hidden-state dim — fixed, independent of seq_len d_conv=4, # local conv window (captures short-range patterns) expand=2, # internal expansion factor ).to("cuda") y = model(x) # output same shape as input [2, 4096, 512] print(y.shape) # B, C, Δ are computed from x internally → selective memory # However long the sequence, d_state stays 16, run efficiently via parallel scan
"Stretch context from 4K to 1M" is not flipping a config flag; it's more like horizontally scaling a system designed for small data: you can't just change one max_length line — you must simultaneously solve compute (O(n²) blowup), memory (KV cache grows linearly), and generalization (the model never saw positions this far out) — just as scaling a database from single-node to distributed redoes indexing, caching, and consistency all at once.
Pain point: the original Transformer trains at a fixed length (e.g. 2K, 4K); feeding it 100K tokens directly hits three barriers. Long context is a collection of techniques, each attacking one:
The mainstream fix for the third revolves around RoPE (Rotary Position Embedding) (Day 13 covered its mechanism). RoPE encodes position as a "rotation angle"; higher-frequency dimensions rotate faster. Extrapolating directly to unseen lengths makes high-frequency dimensions "over-rotate," confusing the model. Position Interpolation's trick: rather than show the model unseen large angles, compress all positions proportionally back into the trained range — like re-marking a ruler designed for 30cm so it measures 3 meters: the ticks get denser but all stay within the "known interval," needing only a little fine-tuning.
import torch # Core idea of position interpolation: "compress" very long positions back into the trained range train_len = 4096 # model's original training length target_len = 32768 # length we want to extend to scale = train_len / target_len # scaling factor = 1/8 pos = torch.arange(target_len).float() # Key step: position ×scale → squeeze [0, 32768) back into the "known" [0, 4096) pos_interpolated = pos * scale # model only needs to "know" positions it has seen # These interpolated positions feed RoPE to compute rotation angles (RoPE: see Day 13) # In practice: after interpolating, fine-tune briefly on long text to adapt stably print(pos_interpolated.max()) # ≈ 4095, all within the trained range