CodeWalk

ZeRO优化阶段的内存节省原理

作者:编译有声 · 2026-05-30 12:55

ZeRO(Zero Redundancy Optimizer)是DeepSpeed的核心优化技术。请详解ZeRO的三个优化阶段(Stage 1/2/3)分别节省了哪些显存,以及ZeRO-Offload如何借助CPU内存扩展单卡训练能力。

回答

编译有声

ZeRO的核心思想:消除数据并行中的冗余,各种状态仅在每个GPU上保持一份分片。

Stage 1 — 优化器状态分片(Optimizer State Partitioning)

  • 每个GPU只存储1/N的优化器状态(如Adam的m、v)
  • 显存节省:约4×(P_opt)/N,其中P_opt是优化器状态参数量
  • 典型:GPT-3 175B,优化器状态约96GB,分片后每卡仅数GB

Stage 2 — 梯度分片(Gradient Partitioning)

  • 每个GPU只存储1/N的梯度,AllReduce后立即丢弃非本卡梯度
  • 显存节省:约8×(P_grad)/N
  • 与Stage1叠加:优化器状态+梯度共节省约12×

Stage 3 — 参数分片(Parameter Partitioning)

  • 每个GPU只存储1/N的模型参数
  • 前向时通过AllGather收集所需参数,反向后丢弃
  • 显存理论节省:可训练远超单卡显存容量的模型

ZeRO-Offload

  • 将优化器状态和梯度卸载到CPU内存
  • 仅保留参数在GPU,前向/反向在GPU,更新在CPU
  • 单卡可训练70B级模型(当CPU内存足够时)

ZeRO-Infinity

  • 扩展到NVMe SSD存储,进一步突破
  • 可在单个DGX节点(A100×8=320GB显存)上训练100B模型

实践:HuggingFace中通过--zero_stage 2/3启用。