写在前面
大家好,我是刘聪NLP。
2024年第一天,给大家带来一篇前一段时间大火的Mixtral-8x7B分析文章-《Mixtral-8x7B 模型挖坑》,来自知乎@孟繁续(已授权)。
知乎原文:https://zhuanlan.zhihu.com/p/674751021
MistralAI很高冷的给开源社区扔了一条磁力链,基于Mixture of Experts的混合专家模型Mixtral-8x7B和指令微调的Mixtral-8x7B-Instruct来了。此前曾爆料GPT4就是基于MoE技术的大模型,MistralAI证明通过不到8个7B的参数量,不到2个7B模型的计算量,就能超越LLaMA 2 70B的效果,甚至部分超越了GPT-3.5的水平,随即这两个模型引爆社交网络。截至目前,官网展示了Mixtral-8x7B的模型效果:
图1. Mistral 8x7B超越LLaMA 2 70B和GPT-3.5
Mistral官网:https://mistral.ai/news/mixtral-of-experts/
Mixtral-8x7B: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1
Mixtral-8x7B-Instruct: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
模型的命名方式也充满野心, 新的7B模型只叫了个小小杯,效果这么好的8x7B MoE模型叫了个小杯,而在La plateforme中可以申请调用一个中杯模型的API(也许是8x13b、8x34B?),推测大杯和超大杯应该也在路上了。
图2.小小杯-小杯-中杯效果对比
La plateforme: https://mistral.ai/news/la-plateforme/
结构介绍
Mixtral-8x7B和LLaMA结构唯一的区别,在于将MLP layer复制成了8个expert layers并在一起,通过一个gate layer,对每个token选择top-2的专家模型进行计算,这里结合transformers中的代码和图示理解会比较好:
# 注意:为了容易理解,我对代码进行了简化,同时不考虑batch size,实际使用时还是要用官方代码
class MixtralSparseMoeBlock(nn.Module):
def \_\_init\_\_(self, config):
super().__init__()
self.gate = nn.Linear(self.hidden_dim, 8)
self.experts = nn.ModuleList([MLP(config) for _ in range(8)])
def forward(self, x):
# 对每个token计算8个expert的权重,并将权重归一化
router_logits = self.gate(x)
routing_weights = F.softmax(router_logits, dim=1)
# 每个token选择top-2 experts的权重、索引, 并将索引转为size=(len(tokens), 8)的独热编码
routing_weights, selected_experts = torch.top2(routing_weights, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=8)
# 重新将top-2 expert的权重归一化(因为删掉其余6个expert,权重的和不等于1了)
routing_weights /= routing_weights.sum(dim=-1)
# 创建形状和x一致,初始值为0的矩阵,用来存储每个expert的输出
final_hidden_states = torch.zeros_like(x)
for expert_idx in range(8):
# 选择当前使用的expert
expert_layer = self.experts[expert_idx]
# 选择当前expert对应的index
idx_list, top_x_list = torch.where(expert_mask[expert_idx])
# 选择需要计算的状态
current_state = x[top_x_list]
# 选择当前expert对每个token的权重
current_routing_weights = routing_weights.t()[top_x_list, idx_list]
# 将选择的状态输入给专家模型计算,并乘上权重
current_hidden_states = expert_layer(current_state) * current_routing_weights
# 将每个expert的输出按照索引加到最终结果里
final_hidden_states.index_add_(0, top_x_list, current_hidden_states)
return final_hidden_states
MixtralSparseMoeBlock
如图所示为Mixtral的MoE FFN的示意图,
首先,对于输入,先乘上一个 的gate layer,得到的表示(router),用softmax对其归一化之后,选出top-2的专家的权重和索引,将索引转为稀疏矩阵expert_mask。
接着循环对8个experts重复执行以下操作:选择expert(如图中),每个expert只需要处理自己是top-2的tokens,利用expert_mask得到这些tokens(值得注意的是,每个expert需要处理的长度不一,但8个experts需要处理的序列长度之和刚好等于2倍出入序列长度,因为每个tokens都会被两个experts处理)。
最后将每个expert的输出加权平均得到输出。
由于官方没有对Mixtral-8x7B的训练方法做任何说明,我希望通过实验验证一些问题,试图了解一点可能的训练流程。大部分实验基于Mixtral-8x7B-v0.1模型,MMLU的代码基于chain-of-thought-hub,用了4bit量化,所以测出来的结果会低于官方公布的结果。
chain-of-thought-hub: https://github.com/FranxYao/chain-of-thought-hub/tree/main
问题1. 训练的时候也是用的top-2 experts吗?
有这个问题是考虑,有没有可能训练的时候激活了全部的experts,推理时为了平衡速度和效果才选择的top-2 experts。这个问题的答案比较容易获得,只需要更改激活experts的数量,在MMLU上测试其效果就可以:
激活x/8个experts的效果
实验发现,在top-1效果下降很多,top-3略微上升,继续增多被激活的experts数量,效果会下降。如果训练时激活了更多的experts,这条曲线应该是平稳中有一点上升,而不会下降。所以这个猜测是多余的,模型确实是用top-2训练的(杠一下的话,top-3也不是没有可能)。
问题2. 8个专家模型的贡献一样吗?
为了解答这个问题,我们使用select_experts代码,逐一删除每个专家模型,并删除对应的gate layer,然后在MMLU上测试所得模型的能力:
删除一个expert后,在MMLU数据集上的表现
可以发现,删除expert3后,模型直接fail,但是删除experts-0,2,4,5,6,7专家后的表现差不多,可以说一句,Mixtral-8x7B的负载均衡是没做好的。这里放一个大家很熟悉的表情包,显然experts是在坑里干活的那个。
另外每次激活top-2和top-1的趋势是一致的,为了节省计算量和时间,后续的实验出特殊说明,都用了top-1。
select_experts代码: https://github.com/fxmeng/mixtral_spliter/blob/11550c700540fb623dea22e42ee55a57536af1b8/select_experts.py
问题3. 每个专家模型预训练阶段是否针对不同的任务进行的训练?
为了解答这个问题,需要在不同的数据集上进行删除每个expert的测试,除了MMLU,我还测试了MT-Bench,之后还会测测Math,Code的能力。MT-Bench的曲线和MMLU的几乎一致:
移除某个专家模型后在MT Bench的效果(激活top-1 expert)
暂时看起来每个专家模型预训练的任务并没有太大差异,不过这个结论需要等更多任务测试结果出来才敢确定。
问题4. 专家模型数量对效果的影响
这里才用了贪心法,逐一删除experts,每一行代表基于上一次的最优模型,删除每一个expert的结果。为了看起来容易,每一列的顺序调整成了删除的顺序[2, 5, 6, 5, 7, 0, 1, 3]。这个顺序一定程度上反映了对应专家模型的重要性,越往后越高。可以发现在任何时候,试图删除expert 3,掉点都会很严重。
按照[2, 5, 6, 5, 7, 0, 1, 3]的顺序删除或添加专家模型,专家数量和效果的关系在下图中直观的展示出来了:
效果随着模型数量变化趋势
可以看出来,每个模型删除的时候都有比较明显的掉点,说明除了在坑里的expert 3,其他专家也是有他们的作用的。
得到的每个数量贪心最优的子模型分享在了Mixtral-1-7x7b-instruct,Mixtral-1-7x7b。如果硬盘,显存不够可以按照实际硬件情况,选择这个贪心最优子模型来学习,使用Mixtral。调用的方法和原始的模型没有区别,生成子结构的代码也放在每个checkpoint里了,如果你本地有完整版的模型,运行代码就可以自己生成这个子模型了。未来也许会finetune一下3-6个experts的子模型,应该会有比较明显的提升,毕竟模型还是有些脆弱的。
Mixtral-1-7x7b-instruct: https://huggingface.co/collections/fxmeng/mixtral-1-7x7b-instruct-v01-658bf0268efe4503a06da312
Mixtral-1-7x7b: https://huggingface.co/collections/fxmeng/mixtral-1-7x7b-v01-658bf0a9674aa14b9b07d2d4
问题5. 在预训练阶段是否先独立训练了8个7B的模型,然后再把FFN合在一起训?
在Twitter上看到一张图,比较Mixtral-8x7B和Mistral-7B的attention layer中的QKVO矩阵的cosine similarity,发现相似度比较高,由此得出结论是Mixtral是将Mistral复制了8份。这个结论下的有点容易,所以我打算把实验扩展一下。
https://x.com/tianle\_cai/status/1734188749117153684?s=20
首先需要知道这个相似度代表了什么,我测了
- LLaMA-7B、LLaMA-2 7B以及OpenLLaMA-7B的相似度接近于0;
- LLaMA-2 和 LLaMA-2-Chat的相似度接近1;
- Mistral-7B-v0.1,和Mistral-7B-Instruct-v0.1,Mistral-7B-Instruct-v0.2的相似度也为1。
LLaMA、LLaMA-2,OpenLLaMA的对比说明对于相同架构的模型,训练数据相差也不大的情况下,每次独立的训练,相似度极低。
LLaMA-2 和LLaMA-2-Chat的对比,以及Mistral-7B-v0.1,和Mistral-7B-Instruct-v0.1,Mistral-7B-Instruct-v0.2的对比说明,对于训好的模型finetune,即使模型的输出分布发生了很大的变化,相似度也极高。
那可以得出Mixtral-8x7B和Mistral-7B的attention layers之间肯定是有继承关系的。
那FFN呢?8个experts参数都是哪里来的?我在FFN上也测试了和Mistral-7B的相似度,结果如下:
Expert 0
Expert 1
Expert 2
Expert 3
Expert 4
Expert 5
Expert 6
Expert 7
可以看出相似度均在40%左右,说明训练每个experts的FFN和Mistral-7B也是有关系的。我又测试了每个expert之间,每层的平均相似度如下:
可以发现expert 3和其他experts的相似度比较低,关于这个expert的特殊性,之后还将继续探索。
此外experts彼此之间的相似度比和Mistral-7B的低。
这里做一个猜测:Mixtral-8x7B的模型是利用Mistral-7B训练过程中early-stage的checkpoint训练而来的。Attention layer是直接继承的checkpoint,FFN是复制了8份,然后加上gate layer继续完成后续的预训练的。这样就能解释Attention和MLP有一定的相似度,但不太高。这不会是独立训练8个模型或者在模型收敛后finetune一下能得到的结果。
问题6. 只有单卡3090能不能finetune Mixtral?
很多时候我们可能需要调试代码,或者自学过程中,没有A100,能不能finetune Mixtral 8x7B?答案是不能。不过别急,只要用我拆开的模型,就能流畅的在单张3090上进行微调。下面是代码(基于learn-llm项目),实测finetune 2个专家模型只需13G显存,3个专家模型需要18G显存,4个专家模型需要21G显存:
from trl import SFTTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments,AutoConfig
from datasets import load_dataset
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
## load model in 4bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
)
config = AutoConfig.from_pretrained('fxmeng/Mixtral-2x7B-Instruct-v0.1')
config.use_cache = False
config.gradient_checkpointing = True
model = AutoModelForCausalLM.from_pretrained('fxmeng/Mixtral-2x7B-Instruct-v0.1',
config=config,
quantization_config=bnb_config,
trust_remote_code=False,
torch_dtype=torch_dtype,
device_map="auto")
tokenizer = AutoTokenizer.from_pretrained('fxmeng/Mixtral-2x7B-Instruct-v0.1',
trust_remote_code=False,
use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
# lora
peft_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.1,
target_modules=['k\_proj', 'w3', 'v\_proj', 'gate', 'lm\_head', 'q\_proj', 'w2', 'w1', 'o\_proj'],
bias="none",
task_type="CAUSAL\_LM",
inference_mode=False
)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model = get_peft_model(model, peft_config)
# training setting
training_args = TrainingArguments(
do_train=True,
do_eval=True,
output_dir="./checkpoints",
dataloader_drop_last=True,
evaluation_strategy="steps",
save_strategy="steps",
logging_strategy="steps",
num_train_epochs=1,
eval_steps=10,
save_steps=10,
logging_steps=10,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
optim="paged\_adamw\_8bit",
learning_rate=1e-4,
lr_scheduler_type='constant',
warmup_steps=50,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
weight_decay=0.05,
report_to="wandb",
load_best_model_at_end=True,
save_total_limit=1,
bf16=torch_dtype==torch.bfloat16,
fp16=torch_dtype!=torch.bfloat16,
)
# training and evaluation use the same dataset
dataset = load_dataset("imdb", split="train")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=dataset,
dataset_text_field="text",
max_seq_length=1024,
tokenizer=tokenizer,
data_collator=None,
packing=None
)
trainer.train()
trainer.save_model('mixtral\_2x7b')
4个experts用来学习和调试代码的话,效果看起来还不错:
from transformers import AutoTokenizer
import transformers
import torch
model = "fxmeng/Mixtral-4x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
model_kwargs={"torch\_dtype": torch.float16, "load\_in\_4bit": True},
)
messages = [{"role": "user", "content": "Explain what a Mixture of Experts is in less than 100 words."}]
prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated\_text"])
from transformers import AutoTokenizer
import transformers
import torch
model = "fxmeng/Mixtral-4x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
model_kwargs={"torch\_dtype": torch.float16, "load\_in\_4bit": True},
)
learn-llm github: https://github.com/hengjiUSTC/learn-llm/tree/main
问题7. MoE每层之间的expert有无关联?
按照前文的猜测,在一个模型被训练到一定阶段后,被复制成了8份,此时每一层参数的每个expert参数一样,然后继续训练,此时每个expert不再是独立的个体,因此最终每层之间的expert不再有关联。表现为打乱每层expert的顺序,不影响模型的输出。因此可以说Mixtral-8x7B不是8个experts,而是32x8个experts。那么问题就来了,这256个experts对模型输出的影响是多大?这个问题如果还像之前一样一个一个的挖掉,计算量就太大了。因此我使用了另一种方法:将整个数据集输入模型,不再需要进行生成,而是记录下每个token在每一层激活的expert。然后就能得到一个32x8的矩阵,每个元素表示这个第层,第个expert被激活的次数。然后每层按照每个expert被激活的次数从大到小,可以认为其重要性逐渐降低。
# 作为top-1 expert被激活的次数排序
[[1, 2, 3, 6, 0, 4, 7, 5],
[2, 1, 5, 6, 7, 4, 3, 0],
[6, 5, 2, 4, 0, 1, 7, 3],
[6, 4, 0, 3, 7, 2, 5, 1],
[3, 6, 1, 4, 0, 7, 2, 5],
[4, 3, 6, 7, 0, 5, 1, 2],
[5, 3, 6, 1, 4, 0, 2, 7],
[5, 2, 1, 4, 6, 0, 7, 3],
[5, 7, 3, 1, 6, 4, 2, 0],
[4, 6, 5, 0, 1, 2, 3, 7],
[4, 0, 3, 1, 6, 5, 2, 7],
[5, 1, 4, 0, 7, 6, 2, 3],
[4, 1, 3, 6, 0, 2, 5, 7],
[5, 1, 0, 7, 6, 3, 4, 2],
[3, 7, 6, 0, 5, 2, 1, 4],
[7, 0, 6, 2, 5, 1, 3, 4],
[7, 1, 5, 3, 6, 2, 0, 4],
[0, 7, 5, 6, 1, 4, 2, 3],
[3, 7, 2, 1, 4, 0, 6, 5],
[0, 4, 2, 7, 6, 5, 1, 3],
[3, 7, 5, 6, 0, 1, 2, 4],
[3, 5, 0, 1, 6, 4, 2, 7],
[7, 1, 4, 0, 3, 6, 5, 2],
[1, 3, 5, 6, 0, 2, 7, 4],
[5, 2, 3, 1, 7, 6, 4, 0],
[1, 4, 0, 3, 6, 5, 2, 7],
[2, 1, 4, 3, 5, 0, 7, 6],
[3, 5, 2, 0, 7, 4, 1, 6],
[2, 3, 4, 1, 7, 5, 0, 6],
[4, 1, 5, 2, 6, 7, 3, 0],
[0, 7, 1, 4, 5, 3, 2, 6],
[2, 6, 5, 0, 7, 1, 3, 4]]
# 作为top-2 expert被激活的次数排序
[[5, 2, 7, 3, 4, 0, 6, 1],
[3, 5, 2, 4, 7, 1, 0, 6],
[5, 1, 6, 7, 2, 0, 3, 4],
[0, 3, 4, 5, 7, 1, 6, 2],
[0, 3, 7, 2, 5, 6, 4, 1],
[1, 7, 2, 5, 4, 6, 3, 0],
[6, 0, 7, 3, 2, 5, 4, 1],
[3, 2, 4, 7, 0, 6, 5, 1],
[4, 0, 3, 2, 7, 5, 6, 1],
[5, 2, 4, 3, 1, 0, 7, 6],
[5, 1, 7, 3, 6, 2, 4, 0],
[6, 0, 2, 1, 7, 5, 4, 3],
[0, 1, 2, 5, 4, 3, 6, 7],
[1, 4, 2, 0, 5, 3, 7, 6],
[5, 1, 0, 3, 2, 4, 6, 7],
[2, 5, 4, 6, 1, 0, 7, 3],
[7, 6, 4, 1, 5, 2, 0, 3],
[1, 2, 7, 0, 6, 5, 4, 3],
[6, 4, 0, 1, 5, 7, 3, 2],
[7, 5, 6, 2, 0, 3, 1, 4],
[2, 1, 3, 0, 5, 6, 7, 4],
[0, 1, 3, 7, 2, 4, 6, 5],
[1, 3, 7, 2, 0, 4, 6, 5],
[5, 2, 1, 0, 3, 7, 4, 6],
[2, 5, 1, 6, 7, 3, 0, 4],
[6, 7, 2, 5, 0, 1, 3, 4],
[2, 5, 0, 4, 6, 7, 3, 1],
[6, 3, 5, 4, 2, 1, 0, 7],
[0, 1, 7, 6, 2, 5, 4, 3],
[3, 4, 0, 2, 7, 1, 5, 6],
[4, 3, 2, 0, 7, 6, 5, 1],
[5, 7, 3, 4, 2, 0, 6, 1]]
将模型的每一层按照2 x top1 + top2次数排序,能得到一个新的模型,其输出和原模型一致,但是每层的第一个expert都是这层里面最重要的,最后一个expert都是这层里面最不重要的。然后就可以用问题4中的方法,根据自己的硬件条件,选择对应数量的experts。使用新的方法获得的子结构比之前把每个expert当作一个整体,使用贪心法得到的子结构精度损失明显减小:
写在最后
本文简单介绍了一下Mixtral 8x7B模型的原理,随后通过5个问题,试图还原其训练流程,了解模型特性。用贪心算法搜出了1-7个专家的子结构,推测Mixtral-8x7B是利用Mistral-7B训练过程中early-stage的checkpoint训练而来的。受限于笔者水平,肯定有不足之处,一家之言,恳请大家批评指正。
欢迎多多关注公众号「NLP工作站」,欢迎加入交流群,有问题的朋友也欢迎加我微信「logCong」私聊,交个朋友吧,一起学习,一起进步。我们的口号是“生命不止,学习不停”。
PS:新书已出《ChatGPT原理与实战》,欢迎购买。
往期推荐:
- 大模型微调技巧 | 高质量指令数据筛选方法-MoDS
- 辟谣!微软撤回声称ChatGPT为20B参数的论文,并给出解释。
- 如何看待微软论文声称 ChatGPT 是 20B (200亿) 参数量的模型?
- 大模型微调技巧-在Embeeding上加入噪音提高指令微调效果
- 如何从数据集中自动识别高质量的指令数据
- BaiChuan2技术报告细节分享&个人想法
- 大模型LLM微调经验总结&项目更新
- 打造LLM界的Web UI
- 是我们在训练大模型,还是大模型在训练我们?
- Llama2技术细节&开源影响
- 大模型时代-行业落地再思考
- 垂直领域大模型的一些思考及开源模型汇总
- 如何评估大模型-LLMs的好坏?
- 总结|Prompt在NER场景的应用