点击下方卡片,关注「集智书童」公众号
本文主要解决了什么问题
Token冗余问题 :基于Mamba的视觉模型中存在过多的token表示图像,增加了计算成本并降低了推理速度。
Block冗余问题 :多个SSM模块的存在导致了吞吐量瓶颈,影响了模型效率。
训练与推理不一致性 :早期的token剪枝方法会导致训练与推理之间的不一致性或引入额外计算开销。
本文的核心创新是什么
重新排列策略 :提出了一种在训练过程中重新排列剪枝token的方法,以解决基于Mamba模型的训练与推理不一致性问题,且无需额外计算开销。
动态块选择 :允许每张图像根据推理需求动态选择SsM block数量,从而减少不必要的计算。
端到端优化框架 :设计了包含分类损失、监督损失和蒸馏损失的联合损失函数,确保模型在剪枝和块选择后仍能保持高性能。
结果相较于以前的方法有哪些提升
显著降低FLOPs :在Vim-S上实现了35.2%的FLOPs减少,仅损失了1.7%的精度。
泛化能力强 :DyVM在不同Mamba视觉模型架构(如VideoMamba、MambaReg)和不同任务(如图像分类、视频理解、语义分割)上均表现出色。
吞吐量提升 :实验表明DyVM在各种设备上均能加速模型推理,尤其是对于较大模型效果更显著。
局限性总结
性能下降不可避免 :尽管FLOPs大幅降低,但模型性能仍有轻微下降(如1.7%的精度损失)。
依赖特定结构 :DyVM的设计针对基于Mamba的模型,可能不适用于其他类型的视觉模型。
复杂性增加 :引入了可学习的预测器和多阶段剪枝策略,增加了模型复杂性和训练难度。
基于Mamba的视觉模型因其计算效率高于基于注意力的模型而受到广泛关注。然而,这些模型中仍然存在空间冗余,表现为token冗余和block冗余。对于token冗余,作者通过分析发现,早期的token剪枝方法会导致训练与推理之间的不一致性,或为推理引入额外的计算量。
因此,作者通过在将其输入下一个Mamba block之前重新排列剪枝序列,将token剪枝定制以适应Mamba结构。对于block冗余,作者允许每张图像根据Mamba基于视觉模型的推理速度很大程度上受SsM block数量的影响这一经验观察动态选择SsM block。DyVM,动态视觉Mamba(DyVM),有效地减少了FLOPs,同时性能略有下降。作者在Vim-S上实现了35.2%的FLOPs减少,仅损失了1.7%的精度。该方法在不同Mamba视觉模型架构和不同视觉任务上都具有良好的泛化能力。
作者的代码将在 https://github.com/NUS-HPC-AI-Lab/DyVM 上公开。
- 引言
Vision Mambas [9, 19, 25, 38] 在图像分类 [21, 42]、视频理解 [16] 和图像分割 [22, 28, 33, 36] 等视觉任务上取得了显著的性能表现。其核心思想是通过状态空间模型(SSMs)[5] 对视觉 Token 之间的交互进行建模。
空间冗余,已被广泛证明存在于视觉Transformer(ViT)[10, 18, 23, 30, 32]中,也可能存在于视觉Mamba中。这种冗余出现在token Level ,因为用过多的视觉token表示图像[20, 34],从而增加计算成本并降低推理速度。在图1(a)中,作者从ImageNet-1K数据集的100个类别中随机挑选每类10张图像,并计算每个像素的注意力分数。统计数据表明,所有像素中有94.6%的注意力分数低于70%,表明其对模型性能的贡献最小。尽管这个问题在ViT中已被充分讨论并有效解决[12, 26],但对于视觉Mamba来说仍不充分。
为解决视觉token数量过多的问题,token剪枝已被证明在ViT场景下是一种有效的解决方案。通过屏蔽不需要token的注意力分数,作者可以在训练中模拟token剪枝,从而实现训练与推理之间的一致性。
然而,简单的 Mask 方法与基于Mamba的模型不兼容。为了证明这一点,作者以Vim [42]为例,Vim是一种典型的视觉Mamba模型。当作者尝试 Mask 剪枝后的token,如图2(a)所示,这会导致训练和推理过程中输出表示的不一致性,从而削弱了模型在视觉token剪枝后的性能。这种失败可以归因于Mamba的类似循环结构,其中来自先前状态的信息通过隐藏状态传播,而简单的 Mask 会破坏这一过程。
除了token Level 的冗余之外,作者还注意到由多个SSM模块引起的吞吐量 Bottleneck 。例如,Vim [42] 在每一层都实现了前向和后向SSM,以增强空间感知能力。在图1 (b)中,作者比较了Vim在每层使用两个模块、一个模块以及不使用模块时的计算成本和推理吞吐量。可以观察到,尽管减少SSM模块对FLOPs的影响较小,但它显著提高了推理吞吐量,移除一个模块后实现了1.36倍的改进,移除两个模块后实现了2.83倍的速度提升。这一发现证实了作者的假设,即Vision Mambas中过多的SSM模块会损害效率。因此,在推理过程中识别并禁用这些冗余模块至关重要。
基于上述分析,作者引入动态视觉Mamba(DyVM),旨在在token和block层面减少冗余。从token的角度来看,DyVM在特定层使用预测器来识别并剪枝信息量较低的token。为缓解训练与推理不一致的问题,作者在训练过程中对剪枝的token进行重新排列,使其在 Mask 后紧随保留的token,确保保留的token在推理阶段不受剪枝token的影响。此外,在block层面,DyVM动态选择适当数量的SSM block来处理每张图像。这种数据依赖的方法特别提升了每个样本的吞吐量。
大量实验表明,DyVM能够显著降低不同尺寸的Vision Mambas的FLOPs,同时仅以微小的性能下降为代价。作者在Vim-S上实现了3.5.2%的FLOPs降低,且仅损失1.7%的精度。与视觉 Token 剪枝 Baseline 方法,即HiddenAlign [40]相比,DyVM展现出更优的性能效率权衡。此外,在VideoMamba [16]和MambaReg [31]上的实验验证了DyVM的一般化能力。
- 相关工作
视觉Mamba。序列建模的进步显著影响了计算机视觉领域,ViT [4]等模型被应用于视觉任务。近年来,Mamba [5]等状态空间模型(SSMs)因能有效处理长序列而受到关注。后续研究[11, 13, 17, 24, 29]使用基于Mamba的 Backbone 网络在视觉基准测试中取得了优异性能。Vim [42]通过位置嵌入和集成双向SSM的Mamba模块处理图像块。VMamba [21]通过在四个方向交叉扫描块来处理二维依赖,解决了Mamba的一维局限性。PlainMamba [37]利用之字形扫描和方向感知更新增强了特征融合和泛化能力。
Token剪枝。Token剪枝通过移除重要性较低的token来降低计算成本,在最小化架构变更的同时加速推理。对于视觉Transformer(ViT)[4],EViT [18]等方法使用类token注意力分数,DynamicViT [26]采用预测层,ToMe [2]合并相似token,PATCHMERGER [27]引入合并模块,T2T-ViT [39]聚合邻近token。然而,由于结构差异,这些方法不适用于视觉Mamba。HiddenAlign [40]探索了Mamba的token级剪枝,但会带来额外的推理成本。相比之下,DyVM结合token级和模块级剪枝,在不增加计算量的情况下实现全面改进。
- 方法
3.1. 初步
状态空间模型(SSMs)[6-8]通过在隐藏状态
中传播信息,将输入序列
映射到输出序列
。
是演化参数,
和
是投影参数。
SSM针对连续输入,而Mamba [5]通过引入时间尺度参数
和零阶保持器(ZOH)提供了一个离散版本:
相应地,方程1的离散版本可以表述为:
Mamba模型通过全局卷积计算输出,具体如下:
(a) 纯粹的token Mask 方法导致训练与推理之间存在不一致性,主要原因是进化转换的数量不同(即
)。
(b) 隐藏对齐 [40] 在推理过程中保留了进化变换(即
),从而在牺牲额外计算力的代价下实现了训练与推理之间的一致性。
(c) DyVM 在训练过程中重新排列token以实现一致性,与HiddenAlign相比,在推理过程中消除了额外的计算。
为了处理2D图像,Vision Mambas [13, 16, 21, 42] 将其转换为一系列token,这与视觉Transformer [4]中的做法相同。随后,采用一系列带有SSM(状态空间模型)的Mamba层来构建token之间的关系。不同的方法在层的设计上有所不同。
3.2. 现有 Mask 方法
现有的视觉Transformer的token剪枝方法[26]在训练时使用 Mask 来模拟移除token,但这些方法与基于Mamba的模型结构不直接兼容。
平铺 Mask 法 最直接的方法是在训练期间将嵌入设置为零来 Mask Token ,如图2(a)所示。尽管它阻挡了某些 Token 的信息,同时允许其余 Token 的信息传播,但在训练和推理过程中会导致不一致性,从而损害模型性能。下面,作者提供详细分析。
对于长度为
的输入 Token 序列
,保留
个 Token 而剪枝其余 Token ,令
表示保留 Token 的索引(若
,则
)。在 Mask Token 的情况下,训练期间Mamba块每个序列位置的输出可按如下方式计算:
在推理过程中,冗余的token会被直接丢弃,而保留的token会被连接。假设作者保留具有相同索引的token,那么在推理过程中Mamba模块每个序列位置的输出可以按如下方式计算:
训练和推理输出在进化转换次数
的数量上高度不一致,因为
。只有在
在训练序列中是连续索引的罕见情况下,才能实现等式。
隐藏对齐 [40] 之前的隐藏对齐(HA)方法 Aware 到了早期方法中的这种不一致性,并提出了一种新的方法,如图2(b)所示。在训练过程中,HA使用与普通 Token Mask 方法相同的 Token Mask 方法。在推理过程中,对于每个被剪枝的 Token ,HA保留其演化变换(即
),同时剪枝其对应的投影(即
和
)。通过这种方法,训练和推理的输出都是一致的,并计算为:
然而,与普通的token剪枝方法相比,这种方法在推理过程中引入了额外的计算。在大多数情况下,HA的推理成本(公式11)大于普通的token Mask 方法的推理成本(公式10),因为
,并且只有在-1是连续索引的情况下才可能达到等式。
3.3. 动态视觉Mamba
作者的分析提出了一个关键问题:token剪枝能否在不增加额外计算开销的情况下实现训练-推理一致性?作者提出了动态视觉Mamba(DyVM),该模型在token和块 Level 上减少了视觉Mamba的空间冗余。token剪枝通过
个剪枝阶段逐步进行,每个阶段根据前一阶段的 Mask 继续屏蔽token。
在每个阶段,训练过程中对token进行重新排序以模拟推理时的顺序,从而在消除额外计算的同时实现一致性。动态块选择在每一层执行,预测每个样本应通过哪些块。作者通过图3展示了DyVM的流程。
Token剪枝 在每个阶段
,作者剪枝固定比例的token,并通过维护一个二元 Mask
来进行,该 Mask 表示每个token是保留还是丢弃。
中的所有元素初始值均为1。遵循DynamicViT [26]的方法,作者在每个剪枝阶段实现一个预测器
并随后进行softmax操作,以生成批量输入序列 H E RBLD 中每个token剪枝和保留的概率。
其中
表示保留第
批中的第
个token的概率,而
表示剪枝该token的概率。然后,
通过当前策略
进行更新。
此处,作者采用Gumbel-Softmax [14]技巧使采样过程可微,从而实现端到端训练。剪枝效果通过将 Token 乘以更新的 Mask
实现。为解决上述训练-推理不一致问题,作者提出在将 Token 传递给SSM模块进行训练时,重新排列 Token 的位置,如图2(c)所示。
具体而言,对于每个序列
,作者首先将类别 Token
和其他 Token 分开。作者仅对保留 Token 的序列进行操作
如果
,由
给出)。作者不是在原地 Mask Token ,而是将保留的 Token 聚合为一个连续块,同时保持它们的相对顺序。然后,作者将类别 Token
重新插入到中间位置,这是Vim模型中常用的类别 Token 位置:
类似地,作者将剪枝后的token分组到另一个连续的块中。令
表示剪枝后token的索引(若
,则
)。剪枝后的token块为:
最后,作者将保留和剪枝的块连接起来:
这种公式化方法消除了通过隐藏状态无意传播信息的问题,保持了训练和推理之间的一致性,并提升了模型的性能和稳定性。现在,训练和推理的输出都是一致的,并计算如下:
动态块选择在图1(b)中,作者展示了吞吐量随着活动扫描块数量的增加而降低。因此,作者提出为每个样本动态选择扫描块,以在块 Level 进一步减少冗余。具体来说,一个样本可以经过前向和后向块,其中一个前向和后向块,甚至两个块都不经过。这是通过每个Vim层内的一个块选择器实现的,该选择器以层l中的类别 Token
作为输入,预测每个块的分数。随后,通过Gumbel-sigmoid函数将分数矩阵转换为二进制 Mask :
最后,通过将输出与 Mask 相乘的方式,对每个样本的冗余块进行禁用:
是层 l 中的正向和反向块的重新排列的输入序列。
分别表示层 l 中正向块和反向块的 Mask 输出。需要注意的是,由于每个块中存在非零偏差项,因此无法直接在正向或反向输入上应用 Mask 。
3.4. 训练与推理
训练。DyVM的训练目标由五个部分组成:一个分类损失,两个用于约束剪枝率的监督损失,以及两个用于校准模型性能的蒸馏损失。
首先,作者计算模型预测值
与真实标签
之间的标准交叉熵损失,将其作为分类损失。
其次,为了监督 Token 剪枝比例,作者设定目标 Token 比例
,并期望在
次剪枝阶段后保留
个 Token 。给定一组
次剪枝阶段,其目标比例
,作者计算均方误差损失:
其中
表示第
批在第
个剪枝阶段后的 Mask 的第
个值。为了监督模块选择比例,作者计算所有层(总共
层)中活跃模块的平均比例,并使用预定义的模块比例
计算均方误差损失。
最后,作者通过使用原始 Backbone 网络作为教师模型,进一步校准在 Token 剪枝和块选择后的模型行为。首先,作者最小化模型输出
与教师模型输出
之间的Kullback-Leibler(KL)散度损失。
此外,作者通过计算均方误差(MSE)损失,使所有保留的token接近教师模型的token。
联合损失是上述五个损失的加权总和:
推理。在推理过程中,剪枝的token会被丢弃,并且直接跳过块以提高更高的效率。对于token剪枝,给定目标比率
,在
次剪枝阶段后,作者保留
个token,丢弃其余的token。保留的token的索引是通过根据保留概率对token进行排序并选择前
个token获得的。
对于块选择,以正向块为例,只有块 Mask 值为1的样本被发送到层
的正向块中。形式上,通过层
正向块的样本索引为:
反向块遵循相同的逻辑。因此,在评估过程中,卷积和SsM扫描计算更少,从而加速了推理过程。
- 实验
4.1. 模型与数据集
模型。作者在Vim模型(Vim-T、Vim-S、Vim-B)[42]上实现了DyVM,并将其与HiddenAlign(HA)[40]作为 Baseline 进行比较。为了展示DyVM的泛化能力,作者将其集成到VideoMamba(VideoMambaT、VideoMamba-S)[16]和MambaReg(MambaReg-S、MambaReg-B)[31]中进行图像分类。此外,作者通过评估集成DyVM的VideoMamba框架在视频理解任务上的表现,对DyVM的跨模态泛化能力进行了评估。此外,遵循HA的实验设置,作者使用UperNet[35]作为 Baseline 框架,对DyVM在语义分割任务上的性能进行了评估。
数据集。对于图像分类任务,作者在ImageNet-1K [3]上进行了实验,该数据集包含1281167张图像,分为1000个类别。对于视频理解任务,作者在Kinetics400 [15]上进行了实验,该数据集涵盖了400个人类动作类别,包含650000个视频。对于语义分割任务,作者在ADE20K [41]上进行了实验,这是一个大规模数据集,包含20000张图像,涵盖150个语义类别。
4.2. 实验设置
在图像分类任务中,作者通过微调 Backbone 模型30个epoch来训练DyVM。作者为Tiny尺寸模型设置学习率为3e-5,为Small和Base尺寸模型设置学习率为5e-5。作者使用余弦学习率调度器,并设置5个epoch的预热阶段。Tiny、Small和Base尺寸模型对应的批大小分别为128、64和32。对于token剪枝,作者设置
个token剪枝阶段,剪枝率
,其中
为目标token比例。对于模块选择,作者在所有层中设置一个单一的目标模块比例
。具体来说,作者初始化模块选择模块以保留所有样本,确保其行为与原始模型高度相似。在计算联合损失时,作者使用
、
和
。其他训练设置和细节可以在补充材料中找到。
4.3. 主要结果
与 Baseline 的比较。表1展示了在图像分类设置下将DyVM应用于Vim的结果。DyVM成功降低了Vim在不同模型尺寸上的FLOPs,同时保持了令人满意的表现。与HA方法相比,DyVM在Vim-T和Vim-S上实现了相同或更好的性能,同时FLOPs降低幅度更大。DyVM在VideoMamba和MambaReg上也有良好的泛化表现,显著降低了FLOPs,性能略有下降。
不同比例组合的结果。图4展示了在Vim-T和Vim-S上不同token-block比例组合的Top-1准确率和FLOPs。Token剪枝通过缩短序列长度减少了FLOPs,但随着比例的增加,会导致更大的准确率损失。将其与块选择结合可以实现相似的FLOPs减少,同时性能下降更小。较大的模型由于具有更多的空间冗余,更能容忍激进的剪枝,表现出更小的性能退化。
扩展到更大的token数量 在表2中,作者展示了DyVM应用于VideoMamba在K400视频理解数据集上的结果。作者持续减少了可观的FLOPs,同时保持了相当的性能。这表明DyVM在更大的token数量设置下是稳健的。
语义分割结果。当DyVM集成到UperNet中用于ADE20K的语义分割时,它表现出对预测任务的强大适应性。如表3所示,该框架在保持竞争性分割精度(mIoU)的同时降低了计算成本。
- 分析
可学习 Token 剪枝和块选择。DyVM引入了可学习的 Token 和块评分预测器,用于 Token 剪枝和块选择。作者通过消融研究验证了它们的有效性。对于 Token 剪枝,作者将可学习预测器与另外两种剪枝策略进行了比较:随机选择和固定位置(静态)。对于块选择,由于样本可以自由通过任何块,作者仅将其与随机选择进行了比较。如表
和
所示,可学习预测器通过精确识别冗余 Token 和块实现了最高的准确率。
Token 剪枝的阶段。在DyVM中, Token 剪枝方法采用多阶段方法。另一种方法可以使用较少的阶段和较高的每阶段剪枝率来达到相同的最终剪枝率。因此,作者进行了一项消融研究,以检验在最终剪枝率固定的情况下,剪枝阶段的数量如何影响模型质量。结果报告在表5中,表明更多的剪枝阶段可以实现更高的准确率。
DyVM的token mask预测器将token本身作为输入来决定保留或剪枝。为了评估不同输入的影响,作者分析了Mamba生成的变量
,如第3.1节所述),这些变量是输入相关的。一项消融研究将预测器的输入分别替换为每个变量。如表6所示,预测器在直接token输入时表现最佳,表明额外的token转换是不必要的。
不同损失的影响。为验证不同训练损失的效果,作者对移除了蒸馏损失(公式25和26)的Vim-S模型进行实验,并将结果报告在表7中。结果表明,这两种损失均略微提升了模型性能。由于两种监督损失(公式23和24)对于控制剪枝率至关重要,因此未进行实验。
吞吐量。为了评估DyVM是否提高了模型的吞吐量,作者在各种设备上进行了测试。表8报告了测试结果,这些结果表明DyVM可以在所有设备上实现加速。值得注意的是,对于较大模型(例如Vim-B),改进效果更为显著,这与表1中的FLOPs分析结果一致。
可视化。作者通过可视化预测的token剪枝和模块选择策略来展示DyVM的有效性。
对于token剪枝,作者展示了隐藏注意力热力图[1]以及每个阶段的保留token(图5)。红色区域表示高注意力分数,而蓝色区域表示低注意力分数。在非活动区域的冗余token在各个阶段被剪枝,而接收高注意力分数的判别性特征被保留。对于模块选择,策略在不同图像间有所不同(图6),突出了DyVM为每个样本定制路径的能力。这些可视化强调了DyVM在减少Vim模型空间冗余方面的有效性。
- 结论
在这项工作中,作者提出了一种名为DyVM的新方法,旨在提高基于Mamba的视觉模型的效率。DyVM的重新排列策略成功解决了训练与推理之间的不一致性,且没有额外的计算开销。DyVM有效降低了Vim的FLOPs,并保持了相当的性能。此外,作者还做出了早期努力来减少扫描块的数量,并启发未来的研究在设计新的视觉Mamba架构时保持良好的平衡。此外,DyVM展现出强大的泛化能力,并提高了其他基于Mamba的视觉任务模型的效率。
参考
[1]. Dynamic Vision Mamba
扫码加入👉「集智书童」交流群
(备注:方向+学校/公司+昵称)