理解Mixtral Moe模型原理与代码实现

火山方舟人工智能与算法内容安全与风控

“ 花里胡哨带英文的的图是从下地址截取的,文字和代码都是笔者写的,所以可能有错误的地方


        
          
https://huggingface.co/blog/vtabbott/mixtral  

      

mixtral基础结构跟正常的decoder结构模型一致,可以划分成3个部分,输入embedding层、N个decoder block、lm解码头。如下图所示:

picture.image

每个decoder layer包含2个大模块,attention和mlp,整体上跟llama这些模型一致picture.image

接下来这张图,是描述attention的结构,mixtral中比较特别的是,使用的sliding windown attention、grouped query attentionpicture.image

其中sliding attention是指,指定一个sliding window,每个token往前只能观察到sliding window内的token信息。(decoder 结构是单向注意力,也就是每个位置的token只能观察到当前位置及之前位置的输入信息),这个可以很容易的通过attention mask来实现。下图是复制transformers中生成attention mask的代码示例(指定sliding window=3):picture.image

接下来看grouped query attention,如下图picture.image

multi-head attention每个注意力头都有自己的query、key、value;multi-query attention:在所有的注意力头上共享key、value,训练过程,不会明显影响训练过程,训练速度基本不变,会引起非常细微的模型效果损失。但是推理速度更快,显存占用更低,因为推理时,反复加载很大的kv cache,内存开销比较大,性能收到内存受限。(GPU的内存由多个大小不同,读写速度不同的内存组成,对于A100-40G,SRAM内存大小为20MB、HBM内存大小为40GB。一般memory bound的问题,比如flash attention这些策略就是减少对HBM内存的读写次数);grouped query attention:介于multi head和multi query之间,具有多个key、value

最后是Mixtral的核心模块,Sparse Mixture of Experts (SMoE),如下图所示。SMoE具有多个层(“专家”)可用。对于每个输入,将对最相关专家的输出进行加权求和。

picture.image

还是看代码来理解把:每个expert是个什么东西呢?


        
          
class MixtralBLockSparseTop2MLP(nn.Module):  
    def \_\_init\_\_(self, config: MixtralConfig):  
        super().__init__()  
        self.ffn_dim = config.intermediate_size  
        self.hidden_dim = config.hidden_size  
  
        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)  
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)  
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)  
  
        self.act_fn = ACT2FN[config.hidden_act]  
  
    def forward(self, hidden\_states):  
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)  
        current_hidden_states = self.w2(current_hidden_states)  
        return current_hidden_states  

      

mlp层一般分为2种,一个是传统的:FFN(x) = f(x * w1 + b1) * w2 + b2;激活函数为Gelu、Swish

一种是使用GLU门控的: FFN(x) = (f(x * w1) X (x * V) ) * w2;同样激活函数可以用Gelu、Swish;中间有个X 是点乘,这种有3个训练权重,比传统的多一个训练权重,像llama是用的swiGLU的mlp,但是chatglm 6b就是用的gelu。

上面的每个expert的代码也是基于GLU门控的。

有了8个专家之后,就可以来看sparseMOE模块了,分为2步,第一步是针对每个token,选topk个专家,第二步是,加权得到每个token的mlp特征


        
          
  
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)  
  
获取到每个token的mlp层输入特征   
batch_size, sequence_length, hidden_dim = hidden_states.shape  
hidden_states = hidden_states.view(-1, hidden_dim)  
  
得到每个专家的打分,维度是batch * sequence, num_experts,取topk个专家  
router_logits = self.gate(hidden_states)  
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)  
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)  
  
取到topk个专家的打分,需要计算在归一化一下,用于对后面的expert计算出来的  
结果进行加权  
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)  
routing_weights = routing_weights.to(hidden_states.dtype)  
  
routing_weights、selected_experts 维度是一致的,取了topk   (bs * sl, topk)  
  
  
final_hidden_states = torch.zeros(  
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device  
        )  
  
如果不做后面的维度切换,那expert_mask的维度是 (bs*sl, topk, n_experts),但是后面要遍历n_experts来计算,所以颠倒一下,得到(n_experts, topk, bs * sl);   
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)  
  
for expert_idx in range(self.num_experts):  
    expert_layer = self.experts[expert_idx]  
    idx, top_x = torch.where(expert_mask[expert_idx])  
      
这样取到expert_mask[expert_idx],从上面的注释可以知道维度是  
    [topk, bs * sl];torch.where的结果,第一个结果代表选到了哪一行,第二个代表选择了哪一列  
      
    对应到实际意义,top_x表示取的列,也就是取哪些token  
    而行表示,取到的这些token,根据路由gate计算,当前expert是排行第几;  
    所以这里变量名字可能有点混淆,  
      
      
    if top_x.shape[0] == 0:  
    没有token需要当前的expert计算  
        continue  
      
    tensor index使用list比tensor快  
    top_x_list = top_x.tolist()  
    idx_list = idx.tolist()  
  
    前面hidden states已经转成了 [bs * sl, hs],根据top_x  
    可以找到需要计算的token  
    current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)  
      
    找到这个expert对应的权重 乘进去  
    上面计算的权重是routing_weights,维度是bs * sl, topk  
    根据top_x_list 对应的token,idx_list表示topk中第几个  
    可以直接取到相应的权重  
    current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]  
  
    合到最终的特征里边去  
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))  
      
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)  
  

      

picture.image

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论