Positional Encoding

在 RNN 架构里,由于我们是顺序处理 tokens,因此这里天然地自带 token position 的信息. 然而,对于 Attention 它是 parallel computing 的. 每一个 token 通过矩阵乘法,是同时计算出与其他 token 的相似度 score 的. 我们怎么在 Attention 里保留 position 呢?

Transformer 使用了叫做 Positional Encoding 的技术. 假设我们的输入张量为 XRn×d\bold{X}\in\R^{n\times d},表示 nn 个 token 每一个 token 的词向量维度为 dd. Positional Encoding 矩阵 PRn×d\mathbf{P}\in\R^{n\times d} 经由以下方式计算得到后,加到输入矩阵 X\mathbf{X} 上,即 XP+X\mathbf{X}\gets \mathbf{P+X}

Pi,2j=sin(iC2j/d)Pi,2j+1=cos(iC2j/d)C=10000 \begin{aligned} \bold{P}_{i,2j}&=\sin\Big( \frac{i}{C^{2j/d}} \Big)\\ \bold{P}_{i,2j+1}&=\cos\Big( \frac{i}{C^{2j/d}} \Big)\\ C&=10000 \end{aligned}

CC 是常数,Transformer 的论文里取了 1000010000.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class PositionalEncoding(nn.Module):
def __init__(self, num_hiddens: int, dropout: float, max_len: int = 1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.P = torch.zeros((1, max_len, num_hiddens))

base = rearrange(
torch.arange(max_len, dtype=torch.float32),
"seq -> seq 1",
)
base = base / torch.pow(
10000,
torch.arange(0, num_hiddens, 2, dtype=torch.float32),
)

self.P[:, :, 0::2] = torch.sin(base)
self.P[:, :, 1::2] = torch.cos(base)

def forward(self, X: torch.Tensor) -> torch.Tensor:
X = X + self.P[:, : X.shape[1], :].to(X.device)
return self.dropout(X)
Relative Positional Encoding

Relative Positional Encoding 的巧妙之处在于,给定任意一个固定的位置 ii,我们可以仅通过一个线性变换得到一个新位置 i+δi+\delta.

我们用数学推导一下。令 ωj=C2j/d\omega_j=C^{-2j/d},则 (pi,2j,pi,2j+1)(p_{i,2j},p_{i,2j+1}) 可以直接线性变换到不同行的同列上 (pi+δ,2j,pi+δ,2j+1)(p_{i+\delta, 2j}, p_{i+\delta, 2j+1})

[cos(δωj)sin(δωj)sin(δωj)cos(δωj)][pi,2jpi,2j+1]=[cos(δωj)sin(δωj)sin(δωj)cos(δωj)][sin(iωj)cos(iωj)]=[cos(δωj)sin(iωj)+sin(δωj)cos(iωj)sin(δωj)sin(iωj)+cos(δωj)cos(iωj)]=[sin((i+δ)ωj)cos((i+δ)ωj)]=[pi+δ,2jpi+δ,2j+1] \begin{aligned} \begin{bmatrix} \cos(\delta\omega_j) &\sin(\delta\omega_j)\\ -\sin(\delta\omega_j) &\cos(\delta\omega_j) \end{bmatrix} \begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix} &= \begin{bmatrix} \cos(\delta\omega_j) &\sin(\delta\omega_j)\\ -\sin(\delta\omega_j) &\cos(\delta\omega_j) \end{bmatrix} \begin{bmatrix} \sin(i\omega_j)\\ \cos(i\omega_j) \end{bmatrix}\\ &= \begin{bmatrix} \cos(\delta\omega_j)\sin(i\omega_j) + \sin(\delta\omega_j)\cos(i\omega_j)\\ -\sin(\delta\omega_j)\sin(i\omega_j)+\cos(\delta\omega_j)\cos(i\omega_j) \end{bmatrix}\\ &= \begin{bmatrix} \sin((i+\delta)\omega_j)\\ \cos((i+\delta)\omega_j) \end{bmatrix}\\ &= \begin{bmatrix} p_{i+\delta, 2j}\\ p_{i+\delta,2j+1} \end{bmatrix} \end{aligned}