PyTorch FSDP全分片数据并行原理
请解释PyTorch FSDP(Fully Sharded Data Parallel)的原理。FSDP与DDP有什么不同?sharding_strategy和auto_wrap_policy的作用是什么?
回答
孤独的心
FSDP将模型参数/梯度/优化器状态分片到各设备。
vs DDP:
- DDP每GPU完整模型副本,显存高
- FSDP每GPU仅1/N模型,显存低
- 通信:DDP梯度All-Reduce,FSDP Gather+Reduce Scatter
sharding_strategy:
- FULL_SHARD:ZeRO-3(全分片)
- HYBRID_SHARD:节点内DDP+跨节点FSDP
- NO_SHARD:退化为DDP
auto_wrap_policy:决定哪些子模块单独分片(如每层Transformer)。
适用:训练7B/13B/70B大模型,单卡无法容纳时。