大家好,我是刘聪NLP。
刚刚刷到的,小米也开源了大模型-MiMo-7B。
HF: https://huggingface.co/XiaomiMiMo
Paper: https://github.com/XiaomiMiMo/MiMo
这是一个参数级别为7B的系列模型,共包括4个,Base,Base-Zero,SFT和RL。
- MiMo-7B-Base:预训练模型
- MiMo-7B-SFT:监督微调模型
- MiMo-7B-Base-Zero:基于MiMo-7B-Base直接强化学习的模型
- MiMo-7B-RL:基于MiMo-7B-SFT强化学习的模型
模型结构是Dense模型,Transfomer-Decoder,但引入采用MTP(Multi-Token Prediction)机制,以加速推理。
class MiMoMTPLayers(nn.Module):
def \_\_init\_\_(self, config):
super().\_\_init\_\_()
self.input\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)
self.post\_attention\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)
self.token\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)
self.hidden\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)
self.input\_proj = nn.Linear(config.hidden\_size * 2, config.hidden\_size, bias=False)
self.final\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)
self.self\_attn = Qwen2Attention(config, layer\_idx=0)
self.mlp = Qwen2MLP(config)
def forward(self, input\_embeds,
hidden\_states,
attention\_mask,
position\_ids,
past\_key\_values: Optional[Cache]=None,
output\_attentions: Optional[bool]=False,
use\_cache: Optional[bool]=False,
position\_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache\_position=None,
**kwargs):
input\_embeds = self.token\_layernorm(input\_embeds)
previous\_hidden\_states = self.hidden\_layernorm(hidden\_states)
hidden\_states = self.input\_proj(torch.cat([previous\_hidden\_states, input\_embeds], dim=-1))
residual = hidden\_states
hidden\_states = self.input\_layernorm(hidden\_states)
hidden\_states, \_ = self.self\_attn(hidden\_states,
attention\_mask=attention\_mask,
position\_ids=position\_ids,
past\_key\_values=past\_key\_values,
output\_attentions=output\_attentions,
use\_cache=use\_cache,
cache\_position=cache\_position,
position\_embedding=position\_embedding,
**kwargs)
hidden\_states = residual + hidden\_states
residual = hidden\_states
hidden\_states = self.post\_attention\_layernorm(hidden\_states)
hidden\_states = self.mlp(hidden\_states)
hidden\_states = residual + hidden\_states
hidden\_states = self.final\_layernorm(hidden\_states)
return hidden\_states
class MiMoModel(Qwen2Model):
config\_class = MiMoConfig
def \_\_init\_\_(self, config: MiMoConfig):
super().\_\_init\_\_(config)
self.mtp\_layers = nn.ModuleList([MiMoMTPLayers(config) for \_ in range(config.num\_nextn\_predict\_layers)])
class MiMoForCausalLM(Qwen2ForCausalLM):
config\_class = MiMoConfig
def \_\_init\_\_(self, config: MiMoConfig):
super(Qwen2ForCausalLM, self).\_\_init\_\_(config)
self.model = MiMoModel(config)
self.vocab\_size = config.vocab\_size
self.lm\_head = nn.Linear(config.hidden\_size, config.vocab\_size, bias=False)
self.post\_init()
模型预训练阶段,整体数据量为25T Tokens,
预训练分为3个阶段:
- 阶段1:通用数据训练,约19T Tokens数据,在此过程中对知识密度和推理深度不足的数据(广告、新闻、招聘信息等)进行了降采样,对专业领域的高质量数据进行上采样。
- 阶段2:代码和数学数据针对性训练,约4T Tokens数据,为了不影响模型通用能力,混入部分通用数据,比例为7:3。
- 阶段3:合成推理数据训练,约2T Tokens数据,主要为了提高模型解决复杂任务的能力,加入了合成的数学、代码和创造性写作数据,并将上下文长度从8192扩展到32768。
后训练阶段,主要就是SFT和RL。
SFT的数据大概有50K,主要来源就是开源数据和蒸馏数据,为了保证数据的多样性,经过了一系列数据清洗,最后在MiMo-7B-Base模型上训练。训练学习率为3e-5,Batch为128,样本会被pack到32k长度训练。
RL阶段包括100K的数学数据和30K的代码数据,均经过严格处理,去除大部分简单数据(SFT模型 rollout 16次都对的数据)。
采用修改的GRPO方法进行训练,对于每个问题
,算法从旧策略
中采样一组响应
,并通过最大化以下目标函数来更新策略
,
其中,
和
是超参数。
是优势值(advantage),通过同一组响应的奖励
计算得出:
改进优化点主要是移除 KL 损失、动态采样、Clip-Higher,同时在训练过程中使用基于测试复杂度驱动的奖励函数和简单的数据重采样方法。
效果上,base模型如下(没跟Qwen3比,做实验时Qwen3还没出):
在数学(AIME 24-25)和 代码(LiveCodeBench v5)上,MiMo-7B-RL, o1-mini 和 QwQ-32B-Preview。
最后,小米也开源啦,正式开卷大模型赛道了。
7B模型还是小了点,期望开源更多更大的模型,一会儿就去测一下,看看行不行!
PS:看到这里,如果觉得不错,可以来个点赞 、在看 、关注 。 给公众号添加【星标⭐️】不迷路!您的支持是我坚持的最大动力!
欢迎多多关注公众号「刘聪NLP」,加入交流群,交个朋友吧,一起学习,一起进步!