“ 花里胡哨带英文的的图是从下地址截取的,文字和代码都是笔者写的,所以可能有错误的地方
https://huggingface.co/blog/vtabbott/mixtral
mixtral基础结构跟正常的decoder结构模型一致,可以划分成3个部分,输入embedding层、N个decoder block、lm解码头。如下图所示:
每个decoder layer包含2个大模块,attention和mlp,整体上跟llama这些模型一致
接下来这张图,是描述attention的结构,mixtral中比较特别的是,使用的sliding windown attention、grouped query attention
其中sliding attention是指,指定一个sliding window,每个token往前只能观察到sliding window内的token信息。(decoder 结构是单向注意力,也就是每个位置的token只能观察到当前位置及之前位置的输入信息),这个可以很容易的通过attention mask来实现。下图是复制transformers中生成attention mask的代码示例(指定sliding window=3):
接下来看grouped query attention,如下图
multi-head attention每个注意力头都有自己的query、key、value;multi-query attention:在所有的注意力头上共享key、value,训练过程,不会明显影响训练过程,训练速度基本不变,会引起非常细微的模型效果损失。但是推理速度更快,显存占用更低,因为推理时,反复加载很大的kv cache,内存开销比较大,性能收到内存受限。(GPU的内存由多个大小不同,读写速度不同的内存组成,对于A100-40G,SRAM内存大小为20MB、HBM内存大小为40GB。一般memory bound的问题,比如flash attention这些策略就是减少对HBM内存的读写次数);grouped query attention:介于multi head和multi query之间,具有多个key、value
最后是Mixtral的核心模块,Sparse Mixture of Experts (SMoE),如下图所示。SMoE具有多个层(“专家”)可用。对于每个输入,将对最相关专家的输出进行加权求和。
还是看代码来理解把:每个expert是个什么东西呢?
class MixtralBLockSparseTop2MLP(nn.Module):
def \_\_init\_\_(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden\_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
mlp层一般分为2种,一个是传统的:FFN(x) = f(x * w1 + b1) * w2 + b2;激活函数为Gelu、Swish
一种是使用GLU门控的: FFN(x) = (f(x * w1) X (x * V) ) * w2;同样激活函数可以用Gelu、Swish;中间有个X 是点乘,这种有3个训练权重,比传统的多一个训练权重,像llama是用的swiGLU的mlp,但是chatglm 6b就是用的gelu。
上面的每个expert的代码也是基于GLU门控的。
有了8个专家之后,就可以来看sparseMOE模块了,分为2步,第一步是针对每个token,选topk个专家,第二步是,加权得到每个token的mlp特征
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
获取到每个token的mlp层输入特征
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
得到每个专家的打分,维度是batch * sequence, num_experts,取topk个专家
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
取到topk个专家的打分,需要计算在归一化一下,用于对后面的expert计算出来的
结果进行加权
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
routing_weights、selected_experts 维度是一致的,取了topk (bs * sl, topk)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
如果不做后面的维度切换,那expert_mask的维度是 (bs*sl, topk, n_experts),但是后面要遍历n_experts来计算,所以颠倒一下,得到(n_experts, topk, bs * sl);
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
这样取到expert_mask[expert_idx],从上面的注释可以知道维度是
[topk, bs * sl];torch.where的结果,第一个结果代表选到了哪一行,第二个代表选择了哪一列
对应到实际意义,top_x表示取的列,也就是取哪些token
而行表示,取到的这些token,根据路由gate计算,当前expert是排行第几;
所以这里变量名字可能有点混淆,
if top_x.shape[0] == 0:
没有token需要当前的expert计算
continue
tensor index使用list比tensor快
top_x_list = top_x.tolist()
idx_list = idx.tolist()
前面hidden states已经转成了 [bs * sl, hs],根据top_x
可以找到需要计算的token
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
找到这个expert对应的权重 乘进去
上面计算的权重是routing_weights,维度是bs * sl, topk
根据top_x_list 对应的token,idx_list表示topk中第几个
可以直接取到相应的权重
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
合到最终的特征里边去
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)