啥?小米也开源大模型啦!

大模型向量数据库机器学习

大家好,我是刘聪NLP。

刚刚刷到的,小米也开源了大模型-MiMo-7B。

HF: https://huggingface.co/XiaomiMiMo

Paper: https://github.com/XiaomiMiMo/MiMo

这是一个参数级别为7B的系列模型,共包括4个,Base,Base-Zero,SFT和RL。

  • MiMo-7B-Base:预训练模型
  • MiMo-7B-SFT:监督微调模型
  • MiMo-7B-Base-Zero:基于MiMo-7B-Base直接强化学习的模型
  • MiMo-7B-RL:基于MiMo-7B-SFT强化学习的模型

picture.image

模型结构是Dense模型,Transfomer-Decoder,但引入采用MTP(Multi-Token Prediction)机制,以加速推理。

picture.image

  
class MiMoMTPLayers(nn.Module):  
    def \_\_init\_\_(self, config):  
        super().\_\_init\_\_()  
        self.input\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)  
        self.post\_attention\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)  
        self.token\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)  
        self.hidden\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)  
        self.input\_proj = nn.Linear(config.hidden\_size * 2, config.hidden\_size, bias=False)  
        self.final\_layernorm = Qwen2RMSNorm(config.hidden\_size, eps=config.rms\_norm\_eps)  
        self.self\_attn = Qwen2Attention(config, layer\_idx=0)  
        self.mlp = Qwen2MLP(config)  
  
    def forward(self, input\_embeds,  
                    hidden\_states,  
                    attention\_mask,  
                    position\_ids,  
                    past\_key\_values: Optional[Cache]=None,  
                    output\_attentions: Optional[bool]=False,  
                    use\_cache: Optional[bool]=False,  
                    position\_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  
                    cache\_position=None,  
                    **kwargs):  
        input\_embeds = self.token\_layernorm(input\_embeds)  
        previous\_hidden\_states = self.hidden\_layernorm(hidden\_states)  
        hidden\_states = self.input\_proj(torch.cat([previous\_hidden\_states, input\_embeds], dim=-1))  
        residual = hidden\_states  
        hidden\_states = self.input\_layernorm(hidden\_states)  
        hidden\_states, \_ = self.self\_attn(hidden\_states,  
                                       attention\_mask=attention\_mask,  
                                       position\_ids=position\_ids,  
                                       past\_key\_values=past\_key\_values,  
                                       output\_attentions=output\_attentions,  
                                       use\_cache=use\_cache,  
                                       cache\_position=cache\_position,  
                                       position\_embedding=position\_embedding,  
                                       **kwargs)  
        hidden\_states = residual + hidden\_states  
        residual = hidden\_states  
        hidden\_states = self.post\_attention\_layernorm(hidden\_states)  
        hidden\_states = self.mlp(hidden\_states)  
        hidden\_states = residual + hidden\_states  
        hidden\_states = self.final\_layernorm(hidden\_states)  
        return hidden\_states  
  
  
class MiMoModel(Qwen2Model):  
    config\_class = MiMoConfig  
  
    def \_\_init\_\_(self, config: MiMoConfig):  
        super().\_\_init\_\_(config)  
        self.mtp\_layers = nn.ModuleList([MiMoMTPLayers(config) for \_ in range(config.num\_nextn\_predict\_layers)])  
  
  
class MiMoForCausalLM(Qwen2ForCausalLM):  
    config\_class = MiMoConfig  
    def \_\_init\_\_(self, config: MiMoConfig):  
        super(Qwen2ForCausalLM, self).\_\_init\_\_(config)  
        self.model = MiMoModel(config)  
        self.vocab\_size = config.vocab\_size  
        self.lm\_head = nn.Linear(config.hidden\_size, config.vocab\_size, bias=False)  
  
        self.post\_init()  

模型预训练阶段,整体数据量为25T Tokens,

预训练分为3个阶段:

  • 阶段1:通用数据训练,约19T Tokens数据,在此过程中对知识密度和推理深度不足的数据(广告、新闻、招聘信息等)进行了降采样,对专业领域的高质量数据进行上采样。
  • 阶段2:代码和数学数据针对性训练,约4T Tokens数据,为了不影响模型通用能力,混入部分通用数据,比例为7:3。
  • 阶段3:合成推理数据训练,约2T Tokens数据,主要为了提高模型解决复杂任务的能力,加入了合成的数学、代码和创造性写作数据,并将上下文长度从8192扩展到32768。

后训练阶段,主要就是SFT和RL。

SFT的数据大概有50K,主要来源就是开源数据和蒸馏数据,为了保证数据的多样性,经过了一系列数据清洗,最后在MiMo-7B-Base模型上训练。训练学习率为3e-5,Batch为128,样本会被pack到32k长度训练。

RL阶段包括100K的数学数据和30K的代码数据,均经过严格处理,去除大部分简单数据(SFT模型 rollout 16次都对的数据)。

采用修改的GRPO方法进行训练,对于每个问题

,算法从旧策略

中采样一组响应

,并通过最大化以下目标函数来更新策略

其中,

是超参数。

是优势值(advantage),通过同一组响应的奖励

计算得出:

改进优化点主要是移除 KL 损失、动态采样、Clip-Higher,同时在训练过程中使用基于测试复杂度驱动的奖励函数和简单的数据重采样方法。

效果上,base模型如下(没跟Qwen3比,做实验时Qwen3还没出):

picture.image

在数学(AIME 24-25)和 代码(LiveCodeBench v5)上,MiMo-7B-RL, o1-mini 和 QwQ-32B-Preview。

picture.image

最后,小米也开源啦,正式开卷大模型赛道了。

7B模型还是小了点,期望开源更多更大的模型,一会儿就去测一下,看看行不行!

PS:看到这里,如果觉得不错,可以来个点赞在看关注 。 给公众号添加【星标⭐️】不迷路!您的支持是我坚持的最大动力!

欢迎多多关注公众号「刘聪NLP」,加入交流群,交个朋友吧,一起学习,一起进步!

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
亿万用户下高可用融合直播的应用实践
直播融合 CDN 调度系统承担了公司内所有直播流量的接入工作,对高并发高带宽场景支持友好,有完善的体系进行容灾降级、质量优化、成本优化。本次演讲将带大家了解直播融合 CDN 调度系统的整体架构及在抖音上的应用。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论