Flash Attention:减少 Attn 计算的 IO 调用

2026-01-13

传统 Attention 的内存调度过程

Flash Attention 的优化思路

Memory 从访问速度和内存大小的角度来看,可以分为三层:

  • GPU SRAM

Flash Attention 的目标则是减少内存 IO 量

算法核心

矩阵分块

softmax 的分块实现

我们来考察 safe softmax,其计算流程可以描述为

  • $x = [x_1,x_2,\dots,x_n]$
  • 求出数组的最大元素:$m(x):=\max x_i$
  • 求出 safe exponential $p(x):=[e^{x_1-m(x)},e^{x_2-m(x)},\dots,e^{x_n-m(x)}]$
  • 求和:$l(x)=\sum p(x)$
  • 计算每个元素的 softmax:$softmax(x)=\frac{p(x)}{l(x)}$

接下来,我们考察,如果有两个 list $x^{(1)}=[x_1^{(1)},x_2^{(1)},\dots,x_n^{(1)}]$ 和 $x^{(2)}=[x_1^{(2)},x_2^{(2)},\dots,x_n^{(2)}]$,看看我们怎么把他们合并成一个 $x=[\dots]$

  • $m(x)=\max(m(x^{(1)}), m(x^{(2)}))$
  • $p(x)=[e^{m(x^{(1)})-m(x)}p(x^{(1)}), e^{m(x^{(2)})-m(x)}p(x^{(2)})]$.