FP4 推理框架
Motivation
在 NVIDIA Tensor Core 上,FP4 的计算速度比 FP16 快得多.所以希望提出 FP4 推理即插即用的模块用 FP4 做进一步提速.
Challenges for FP4 Inference
FP4 的推理主要面临几个问题:
- FP4 量化的数值范围非常有限,只有 15 个数值.
这导致 per-tensor 和 per-token 的量化方式无法保证模型精度.
- Attention Map P=softmax(dQK⊤) 中的元素范围都较小,大多在 [0,1] 之间,且靠近 0.
如果直接用 FP4 进行量化,那么大部分元素都会直接变成 0.所以,更常见的是引入 scaling factor s,然后 P≈FP4(P/s)×s.
但这样也有问题:因为 P 的元素值都很小,s 通常就要取到 10−3 的量级,而在 NVIDIA Hopper/Blackwell 等架构里,原生支持的 quantization scaling factor 的数据类型必须为 FP8,导致 scaling factor 在储存时会产生舍入误差.
具体方案
FP4 Microscaling 量化
对于 X∈RN×d 的矩阵,我们将其切分成若干个 R1×n 的 block,每一个 block 记为 Xij,并且,一个 block 内的 n 个元素共享同一个 FP8 scaling factor sij.于是 FP4 的量化与反量化可以表示为
Quantization ϕ(⋅)Dequantization ϕ−1(⋅)sij=max(Xij)/6X^ij=⌈Xij/sij⌋Xij′=sijX^ij
这里 sij 本身也会构成一个矩阵。
FP4 Microscaling MatMul
实现一个新算子 FP4MM(A′,sA,B′,sB),其输出 C 等价于 ϕ−1(A′,sA) 与 ϕ−1(B′,sB) 之间的(相对)高精度的 MatMul.
Attention 计算
作者对 Attention 里的 QK⊤ 和 PV 做了 FP4MM.这里的 Attention 计算方式沿用了 Flash Attention 的计算过程(即 tiling + online softmax 那一套)
除此之外,为了进一步提高精度,对 Q,K 做了 smoothing 处理
数据类型选择
做实验发现以下配置得到的精度最高:
- NVFP4 (E2M1)
- n=16
- scaling factor 为 FP8 (E4M3)
Attention Map 的两阶段量化
实验发现,直接对 P~ 进行 NVFP4 量化的精度误差非常大,主要原因在于,NVFP4 原生实现中,要求 scaling factor 是 E4M3 FP8 而非 FP32 格式.
为了进一步研究反量化精度误差的来源,在研究 P~ 的数值分布后认为:由于 online softmax 计算出来的 P~ 的值在 [0,1] 之间,所以其 scaling factor sij=max(P~ij)/6 之数值范围通常落在 [0,1/6] 之间,导致 E4M3 FP8 并没有发挥出值域范围大的优势,也增加了 accuracy loss.
所以提出两阶段量化:先将 P~ij 的范围放缩到 P~ijq∈[0,448×6],再对 P~ijq 进行量化:
sPP~ijq(sPq,P^ijq)O=rowmax(P~ij)/(448×6)=P~ij/sP=ϕ(P~ijq)=FP4MM(P^ijq,sPq,V^,sV)×sP
其中
- P~ij,P~ijq,sP∈FP32
- sPq,sV∈FP8 E4M3
- P^q,V^ 则是 FP4 格式.
这一套下来,P~ij≈P^ijq×sPq×sP
Empirical Result: 这个两阶段量化可以充分利用 sP 的 E4M3 数值范围,进而减小 P~ 的量化误差和 sP 的数值表示误差
算法流程
【输入】
- Q,K,V∈FP16N×d
- 分块大小 Bq,Bkv
先仿照 Sage Attention 的做法,对 K 做 smoothing:K←K−mean(K)
然后,将 Q 切分为 Tm=N/Bq 块 {Qi},每一块 Qi 的形状为 FP16Bq×d;同理,将 K,V 也进行切块,切成 {Ki},{Vi},形状为 FP16Bkv×d,数量为 Tn=N/Bkv.
-
对于每一块 Qi,i∈[1,Tm]:
-
先进行 smoothing,然后直接 FP4 量化:qˉi=mean(Qi),(sQi,Q^i)=ϕ(Qi−qˉi).这里的 qˉi∈FP16
-
接着,遍历 Kj,Vj,j∈[1,Tn].
这一层循环里,我们对 Qi 计算 Attention Map P,并计算 partial output O
-
对 Kj,Vj 进行 FP4 量化:(sKj,Kj^)=ϕ(Kj),(sVj,V^j)=ϕ(Vj)
-
计算 Sij=QiKj⊤.
这里,因为我们之前其实把 Qi 拆成了
Qi=(Qi−qˉi)+qˉi
所以
Sij=QiKj⊤=(Qi−qˉi)Kj⊤+qˉiKj⊤
因此,这里我们需要同理使用 FP4MM 和 GEMV(本质是标量乘矩阵):
Sij=FP4MM(Q^i,sQi,K^j,sKj)+GEMV(qˉi,Kj⊤)
-
然后,我们使用 Online Attention 的方法,在线计算 Sij rowmax 和 ℓij=∑exp(⋅):
mi,jP~ijℓij=max(mi,j−1,rowmax(Sij))=exp(Sij−mi,j)=exp(mi,j−1−mij)⋅ℓi,j−1+rowsum(P~ij)
CUDA Kernel 的实现优化
INT8 训练框架
Challenges for INT8 Training
对于 INT8 训练来说,其挑战在于:
- Attention Map 的梯度很容易受量化误差的影响,导致在计算 input 的梯度时产生累加误差.