本文主要解决了什么问题
-
- Transformer模型中自注意力机制的计算成本随token数量呈平方级增长,导致GPU内存访问开销增加的问题。
-
- 现有token压缩技术与FlashAttention等融合注意力核不兼容的问题,因为前者通常依赖注意力图确定token重要性,而后者不提供中间注意力图。
-
- 现有token重要性评估方法要么需要额外训练(如引入可学习网络),要么依赖注意力图,限制了它们在无训练场景和FlashAttention等优化技术中的应用。
本文的核心创新是什么
-
- 提出了一种名为"Representation Shift"(表示偏移)的无需训练、模型无关的度量指标,通过计算token在通过神经网络层前后的表示差异(如L2距离)来评估token重要性。
-
- Representation Shift不依赖于注意力机制,能够与FlashAttention等融合注意力核无缝集成,无需注意力图或重新训练。
-
- 该方法具有通用性,不仅适用于Transformer模型,还能泛化到CNN和状态空间模型等其他架构。
结果相较于以前的方法有哪些提升
-
- 在视频-文本检索任务中,结合FlashAttention与Representation Shift的token压缩方法在UMT-B和UMT-L模型上分别实现了5.47倍和5.5倍的吞吐量提升,同时保持具有竞争力甚至更好的性能(平均R@1提升7.2%)。
-
- 在视频问答任务中,该方法在UMT-B/L上实现了约4倍/3.83倍的吞吐量提升,同时性能相当或更好(在MSRVTT和MSVD上分别观察到0.5%和0.7%的改进)。
-
- 在图像分类任务中,Representation Shift结合FlashAttention在DeiT-T/S/B上实现了1.2倍的更高吞吐量,准确率分别提升了+2.8%、+5.7%和+2.7%。
局限性总结
-
- 对于CNN架构,token压缩不能以直接的方式执行,需要采用特殊的剪枝变体(如逐行/列剪枝),可能限制其在某些CNN结构上的应用效果。
-
- 在极端剪枝情况下(如保留50%的token),性能仍然会显著下降(前50%token准确率78.0%,后50%仅51.7%),表明该方法在极高压缩率下存在局限性。
-
- 虽然Representation Shift不依赖注意力图,但仍需访问神经网络的中间表示,可能在一定程度上限制其在某些极端优化场景中的应用。
导读
Transformer模型在视觉、语言和视频领域均取得了显著成功。然而,随着任务复杂度的增加,模型规模和token数量的增长导致自注意力机制的计算成本呈平方级增长,同时GPU内存访问的开销也随之增加。为了降低自注意力机制的计算成本,已有研究提出了token压缩技术,通过丢弃冗余或信息量较低的token来减少计算量。同时,如FlashAttention等融合注意力核的开发旨在通过避免注意力图构建及其相关的I/O到HBM的传输来降低内存开销。然而,这种方法与大多数无需训练的token压缩方法不兼容,后者依赖于注意力图来确定token的重要性。在此,作者提出了Representation Shift,一种无需训练、模型无关的度量指标,用于衡量每个token表示的变化程度。该方法能够无缝地将token压缩与FlashAttention相结合,无需注意力图或重新训练。Representation-Shift进一步泛化至CNN和状态空间模型。大量实验表明,Representation Shift能够实现与FlashAttention兼容的有效token压缩,在视频-文本检索和视频问答任务中分别实现了高达5.5倍和4.4倍的显著加速。
代码 https://github.com/mlvlab/Representation-Shift
- 引言
Transformer模型最初是为自然语言处理(NLP)[55]而提出的,现已成为视觉领域的一种重要架构。在开创性工作ViTs[19]之后,大量后续研究将Transformer扩展到各种视觉任务,例如图像分类[15, 19, 37, 52, 53, 59]、目标检测[8, 23, 58, 63, 75, 78]、分割[12, 49, 76]以及视频理解[26-28, 32, 44, 51, 61, 62, 64]。尽管这些工作已被证明是有效的,但自注意力机制的四阶复杂度仍然是一个关键 Bottleneck ,限制了基于Transformer的架构的可扩展性。
为解决这一问题,研究行人提出了多种方法来加速Transformer模型,这些方法涵盖了视觉和自然语言处理(NLP)等多个领域。早期研究通过提出Sparse注意力机制[3, 25, 57, 66]和架构改进[15, 37, 48, 54, 59, 73]来减轻计算负担,这些方法通过低秩近似和Sparse注意力模式等技术来近似自注意力机制。然而,这些方法通常会导致与原始Transformer架构的结构偏差,从而使其与广泛使用的预训练模型不兼容。因此,在实践应用中,传统Transformer[19]仍然是主流选择,这得益于其在各个领域广泛可用的预训练模型。在此,FlashAttention[16]是一种加速预训练传统Transformer的 promising 方法,它在保持原始公式的同时优化了GPU内存访问。虽然FlashAttention最初主要关注大语言模型的长序列,但近期研究[1, 11, 42, 60, 64]也表明其在视觉Transformer上同样能实现显著加速。加速视觉Transformer的另一条研究路线是token压缩[4, 13, 24, 29, 33, 39, 41, 43, 46, 56, 69, 71],通过剪枝或合并token来降低计算成本。由于确定保留哪些token至关重要,先前研究将token重要性度量作为基本步骤。一些方法[41, 46, 69]引入额外的可学习网络来预测token重要性,而其他研究[13, 20, 29, 39, 56]则采用基于注意力的启发式方法作为token重要性的替代方案。尽管这些研究在视觉Transformer上展示了 promising 的加速效果,但采用可学习网络的方案需要额外的训练,使其无法以无训练的方式实现。此外,基于注意力的评分方法在注意力图不可用时(例如FlashAttention、CNN)会限制其应用。虽然FlashAttention单独就能提供显著加速,在DeiT-S上实现1.5倍的加速,在UMT-B上实现2.7倍的加速,但现有的token剪枝方法由于依赖可学习模块或注意力图,在无训练设置中无法进一步提高效率。
为解决这一问题,作者提出了一种基于表示偏移的无训练和模型无关的token重要性准则,该准则量化了token嵌入在层前后发生变化的情况(图2)。这种简单而有效的方法成功地捕捉了任何操作(如FFN、Attention和卷积)所放大信息量。通过利用表示偏移作为重要性度量,Representation-Shift能够有效识别并移除冗余token。由于Representation-Shift不依赖于注意力机制,因此它不仅适用于Transformer,还能泛化到CNN [21, 38, 65]和SSM [30, 36, 77]等架构,并能与FlashAttention等融合核操作无缝集成以实现高效推理。实验结果表明,Representation-Shift在准确性和效率方面均优于现有的基于注意力的token重要性方法,特别是在vanilla Transformer上。具体而言,在多个视频-文本检索基准测试中,作者使用UMT [32]实现了约5.5倍的吞吐量提升。此外,与先前的依赖注意力的方法不同,Representation-Shift还泛化到了先前不受支持的架构,如CNN和状态空间模型。
总之,作者的主要贡献如下:
- • 作者提出了一种名为表示偏移的新方法,用于估计token的重要性,该方法直接捕捉每个操作放大信息量的程度。这种与模型无关的重要性评分可以在无需训练的方式下计算,且开销可以忽略不计。
- • 据作者所知,这是首个同时适用于FlashAttention和CNNs的token减少方法。
- • 通过在视频和图像理解任务上进行的大量实验,作者证明了结合FlashAttention与基于表示迁移的 Token 剪枝能够显著提升推理速度。
- 相关工作
高效视觉Transformer。基于ViTs [19],自注意力 [55] 被引入以处理各种视觉任务。后续工作 [35, 70],如DeiT [52],进一步提高了视觉Transformer的数据效率。然而,尽管性能具有竞争力,自注意力相对于token数量的二次成本仍然是主要 Bottleneck 。为解决此问题,早期工作 [14, 22, 25, 45, 57, 66] 尝试寻找自注意力的有效近似。例如,Reformer [25] 通过哈希函数实现了
复杂度,而Linformer [57] 通过低秩矩阵近似自注意力,从而得到
的线性成本。Nystromformer [66] 和 performer [14] 也提出了自注意力的线性近似。与此同时,一些工作 [3, 9, 48, 73] 专注于Sparse化注意力图以降低复杂度。类似地,最近的视觉Transformer [15, 37, 54, 58, 59] 减少了 Key和Value token的数量。PVT [58, 59] 引入了空间降采样注意力,在注意力计算前对 Key和Value token进行下采样,而Swin [37]、Twins [15] 和 MaxViT [54] 也应用局部注意力以减少参考token的数量。此外,为在边缘设备中部署,一系列工作 [7, 20, 34, 40, 72] 已被提出。最近,为通过内存受限操作减少延迟,FlashAttention [16] 在快速SRAM内进行注意力计算,以最小化对慢速HBM的内存访问。在本工作中,作者旨在通过token压缩进一步提升FlashAttention的性能。
Token压缩。由于成本高度依赖于token数量,近期研究[4, 13, 24, 33, 39, 41, 43, 46, 56, 69, 71]明确聚焦于压缩token。为在压缩token后保留图像的核心信息,这些方法通常通过剪枝或合并不重要token来实现。重要性评估通常遵循两种主要方法。首先是引入额外的可学习网络来预测重要性。例如,AdaViT[41]和DynamicViT[46]引入额外的可学习决策网络来选择待压缩的token,而A-ViT[69]也需要训练额外参数来计算token的重要性。第二种方法是利用中间注意力分数作为衡量重要性的 Agent 函数。具体而言,EViT[33]和BAT[39]使用类别token的注意力分数来近似token的重要性,该分数反映了每个token对最终预测的影响。Zero-TPrune[56]通过受Page Rank[5]启发的注意力图和排序方法来衡量token的信息量。在视频领域,vid-TLDR[13]基于注意力分数的熵来捕获显著区域。尽管上述方法在压缩token时已证明在可接受的速度-精度权衡下有效,但它们需要额外的训练或注意力图。请注意,FlashAttention不提供中间注意力分数以最小化HBM的内存访问。因此,尽管FlashAttention比标准自注意力快得多,但以无训练方式应用先前的token压缩方法并不直接。
- 方法
3.1. 预备知识
在视觉Transformer[19, 52, 53]中,输入图像首先被分割成一组图像块
,称为token,其中
是token的数量,
是图像的分辨率,
是块的大小。这组token随后通过自注意力机制进行处理,定义为:
是一个可学习的投影矩阵。这个过程产生了
的二次成本。为了降低这种成本,最近的研究工作[13, 33, 41, 46, 69]明确地剪枝掉信息较少的token,从而得到一个缩减的token集合
,其中
是剪枝掉的token数量。
Token 的重要性,
,通常通过注意力图进行估计,
,这是自注意力过程的副产品。例如, Token 的重要性可以定义为相对于类别 Token
的注意力分数。
或者,可以看作是所有 Query 向量上的总结性注意力:
其中
。尽管这些基于注意力的分数已被证明是衡量 Token 信息性的有效替代指标,但在注意力图不可用的情况下(例如FlashAttention [16]),它们并不适用。在作者的初步实验(表1)中,FlashAttention在视觉和视频Transformer(例如DeiT [52]和UMT [32])中也比标准注意力带来了显著的加速。尽管结果令人鼓舞,但作者无法通过之前的基于注意力的 Token 压缩方法进一步优化它。在这里,作者旨在开发一种简单而有效的模型无关方法,以无训练方式量化 Token 的重要性。
3.2. Token 重要性表示迁移
在作者的初步实验中,作者观察到通过网络层的token表示偏移反映了它们对模型预测的贡献。在这里,作者首先定义表示偏移,然后提供定性和定量的结果来验证它。给定输入token
,重要性分数
的表示偏移定义为
其中
表示层变换(例如 Attention 和 MLP),而
是距离度量,如 L2 距离,即
。换句话说,表示偏移反映了每个 Token 被函数强调的程度。作者的中心假设是关键 Token 倾向于具有更高的表示偏移,因为网络鼓励它们强调核心信息或抑制冗余信号。相反,表示偏移最小的 Token 可能与目标任务无关。
为验证这一假设,作者使用DeiT-S [52]在图像分类(ImageNet1K [18])任务上,以及UMT-B [32]在视频文本检索(MSRVTT [68])任务上进行了玩具实验,并将结果总结在图3中。为了比较,作者首先分别使用基于注意力的指标和作者的表示偏移来评估token重要性,然后在每个层中移除得分最低的k个token([0,2,4,6,8])。作者为DeiT使用k=40,为UMT使用k=1100。对于基于注意力的评分,作者为DeiT选择[33,39]中使用的公式(2),为UMT选择公式(3),因为视频Transformer中通常不存在类别token。此外,对于表示偏移,作者计算注意力层前后token表示的L2距离,表示为
。如图3所示,基于表示偏移的剪枝相比于基于常用注意力评分的剪枝,实现了具有竞争力或更好的性能。作者证明了表示偏移是token重要性以及传统基于注意力评分的充分近似。值得注意的是,Representation-Shift引入了没有额外的可学习参数,并且即使在中间注意力图不可用的情况下(例如FlashAttention的情况),仍然适用。
作者还对DeiT中表征转换进行了定性分析(图4)。有趣的是,它捕捉到了前景目标,这与显著性检测的概念相吻合。换句话说,作者可以通过基于所提出的分数对与主要内容无关的token进行压缩来抑制它们的噪声。基于定量和定性分析,作者强调了表征转换对token重要性的有效性。在下一小节中,作者将对表征转换进行深入探讨。
3.3. 对表征迁移的探索
操作选择。给定
,视觉Transformer的注意力块通常计算如下
LN表示Layer Normalization。作者研究了表征偏移操作选择的影响,特别是针对三种情况:通过(i)注意力机制实现表征偏移,即
(ii)MLP实现表征偏移,即
,以及(iii)包含公式(5)和(6)的整个注意力模块实现表征偏移,即
。作者进行了消融实验,以评估每个指标作为 Token 重要性的替代方案的效能。在上一节相同设置下,作者根据计算出的L2距离分数,每层剪枝固定数量的 Token ,并评估其对整体模型性能的影响。图5a显示,通过仅使用MLP引导的 Token 剪枝通常在所有层和模型中表现优于其他指标。由于注意力层本质上促进了 Token 间的信息交换,其转换可能更为分散。相比之下,MLP独立地对每个 Token 进行操作,导致更具区分性的表征偏移,能够捕捉 Token 特有的贡献。基于这些发现,作者采用MLP的表征偏移作为 Token 重要性的主要衡量指标。
距离度量。作者进一步探究哪种距离度量
最适合用于估计表示偏移。一种直接的方法是 (i) L2范数,即
,它计算输入和输出表示之间的欧几里得距离,捕捉转换的绝对幅度。作者还研究了 (ii) L1范数的有效性,即
,它对离群值更鲁棒。此外,(iii) 余弦距离(Cos),即
,计算向量之间的角度差异,强调方向变化而非幅度。为了比较距离度量,作者计算MLP层前后表示偏移,即
,并丢弃 Token 。如图5b所示,与其他距离度量相比,L2距离在 Token 重要性方面始终产生更鲁棒的结果。作者的分析表明,余弦相似度在Transformer的深层中对评估 Token 重要性不是最优的。此外,尽管L1距离在第一层表现良好,但在后续层中始终相对于L2距离表现较差。因此,作者将使用L2距离作为表示偏移的默认距离度量。
- 实验
在本节中,作者将展示4.1节视频理解任务的成果,4.2节图像分类的结果,以及4.3节对所提出方法的分析。
4.1. 视频理解
设置。为验证表示迁移的有效性,作者首先基于表示迁移对多个视频任务进行了token剪枝,其中跨帧的token数量带来了巨大的计算成本。作者使用UMT [32],一个基于vanilla attention构建的视频Transformer,作为视频文本检索 [2, 6, 10, 31, 47, 68] 和视频问答 [67] 的 Baseline 。为了与基于attention的分数进行比较,作者还使用方程(3)中的平均attention分数,因为视频Transformer中没有类token。作者分别对UMT的前三层进行token数量减少20%和10%,通过应用基于这两种指标的剪枝,分别用于视频文本检索和视频问答。在表示迁移的情况下,作者使用FlashAttention,因为基于attention的分数与其不兼容。所有实验均以无训练的方式进行。
视频文本检索。在视频文本检索中,模型根据视频检索最相关的文本(视频到文本检索,V2T)或为文本 Query 找到最相关的视频(文本到视频检索,T2V)。作者在七个基准上报告了V2T和T2V结果的调和平均值:MSRVTT [68]、MSVD [10]、ActivityNet [6]、DiDeMo [2]、LSMDC [47]、SSV2-Label/Template [31]。为了比较效率,作者使用单个NVIDIA RTX A6000,批处理大小为20,并测量和提供FLOPs(G)和吞吐量(vid/s),其中视频由12帧组成,分辨率为
。基于没有 Token 剪枝的 Baseline 模型(Base),作者分别应用基于注意力分数的 Token 剪枝(Att)和表示偏移(Ours)。结果如表2所示。由于作者的表示偏移使 Token 剪枝能够与FlashAttention协同工作,它在UMT-B和UMTL中分别带来了
和
的加速。与基于传统注意力分数的标准注意力 Token 剪枝方法相比,Representation-Shift将吞吐量提高了近一倍。此外,尽管推理速度更快,但Representation-Shift显示出具有竞争力的甚至更好的性能,在注意力剪枝与ActivityNet的UMT-L相比时,实现了高达
的提升。平均而言,作者观察到在
𝟙
上
的改进。值得注意的是,应用表示偏移的 Token 剪枝比简单地缩小模型提供了更有利的速度-精度权衡,因为具有表示偏移的UMT-L(66 vid/s)的吞吐量比 Baseline UMT-B(32 vid/s)高约2倍,同时始终超越它。
作者进一步探索了表征迁移在其他token压缩工作中的应用,通过替换高效视频Transformer的token合并方法vid-TLDR[13]中的重要性度量。遵循vid-TLDR的原始配置,包括压缩率和层选择,作者在视频-文本检索任务上报告了结果。如表3所示,作者展示了表征迁移相对于其他token压缩的显著优势。最初,vid-TLDR采用基于注意力的度量来检测图像中的显著区域,因此与FlashAttention不兼容。然而,通过用作者的表征迁移替换重要性度量,作者可以结合vid-TLDR和FlashAttention的效率。具体而言,在相同的压缩率下,作者的表征迁移在UMT-B和UMT-L上实现了平均加速3.74倍和3.67倍,同时性能下降最小。
视频问答。作者还展示了所提出方法在视频问答(video QA)任务中的效率。在视频问答中,模型针对给定视频生成相关问题的答案。为了评估这一点,作者在MSRVTT-QA、MSVDQA基准测试[67]上评估了每种方法,并将结果汇总在表4中。类似于视频文本检索,作者比较了三种情况:未剪枝的 Baseline 模型(Base)、基于注意力机制的 Token 剪枝模型(Att)以及(Ours)。与Base模型相比,作者展示了有前景的提升,在UMT-B/L中实现了约4倍/3.83倍更高的吞吐量。此外,尽管比传统的基于注意力的剪枝更快,Representation-Shift实现了相当或更好的性能。值得注意的是,在UMT-L中,作者在MSRVTT和MSVD上分别观察到0.5%和0.7%的显著改进。
4.2. 图像分类
视觉Transformer。作者在ImageNet1K [18]上进行图像分类实验。对于视觉Transformer,作者使用未经额外训练的DeiT [52],并报告在512批大小的条件下top-1准确率和吞吐量。为了比较,作者使用EViT [33]中用于类 Token 的注意力分数(公式(2)),以及BAT [39]。对于表示迁移,作者使用与视频理解相同的设置(L2,MLP),并结合FlashAttention。在量化DeiT中[1,4,7]层中 Token 的重要性后,作者在每一层剪枝了20%的 Token 。如表5所示,尽管剪枝的 Token 比例相同,但Representation-Shift始终优于基于注意力的分数。具体而言,结合FlashAttention,表示迁移在DeiT-T/S/B上实现了1.2倍的更高吞吐量,准确率分别提升了+2.8%、+5.7%和+2.7%。作者认为表示迁移比传统注意力分数提供了更鲁棒的重要性分数,从而导致了显著的性能差距。
卷积神经网络和状态空间模型。由于表示偏移是一种与模型无关的方法,用于估计 Token 的重要性,因此它自然地扩展到先前在 Token 压缩中未被充分探索的其他架构。为此,作者首先在ImageNet1K上使用ResNet [21] 进行实验。在卷积神经网络中,作者测量每个阶段之前和之后的表示偏移,因为ResNet不包含多层感知机。由于ResNet中的卷积操作仅与2D网格结构工作,因此卷积神经网络中的 Token 剪枝不能以直接的方式执行。因此,作者考虑了两种 Token 剪枝的变体:i) 从每一行和每一列中移除最不重要的 Token ( Token 级,T-W),以及ii) 计算每一行和每一列的表示偏移平均值,然后从具有最低平均值的行和列逐行剪枝 Token ,类似于[50]。具体来说,通过每种方法,作者在第一阶段后移除8列和8行,在第二阶段后移除4列和4行。由于卷积神经网络中的 Token 压缩会改变分辨率,作者微调模型100个周期,包括10个冷却周期以细化这一变化。表6显示,使用表示偏移的两种剪枝方法在ResNet中都带来了显著的吞吐量提升。作者观察到两种剪枝方法至少提高了18%的速度。特别是,逐行剪枝显示出非常有竞争力的性能,与未经剪枝的 Baseline ResNet相比,实现了7112/3553(张/秒)的更高吞吐量,而ResNet-34/50的原有吞吐量为5811/2927(张/秒)。
作者在表7中使用视觉Mamba(ViM)[77]和状态空间模型(SSM)验证了表征迁移。总体而言,作者基本遵循了ToP-ViM [74]的设置,该模型旨在通过基于激活值剪枝token来加速SSM。作者观察到,在Top-ViM相似的吞吐量下,ViM-T的改进达到了+0.4%。这些结果表明,表征迁移是一种适用于各种架构的通用方法。
4.3. 分析
定性结果。为了更深入地理解表征迁移的行为,作者通过可视化提供定性分析。在图6中,给定图像样本(左侧),作者定性比较了文献[33, 39]中使用的基于注意力的公式(2)的注意力分数,以及作者提出的基于DeiT-B[52]的表征迁移方法(包含12层注意力层)。为了研究每种方法在早期、中期和深层中的行为,作者在模型的第1层、第5层和第9层进行评估。首先,在早期阶段
,注意力图通常表现出低可靠性,正如先前工作[13, 33]中讨论的,这并不是一个理想的用于 Token 重要性的方法。另一方面,作者的表征迁移方法即使在第一层也能成功检测前景目标。在中期层
,表征迁移仍然比注意力分数更好地捕获主要内容。最后,众所周知,在视觉Transformer中,随着层级的传递,全局信息会聚集在少数特定 Token 中,这些 Token 具有更高的注意力分数[17]。在这方面,模仿后期层(
)的注意力图,以避免信息损失并保留信息丰富的 Token ,会更好。总结来说,表征迁移减轻了早期层注意力分数的低依赖性,并在中期层找到显著区域,帮助模型捕获细粒度模式。此外,它还能够在后期层捕获具有High-Level语义的 Token 。
此外,作者在图7中展示了ResNet-50 [21]每个阶段的表示变化。结果表明,前景 Token 的嵌入在每个阶段的变化趋势比背景 Token 更为剧烈。换句话说,网络对前景 Token 进行更积极的更新,而背景 Token 由于重要性较低,仅经历微小的更新。因此,表示变化本质上反映了 Token 对任务的信息量,使得在不影响整体性能的情况下进行 Token 剪枝,如表6所示。
可靠性分析。为评估表示偏移作为重要性指标的可靠性,作者在ImageNet1K [18]上使用DeiT-S [52]进行极端剪枝实验,其中作者保留表示偏移得分排名前或后50%的token。如表8所示,在所有Transformer层(L1-L11)中,保留前50%的token始终显著高于保留后50%,这证明了重要性信号的鲁棒性。平均而言,前50%的选择达到了78.0%的准确率,而后50%仅达到51.7%,导致26.3%的显著性能差距。各层间的一致性差距验证了表示偏移能够有效识别信息量大的token,支持其可靠性。
- 结论
本文提出了一种基于表征偏移的无训练、模型无关的token重要性准则,有效量化了每个操作的信息贡献。与常规方法不同,Representation-Shift独立于注意力图,能够与FlashAttention无缝集成,同时实现具有竞争力的准确性和显著的推理速度提升。此外,其适用性不仅限于Transformer,还扩展到CNN,使其成为在保持性能的同时提高各种视觉模型效率的多功能方法。此外,作者通过定性分析证明,Representation-Shift在早期和中期层更有效地检测前景目标,在后期层更有效地检测信息性token,突显了其作为改进的token重要性准则用于高效token压缩的潜力。
参考
[1]. Representation Shift: Unifying Token Compression with FlashAttention
