3步起飞 | 如何让大模型推理飞起来?SpargeAttn稀疏注意力完美实现加速并做到端到端无损

大模型向量数据库机器学习

点击下方卡片,关注

「集智书童」

公众号

点击加入👉

「集智书童」

交流群

picture.image

picture.image

picture.image

picture.image

picture.image

picture.image

导读

高效的关注机制对于大型模型至关重要,因为其时间复杂度为二次方。幸运的是,关注机制通常表现出Sparse性,即注意力图中许多值接近于零,这允许省略相应的计算。许多研究已经利用Sparse模式来加速关注机制。然而,大多数现有工作通过利用注意力图的特定Sparse模式来优化特定模型中的关注机制。保证各种模型速度提升和端到端性能的通用Sparse关注机制仍然难以实现。

在本文中,作者提出了SpargeAttn,这是一种适用于任何模型的通用Sparse和量化关注机制。SpargeAttn使用两阶段在线过滤器:在第一阶段,作者快速且准确地预测注意力图,从而允许跳过注意力中的某些矩阵乘法。在第二阶段,作者设计了一个在线softmax感知过滤器,它不会产生额外的开销,并进一步跳过一些矩阵乘法。

实验表明,SpargeAttn显著加速了包括语言、图像和视频生成在内的各种模型,而不会牺牲端到端指标。

代码:https://github.com/thu-ml/SpargeAttn

  1. 引言

随着大型模型中序列长度的增加,例如视频生成和语言模型中的45K-128K,注意力机制的耗时占据了大型模型推理延迟的很大一部分。幸运的是,注意力图

具有固有的Sparse性,因为softmax操作通常会产生许多接近零的值。Sparse注意力方法通过以下方式利用这种Sparse性来加速注意力:

    1. 构建一个“Sparse Mask ”,指示注意力图

中应计算的重要非零条目 2. 2. 仅对对应Sparse Mask 的部分计算注意力

根据Sparse Mask 的生成方式,Sparse注意力方法可以分为三类。基于模式的Sparse注意力方法依赖于基于经验观察的特定Sparse模式,动态Sparse注意力根据输入动态计算 Mask ,基于训练的方法直接训练具有原生Sparse注意力的模型。

局限性。(L1. 通用性)尽管现有的Sparse注意力方法在某些任务上已经显示出有希望的速度提升,但它们的通用性仍然有限。现有工作通常针对特定任务开发,例如语言建模,利用特定于任务的模式,如滑动窗口或注意力汇聚。然而,注意力模式在不同任务之间存在显著差异(见图2中的示例),使得这些模式难以推广。(L2. 可用性)此外,对于任何输入实现既准确又高效的Sparse注意力都很难。这是因为准确性要求精确预测注意力图中的Sparse区域,而效率要求这种预测的开销最小。然而,当前方法难以同时有效满足这两个要求。例如,MInference(Jiang等人,2024)需要较长的序列长度,如100K,才能实现明显的速度提升。

picture.image

目标。作者旨在设计一种无需训练的Sparse注意力算子,以加速所有模型而不损失度量。

SpargeAttn。在本工作中,作者开发了SpargeAt tn,这是一种无需训练的Sparse注意力机制,可以广泛应用于各种任务,包括语言建模、文本到图像/视频,以及各种序列长度。作者提出了三种主要技术来提高其通用性、准确性和效率。首先,作者提出了一种通用的Sparse Mask 预测算法,该算法通过将每个

块压缩成一个单独的token来构建Sparse Mask 。重要的是,作者根据块内token的相似性进行选择性压缩,因此该算法可以准确预测各种任务中的Sparse Mask 。其次,作者提出了一种在GPU warp Level 的Sparse在线softmax算法,通过利用在线softmax中全局最大值和局部最大值之间的差异,进一步省略了一些

乘法。第三,作者将这种Sparse方法集成到8位量化SageAttention框架中,以进一步加速。

结果。作者在多种生成任务上评估了SpargeAttn,包括语言建模和文本到图像/视频,并使用全面的性能指标对模型质量进行了评估。SpargeAttn能够稳健地保持模型端到端性能,而现有的Sparse注意力 Baseline 则会导致性能下降。此外,SpargeAttn的速度比现有的密集和Sparse注意力模型快2.5倍到5倍。

  1. 相关工作

根据Sparse Mask 的构建方式,Sparse注意力方法可以分为三种类型:

(1)需要模式的方法依赖于注意力图的一些固定模式,例如滑动窗口或注意力汇聚点。H2O、InfLLM和DUOAttention依赖于滑动窗口模式。SampleAttention、MOA和StreamingLLM依赖于滑动窗口和注意力汇聚点模式。DitFastAttn依赖于滑动窗口模式和不同注意力图之间的相似性。此外,DitFastAttn仅限于简单的扩散 Transformer ,与语言模型和MMDiT模型(如Flux、Stable Difusion3和3.5以及CogVideoX)不兼容。由于模式在不同模型之间有所不同,这些方法可能不适用于所有模型。

(2)动态Sparse方法根据输入动态构建Sparse Mask ,无需预设模式,因此可能更加通用。现有工作可以进一步分为通道压缩和 Token 压缩。通道压缩方法包括SparQAttn和LokiAttn。它们通过降低维度来携带完整的注意力构建 Mask 。然而,由于维度已经很小,例如在常用的注意力中为64、128,因此加速潜力可能有限。 Token 压缩方法包括MInference和FlexPrefill。它们通过将每个 Token 块压缩为单个 Token 并在此较短的序列上计算注意力来构建 Mask 。然而,这种近似过于激进:如果压缩序列上没有大的注意力分数,则可能遗漏重要的

块。SeerAttention需要训练额外的注意力参数,这使用起来成本很高。此外,它们都是为语言模型设计的,它们对其他模型类型(如扩散模型)的适用性仍然不确定。

(3)基于训练的方法修改了注意力计算逻辑,需要重新训练整个模型,例如Reformer和FastAttention。这些方法的使用成本远高于无训练方法。

存在其他加速注意力的方法,例如优化 Kernel 实现、量化、分配工作负载以及设计线性时间注意力。它们与SpargeAttn正交。

  1. SpargeAttn

Spa rgeAt tn 包含一个两阶段在线过滤器以实现Sparse FlashAttention。首先,如图3中的步骤1和步骤2所示,作者设计了一种快速且准确的方法来预测注意力图中的Sparse块,从而跳过

的对应乘积。其次,如图3中的步骤3所示,作者设计了一种Sparse在线 softmax 方法以进一步跳过

的乘积。

picture.image

3.1 SparseFlashAttention

SpargeAttn采用了FlashAttention(Dao,2024)的拼接策略,跳过了计算被过滤掉的块。考虑一个注意力操作

,其中

是softmax操作。设

为序列长度,

为每个头的维度;矩阵

每个的维度为

,而矩阵

的维度为

。FlashAttention建议将

从 Token 维度拼接成块

,分别具有块大小

。然后,它使用在线softmax(Milakov & Gimelshein,2018)逐步计算

的每个块,即

的向量,分别初始化为

和 0。

是一个类似于 softmax 的算子:

。最后,输出

可以通过

计算得到。

实现Sparse FlashAttention 是直观的。通过跳过

的某些块矩阵乘法,作者可以加速注意力计算。作者根据 FlashAttention 公式定义了以下Sparse注意力。

定义1(块 Mask )。令

为维度为

的二值 Mask ,其中每个值要么为0要么为1。这些 Mask 决定了在Sparse注意力机制中哪些计算将被跳过。

定义2(SparseFlashAttention)。基于 Mask 的SparseFlashAttention的计算规则定义如下:

3.2 选择性 Token 压缩以实现Sparse预测

关键思想。尽管注意力图在不同模型中有所差异,但作者观察到各种模型具有一个共同特征: Query 和键矩阵中大多数接近的 Token 显示出高度相似性(见图4)。因此,对于由高度相似 Token 组成的块,作者可以将这些 Token 合并成一个代表 Token 。基于这一观察,作者提出了一种无模式在线预测方法,用于识别

中的Sparse块,以跳过FlashAttention过程中

的一些计算。具体来说,作者首先将

中表现出高度自相似性的块压缩成 Token 。然后,作者使用压缩的

快速计算压缩后的注意力图

。最后,作者仅对那些在压缩注意力图中

累积高分对的

对,选择性地计算

。重要的是,仅压缩具有高度自相似性的 Token 块是至关重要的,因为省略非自相似块的计算可能会导致关键信息的丢失。这将在第4节和附录A.2中得到证实。

picture.image

预测。如图3的第1步所示,作者首先计算每个

块的token的平均余弦相似度。接下来,通过计算token的平均值将每个块压缩成一个token。然后,使用压缩后的

计算压缩后的

。最后,为了防止非自相似块(即块相似度小于超参数

)的干扰,作者将

中对应的值设置为

,然后通过softmax获得压缩后的注意力图。此算法可以表示为:

,其中

用于衡量一个块内的余弦相似度。

对于

的每一行,即

,作者选择累积和达到

的前几个最大值的位置,其中

是一个超参数。这些位置在

中被设置为1,而所有其他位置被设置为0。

可以表示如下。

最后,作者需要确保涉及

的非自相似块的运算不被遗漏。因此,作者将

中对应于

的非自相似块的行中的所有值设为 1,并将

中对应于

的非自相似块的列中的所有值设为 1。

3.3. 第一阶段的 Mask

Mask 。

可以直接应用于FlashAttention中,以节省一些计算。在FlashAttention的内循环中,即在计算

之间的注意力时,当

时,作者可以跳过

3.4 Sparse变形在线softmax

关键思想。作者可以在在线softmax过程中进一步识别注意力图中足够小的值。如果

中的所有值都足够接近于零,则

将可忽略不计,可以省略。

为了确定哪个

(参见第3.1节)包含足够小的值以被省略,作者注意到在FlashAttention的每个内部循环中,

将被

缩放,然后加上

,则

。因此,

。此外,如果

成立,则

中的所有值都接近于 0。这导致

中的所有值都接近于 0。此条件意味着当

显著小于

时,

可以忽略不计。

足够小时,上述等价性成立。

因此,基于上述分析,作者提出了一种简单而有效的方法来进一步跳过

的计算。具体来说,在 FlashAttention 的内部循环中,

将被

个 GPU warps 分割为

其中

是 GPU warps 的索引。如果

,其中

足够小,那么

,作者将跳过

的计算,该计算用于更新

3.5 结合 SageAttention

为进一步加速Sparse注意力机制的实现,作者将SpargeAttn集成到SageAttention中,该方法提出了一种用于加速注意力的量化方法。由于量化操作和Sparse操作是正交的,Sparse计算可以直接应用于SageAttention。完整的算法展示在算法1中。具体来说,首先,作者需要在SageAttention的内循环开始处添加一个判断(算法1中的第10行),以决定是否跳过整个内循环。其次,作者在SageAttention的内循环中更新

之前添加另一个判断(算法1中的第15行),以决定是否跳过

的计算。此外,为了最小化注意力图预测开销,作者使用CUDA实现预测,并采用了一些 Kernel 融合技术。

picture.image

3.6 模型层超参数确定

基于第3.2节和3.4节中的方法描述,SpargeAttn包含三个超参数:

。任何模型中每个注意力层的参数确定过程都很直接。作者的目标是确定一组超参数,这些参数不仅最大化注意力Sparse性,而且限制五个不同模型输入的注意力误差。为了评估注意力精度,作者采用严格的误差度量标准,即相对L1距离,定义为

。过程首先设置两个L1误差阈值

,例如,

。作者首先对

进行网格搜索,以确定最大化Sparse性的最佳配对,同时确保

。随后,作者对

进行另一轮网格搜索,以找到进一步最大化Sparse性的最佳值,同时保持

3.7. 希尔伯特曲线排列

关键思想。在提高Sparse注意力的性能过程中,如何在保持准确性的同时提高Sparse性是一个关键挑战。在作者的算法中,通过增加键和 Query 块的自我相似性,可以减少非自我相似块的数量。这使得更多的块可以参与到TopCdf选择中,从而提高Sparse性。由于注意力对 Token 的排列是计算不变的,因此问题简化为寻找一种排列,以增强相邻 Token 的相似性。

图像和视频模型受益于强大的先验知识:相邻像素很可能相似。为了更好地利用这一先验知识,作者提出了希尔伯特曲线排列,给定3D视觉 Token

,作者使用希尔伯特曲线填充3D空间,然后将 Token 沿曲线展平成形状

,其中

。图5展示了通过行主序和希尔伯特曲线展平的

视觉 Token 的示例。希尔伯特曲线有效地保持了局部性,遍历整个3D空间而不跨越行或列,从而增加了相邻 Token 的相似性和注意力Sparse性。

picture.image

  1. 实验

4.1. 设置

模型。作者验证了SpargeAttn在多种代表性模型中的有效性,这些模型来自语言、图像和视频生成领域。具体来说,作者在Llama3.1(8B)上进行文本到文本的实验,在CogvideoX(2B)和Mochi上进行文本到视频的实验,在Flux和Stable-Diffusion3.5上进行文本到图像的实验。

数据集。文本到文本模型在四个零样本任务上进行评估:WikiText用于评估模型的预测信心,Longbench和InfiniteBench的En.MC用于全面评估长上下文理解能力,以及Needle-in-A-Haystack任务用于评估模型的检索能力。文本到视频模型使用open-sora Prompt 集进行评估。文本到图像模型在COCO标注上进行评估。

端到端指标。对于Llama3.1,作者使用WikiText上的困惑度(ppl.)、Longbench得分以及Needlein-A-Haystack任务中的检索准确率。对于文本到视频模型,遵循Zhao等(2025)的方法,作者在五个指标上评估生成视频的质量:CLIPSIM和CLIPTemp(CLIP-T)用于衡量文本-视频对齐;VQA-a和VQA-t用于评估视频的美感和技术质量,以及Flow-score(FScore)用于衡量时间一致性。对于文本到图像模型,生成的图像在三个方面与COCO数据集中的图像进行比较:FID用于忠实度评估,Clipscore(CLIP)用于文本-图像对齐,以及ImageReward(IR)用于人类偏好。

速度和Sparse度指标。作者使用TOPS(每秒太操作数)来评估Sparse注意力方法的速度。具体来说,

,其中

表示标准注意力计算中的总操作数,

是从给定的

到注意力输出延迟的时间。请注意,这个速度指标是完全公平的。这是因为对于一组输入,

是固定的,然后速度由

决定,

包括预测注意力图Sparse区域所需的时间。作者定义Sparse度为在完整注意力中,相对于所需的全部

加上

,跳过的

的Matmul加上

的比例。

实现与超参数。作者使用CUDA实现了SpargeAttn。如第3.6节所述,作者需要确定模型中的

。对于Llama3.1,作者使用

;对于CogvideoX和Mochi,使用

;对于Stable-Diffusion3.5和Flux,使用

Baseline 。目前,适用于不同模型类型的Sparse注意力方法有限。作者选择了block-sparse MInference和FlexPrefill(FlexPrefill,2025)作为作者的 Baseline 。为了改变这些 Baseline 的Sparse性,作者对MInference使用

,根据其论文,对FlexPrefill使用

和0.99。

4.2 质量与效率评估

端到端指标 。作者使用SpargeAttn与使用全注意力机制和 Baseline 方法相比,评估了各种模型的端到端指标。

picture.image

表1展示了结果。作者可以观察到,与全注意力机制相比,SpargeAttn在各种模型上几乎不产生端到端指标损失,并且在端到端准确率方面超越了具有不同Sparse度的 Baseline 。

picture.image

图6和图7展示了在Flux、Stable-Diffusion3.5和Mochi上的一些可见比较示例,表明SpargeAttn没有性能损失,并且优于 Baseline 。

picture.image

注意力速度 。表1显示,与全注意力方法相比,SpargeAttn在注意力速度上更快,并且在注意力速度方面超越了具有不同Sparse度的 Baseline 方法。图9展示了在不同Sparse度下各种方法的核速度,突出了SpargeAttn的高效性及其相较于其他方法的显著优势。

picture.image

端到端加速。表2展示了使用SpargeAttn在CogvideoX、Mochi和Llama3.1上的端到端延迟。值得注意的是,SpargeAttn在Mochi上实现了

的加速。

picture.image

4.3 消融研究和关键洞见

Sparse块预测的开销。表3比较了动态Sparse块预测的开销。

picture.image

SpargeAttn与注意力执行延迟的比较。结果表明,与注意力相比,预测开销最小,尤其是在较长的序列中。

希尔伯特曲线排列效应。作者通过比较三个指标来评估希尔伯特曲线排列对Mochi的影响: Query 或键块的平均块相似度、第3.6节中定义的L1误差以及Sparse性。表4显示,希尔伯特曲线排列在块自相似性和Sparse性方面始终表现出优异的性能,准确率上的差异微乎其微。更多分析和细节请参阅附录A.1。

picture.image

消除自相似性判断器的影响 作者消除了自相似性判断器对Mochi的影响。如表5所示,作者发现自相似性判断器可以保证端到端的准确率。更多分析请参阅附录A.2。

picture.image

Sparse性分析:从

出发。表6展示了在Llama 3.1上,仅使用

、仅使用

以及使用

在NeedleInAHaystack任务中,128K序列长度下的Sparse性情况。

picture.image

SpargeAttn提升了LLM的性能。从表1、图8和图10中,作者可以观察到SpargeAttn在长上下文任务中提升了LLM的性能。这种提升可能源于Sparse注意力帮助LLM关注更多相关信息。

picture.image

Sparse性随序列长度增加。如表7所示,作者在Llama3.1上发现Sparse性随序列长度增加。这表明,较长上下文可以使得SpargeAtt n实现更高的加速。

picture.image

  1. 结论

本文提出了一种名为SpargeAttn的通用Sparse和量化注意力机制,该机制能够高效且准确地执行任何输入的注意力操作。SpargeAttn采用两阶段在线滤波器:在第一阶段,作者快速且准确地预测注意力图,从而在注意力操作中跳过一些矩阵乘法。在第二阶段,作者设计了一种在线softmax感知滤波器,该滤波器不会产生额外开销,并进一步跳过一些矩阵乘法。

实验表明,SpargeAttn能够加速各种模型,包括语言、图像和视频生成模型,同时不会牺牲端到端指标。

参考

[1]. SpargeAttn: Accurate Sparse Attention Accelerating Any Model Inference

picture.image

扫码加入👉

「集智书童」

交流群

(备注:

方向

学校/公司+

昵称

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
在火山引擎云搜索服务上构建混合搜索的设计与实现
本次演讲将重点介绍字节跳动在混合搜索领域的探索,并探讨如何在多模态数据场景下进行海量数据搜索。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论