一个 70B 参数的模型,FP16 精度下光权重就要 140 GB 显存——单张消费级显卡装不下。模型压缩回答的是同一个问题的四个角度:能不能用更小的模型、更少的权重、更小的改动、更低的精度,把这堆参数塞进你负担得起的硬件。四种技术正交,可叠加:蒸馏(换小模型)、剪枝(删权重)、LoRA(只训增量)、量化(降精度)。本期讲「为什么这么做有效」的机制与数学。
蒸馏像资深工程师带新人 code review。差的师傅只说「这题选 A」(硬标签,hard label);好的师傅会说「A 八成对,B 还行有一成可能,C 基本不用考虑」(软标签,soft label)。后者传递的信息量大得多——新人不只学到答案,还学到了类别之间的相似度结构。蒸馏就是让小模型(student)去模仿大模型(teacher)输出的整个概率分布,而不只是最终答案。
大模型准但贵,小模型快但笨。能不能让小模型「继承」大模型的判断力?关键洞察来自 Hinton 2015:大模型 softmax 输出里,那些非正确类的微小概率(猫=0.9、狗=0.08、汽车=0.0001)藏着大量「暗知识」——它告诉你「猫和狗很像,和汽车完全不像」。这种类间关系,硬标签(猫=1,其余=0)完全丢掉了。
机制核心是用温度(temperature,T)软化 softmax。普通 softmax 把最大值压得接近 1、其余接近 0;除以一个 T>1 再做 softmax,分布会变「平」,那些小概率被放大、暴露出来:
这里 zi 是第 i 类的原始打分(logit),T 是温度旋钮。T=1 是普通 softmax;T 越大分布越平滑,类间相似度越清晰。训练时让 student 在同样的高温下匹配 teacher 的软分布(用 KL 散度衡量两个分布的差距),通常再加一点真实硬标签做锚。下图直观感受软硬标签的信息差:
import torch.nn.functional as F def distill_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7): # 1) 软目标:师生都用温度 T 软化,再比对分布(KL 散度) s_soft = F.log_softmax(student_logits / T, dim=-1) t_soft = F.softmax(teacher_logits / T, dim=-1) # teacher 不回传梯度 kd = F.kl_div(s_soft, t_soft, reduction="batchmean") * (T * T) # ↑ 乘 T²:温度软化会缩小梯度,乘回来保持量级 # 2) 硬目标:和真实标签的普通交叉熵,做锚点防跑偏 ce = F.cross_entropy(student_logits, labels) # 3) 加权组合:alpha 偏重模仿老师,(1-alpha) 偏重真值 return alpha * kd + (1 - alpha) * ce
剪枝就是删掉用不到的数据库索引,或者 dead code elimination。一个训练好的网络里,大量权重的绝对值接近 0——它们对最终输出几乎没贡献,就像一个从没被查询命中的索引、一段永远走不到的分支。把它们置零(甚至物理删除),模型变小变快,精度几乎不掉。问题只在于:怎么判断哪些权重是「死」的。
神经网络天然过参数化(over-parameterized)——参数远多于任务实际需要,这是训练能收敛的代价。训练完之后,很多参数就成了冗余。最朴素也最有效的判据是幅值剪枝(magnitude pruning):权重绝对值越小,删掉影响越小。
直觉:w 是连接强度,|w|≈0 意味着这条连接几乎不传递信号,剪掉等于删一根没用的线。但一刀剪太狠精度会崩,所以标准做法是迭代式「剪一点 → 微调恢复 → 再剪一点」(train-prune-finetune 循环),让网络逐步适应稀疏结构。
更深的发现是 Frankle & Carbin 2018 的彩票假说(Lottery Ticket Hypothesis):一个大的随机初始化网络里,本就藏着一个小的「中奖」子网络——单独把它拎出来、用原始初始化训练,能达到和完整大网络相当的精度。剪枝某种意义上是在「刮开彩票」找到那个子网络。两种剪枝粒度的工程含义截然不同:
import torch.nn.utils.prune as prune import torch.nn as nn layer = nn.Linear(1024, 1024) # 幅值剪枝:把这层 40% 绝对值最小的权重置零 prune.l1_unstructured(layer, name="weight", amount=0.4) print((layer.weight == 0).float().mean()) # ≈ 0.40 稀疏度 # 关键:剪完要微调几个 epoch 让网络恢复(此处省略训练循环) # ... train(model) ... # 让存活的权重补偿被删掉的 # 满意后固化:移除 mask,让置零永久生效 prune.remove(layer, "weight")
全量微调像把整个代码仓库 fork 一份重写——70B 参数全部更新,每个任务存一份 140GB 副本,灾难。LoRA 像 git diff / patch 文件:原始权重冻结不动(base repo),只额外训练一个小小的增量补丁。每个任务只存这个补丁(几 MB~几十 MB),用时叠加到主干上。一个底座,无数个轻量 patch,按需切换。
全量微调一个大模型,要为每个下游任务保存一整套权重,存储和切换成本爆炸。LoRA(Hu et al. 2021)的关键假设:微调时权重的变化量 ΔW 本质是「低秩」的——它没那么复杂,可以用两个瘦长矩阵的乘积来逼近。原本要更新一个 d×k 的大矩阵,现在拆成 B(d×r)× A(r×k),其中秩 r 极小(常取 8、16):
参数量从 d×k 降到 r×(d+k)。举例 d=k=4096、r=8:全量是 1600 万参数,LoRA 只有 6.5 万——缩小约 250 倍。冻结的 W 提供「通用能力」,小小的 BA 提供「任务特化」。推理时把 BA 合并进 W,零额外延迟。下图是维度直觉:
QLoRA(Dettmers et al. 2023)更进一步:把冻结的底座量化到 4-bit 存显存,只有那个小 LoRA 适配器保持高精度训练。这样单张 48GB 显卡就能微调 65B 模型。它的三个关键发明:NF4(4-bit NormalFloat,针对正态分布权重信息论最优的数据类型)、双重量化(连量化常数本身也量化,再省一点)、分页优化器(用显存换内存应对峰值)。一句话:LoRA 省「要训的参数」,QLoRA 再省「冻结底座占的显存」。
from transformers import AutoModelForCausalLM, BitsAndBytesConfig from peft import LoraConfig, get_peft_model # QLoRA:底座 4-bit 量化加载(NF4 + 双重量化) bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True) base = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3-8b", quantization_config=bnb) # 只在注意力的 q/v 投影上挂 LoRA 适配器,秩 r=8 cfg = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.05) model = get_peft_model(base, cfg) model.print_trainable_parameters() # → trainable: ~4M / 8B (≈0.05%),其余全部冻结
量化就是换更小的数据类型存数——把 FP16(16 位浮点)压成 INT8 甚至 INT4,类似把 PNG 转成 JPEG、把 DOUBLE 字段改成 SMALLINT。用精度换空间和速度:模型权重不需要那么多有效数字,砍掉低位信息,体积直接减半、再减半,访存和算力也跟着降。代价是引入量化噪声,关键是控制它别毁掉模型。
大模型推理的瓶颈常常不是算力,而是把几百 GB 权重从显存搬进计算单元的带宽。权重存得越小,搬得越快、装得越下。量化的数学核心是一个线性映射:把一段连续的浮点范围 [min, max],均匀映射到 2b 个整数格子上。
直觉:scale 是「每个整数格子代表多大的浮点跨度」,b 是位宽(INT8 是 8、INT4 是 4)。存的时候把浮点除以 scale 取整成小整数 xq;用的时候再乘回 scale 还原成近似浮点 x̂。x̂ 和原始 x 的差就是量化误差。位宽越低、格子越少、误差越大——这就是「精度换空间」的本质。下图是把连续值塞进 4 个格子(2-bit)的直观感受:
两个实操关键。其一异常值(outlier):LLM 权重/激活里偶有极大值,会把 [min,max] 撑得很宽,导致大多数正常值挤在少数格子里、精度尽失。LLM.int8()、GPTQ 等方法的核心都在处理异常值。其二训练后量化 vs 量化感知:
import torch def quantize_int8(w): # 对称量化:用绝对值最大值定 scale,零点对齐 0 scale = w.abs().max() / 127.0 # INT8 范围 [-127,127] w_q = torch.round(w / scale).clamp(-127, 127).to(torch.int8) return w_q, scale # 存 int8 权重 + 一个 fp scale def dequantize(w_q, scale): return w_q.to(torch.float32) * scale # 用时还原近似浮点 w = torch.randn(4096, 4096) # fp32 权重:64 MB w_q, s = quantize_int8(w) # int8:16 MB,缩到 1/4 err = (w - dequantize(w_q, s)).abs().mean() print(f"平均量化误差 {err:.5f}") # 噪声很小,模型基本无感