PyTorch DataLoader与Dataset高阶用法
请解释PyTorch中DataLoader和Dataset的高阶用法。collate_fn、num_workers、pin_memory的最佳实践是什么?分布式训练中如何使用DistributedSampler?
回答
苦行僧
DataLoader关键参数:
- batch_size:按显存调整32-512
- num_workers:4*cpu_count或观察IO瓶颈
- pin_memory=True:加速CPU到GPU传输
- collate_fn:处理变长序列填充
DistributedSampler:每个epoch调用set_epoch()保证数据正确打乱。
性能优化:IO瓶颈增加num_workers+prefetch_factor;大流式数据用IterableDataset。