Mixtral-8x7B 模型挖坑

技术

写在前面

大家好,我是刘聪NLP。

2024年第一天,给大家带来一篇前一段时间大火的Mixtral-8x7B分析文章-《Mixtral-8x7B 模型挖坑》,来自知乎@孟繁续(已授权)。


          
知乎原文:https://zhuanlan.zhihu.com/p/674751021  

      

MistralAI很高冷的给开源社区扔了一条磁力链,基于Mixture of Experts的混合专家模型Mixtral-8x7B和指令微调的Mixtral-8x7B-Instruct来了。此前曾爆料GPT4就是基于MoE技术的大模型,MistralAI证明通过不到8个7B的参数量,不到2个7B模型的计算量,就能超越LLaMA 2 70B的效果,甚至部分超越了GPT-3.5的水平,随即这两个模型引爆社交网络。截至目前,官网展示了Mixtral-8x7B的模型效果:

picture.image 图1. Mistral 8x7B超越LLaMA 2 70B和GPT-3.5


          
Mistral官网:https://mistral.ai/news/mixtral-of-experts/  
Mixtral-8x7B: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1  
Mixtral-8x7B-Instruct: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1  

      

模型的命名方式也充满野心, 新的7B模型只叫了个小小杯,效果这么好的8x7B MoE模型叫了个小杯,而在La plateforme中可以申请调用一个中杯模型的API(也许是8x13b、8x34B?),推测大杯和超大杯应该也在路上了。

picture.image

picture.image 图2.小小杯-小杯-中杯效果对比


          
La plateforme: https://mistral.ai/news/la-plateforme/  

      

结构介绍

Mixtral-8x7B和LLaMA结构唯一的区别,在于将MLP layer复制成了8个expert layers并在一起,通过一个gate layer,对每个token选择top-2的专家模型进行计算,这里结合transformers中的代码和图示理解会比较好:


          
# 注意:为了容易理解,我对代码进行了简化,同时不考虑batch size,实际使用时还是要用官方代码  
class MixtralSparseMoeBlock(nn.Module):  
    def \_\_init\_\_(self, config):  
        super().__init__()  
        self.gate = nn.Linear(self.hidden_dim, 8)  
        self.experts = nn.ModuleList([MLP(config) for _ in range(8)])  
  
    def forward(self, x):  
        # 对每个token计算8个expert的权重,并将权重归一化  
        router_logits = self.gate(x)  
        routing_weights = F.softmax(router_logits, dim=1)  
        # 每个token选择top-2 experts的权重、索引, 并将索引转为size=(len(tokens), 8)的独热编码  
        routing_weights, selected_experts = torch.top2(routing_weights, dim=-1)  
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=8)  
        # 重新将top-2 expert的权重归一化(因为删掉其余6个expert,权重的和不等于1了)  
 routing_weights /= routing_weights.sum(dim=-1)  
        # 创建形状和x一致,初始值为0的矩阵,用来存储每个expert的输出  
 final_hidden_states = torch.zeros_like(x)  
        for expert_idx in range(8):  
            # 选择当前使用的expert  
            expert_layer = self.experts[expert_idx]  
            # 选择当前expert对应的index  
            idx_list, top_x_list = torch.where(expert_mask[expert_idx])  
            # 选择需要计算的状态  
            current_state = x[top_x_list]  
            # 选择当前expert对每个token的权重  
            current_routing_weights = routing_weights.t()[top_x_list, idx_list]  
            # 将选择的状态输入给专家模型计算,并乘上权重  
            current_hidden_states = expert_layer(current_state) * current_routing_weights  
            # 将每个expert的输出按照索引加到最终结果里  
            final_hidden_states.index_add_(0, top_x_list, current_hidden_states)  
        return final_hidden_states  
  

      

picture.image MixtralSparseMoeBlock

如图所示为Mixtral的MoE FFN的示意图,

首先,对于输入,先乘上一个 的gate layer,得到的表示(router),用softmax对其归一化之后,选出top-2的专家的权重和索引,将索引转为稀疏矩阵expert_mask。

接着循环对8个experts重复执行以下操作:选择expert(如图中),每个expert只需要处理自己是top-2的tokens,利用expert_mask得到这些tokens(值得注意的是,每个expert需要处理的长度不一,但8个experts需要处理的序列长度之和刚好等于2倍出入序列长度,因为每个tokens都会被两个experts处理)。

最后将每个expert的输出加权平均得到输出。

由于官方没有对Mixtral-8x7B的训练方法做任何说明,我希望通过实验验证一些问题,试图了解一点可能的训练流程。大部分实验基于Mixtral-8x7B-v0.1模型,MMLU的代码基于chain-of-thought-hub,用了4bit量化,所以测出来的结果会低于官方公布的结果。


          
chain-of-thought-hub: https://github.com/FranxYao/chain-of-thought-hub/tree/main  

      

问题1. 训练的时候也是用的top-2 experts吗?

有这个问题是考虑,有没有可能训练的时候激活了全部的experts,推理时为了平衡速度和效果才选择的top-2 experts。这个问题的答案比较容易获得,只需要更改激活experts的数量,在MMLU上测试其效果就可以:

picture.image 激活x/8个experts的效果

实验发现,在top-1效果下降很多,top-3略微上升,继续增多被激活的experts数量,效果会下降。如果训练时激活了更多的experts,这条曲线应该是平稳中有一点上升,而不会下降。所以这个猜测是多余的,模型确实是用top-2训练的(杠一下的话,top-3也不是没有可能)。

问题2. 8个专家模型的贡献一样吗?

为了解答这个问题,我们使用select_experts代码,逐一删除每个专家模型,并删除对应的gate layer,然后在MMLU上测试所得模型的能力:

picture.image 删除一个expert后,在MMLU数据集上的表现

可以发现,删除expert3后,模型直接fail,但是删除experts-0,2,4,5,6,7专家后的表现差不多,可以说一句,Mixtral-8x7B的负载均衡是没做好的。这里放一个大家很熟悉的表情包,显然experts是在坑里干活的那个。

picture.image

另外每次激活top-2和top-1的趋势是一致的,为了节省计算量和时间,后续的实验出特殊说明,都用了top-1。


          
select_experts代码: https://github.com/fxmeng/mixtral_spliter/blob/11550c700540fb623dea22e42ee55a57536af1b8/select_experts.py  

      

问题3. 每个专家模型预训练阶段是否针对不同的任务进行的训练?

为了解答这个问题,需要在不同的数据集上进行删除每个expert的测试,除了MMLU,我还测试了MT-Bench,之后还会测测Math,Code的能力。MT-Bench的曲线和MMLU的几乎一致:

picture.image 移除某个专家模型后在MT Bench的效果(激活top-1 expert)

暂时看起来每个专家模型预训练的任务并没有太大差异,不过这个结论需要等更多任务测试结果出来才敢确定。

问题4. 专家模型数量对效果的影响

这里才用了贪心法,逐一删除experts,每一行代表基于上一次的最优模型,删除每一个expert的结果。为了看起来容易,每一列的顺序调整成了删除的顺序[2, 5, 6, 5, 7, 0, 1, 3]。这个顺序一定程度上反映了对应专家模型的重要性,越往后越高。可以发现在任何时候,试图删除expert 3,掉点都会很严重。

picture.image

按照[2, 5, 6, 5, 7, 0, 1, 3]的顺序删除或添加专家模型,专家数量和效果的关系在下图中直观的展示出来了:

picture.image 效果随着模型数量变化趋势

可以看出来,每个模型删除的时候都有比较明显的掉点,说明除了在坑里的expert 3,其他专家也是有他们的作用的。

得到的每个数量贪心最优的子模型分享在了Mixtral-1-7x7b-instruct,Mixtral-1-7x7b。如果硬盘,显存不够可以按照实际硬件情况,选择这个贪心最优子模型来学习,使用Mixtral。调用的方法和原始的模型没有区别,生成子结构的代码也放在每个checkpoint里了,如果你本地有完整版的模型,运行代码就可以自己生成这个子模型了。未来也许会finetune一下3-6个experts的子模型,应该会有比较明显的提升,毕竟模型还是有些脆弱的。


          
Mixtral-1-7x7b-instruct: https://huggingface.co/collections/fxmeng/mixtral-1-7x7b-instruct-v01-658bf0268efe4503a06da312  
Mixtral-1-7x7b: https://huggingface.co/collections/fxmeng/mixtral-1-7x7b-v01-658bf0a9674aa14b9b07d2d4  

      

问题5. 在预训练阶段是否先独立训练了8个7B的模型,然后再把FFN合在一起训?

在Twitter上看到一张图,比较Mixtral-8x7B和Mistral-7B的attention layer中的QKVO矩阵的cosine similarity,发现相似度比较高,由此得出结论是Mixtral是将Mistral复制了8份。这个结论下的有点容易,所以我打算把实验扩展一下。

picture.image https://x.com/tianle\_cai/status/1734188749117153684?s=20

首先需要知道这个相似度代表了什么,我测了

  • LLaMA-7B、LLaMA-2 7B以及OpenLLaMA-7B的相似度接近于0;
  • LLaMA-2 和 LLaMA-2-Chat的相似度接近1;
  • Mistral-7B-v0.1,和Mistral-7B-Instruct-v0.1,Mistral-7B-Instruct-v0.2的相似度也为1。

LLaMA、LLaMA-2,OpenLLaMA的对比说明对于相同架构的模型,训练数据相差也不大的情况下,每次独立的训练,相似度极低。

LLaMA-2 和LLaMA-2-Chat的对比,以及Mistral-7B-v0.1,和Mistral-7B-Instruct-v0.1,Mistral-7B-Instruct-v0.2的对比说明,对于训好的模型finetune,即使模型的输出分布发生了很大的变化,相似度也极高。

那可以得出Mixtral-8x7B和Mistral-7B的attention layers之间肯定是有继承关系的。

那FFN呢?8个experts参数都是哪里来的?我在FFN上也测试了和Mistral-7B的相似度,结果如下:

picture.image Expert 0

picture.image Expert 1

picture.image Expert 2

picture.image Expert 3

picture.image Expert 4

picture.image Expert 5

picture.image Expert 6

picture.image Expert 7

可以看出相似度均在40%左右,说明训练每个experts的FFN和Mistral-7B也是有关系的。我又测试了每个expert之间,每层的平均相似度如下:

picture.image

可以发现expert 3和其他experts的相似度比较低,关于这个expert的特殊性,之后还将继续探索。

此外experts彼此之间的相似度比和Mistral-7B的低。

这里做一个猜测:Mixtral-8x7B的模型是利用Mistral-7B训练过程中early-stage的checkpoint训练而来的。Attention layer是直接继承的checkpoint,FFN是复制了8份,然后加上gate layer继续完成后续的预训练的。这样就能解释Attention和MLP有一定的相似度,但不太高。这不会是独立训练8个模型或者在模型收敛后finetune一下能得到的结果。

问题6. 只有单卡3090能不能finetune Mixtral?

很多时候我们可能需要调试代码,或者自学过程中,没有A100,能不能finetune Mixtral 8x7B?答案是不能。不过别急,只要用我拆开的模型,就能流畅的在单张3090上进行微调。下面是代码(基于learn-llm项目),实测finetune 2个专家模型只需13G显存,3个专家模型需要18G显存,4个专家模型需要21G显存:


          
from trl import SFTTrainer  
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments,AutoConfig  
from datasets import load_dataset  
import torch  
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training  
  
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16  
  
## load model in 4bit  
bnb_config = BitsAndBytesConfig(  
    load_in_4bit=True,  
    bnb_4bit_quant_type="nf4",  
    bnb_4bit_compute_dtype=torch_dtype,  
    bnb_4bit_use_double_quant=True,  
)  
config = AutoConfig.from_pretrained('fxmeng/Mixtral-2x7B-Instruct-v0.1')  
config.use_cache = False  
config.gradient_checkpointing = True  
  
model = AutoModelForCausalLM.from_pretrained('fxmeng/Mixtral-2x7B-Instruct-v0.1',  
                                             config=config,  
                                             quantization_config=bnb_config,  
                                             trust_remote_code=False,  
                                             torch_dtype=torch_dtype,  
                                             device_map="auto")  
  
tokenizer = AutoTokenizer.from_pretrained('fxmeng/Mixtral-2x7B-Instruct-v0.1',  
                                          trust_remote_code=False,  
                                          use_fast=True)  
tokenizer.pad_token = tokenizer.eos_token  
  
# lora  
peft_config = LoraConfig(  
    r=64,  
    lora_alpha=16,  
    lora_dropout=0.1,  
    target_modules=['k\_proj', 'w3', 'v\_proj', 'gate', 'lm\_head', 'q\_proj', 'w2', 'w1', 'o\_proj'],  
    bias="none",  
    task_type="CAUSAL\_LM",  
    inference_mode=False  
)  
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)  
model = get_peft_model(model, peft_config)  
  
# training setting  
training_args = TrainingArguments(  
    do_train=True,  
    do_eval=True,  
    output_dir="./checkpoints",  
    dataloader_drop_last=True,  
    evaluation_strategy="steps",  
    save_strategy="steps",  
    logging_strategy="steps",  
    num_train_epochs=1,  
    eval_steps=10,  
    save_steps=10,  
    logging_steps=10,  
    per_device_train_batch_size=1,  
    per_device_eval_batch_size=1,  
    optim="paged\_adamw\_8bit",  
    learning_rate=1e-4,  
    lr_scheduler_type='constant',  
    warmup_steps=50,  
    gradient_accumulation_steps=1,  
    gradient_checkpointing=True,  
    weight_decay=0.05,  
    report_to="wandb",  
    load_best_model_at_end=True,  
    save_total_limit=1,  
    bf16=torch_dtype==torch.bfloat16,  
    fp16=torch_dtype!=torch.bfloat16,  
)  
  
# training and evaluation use the same dataset  
dataset = load_dataset("imdb", split="train")  
  
trainer = SFTTrainer(  
    model=model,  
    args=training_args,  
    train_dataset=dataset,  
    eval_dataset=dataset,  
    dataset_text_field="text",  
    max_seq_length=1024,  
    tokenizer=tokenizer,  
    data_collator=None,  
    packing=None  
)  
  
trainer.train()  
trainer.save_model('mixtral\_2x7b')  
  

      

4个experts用来学习和调试代码的话,效果看起来还不错:


          
from transformers import AutoTokenizer  
import transformers  
import torch  
  
model = "fxmeng/Mixtral-4x7B-Instruct-v0.1"  
tokenizer = AutoTokenizer.from_pretrained(model)  
pipeline = transformers.pipeline(  
    "text-generation",  
    model=model,  
    model_kwargs={"torch\_dtype": torch.float16, "load\_in\_4bit": True},  
)  
  
messages = [{"role": "user", "content": "Explain what a Mixture of Experts is in less than 100 words."}]  
prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)  
outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)  
print(outputs[0]["generated\_text"])  
  
from transformers import AutoTokenizer  
import transformers  
import torch  
  
model = "fxmeng/Mixtral-4x7B-Instruct-v0.1"  
  
tokenizer = AutoTokenizer.from_pretrained(model)  
pipeline = transformers.pipeline(  
    "text-generation",  
    model=model,  
    model_kwargs={"torch\_dtype": torch.float16, "load\_in\_4bit": True},  
)  
  

      

          
learn-llm github: https://github.com/hengjiUSTC/learn-llm/tree/main  

      

问题7. MoE每层之间的expert有无关联?

按照前文的猜测,在一个模型被训练到一定阶段后,被复制成了8份,此时每一层参数的每个expert参数一样,然后继续训练,此时每个expert不再是独立的个体,因此最终每层之间的expert不再有关联。表现为打乱每层expert的顺序,不影响模型的输出。因此可以说Mixtral-8x7B不是8个experts,而是32x8个experts。那么问题就来了,这256个experts对模型输出的影响是多大?这个问题如果还像之前一样一个一个的挖掉,计算量就太大了。因此我使用了另一种方法:将整个数据集输入模型,不再需要进行生成,而是记录下每个token在每一层激活的expert。然后就能得到一个32x8的矩阵,每个元素表示这个第层,第个expert被激活的次数。然后每层按照每个expert被激活的次数从大到小,可以认为其重要性逐渐降低。


          
# 作为top-1 expert被激活的次数排序  
[[1, 2, 3, 6, 0, 4, 7, 5],  
 [2, 1, 5, 6, 7, 4, 3, 0],  
 [6, 5, 2, 4, 0, 1, 7, 3],  
 [6, 4, 0, 3, 7, 2, 5, 1],  
 [3, 6, 1, 4, 0, 7, 2, 5],  
 [4, 3, 6, 7, 0, 5, 1, 2],  
 [5, 3, 6, 1, 4, 0, 2, 7],  
 [5, 2, 1, 4, 6, 0, 7, 3],  
 [5, 7, 3, 1, 6, 4, 2, 0],  
 [4, 6, 5, 0, 1, 2, 3, 7],  
 [4, 0, 3, 1, 6, 5, 2, 7],  
 [5, 1, 4, 0, 7, 6, 2, 3],  
 [4, 1, 3, 6, 0, 2, 5, 7],  
 [5, 1, 0, 7, 6, 3, 4, 2],  
 [3, 7, 6, 0, 5, 2, 1, 4],  
 [7, 0, 6, 2, 5, 1, 3, 4],  
 [7, 1, 5, 3, 6, 2, 0, 4],  
 [0, 7, 5, 6, 1, 4, 2, 3],  
 [3, 7, 2, 1, 4, 0, 6, 5],  
 [0, 4, 2, 7, 6, 5, 1, 3],  
 [3, 7, 5, 6, 0, 1, 2, 4],  
 [3, 5, 0, 1, 6, 4, 2, 7],  
 [7, 1, 4, 0, 3, 6, 5, 2],  
 [1, 3, 5, 6, 0, 2, 7, 4],  
 [5, 2, 3, 1, 7, 6, 4, 0],  
 [1, 4, 0, 3, 6, 5, 2, 7],  
 [2, 1, 4, 3, 5, 0, 7, 6],  
 [3, 5, 2, 0, 7, 4, 1, 6],  
 [2, 3, 4, 1, 7, 5, 0, 6],  
 [4, 1, 5, 2, 6, 7, 3, 0],  
 [0, 7, 1, 4, 5, 3, 2, 6],  
 [2, 6, 5, 0, 7, 1, 3, 4]]  

      

          
# 作为top-2 expert被激活的次数排序  
[[5, 2, 7, 3, 4, 0, 6, 1],  
 [3, 5, 2, 4, 7, 1, 0, 6],  
 [5, 1, 6, 7, 2, 0, 3, 4],  
 [0, 3, 4, 5, 7, 1, 6, 2],  
 [0, 3, 7, 2, 5, 6, 4, 1],  
 [1, 7, 2, 5, 4, 6, 3, 0],  
 [6, 0, 7, 3, 2, 5, 4, 1],  
 [3, 2, 4, 7, 0, 6, 5, 1],  
 [4, 0, 3, 2, 7, 5, 6, 1],  
 [5, 2, 4, 3, 1, 0, 7, 6],  
 [5, 1, 7, 3, 6, 2, 4, 0],  
 [6, 0, 2, 1, 7, 5, 4, 3],  
 [0, 1, 2, 5, 4, 3, 6, 7],  
 [1, 4, 2, 0, 5, 3, 7, 6],  
 [5, 1, 0, 3, 2, 4, 6, 7],  
 [2, 5, 4, 6, 1, 0, 7, 3],  
 [7, 6, 4, 1, 5, 2, 0, 3],  
 [1, 2, 7, 0, 6, 5, 4, 3],  
 [6, 4, 0, 1, 5, 7, 3, 2],  
 [7, 5, 6, 2, 0, 3, 1, 4],  
 [2, 1, 3, 0, 5, 6, 7, 4],  
 [0, 1, 3, 7, 2, 4, 6, 5],  
 [1, 3, 7, 2, 0, 4, 6, 5],  
 [5, 2, 1, 0, 3, 7, 4, 6],  
 [2, 5, 1, 6, 7, 3, 0, 4],  
 [6, 7, 2, 5, 0, 1, 3, 4],  
 [2, 5, 0, 4, 6, 7, 3, 1],  
 [6, 3, 5, 4, 2, 1, 0, 7],  
 [0, 1, 7, 6, 2, 5, 4, 3],  
 [3, 4, 0, 2, 7, 1, 5, 6],  
 [4, 3, 2, 0, 7, 6, 5, 1],  
 [5, 7, 3, 4, 2, 0, 6, 1]]  

      

将模型的每一层按照2 x top1 + top2次数排序,能得到一个新的模型,其输出和原模型一致,但是每层的第一个expert都是这层里面最重要的,最后一个expert都是这层里面最不重要的。然后就可以用问题4中的方法,根据自己的硬件条件,选择对应数量的experts。使用新的方法获得的子结构比之前把每个expert当作一个整体,使用贪心法得到的子结构精度损失明显减小:

picture.image

写在最后

本文简单介绍了一下Mixtral 8x7B模型的原理,随后通过5个问题,试图还原其训练流程,了解模型特性。用贪心算法搜出了1-7个专家的子结构,推测Mixtral-8x7B是利用Mistral-7B训练过程中early-stage的checkpoint训练而来的。受限于笔者水平,肯定有不足之处,一家之言,恳请大家批评指正。

欢迎多多关注公众号「NLP工作站」,欢迎加入交流群,有问题的朋友也欢迎加我微信「logCong」私聊,交个朋友吧,一起学习,一起进步。我们的口号是“生命不止,学习不停”。

PS:新书已出《ChatGPT原理与实战》,欢迎购买。

往期推荐:

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动基于 DataLeap 的 DataOps 实践
随着数字化转型的推进以及业务数仓建设不断完善,大数据开发体量及复杂性逐步上升,如何保证数据稳定、正确、持续产出成为数据开发者核心诉求,也成为平台建设面临的挑战之一。本次分享主要介绍字节对于DataOps的理解 以及 DataOps在内部业务如何落地实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论