KV cache is memoization for autoregressive generation — structurally identical to "don't re-compute solved subproblems" in dynamic programming. Backend analogues: keyed state in a stream processor (Flink's keyed state), materialized views in a database, the session state on an HTTP keep-alive connection. Without it, generating each token re-runs attention over every prior token — O(n²) explodes into O(n³).
LLM inference is autoregressive: generating token #100 requires attention against tokens 1–99. The attention formula is softmax(Q·KT/√d)·V — where Q (Query) is the current token's "asking vector", K (Key) is each historical token's "label vector", and V (Value) is each historical token's "content vector". Intuition: Q is a query, K is an index, V is the indexed payload.
Key observation: when generating token #100, the K and V of tokens 1–99 do not change — each token's K/V depends only on what came before it. So K/V is computed once and reused forever; only the new token's Q/K/V is new work. Storing all historical K/V is the KV cache.
2 × n_layers × n_heads × head_dim × 2 bytesThis is why LLM inference is bottlenecked by memory bandwidth, not compute — a counter-intuitive but critical fact. The other three optimizations on this page (speculative decoding, continuous batching, quantization) all exist to "squeeze more out of bandwidth".
import torch from transformers import AutoTokenizer, AutoModelForCausalLM tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") m = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=torch.float16).cuda() input_ids = tok("The capital of France is", return_tensors="pt").input_ids.cuda() past_key_values = None # ← KV cache, starts empty for _ in range(20): with torch.no_grad(): out = m(input_ids=input_ids, past_key_values=past_key_values, use_cache=True) next_id = out.logits[:, -1, :].argmax(-1, keepdim=True) past_key_values = out.past_key_values # ← accumulate cache input_ids = next_id # ← next round feeds only the "new token"; the rest is in cache if next_id.item() == tok.eos_token_id: break # Disable use_cache and feel the difference — same 20 tokens, 10–50x slower
The same trick as CPU branch prediction + speculative execution: let a cheap "small model" run ahead and propose several tokens, then have the "big model" verify them in parallel in a single pass. If the guess holds, you've banked free tokens; if it fails, you roll back. Backend analogues: optimistic concurrency (write first, conflict-check later) and predicate pushdown (use a cheap filter to drop most rows before the expensive operator).
During decode, each new token requires shipping the entire 70B model's weights from VRAM to compute units — 140 GB (FP16), which costs ~50 ms even on an H100. The bottleneck is bandwidth, not compute, meaning GPU FLOPs sit largely idle. Speculative decoding's insight: since the bandwidth cost of one forward pass is fixed, verifying 5 candidate tokens at once is nearly free — parallelism for almost no extra cost.
Three steps:
Math: with acceptance rate α and K draft tokens, expected yield per round is (1-αK+1)/(1-α) tokens. At α=0.7, K=5, that's ~2.9 tokens per big-model call — a theoretical 2.9x speedup. In practice 1.5-3x, with the remainder going to draft-model overhead. Medusa / EAGLE replace the standalone draft with lightweight extra heads on the big model itself, eliminating draft overhead and pushing speedup to 3-4x.
# HuggingFace transformers' built-in assisted_generation is speculative decoding from transformers import AutoTokenizer, AutoModelForCausalLM import torch big = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B", torch_dtype=torch.float16, device_map="auto") draft = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=torch.float16, device_map="auto") tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-70B") prompt = tok("Write a haiku about caching:", return_tensors="pt").to("cuda") # assistant_model = the draft — big and draft must share family (same tokenizer + vocab) out = big.generate( **prompt, assistant_model=draft, # ← the key flag max_new_tokens=100, do_sample=False, ) print(tok.decode(out[0])) # Output is identical to no-assistant; speed 1.5-3x. # vLLM / TensorRT-LLM / SGLang have production-grade impls (PagedAttn + speculative fused)
The paradigm jump from a blocking thread pool to an async event loop — same kernel idea as Node.js / Netty / asyncio. Static batching has a weakest-link problem: the whole batch waits for the slowest request to finish before the next batch can start. Continuous batching is the OS scheduler: whoever finishes first yields its slot, and a new request fills in immediately. The vLLM 2023 paper made this mainstream, lifting throughput 5-23x.
LLM requests have huge variance in output length: one needs 10 tokens, another needs 2000. The old approach was request-level batching: gather 8 requests, run them together, but because the batch advances in lockstep, everyone waits for the 2000-token one. GPU spins idle, throughput tanks.
Continuous batching (a.k.a. iteration-level scheduling) drops scheduling granularity from "whole request" down to "every decode step":
The implementation challenge is KV cache memory management: the naive approach pre-allocates max-length KV memory per request, wasting 60-80%. vLLM's PagedAttention slices the KV into fixed 4KB "pages" allocated on demand — directly borrowing OS virtual memory paging — which made continuous batching actually usable. It's the most important LLM-serving systems advance of 2023; today vLLM / SGLang / TensorRT-LLM / TGI all use the same pattern.
# vLLM has continuous batching on by default — just submit requests from vllm import LLM, SamplingParams llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", gpu_memory_utilization=0.9, max_num_seqs=256) # ← max concurrent requests; scheduler does in-flight scheduling prompts = [ "Explain caching in one paragraph.", # short output "Write a 1000-word essay on consciousness.", # long output "What is 2+2?", # very short # ... dozens or hundreds of mixed-length requests at once ] params = SamplingParams(max_tokens=1024, temperature=0.7) outputs = llm.generate(prompts, params) # Short requests return in 1-2s without waiting for long ones; long ones continue. # Same hardware vs raw HuggingFace transformers: typically 10x+ throughput. # This is why every production LLM serving stack uses vLLM/SGLang/TensorRT-LLM, not raw transformers.
Quantization is JPEG compression for neural networks — deliberately accept controlled lossy precision degradation in exchange for exponential gains in memory and speed. Backend analogues: fixed-point instead of float (classic embedded systems trick), columnar compression (Parquet's dictionary encoding squeezing strings into ints), protobuf varint (small numbers in fewer bytes). All bet on the same thing: "data's distribution structure lets us recover it with fewer bits."
Llama-3 70B in FP16 takes 140 GB of VRAM — a single H100 (80 GB) can't hold it. And H100s cost ~$30K. Quantized to INT4, it shrinks to 35 GB — a consumer RTX 4090 (24 GB) with CPU offload can run it, a 100x cost gap. This is why quantization is the entry ticket to local / edge LLM deployment.
Core formula: x_int = round((x_float - zero) / scale), dequantized via x_float ≈ x_int * scale + zero. Intuition: linearly map floats to a small integer range — scale sets "step size", zero sets "where zero lands". Compressing FP16's [-65504, 65504] to INT8's [-128, 127] loses precision, but model weights follow a bell curve concentrated near zero, so the loss is far smaller than you'd expect.
Weight-only vs weight+activation (W·A) are two distinct paths: weight-only saves VRAM and accelerates memory transfer (which is exactly the inference bottleneck), is simple to implement; activation quantization unlocks INT8 tensor cores for real compute speedup but is harder to keep accurate. Community consensus: local inference uses weight-only INT4 (AWQ/GPTQ); cloud high-throughput uses W8A8 / FP8.
# Load a 4-bit quantized model directly with bitsandbytes — one config block from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch qcfg = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", # normal-float 4 (QLoRA paper), optimal for normal distributions bnb_4bit_use_double_quant=True, # quantize the quantization constants too — small extra win ) m = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.1-8B-Instruct", quantization_config=qcfg, device_map="auto", ) tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") ids = tok("Define entropy in one sentence.", return_tensors="pt").to("cuda") print(tok.decode(m.generate(**ids, max_new_tokens=100)[0])) # 8B model drops from 16GB to 5GB — runs on M2 Mac / RTX 3060; quality loss usually <2%