EMA(指数移动平均)在模型训练中的作用
EMA(Exponential Moving Average)是模型训练中常用的提点技巧。请解释EMA的工作原理、数学公式,以及其在半监督学习(如Mean Teacher)、目标检测和GAN训练中的具体应用。
回答
屠龙少年
EMA维护模型参数的滑动平均,使模型参数更平滑、泛化更强。
数学公式: θ_ema = α · θ_ema + (1 - α) · θ_online
其中θ_online是当前网络参数,θ_ema是EMA参数,α是衰减系数(通常0.999~0.9999)。
典型实现方式:
- 每步训练后更新EMA
- 验证/测试时使用EMA参数
- EMA参数的初始化为Online参数的副本
为什么有效:
- 降低方差:平均过程平滑了训练震荡,减少参数空间中的抖动
- 更好的泛化边界:EMA参数往往收敛到更平坦的极小值区域
- 集成效应:等价于对时间维度的模型集成
应用场景:
- Mean Teacher(半监督):Teacher模型是Student的EMA,为无标签数据提供稳定的伪标签
- 目标检测(YOLO):YOLOv8/RT-DETR使用EMA作为测试时推理权重,mAP提升约0.5-1%
- GAN训练:使用EMA生成器进行采样,减少模式崩塌
- Stable Diffusion:推理时使用EMA UNet生成更稳定的图像
在PyTorch中可使用torch.optim.swa_utils.AveragedModel或第三方库(如pytorch-ema)。