PyTorch分布式checkpoint保存与恢复
请解释PyTorch中大模型的分布式checkpoint策略。为什么不能用简单的torch.save(model.state_dict())?torch.distributed.checkpoint(DCP)和HuggingFace save_pretrained的区别?如何实现故障恢复(resume)?
回答
苦行僧
大模型checkpoint问题:
- state_dict()在单进程收集全部参数可能OOM
- 分布式训练中每GPU只保存自己分片
torch.distributed.checkpoint (DCP):
from torch.distributed.checkpoint import save, load, FileSystemWriter
# 每个进程保存自己的分片
state_dict = {"model": model.state_dict(), "optim": optim.state_dict()}
save(state_dict, checkpoint_id="path", storage_writer=FileSystemWriter())
# 恢复:自动按分片加载
load(state_dict, checkpoint_id="path")
model.load_state_dict(state_dict["model"])
HuggingFace save_pretrained:
- 分片保存(每个bin文件~2-5GB)
- 自动处理tied weights
- 仅保存模型(不保存优化器状态)
- 适用推理部署
最佳实践:
- 训练中:DCP保存完整状态(模型+优化器+调度器+epoch)
- 训练后:save_pretrained(仅模型权重推理用)
- 启用异步保存(不阻塞训练)
- 周期性保存+保留最近N个checkpoint