点击下方卡片,关注 「AI视界引擎」 公众号
UNet及其变体已被广泛应用于医学图像分割。然而,这些模型,尤其是基于Transformer架构的模型,因其参数众多和计算负载大而带来挑战,这使得它们不适合移动健康应用。
最近,以Mamba为代表的态空间模型(SSMs)作为与CNN和Transformer架构竞争的有力替代品出现。基于此,作者采用Mamba作为UNet中CNN和Transformer的轻量级替代方案,旨在解决实际医疗环境中计算资源限制带来的挑战。为此,作者提出了轻量级Mamba UNet(LightM-UNet),它在一个轻量级框架内集成了Mamba和UNet。
具体来说,LightM-UNet以纯Mamba风格采用残差视觉Mamba层来提取深层语义特征并建模长距离空间依赖关系,具有线性计算复杂性。在两个真实世界的2D/3D数据集上进行的广泛实验表明,LightM-UNet超越了现有最先进的研究。
值得注意的是,与著名的nnU-Net相比,LightM-UNet在大幅减少参数和计算成本(分别减少116倍和21倍)的同时,实现了更优的分割性能。这突显了Mamba在促进模型轻量化方面的潜力。
1 Introduction
UNet [16]作为一种成熟的医学图像分割算法,在涉及医学器官和病变的广泛分割任务中得到了应用,涵盖了各种医学图像的模态。它对称的U型编码器-解码器架构,结合了积分跳跃连接,为分割模型奠定了基础,催生了大量基于U型结构的工作[8, 15, 18]。然而,作为一种基于卷积神经网络(CNN-based)的模型,UNet面临着卷积操作固有的局部性,这限制了它理解明确的全局和长距离语义信息交互的能力[2]。一些研究尝试通过采用扩张卷积层[5],自我注意力机制[19]和图像金字塔[25]来缓解这个问题。然而,这些方法在建模长距离依赖方面仍然表现出限制。
在努力赋予UNet捕捉全局信息的能力的过程中,近期的研究[2, 7, 6]探讨了整合Transformer架构[22],利用自注意力机制将图像视为一系列连续的 Patch ,从而捕获全局信息。尽管这种方法有效,但基于Transformer的解决方案由于自注意力机制导致了图像尺寸的二次复杂度,特别是对于需要密集预测的任务(如医疗图像分割),这带来了相当大的计算开销。这在现实医疗环境中忽视了计算约束的必要性,未能满足移动医疗分割任务中参数低和计算负载最小的模型需求[18]。总之,一个未解决的问题依然存在:“如何在不增加额外参数和计算负担的情况下,赋予UNet容纳长距离依赖的能力?”
近期,状态空间模型(SSMs)在研究行人中引起了广泛关注。在现代SSM研究[10]的基础上,现代SSM(例如Mamba[4])不仅建立了长距离依赖关系,而且输入规模的线性复杂性也使其成为轻量级UNet道路上CNN和Transformer的强大竞争对手。一些当代尝试,如U-Mamba[14],提出了一个混合的CNN-SSM块,结合了卷积层的局部特征提取能力与SSM在捕捉纵向依赖关系方面的专长。然而,U-Mamba[14]引入了大量参数和计算负载(1.7353亿个参数和18,057.20 GFLOPs),这使得在移动医疗环境中部署用于医学分割任务变得具有挑战性。因此,在这项研究中,作者引入了LightM-UNet,这是一个基于Mamba的轻量级U形分割模型,它在显著降低参数和计算成本的同时实现了最先进的性能(如图1所示)。这项工作的贡献有三个层面。
作者介绍了LightM-UNet,这是UNet与Mamba的轻量级融合,其参数数量仅为100万。通过对2D和3D真实世界数据集的验证,LightM-UNet超越了现有的最先进模型。与著名模型nnU-Net [8]和同期模型U-Mamba [14]相比,LightM-UNet分别将参数数量减少了116和224。
在技术层面,作者提出了“残差视觉曼巴层(RVM层)”以纯曼巴方式从图像中提取深层特征。在引入的新参数和计算开销最小的情况下,作者通过使用“残差连接”和“调整因子”进一步增强了SSM对视觉图像中长距离空间依赖关系建模的能力。
在Insightly中,与同时期的尝试[14, 23, 17]将UNet与Mamba集成不同,作者提倡在UNet内部将Mamba作为CNN和Transformer的轻量级替代品,旨在解决实际医疗环境中计算资源限制所带来的挑战。据作者所知,这代表着首次尝试将Mamba引入UNet中,作为一种轻量级优化策略。
2 Methodologies
尽管LightM-UNet支持医学图像分割的2D和3D版本,但为了方便起见,本文档使用的是LightM-UNet的3D版本来描述该方法。
Architecture Overview
所提出的LightM-UNet的总体架构如图2所示。给定一个输入图像 ,其中 、、 和 分别表示3D医疗图像的通道数、高度、宽度和切片数。
LightM-UNet首先使用深度可分卷积(DWConv)层进行浅层特征提取,生成浅层特征图 ,其中32表示固定的滤波器数量。随后,LightM-UNet结合三个连续的编码器块(Encoder Blocks)从图像中提取深层特征。在每个编码器块之后,特征图中的通道数翻倍,而分辨率减半。
因此,在-th编码器块处,LightM-UNet提取深层特征 ,其中 。
在此之后,LightM-UNet使用瓶颈块(Bottleneck Block)来建模长距离空间依赖关系,同时保持特征图的大小不变。之后,LightM-UNet整合三个连续的解码器块(Decoder Blocks)进行特征解码和图像分辨率恢复。在每个解码器块之后,特征图中的通道数减半,分辨率加倍。
最后,最后一个解码器块的输出达到与原始图像相同的分辨率,包含32个特征通道。LightM-UNet使用DWConv层将通道数映射到分割目标数,并应用SoftMax激活函数生成图像 Mask 。与UNet的设计一致,LightM-UNet也采用跳跃连接(skip connections)为解码器提供多 Level 特征图。
Encoder Block
为了最小化参数数量和计算成本,LightM-UNet采用了仅包含Mamba结构的编码器块来从图像中提取深层特征。
具体来说,给定一个特征图,其中,,,,以及,编码器块首先将特征图展平并转置成的形状,其中。
随后,编码器块使用个连续的RVM层来捕捉全局信息,在最后一个RVM层中通道数增加。此后,编码器块重新调整并转置特征图的形状为,紧接着进行最大池化操作以降低特征图的分辨率。
最终,第个编码器块输出新的特征图,其形状为。
2.2.1 Residual Vision Mamba Layer (RVM Layer)
LightM-UNet提出了RVM层以增强原始的SSM块,用于图像深层语义特征提取。具体来说,LightM-UNet利用先进的残差连接和调整因子进一步增强了SSM的长距离空间建模能力,几乎不引入新的参数和计算复杂性。
如图2(a)所示,对于给定的输入深层特征,RVM层首先采用LayerNorm,然后是VSSM来捕捉空间长距离依赖。随后,它在残差连接中使用调整因子以获得更好的性能。这个过程可以用以下数学方式表示:
紧随其后,RVM层使用另一个LayerNorm来规范化,随后利用一个投射层将转换为一个更深的特征。上述过程可以表述为:
视觉状态空间模块(VSS模块)遵循[13]中概述的方法,LightM-UNet引入了VSS模块(如图2(b)所示)进行长距离空间建模。VSS模块以特征作为输入,并将其引导到两个并行分支中。在第一个分支中,VSS模块使用线性层将特征通道扩展到,其中表示预定义的通道扩展因子。
随后,它应用了DWConv、SiLU激活函数[20],然后是SSM和层归一化。在第二个分支中,VSS模块同样使用线性层将特征通道扩展到,之后是SiLU激活函数。随后,VSS模块通过哈达玛积从两个分支聚合特征,并将通道数投射回,生成与输入形状相同的输出。上述过程可以公式化为:
其中 表示哈达玛积。
Bottleneck Block
类似于Transformer,当网络深度变得过大时,Mamba也会遇到收敛挑战。因此,LightM-UNet通过结合四个连续的RVM层来构建瓶颈,以进一步建模空间长期依赖关系,从而解决这个问题。在这些瓶颈区域中,特征通道的数量和分辨率保持不变。
Decoder Block
LightM-UNet使用了解码器块(Decoder Blocks)来解码特征图并恢复图像分辨率。具体来说,给定来自跳跃连接的和来自前一个块的输出的,解码器块首先通过加法操作进行特征融合。
随后,它使用一个深度卷积(DWConv)、一个残差连接以及一个ReLU激活函数来解码特征图。另外,一个调整因子被加到残差连接上以增强解码能力。这个过程可以用数学方式表达为:
解码器块最终采用双线性插值方法将预测恢复到原始分辨率。
3 Experiments
数据集和实验设置。为了评估作者模型的性能,作者选择了两个公开可用的医学图像数据集:LiTs数据集[1],包含3D CT图像;以及Montgomery&Shenzhen数据集[9],包含2D X光图像。这些数据集在现有的分割研究[12, 24]中被广泛使用,并在本研究中分别用来验证2D和3D版本的LightM-UNet的性能。数据被随机划分为训练集、验证集和测试集,比例分别为7:1:2。
LightM-UNet是使用PyTorch框架实现的,三个编码器块中的RVM层数量分别设置为1、2和2。所有实验都是在单个Quadro RTX 8000 GPU上进行的。采用SGD作为优化器,初始学习率为1e-4。PolyLRScheduler作为调度器,共训练了100个周期。
此外,损失函数被设计为交叉熵损失和Dice损失的简单组合。对于LiTs数据集,图像被归一化并调整至128 128 128的大小,批量大小为2。对于Montgomery&Shenzhen数据集,图像被归一化并调整至512 512的大小,批量大小为12。
为了评估LightM-UNet,作者将其与两种基于CNN的分割网络(nnU-Net和 SegResNet)、两种基于Transformer的网络(UNETR和 SwinUNETR)以及一种基于Mamba的网络(U-Mamba)进行了比较,这些网络通常用于医学图像分割竞赛中。此外,作者还采用了平均交并比(mIoU)和Dice相似度得分(DSC)作为评估指标。
比较结果。表1展示的对比实验结果表明,作者的LightM-UNet在LiTS数据集上达到了全面的最新性能。值得注意的是,与像nnU-Net这样的大型模型相比,LightM-UNet不仅表现出更优越的性能,同时显著减少了参数数量和计算成本,分别降低了47.39和15.82。
与同期的U-Mamba相比,LightM-UNet在平均mIoU方面提高了2.11%的性能。特别是对于通常太小而难以轻易检测的肿瘤,LightM-UNet在mIoU上实现了3.63%的改进。重要的是,作为一种将Mamba方法融入UNet架构的方法,与U-Mamba相比,LightM-UNet仅使用了少1.07%的参数和2.53%的计算资源。
蒙哥马利和深圳数据集[9]的实验结果总结在表2中。LightM-UNet再次实现了最优性能,并且显著超越了其他基于Transformer和Mamba的文献。
此外,LightM-UNet以其极低的参数数量脱颖而出,仅使用了1.09M个参数。与nnU-Net和U-Mamba相比,参数分别减少了99.14%和99.55%。为了更清晰地可视化实验发现,请参考图1。
图3展示了分割结果示例,表明与其他模型相比,LightM-UNet具有更平滑的分割边缘,并且不会对小型目标(如肿瘤)产生错误的识别。
使用卷积或自注意力机制的VSSM会导致性能损失。此外,卷积和自注意力引入了大量参数和计算开销。进一步地,作者观察到基于Transformer和基于VSSM的结果都优于基于卷积的结果,这证明了建模长距离依赖的好处。
作者进一步去除了RVM层中的调整因子和残差连接。实验结果表明,在移除这两个组件后,模型的参数数量和计算开销几乎没有减少,但模型的性能显著下降(mIoU下降了0.44%和0.69%)。这验证了作者在不引入额外参数和计算开销的情况下提升模型性能的基本原则。关于Montgomery&Shenzhen数据集的额外消融分析可以在补充材料中找到。
4 Conclusion
在这项研究中,作者介绍了LightM-UNet,一个基于Mamba的轻量级网络,该网络在2D和3D分割任务中均达到了最先进性能,同时仅包含1M个参数,相较于最新的基于Transformer的架构,参数减少了超过99%,GFLOPS也显著降低。作者通过在统一框架内进行严格的消融研究来验证作者的方法,这是首次尝试将Mamba作为UNet的轻量级策略。作者的未来工作包括设计更轻量的网络,并在多个器官的更多数据集上进行验证,推动它们在移动健康及更广泛领域的应用。
参考
[1].LightM-UNet: Mamba Assists in Lightweight.
点击上方卡片,关注 「AI视界引擎」 公众号