视觉 Mamba 再进化,引入多头扫描与路径注意力机制提升图像建模性能 !

技术

点击下方卡片,关注 「AI视界引擎」 公众号

picture.image

最近,以Mamba为代表的态空间模型(SSMs)在长距离依赖建模方面显示出线性复杂度的巨大潜力。随后,依次介绍了Vision Mamba及其后续架构,它们在视觉任务上表现出色。

将Mamba应用于视觉任务的关键步骤是按顺序构建2D视觉特征。为了通过1D选择性扫描在2D图像空间内有效地组织和构建视觉特征,作者提出了一个新颖的多头扫描(MHS)模块。

从前一层提取的嵌入被投射到多个低维子空间中。随后,在每个子空间内,沿着不同的扫描路径执行选择性扫描。从多头扫描过程中获得的子嵌入随后被整合,并最终投射回高维空间。

此外,作者融入了一个扫描路径注意力(SRA)机制,以增强模块识别复杂结构的能力。

为了验证作者模块的有效性,作者仅用作者提出的模块替换了VM-UNet中的2D-Selective-Scan(SS2D)块,并且作者从零开始训练作者的模型,不使用任何预训练权重。

结果表明,在减少原始VM-UNet参数的同时,性能有了显著提升。

本研究的相关代码可在https://github.com/PixDeep/MHS-VM公开获取。

1 Introduction

近年来,深度学习的不断演进推动了计算机视觉领域的实质性进步。视觉表示学习是计算机视觉中的关键步骤。此前,两种基础模型——卷积神经网络(CNN)[20, 7, 9, 15]和视觉 Transformer (ViTs)[4, 14, 23]——在一系列视觉任务中被广泛采用。这两种模型在生成富有表现力的视觉表示方面取得了显著成就,但ViTs通常在性能上超越CNNs,这得益于它们的全局接受域和注意力机制实现的动态加权。然而,随着图像尺寸的增加,注意力机制的复杂性呈二次增长,这对于像语义分割和目标检测这样的密集预测任务来说,计算负担沉重。因此,设计具有线性复杂度的视觉模型,同时保留全局接受域和动态权重的优势,是至关重要的。

计算机视觉领域最近重新燃起了对状态空间模型(SSMs)[5, 19, 25]的兴趣,这些模型传统上在模拟长距离依赖方面表现出色,具有线性的计算扩展性。SSMs的进步将其应用扩展到复杂的视觉任务中,在这些任务中,它们有望比已确立的体系结构(如CNNs和ViTs)提供效率上的优势。Mamba模型[5]利用数据相关SSM层,在大型真实世界数据集上超越Transformers,并在序列长度上保持线性计算复杂性,展示了其强大实力。这一成就凸显了Mamba作为语言建模变革性架构的潜力,表明SSMs可以成为Transformers在原始应用领域之外的有力竞争者。受到Mamba成功的启发,Vision Mamba(Vim)[32]应运而生,这是向无需依赖注意力机制,将纯SSM方法应用于视觉任务迈出的重要一步。Vim通过整合双向选择性状态空间模型和位置嵌入,增强了视觉表示学习,实现了在ImageNet分类以及密集预测任务(如目标检测和语义分割)上超越DeiT[23]的卓越性能。同时,VMamba[13]将SSM范式扩展到视觉任务,专注于在视觉表示学习中提高效率和可扩展性。VMamba引入了一个跨扫描模块(CSM),以在2D图像空间中实现1D选择性扫描,在保持全局接受域的同时,将计算复杂性从二次降低到线性。这一设计加上架构的改进,使得VMamba成为了一个在性能上与流行的视觉模型(如ResNet[7],ViT[4],Swin[14])相匹配甚至超越的有竞争力的 Backbone 模型。此外,U-Mamba[18]解决了在生物医学图像分割中对高效长距离依赖建模的需求。通过将CNNs在局部特征提取方面的优势与SSMs在捕捉长距离上下文方面的优势相结合,U-Mamba展示了混合架构在改善跨多种生物医学成像任务的分割性能方面的潜力。

这些发展共同表明,基于SSM的模型正在成熟,成为视觉表示学习通用且高效的 Backbone 网络,与传统架构相比,在某些情况下甚至表现更佳。它们能够高效地处理长距离依赖关系,同时保持或提高性能,这为计算机视觉领域未来的研究和应用提供了一个有希望的轨迹,特别是随着视觉数据规模的不断扩大和对高效处理需求的持续增长。实验结果表明,基于SSM的模型在包括图像分类[2, 10]、语义分割、目标检测[10, 26]、图像恢复[6]和图像生成[8, 30, 22]等在内的各种视觉任务上取得了令人鼓舞的性能。

本文重点关注基于纯SSM的模型在视觉任务上的表现。VM-UNet [21]是首个将VMamba融入到U-Net架构中的模型,专注于医学图像分割并采用纯SSM方法。尽管有一些进步,但纯SSM模型在视觉任务上尚未显著超越传统的CNN和ViT方法。特别是当纯SSM模型仅用于较小数据集时,观察到其泛化能力仍有提升空间。本文提出了优化SSM视觉模型关键因素的思路。在将SSM用于各种视觉任务时,一个关键步骤是将2D图像数据转换为1D序列。这种转换对于发挥Mamba处理序列数据的能力至关重要。通过这种方式重新构建图像数据,Mamba可以有效地分析和解释视觉信息,从而在多个视觉识别任务中提高性能。传统的方法,如1D序列中的单向扫描,难以同时捕获2D图像中的多方向依赖关系,限制了其感受野,可能在目标检测和分割等密集预测任务中影响准确性。作者尝试通过引入多头扫描(MHS)模块来捕获2D图像块中的多方向依赖关系。一个单一的扫描头通过遵循嵌入子空间内的特定扫描模式来捕获结构信息。此外,作者还提出了一种新颖的机制,该机制促使模块提取的特征沿着扫描路径隐式地融入位置信息。

简而言之,本文侧重于提升VM-UNet [21]中VSS块的2D-Selective-Scan(SS2D)性能。

本文的主要贡献总结如下:

引入了一种多头扫描(MHS)机制以增强视觉表示学习。

  • 引入了更为丰富的扫描模式阵列,以捕捉视觉数据中呈现的多样化视觉模式。
  • 引入了一种扫描路径注意力(SRA)机制,使模型能够衰减或筛选掉无关紧要的特征,从而增强其捕捉图像中复杂结构的能力。
  • 作者开发了一个易于使用的模块,在小规模医学图像数据集上的分割任务中展示了增强的泛化能力。

2 Preliminaries

在基于结构化状态空间(SSM)的现代模型中,如结构化状态空间序列模型(S4)和Mamba,采用了一个经典的连续系统。该系统将一维输入函数或序列(表示为 ),通过中间隐式状态 映射到产生输出 。这个过程可以通过一个线性常微分方程(ODE)来封装:

其中 是状态矩阵,而 和 表示投影参数。

离散化。 为了将这个连续系统适应于深度学习应用,S4和Mamba通过引入时间尺度参数 对其进行离散化。它们使用固定的离散化规则(如零阶保持(ZOH))将连续参数 和 转换为离散对应物 和 。这种离散化表达为:

一旦离散化,基于SSM的模型可以通过线性递归

或全局卷积进行计算,

其中 表示一个结构化卷积核, 是输入序列 的长度。这种方法利用卷积同时在整个序列中整合输出,从而提高计算效率和可扩展性。

3 Methods

这段文字首先介绍了作者模块的整体架构。随后,作者详细阐述了核心组件的细节。

图1:作者提出的多 Head 扫描(MHS)模块的架构。在所示模块中,有三个扫描头,这个数量可以根据实际需求进行调整。这种设计便于其应用,作者可以立即将VM-UNet中的VSS块的SS2D模块替换为作者的MHS模块。

picture.image

所 Proposal 模块的整体架构如图1所示。作者在图中展示了两套略有差异的架构。右侧的那个变体在尾部部分去除了投影,其性能影响将在随后的消融研究部分讨论。关注作者模块的结构,它由三个主要组成部分构成: Head 、中间部分和尾部,每个部分在整体架构中都扮演着关键角色。随后,作者将对这些组件进行全面的解释。### 子空间和扫描模式

Head 部分关注于将从前一层提取的嵌入转换成个子嵌入,这些子嵌入被描述为个等维度的平行子空间中。这可以简洁地用以下方程表示:

其中,,,。默认情况下,嵌入所在空间的维度是每个子空间的倍,即,但这不是强制性的。

中间部分由个扫描头和嵌入部分融合(ESF)子模块组成。随后,作者将详细解释这一部分。

由于Mamba被设计用于处理具有因果关系的1D因果数据,如语言序列。然而,2D视觉数据的固有非因果关系给因果处理方法的应用带来了重大挑战。一个想法是将2D视觉数据通过多方向1D序列进行解构。作者引入扫描模式来捕捉图像块之间的方向依赖性。如图2所示,本文探讨了三种额外的扫描模式,用于捕捉2D图像中图像块的复杂方向依赖性。除了之前工作使用的第一种扫描模式外,另外三种模式更关注图像块之间的邻接关系。第二种模式是在水平或垂直方向上的连续扫描。第三种是沿对角线连续扫描。最后一种是沿螺旋线从外向内连续扫描。如果不需要关注图像块之间的邻接,可以引入更多模式。在作者当前的架构中,每个子空间部署一个扫描模式。

在输入到Mamba模块之前,作者按照图3所示的遍历路径将图像(嵌入图)展平为图像块的序列。作者可以看到,每一行显示沿着遍历路径展开的图像块的1D序列。这些图像块之间的不同1D因果关系被结合起来,以近似2D视觉数据中的复杂结构关系。

picture.image

为了更清楚地理解作者的方法,作者引入了两个关键概念,这两个概念是作者图像块嵌入方法的核心。首先,作者将沿特定扫描路线提取的连续图像块嵌入序列称为嵌入扫描序列。这个序列捕捉了沿着扫描路线遍历时图像块的空间和上下文信息。其次,作者考虑了沿着预定义路径展平的图像块嵌入的排列,这些排列共同构成了一个2D平面,作者称之为嵌入部分。这个平面为图像块嵌入提供了一个结构化表示,使模型能够有效地处理和分析视觉数据中的空间关系。如图2所示的绿色遍历路径用于为每条扫描路线生成嵌入部分。

图2:四种扫描模式的说明。从左到右:图像块沿着四种扫描模式的遍历路径。沿着虚线标记的数字表示沿着遍历路径的图像块的遍历顺序。

对于每种扫描模式,作者部署了多条扫描路径来捕捉不同方向上的 Patch 间依赖关系。这些嵌入扫描序列随后被输入到一个单独的Mamba模块中。为了根据不同的扫描模式独立捕获嵌入,作者不共享每个Mamba模块之间的权重。这个过程可以用以下公式表达:

其中 表示第 个Mamba模块,它接收 个嵌入扫描序列 作为输入,这些序列是从 条扫描路径中收集的,并生成 个输出嵌入扫描序列 。 表示输入嵌入扫描序列的长度。这些输出嵌入扫描序列随后根据各自的扫描路径重新排列,以创建嵌入段落。

Embedding Section Fusion

对于每个扫描头,在其子空间内进行若干扫描路径。扫描路径记录了图像块的实际遍历路径。在相同的扫描模式中,作者可以从位于不同位置(如角落)的图像块开始扫描,因此将存在多条扫描路径。默认情况下,作者在一个扫描头中使用四条路径。一个例子示于图4中。这些扫描路径用于收集通过Mamba块转换的嵌入序列。例如,四条扫描路径产生四个嵌入序列。每个序列沿着预定义的遍历路径堆叠嵌入向量,被重新组织成一个嵌入段。此后,使用ESF融合这些嵌入段以产生最终的图像块嵌入。除了直接相加 之外,作者还进一步引入了两种替代方案。

picture.image

池化混合方法。 这里作者部分借鉴了CBAM [27]中空间注意力模块的设计。如图5(a)所示,作者首先沿着分段轴计算平均值和最大值,从而得到两个池化段,并将它们连接起来生成一个简洁的特征描述符,然后将它们传递给线性层以有效地输出最终的嵌入段。这个过程可以表述为:

picture.image

其中 , 和 分别是沿分段轴的平均池化和最大池化。在连接的特征描述符 上,作者应用线性投影以生成融合的嵌入段。

基于CV的缩放。 在Transformer模型中,位置编码对于向模型提供序列中标记的相对或绝对位置信息起着至关重要的作用。作者不是使用显式的位置编码,而是引入了一种机制来增强对图像块空间位置的隐性感知能力。具体来说,为了使模型能够筛选出对扫描路径不敏感的琐碎特征,作者引入了一种扫描路径注意力机制。对于每种扫描模式,作者从不同的角落块开始选择条扫描路径,从而得到个扫描序列。每个块的嵌入向量的每个分量都有个值。为了量化提取的嵌入对扫描路径的感知程度,作者计算了这些值的变异系数(CV)的一个变体。变异系数是衡量样本或总体中数据围绕平均值相对波动性或离散程度的一个度量。如果来自条扫描路径的个值更一致,那么CV相对较小;相反,如果这些值更分散,那么CV相对较大。在某种程度上,通过这些扫描路径提取的嵌入对位置信息有很好的感知,并值得保留或增强。作者以上过程表达为:

其中 是CV的一个变体, 表示张量的逐元素乘积, 是一个单调函数,比如Sigmoid和ReLU。引入这个单调函数是为了促使Mamba块提取对位置敏感的特征。简化流程图如图5(b)所示。

在实际实验中,作者引入了一个参数来过滤掉具有较低CV值的特征。

这个函数当x<tx<t时返回00,否则返回xtx-t。参数tt可以作为超参数或可学习参数设置。这种策略可以被认为是一种新型的正则化技术,用于防止过拟合并提高泛化能力。< p=""></t时返回时返回0,否则返回,否则返回x-t。参数。参数t$可以作为超参数或可学习参数设置。这种策略可以被认为是一种新型的正则化技术,用于防止过拟合并提高泛化能力。<>

此外,将最后两种方案合并也是可行的,从而形成一个更复杂的方案:

图5:ESF子模块的两个方案说明。(a) 池化混合;(b) CV引导的缩放。

然而,在ISIC18数据集上的实际实验并没有显示出显著的性能提升,同时它们也表现出计算开销的增加。注意,上述聚合操作(例如平均、最小值、最大值和标准差)是沿着截面轴进行的。

Projection

最后,尾端部分被设计用来将子嵌入转换回高维空间。为此,作者首先将子嵌入进行拼接,然后通过层归一化对它们进行归一化,最后将归一化的结果投射回高维空间。在实践中,如果这些子空间维度之和 等于输入嵌入空间的维度 ,则可以可选地从架构中排除最后的投射步骤。下一节中的消融研究将阐明性能比较。

作者还用前馈网络(ffn)替换了投射组件;然而,在小型数据集上进行的实验表明,这种修改增加了参数数量,但没有实质性提高分割性能。

Pseudo-code for MHS Module

给定输入图像 ,其中 、 和 分别表示图像的高度、宽度和通道数。首先,应用一个 Patch 嵌入层将输入图像划分为大小为 的非重叠 Patch 。随后,作者的MHS模块与 Patch 合并层交替使用,将它们转换成一系列嵌入空间 ,,,其中 默认为 ,与 VM-UNet 中的设置相同。作者的一个模块以 通道嵌入图为输入,并产生具有相等通道数的输出嵌入图。具体来说,所提出的MHS模块的伪代码在 算法1 中呈现。

picture.image

算法1 MHS模块的伪代码

4 Experimental results

在本节中,作者针对医学图像分割任务进行了作者的模块实验,特别是在相对较小的数据集上。为了评估作者的模块的性能,作者将其应用于VM-UNet,这是一个最近提出的基于Mamba的医学图像分割框架。

需要强调的是,作者没有加入额外的机制来提高模型的性能,正如VM-UNet v2 [31] 和 UltraLight VM-UNet [29] 的情况一样。作者的努力集中在进一步挖掘视觉Mamba的潜力上。为了验证作者的模块的有效性,作者保留了VM-UNet的网络架构的其余部分不变,仅用作者的MHS模块替换了VSS块中的SS2D模块(见图1)。

Datasets and Experimental Setups

作者具体评估了作者的模块在三个公开可用的医疗图像数据集上的性能:ISIC17 [1],ISIC18 [3],以及Synapse [11]数据集。这些数据集在当前的分割研究中被广泛使用,在本研究中,它们用于基准测试作者提出的模块的竞争力表现。

作者使用PyTorch 2.0实现了MHS-VM,并以与VM-UNet相同的参数训练网络。批量大小设置为32,训练周期设置为300。采用AdamW [17]优化器,初始学习率为1e-3。使用CosineAnnealingLR [16]作为调度器,最大迭代次数为50,最低学习率为1e-5。所有实验都在一个NVIDIA RTX 4090 GPU上完成。

Ablation Study

在本节中,作者使用ISIC18数据集进行了一系列消融实验,以探索所提出的多头扫描(MHS)模块的各种配置。为了独立评估作者模块的性能,作者从零开始训练所有网络,而没有使用任何预训练权重。

ESF的影响。这个子模块专门设计用来在单个扫描头内集成来自不同扫描路线提取的嵌入部分。作者评估了上述三种方案,结果表明第三种方案优于其他方案,取得了最高的性能。实验结果如表1所示。在这个实验中,作者部署了三个与图2中最后三种扫描模式相对应的扫描头,并在等式(9)中将作为一个超参数。在后续实验中,作者主要在ESF子模块中使用第三种方案(CV引导的缩放)。

picture.image

投影的影响。作者进一步移除了MHS模块尾部部分的投影。移除这个组件导致参数数量和计算开销大幅减少,但性能仍然相当(见表2)。与之前的设置一致,作者继续在这个比较实验中使用三个扫描头。

picture.image

尾部的投影用于合并子嵌入的组件。考虑到后续层中的 Head 也包含了投影,作者可以消除当前层的尾部投影,从而大幅减少参数数量。然而,如果子空间维度的总和与嵌入空间的维度不匹配,投影不能随意省略。

扫描头的数量。在之前的实验中,作者专注于使用三个扫描头来收集2D图像中的特征。接下来,作者探讨更多的扫描头是否能够提高作者模型的性能。实证评估表明,虽然增加扫描头的数量可以略微提高模型性能,但改进非常小。然而,参数数量略有减少。

表1:对各种ESF方案的消融研究。

表2:对投影的消融研究。(带或不带的模型)实现了轻微的提升,同时参数数量显著增加。增加子空间的维度对于具有更大数据量的数据集可能带来优势,这是未来需要进一步研究的领域。

此外,作者还进行了实验,以探索增加并行子空间维度的效果。有了投影,子空间的维度不受限制,从而允许进一步探索其潜力。如表3所示,在ISIC18数据集上的实验表明,性能…(此处原文未给出完整句子)。作为参考,原始的VM-Unet包含27.4276M个参数和4.1119G FLOPs。在采用四个 Head 的情况下,作者使用了涵盖图2中所有扫描模式的配置,每个模式与四个扫描头中的一个相关联。与之前的实验一致,作者仍然在等式(9)中设置。此外,作者还使用了表中最少参数的第二个模型在ISIC17数据集上进行实验,获得了78.97%(mIoU)和88.25%(DSC)的分数。对于这个数据集,当设置为0时,模型表现良好。为了优化不同数据集上的性能,可能需要校准超参数的值。

picture.image

Comparative Results

作者采用了表2中展示的最轻量级模型与原始VM-UNet进行比较。表4中展示的对比实验结果强调了作者VM-UNet变体在这些数据集上取得了更优的性能。需要强调的是,作者仅将VM-UNet中的SS2D块替换为作者的 Proposal 模块,从而将模型性能的观察改进独立出来,这明确指出了作者模块的有效性。

picture.image

VM-UNet-T模型是通过用从VMMaba-T [13]获得的预训练权重初始化VM-UNet进行训练的。VM-UNet模型和作者的模型(简称MHS-UNet)是从头开始训练300个周期。值得注意的是,与原始VM-UNet相比,更新后的网络不仅在性能上更胜一筹,而且在参数数量和计算开销上也显著减少了48.00%和55.89%。作者模型的性能要么接近要么甚至超过了用预训练权重初始化的VM-UNet-T。然而,与用从VLambda-S [13]获得的预训练权重初始化VM-UNet进行训练的VM-UNet-S的性能还存在一定差距。

5 Conclusion and Future Work

为了提高Mamba在视觉任务中的性能,作者基于视觉Mamba引入了一种多头扫描模块,简称MHS-VM。作者将路径嵌入投影到多个平行的子空间中,并引入了不同的扫描模式来捕捉2D图像中路径之间的复杂依赖关系。为了验证作者模块的有效性,作者将VM-UNet中的SS2D模块替换为作者的模块,保持网络其余架构不变,并基于此框架进行消融研究。

与原始VM-UNet相比,融入作者模块的改良网络具有更少的参数数量和计算开销,作者在三个公开数据集上进行的实验证实了作者的模块提高了预测性能。

虽然将1D选择性扫描用于2D视觉任务是一个有前景且值得进一步探索的方法,但值得注意的是,本文讨论的模式并非穷尽的。发现并实现额外的扫描模式为未来的研究提供了丰富的可能性,为视觉识别和处理领域带来更加复杂和细腻的应用途径。

作者将改善架构和平行编程,并进一步探索融合局部扫描和全局扫描的分层表示,从而使得作者的模块成为未来更多视觉任务的通用基础架构。

参考

[1].MHS-VM: Multi-Head Scanning in Parallel Subspaces for Vision Mamba.

点击上方卡片,关注 「AI视界引擎」 公众号

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论