FSDP(Fully Sharded Data Parallel)的设计
FSDP是PyTorch对标ZeRO-3的原生实现。请解释FSDP的分片机制、通信调度(前向AllGather/反向ReduceScatter)、以及它与ZeRO-3在实现层面的异同。
回答
小字辈
FSDP是PyTorch在1.11中引入的分布式训练策略,完全对齐ZeRO-3的语义。
核心设计:
- 参数分片:每个GPU只保存模型参数的1/N分片
- 前向AllGather:前向计算需要某层参数时,所有GPU通信收集完整参数
- 反向ReduceScatter:反向计算出梯度后,ReduceScatter将梯度求和并分片到各个GPU
- 丢弃参数:该层计算完成后立即丢弃非本GPU分片的参数
通信调度优化:
- 分片粒度(wrapping policy):每个Transformer Block作为一个FlattenParamsWrapper
- 预取(prefetch):在前向/反向时提前发起下一层的AllGather请求,隐藏通信延迟
- 混合分片:支持不同层的不同分片策略
与ZeRO-3的异同: | 维度 | FSDP | ZeRO-3 | |------|------|--------| | 实现 | PyTorch原生 | DeepSpeed | | 通信后端 | NCCL/GLOO | NCCL/CUDA-Aware MPI | | 分片粒度 | 模块级(灵活) | 参数级 | | CPU Offload | 支持 | 支持(ZeRO-Offload) | | 易用性 | 无需修改模型结构 | 需要DeepSpeed引擎包装 | | 混合精度 | 原生支持 | 需要额外配置 |
性能:FSDP在相同微批大小下的吞吐量通常接近或略超ZeRO-3,且与torch.compile兼容更好。
用法:FullyShardedDataParallel(model, sharding_strategy=ShardingStrategy.FULL_SHARD)对应ZeRO-3。