FP4 推理框架

Motivation

在 NVIDIA Tensor Core 上,FP4 的计算速度比 FP16 快得多.所以希望提出 FP4 推理即插即用的模块用 FP4 做进一步提速.

Challenges for FP4 Inference

FP4 的推理主要面临几个问题:

  1. FP4 量化的数值范围非常有限,只有 1515 个数值.

    这导致 per-tensor 和 per-token 的量化方式无法保证模型精度.

  2. Attention Map P=softmax(QKd)P=\text{softmax}(\frac{QK^\top}{\sqrt{d}}) 中的元素范围都较小,大多在 [0,1][0,1] 之间,且靠近 00

    如果直接用 FP4 进行量化,那么大部分元素都会直接变成 00.所以,更常见的是引入 scaling factor ss,然后 PFP4(P/s)×sP\approx \texttt{FP4}(P/s)\times s

    但这样也有问题:因为 PP 的元素值都很小,ss 通常就要取到 10310^{-3} 的量级,而在 NVIDIA Hopper/Blackwell 等架构里,原生支持的 quantization scaling factor 的数据类型必须为 FP8,导致 scaling factor 在储存时会产生舍入误差.

具体方案

FP4 Microscaling 量化

对于 XRN×dX\in\mathbb{R}^{N\times d} 的矩阵,我们将其切分成若干个 R1×n\mathbb{R}^{1\times n} 的 block,每一个 block 记为 XijX_{ij},并且,一个 block 内的 nn 个元素共享同一个 FP8 scaling factor sijs_{ij}.于是 FP4 的量化与反量化可以表示为

Quantization ϕ()sij=max(Xij)/6X^ij=Xij/sijDequantization ϕ1()Xij=sijX^ij\begin{array}{r|l} \text{Quantization }\phi(\cdot) & s_{ij}=\max(X_{ij})/6\\ &\hat X_{ij}=\lceil X_{ij}/s_{ij} \rfloor \\ \hline \text{Dequantization }\phi ^{-1}(\cdot) & X'_{ij}=s_{ij}\hat{X}_{ij} \end{array}

这里 sijs_{ij} 本身也会构成一个矩阵。

FP4 Microscaling MatMul

实现一个新算子 FP4MM(A,sA,B,sB)\texttt{FP4MM}(A',s_A,B',s_B),其输出 CC 等价于 ϕ1(A,sA)\phi^{-1}(A',s_A)ϕ1(B,sB)\phi^{-1}(B',s_B) 之间的(相对)高精度的 MatMul.

Attention 计算

作者对 Attention 里的 QK\mathbf{QK}^\topPV\mathbf{PV} 做了 FP4MM\tt{FP4MM}.这里的 Attention 计算方式沿用了 Flash Attention 的计算过程(即 tiling + online softmax 那一套)

除此之外,为了进一步提高精度,对 Q,K\mathbf{Q},\mathbf{K} 做了 smoothing 处理

数据类型选择

做实验发现以下配置得到的精度最高:

  • NVFP4 (E2M1)
  • n=16n=16
  • scaling factor 为 FP8 (E4M3)

Attention Map 的两阶段量化

实验发现,直接对 P~\tilde{\mathbf{P}} 进行 NVFP4 量化的精度误差非常大,主要原因在于,NVFP4 原生实现中,要求 scaling factor 是 E4M3 FP8 而非 FP32 格式.

为了进一步研究反量化精度误差的来源,在研究 P~\tilde{\mathbf{P}} 的数值分布后认为:由于 online softmax 计算出来的 P~\tilde{\mathbf{P}} 的值在 [0,1][0,1] 之间,所以其 scaling factor sij=max(P~ij)/6s_{ij}=\max(\tilde{\mathbf{P}}_{ij})/6 之数值范围通常落在 [0,1/6][0,1/6] 之间,导致 E4M3 FP8 并没有发挥出值域范围大的优势,也增加了 accuracy loss.

所以提出两阶段量化:先将 P~ij\tilde{\mathbf{P}}_{ij} 的范围放缩到 P~ijq[0,448×6]\tilde{\mathbf{P}}^q_{ij}\in [0,448\times 6],再对 P~ijq\tilde{\mathbf{P}}^q_{ij} 进行量化:

sP=rowmax(P~ij)/(448×6)P~ijq=P~ij/sP(sPq,P^ijq)=ϕ(P~ijq)O=FP4MM(P^ijq,sPq,V^,sV)×sP\begin{aligned} \mathbf{s}_{\mathbf{P}}&=\texttt{rowmax}(\tilde{\mathbf{P}}_{ij})/(448\times 6) \\ \tilde{\mathbf{P}}^q_{ij}&=\tilde{\mathbf{P}}_{ij} / \mathbf{s}_{\mathbf{P}} \\ (\mathbf{s}_{\mathbf{P}^q}, \hat{\mathbf{P}}^q_{ij}) &= \phi(\tilde{\mathbf{P}}^q_{ij}) \\ \mathbf{O} &= \texttt{FP4MM}(\hat{\mathbf{P}}^q_{ij}, \mathbf{s}_{\mathbf{P}^q},\hat{\mathbf{V}}, \mathbf{s}_{\mathbf{V}}) \times \mathbf{s}_{\mathbf{P}} \end{aligned}

其中

  • P~ij,P~ijq,sPFP32\tilde{\mathbf{P}}_{ij},\tilde{\mathbf{P}}^q_{ij},\mathbf{s}_{\mathbf{P}}\in \texttt{FP32}
  • sPq,sVFP8 E4M3\mathbf{s}_{\mathbf{P}^q},\mathbf{s}_{\mathbf{V}}\in\texttt{FP8 E4M3}
  • P^q,V^\hat{\mathbf{P}}^q,\hat{\mathbf{V}} 则是 FP4\tt FP4 格式.

这一套下来,P~ijP^ijq×sPq×sP\tilde{\mathbf{P}}_{ij}\approx\hat{\mathbf{P}}_{ij}^q\times \mathbf{s}_{\mathbf{P}^q}\times \mathbf{s}_{\mathbf{P}}

Empirical Result: 这个两阶段量化可以充分利用 sP\mathbf{s}_{\mathbf{P}} 的 E4M3 数值范围,进而减小 P~\tilde{\mathbf{P}} 的量化误差和 sP\mathbf{s}_{\mathbf{P}} 的数值表示误差

算法流程

【输入】

  • Q,K,VFP16N×dQ,K,V \in \texttt{FP16}^{N\times d}
  • 分块大小 Bq,BkvB_q,B_{kv}

先仿照 Sage Attention 的做法,对 KK 做 smoothing:KKmean(K)K\gets K-\text{mean}(K)

然后,将 QQ 切分为 Tm=N/BqT_m=N/B_q{Qi}\{\mathbf{Q}_i\},每一块 Qi\mathbf{Q}_i 的形状为 FP16Bq×d\texttt{FP16}^{B_q\times d};同理,将 K,VK,V 也进行切块,切成 {Ki},{Vi}\{\mathbf{K_i}\},\{\mathbf{V_i}\},形状为 FP16Bkv×d\texttt{FP16}^{B_{kv}\times d},数量为 Tn=N/BkvT_n=N/B_{kv}

  1. 对于每一块 Qi,i[1,Tm]\mathbf{Q}_i, i \in [1,T_m]

    1. 先进行 smoothing,然后直接 FP4 量化:qˉi=mean(Qi),(sQi,Q^i)=ϕ(Qiqˉi)\bar q_i=\text{mean}(\mathbf{Q}_i), (s_{\mathbf{Q}_i},\hat{\mathbf{Q}}_i)=\phi(\mathbf{Q}_i-\bar q_i).这里的 qˉiFP16\bar q_i\in \texttt{FP16}

    2. 接着,遍历 Kj,Vj,j[1,Tn]\mathbf{K}_j,\mathbf{V}_j,j\in[1,T_n]

      这一层循环里,我们对 Qi\mathbf{Q}_i 计算 Attention Map P\mathbf{P},并计算 partial output O\mathbf O

      1. Kj,Vj\mathbf{K}_j,\mathbf{V}_j 进行 FP4 量化:(sKj,Kj^)=ϕ(Kj),(sVj,V^j)=ϕ(Vj)(s_{\mathbf{K}_j},\hat{\mathbf{K}_j})=\phi(\mathbf{K}_j),(s_{\mathbf{V}_j},\hat{\mathbf{V}}_j)=\phi(\mathbf{V}_j)

      2. 计算 Sij=QiKj\mathbf{S}_{ij}=\mathbf{Q}_i\mathbf{K}_j^\top

        这里,因为我们之前其实把 Qi\mathbf{Q}_i 拆成了

        Qi=(Qiqˉi)+qˉi\mathbf{Q}_i=(\mathbf{Q}_i-\bar q_i)+\bar q_i

        所以

        Sij=QiKj=(Qiqˉi)Kj+qˉiKj\mathbf{S}_{ij}=\mathbf{Q}_i\mathbf{K}_j^\top=(\mathbf{Q}_i-\bar q_i)\mathbf{K}_j^\top + \bar q_i \mathbf{K}_j^\top

        因此,这里我们需要同理使用 FP4MM\tt FP4MMGEMV\tt GEMV(本质是标量乘矩阵):

        Sij=FP4MM(Q^i,sQi,K^j,sKj)+GEMV(qˉi,Kj)\mathbf{S}_{ij}=\texttt{FP4MM}(\hat{\mathbf{Q}}_i,s_{\mathbf{Q}_i},\hat{\mathbf{K}}_j,s_{\mathbf{K}_j})+\texttt{GEMV}(\bar q_i,\mathbf{K}_j^\top)

      3. 然后,我们使用 Online Attention 的方法,在线计算 Sij\mathbf{S}_{ij} rowmax 和 ij=exp()\ell_{ij}=\sum \exp(\cdot)

      mi,j=max(mi,j1,rowmax(Sij))P~ij=exp(Sijmi,j)ij=exp(mi,j1mij)i,j1+rowsum(P~ij)\begin{aligned}m_{i,j}&=\max(m_{i,j-1}, \texttt{rowmax}(\mathbf{S}_{ij})) \\ \tilde{\mathbf{P}}_{ij}&=\exp(\mathbf{S}_{ij}-m_{i,j}) \\ \ell_{ij}&= \exp(m_{i,j-1}-m_{ij})\cdot\ell_{i,j-1} + \texttt{rowsum}(\tilde{\mathbf{P}}_{ij})\end{aligned}

CUDA Kernel 的实现优化


INT8 训练框架

Challenges for INT8 Training

对于 INT8 训练来说,其挑战在于:

  1. Attention Map 的梯度很容易受量化误差的影响,导致在计算 input 的梯度时产生累加误差.