点击下方卡片,关注「AI视界引擎」公众号
( 添加时备注:方向+学校/公司+昵称/姓名 )
医学图像分割在临床诊断与治疗规划中至关重要。尽管基于Transformer的方法已取得显著成果,但其高昂的计算成本限制了其在临床场景中的部署。为解决这一问题,作者提出TM-UNet,一种新颖的轻量级框架,通过将token序列建模与高效的内存机制相结合,实现高效的医学图像分割。
具体而言,作者设计了多尺度token-内存(Multi-scale Token-Memory, MSTM)模块,该模块通过策略性空间扫描将2D空间特征转换为token序列,并利用矩阵内存单元(matrix memory cells)选择性地保留与传播具有判别性的上下文信息。
该新型token-内存机制充当动态知识存储器,以线性复杂度捕捉长程依赖关系,实现高效的全局推理而无需冗余计算。
此外,MSTM模块引入指数门控机制以识别token的有效性,并通过并行池化操作实现多尺度上下文提取,从而在不增加计算开销的前提下完成分层表征学习。
大量实验表明,TM-UNet在多种医学图像分割任务中均优于当前最先进的方法,同时显著降低了计算成本。
代码已开源,地址为:https://github.com/xq141839/TM-UNet。
- 引言
医学图像分割是现代医疗应用的基础,能够精确识别解剖结构和病灶区域,为诊断辅助、手术规划和治疗评估提供支持 [1]。尽管取得了显著进展,但由于成像噪声、边界模糊性以及不同模态间患者间解剖结构的差异性,实现鲁棒且准确的分割仍然具有挑战性。
U-Net [2] 的引入标志着生物医学图像分割领域的一个重要突破,其对称的编码器-解码器架构与 Shortcut (skip connections)能够有效保留多尺度解剖特征 [3, 4]。然而,卷积神经网络(Convolutional Neural Networks, CNNs)仍受限于其固有的局部感受野,这限制了全局上下文建模能力,阻碍了对复杂病灶的精确分割 [5, 6]。为克服这一局限,视觉Transformer(Vision Transformer, ViT)作为一种强大的替代方案应运而生,其通过自注意力(self-attention)机制能够捕捉长距离的空间依赖关系 [7]。基于token的表示方式使模型能够在整个医学图像上进行全局推理,显著提升了上下文理解能力与分割性能。
然而,尽管现有的医学视觉Transformer(Medical ViT)研究[3, 4]取得了进展,其实际部署仍面临显著的计算挑战,限制了其在真实临床场景中的应用。ViT架构中的自注意力机制(self-attention mechanism)相对于token数量呈现出二次计算复杂度,导致高内存占用和延迟,这与实时或资源受限的医疗环境不兼容[8]。此外,许多图像token通常对应非病理或背景区域,提供的判别性信息极少,造成冗余计算和效率低下。这些挑战促使人们亟需一种计算高效的机制,能够在保持建模长距离空间依赖能力的同时,选择性地关注具有信息量的token。
为弥补这一研究空白,作者提出了token-memory ( Token -记忆)概念,这是一种紧凑且动态的表示机制,使 Token 能够访问并更新存储在记忆单元中的共享上下文信息。与Transformer架构中反复重新计算成对关系不同,token-memory能够在时间维度上累积并传播显著的上下文线索,从而减少冗余并保持全局一致性。基于该概念,作者提出了TM-UNet,一种新颖的框架,将 Token 序列建模与多尺度 Token -记忆(Multi-scale Token-Memory, MSTM)模块相结合。
该模块通过策略性空间扫描将2D空间特征转换为 Token 序列,并利用具有协方差更新的矩阵记忆单元,在线性计算复杂度下捕捉长程依赖关系。此外,指数门控机制进一步强化具有判别性的 Token ,同时过滤冗余信息;并行池化操作则在不增加额外计算开销的前提下实现多尺度上下文提取。在多种分割任务上的大量实验表明,TM-UNet不仅达到了当前最优性能,还显著提升了计算效率。
- 方法论
Image
如图2所示,作者提出了TM-UNet框架,用于高效医学图像分割。为实现这一目标,作者设计了多尺度 Token -记忆(Multi-scale Token-Memory, MSTM)模块,以线性复杂度捕捉长距离依赖关系。
所提出的 MSTM 模块采用扩展型长短期记忆网络(Extended Long Short-Term Memory, xLSTM)[13],通过序列建模识别 token 的有效性。这种集成设计使作者的 TM-UNet 框架在实现优异分割性能的同时,具备出色的计算效率。
2.1. 前置知识
经典的 LSTM [14] 通过门控机制控制信息流,以克服循环神经网络的梯度消失问题:
其中
是内部单元状态,用于累积先前步骤的长期上下文信息。遗忘门、输入门和输出门
, 和
) 决定了从前一时刻记忆
和当前候选值
中保留多少信息。近期的 xLSTM [15] 通过指数门控机制对这一结构进行了增强,实现了更灵活且更具表现力的信息传播。其指数输入门定义如下:
其中
、
和
为可学习参数。这为信息接收提供了无界的动态范围,同时……
输出门保持sigmoid激活函数
以保证稳定性。遗忘门采用sigmoid或指数激活函数:
这种混合门控策略能够有效处理序列化医学图像特征。
2.2. 扩展的长短期记忆层
为了将第2.1节中提出的公式转化为能够捕捉长距离空间依赖的实际网络组件,采用双向矩阵LSTM(mLSTM)通过交替扫描方向来处理,以确保全面的空间建模。给定输入特征图
,该层通过两条并行路径处理特征,路径中包含线性投影和翻转操作。双向处理的公式表示为:
其中,翻转操作(flip operations)用于反转序列顺序,以实现双向空间建模,最终输出通过残差连接(residual connections)融合双向处理的结果。mLSTM 的计算提供了多种灵活的处理模式,具有不同的计算复杂度。对于序列长度
和隐藏维度
,作者采用分块模式(chunkwise mode),其复杂度为
,其中
表示分块大小(chunk size),该模式可实现线性扩展,适用于实际的医学图像处理场景。
2.3 多尺度 Token-Memory 块
多尺度特征表示对于医学图像分割至关重要,因为解剖结构和病灶区域在尺寸和空间分布上差异显著。在xLSTM层的基础上,所提出的MSTM模块实现了核心的token-memory机制,利用紧凑的记忆单元在空间token间累积并复用上下文信息。具体而言,一个多尺度池化组件
首先在不增加额外计算量的前提下,捕捉多个感受野尺度下的差异性上下文线索。随后,所得表示通过
层xLSTM进行处理,每层更新共享的基于矩阵的记忆单元,以存储显著的token交互信息,从而实现具有线性复杂度的自适应长程依赖建模。每层xLSTM后接一个
深度可分离卷积,以细化局部空间结构。该序列路径可表述为:
其中 DWConv 包含批归一化(batch normalization)和 ReLU 激活函数。这种交替结构在计算效率与表征能力之间取得了平衡。最终的嵌入表示以残差形式计算,以稳定训练过程并保留细粒度细节:
通过结合多尺度上下文感知与序列记忆建模,MSTM模块能够在保持轻量化效率的同时,实现对医学图像的分层理解,满足临床部署的需求。
2.4 TM-UNet框架的优化
完整的TM-UNet在U-Net框架内集成了所提出的MSTM模块,该框架针对高效的医学图像分割进行了优化。其采用混合编码器-解码器架构,在深层阶段战略性地嵌入MSTM模块,以捕捉长距离依赖关系,同时保持计算效率。编码器通过五个逐步下采样的阶段处理尺寸为
的输入图像。前三个阶段采用传统的卷积块,包含
卷积、批归一化(Batch Normalization)和ReLU激活函数,以高效提取局部解剖特征,生成分辨率分别为
的特征图。最后两个阶段引入MSTM模块,以增强全局上下文推理能力。
解码器与编码器结构对称,包含五个上采样阶段。每个阶段首先进行双线性上采样,随后通过卷积进行精细化处理,并通过通道维度的拼接(channel-wise concatenation)融合对应编码器层的 Shortcut (skip connections),以保留细粒度的空间细节。解码器最深层的两个阶段还引入了MSTM模块,以确保多尺度上下文建模的一致性并维持网络架构的对称性。网络采用联合分割损失(交叉熵损失与Dice损失)进行优化,兼顾像素级精度与区域级一致性:
其中
和
为权重因子。通过这种混合设计,TM-UNet 结合了卷积网络的空间归纳偏置与 MSTM 模块的全局依赖建模能力,在保持线性计算复杂度的同时,实现了精准高效的医学图像分割。
- 实验
3.1 数据集与实现细节
为验证所提出TM-UNet的有效性,作者在四个医学图像数据集上进行了全面评估:CVCCDB [16]、MoNuSeg [17]、ColonDB [18] 和 UDIAT [19]。所有数据集均按 7:1:2 的比例划分为训练集、验证集和测试集,且在训练和测试阶段所有图像均被缩放至
。所有实验均在单块 NVIDIA A100 GPU 上使用 PyTorch 完成。为保证公平比较,所有分割方法均采用相同的训练设置与配置。损失系数
和
分别设置为 2 和 1。优化器采用 Adam,初始学习率为
,并使用指数衰减策略调整学习率,衰减因子为 0.98。批量大小(batch size)和训练轮数(training epoch)分别设置为 16 和 200。
3.2 与最先进方法的对比
Image
为了全面评估 TM-UNet 的性能,作者将其与当前最先进的方法进行了广泛比较,包括经典的 U-Net 架构 [2, 9]、轻量级模型 [10, 5] 以及近期先进的方法 [11, 12]。如表1 所示,TM-UNet 在所有四个数据集上均持续优于其他方法,同时保持了卓越的计算效率。具体而言,TM-UNet 在 CVCCDB 上取得了
的 Dice 分数(相比高性能方法 Zig-RiR [12] 提升
),在 MoNuSeg 上达到
Dice(提升
),在 ColonDB 上获得
Dice(提升
),在 UDIAT 上取得
Dice 且 HD 值最低,仅为 39.74。此外,TM-UNet 仅需 6.55G FLOPs(相比 TinyU-Net [5] 降低
),并实现了最高的 FPS(87.55),相比最高效的 Baseline 方法提升
。这些结果表明,TM-UNet 有效解决了分割精度与计算成本之间长期存在的权衡问题,在显著提升效率的同时实现了 SOTA 性能,适用于实际临床部署。
3.3. 消融实验
Image
Image
为探究 MSTM 模块组件的有效性,作者在表2 中进行了全面的消融实验。从仅采用纯 encoder-decoder 架构的 Baseline 模型出发,作者逐步引入 xLSTM、多尺度池化(multi-scale pooling)和深度可分离卷积(depth-wise convolution)。xLSTM 将 Dice 分数从
提升至
,验证了其在捕捉长程依赖关系方面的有效性。多尺度池化进一步将性能提升至
,表明其在跨尺度上下文聚合方面的优势。完整版 TM-UNet 达到了
的 Dice 分数和
的 IoU,相较于 Baseline 模型分别提升了
和
,且仅带来微小的 FLOPs 增加。表3 分析了通道配置的影响:轻量级变体
在资源受限场景下实现了
的 Dice 分数,仅需 0.49G FLOPs,推理速度达 104.23 FPS;而作者的标准配置
在 87.55 FPS 下实现了最优平衡,Dice 分数达到
。上述结果验证了 xLSTM、多尺度池化和 DWConv 在 MSTM 中对 TM-UNet 性能提升的关键贡献。
- 结论
在本工作中,作者提出了TM-UNet,这是一种高效且平衡精度与计算成本的医学图像分割框架。通过在MSTM模块中引入所提出的token-memory机制,将U-Net与xLSTM相结合,TM-UNet以线性复杂度捕捉长距离依赖关系。大量实验表明,TM-UNet在保持优越计算效率的同时,性能优于现有最先进方法。
参考
[1]. TM-UNET: TOKEN-MEMORY ENHANCED SEOUENTIAL MODELING FOR EFFICIENT MEDICAL IMAGE SEGMENTATION
点击上方卡片,关注「AI视界引擎」公众号
