点击下方卡片,关注 「AI视界引擎」 公众号
( 添加时备注:方向+学校/公司+昵称/姓名 )
高效建模大面积2D上下文对于大型全景组织切片成像(WSI)和遥感等领域至关重要。基于Transformer的模型能够提供高并行性,但处理长序列时由于其二次复杂度而面临挑战。
最近,Mamba引入了一种具有线性复杂度和高并行性的选择性状态空间模型(SSM),这使得在1D序列中有效地建模广泛上下文成为可能。然而,将Mamba扩展到视觉任务中,而这些任务本质上涉及2D结构,则会因为1D序列处理的局限性而导致空间上的不一致。
另一方面,当前的2D SSMs本可以建模2D结构,但由于缺乏高效的并行算法,它们的计算速度变得难以承受。在这项工作中,作者提出了一种新颖的2D选择性SSM框架——2DMamba,它将图像的2D空间结构融入到Mamba中,并采用高度优化的硬件感知算子,兼顾了空间连续性和计算效率。
作者在WSI和自然图像上验证了该方法的灵活性。在10个公开数据集上进行的WSI分类和生存分析实验表明,2DMamba在AUC上提高了2.48%,Fl分数提高了3.11%,准确率提高了2.47%,C-指数提高了5.52%。
此外,将作者的方法与VMamba结合应用于自然影像,在ADE20k语义分割数据集上可获得0.5至0.7的mIoU改进,在ImageNet-1K分类数据集上则实现了0.2%的准确性提升。
unset
unset1. Introductionunset
unset
在2D视觉域中高效理解大规模上下文对于医学影像和遥感等领域至关重要。虽然循环神经网络(RNNs)[34]可以在长序列中建模宽泛的上下文,但其顺序结构限制了并行性,使其难以充分利用GPU资源。为了解决这一问题,Transformer [10, 38]因其高并行能力成为处理长序列的主流方法,尽管它们的复杂度为二次阶。作为解决方案,Mamba [8, 13]通过结合线性时间复杂度和并行性,成为了一种有前途的方法。Mamba是一种状态空间模型(SSM),这是一种控制理论中的数学框架,用于捕捉状态变量之间的动态交互[12, 19, 33]。它引入了一种选择机制,增强了SSMs的灵活性,使其能够捕捉到重要信息并忽略无关的上下文。
Mamba 是一种仅处理一维序列的语言模型,并被扩展至视觉领域 [24, 44]。鉴于视觉任务的二维性质,基于 Mamba 的方法在处理自然图像时尝试通过采用更连续的空间扫描模式 [40] 或同时使用多条扫描路径 [24, 40] 来引入二维图像结构。尽管取得了一定进展,但这些方法仍然依赖于 Mamba 的一维扫描过程,这可能会导致几何一致性的误导性表示,并如图1 所示产生空间差异。
一种替代方案是使用2D结构自监督模型(2D SSMs)来维持二维结构的空间连续性[2, 21, 22]。然而,与高度并行的Mamba架构不同,要实现这些方法的并行实施仍然面临很大挑战。因此,类似于传统的RNN,这些方法会经历非常缓慢的计算,几乎使其变得不可行。此外,它们缺乏Mamba的选择机制,导致性能不佳。
除了在通用视觉任务中的应用之外,Mamba 在计算病理学领域展现出了巨大的潜力,特别是在 Gigapixel 全视野显微图像(WSIs)的分类中,WSIs 是癌症诊断的黄金标准 。WSIs 是高分辨率的组织样本图像,通常在 40 倍放大下可达 100,000×100,000 像素,因此它们非常庞大且富含空间细节。由于其巨大的尺寸,WSIs 通常以多实例学习(MIL)的方式进行分析:传统的基于集合的 MIL 方法将 WSI 转换为“集合”(实例/ Patch )的集合,这些 Patch 通常独立聚合,忽略了 Patch 之间的空间感知能力。相比之下,基于 Mamba 的方法将 WSI 视作 Patch 序列 [11, 41],这能够更有效地整合信息,并可能增强诊断洞察力。然而,它们也依赖于 Mamba 的一维扫描性质,空间差异仍然是一个问题,如图1 所示。
在本文中,作者提出了一种名为2DMamba的新框架,它克服了Mamba的一维性质以及二维状态空间模型的顺序性限制。具体来说,作者提出了一种新颖的二维选择性状态空间模型架构,该架构可以直接扫描二维图像而无需首先将其展平为一维序列。作者还提出了一种新型硬件感知的二维选择性扫描操作符,该操作符将一维Mamba的并行性扩展到了二维。通过在两个非常不同的领域——用于MIL聚合的大像素WSI以及自然图像上实现作者的架构,验证了其多功能性。在针对WSI分类和生存分析的10个公开数据集上的广泛实验表明,作者的方法分别在AUC、F1、准确率和C-index上实现了最高达2.48%、3.11%、2.47%和5.52%的相对提升。此外,作者将扫描方法整合到当前最先进的方法VMamba [24] 中。在ADE20k语义分割数据集上,作者的方法在mIoU上超越SOTA方法0.5至0.7个百分点;在ImageNet-1K分类数据集上,作者的方法在准确率上超越SOTA方法0.2个百分点。
unset
unset2. Related l workunset
unset
状态空间模型(SSM)。SSM [20] 是一种有效的序列模型,通过定义隐藏状态及其转换来表示随时间演变的系统,使其特别适用于捕捉序列数据中的动态时序行为。Gu等人 [15] 将RNN、时序卷积和神经微分方程统一到一个线性状态空间层中,并利用HiPPO初始化展示了基于SSM的方法的潜力。S4 [14] 提出将参数矩阵归一化为对角结构。2D-SSM [2] 采用Roesser的2D-SSM递归公式 [21] 并将其应用于2D图像。然而,在Mamba之前,所有这些SSM方法都面临着由于状态之间的序列依赖性使得高效并行算法难以实现的问题。
Mamba. 为了加快SSM方法的速度,Mamba [13]引入了一种选择机制,使模型参数依赖于输入,并通过遗忘较少相关的状态来消除长依赖性。它还引入了一种硬件感知算法,大幅加速了状态计算。最初,Mamba应用于语言任务,Vim [44]引入了一个Vision Mamba模块,该模块使用两个独立的选择性SSM进行双向信息聚合。PlainMamba [40]使用了四方向选择性扫描,并采用了更具空间连续性的扫描路径。类似地,VMamba [24]和GroupMamba [36]也在分层网络中利用了这种四方向扫描,并优化了网络结构。然而,当前这些基于Mamba的模型仍局限于1D。
WSI 分类中的多实例学习应用。在 WSI 分类中,多实例学习方法是主流技术。它通过聚集 WSIs 中嵌入的特征来实现切片 Level 的表示。AB-MIL [18] 引入了一种基于注意力的聚集方法,其中注意力值由神经网络学习得到。在此基础上,CLAM [27] 提出了一种多分支池化机制以提高性能。DSMIL [23] 在双流架构中使用了多尺度 Patch 特征。TransMIL [37] 引入了多头自注意力层来捕捉形态学和空间关系,并使用 Nystrom 注意力 [39] 来缓解自注意力的二次复杂度。
DTFD-MIL [42] 通过引入伪袋来引入双层级的 MIL 架构。最近,S4-MIL [11] 和 MambaMIL [41] 使用 Mamba 来更好地捕获长序列 Patch 中的信息。然而,这些工作仍然未能充分利用 WSI 的二维空间信息。
unset
unset3. Methodunset
unset
作者提出了一种名为2DMamba的有效且高效的设计,并配套了一个用于WSI表示的框架:2DMambaMIL。
3.1.SSMinMamba and 1Dselective scan
作者重新审视了状态空间模型(SSM),这是一种用于捕捉动态系统行为的数学模型。SSM被设计为一个函数到函数的模型,适用于连续系统;经过离散化后,它变成了一个序列到序列的模型。
其中, 是时间 的潜在状态, 是输出,且 表示状态维度。参数 和 是时不变的,因此不会根据输入进行适应。这种设计限制了状态空间模型(SSMs)处理长序列输入时的情境感知能力。
Mamba块[13]引入了一种选择机制,允许SSM动态适应输入上下文。该机制将重要的输入聚合到隐藏状态中,而不重要的输入可以被忽略。从数学上讲,参数被表示为输入的函数:
其中,、 和 是可学习的线性函数,依赖于 和 表示离散化的时步。Mamba 块中的选择机制通常称为选择扫描。为了更好地与作者的二维方法区分开来,作者称这种扫描为一维(ID)选择扫描。
3.2.Architectureof2DMambaMm
2DMambaMIL的整体架构如图2所示。2DMambaMIL包括:层的2DMamba块和一个聚合模块。第一部分是由层2DMamba块堆叠而成。作者采用了原版Mamba块[13]的设计,并将原始的一维选择扫描替换为作者的二维变种。第二部分是一个聚合模块,这是一个基于注意力机制的模块,包含两个线性投影,生成滑动特征。
作者的2DMambaMIL首先将输入的WSI分割成patches ,其中且,和分别表示沿高度和宽度方向分割出的patches数量。这些patches根据其类型被不同的方式嵌入。组织patches使用预训练的病理特征抽取器进行嵌入。此外,作者提出使用一个可学习的token 来表示填充以获得2D矩形特征图的非组织patches。
这使得模型在训练过程中能够学习到非组织区域的适当表示。形式上,WSI被转换成一个具有矩形形状的特征图。
3.3.2DselectiveSSM architecture
作者详细阐述了2D选择性SSM架构。2DMamba的关键组件是2D选择性扫描操作。与传统的mamba不同,2DMamba直接从2D特征图中聚合几何和语义信息,而不是从展平的一维序列中聚合信息。特别地,2DMamba同时并行执行水平和垂直扫描。为了简化表示,在本节中省略了状态维度的上标。2D选择性扫描的操作参数保持与一维情况相同(参见公式(3)),其中下标为以索引2D输入,而不是。作者使用来表示图2中归一化、投影和卷积层之后的2D选择性扫描的输入。
如图2所示,作者首先独立地对每一行进行水平扫描。这一操作类似于在每一行上应用1维选择性扫描。具体而言,在水平扫描过程中获得的状态 ( h_{i,j}^{\text{hor}} ) 为:
对于第一列,作者假设 ,因此 。两个参数 和 依赖于 ,用于调节先前状态 和当前输入 的信息。在进行水平扫描后,作者独立地对每个 列应用垂直扫描。与水平扫描相比,在垂直扫描中作者将 替换为从水平扫描获得的结果 。
对于第一行,假设,则有。作者为垂直扫描reuse相同的。
如果作者省略和的下标,并展开式(5)和式(6),隐藏状态可以表示为以下递推公式:
其中, 表示位置与之间的曼哈顿距离,代表从到的一条路径,该路径先水平向右移动,再垂直向下移动。经过两次扫描后,输出通过参数从中聚合得到,类似于1D-Mamba:。对于每个位置,其聚合信息来自其左上角的位置。
2DMamba 聚合信息而不产生空间差异。如图1所示,橙色和蓝色块分别对应 和 。对于1DMamba 扫描和作者的2D扫描,橙色块的隐藏状态为:
请注意,在1D扫描中,的幂次为3,而在作者的2D扫描中,该幂次仅为1。由于这些相邻Patch在空间上相连,作者的公式反映了它们之间的邻接关系。相比之下,1D Mamba可能导致空间上的不一致性,并可能稀释的信息。
3.4. Hardware-aware 2D selective scan operator
作者提出了一种硬件感知的扫描操作符,可以加速二维选择性扫描。首先,作者回顾了GPU内存层次结构,并分析了二维选择性扫描的主要挑战。然后,作者详细介绍了作者的新型操作符。
GPU 内存层次结构。图3(d)展示了现代 GPU 的内存层次结构。绿色区域代表离片 GPU 内存,速度较低但容量较大,被称为高带宽内存(HBM)。橙色区域表示片上内存,速度快但容量小,称为静态随机存取存储器(SRAM)。在 GPU 算法中,数据从 HBM 转移到 SRAM 进行计算,计算结果再存回 HBM 以释放 SRAM 给后续计算。内存传输成本高昂。因此,在许多 GPU 算法中,尽管算法本身可能涉及大量计算,但内存访问 Bottleneck 限制了其性能表现。Mamba 的选择性扫描[13]同样受到内存限制。
Mamba的一维选择性扫描。vanilla Mamba模型之所以快速,是因为它通过一维平铺和缓存尊重了GPU的内存层次结构。如图3(a)所示,长特征序列在HBM中被划分为更小的块。每个块加载到SRAM中,沿着个独立的状态维度进行扫描,并根据公式(2)中指定的规则聚合为单一输出,最终再存储回HBM。状态维度上的中间结果不会被实际存储在HBM中,而是在反向传播过程中重新计算。整体的内存访问复杂度为,其中表示序列长度。
朴素的二维选择性扫描。将1D Mamba扫描扩展到2D并不是一件容易的事情。如图3(b)所示,对1D Mamba进行简单的二维扩展会分为两个步骤进行扫描。首先,特征图被分成行以进行行向的1D Mamba扫描。然后,接下来的垂直扫描必须独立应用于每一列,而每一列具有个独立的状态维度。因此,水平扫描器需要在HBM上产生个中间特征图。每个特征图随后会被分成列以进行垂直扫描。其内存访问复杂度为,这正如表3所示,导致了低吞吐量和高内存消耗。
面向硬件的2D选择性扫描。文中提出的面向硬件的2D选择性扫描操作符如图3(c)所示,通过2D分块和缓存优化了内存事务。不同于按行或列分块,作者将特征图分割成一个2D网格。在每一步中,作者只将一个小的子矩阵加载到SRAM。然后作者对N个独立的状态维度进行水平和垂直扫描,并将汇总结果写回到HBM。这避免了显式的状态维度材料化,并保持了整体的内存访问复杂度为,等同于传统的Mamba。
此外,vanilla Mamba 使用了 NVIDIA 的 CUB 库 [31] 进行 1D 并行扫描。然而,如图3(e) 所示,CUB 的 BlockScan 算法只支持全序列扫描。因此,对于多行特征图,它需要进行多次扫描。进一步地,对于一个二维特征图,CUB 的 BlockScan 要求其高度 和宽度 都是 32 的倍数,其中 32 是 NVIDIA GPU 上最小的线程调度粒度。
因此,小的特征图必须在计算前被填充,导致效率低下。例如,一个典型的 特征图每行和每列需要 18 个填充元素,浪费了高达 56% 的计算资源。为了解决这一限制,作者引入了 SegmentedBlockScan 算法,该算法如图3(f) 所示。它在行和列之间分配 GPU 线程,只需要 是 32 的倍数。这使得可以同时对多行/列进行扫描,并显著减少了小特征图的填充需求。例如,对于同样的 特征图,作者的方法每行和每列只需要 2 个填充元素。
unset
unset4. Experimentsunset
unset
4.1. Dataset
作者使用2DMambaMIL在5个公开的病理分类数据集上进行了评估,包括TCGA-BRCA [1]、BRACS [3]、PANDA [4]、DHMC [45]以及TCGA-NSCLC,同时还使用了5个公开的生存数据集,即TCGA-(KIRC、KIRP、LUAD、STAD、UCEC)。这些数据集涵盖了多种器官,包括乳腺、前列腺、肺、肾、胃和子宫。数据集中的切片数量范围从261到10614张。
对于所有这些数据集,作者都使用了20倍放大率。有关这些数据集的详细信息参见附录B。按照[24]的方法,作者还评估了2DMamba在两个自然图像数据集上的表现:ImageNet-1K分类和ADE20K语义分割。
4.2. Implementation Details
对于所有的病理学实验,作者使用在1亿张病理图像上预训练的UNI基础模型[6]来嵌入组织 Patch 。作者使用AdamW [26] 对模型进行优化,在批量大小为1的情况下进行20个epochs的训练。
初始学习率设置为0.0001,并使用余弦退火调度器进行调整。为了公平比较,作者使用基于SSM的单个模块,具有128维的SSM,并将所有Mamba基方法的状态维度设置为16。所有的病理学和自然图像实验分别在一块NVIDIA V100 GPU和八块NVIDIA A100 GPU上进行训练。
4.3. Results
4.3.1.WSI classification
作者首先将2DMambaMIL与八个其他当前最佳的MIL Baseline 方法在五个WSI分类数据集上进行比较。这些 Baseline 方法包括ABMIL [18]、CLAM [27]、DSMIL [23]、DTFDMIL [42]、TransMIL [37]、S4-MIL [11]、MambaMIL [41] 和 SRMambaMIL [41]。前四种方法是基于注意力的,TransMIL 是基于Transformer的方法,而最后三种方法则是基于一维顺序模式(1D SSM)的MIL方法。作者使用准确率(Acc)、F1分数(FI)和曲线下面积(AUC)这三种指标来评估WSI分类性能。
表1显示,作者的2DMambaMIL在多个数据集上超越了所有当前的SOTA方法,表明作者具有强大的泛化能力。与表现最佳的非Mamba方法相比,在准确率方面作者获得了最高达5.83%的提升,在F1分数方面提升了14.90%,在AUC方面提升了4.65%。2DMambaMIL在准确率方面也比基于SSM的方法高出了3.26%,在F1分数方面高出3.11%,在AUC方面高出2.48%,这表明保留WSI中的空间连续性具有显著优势。
4.3.2. WSI survival analysis
作者进一步将2DMambaMIL与之前提到的八个 Baseline 方法在五个WSI生存数据集中进行了比较。作者使用C指数来评估性能,该指数评估生存模型如何根据患者的生存时间对其排名。如表2所示,在所有评估的方法中,2DMambaMIL在所有数据集上都一致获得了最高的C指数得分,表明其预测性能更优。
具体来说,2DMambaMIL在KIRC、KIRP、LUAD、STAD和UCEC上的C指数相较于表现最好的 Baseline 分别取得了相对改进的0.6%、1.2%、5.5%、2.9%和1.0%。
4.3.3. Speed and GPU memory efficiency
作者的方法展示了高速度和较低的内存使用量。作者在三种不同输入特征尺寸下评估了浮点运算次数(FLOPs)、吞吐量以及GPU内存消耗:、 和 。
首先,作者比较了三种基于CUDA的扫描操作符:
Mamba使用的CUB 1D扫描、在第3.4节中引入的朴素2D扫描以及作者优化的2D扫描,在这三种不同输入大小且有16个独立状态维度的情况下。如表3所示,作者的2D扫描在所有输入尺寸下均显著优于朴素2D扫描的吞吐量和GPU内存效率,差距随着特征尺寸的增大而增大。
对于的输入尺寸,作者的2D扫描与Mamba的CUB扫描在吞吐量上相当。然而,随着输入尺寸的增加,其吞吐量相较于CUB扫描有所下降。这是由于2D数据更复杂的内存布局及作者加倍的计算量导致的。尽管如此,作者的2D扫描仍然保持线性内存消耗,与序列长度成正比。
作者随后评估了 Mamba 的 CUDA 实现,即作者 2DMamba 的 Python 实现,以及在 MIL 框架内的 2DMamba 的 CUDA 实现,这些评估基于三种不同的输入特征尺寸。
表3 显示,作者的基于 CUDA 的 2DMambaMIL 框架在所有指标上都优于基于 Python 的实现,这得益于作者针对硬件优化的二维扫描操作符。作者的方法的吞吐量保持在基础 Mamba 框架的 70% 至 90%。
4.3.4.Naturalimage classification
除了在病理图像上的有效性和效率外,作者的方法在自然图像分类上也表现出良好的泛化能力。作者将2DMamba应用于基于Mamba的方法VMamba [24]。作者用作者的2DMamba块替换其Mamba块,并将其命名为2DVMamba。
首先,作者在ImageNet-1K分类数据集上对其进行评估,并将其与Swin Transformer [25]、Vim [44]、EfficientVMamba [32]、LocalVMamba [17] 以及原始VMamba进行比较。表4显示,作者的2DVMamba比原始VMamba的准确性高出0.2%,并且超越了所有最先进的方法。
4.3.6.Ablationstudies
非组织填充。作者通过与一个简单的填充解决方案进行比较,即填充所有固定零 Token ,在PANDA和TCGA-BRCA数据集上消融了可学习的填充 Token 。表6显示,作者的可学习填充在准确率和AUC上的表现分别优于固定填充1.56%至4.25%,以及-0.62%至-1.58%。这表明,作者的可训练填充使扫描能够更有效地适应非组织区域。
4.3.5. Natural image segmentation
作者进一步评估了2DVMamba在ADE20K语义分割数据集上的性能。如表5所示,2DVMamba-T优于基准VMambaT,在单尺度mIoU上提高了0.7,在多尺度mIoU上提高了0.5。它超过了所有之前的 Baseline 。值得注意的是,与分类任务相比,分割性能的提升更为显著。这可能是因为分割是一个密集预测任务,保持patch间的空间连续性至关重要。
位置嵌入(Positional Embeddings,PE)。作者研究了PE对Mamba基多实例学习(MIL)的影响。作者在PANDA和TCGA-BRCA数据集上比较了带PE和不带PE的MambaMIL、SRMambaMIL和2DMambaMIL的表现。
由于WSIs的尺寸较大,采用绝对PE,如[10]所示,会导致MIL模型参数过多。因此,作者采用线性投影将每个Patch的2D坐标映射到PE,并将其添加到Patch嵌入中以整合位置信息。如表7所示,引入PE一般可以提高1D Mamba基方法的性能,表明额外的空间信息有助于缓解空间差异。相比之下,向作者的2DMambaMIL添加PE会降低其性能。这种下降发生的原因是作者2D形式有效地整合了空间信息,使得额外的PE变得多余。
4.3.7.Qualitative evaluation
作者定性比较了2DMambaMIL生成的注意力 Heatmap 与四种现有方法(ABMIL、CLAM、MambaMIL和SRMambaMIL)在分类和生存分析任务中的表现,重点关注病理学方面。
4.3.8. Visualization of Effective Receptive Fields
有效感受野(Effective Receptive Fields, ERF)是指输入空间中对某个输出单元激活有贡献的区域。作者对Swin-T、vanilla VMamaba-T以及作者的2DVMamba-T中央像素点的ERF进行了分析。如图5所示,Swin-T的ERF表现出局部模式,这与其局部结构一致。VMamaba-T则表现出更加全局的模式,但由于其四向一维扫描过程,存在明显的交叉信号。相比之下,2DVMamba则显示出更为全局且平滑的ERF,没有交叉信号,表明其在保持空间连续性方面更具优势。
在肿瘤区域和生物可解释性方面,结果表明2DMambaMIL在WSI分类和生存分析数据集中始终针对肿瘤区域,偶尔也会包含邻近肿瘤区域的部分像素。图4展示了在生存分析背景下肾透明细胞癌的一个案例。AB-MIL和SRMambaMIL主要关注非肿瘤区域,这并不需要用于风险预测,而CLAM也表现出对非肿瘤区域显著的关注。相比之下,2DMambaMIL和SRMambaMIL的注意力都集中在肿瘤区域。作者的方法展示出更异质性的注意力模式,具体聚焦于与生存高度相关的区域(用红色箭头指示),而SRMambaMIL的注意力分布则较为均匀,主要集中在一些不太与生存相关的区域(用紫色箭头指示)。
unset
unset参考unset
unset
[0]. 2DMamba: Efficient State Space Model for Image Representation with Applications on Giga-Pixel Whole Slide Image Classification .
点击上方卡片,关注 「AI视界引擎」 公众号