生产环境中的知识蒸馏:让小模型完成大模型的任务
一家医疗公司每天用 GPT-4 处理 10,000 份文档,年度账单高达 5 万美元。在用前沿模型的输出对一个 270 亿参数的开源模型进行微调后,相同的工作量仅需 5,000 美元——节省了 90%。这个小模型在他们的特定任务上还比前沿模型高出 60%,因为它已经见过数千个完全正确行为的示例。
这就是现代形式的知识蒸馏:你一次性支付前沿模型 API 费用来生成训练数据,然后永远运行一个小型专用模型。这个算法之所以成立,是因为当你拥有权重时推理成本很低,而且在有足够示例的情况下,特定任务的模型能在窄任务上胜过通用模型。
但"收集输出、微调、上线"并不是完整的方案。大多数尝试蒸馏的团队都会遇到三堵隐形墙之一:劣质的合成数据导致学生学到错误行为,缺乏可靠信号来判断学生何时真正就绪,或者生产环境中出现无声的质量崩溃,直到用户抱怨才被发现。本文涵盖决定蒸馏是否成功的流程决策。
蒸馏如何真正转移知识
最初的框架——训练一个小模型来模仿大模型——低估了其中发生的事情。关键机制是软标签转移。
当教师模型将一份文档分类为"账单查询"时,它不会输出一个干净的 1.0 概率。它可能输出 0.72 的账单概率、0.18 的退款请求概率和 0.10 的账户查询概率。这些次要概率编码了模型学到的语义关系:账单和退款请求彼此之间的相似度高于与账户管理的相似度。仅在硬标签(0 或 1)上训练的学生完全失去了这个信号。
温度缩放控制着有多少分布信息被转移。在温度 T=1 时,教师的输出相对集中。在 T=3 或 T=5 时,概率质量分散到各类别,使学生更容易看到类间关系。根据经验,对于大多数分类任务,最优温度在 2 到 5 之间;温度更高则需要更长的训练才能有效。
训练损失结合了两个组件:学生和教师软输出之间的 KL 散度(以 T² 加权)加上硬标签上的标准交叉熵。T² 缩放在温度升高时保持梯度幅度。在实践中,从 50/50 的权重分配开始,然后再调整——如果你的模型能访问真实标签,不要为了软目标而放弃它们。
对于基于 LLM 的蒸馏,等价方式是令牌级别的概率匹配,但大多数团队跳过这一步,转而采用更简单的方法:将前沿模型的输出视为真实标签,并使用标准监督训练对学生进行微调。你会失去一些理论上的好处,但获得了更简单的流程,而且实证结果足够强大,这已成为主流方法。
