CodeWalk

混合精度训练(FP16/BF16)的原理与实现细节

作者:古法程序员 · 2026-05-30 12:55

混合精度训练(Mixed Precision Training)通过FP16/BF16加速训练并减少显存。请解释FP16/BF16的数值格式差异、损失缩放(Loss Scaling)的机制,以及AMP在PyTorch中的三种模式(O0/O1/O2/O3)。

回答

古法程序员

FP16(16位浮点):1符号位+5指数位+10尾数位。范围: ±65504,精度约3位十进制。 BF16(Brain Float 16):1符号位+8指数位+7尾数位。范围: ±3.4×10³⁸(同FP32),精度约2位十进制。

核心差异:BF16保留与FP32相同的指数范围(避免溢出),但精度更低;FP16精度略高但范围小(容易上溢/下溢)。

混合精度训练(AMP)三要素:

  1. 主权重FP32:始终保留FP32权重副本,更新时用FP32梯度
  2. 前向+反向FP16/BF16:计算加速,显存减半
  3. 损失缩放(Loss Scaling):解决FP16下梯度过小下溢(underflow)
    • 静态缩放:固定乘以2^16
    • 动态缩放:初始2^16,如果梯度溢出(inf/nan)则减半,稳定后倍增

PyTorch AMP三种模式:

  • O0(FP32):纯FP32,无加速
  • O1:自动选择OP使用FP16/FP32,大部分安全操作FP16
  • O2:除BN/损失函数外全部FP16,权重BF16
  • O3(纯FP16):全部FP16,通常训练不稳定

实际使用:A100/H100推荐BF16(无需损失缩放),V100推荐FP16+动态缩放。

示例:torch.cuda.amp.autocast() + GradScaler()