兄弟们,今天聊聊模型蒸馏这个“省钱大户”。最近几篇论文和社区实践让我觉得,蒸馏技术正从“炫技”走向“真香”。
先说个具体案例:有人用 Llama-3-70B 当教师,蒸馏出一个 7B 的“学生”。关键操作不是简单软标签复制,而是用了多任务蒸馏+对比学习。在推理任务上,学生模型在 GSM8K 上掉点不到 2%,但推理速度(GPU 上 batch=1)从 70B 的 15 tokens/s 飙到 7B 的 120 tokens/s。算下来,成本节省了 90% 以上。
技术细节上,建议关注“注意力迁移”。传统做法只匹配 logits,现在更流行匹配中间层注意力图,甚至加个“温度缩放”处理软标签的置信度分布。比如 Google 的 DistiBERT 就靠这个在 GLUE 上追平了 BERT-base。
几个实用坑:1. 别盲目缩小模型,学生容量不足时会欠拟合;2. 蒸馏时用“数据增强”比纯原始数据效果好 5%-10%;3. 混合损失设计很重要——KL 散度 + 任务损失 + 特征匹配,缺一个可能崩。
现在社区里像 Llama-Factory 都集成了蒸馏模块,命令一行搞定。想省卡、降延迟的兄弟,这周可以试试。 |