Query, Key, Value

QKV\tt QKV 可以类比为“查字典”的过程,对于要查找的单词 Q\tt Q,在字典 D={(K,V)}\mathcal{D}=\{(\mathtt{K,V})\} 里查找有关联的条目,K\tt K 就是字典收录的单词,V\tt V 就是这个收录的单词的意义。

那么我们怎么通过 (K,V)\tt (K,V) 来确定要查找的单词 Q\tt Q 是什么含义呢?现实世界里,假如我们知道 evaluation 的意思,那么我们可以大致猜测出来 evaluate 的意思,因为他们长得比较像

类比地,在 LLM 的世界里,“长得像”是通过计算 Score 得到的。我们定义 a(q,ki)a(\mathbf{q}, \mathbf{k}_i) 表示查询单词 q\bf q 和字典单词 k\bf k 的“相似程度”,有了相似程度,我们就可以乘上意义 v\bf v 来猜测 q\bf q 这个单词的含义了。

所以,我们正式给出 Attention 的定义

Attention Pooling over Dictionary

考虑字典 D={(K,V)}\mathcal{D}=\{(\mathtt{K,V})\},对于每个查询 q\tt q,其在字典 D\cal D 上的 Attention 定义为

Attention(q,D):=i=1Da(q,ki)vi \mathtt{Attention}(\mathbf{q}, \mathcal{D}):=\sum_{i=1}^{|\mathcal D|}a(\mathbf{q},\mathbf k_i)\mathbf v_i

这里的 a(q,ki)a(\mathbf q,\mathbf k_i) 也可以理解为概率 p(kiq)p(\mathbf k_i|\mathbf q),即给定查询 p\bf p 的情况下,ki\mathbf k_ip(kiq)p(\mathbf k_i|\mathbf q) 的可能性,其含义(vi\mathbf v_i)与 q\bf q 相同。

既然是概率,我们肯定是希望 i=1Da(q,ki)=1\sum_{i=1}^{|\mathcal{D}|} a(\bold{q}, \bold{k}_i)=1. 通常,为了确保,我们可以使用 softmax 操作

a(q,ki)exp(a(q,ki))j=1Dexp(a(q,kj)) a(\bold{q}, \bold{k}_i) \gets \frac{\exp\Big(a(\bold{q}, \bold{k}_i)\Big)}{\sum_{j=1}^{|\mathcal{D}|} \exp\Big(a(\bold{q}, \bold{k}_j)\Big)}

softmax 的好处有很多。首先,softmax 之后,所有元素都在 (0,1)(0,1) 之间;其次,所有元素的和为 11;第三,softmax 可导且梯度不会消失。

Attention Scoring Function

Scaled Dot Product Attention

在 Attention 里,我们先来分析一下 a(,)a(\cdot, \cdot)

a(q,ki)=12qki2=qki12ki212q2 a(\mathbf{q}, \mathbf{k}_i) = -\frac{1}{2} \|\mathbf{q} - \mathbf{k}_i\|^2 = \mathbf{q}^\top \mathbf{k}_i -\frac{1}{2} \|\mathbf{k}_i\|^2 -\frac{1}{2} \|\mathbf{q}\|^2

a()a()12q2-\frac{1}{2}\|\bold{q}\|^2 在所有 a(q,kj)a(\bold{q}, \bold{k}_j) 都是一样的,所以在 softmax 之后:

a(q,ki)=exp(i)×exp(12q2)j=1Dexp(j)×exp(12q2) a(\bold{q}, \bold{k}_i)=\frac{\exp(\dots_i)\times \cancel{\exp(-\frac{1}{2}\|\bold{q}\|^2)}}{\sum_{j=1}^{|\mathcal{D}|} \exp(\dots_j)\times \cancel{\exp(-\frac{1}{2}\|\bold{q}\|^2)}}

会被约掉。此外,由于我们在模型中会作 batch and layer norm,所以,我们可以近似地认为 ki\|\bold{k}_i\| 也是常数,和 q\bold{q} 类似的理由,我们可以把 12ki2-\frac{1}{2}\|\bold{k}_i\|^2 也省去.

最后,我们希望a()a() 控制在一定的 scale 内. 这样,在 exp(a())\exp(a()) 的时候就不会数值爆炸了. 由于 layer norm 和 batch norm,我们可以假设 q,kiRd\bold{q}, \bold{k_i}\in \R^d 并且 zero mean and unit variance. 即

E[qj]=0,Var(qj)=1E[kij]=0,Var(kij)=1 \mathbb{E}[\bold{q}_{j}]=0, \text{Var}(\bold{q}_{j})=1\\ \mathbb{E}[\bold{k}_{ij}]=0, \text{Var}(\bold{k}_{ij})=1

所以对于其点积 qki=j=1dqjkij\bold{q}\cdot \bold{k}_{i}=\sum_{j=1}^d \bold{q}_{j}\bold{k}_{ij}

E[qki]=E[j=1dqjkij]=j=1dE[qjkij]=j=1dE[qj]E[kij]=0Var[qki]=Var[j=1dqjkij]=j=1dVar[qjkij]=j=1dVar[qj]Var[kij]=d \begin{aligned} \mathbb{E}[\bold{q}\cdot \bold{k}_{i}]&=\mathbb{E}\Big[\sum_{j=1}^d \bold{q}_{j}\bold{k}_{ij}\Big]\\ &=\sum_{j=1}^d \mathbb{E}[\bold{q}_{j}\bold{k}_{ij}]\\ &=\sum_{j=1}^d \mathbb{E}[\bold{q}_{j}]\mathbb{E}[\bold{k}_{ij}]\\ &=0 \end{aligned}\\ \begin{aligned} \text{Var}[\bold{q}\cdot \bold{k}_{i}]&=\text{Var}\Big[\sum_{j=1}^d \bold{q}_{j}\bold{k}_{ij}\Big]\\ &=\sum_{j=1}^d \text{Var}[\bold{q}_{j}\bold{k}_{ij}]\\ &=\sum_{j=1}^d \text{Var}[\bold{q}_{j}]\text{Var}[\bold{k}_{ij}]\\ &=d \end{aligned}\\

由于 qj,kij\bold{q}_j,\bold{k}_{ij} 是互相独立的,所以 E[],Var[]\mathbb{E}[],\text{Var}[] 可以直接线性相加 & 拆开为乘积

而我们希望,点积和 q,ki\bold{q},\bold{k}_i 一样,都是 zero mean and unit variance. 所以我们在额外给点积乘上系数 1d\frac{1}{\sqrt{d}},这样,其方差就变为 11 了。所以,我们最后的 Scaled Dot Product Attention 就是

Attn(q,ki)=exp(qkid)j=1Dexp(qkjd) \text{Attn}(\bold{q}, \bold{k}_i)=\frac{\exp\Big( \frac{\bold{q}^\top \bold{k}_i}{\sqrt{d}} \Big)}{\sum\limits_{j=1}^{|\mathcal{D}|} \exp\Big( \frac{\bold{q}^\top \bold{k}_j}{\sqrt{d}} \Big)}
Masked Softmax

输入一个三维 tensor 和 valid_lens(一维或二维),使得一些元素在 softmax 操作之后其权重变为 00.

要让 softmax 后的权重变为 00,我们只需要让原来的元素设定为一个足够小的值即可(我们这里选择 106-10^6

以下假设 input tensor 的形状为 (h,w,d)(h,w,d),我们的步骤就是

  1. 先把 input tensor 变形为 (hw,d)(hw,d),这样我们只需要 (hw,)(hw,) 的向量用来限制每一行保留几个数.
  2. 使用 X[~mask] = value 设置那些要屏蔽的位置为特定值
  3. X 使用 softmax()\texttt{softmax}() 最后重新放缩回 (h,w,d)(h,w,d)

v=valid_lens\bold{v}=\texttt{valid\_lens},所以我们的重心在于把 v\bold{v} 计算为 (hw,)(hw,) 的向量:

  1. 如果 v\bold{v} 为一维,那么其形状为 (h,)(h,),我们直接把 v\bold{v} 扩展到 v:(hw,)\bold{v}':(hw,).
  2. 否而直接 rearrange 为 v:(hw,)\bold{v}':(hw,)

那么怎么计算出 maskmask 呢?我们构造出一个 (hw,d)(hw,d) 的 arange 矩阵 maskmask 满足 mask[,i]=i1mask[\dots, i]=i-1,所以根据 PyTorch 的内置比较运算符,只需要 mask<vmask\lt \bold{v}' 就可以了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
from d2l import torch as d2l
from typing_extensions import Optional


def masked_softmax(
x: torch.Tensor, # 输入张量
remain: Optional[torch.Tensor], # 要保留的元素
) -> torch.Tensor:
if remain is None: # 如果全都保留,那么直接使用 softmax
return F.softmax(x, dim=-1)
else:
shape = x.shape
value = -1e6

# construct the mask
# reshape 或者 repeat 为 (hw,)
remain = (
repeat(remain, "h -> (h w)", b=remain.shape[0], w=shape[1])
if remain.dim() == 1
else rearrange(remain, "h w -> (h w)")
)

# flatten the input tensor to rows of data
# 把输入张量也变形为 (hw, d)
x = rearrange(x, "h w d -> (h w) d")
assert x.shape[0] == remain.shape[0]
print(f"{x=}, {x.shape=}")

# generate the attention mask
# 计算 mask 并且赋值.
word_len = x.shape[-1]
mask = repeat(torch.arange(word_len), "d -> w d", w=x.shape[0])
mask = mask < repeat(remain, "n -> n d", d=word_len)
x[~mask] = value

# 计算 softmax 并且重新 reshape 到原来的形状
return rearrange(
F.softmax(x, dim=-1),
"(h w) d -> h w d",
h=shape[0],
w=shape[1],
)
Batched Matrix Multiplication

BMM(Q,K)\texttt{BMM}(\bold{Q},\bold{K}) 的作用是将

QRn×a×b=[Q1,Q2,,Qn]KRn×b×c=[K1,K2,,Kn] \bold{Q}\in \R^{n\times a\times b}=[\bold{Q}_1,\bold{Q}_2,\dots, \bold{Q}_n]\\ \bold{K}\in \R^{n\times b\times c}=[\bold{K}_1,\bold{K}_2,\dots, \bold{K}_n]

计算得出

[Q1K1,Q2K2,,QnKn]Rn×a×c [\bold{Q}_1\bold{K}_1,\bold{Q}_2\bold{K}_2, \dots, \bold{Q}_n\bold{K}_n]\in \R^{n\times a\times c}

API 是 torch.bmm(Q, K).

Additive Attention

Additive Attention 的优势是计算量更小:

Attn(q,k)=vtanh(Wqq+Wkk)R \text{Attn}(\bold{q}, \bold{k})=\bold{v}^\top \tanh\Big( \bold{W}_q\bold{q}+\bold{W}_k\bold{k} \Big) \in \R

其中 qRq,kRk,vRh\bold{q}\in\R^q, \bold{k}\in\R^k, \bold{v}\in\R^hWqRh×q,WkRh×k\bold{W}_q\in\R^{h\times q}, \bold{W}_k\in\R^{h\times k}.