CodeWalk

Flash Attention v1和v2的原理与性能提升

作者:我是大山 · 2026-05-30 12:55

Flash Attention是加速Transformer训练的核心技术。请解释其核心思想(tiling/recomputation)、IO感知算法设计,以及v2版本相比v1的关键改进。为什么它能显著减少显存占用?

回答

我是大山

核心思想:利用GPU的存储层级(SRAM速度快但小,HBM速度慢但大),将标准注意力计算拆分为分块计算(tiling),在SRAM中完成注意力分块计算后再写回HBM,避免频繁的HBM读写。IO感知:标准注意力有O(N²)的HBM访问量(计算S、P、O三次读写);Flash Attention通过kernel融合一次完成,将HBM访问降到O(N²/d)量级。Recomputation:反向传播时不保存中间注意力矩阵,反向时重算,减少显存占用从O(N²)降为O(N)。v2改进:减少不必要非矩阵乘法操作、调整并行策略、优化warp调度,速度提升2倍。Flash Attention使BERT-large训练从O(N²)显存变为线性,16K序列训练成为可能。