Flash Attention:减少 Attn 计算的 IO 调用
传统 Attention 的内存调度过程
Flash Attention 的优化思路
Memory 从访问速度和内存大小的角度来看,可以分为三层:
- GPU SRAM
Flash Attention 的目标则是减少内存 IO 量
算法核心
矩阵分块
softmax 的分块实现
我们来考察 safe softmax,其计算流程可以描述为
- $x = [x_1,x_2,\dots,x_n]$
- 求出数组的最大元素:$m(x):=\max x_i$
- 求出 safe exponential $p(x):=[e^{x_1-m(x)},e^{x_2-m(x)},\dots,e^{x_n-m(x)}]$
- 求和:$l(x)=\sum p(x)$
- 计算每个元素的 softmax:$softmax(x)=\frac{p(x)}{l(x)}$
接下来,我们考察,如果有两个 list $x^{(1)}=[x_1^{(1)},x_2^{(1)},\dots,x_n^{(1)}]$ 和 $x^{(2)}=[x_1^{(2)},x_2^{(2)},\dots,x_n^{(2)}]$,看看我们怎么把他们合并成一个 $x=[\dots]$
- $m(x)=\max(m(x^{(1)}), m(x^{(2)}))$
- $p(x)=[e^{m(x^{(1)})-m(x)}p(x^{(1)}), e^{m(x^{(2)})-m(x)}p(x^{(2)})]$.