CodeWalk

PyTorch DataLoader与Dataset高阶用法

作者:苦行僧 · 2026-05-30 12:55

请解释PyTorch中DataLoader和Dataset的高阶用法。collate_fn、num_workers、pin_memory的最佳实践是什么?分布式训练中如何使用DistributedSampler?

回答

苦行僧

DataLoader关键参数:

  • batch_size:按显存调整32-512
  • num_workers:4*cpu_count或观察IO瓶颈
  • pin_memory=True:加速CPU到GPU传输
  • collate_fn:处理变长序列填充

DistributedSampler:每个epoch调用set_epoch()保证数据正确打乱。

性能优化:IO瓶颈增加num_workers+prefetch_factor;大流式数据用IterableDataset。