知识蒸馏的损失设计与学习策略
知识蒸馏将教师模型的知识迁移到学生模型。请解释知识蒸馏的核心思想、软标签损失、温度参数、特征层蒸馏和关系蒸馏。
回答
我是大山
知识蒸馏(Knowledge Distillation)由Hinton等人于2015年提出,核心思想是'让学生模仿教师'。
核心思想: 训练一个小型学生模型(Student)来复制大型教师模型(Teacher)的行为,使学生达到接近教师的性能,同时推理速度更快、资源需求更少。
软标签损失(Soft Label Loss)——蒸馏损失: 软标签(Soft Labels)是教师模型输出的类别概率分布,包含类别间的相对关系信息(如猫vs狗vs车的相似度)。
使用温度参数T软化概率分布: q_i = exp(z_i/T) / Σ_j exp(z_j/T)
- T=1:标准Softmax。
- T>1(如T=4):分布更平滑,揭示更多类别间关系。
- T很大时:所有类别的概率接近均匀。
完整损失函数: L = α·L_hard(y, σ(z_s)) + β·L_soft(q_t(T), q_s(T))
- L_hard:学生预测与真实标签的交叉熵(硬标签)。
- L_soft:学生软预测(高温T)与教师软标签的KL散度。
- α, β:权重系数,通常温度使用相同的T。
特征层蒸馏(Feature-based Distillation):
- 不仅仅对输出层蒸馏,也让中间特征层对齐。
- FitNets(2015):学生网络的学习中间层特征与教师的某层特征匹配(Hints Training)。
- 常用方法:
- 选择师生网络对应的中间层。
- 使用适配器(1×1卷积或MLP)对齐维度。
- 最小化MSE或余弦距离。
- 优点:中间层知识更丰富,适合迁移学习。
关系蒸馏(Relation-based Distillation):
- 关注样本间的关系而非单个样本的输出。
- 代表性方法——RKD(Relational Knowledge Distillation):
- 距离关系:教师和学生输出空间中样本对的距离应一致。
- 角度关系:教师和学生输出空间中三元组的角度应一致。
- 优点:捕获了数据的结构信息,对数据分布变化更鲁棒。
自蒸馏(Self-Distillation):
- 教师和学生是同一个架构,教师是历史版本或加深版。
- 在训练过程中逐步使用模型自身的预测作为软标签。
- 有趣发现:自蒸馏也能提升性能(如BYOT, Born-Again Networks)。
实践技巧:
- T通常设为4-8。
- 蒸馏损失权重α通常比β小。
- 教师模型需提前训练好且固定。
- 在数据量充足时,直接训练学生+蒸馏比单独训练学生效果好。
- 近年大模型领域,蒸馏是缩小模型规模的主要方法。