CodeWalk

FSDP(Fully Sharded Data Parallel)的设计

作者:小字辈 · 2026-05-30 12:55

FSDP是PyTorch对标ZeRO-3的原生实现。请解释FSDP的分片机制、通信调度(前向AllGather/反向ReduceScatter)、以及它与ZeRO-3在实现层面的异同。

回答

小字辈

FSDP是PyTorch在1.11中引入的分布式训练策略,完全对齐ZeRO-3的语义。

核心设计:

  1. 参数分片:每个GPU只保存模型参数的1/N分片
  2. 前向AllGather:前向计算需要某层参数时,所有GPU通信收集完整参数
  3. 反向ReduceScatter:反向计算出梯度后,ReduceScatter将梯度求和并分片到各个GPU
  4. 丢弃参数:该层计算完成后立即丢弃非本GPU分片的参数

通信调度优化:

  • 分片粒度(wrapping policy):每个Transformer Block作为一个FlattenParamsWrapper
  • 预取(prefetch):在前向/反向时提前发起下一层的AllGather请求,隐藏通信延迟
  • 混合分片:支持不同层的不同分片策略

与ZeRO-3的异同: | 维度 | FSDP | ZeRO-3 | |------|------|--------| | 实现 | PyTorch原生 | DeepSpeed | | 通信后端 | NCCL/GLOO | NCCL/CUDA-Aware MPI | | 分片粒度 | 模块级(灵活) | 参数级 | | CPU Offload | 支持 | 支持(ZeRO-Offload) | | 易用性 | 无需修改模型结构 | 需要DeepSpeed引擎包装 | | 混合精度 | 原生支持 | 需要额外配置 |

性能:FSDP在相同微批大小下的吞吐量通常接近或略超ZeRO-3,且与torch.compile兼容更好。

用法:FullyShardedDataParallel(model, sharding_strategy=ShardingStrategy.FULL_SHARD)对应ZeRO-3。