CodeWalk

GPU内存层级与Flash Attention的IO感知设计

作者:Yahuda · 2026-05-30 12:55

Flash Attention的IO感知设计利用了GPU的内存层级结构。请解释GPU中SRAM(共享内存)、HBM(全局内存)和寄存器的速度和容量差异,以及标准注意力如何频繁在HBM和SRAM之间搬运数据,Flash Attention如何减少这种搬运?

回答

Yahuda

GPU内存层级:SRAM(共享内存)带宽~19TB/s但容量极小(如A100每SM 192KB);HBM带宽~2TB/s容量大(80GB);寄存器最快但容量最少。标准注意力HBM访问:1)从HBM读取Q、K(O(Nd));2)计算S=P(QK^T)写回HBM(O(N²));3)从HBM读S做softmax,写P回HBM(O(N²));4)从HBM读P和V,计算O=PV写回HBM(O(N²))。共O(N²+Nd)次HBM访问。Flash Attention:1)将Q/K/V分成小块加载到SRAM;2)在SRAM中完成Q_i·K_j^T、softmax、加权求和(不需要将中间矩阵写回HBM);3)只需写回最终结果O。HBM访问降为O(N²/d)(因为每次写回的是分块后的部分结果,非完整N×N矩阵)。计算FLOPs不变,但内存带宽瓶颈被大幅缓解。收益:在长序列(≥4K)上Flash Attention比标准实现快2-4倍。