跳到主要内容

生产环境中的知识蒸馏:让小模型完成大模型的任务

· 阅读需 9 分钟
Tian Pan
Software Engineer

一家医疗公司每天用 GPT-4 处理 10,000 份文档,年度账单高达 5 万美元。在用前沿模型的输出对一个 270 亿参数的开源模型进行微调后,相同的工作量仅需 5,000 美元——节省了 90%。这个小模型在他们的特定任务上还比前沿模型高出 60%,因为它已经见过数千个完全正确行为的示例。

这就是现代形式的知识蒸馏:你一次性支付前沿模型 API 费用来生成训练数据,然后永远运行一个小型专用模型。这个算法之所以成立,是因为当你拥有权重时推理成本很低,而且在有足够示例的情况下,特定任务的模型能在窄任务上胜过通用模型。

但"收集输出、微调、上线"并不是完整的方案。大多数尝试蒸馏的团队都会遇到三堵隐形墙之一:劣质的合成数据导致学生学到错误行为,缺乏可靠信号来判断学生何时真正就绪,或者生产环境中出现无声的质量崩溃,直到用户抱怨才被发现。本文涵盖决定蒸馏是否成功的流程决策。

蒸馏如何真正转移知识

最初的框架——训练一个小模型来模仿大模型——低估了其中发生的事情。关键机制是软标签转移

当教师模型将一份文档分类为"账单查询"时,它不会输出一个干净的 1.0 概率。它可能输出 0.72 的账单概率、0.18 的退款请求概率和 0.10 的账户查询概率。这些次要概率编码了模型学到的语义关系:账单和退款请求彼此之间的相似度高于与账户管理的相似度。仅在硬标签(0 或 1)上训练的学生完全失去了这个信号。

温度缩放控制着有多少分布信息被转移。在温度 T=1 时,教师的输出相对集中。在 T=3 或 T=5 时,概率质量分散到各类别,使学生更容易看到类间关系。根据经验,对于大多数分类任务,最优温度在 2 到 5 之间;温度更高则需要更长的训练才能有效。

训练损失结合了两个组件:学生和教师软输出之间的 KL 散度(以 T² 加权)加上硬标签上的标准交叉熵。T² 缩放在温度升高时保持梯度幅度。在实践中,从 50/50 的权重分配开始,然后再调整——如果你的模型能访问真实标签,不要为了软目标而放弃它们。

对于基于 LLM 的蒸馏,等价方式是令牌级别的概率匹配,但大多数团队跳过这一步,转而采用更简单的方法:将前沿模型的输出视为真实标签,并使用标准监督训练对学生进行微调。你会失去一些理论上的好处,但获得了更简单的流程,而且实证结果足够强大,这已成为主流方法。

构建蒸馏数据集

近期蒸馏研究中最重要的一课:示例的质量主导数量。LIMA 论文证明,1,000 个精心策划的指令跟随示例,与早期工作中需要 52,000 个未精选示例才能达到的性能相当。一项网络代理蒸馏研究使用 2,322 个经过筛选的教师轨迹,生产出一个 90 亿参数的模型,在代理任务上优于 GPT-4o 和 Claude 3.5 Sonnet。

这意味着你的数据流程与模型架构同等重要。三阶段过滤是最低标准:

阶段一:生成

设计你的提示以引导你想要转移的行为,而不仅仅是任何行为。对于推理任务,使用思维链提示并保留推理轨迹——这些是你要蒸馏的一部分。对于分类,变化示例以覆盖你真实的输入分布。不要从单一狭窄的提示模板生成;学生会过拟合到模板的特征上。

阶段二:自动过滤

对生成的输出进行以下处理:

  • 使用辅助 AI 评分器(不同的模型,或使用评分提示的同一模型)标记错误答案
  • 使用启发式规则删除自相矛盾、截断的输出和拒绝响应
  • 基于熵的过滤,删除教师高度不确定的输出——这些是不可靠的监督信号
  • 去重以避免学生记住重复的表达

阶段三:人工抽查

抽取过滤后数据集的 5-10% 并审查。自动过滤器会遗漏系统性错误——例如,教师模型在特定类别输入上自信地出错。如果在抽查中发现错误模式,在训练前需要针对该模式进行定向过滤。

最终数据集应该是合成前沿输出(约 70%)和来自任务分布的真实示例(30%)的混合。真实示例将学生锚定到实际用户输入;合成示例提供覆盖范围。仅在合成数据上训练会产生在前沿生成的测试集上表现良好、但在真实输入上退化的模型。

在用户发现之前检测质量崩溃

知识蒸馏引入了标准监督训练中不存在的失败模式。以下三种常见到足以值得专项监控:

容量差距失败:学生太小,无法表示教师学到的内容。这表现为训练损失提前达到平台期,而验证准确率始终比教师低 10-15 个百分点。修复方法通常是更大的学生,或者渐进式蒸馏——先训练一个中等大小的模型,然后将其蒸馏到目标大小。

置信度校准错误:学生产生的输出与教师的 top-1 预测匹配,但置信度校准不准。这在生产中很危险,因为下游系统通常使用置信度分数来做路由决策(例如,"如果置信度 < 0.8,升级到人工审核")。在教师不确定的情况下输出 0.95 置信度的学生会把所有内容路由到自动路径。在评估期间明确测量期望校准误差(ECE)——仅靠准确率无法发现这个问题。

分布漂移:学生是在时间 T 你的数据分布上的前沿输出上训练的。当用户的输入发生变化时,学生的性能会退化,而教师(你仍可以按需查询)会通过下一个蒸馏周期适应。持续跟踪参考保留集上的准确率。如果参考集准确率从基线下降超过 5 个百分点,触发重新训练。

熵崩溃是最近研究中记录的更微妙的失败:模型失去在一系列输入中产生多样化输出的能力,不仅仅是在特定情况下的准确率退化。你可以通过跟踪模型在一段时间内对保留评估集的输出分布熵来检测它。如果熵在多个评估周期内单调递减,而对应的准确率没有提升,说明模型正在收敛到狭窄的行为库。

蒸馏模型何时可以上线?

团队常犯的问题是:"学生是否与教师的准确率相匹配?"正确的问题是:"学生是否在这个任务重要的维度上满足了生产标准?"

这些维度因情况而异,但最低可行评估涵盖:

分布内数据的准确率:目标是达到教师在标准基准上性能的 95%+。如果你的任务没有现有基准,你需要在开始蒸馏之前从带标签的真实示例构建一个——而不是从合成数据构建。

尾部案例的准确率:抽取 200-500 个代表最困难输入的示例——罕见实体、不寻常格式、边缘语义。分别在教师和学生上评估这个集合。学生经常在典型输入上与教师性能相当,但在尾部退化 15-20 个百分点。如果你的尾部案例是高风险的,这个差距比总体准确率更重要。

校准:ECE 低于 5% 是合理的生产目标。校准不准的学生会侵蚀下游系统所依赖的置信度信号的可信度。

延迟和成本:这些通常是你进行蒸馏的原因。在宣布成功之前,在真实负载下测量 P50 和 P95 延迟——批量推理基准对生产请求模式具有误导性。

鲁棒性:在一小组对抗性改写和域外输入上测试。你不需要详尽的对抗测试,但你确实需要知道你的学生对教师能优雅处理的输入变化是否脆弱。

最终部署协议应该首先是影子模式:将所有生产流量路由到两个模型,只服务冠军(教师或之前的学生),并记录两者的输出进行比较。两周影子数据后,比较分歧上的输出——学生和教师不同的情况。如果学生在分歧上的错误率超过 30-40%,它还没有准备好。如果错误率低于这个,以 10% 流量进行金丝雀部署并监控业务指标。

蒸馏循环

蒸馏不是一次性项目。使其在生产中发挥作用的流程是一个循环:

  1. 识别任务——你在推理成本上支出不成比例,或者小型专用模型可以胜过通用模型的任务
  2. 生成和过滤该任务的前沿模型示例
  3. 训练和评估学生,对照你的生产标准
  4. 影子部署并在真实流量上验证,然后才提升
  5. 监控生产中的校准、准确率和熵漂移
  6. 重新触发蒸馏,当你的分布变化足以导致学生性能退化时

从蒸馏中获得最多杠杆的团队将其视为基础设施,而不是研究项目。数据集构建工具、评估框架和部署机制可在任务间复用。一旦到位,蒸馏一个新任务看起来像是几天的数据生成和一次微调运行,而不是数月的工程工作。

成本数学只会随规模改善。如果你在任何有稳定、可定义成功标准且有足够量来证明蒸馏运行合理的任务上支付前沿模型价格,你正在白白浪费可观的节省。

Let's stay in touch and Follow me for more thoughts and updates