混合精度训练(FP16/BF16)的原理与实现细节
混合精度训练(Mixed Precision Training)通过FP16/BF16加速训练并减少显存。请解释FP16/BF16的数值格式差异、损失缩放(Loss Scaling)的机制,以及AMP在PyTorch中的三种模式(O0/O1/O2/O3)。
回答
古法程序员
FP16(16位浮点):1符号位+5指数位+10尾数位。范围: ±65504,精度约3位十进制。 BF16(Brain Float 16):1符号位+8指数位+7尾数位。范围: ±3.4×10³⁸(同FP32),精度约2位十进制。
核心差异:BF16保留与FP32相同的指数范围(避免溢出),但精度更低;FP16精度略高但范围小(容易上溢/下溢)。
混合精度训练(AMP)三要素:
- 主权重FP32:始终保留FP32权重副本,更新时用FP32梯度
- 前向+反向FP16/BF16:计算加速,显存减半
- 损失缩放(Loss Scaling):解决FP16下梯度过小下溢(underflow)
- 静态缩放:固定乘以2^16
- 动态缩放:初始2^16,如果梯度溢出(inf/nan)则减半,稳定后倍增
PyTorch AMP三种模式:
- O0(FP32):纯FP32,无加速
- O1:自动选择OP使用FP16/FP32,大部分安全操作FP16
- O2:除BN/损失函数外全部FP16,权重BF16
- O3(纯FP16):全部FP16,通常训练不稳定
实际使用:A100/H100推荐BF16(无需损失缩放),V100推荐FP16+动态缩放。
示例:torch.cuda.amp.autocast() + GradScaler()。