兄弟们,最近不少人在后台问我模型蒸馏怎么玩儿,今天直接开聊。干过部署的都知道,GPT-4那种怪物本地跑不动,但业务又需要接近的能力——蒸馏就是干这个的:用大模型(Teacher)教小模型(Student)输出,把知识压缩进更小的参数里。
核心逻辑其实不复杂:Teacher生成软标签(带概率分布的logits),Student去拟合这些分布,损失函数用KL散度+温度参数T。T越高,概率分布越平滑,Student能学到更多暗知识。实操中,我一般把T设在3-5之间,太低学生只会抄答案,太高容易丢失关键细节。
部署上最爽的是:蒸馏后的模型参数量可以砍到原版的10%-30%,推理速度提升5-10倍,显存占用直接打骨折。比如我用Llama-2-13B蒸馏出7B版本,在单张RTX 4090上跑出接近原版90%的准确率,延迟从秒级降到毫秒级。
但别迷信蒸馏——它只能压缩知识,补不了数据短板。训练数据脏、Teacher本身有偏见,蒸馏出来的学生只会更蠢。而且过大的T值会让学生学成“和稀泥”,该硬分类的地方软塌塌。
扔个问题:你们在实际部署中,有没有遇到过蒸馏后模型“变傻”但指标不降的情况?来评论区battle一下技术细节。 |