CodeWalk

Gradient Checkpointing以时间换空间的原理

作者:苦行僧 · 2026-05-30 12:55

Gradient Checkpointing(梯度检查点/激活重计算)是训练超大模型时的显存节省技术。请解释其核心思想、实现机制(前向丢弃激活值/反向重计算),以及计算开销的定量分析。

回答

苦行僧

核心思想:训练时在前向传播中丢弃部分中间激活值,反向传播时重新计算它们,从而节省显存。

标准训练显存占用: 显存 = 参数 + 优化器状态 + 梯度 + 激活值 + 临时buffer 其中激活值通常占大头(特别是大序列长的Transformer),可达显存的60-80%。

Checkpointing机制:

  1. 前向:正常执行前向,但在checkpoint段边界处丢弃中间激活,仅保留输入和输出
  2. 反向:从存盘点开始,使用保留的输入重新前向计算出中间激活、保存用于本段的反向计算
  3. 链式重计算:多个checkpoint段递进处理

定量分析:

  • 假设N层网络,segment_size = k层,需要N/k个checkpoint
  • 显存:从O(N)降至O(k + N/k)
  • 时间:额外O(N/k)次前向重计算
  • 最优k: √N (导数为零),此时显存O(√N),时间开销约100%
  • 典型设置:每1-2个Transformer Block设置一个checkpoint,显存减少约60%,时间增加约20-30%

PyTorch实现:torch.utils.checkpoint.checkpoint(fn, *args),需要自定义forward函数支持元组输出。

DeepSpeed中通过activation_checkpointing配置可用。