CodeWalk

知识蒸馏的损失设计与学习策略

作者:我是大山 · 2026-05-30 12:55

知识蒸馏将教师模型的知识迁移到学生模型。请解释知识蒸馏的核心思想、软标签损失、温度参数、特征层蒸馏和关系蒸馏。

回答

我是大山

知识蒸馏(Knowledge Distillation)由Hinton等人于2015年提出,核心思想是'让学生模仿教师'。

核心思想: 训练一个小型学生模型(Student)来复制大型教师模型(Teacher)的行为,使学生达到接近教师的性能,同时推理速度更快、资源需求更少。

软标签损失(Soft Label Loss)——蒸馏损失: 软标签(Soft Labels)是教师模型输出的类别概率分布,包含类别间的相对关系信息(如猫vs狗vs车的相似度)。

使用温度参数T软化概率分布: q_i = exp(z_i/T) / Σ_j exp(z_j/T)

  • T=1:标准Softmax。
  • T>1(如T=4):分布更平滑,揭示更多类别间关系。
  • T很大时:所有类别的概率接近均匀。

完整损失函数: L = α·L_hard(y, σ(z_s)) + β·L_soft(q_t(T), q_s(T))

  • L_hard:学生预测与真实标签的交叉熵(硬标签)。
  • L_soft:学生软预测(高温T)与教师软标签的KL散度。
  • α, β:权重系数,通常温度使用相同的T。

特征层蒸馏(Feature-based Distillation)

  • 不仅仅对输出层蒸馏,也让中间特征层对齐。
  • FitNets(2015):学生网络的学习中间层特征与教师的某层特征匹配(Hints Training)。
  • 常用方法:
    1. 选择师生网络对应的中间层。
    2. 使用适配器(1×1卷积或MLP)对齐维度。
    3. 最小化MSE或余弦距离。
  • 优点:中间层知识更丰富,适合迁移学习。

关系蒸馏(Relation-based Distillation)

  • 关注样本间的关系而非单个样本的输出。
  • 代表性方法——RKD(Relational Knowledge Distillation)
    1. 距离关系:教师和学生输出空间中样本对的距离应一致。
    2. 角度关系:教师和学生输出空间中三元组的角度应一致。
  • 优点:捕获了数据的结构信息,对数据分布变化更鲁棒。

自蒸馏(Self-Distillation)

  • 教师和学生是同一个架构,教师是历史版本或加深版。
  • 在训练过程中逐步使用模型自身的预测作为软标签。
  • 有趣发现:自蒸馏也能提升性能(如BYOT, Born-Again Networks)。

实践技巧

  1. T通常设为4-8。
  2. 蒸馏损失权重α通常比β小。
  3. 教师模型需提前训练好且固定。
  4. 在数据量充足时,直接训练学生+蒸馏比单独训练学生效果好。
  5. 近年大模型领域,蒸馏是缩小模型规模的主要方法。