MoE 架构
MoE 代码实现:以 MiniMind 为例
Experts
首先定义专家模块,Experts 是
Code
1 2 3 4 5 6 7 8 9 10
| class FeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w2 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w3 = nn.Linear(config.hidden_dim, condig.dim, bias=False) self.dropout = nn.Dropout(config.dropout)
def forward(self, x): return self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x)))
|
Router
然后,我们来实现 Router 路由器。Router 接收一个 Token,计算出概率取 Top K 后转发给对应的专家。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
| class MoEGate(nn.Module): def __init__(self, config: LMConfig): super().__init__()
self.config = config self.top_k = config.num_experts_per_token self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func self.alpha = config.aux_loss_alpha self.seq_aux = config.seq_aux self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim self.weight = nn.Parameter( torch.empty((self.n_routed_experts, self.gating_dim)) )
self.reset_parameter()
def reset_parameter(self) -> None: import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, tokens: torch.Tensor): batch, seq_len, d = tokens.shape tokens = einops.rearrange(tokens, "batch seq dim -> (batch seq) dim") logits = F.linear(tokens, self.weight, None)
if self.scoring_func == "softmax": scores = logits.softmax(dim=-1) else: raise NotImplementedError( "Unsupported scoring function", )
topk_weight, topk_idx = torch.topk( scores, k=self.top_k, dim=-1, sorted=False, )
if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator
if self.train and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k topk_idx_for_aux_loss = topk_idx.view(batch, -1)
if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(batch, seq_len, -1) ce = torch.zeros( batch, self.n_routed_experts, device=tokens.device, ) ce.scatter_add_( 1, topk_idx_for_aux_loss, torch.ones( batch, seq_len * aux_topk, device=tokens.device ) ).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = ( (ce * scores_for_seq_aux.mean(dim=-1)) .sum(dim=1) .mean() * self.alpha ) else: mask_ce = F.one_hot( topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts, ) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha else: aux_loss = 0.0
return topk_idx, topk_weight, aux_loss
|
以下分析约定:
- Batch size B,即一个 batch 包含 B 句句子
- Sequence Length S,即一句句子包含 S 个 Token
- 词嵌入向量维度 H
- 专家数量 E
- 每个 Token 被转发到 Top K 个专家
参数含义
top_k
为每一个 Token 选择几个专家进行训练
n_routed_experts
为专家总数
scoring_func
用于计算 Token 和专家之间的评分
aux_loss_alpha
辅助损失的 alpha 参数
seq_aux
控制是否在序列级别上计算辅助损失
norm_topk_prob
是否对概率进行归一化
由于我们对每一个 Token 计算它被发送到某一个专家的概率,在 __init__()
函数里,我们令 Token 的维度为 d= self.gating_dim
,专家数量为 n= self.n_routed_experts
,那么我们希望 Router 输出的矩阵大小就是
Router:RB×S×H↦RB×S×E因此这里的 self.weight
是 RH×E 大小的矩阵,负责计算一个 Token 的被转发到 Expert 的概率。
前向传播 forward()
Tensor 输入是 B×S×H,因为我们只关心 Token 发送到哪个专家,所以我们首先把 Tensor 拍成二维 BS×H,并用 self.weight
计算概率,用 softmax()
归一化。F.linear(x, A, bias=None)
计算
y=xA⊺因此,这里的 score
大小为 BS×E
接着,我们调用 torch.topk()
选取前 K 个专家,返回 score
中对应的权重和下标。此时 topk_weight
, topk_idx
大小均为 BS×K
然后对选择出来的 K 个专家的权重再进行一次归一化(除以 denominator
)。不过这一步是可选的
接着进入 if self.train and self.alpha > 0.0:
判断,目的是为了平衡专家之间的负载。首先,代码确保只在训练模式以及需要平衡负载的时候才会启用。
如果需要计算 sequence level 的 loss,那么 topk_idx
首先被拍平成 (B,SK)
因为要对 sequence level 计算专家负载损失,所以我们先定义每一句句子上专家的负载损失,其大小为 (B,E).
然后遍历 topk_idx
中的每一个元素,用 scatter_add_()
将 topk_idx
中的每一个元素添加到对应的 sequence 里对应的专家中。
这个流程结束之后,ce
(count experts) 保存的就是每一句句子的专家负载。随后,我们对每一句句子都归一化其专家负载:一句句子会产生 SK 个 counting,总数为 SK
我们沿 S 轴对 scores_for_seq_aux
计算平均值 scores.mean(dim=1)
这就表示每句句子与每个 expert 之间的得分(token 与 experts 得分的平均),其大小变为 (B,E).
然后再将 ce
与 scores.mean()
对应位置相乘,ce
可以理解为每句句子中 expert 的频率(这个 expert 在这句句子里总是被分配处理 token),scores.mean()
可以理解为每句句子中 expert 对于每个 token 的重要程度(这个 expert 总是被 MoEGate
认为与 token 关联很大),大小变为 (B,E),但此时仍然是每句句子与 expert 的关联。
因此再对 E 求和取平均 .sum(dim=1).mean()
,把 expert 与不同句子之间的频率与关联度整合起来,即对于这些句子 expert 的频率与关联度,也即 expert 的负载,其大小先变为 (E,),再变为 (1,)。最后再乘上标量 self.alpha
. 这就是我们的 sequence level 的 aux loss.
值得一提的是,这里还额外乘了一个 E,推测是为了让梯度不至于太小
Why It Makes Sense
因为我们的训练目标是让 Loss 尽可能的小,对于这个 Aux Loss 而言,如果某一个专家的负载特别大,那么就说明
- 每一个 Batch 内的所有 Tokens,这个专家对应的
scores
都会比较大
- 由于
torch.topk
总是都把这个专家选上,因此 ce
计算出来的加权也比较大
所以根据排序不等式,负载越不均衡,计算出来的 aux_loss
也就越大。那么反过来说,如果 aux_loss
越小,说明专家之间的负载越均衡。
如果不关心 sequence level 的 loss,那么我们直接把这一个 batch 的所有 token 拍到一起(总共 B⋅S 个 token,总共被分配 BSK 个专家),我们把 (BSK,1) 的专家重新 encode 为 one-hot vector,即 (BSK,E)
那么我们对 BSK 轴求平均,fi
向量大小变为 (1,E),就计算出来了某个专家处理的 token 数占所有 token 的比值(即频率)。因为 one-hot vector 不是 0 就是 1.
类似的,我们也对 scores
(大小为 (BS,E)) 做 token level 的计算,直接按 BS 轴取平均即可,Pi
大小变为 (1,E),即每个专家在所有 token 上的平均得分(关联度)
同样地,我们直接将 fi
与 Pi
对应位置相乘,求和乘上 self.alpha
,这一步的目的和 sequence level 的 aux loss 是一样的。
MoE Feed Forward (MoEFFN)
然后我们来把他们组合到一起:首先,我们为 MoEFFN 定义好门控和专家,以及共享专家(无论如何都要处理 token)
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| class MOEFeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() self.config = config self.experts = nn.ModuleList( [ FeedForward(config) for _ in range(config.n_routed_experts) ] ) self.gate = MoEGate(config) if config.n_shared_experts is not None: self.shared_experts = FeedForward(config)
|
然后编写训练和推理。训练和推理的区别在于
- 推理模式下,Token 只转发给最优的 Expert。但在训练模式下,Token 会被转发给每一个 Expert
Training
我们来考察一下训练时的代码。
这里,x.repeat_interleave()
重复输入数据,目的是让一个 token 可以多次被不同的 expert 处理,提升 expert 的泛化性
然后 y
就是计算 token 经过专家计算后输出的结果,并且将类型转为半精度浮点数 float16
,此时的张量形状为 (BS×K,H),经过 .view(*
topk_weight.shape
, -1)
之后变为
1 2 3 4 5 6 7 8 9 10
| if self.training: x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) y = torch.empty_like(x, dtype=torch.float16) for i, expert in enumerate(self.experts): y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to( y.dtype ) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape)
|
Inferencing
我们来考察前向传播过程。
提取出 x
(输入的 Batch of sequence of tokens) 的维数信息之后,先经由 self.gate(x)
计算每一个 token 对应的专家,然后直接拍成 a sequence of tokens (BS,H),topk_idx
则直接拍成 (BSK,)
1 2 3 4 5 6 7
| identity = x orig_shape = x.shape bsz, seq_len, _ = x.shape
topk_idx, topk_weight, aux_loss = self.gate(x) x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1)
|
随后进入 y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
进行计算。
- 这里
.argsort()
的逻辑是:因为 flat_expert_indices
的值是专家的编号,通过 argsort()
,idxs
把同一个专家要处理的 token 下标聚集到一起。
tokens_per_expert
以前缀和的方式,计算每一个专家处理的 token 下标的范围。
token_idxs
从 idxs
出发计算某一个 idxs[i]
对应的是第 token_idxs[i]
个 token。这是因为每一个 token 都会分给 top K 个专家,因此从下标来说,i∗K→(i+1)∗K−1 (idxs
保存的正好都是下标) 对应的都是第 i 个 token,因此直接整数出除法可以计算出对应第几个 token.
1 2 3 4 5 6
| @torch.no_grad() def moe_infer(self, x, flat_expert_indices, flat_expert_weights): expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) token_idxs = idxs // self.config.num_experts_per_tok
|
接着,我们枚举每一个专家,拿出它需要处理的所有 tokens (即代码里的 token_idxs[start_idx : end_idx]
以及 x[exp_token_idx]
)
我们把这些 token 经过 expert(expert_tokens)
计算、输出,得到 expert_out
,乘上(对于这个 token 而言)每一个 expert 的权重。通过 scatter_add_()
,expert_cache
包含了每个 token 位置的加权专家输出总和。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| for i, end_idx in enumerate(tokens_per_expert): start_idx = 0 if i == 0 else tokens_per_expert[i - 1] if start_idx == end_idx: continue
expert = self.experts[i] exp_token_idx = token_idxs[start_idx:end_idx] expert_tokens = x[exp_token_idx] expert_out = expert(expert_tokens).to(expert_cache.dtype) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) expert_cache.scatter_add_( 0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out ) return expert_cache
|
除此之外,还需要加上共享专家的输出。不过这里的话,如果在推理模式,self.aux_loss
其实没作用
1 2 3 4
| if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) self.aux_loss = aux_loss return y
|