Sage Attention (v1, v2) 的实现主要是 csrc/sageattention/triton/ 目录下,其中 sageattention/ 目录下主要是 Triton 实现和 CUDA 实现的 wrapper.主入口是 sageattention/core.py 中的 sageattn() 函数,职责是根据 CUDA 和 GPU 架构进行分发.

我们先从 Triton 实现入手阅读 Sage Attention 的实现

Quant Kernel

主要关联文件:quant_per_block.py

Before Triton

函数签名是

1
def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):

对于传入的 q,k,如果 km (mean of k) 不为空,那么会先执行如下 smooth-KK 过程.

1
2
if km is not None:
k = k - km

接着根据 tensor_layout 计算 stride 和 tensor dimension

1
2
3
4
5
6
# For HND layout ([batch, heads, seq, dim]):
b, h_qo, qo_len, head_dim = q.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
# For NHD layout ([batch, seq, heads, dim]):
b, qo_len, h_qo, head_dim = q.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)

然后分配 scale 张量和 int8 张量:

1
2
3
4
q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32)
k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32)
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)

最后 launch kernel

1
2
3
4
5
6
7
8
9
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)  # [num_blocks, num_heads, batch_size]
quant_per_block_int8_kernel[grid](
q, q_int8, q_scale, qo_len,
stride_bz_q, stride_h_q, stride_seq_q, # Input strides
stride_bz_qo, stride_h_qo, stride_seq_qo, # Output strides
q_scale.stride(0), q_scale.stride(1), # Scale strides
sm_scale=(sm_scale * 1.44269504), # Convert ln to log2
C=head_dim, BLK=BLKQ
)
  • Grid is 3D: (num_blocks, num_heads, batch_size)
  • sm_scale * 1.44269504: Convert natural log to base-22 log
    • Attention kernel uses exp2(), so we need log2 scale

数据准备

先看函数签名:

1
2
3
4
5
6
7
@triton.jit
def quant_per_block_int8_kernel(Input, Output, Scale, L,
stride_iz, stride_ih, stride_in,
stride_oz, stride_oh, stride_on,
stride_sz, stride_sh,
sm_scale,
C: tl.constexpr, BLK: tl.constexpr):

这里的参数……

  • InputQ,KQ,K 矩阵,输入数据为 FP16/BF16
  • Output 指输出的 INT8 张量
  • Scale 是 per-block 的 scale factor
  • L 是 actually sequence length,用于 masking
  • stride_* 则表明了张量的底层数据分布,表示张量的形状
  • sm_scale 是 softmax scale
  • C 是 head dim
  • BLK 表示量化的 block size

Kernel 的 issue grid 是 (Nblocks,Nheads,sizebatch)(N_{\text{blocks}}, N_{\text{heads}}, \text{size}_{\text{batch}}).

于是我们取出对应的 index;offs_n 计算当前线程对应的 block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
off_blk = tl.program_id(0)  # Block index along sequence dimension
off_h = tl.program_id(1) # Head index
off_b = tl.program_id(2) # Batch index

offs_n = off_blk * BLK + tl.arange(0, BLK)
# [seq_start, seq_start+1, ..., seq_start+BLK-1]
offs_k = tl.arange(0, C)
# [0, 1, 2, ..., head_dim-1]

input_ptrs = (
Input +
off_b * stride_iz +
off_h * stride_ih +
offs_n[:, None] * stride_in +
offs_k[None, :])
output_ptrs = (
Output +
off_b * stride_oz +
off_h * stride_oh +
offs_n[:, None] * stride_on +
offs_k[None, :])
scale_ptrs = (
Scale +
off_b * stride_sz +
off_h * stride_sh +
off_blk)
  • Input/Output pointers: 2D matrix [BLK, C] for this block
  • Scale pointer: Single scalar per block (not per element!)
  • 通过传入不同的 Strides 同时支持 both HND and NHD layouts

接着 load block,传入一个 mask 保证对于任意长度的 attention 均有效.然后乘上 softmax scale

1
2
3
4
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)

x = x.to(tl.float32)
x *= sm_scale

注意力量化

后面正式进入量化算法流程,按照论文的说法,先计算 scale

1
scale = tl.max(tl.abs(x)) / 127.

接着 quantize with rounding,并转化为 INT8

1
2
3
x_int8 = x / scale
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
x_int8 = x_int8.to(tl.int8)

最后保存结果,写入 pointers

1
2
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
tl.store(scale_ptrs, scale)