点击下方卡片,关注
「集智书童」
公众号
导读
高效的关注机制对于大型模型至关重要,因为其时间复杂度为二次方。幸运的是,关注机制通常表现出Sparse性,即注意力图中许多值接近于零,这允许省略相应的计算。许多研究已经利用Sparse模式来加速关注机制。然而,大多数现有工作通过利用注意力图的特定Sparse模式来优化特定模型中的关注机制。保证各种模型速度提升和端到端性能的通用Sparse关注机制仍然难以实现。
在本文中,作者提出了SpargeAttn,这是一种适用于任何模型的通用Sparse和量化关注机制。SpargeAttn使用两阶段在线过滤器:在第一阶段,作者快速且准确地预测注意力图,从而允许跳过注意力中的某些矩阵乘法。在第二阶段,作者设计了一个在线softmax感知过滤器,它不会产生额外的开销,并进一步跳过一些矩阵乘法。
实验表明,SpargeAttn显著加速了包括语言、图像和视频生成在内的各种模型,而不会牺牲端到端指标。
代码:https://github.com/thu-ml/SpargeAttn
- 引言
随着大型模型中序列长度的增加,例如视频生成和语言模型中的45K-128K,注意力机制的耗时占据了大型模型推理延迟的很大一部分。幸运的是,注意力图
具有固有的Sparse性,因为softmax操作通常会产生许多接近零的值。Sparse注意力方法通过以下方式利用这种Sparse性来加速注意力:
-
- 构建一个“Sparse Mask ”,指示注意力图
中应计算的重要非零条目 2. 2. 仅对对应Sparse Mask 的部分计算注意力
根据Sparse Mask 的生成方式,Sparse注意力方法可以分为三类。基于模式的Sparse注意力方法依赖于基于经验观察的特定Sparse模式,动态Sparse注意力根据输入动态计算 Mask ,基于训练的方法直接训练具有原生Sparse注意力的模型。
局限性。(L1. 通用性)尽管现有的Sparse注意力方法在某些任务上已经显示出有希望的速度提升,但它们的通用性仍然有限。现有工作通常针对特定任务开发,例如语言建模,利用特定于任务的模式,如滑动窗口或注意力汇聚。然而,注意力模式在不同任务之间存在显著差异(见图2中的示例),使得这些模式难以推广。(L2. 可用性)此外,对于任何输入实现既准确又高效的Sparse注意力都很难。这是因为准确性要求精确预测注意力图中的Sparse区域,而效率要求这种预测的开销最小。然而,当前方法难以同时有效满足这两个要求。例如,MInference(Jiang等人,2024)需要较长的序列长度,如100K,才能实现明显的速度提升。
目标。作者旨在设计一种无需训练的Sparse注意力算子,以加速所有模型而不损失度量。
SpargeAttn。在本工作中,作者开发了SpargeAt tn,这是一种无需训练的Sparse注意力机制,可以广泛应用于各种任务,包括语言建模、文本到图像/视频,以及各种序列长度。作者提出了三种主要技术来提高其通用性、准确性和效率。首先,作者提出了一种通用的Sparse Mask 预测算法,该算法通过将每个
块压缩成一个单独的token来构建Sparse Mask 。重要的是,作者根据块内token的相似性进行选择性压缩,因此该算法可以准确预测各种任务中的Sparse Mask 。其次,作者提出了一种在GPU warp Level 的Sparse在线softmax算法,通过利用在线softmax中全局最大值和局部最大值之间的差异,进一步省略了一些
乘法。第三,作者将这种Sparse方法集成到8位量化SageAttention框架中,以进一步加速。
结果。作者在多种生成任务上评估了SpargeAttn,包括语言建模和文本到图像/视频,并使用全面的性能指标对模型质量进行了评估。SpargeAttn能够稳健地保持模型端到端性能,而现有的Sparse注意力 Baseline 则会导致性能下降。此外,SpargeAttn的速度比现有的密集和Sparse注意力模型快2.5倍到5倍。
- 相关工作
根据Sparse Mask 的构建方式,Sparse注意力方法可以分为三种类型:
(1)需要模式的方法依赖于注意力图的一些固定模式,例如滑动窗口或注意力汇聚点。H2O、InfLLM和DUOAttention依赖于滑动窗口模式。SampleAttention、MOA和StreamingLLM依赖于滑动窗口和注意力汇聚点模式。DitFastAttn依赖于滑动窗口模式和不同注意力图之间的相似性。此外,DitFastAttn仅限于简单的扩散 Transformer ,与语言模型和MMDiT模型(如Flux、Stable Difusion3和3.5以及CogVideoX)不兼容。由于模式在不同模型之间有所不同,这些方法可能不适用于所有模型。
(2)动态Sparse方法根据输入动态构建Sparse Mask ,无需预设模式,因此可能更加通用。现有工作可以进一步分为通道压缩和 Token 压缩。通道压缩方法包括SparQAttn和LokiAttn。它们通过降低维度来携带完整的注意力构建 Mask 。然而,由于维度已经很小,例如在常用的注意力中为64、128,因此加速潜力可能有限。 Token 压缩方法包括MInference和FlexPrefill。它们通过将每个 Token 块压缩为单个 Token 并在此较短的序列上计算注意力来构建 Mask 。然而,这种近似过于激进:如果压缩序列上没有大的注意力分数,则可能遗漏重要的
块。SeerAttention需要训练额外的注意力参数,这使用起来成本很高。此外,它们都是为语言模型设计的,它们对其他模型类型(如扩散模型)的适用性仍然不确定。
(3)基于训练的方法修改了注意力计算逻辑,需要重新训练整个模型,例如Reformer和FastAttention。这些方法的使用成本远高于无训练方法。
存在其他加速注意力的方法,例如优化 Kernel 实现、量化、分配工作负载以及设计线性时间注意力。它们与SpargeAttn正交。
- SpargeAttn
Spa rgeAt tn 包含一个两阶段在线过滤器以实现Sparse FlashAttention。首先,如图3中的步骤1和步骤2所示,作者设计了一种快速且准确的方法来预测注意力图中的Sparse块,从而跳过
和
的对应乘积。其次,如图3中的步骤3所示,作者设计了一种Sparse在线 softmax 方法以进一步跳过
的乘积。
3.1 SparseFlashAttention
SpargeAttn采用了FlashAttention(Dao,2024)的拼接策略,跳过了计算被过滤掉的块。考虑一个注意力操作
,其中
是softmax操作。设
为序列长度,
为每个头的维度;矩阵
和
每个的维度为
,而矩阵
和
的维度为
。FlashAttention建议将
和
从 Token 维度拼接成块
,分别具有块大小
。然后,它使用在线softmax(Milakov & Gimelshein,2018)逐步计算
的每个块,即
和
是
的向量,分别初始化为
和 0。
是一个类似于 softmax 的算子:
。最后,输出
可以通过
计算得到。
实现Sparse FlashAttention 是直观的。通过跳过
和
的某些块矩阵乘法,作者可以加速注意力计算。作者根据 FlashAttention 公式定义了以下Sparse注意力。
定义1(块 Mask )。令
和
为维度为
的二值 Mask ,其中每个值要么为0要么为1。这些 Mask 决定了在Sparse注意力机制中哪些计算将被跳过。
定义2(SparseFlashAttention)。基于 Mask 的SparseFlashAttention的计算规则定义如下:
当
时
,
和
被
跳
过
。
如
果
,
则
被
跳
过
。
3.2 选择性 Token 压缩以实现Sparse预测
关键思想。尽管注意力图在不同模型中有所差异,但作者观察到各种模型具有一个共同特征: Query 和键矩阵中大多数接近的 Token 显示出高度相似性(见图4)。因此,对于由高度相似 Token 组成的块,作者可以将这些 Token 合并成一个代表 Token 。基于这一观察,作者提出了一种无模式在线预测方法,用于识别
中的Sparse块,以跳过FlashAttention过程中
和
的一些计算。具体来说,作者首先将
和
中表现出高度自相似性的块压缩成 Token 。然后,作者使用压缩的
和
快速计算压缩后的注意力图
。最后,作者仅对那些在压缩注意力图中
累积高分对的
对,选择性地计算
。重要的是,仅压缩具有高度自相似性的 Token 块是至关重要的,因为省略非自相似块的计算可能会导致关键信息的丢失。这将在第4节和附录A.2中得到证实。
预测。如图3的第1步所示,作者首先计算每个
和
块的token的平均余弦相似度。接下来,通过计算token的平均值将每个块压缩成一个token。然后,使用压缩后的
和
计算压缩后的
。最后,为了防止非自相似块(即块相似度小于超参数
)的干扰,作者将
中对应的值设置为
,然后通过softmax获得压缩后的注意力图。此算法可以表示为:
,其中
用于衡量一个块内的余弦相似度。
对于
的每一行,即
,作者选择累积和达到
的前几个最大值的位置,其中
是一个超参数。这些位置在
中被设置为1,而所有其他位置被设置为0。
可以表示如下。
最后,作者需要确保涉及
或
的非自相似块的运算不被遗漏。因此,作者将
中对应于
的非自相似块的行中的所有值设为 1,并将
中对应于
的非自相似块的列中的所有值设为 1。
,
如
果
;
,
如
果
3.3. 第一阶段的 Mask
Mask 。
可以直接应用于FlashAttention中,以节省一些计算。在FlashAttention的内循环中,即在计算
与
之间的注意力时,当
时,作者可以跳过
。
和
,
如
果
3.4 Sparse变形在线softmax
关键思想。作者可以在在线softmax过程中进一步识别注意力图中足够小的值。如果
中的所有值都足够接近于零,则
将可忽略不计,可以省略。
为了确定哪个
(参见第3.1节)包含足够小的值以被省略,作者注意到在FlashAttention的每个内部循环中,
将被
缩放,然后加上
。
若
,则
。因此,
。此外,如果
成立,则
中的所有值都接近于 0。这导致
中的所有值都接近于 0。此条件意味着当
显著小于
时,
可以忽略不计。
若
当
足够小时,上述等价性成立。
因此,基于上述分析,作者提出了一种简单而有效的方法来进一步跳过
的计算。具体来说,在 FlashAttention 的内部循环中,
将被
个 GPU warps 分割为
其中
是 GPU warps 的索引。如果
,其中
足够小,那么
,作者将跳过
的计算,该计算用于更新
。
3.5 结合 SageAttention
为进一步加速Sparse注意力机制的实现,作者将SpargeAttn集成到SageAttention中,该方法提出了一种用于加速注意力的量化方法。由于量化操作和Sparse操作是正交的,Sparse计算可以直接应用于SageAttention。完整的算法展示在算法1中。具体来说,首先,作者需要在SageAttention的内循环开始处添加一个判断(算法1中的第10行),以决定是否跳过整个内循环。其次,作者在SageAttention的内循环中更新
之前添加另一个判断(算法1中的第15行),以决定是否跳过
的计算。此外,为了最小化注意力图预测开销,作者使用CUDA实现预测,并采用了一些 Kernel 融合技术。
3.6 模型层超参数确定
基于第3.2节和3.4节中的方法描述,SpargeAttn包含三个超参数:
、
和
。任何模型中每个注意力层的参数确定过程都很直接。作者的目标是确定一组超参数,这些参数不仅最大化注意力Sparse性,而且限制五个不同模型输入的注意力误差。为了评估注意力精度,作者采用严格的误差度量标准,即相对L1距离,定义为
。过程首先设置两个L1误差阈值
和
,例如,
,
。作者首先对
和
进行网格搜索,以确定最大化Sparse性的最佳配对,同时确保
。随后,作者对
进行另一轮网格搜索,以找到进一步最大化Sparse性的最佳值,同时保持
。
3.7. 希尔伯特曲线排列
关键思想。在提高Sparse注意力的性能过程中,如何在保持准确性的同时提高Sparse性是一个关键挑战。在作者的算法中,通过增加键和 Query 块的自我相似性,可以减少非自我相似块的数量。这使得更多的块可以参与到TopCdf选择中,从而提高Sparse性。由于注意力对 Token 的排列是计算不变的,因此问题简化为寻找一种排列,以增强相邻 Token 的相似性。
图像和视频模型受益于强大的先验知识:相邻像素很可能相似。为了更好地利用这一先验知识,作者提出了希尔伯特曲线排列,给定3D视觉 Token
,作者使用希尔伯特曲线填充3D空间,然后将 Token 沿曲线展平成形状
,其中
。图5展示了通过行主序和希尔伯特曲线展平的
视觉 Token 的示例。希尔伯特曲线有效地保持了局部性,遍历整个3D空间而不跨越行或列,从而增加了相邻 Token 的相似性和注意力Sparse性。
- 实验
4.1. 设置
模型。作者验证了SpargeAttn在多种代表性模型中的有效性,这些模型来自语言、图像和视频生成领域。具体来说,作者在Llama3.1(8B)上进行文本到文本的实验,在CogvideoX(2B)和Mochi上进行文本到视频的实验,在Flux和Stable-Diffusion3.5上进行文本到图像的实验。
数据集。文本到文本模型在四个零样本任务上进行评估:WikiText用于评估模型的预测信心,Longbench和InfiniteBench的En.MC用于全面评估长上下文理解能力,以及Needle-in-A-Haystack任务用于评估模型的检索能力。文本到视频模型使用open-sora Prompt 集进行评估。文本到图像模型在COCO标注上进行评估。
端到端指标。对于Llama3.1,作者使用WikiText上的困惑度(ppl.)、Longbench得分以及Needlein-A-Haystack任务中的检索准确率。对于文本到视频模型,遵循Zhao等(2025)的方法,作者在五个指标上评估生成视频的质量:CLIPSIM和CLIPTemp(CLIP-T)用于衡量文本-视频对齐;VQA-a和VQA-t用于评估视频的美感和技术质量,以及Flow-score(FScore)用于衡量时间一致性。对于文本到图像模型,生成的图像在三个方面与COCO数据集中的图像进行比较:FID用于忠实度评估,Clipscore(CLIP)用于文本-图像对齐,以及ImageReward(IR)用于人类偏好。
速度和Sparse度指标。作者使用TOPS(每秒太操作数)来评估Sparse注意力方法的速度。具体来说,
,其中
表示标准注意力计算中的总操作数,
是从给定的
到注意力输出延迟的时间。请注意,这个速度指标是完全公平的。这是因为对于一组输入,
是固定的,然后速度由
决定,
包括预测注意力图Sparse区域所需的时间。作者定义Sparse度为在完整注意力中,相对于所需的全部
加上
,跳过的
的Matmul加上
的比例。
实现与超参数。作者使用CUDA实现了SpargeAttn。如第3.6节所述,作者需要确定模型中的
。对于Llama3.1,作者使用
;对于CogvideoX和Mochi,使用
;对于Stable-Diffusion3.5和Flux,使用
。
Baseline 。目前,适用于不同模型类型的Sparse注意力方法有限。作者选择了block-sparse MInference和FlexPrefill(FlexPrefill,2025)作为作者的 Baseline 。为了改变这些 Baseline 的Sparse性,作者对MInference使用
和
,根据其论文,对FlexPrefill使用
和0.99。
4.2 质量与效率评估
端到端指标 。作者使用SpargeAttn与使用全注意力机制和 Baseline 方法相比,评估了各种模型的端到端指标。
表1展示了结果。作者可以观察到,与全注意力机制相比,SpargeAttn在各种模型上几乎不产生端到端指标损失,并且在端到端准确率方面超越了具有不同Sparse度的 Baseline 。
图6和图7展示了在Flux、Stable-Diffusion3.5和Mochi上的一些可见比较示例,表明SpargeAttn没有性能损失,并且优于 Baseline 。
注意力速度 。表1显示,与全注意力方法相比,SpargeAttn在注意力速度上更快,并且在注意力速度方面超越了具有不同Sparse度的 Baseline 方法。图9展示了在不同Sparse度下各种方法的核速度,突出了SpargeAttn的高效性及其相较于其他方法的显著优势。
端到端加速。表2展示了使用SpargeAttn在CogvideoX、Mochi和Llama3.1上的端到端延迟。值得注意的是,SpargeAttn在Mochi上实现了
的加速。
4.3 消融研究和关键洞见
Sparse块预测的开销。表3比较了动态Sparse块预测的开销。
SpargeAttn与注意力执行延迟的比较。结果表明,与注意力相比,预测开销最小,尤其是在较长的序列中。
希尔伯特曲线排列效应。作者通过比较三个指标来评估希尔伯特曲线排列对Mochi的影响: Query 或键块的平均块相似度、第3.6节中定义的L1误差以及Sparse性。表4显示,希尔伯特曲线排列在块自相似性和Sparse性方面始终表现出优异的性能,准确率上的差异微乎其微。更多分析和细节请参阅附录A.1。
消除自相似性判断器的影响 作者消除了自相似性判断器对Mochi的影响。如表5所示,作者发现自相似性判断器可以保证端到端的准确率。更多分析请参阅附录A.2。
Sparse性分析:从
和
出发。表6展示了在Llama 3.1上,仅使用
、仅使用
以及使用
在NeedleInAHaystack任务中,128K序列长度下的Sparse性情况。
SpargeAttn提升了LLM的性能。从表1、图8和图10中,作者可以观察到SpargeAttn在长上下文任务中提升了LLM的性能。这种提升可能源于Sparse注意力帮助LLM关注更多相关信息。
Sparse性随序列长度增加。如表7所示,作者在Llama3.1上发现Sparse性随序列长度增加。这表明,较长上下文可以使得SpargeAtt n实现更高的加速。
- 结论
本文提出了一种名为SpargeAttn的通用Sparse和量化注意力机制,该机制能够高效且准确地执行任何输入的注意力操作。SpargeAttn采用两阶段在线滤波器:在第一阶段,作者快速且准确地预测注意力图,从而在注意力操作中跳过一些矩阵乘法。在第二阶段,作者设计了一种在线softmax感知滤波器,该滤波器不会产生额外开销,并进一步跳过一些矩阵乘法。
实验表明,SpargeAttn能够加速各种模型,包括语言、图像和视频生成模型,同时不会牺牲端到端指标。
参考
[1]. SpargeAttn: Accurate Sparse Attention Accelerating Any Model Inference
扫码加入👉
「集智书童」
交流群
(备注:
方向
学校/公司+
昵称
)