Matryoshka 与 Mamba2的融合:MatMamba在语言与图像模型上的突破 !

大模型机器学习数据库

点击下方卡片,关注 「AI视界引擎」 公众号

( 添加时备注:方向+学校/公司+昵称/姓名 )

picture.image

picture.image

状态空间模型(SSMs)如Mamba2是 Transformer 的有前景的替代方案,具有更快的理论训练和推理时间 - 尤其是对于长上下文长度。

最近关于Matryoshka表示学习的工作 - 以及其在MatFormer等工作中应用于 Transformer Backbone 的应用 - 展示了如何在通用弹性模型中引入嵌套的小型子模型的层次结构。

在本工作中,作者提出了MatMamba:一种结合了Matryoshka风格学习与Mamba2的状态空间模型,通过修改块以包含嵌套维度来实现联合训练和自适应推理。

MatMamba允许在各种模型大小上实现高效和自适应部署。

作者训练了一个单一的大型MatMamba模型,并能够免费获得多个较小的嵌套模型 - 同时保持或改进了从零训练的 Baseline 较小模型的性能。

作者在35M到1.4B的参数大小的语言和图像模型上进行了训练。

作者的ImageNet和FineWeb结果表明,MatMamba模型与 Transformer 具有可比扩展性,同时具有更高效推理特性。

这使得MatMamba成为在可用的推理计算基础上以弹性方式部署大规模模型的实际可行选项 。

代码和模型已在https://github.com/ScaledFoundations/MatMamba开源。

1 Introduction

深度学习从业者通常会训练同一类模型的不同大小,以适应各种可用的推理计算范围。例如,Llama 3.2系列包括10亿、30亿、100亿和900亿的变化。这些模型单独都非常强大,但由于独立训练,它们不一定具有相同的度量空间,这可以在推理应用中非常有益,如预测解码,混合云边缘推理,或者一般的输入或计算自适应处理。此外,由于训练这些模型非常昂贵,作者通常只选择训练几个大小。在部署设置可以最佳支持中间模型 的情况下,但不得不选择更不准确的10亿模型。

压缩和蒸馏等方法旨在解决这些问题,但需要额外的训练(有时数据不可用),并且可能会降低准确性。因此,在中间粒度提供自适应推理的方法非常有用。这已经在 Transformer 和卷积神经网络中进行了探索。本工作的核心重点是试图在新型架构 Mamba2 中实现开箱即用的自适应推理。

像Mamba2这样的状态空间模型(Dao & Gu,2024年)以及其他一些相关的新型架构在努力提高 Transformer 的效率的同时,保持其准确和通用的序列处理架构的力量方面具有巨大的潜力。Mamba2具有与Transformer相当的扩展特性,同时在较长的时间步长下具有显著的加速效果。

在这项工作中,作者提出了MatMamba,它是一种嵌套的Matryoshka结构,位于Mamba2块内部。MatMamba可以在部署时从相同的一组权重中提取数百个嵌套子模型,而无需进行任何额外的训练。MatMamba是一种通用的序列处理架构,可以应用于任何类型的模型(编码器/解码器),模态(语言/视觉/声音/动作),损失函数,或与Transformer或Mamba2层兼容的学习算法。

哲学上与MatMamba最接近的工作是MatFormer 它在Transformer层中的FFN块上施加嵌套结构。作者使用与MatMamba块中的学习参数依赖于块的隐层维数相同的概念,对Mamba2块中的任何可学习参数施加嵌套结构。

形式上,MatMamba块由嵌套组合的个Mamba2块组成,即,其中意味着子块的所有参数都存在于中。作者使用个带有梯度累积的前向传播和参数更新的单向传播进行训练(见图1)。

picture.image

通过联合训练所有个粒度,最小的子块被激励去表示最重要的信息,就像在Matryoshka表示学习(Kusupati等人,2022年)中一样。现在作者可以灵活地使用个嵌套子块。此外,作者还可以灵活地将块沿着任何维数(甚至超出明确优化的粒度)进行切割。

使用Mix'n'Match(第3.4节),作者可以在不同粒度的多个层上进行此操作,以灵活地从单个较大模型中提取出组合数量巨大的多个模型。作者观察到这些提取的模型保留了较大模型的度量空间,并且在各种测试任务上都是准确的 - 实际上,这使作者能够在模型性能和计算之间进行权衡。

作者训练了基于MatMamba的视觉模型(MatMamba-Vision),并发现:

(a)在所有的粒度下,MatMamba-Vision模型与 Baseline Mamba2基于模型具有相同的扩展能力;

(b)使用Mix'n'Match,作者可以在明确优化的粒度之间灵活地提取子模型。这些子模型跨越(有时甚至超过)帕累托最优的计算-精度曲线;

(c)MatMamba-Vision模型在更高分辨率下比ViTs显著更快,使其成为长文和高质量视觉任务的 promising 候选者,同时支持嵌套子模型的自适应视觉处理。

此外,MatMamba-Vision模型可以作为弹性图像编码器用于自适应图像检索。作者可以用最大的模型来编码视觉数据集,因为较小的子模型共享其度量空间,作者可以将它们作为 Query 编码器,在计算量大幅度降低的同时,准确率几乎无损。

作者还训练了不同大小的基于MatMamba的解码语言模型(MatMamba-LM),参数范围从1.3亿到140亿,粒度 Level 为4。在这里,作者也观察到类似的现象,即对于所有嵌套粒度,MatMamba-LM模型与Mamba2 Baseline 具有相同的架构时,其规模也相当。作者还观察到不同模型之间嵌套粒度之间的有趣的一致缩放行为。

通过MatMamba,作者首次将Matryoshka式学习的适应性以及像Mamba2这样的状态空间模型(SSMs)的有效性结合在一起。

作者的研究贡献如下:

  1. 作者提出了MatMamba,它将一个嵌套的Matryoshka结构应用在Mamba2状态空间模型上。作者同时优化了所有嵌套的粒度,以训练一个单一的弹性模型。
  2. 作者证明了在语言和视觉任务中,从35M到1.4B参数的各种模型大小的MatMamba模型与 Baseline Mamba2模型具有相同的扩展性。
  3. 使用Mix'n'Match与MatMamba,可以灵活地提取出数百个子模型以进行自适应推理。这些子模型保留了原始模型的度量空间。
  4. MatMamba-Vision模型在高分辨率时的准确性和速度与ViTs相当,且明显优于ViTs,使其非常适合长文本/高分辨率以及自适应视觉处理。

2 Related Work

随着AI模型在各种准确性和资源约束下需求的不断增长,为每个用例训练不同的模型变得不切实际。相反,这些自适应部署需求通常通过在模型中引入弹性来解决。关于可收缩网络和一次性为所有网络的工作带来了在单个通用模型中训练多个子模型的想法。嵌套dropout 将这一想法扩展到学习有序表示,进一步通过Matryoshka表示学习(MRL)实现弹性。MRL通过引入少量嵌套粒度(因此得名Matryoshka),它们在大小上以指数 Level 相互分离,并使用与全向量相同的优化目标损失函数简化训练过程,从而实现弹性。MRL还可以平滑地扩展到训练期间未见到的粒度,从而允许根据需求提取子向量。

Matryoshka信息打包和学习已在广泛应用中实现适应性,不仅体现在输出空间,还体现在输入和模型权重。MatFormer 是MRL直接转换到 Transformer 层中的每个隐藏激活向量。MatFormer 显示了类似于Transformer的扩展趋势,同时提供了能够自适应提取位于准确性-计算帕累托曲线上的子模型的能力。近年来的一些工作在MatFormer支持下,开发了在条件计算上实现动态路由以实现部署中的性能提升。此外,Matryoshka打包也被用于灵活的 Token 化(Cai等人,2024年;Hu等人,2024年)以及扩散模型(Gu等人,2023年)。

近年来,Transformer(Vaswani等人,2017年)已成为神经网络中基本的序列处理模块。最近出现了一系列旨在实现更快且性能与Transformer相当的高效序列处理架构的工作,如Mamba(Gu和Dao,2023年)、Mamba2(Dao和Gu,2024年)等,此外还有其他非常相关的作品,如线性关注(Katharopoulos等人,2020年)、测试时训练、RWKV(Peng等人,2023年)、 Griffin(De等人,2024年)、Jamba(Lieber等人,2024年)、XLSTM(Beck等人,2024年)、HGRN2(Qin等人,2024年)、RetNet 、RecurrentGemma(Botev等人,2024年)。Waleffe等人(2024年)详细研究了如何训练大规模的基于Mamba的语言模型。

MambaVision、MambaND(Li等人,2024年)、Vision Mamba、VideoMamba(Li等人,2024年)和 Sonic 等作品都展示了如何使用Mamba层处理视觉数据和其他模态。刘等人(2024年)对基于Mamba的视觉模型进行了详细调查。

3 MatMamba

Mamba2 Preliminaries

MatMamba 是基于 Mamba2 的。作者对 Mamba2 块进行了简单的修改,以实现马特罗什卡结构。关于 Mamba2 内部结构的详细描述,请参见原始论文 Dao 和 Gu (2024)。然而,在本工作中,作者将 Mamba2 块视为一个输入线性投影(,可以分解为 ,,,,),一个具有核大小 4 的因果 1D 卷积层(权重是 , 和 的 ConCat 组合,在组内应用),一个块 + 选择性扫描操作(SSM),以及一个输出投影层()。与 Transformer 类似,这个块输入一个 形状的张量 - 是批量大小, 是序列长度, 是维数 - 经过序列变换后产生一个 形状的输出。对于输入张量 ,Mamba2 块 包括以下步骤:

Training

为了训练一个由MatMamba块组成的模型,针对个选择的粒度,作者执行次前向传播来计算一个联合损失函数。对于输入,模型,目标和损失函数:

picture.image

在本文中,作者训练了g=4个嵌套子模型,每个子模型的权重λi均为0.25(即g)。如图1所示,在每次前向传播过程中,作者累加梯度。参数更新通过一次反向传播完成。在整个过程中,模型和权重保持不变,从而使内存使用与普通Mamba2块相同。在本工作中,作者使用具有g=4嵌套粒度的MatMamba模型,相应的m_i列表为[d_model, d_model/2, d_model/4, d_model/8],即每个子模型的维数减半。

与MatFormer(Devvrit等人,2023年)和Flextron(Cai等人,2024b)类似,作者指出,也可以对现有预训练模型进行微调以生成嵌套结构。

然而,在本工作中,作者关注从零开始训练,以研究MatMamba模型的扩展特性。

Mix'n'Match

作者可以将来自MatFormer(Devvrit等人,2023年)的Mix'n'Match策略应用到MatMamba中,以灵活地从MatMamba中提取任何子模型进行推理。具体来说,对于具有L层的模型f,作者每个层需要选择一个维数m_{i}。值得注意的是,m_{i}可以是显式优化的g粒度之一(例如,从135M-MatMamba-Vision模型中选择[1024, 512, 256, 128]中的一个,见第4.1节),也可以选择不是显式优化的插值维数(例如,选择任何随机有效值,如768或384,这些值没有显式训练)。

Elastic Inference

在部署MatMamba模型进行推理时,作者通常需要将单个通用模型存储在内存中。如果计算不受限制(或推理工作负载可预测),则可以使用完整的模型来获得最准确的结果。然而,根据动态约束(例如可用的推理计算、能耗、系统负载、所需精度等),作者可以在运行时对网络中选定的切片进行前向传播。

结合云边推理,作者有令人激动的可能性,例如在边缘设备上存储较小的模型 ,并在必要时使用在云端的大型模型 ,或者使用较小的模型作为推测解码的草稿模型,与较大的验证模型相结合(Leviathan等人,2023)。作者还可以潜在地实现输入自适应子模型选择(例如,对于更难的输入,使用更大的模型)。这些可能性之所以能够实现,是因为MatMamba具有一致且嵌套的Matryoshka结构,其中所有子模型共享相同的度量空间。

在本节中,作者展示了基于MatMamba的模型在两种模态(MatMamba-Vision 和 MatMamba-LM)上的有效性。对于视觉,作者展示了图像分类(第4.1.1节)和自适应图像检索(第4.1.2节)的结果。对于语言,作者训练了解码语言模型(第4.2节)。作者从35M到1.4B参数的多种尺度训练了模型。为了进行公平的比较,作者还独立训练了具有与每个MatMamba粒度子模型相同架构的 Baseline Mamba2模型。

请注意,在本工作中,作者并未针对所选模型大小在语言或视觉上实现最先进的结果。相反,作者关注诸如嵌套结构一致性、参数减少、子模型的推理速度/内存使用以及使用MatMamba块构建的简单网络的扩展性等属性。

MatMamba-Vision

图2中的MatMamba-Vision包含一个patch embedding,后面跟着L个MatMamba块,采用单向SSM扫描。作者做出一个关键的设计选择,就是使用[CLS] Token 作为后缀而不是传统的前缀。这使得它可以关注整个序列的信息。作者发现这种简单的架构在图像分类和自适应检索方面都有效。作者在ImageNet-1k数据集上训练了两种模型变体(35M,,和135M,,见表1),其中patch大小为16,层为20层。与MambaVision等视觉任务中的SSM研究(如Hatamizadeh和Kautz,2024;Li等人,2024;Zhu等人,2024)相比 - 所有这些在Mamba层上都有重大设计变化,如双向扫描、额外的投影、扫描顺序的变化,或将SSM层与注意力和卷积层结合 - 作者保持了网络架构尽可能简单。

picture.image

picture.image

作者使用FFCV(Leclerc等人,2023年)数据加载器进行高效的训练。作者应用了像RandAug(Cubuk等人,2020年),Random Erasing(Zhong等人,2020年),Mixup(Zhang,2017年),Cutmix(Yun等人,2019年)以及遵循DEiT-3(Touvron等人,2022年),AugReg(Steiner等人,2021年)和更好的ViT基础(Beyer等人,2022年)的各种设置。完整的实验设置可以在附录A中查看。

4.1.1 Image Classification

在图3中,作者可以看到对于35M和135M的MatMamba-Vision模型,明确优化的子模型与具有相同嵌套子模型架构的4个独立训练的基准模型非常接近。然而,作者不需要4个单独的模型,作者可以在单个模型中灵活地获得所有性能 Level /参数计数。

picture.image

自适应推理 using Mix'n'Match: 此外(如图3所示),在各种组合的粒度下使用Mix'n'Match可以得到模型,这些模型可以平滑地插值(有时甚至超过)在连接明确优化的粒度线的精度。这表明了强大的自适应性,因为作者可以在精度计算曲线上提取出一个组合上很大的子模型集。作者可以灵活地为部署约束优化子模型选择,同时只使用单个嵌套的通用模型的权重。

在图4中,作者还研究了与彼此和ViT-B/16模型相比,MatMamba-Vision模型的推理速度折衷。作者发现,在或低于512px的分辨率下,由于GPU并行性和优化如FlashAttention,ViT是最快的模型。然而,当作者将分辨率提高到1024px及以上时,Mamba风格模型在吞吐量和延迟方面开始优于ViT。作者还研究了推理内存使用,并发现MatMamba-Vision在分辨率增加时,其扩展略好于优化后的ViT-B/16。

这两个观察结果都提供了有前景的证据,表明基于MatMamba的模型可以在单个加速器上处理更高分辨率下的更长的视觉数据序列(相反,像RingAttention Liu等人(2023)那样在Transformer中扩展上下文长度需要多个互联加速器进行单个前向传播在长序列长度时)。

picture.image

4.1.2 Adaptive Image Retrieval

图像检索旨在利用预训练编码器生成的表示,根据语义相似性定位相似的图像(Chen等人,2022)。标准方法涉及使用相同的编码器对数据库和 Query 图像进行编码,然后执行最近邻检索。使用强大的数据库图像编码器是可行的,但 Query 编码器必须对实时应用高效。此外, Query 编码场景可以有所不同,例如设备上的处理与云处理,以及 Query 负载和复杂性的变化。固定编码器的现有解决方案通常在不同设置下牺牲准确度或成本。

由于其灵活性,MatMamba-Vision是 Query 编码的有望候选者。然而,检索还要求子模型在各种粒度上保持固定数据库(使用更大编码器编码)和 Query 嵌入之间的距离关系。仅使用较小的基准Mamba2模型仅进行 Query 编码可能导致显著的距离保留问题和检索准确率低(如图5所示)。

picture.image

作者在ImageNet-1K上评估了 Baseline 和MatMamba-Vision编码器在35M和135M参数尺度的图像检索性能。使用[CLS] Token 表示,作者计算了1-最近邻(NN)精度。图5表明,从MatMamba中提取的子模型可以有效地保留距离并提供更大的灵活性。

例如,MatMamba-Vision-135M可以在最小准确率损失小于0.5%的情况下,将计算成本降低55%。虽然带有后缀[CLS]的因果模型在检索方面的准确性与双向编码器可能不相上下,但这朝着更好的长文本编码器迈出了有益的一步,同时实现了自适应 Query 处理。

MatMamba-LM

作者使用MatMamba块(MatMamba-LM)训练解码语言模型。

在图6中,作者可以看到MatMamba-LM模型与Mamba2模型在训练 Token (token)以及最大粒度下呈线性增长。在图7中,作者同样看到对于所有粒度,每个粒度的最终训练模型以及使用相同架构训练的 Baseline 模型都呈线性增长。此外,作者还观察到图6中每个嵌套粒度的验证损失(val loss)在最大模型()和最小模型()之间呈现出相似的距离(通常为0.4的delta),对于中间模型具有一致的间隙。

考虑到在大且多样化的数据集上的验证损失是语言模型性能的最强 Agent (而不是在噪声数据集上的少样本评估),这些扩展趋势提供了非常令人信服的证据,即可以利用一个嵌套的MatMamba-LM模型在各种部署中,而无需独立训练4个模型。

picture.image

picture.image

在图7中,作者展示了在所有4种MatMamba-LM变体上使用Mix'n'Match进行自适应推理的结果。作者观察到在和这两个粒度之间实现了平滑插值(例如,在和之间)。

然而,对于较低的粒度,尽管显式优化的粒度与预期一致,但未显式训练的Mix'n'Match模型性能略有下降。

作者观察到在训练的早期阶段,所有粒度的Mix'n'Match趋势都恰好位于性能-计算曲线上。然而,在训练的后期阶段,显式优化的粒度比Mix'n'Match粒度改善得更快(几乎就像 Anchor 点)。

可能存在可以修复此问题的机制:例如,使用最大子模型的输出进行自监督蒸馏,训练时使用超过的粒度,或使用Flextron中使用的 Agent 模型结构,这应该可以使Mix'n'Match趋势平滑。

然而,这需要更深入的理解,作者将在未来的工作中进行更全面的探索。

5 Conclusions

在这项工作中,作者提出了MatMamba,它是一种将嵌套Matryoshka结构应用于Mamba2状态空间模型的方法。它结合了Mamba风格模型的最佳特性(尤其是在较长序列上的推理速度更快)和Matryoshka风格的学习。

单个MatMamba模型包含数百个嵌套和准确的子模型,可以灵活提取进行推理。MatMamba-Vision和MatMamba-LM模型与独立训练的Mamba2 Baseline 性能和准确性相匹配。

MatMamba模型使作者能够选择所需的性能-计算权衡,同时保持一个单一的Matryoshka风格模型,而不是针对特定场景的多个不同模型。

这使得可以使用有趣的使用案例,如使用较小的草案模型和较大的验证器模型进行推测解码,输入自适应子模型选择,以及基于可用计算的混合云边缘推理使用相同模型。

参考文献

[0]. MatMamba: A Matryoshka State Space Model.

点击上方卡片,关注 「AI视界引擎」 公众号

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论