MoE 架构


MoE 代码实现:以 MiniMind 为例

Experts

首先定义专家模块,Experts 是

Experts(FeedForward)
Experts(FeedForward)
Code
1
2
3
4
5
6
7
8
9
10
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w3 = nn.Linear(config.hidden_dim, condig.dim, bias=False)
self.dropout = nn.Dropout(config.dropout)

def forward(self, x):
return self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x)))

Router

然后,我们来实现 Router 路由器。Router 接收一个 Token,计算出概率取 Top K 后转发给对应的专家。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()

self.config = config
self.top_k = config.num_experts_per_token
self.n_routed_experts = config.n_routed_experts

self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob

self.gating_dim = config.dim
self.weight = nn.Parameter(
torch.empty((self.n_routed_experts, self.gating_dim))
)

self.reset_parameter()

def reset_parameter(self) -> None:
import torch.nn.init as init

init.kaiming_uniform_(self.weight, a=math.sqrt(5))

def forward(self, tokens: torch.Tensor):
batch, seq_len, d = tokens.shape
tokens = einops.rearrange(tokens, "batch seq dim -> (batch seq) dim")
logits = F.linear(tokens, self.weight, None)

if self.scoring_func == "softmax":
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(
"Unsupported scoring function",
)

topk_weight, topk_idx = torch.topk(
scores,
k=self.top_k,
dim=-1,
sorted=False,
)

if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator

if self.train and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(batch, -1)

if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(batch, seq_len, -1)

ce = torch.zeros(
batch,
self.n_routed_experts,
device=tokens.device,
)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(
batch,
seq_len * aux_topk,
device=tokens.device
)
).div_(seq_len * aux_topk / self.n_routed_experts)

aux_loss = (
(ce * scores_for_seq_aux.mean(dim=-1))
.sum(dim=1)
.mean()
* self.alpha
)
else:
mask_ce = F.one_hot(
topk_idx_for_aux_loss.view(-1),
num_classes=self.n_routed_experts,
)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0.0

return topk_idx, topk_weight, aux_loss

以下分析约定:

  • Batch size BB,即一个 batch 包含 BB 句句子
  • Sequence Length SS,即一句句子包含 SS 个 Token
  • 词嵌入向量维度 HH
  • 专家数量 EE
  • 每个 Token 被转发到 Top KK 个专家
参数含义
  1. top_k 为每一个 Token 选择几个专家进行训练
  2. n_routed_experts 为专家总数
  3. scoring_func 用于计算 Token 和专家之间的评分
  4. aux_loss_alpha 辅助损失的 alpha 参数
  5. seq_aux 控制是否在序列级别上计算辅助损失
  6. norm_topk_prob 是否对概率进行归一化

由于我们对每一个 Token 计算它被发送到某一个专家的概率,在 __init__() 函数里,我们令 Token 的维度为 d=d= self.gating_dim,专家数量为 n=n= self.n_routed_experts,那么我们希望 Router 输出的矩阵大小就是

Router:RB×S×HRB×S×E \text{Router}:\R^{B\times S\times H}\mapsto \R^{B\times S\times E}

因此这里的 self.weightRH×E\R^{H\times E} 大小的矩阵,负责计算一个 Token 的被转发到 Expert 的概率。

前向传播 forward()

Tensor 输入是 B×S×HB\times S\times H,因为我们只关心 Token 发送到哪个专家,所以我们首先把 Tensor 拍成二维 BS×HBS\times H,并用 self.weight 计算概率,用 softmax() 归一化。F.linear(x, A, bias=None) 计算

y=xA y=xA^\intercal

因此,这里的 score 大小为 BS×EBS\times E

接着,我们调用 torch.topk() 选取前 KK 个专家,返回 score 中对应的权重和下标。此时 topk_weight, topk_idx 大小均为 BS×KBS\times K


然后对选择出来的 KK 个专家的权重再进行一次归一化(除以 denominator)。不过这一步是可选的


接着进入 if self.train and self.alpha > 0.0: 判断,目的是为了平衡专家之间的负载。首先,代码确保只在训练模式以及需要平衡负载的时候才会启用。

如果需要计算 sequence level 的 loss,那么 topk_idx 首先被拍平成 (B,SK)(B,SK)

因为要对 sequence level 计算专家负载损失,所以我们先定义每一句句子上专家的负载损失,其大小为 (B,E)(B,E).

然后遍历 topk_idx 中的每一个元素,用 scatter_add_()topk_idx 中的每一个元素添加到对应的 sequence 里对应的专家中。

这个流程结束之后,ce (count experts) 保存的就是每一句句子的专家负载。随后,我们对每一句句子都归一化其专家负载:一句句子会产生 SKSK 个 counting,总数为 SKSK

我们沿 SS 轴对 scores_for_seq_aux 计算平均值 scores.mean(dim=1) 这就表示每句句子与每个 expert 之间的得分(token 与 experts 得分的平均),其大小变为 (B,E)(B, E).

然后再将 cescores.mean() 对应位置相乘,ce 可以理解为每句句子中 expert 的频率(这个 expert 在这句句子里总是被分配处理 token),scores.mean() 可以理解为每句句子中 expert 对于每个 token 的重要程度(这个 expert 总是被 MoEGate 认为与 token 关联很大),大小变为 (B,E)(B,E),但此时仍然是每句句子与 expert 的关联。

因此再对 EE 求和取平均 .sum(dim=1).mean(),把 expert 与不同句子之间的频率与关联度整合起来,即对于这些句子 expert 的频率与关联度,也即 expert 的负载,其大小先变为 (E,)(E,),再变为 (1,)(1,)。最后再乘上标量 self.alpha. 这就是我们的 sequence level 的 aux loss.

值得一提的是,这里还额外乘了一个 EE,推测是为了让梯度不至于太小

Why It Makes Sense

因为我们的训练目标是让 Loss 尽可能的小,对于这个 Aux Loss 而言,如果某一个专家的负载特别大,那么就说明

  1. 每一个 Batch 内的所有 Tokens,这个专家对应的 scores 都会比较大
  2. 由于 torch.topk 总是都把这个专家选上,因此 ce 计算出来的加权也比较大

所以根据排序不等式,负载越不均衡,计算出来的 aux_loss 也就越大。那么反过来说,如果 aux_loss 越小,说明专家之间的负载越均衡。


如果不关心 sequence level 的 loss,那么我们直接把这一个 batch 的所有 token 拍到一起(总共 BSB\cdot S 个 token,总共被分配 BSKBSK 个专家),我们把 (BSK,1)(BSK, 1) 的专家重新 encode 为 one-hot vector,即 (BSK,E)(BSK,E)

那么我们对 BSKBSK 轴求平均,fi 向量大小变为 (1,E)(1,E),就计算出来了某个专家处理的 token 数占所有 token 的比值(即频率)。因为 one-hot vector 不是 00 就是 11.

类似的,我们也对 scores (大小为 (BS,E)(BS,E)) 做 token level 的计算,直接按 BSBS 轴取平均即可,Pi 大小变为 (1,E)(1,E),即每个专家在所有 token 上的平均得分(关联度)

同样地,我们直接将 fiPi 对应位置相乘,求和乘上 self.alpha,这一步的目的和 sequence level 的 aux loss 是一样的。


MoE Feed Forward (MoEFFN)

然后我们来把他们组合到一起:首先,我们为 MoEFFN 定义好门控和专家,以及共享专家(无论如何都要处理 token)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()

self.config = config
self.experts = nn.ModuleList(
[
FeedForward(config)
for _ in range(config.n_routed_experts)
]
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)

然后编写训练和推理。训练和推理的区别在于

  1. 推理模式下,Token 只转发给最优的 Expert。但在训练模式下,Token 会被转发给每一个 Expert

Training

我们来考察一下训练时的代码。

这里,x.repeat_interleave() 重复输入数据,目的是让一个 token 可以多次被不同的 expert 处理,提升 expert 的泛化性

然后 y 就是计算 token 经过专家计算后输出的结果,并且将类型转为半精度浮点数 float16,此时的张量形状为 (BS×K,H)(BS\times K, H),经过 .view(*topk_weight.shape, -1) 之后变为

1
2
3
4
5
6
7
8
9
10
if self.training:
# 训练模式下,重复输入数据
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(
y.dtype
) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)

Inferencing

我们来考察前向传播过程。

提取出 x (输入的 Batch of sequence of tokens) 的维数信息之后,先经由 self.gate(x) 计算每一个 token 对应的专家,然后直接拍成 a sequence of tokens (BS,H)(BS,H)topk_idx 则直接拍成 (BSK,)(BSK,)

1
2
3
4
5
6
7
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)

随后进入 y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) 进行计算。

  • 这里 .argsort() 的逻辑是:因为 flat_expert_indices 的值是专家的编号,通过 argsort()idxs 把同一个专家要处理的 token 下标聚集到一起。
  • tokens_per_expert 以前缀和的方式,计算每一个专家处理的 token 下标的范围。
  • token_idxsidxs 出发计算某一个 idxs[i] 对应的是第 token_idxs[i] 个 token。这是因为每一个 token 都会分给 top KK 个专家,因此从下标来说,iK(i+1)K1i*K\to (i+1)*K-1 (idxs 保存的正好都是下标) 对应的都是第 ii 个 token,因此直接整数出除法可以计算出对应第几个 token.
1
2
3
4
5
6
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok

接着,我们枚举每一个专家,拿出它需要处理的所有 tokens (即代码里的 token_idxs[start_idx : end_idx] 以及 x[exp_token_idx])

我们把这些 token 经过 expert(expert_tokens) 计算、输出,得到 expert_out,乘上(对于这个 token 而言)每一个 expert 的权重。通过 scatter_add_()expert_cache 包含了每个 token 位置的加权专家输出总和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue

expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# 使用 scatter_add_ 进行 sum 操作
expert_cache.scatter_add_(
0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out
)
return expert_cache

除此之外,还需要加上共享专家的输出。不过这里的话,如果在推理模式,self.aux_loss 其实没作用

1
2
3
4
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y