SFT

这篇文章的核心在于提出微调方法,使得可以针对不同的下游任务,用较少的参数量完成高效微调. 在不同下游任务上进行微调,可以看作是一次训练,给定数据集 Z={(xi,yi)}\mathcal{Z}=\{ (x_i, y_i) \},其中 xi,yix_i, y_i 都是 a sequence of tokens. 模型的 pretrained parameter 是 Φ0\Phi_0,目标是训练 Φ=Φ0+ΔΦ\Phi=\Phi_0+\Delta\Phi:

maxΦ(x,y)Zt=1ylog(Pϕ(ytx,y<t))\max_\Phi \sum_{(x,y)\in\mathcal{Z}}\sum_{t=1}^{|y|} \log\Big( P_\phi(y_t|x,y_{\lt t}) \Big)

这个 log\log stuff 就是经典 LLM 训练目标函数

对于不同的下游任务,其领域数据集 Z\mathcal{Z} 不尽相同,导致微调所学习的 ΔΦ\Delta\Phi 也不尽相同. 此外如果是传统的微调的话,参数规模 ΔΦ=Φ0|\Delta\Phi|=|\Phi_0|,相当于重新训练.

LoRA

Aghajanyan 的文章显示 pretrained LLM 的权重矩阵有 low instrisic dimension 的特点. 于是提出猜想:权重的更新也具有 low rank 的特点. 即

ΔW=BARd×k,BRd×r,ARr×k\Delta W=BA \in\mathbb{R}^{d\times k}, B\in\mathbb{R}^{d\times r}, A\in\mathbb{R}^{r\times k}

其中 rmin(d,k)r\ll \min(d,k). 记输入为 xx,于是 forward 的路径就是

activation=Wx=W0x+BAx\text{activation}=Wx=W_0x+BAx

在微调时,我们固定 pretrained param W0W_0 不动,而只训练 A,BA,B 两个矩阵. 对 AA 进行 random Gaussian 初始化,BB 初始化为 00,并且 scale ΔW\Delta W by α/r\alpha/r 其中 α\alpha 是一个固定的超参.

LoRA does not Introduce Inference Latency

针对不同任务的微调完成后,我们得到 mapping taski(Ai,Bi)\text{task}_i\mapsto (A_i,B_i). 在推理之前,我们直接把 BiAiB_iA_i 算出来加到 W0W_0 上,即

WW+BiAiW'\gets W+B_iA_i

W,WW',W 都是 d×kd\times k 的矩阵,因此没有 inference latency. 而在不同任务之间切换也很简单,假设从 ii 切换到 jj,只需要先减去 BiAiB_iA_i 再把 BjAjB_jA_j 加回来即可

WWBiAi+BjAjW'\gets W-B_iA_i+B_jA_j

Low-Rank Updates

论文进一步探讨了 properties of low-rank adaptation learnt from downstream tasks. 实验在 GPT-3 175B 上进行

这一部分主要回答三个问题:

  1. Transformer 架构里,什么样的 weight matrix 最值得 FT?
  2. 最优的微调更新 ΔW\Delta W 是否真的具有 low rank 的性质?是的话,最佳的 rank 是多少?
  3. ΔW\Delta WWW 的相关性研究

Weight Matrix Choices

  • 如果单纯微调 Wq,WkW_q, W_k 会导致性能下降
  • 如果联合微调 Wq,WvW_q, W_v 性能是最好的

并且,对于 Wq,WvW_q, W_v 联合微调的情况下,r=4r=4 的效果也很好了,表明低秩也可以学习到足够的信息

Optimal Rank

  • r=4r=4 already gives satisfying boost, and increasing rr does not help more.

Relationship between Updates and Pretrained Param

  • ΔW\Delta W amplifies some features that already in but not emphasized in WW.