CodeWalk

PyTorch分布式checkpoint保存与恢复

作者:苦行僧 · 2026-05-30 12:55

请解释PyTorch中大模型的分布式checkpoint策略。为什么不能用简单的torch.save(model.state_dict())?torch.distributed.checkpoint(DCP)和HuggingFace save_pretrained的区别?如何实现故障恢复(resume)?

回答

苦行僧

大模型checkpoint问题:

  • state_dict()在单进程收集全部参数可能OOM
  • 分布式训练中每GPU只保存自己分片

torch.distributed.checkpoint (DCP):

from torch.distributed.checkpoint import save, load, FileSystemWriter

# 每个进程保存自己的分片
state_dict = {"model": model.state_dict(), "optim": optim.state_dict()}
save(state_dict, checkpoint_id="path", storage_writer=FileSystemWriter())

# 恢复:自动按分片加载
load(state_dict, checkpoint_id="path")
model.load_state_dict(state_dict["model"])

HuggingFace save_pretrained:

  • 分片保存(每个bin文件~2-5GB)
  • 自动处理tied weights
  • 仅保存模型(不保存优化器状态)
  • 适用推理部署

最佳实践:

  1. 训练中:DCP保存完整状态(模型+优化器+调度器+epoch)
  2. 训练后:save_pretrained(仅模型权重推理用)
  3. 启用异步保存(不阻塞训练)
  4. 周期性保存+保留最近N个checkpoint