Sage Attention (v1, v2) 的实现主要是 csrc/ 和 sageattention/triton/ 目录下,其中 sageattention/ 目录下主要是 Triton 实现和 CUDA 实现的 wrapper.主入口是 sageattention/core.py 中的 sageattn() 函数,职责是根据 CUDA 和 GPU 架构进行分发.
我们先从 Triton 实现入手阅读 Sage Attention 的实现
Quant Kernel
主要关联文件:quant_per_block.py
Before Triton
函数签名是
1 | def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"): |
对于传入的 q,k,如果 km (mean of k) 不为空,那么会先执行如下 smooth- 过程.
1 | if km is not None: |
接着根据 tensor_layout 计算 stride 和 tensor dimension
1 | # For HND layout ([batch, heads, seq, dim]): |
然后分配 scale 张量和 int8 张量:
1 | q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32) |
最后 launch kernel
1 | grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b) # [num_blocks, num_heads, batch_size] |
- Grid is 3D:
(num_blocks, num_heads, batch_size) sm_scale * 1.44269504: Convert natural log to base- log- Attention kernel uses exp2(), so we need log2 scale
数据准备
先看函数签名:
1 |
|
这里的参数……
Input指 矩阵,输入数据为 FP16/BF16Output指输出的 INT8 张量Scale是 per-block 的 scale factorL是 actually sequence length,用于 maskingstride_*则表明了张量的底层数据分布,表示张量的形状sm_scale是 softmax scaleC是 head dimBLK表示量化的 block size
Kernel 的 issue grid 是 .
于是我们取出对应的 index;offs_n 计算当前线程对应的 block
1 | off_blk = tl.program_id(0) # Block index along sequence dimension |
- Input/Output pointers: 2D matrix
[BLK, C]for this block - Scale pointer: Single scalar per block (not per element!)
- 通过传入不同的 Strides 同时支持 both
HNDandNHDlayouts
接着 load block,传入一个 mask 保证对于任意长度的 attention 均有效.然后乘上 softmax scale
1 | x = tl.load(input_ptrs, mask=offs_n[:, None] < L) |
注意力量化
后面正式进入量化算法流程,按照论文的说法,先计算 scale
1 | scale = tl.max(tl.abs(x)) / 127. |
接着 quantize with rounding,并转化为 INT8
1 | x_int8 = x / scale |
最后保存结果,写入 pointers
1 | tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) |