传统 Attention 的计算流程里,我们

  1. 先计算 S=QKS=QK^\top,然后把中间计算结果 SS 写入 HBM (Global Memory).
  2. 接着我们又从 HBM 读取 SS(有时根据任务的不同,还会读取 mask)计算 softmax,得到 attention score PP,又把 PP 写回 HBM.
  3. 最后,我们又从 HBM 中读取 PPVV,计算 attention 最终结果.

Flash Attention 做的优化是:不把 S,PS,P 写回 HBM 了,这样可以减少 I/O.

注意这里减少的 IO 是 S,PS,P 矩阵的读写,而不是 Q,K,VQ,K,V 的读写.后面这三个矩阵的读写目前还优化不掉.

正向传播的优化

反向传播的推导

在正向传播的过程中,由于我们没有保存中间结果 S=softmax(QKd)\mathbf{S}=\text{softmax}(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d}}),所以,无法直接计算出 S\mathbf{S}.但是我们保存了中间结果 mij,ijm_{ij},\ell_{ij},我们可以利用这些来反过来计算出 Sij\mathbf{S}_{ij}