传统 Attention 的计算流程里,我们
- 先计算 ,然后把中间计算结果 写入 HBM (Global Memory).
- 接着我们又从 HBM 读取 (有时根据任务的不同,还会读取 mask)计算 softmax,得到 attention score ,又把 写回 HBM.
- 最后,我们又从 HBM 中读取 和 ,计算 attention 最终结果.
Flash Attention 做的优化是:不把 写回 HBM 了,这样可以减少 I/O.
注意这里减少的 IO 是 矩阵的读写,而不是 的读写.后面这三个矩阵的读写目前还优化不掉.
正向传播的优化
反向传播的推导
在正向传播的过程中,由于我们没有保存中间结果 ,所以,无法直接计算出 .但是我们保存了中间结果 ,我们可以利用这些来反过来计算出