Query, Key, Value
Q K V \tt QKV QKV 可以类比为“查字典”的过程,对于要查找的单词 Q \tt Q Q ,在字典 D = { ( K , V ) } \mathcal{D}=\{(\mathtt{K,V})\} D = {( K , V )} 里查找有关联的条目,K \tt K K 就是字典收录的单词,V \tt V V 就是这个收录的单词的意义。
那么我们怎么通过 ( K , V ) \tt (K,V) ( K , V ) 来确定要查找的单词 Q \tt Q Q 是什么含义呢?现实世界里,假如我们知道 evaluation 的意思,那么我们可以大致猜测出来 evaluate 的意思,因为他们长得比较像 。
类比地,在 LLM 的世界里,“长得像”是通过计算 Score 得到的。我们定义 a ( q , k i ) a(\mathbf{q}, \mathbf{k}_i) a ( q , k i ) 表示查询单词 q \bf q q 和字典单词 k \bf k k 的“相似程度”,有了相似程度,我们就可以乘上意义 v \bf v v 来猜测 q \bf q q 这个单词的含义了。
所以,我们正式给出 Attention 的定义
Attention Pooling over Dictionary
考虑字典 D = { ( K , V ) } \mathcal{D}=\{(\mathtt{K,V})\} D = {( K , V )} ,对于每个查询 q \tt q q ,其在字典 D \cal D D 上的 Attention 定义为
A t t e n t i o n ( q , D ) : = ∑ i = 1 ∣ D ∣ a ( q , k i ) v i
\mathtt{Attention}(\mathbf{q}, \mathcal{D}):=\sum_{i=1}^{|\mathcal D|}a(\mathbf{q},\mathbf k_i)\mathbf v_i
Attention ( q , D ) := i = 1 ∑ ∣ D ∣ a ( q , k i ) v i
这里的 a ( q , k i ) a(\mathbf q,\mathbf k_i) a ( q , k i ) 也可以理解为概率 p ( k i ∣ q ) p(\mathbf k_i|\mathbf q) p ( k i ∣ q ) ,即给定查询 p \bf p p 的情况下,k i \mathbf k_i k i 有 p ( k i ∣ q ) p(\mathbf k_i|\mathbf q) p ( k i ∣ q ) 的可能性,其含义(v i \mathbf v_i v i )与 q \bf q q 相同。
既然是概率,我们肯定是希望 ∑ i = 1 ∣ D ∣ a ( q , k i ) = 1 \sum_{i=1}^{|\mathcal{D}|} a(\bold{q}, \bold{k}_i)=1 ∑ i = 1 ∣ D ∣ a ( q , k i ) = 1 . 通常,为了确保,我们可以使用 softmax 操作
a ( q , k i ) ← exp ( a ( q , k i ) ) ∑ j = 1 ∣ D ∣ exp ( a ( q , k j ) )
a(\bold{q}, \bold{k}_i) \gets \frac{\exp\Big(a(\bold{q}, \bold{k}_i)\Big)}{\sum_{j=1}^{|\mathcal{D}|} \exp\Big(a(\bold{q}, \bold{k}_j)\Big)}
a ( q , k i ) ← ∑ j = 1 ∣ D ∣ exp ( a ( q , k j ) ) exp ( a ( q , k i ) )
softmax 的好处有很多。首先,softmax 之后,所有元素都在 ( 0 , 1 ) (0,1) ( 0 , 1 ) 之间;其次,所有元素的和为 1 1 1 ;第三,softmax 可导且梯度不会消失。
Attention Scoring Function
Scaled Dot Product Attention
在 Attention 里,我们先来分析一下 a ( ⋅ , ⋅ ) a(\cdot, \cdot) a ( ⋅ , ⋅ )
a ( q , k i ) = − 1 2 ∥ q − k i ∥ 2 = q ⊤ k i − 1 2 ∥ k i ∥ 2 − 1 2 ∥ q ∥ 2
a(\mathbf{q}, \mathbf{k}_i) = -\frac{1}{2} \|\mathbf{q} - \mathbf{k}_i\|^2 = \mathbf{q}^\top \mathbf{k}_i -\frac{1}{2} \|\mathbf{k}_i\|^2 -\frac{1}{2} \|\mathbf{q}\|^2
a ( q , k i ) = − 2 1 ∥ q − k i ∥ 2 = q ⊤ k i − 2 1 ∥ k i ∥ 2 − 2 1 ∥ q ∥ 2 a ( ) a() a ( ) 中 − 1 2 ∥ q ∥ 2 -\frac{1}{2}\|\bold{q}\|^2 − 2 1 ∥ q ∥ 2 在所有 a ( q , k j ) a(\bold{q}, \bold{k}_j) a ( q , k j ) 都是一样的,所以在 softmax 之后:
a ( q , k i ) = exp ( … i ) × exp ( − 1 2 ∥ q ∥ 2 ) ∑ j = 1 ∣ D ∣ exp ( … j ) × exp ( − 1 2 ∥ q ∥ 2 )
a(\bold{q}, \bold{k}_i)=\frac{\exp(\dots_i)\times \cancel{\exp(-\frac{1}{2}\|\bold{q}\|^2)}}{\sum_{j=1}^{|\mathcal{D}|} \exp(\dots_j)\times \cancel{\exp(-\frac{1}{2}\|\bold{q}\|^2)}}
a ( q , k i ) = ∑ j = 1 ∣ D ∣ exp ( … j ) × exp ( − 2 1 ∥ q ∥ 2 ) exp ( … i ) × exp ( − 2 1 ∥ q ∥ 2 ) 会被约掉。此外,由于我们在模型中会作 batch and layer norm,所以,我们可以近似地认为 ∥ k i ∥ \|\bold{k}_i\| ∥ k i ∥ 也是常数,和 q \bold{q} q 类似的理由,我们可以把 − 1 2 ∥ k i ∥ 2 -\frac{1}{2}\|\bold{k}_i\|^2 − 2 1 ∥ k i ∥ 2 也省去.
最后,我们希望把 a ( ) a() a ( ) 控制在一定的 scale 内 . 这样,在 exp ( a ( ) ) \exp(a()) exp ( a ( )) 的时候就不会数值爆炸了. 由于 layer norm 和 batch norm,我们可以假设 q , k i ∈ R d \bold{q}, \bold{k_i}\in \R^d q , k i ∈ R d 并且 zero mean and unit variance. 即
E [ q j ] = 0 , Var ( q j ) = 1 E [ k i j ] = 0 , Var ( k i j ) = 1
\mathbb{E}[\bold{q}_{j}]=0, \text{Var}(\bold{q}_{j})=1\\
\mathbb{E}[\bold{k}_{ij}]=0, \text{Var}(\bold{k}_{ij})=1
E [ q j ] = 0 , Var ( q j ) = 1 E [ k ij ] = 0 , Var ( k ij ) = 1 所以对于其点积 q ⋅ k i = ∑ j = 1 d q j k i j \bold{q}\cdot \bold{k}_{i}=\sum_{j=1}^d \bold{q}_{j}\bold{k}_{ij} q ⋅ k i = ∑ j = 1 d q j k ij 有
E [ q ⋅ k i ] = E [ ∑ j = 1 d q j k i j ] = ∑ j = 1 d E [ q j k i j ] = ∑ j = 1 d E [ q j ] E [ k i j ] = 0 Var [ q ⋅ k i ] = Var [ ∑ j = 1 d q j k i j ] = ∑ j = 1 d Var [ q j k i j ] = ∑ j = 1 d Var [ q j ] Var [ k i j ] = d
\begin{aligned}
\mathbb{E}[\bold{q}\cdot \bold{k}_{i}]&=\mathbb{E}\Big[\sum_{j=1}^d \bold{q}_{j}\bold{k}_{ij}\Big]\\
&=\sum_{j=1}^d \mathbb{E}[\bold{q}_{j}\bold{k}_{ij}]\\
&=\sum_{j=1}^d \mathbb{E}[\bold{q}_{j}]\mathbb{E}[\bold{k}_{ij}]\\
&=0
\end{aligned}\\
\begin{aligned}
\text{Var}[\bold{q}\cdot \bold{k}_{i}]&=\text{Var}\Big[\sum_{j=1}^d \bold{q}_{j}\bold{k}_{ij}\Big]\\
&=\sum_{j=1}^d \text{Var}[\bold{q}_{j}\bold{k}_{ij}]\\
&=\sum_{j=1}^d \text{Var}[\bold{q}_{j}]\text{Var}[\bold{k}_{ij}]\\
&=d
\end{aligned}\\
E [ q ⋅ k i ] = E [ j = 1 ∑ d q j k ij ] = j = 1 ∑ d E [ q j k ij ] = j = 1 ∑ d E [ q j ] E [ k ij ] = 0 Var [ q ⋅ k i ] = Var [ j = 1 ∑ d q j k ij ] = j = 1 ∑ d Var [ q j k ij ] = j = 1 ∑ d Var [ q j ] Var [ k ij ] = d
由于 q j , k i j \bold{q}_j,\bold{k}_{ij} q j , k ij 是互相独立的,所以 E [ ] , Var [ ] \mathbb{E}[],\text{Var}[] E [ ] , Var [ ] 可以直接线性相加 & 拆开为乘积
而我们希望,点积和 q , k i \bold{q},\bold{k}_i q , k i 一样,都是 zero mean and unit variance. 所以我们在额外给点积乘上系数 1 d \frac{1}{\sqrt{d}} d 1 ,这样,其方差就变为 1 1 1 了。所以,我们最后的 Scaled Dot Product Attention 就是
Attn ( q , k i ) = exp ( q ⊤ k i d ) ∑ j = 1 ∣ D ∣ exp ( q ⊤ k j d )
\text{Attn}(\bold{q}, \bold{k}_i)=\frac{\exp\Big( \frac{\bold{q}^\top \bold{k}_i}{\sqrt{d}} \Big)}{\sum\limits_{j=1}^{|\mathcal{D}|} \exp\Big( \frac{\bold{q}^\top \bold{k}_j}{\sqrt{d}} \Big)}
Attn ( q , k i ) = j = 1 ∑ ∣ D ∣ exp ( d q ⊤ k j ) exp ( d q ⊤ k i )
Masked Softmax
输入一个三维 tensor 和 valid_lens(一维或二维),使得一些元素在 softmax 操作之后其权重变为 0 0 0 .
要让 softmax 后的权重变为 0 0 0 ,我们只需要让原来的元素设定为一个足够小的值即可(我们这里选择 − 10 6 -10^6 − 1 0 6 )
以下假设 input tensor 的形状为 ( h , w , d ) (h,w,d) ( h , w , d ) ,我们的步骤就是
先把 input tensor 变形为 ( h w , d ) (hw,d) ( h w , d ) ,这样我们只需要 ( h w , ) (hw,) ( h w , ) 的向量用来限制每一行保留几个数.
使用 X[~mask] = value 设置那些要屏蔽的位置为特定值
对 X 使用 softmax ( ) \texttt{softmax}() softmax ( ) 最后重新放缩回 ( h , w , d ) (h,w,d) ( h , w , d )
记 v = valid_lens \bold{v}=\texttt{valid\_lens} v = valid_lens ,所以我们的重心在于把 v \bold{v} v 计算为 ( h w , ) (hw,) ( h w , ) 的向量:
如果 v \bold{v} v 为一维,那么其形状为 ( h , ) (h,) ( h , ) ,我们直接把 v \bold{v} v 扩展到 v ′ : ( h w , ) \bold{v}':(hw,) v ′ : ( h w , ) .
否而直接 rearrange 为 v ′ : ( h w , ) \bold{v}':(hw,) v ′ : ( h w , )
那么怎么计算出 m a s k mask ma s k 呢?我们构造出一个 ( h w , d ) (hw,d) ( h w , d ) 的 arange 矩阵 m a s k mask ma s k 满足 m a s k [ … , i ] = i − 1 mask[\dots, i]=i-1 ma s k [ … , i ] = i − 1 ,所以根据 PyTorch 的内置比较运算符,只需要 m a s k < v ′ mask\lt \bold{v}' ma s k < v ′ 就可以了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 import torchimport torch.nn.functional as Ffrom einops import rearrange, repeat, reducefrom d2l import torch as d2lfrom typing_extensions import Optional def masked_softmax ( x: torch.Tensor, remain: Optional [torch.Tensor], ) -> torch.Tensor: if remain is None : return F.softmax(x, dim=-1 ) else : shape = x.shape value = -1e6 remain = ( repeat(remain, "h -> (h w)" , b=remain.shape[0 ], w=shape[1 ]) if remain.dim() == 1 else rearrange(remain, "h w -> (h w)" ) ) x = rearrange(x, "h w d -> (h w) d" ) assert x.shape[0 ] == remain.shape[0 ] print (f"{x=} , {x.shape=} " ) word_len = x.shape[-1 ] mask = repeat(torch.arange(word_len), "d -> w d" , w=x.shape[0 ]) mask = mask < repeat(remain, "n -> n d" , d=word_len) x[~mask] = value return rearrange( F.softmax(x, dim=-1 ), "(h w) d -> h w d" , h=shape[0 ], w=shape[1 ], )
Batched Matrix Multiplication
BMM ( Q , K ) \texttt{BMM}(\bold{Q},\bold{K}) BMM ( Q , K ) 的作用是将
Q ∈ R n × a × b = [ Q 1 , Q 2 , … , Q n ] K ∈ R n × b × c = [ K 1 , K 2 , … , K n ]
\bold{Q}\in \R^{n\times a\times b}=[\bold{Q}_1,\bold{Q}_2,\dots, \bold{Q}_n]\\
\bold{K}\in \R^{n\times b\times c}=[\bold{K}_1,\bold{K}_2,\dots, \bold{K}_n]
Q ∈ R n × a × b = [ Q 1 , Q 2 , … , Q n ] K ∈ R n × b × c = [ K 1 , K 2 , … , K n ] 计算得出
[ Q 1 K 1 , Q 2 K 2 , … , Q n K n ] ∈ R n × a × c
[\bold{Q}_1\bold{K}_1,\bold{Q}_2\bold{K}_2, \dots, \bold{Q}_n\bold{K}_n]\in \R^{n\times a\times c}
[ Q 1 K 1 , Q 2 K 2 , … , Q n K n ] ∈ R n × a × c API 是 torch.bmm(Q, K).
Additive Attention
Additive Attention 的优势是计算量更小:
Attn ( q , k ) = v ⊤ tanh ( W q q + W k k ) ∈ R
\text{Attn}(\bold{q}, \bold{k})=\bold{v}^\top \tanh\Big( \bold{W}_q\bold{q}+\bold{W}_k\bold{k} \Big) \in \R
Attn ( q , k ) = v ⊤ tanh ( W q q + W k k ) ∈ R 其中 q ∈ R q , k ∈ R k , v ∈ R h \bold{q}\in\R^q, \bold{k}\in\R^k, \bold{v}\in\R^h q ∈ R q , k ∈ R k , v ∈ R h 且 W q ∈ R h × q , W k ∈ R h × k \bold{W}_q\in\R^{h\times q}, \bold{W}_k\in\R^{h\times k} W q ∈ R h × q , W k ∈ R h × k .