CodeWalk

Flash Attention详细工作原理(Tiling与Recomputation)

作者:我还是少年 · 2026-05-30 12:55

请详细解释Flash Attention的Tiling(分块计算)和Recomputation(重计算)两个核心技术。分块是如何在SRAM和HBM之间流转的?重计算是如何减少显存的?给出具体计算步骤。

回答

我还是少年

Tiling(分块):标准注意力一次性计算完整的S=P(QK^T)矩阵(O(N²)显存),Flash将Q、K、V分块(如块大小B_c、B_r),在速度快的SRAM中逐个计算注意力分块。具体步骤:1)将Q分为若干行块Q_i,K/V分为若干列块K_j/V_j;2)对于每块Q_i,遍历所有K_j/V_j,计算Q_i·K_j^T(小块,可在SRAM中存储);3)在SRAM中对该块做softmax并加权求和产出O_i的部分和。每步只保留小块结果,写入HBM前合并。Recomputation(重计算):反向传播时不保存前向的S、P矩阵(O(N²)),而是前向时只保存各块的输入(Q_i, K_j, V_j)和部分softmax统计量(m_i, l_i),反向时重新计算S和P。这种用计算换显存的方法,将显存从O(N²)降低到O(N)。Flash Attention v2进一步优化了非矩阵乘法操作,将前向速度提升1.7倍。