摘要
长文本建模对于下一代语言模型至关重要,然而标准注意力机制的高计算成本带来了显著的计算挑战。稀疏注意力为提高效率的同时保持模型能力提供了一个有前景的方向。我们提出了 NSA(Native Sparse Attention),这是一种可原生训练的稀疏注意力机制,通过将算法创新与硬件对齐优化相结合,实现了高效的长文本建模。NSA 采用动态层次化的稀疏策略,将粗粒度的标记压缩与细粒度的标记选择相结合,既保留了全局上下文感知能力,又保持了局部精度。我们的方法在稀疏注意力设计上提出了两项关键创新:(1)通过算术强度平衡的算法设计实现显著加速,并针对现代硬件进行了实现优化。(2)支持端到端训练,减少预训练计算量,同时不牺牲模型性能。如图 1 所示,实验表明使用 NSA 预训练的模型在通用基准测试、长文本任务和基于指令的推理中均保持或超过了全注意力模型的性能。同时,NSA 在处理 64k 长度的序列时,在解码、前向传播和反向传播方面均实现了显著的加速,验证了其在整个模型生命周期中的效率。
- 引言
研究界越来越认识到长文本建模对于下一代大型语言模型是一项关键能力,这由多种现实世界应用推动,包括深入推理(DeepSeek-AI, 2025; Zelikman et al., 2022)、仓库级代码生成(Zhang et al., 2023a; Zhang et al.)和多轮自主代理系统(Park et al., 2023)。最近的突破,包括 OpenAI 的 o 系列模型、DeepSeek-R1(DeepSeek-AI, 2025)和 Gemini 1.5 Pro(Google et al., 2024),使模型能够处理整个代码库、长篇文档、在数千个标记上保持连贯的多轮对话以及执行长距离依赖的复杂推理。然而,随着序列长度的增加,标准注意力(Vaswani et al., 2017)机制的高复杂性(Zaheer et al., 2020)成为一个关键的延迟瓶颈。理论估计表明,在解码 64k 长度上下文时,注意力计算占总延迟的 70-80%,凸显了对更高效注意力机制的迫切需求。
利用 softmax 注意力的固有稀疏性(Ge et al., 2023; Jiang et al., 2023)是一种自然的高效长文本建模方法,通过选择性计算关键的查询-键对,可以在保留性能的同时显著减少计算开销。最近的进展通过多种策略展示了这种潜力:KV 缓存驱逐方法(Li et al., 2024; Zhang et al., 2023b; Zhou et al., 2024)、块状 KV 缓存选择方法(Tang et al., 2024; Xiao et al., 2024)以及基于采样、聚类或哈希的选择方法(Chen et al., 2024; Desai et al., 2024; Liu et al., 2024)。尽管这些方法很有前景,但现有的稀疏注意力方法在实际部署中往往不尽如人意。许多方法未能实现与其理论增益相当的加速;此外,大多数方法主要关注推理阶段,缺乏有效的训练时支持以充分利用注意力的稀疏模式。
为了克服这些局限性,部署有效的稀疏注意力必须解决两个关键挑战:
(1)硬件对齐的推理加速 :将理论计算减少转化为实际速度提升,需要在预填充和解码阶段进行硬件友好的算法设计,以缓解内存访问和硬件调度瓶颈; (2)训练感知的算法设计 :通过可训练的操作符实现端到端计算,以减少训练成本,同时保持模型性能。这些要求对于实际应用实现快速长文本推理或训练至关重要。在考虑这两个方面时,现有方法仍存在明显差距。
为了实现更有效和高效的稀疏注意力,我们提出了 NSA,这是一种具有层次化标记建模的可原生训练的稀疏注意力架构。如图 2 所示,NSA 通过将键和值组织成时间块,并通过三个注意力路径进行处理,从而减少了每个查询的计算量:压缩的粗粒度标记、选择性保留的细粒度标记以及用于局部上下文信息的滑动窗口。然后我们实现了专用内核以最大化其实际效率。NSA 引入了两个核心创新,分别对应上述关键要求:(1)硬件对齐系统:针对 Tensor Core 利用率和内存访问优化块状稀疏注意力,确保算术强度平衡。(2)训练感知设计:通过高效算法和反向操作符实现稳定的端到端训练。这种优化使 NSA 能够支持高效的部署和端到端训练。
我们通过在真实世界语言语料库上的全面实验来评估 NSA。在 27B 参数的 Transformer 骨干网络上进行预训练,使用 260B 个标记,我们在通用语言评估、长文本评估和链式推理评估中评估了 NSA 的性能。我们进一步在 A100 GPU 上与优化的 Triton(Tillet et al., 2019)实现进行了内核速度比较。实验结果表明,NSA 实现了与全注意力基线相当或更优的性能,同时优于现有的稀疏注意力方法。此外,与全注意力相比,NSA 在解码、前向和反向阶段均实现了显著的加速,随着序列长度的增加,加速比也随之增加。这些结果验证了我们的层次化稀疏注意力设计在平衡模型能力和计算效率方面的有效性。
2. 重新思考稀疏注意力方法
现代稀疏注意力方法在减少 Transformer 模型的理论计算复杂性方面取得了显著进展。然而,大多数方法主要在推理阶段应用稀疏性,同时保留预训练的全注意力骨干网络,可能会引入架构偏差,限制其充分利用稀疏注意力优势的能力。在介绍我们的原生稀疏架构之前,我们通过两个关键视角系统地分析了这些局限性。
2.1. 高效推理的幻象
尽管在注意力计算中实现了稀疏性,但许多方法未能实现相应的推理延迟减少,主要由于两个挑战:
- 阶段限制的稀疏性
:例如 H2O(Zhang et al., 2023b)的方法在自回归解码时应用稀疏性,但在预填充阶段需要进行计算密集型的预处理(例如注意力图计算、索引构建)。相比之下,MInference(Jiang et al., 2024)的方法仅关注预填充阶段的稀疏性。这些方法未能在所有推理阶段实现加速,因为至少有一个阶段的计算成本与全注意力相当。这种阶段专业化降低了这些方法在预填充主导的工作负载(如书籍摘要和代码补全)或解码主导的工作负载(如长链式推理(Wei et al., 2022))中的加速能力。
- 与先进注意力架构的不兼容性
:一些稀疏注意力方法未能适应现代解码高效架构,如多查询注意力(MQA)(Shazeer, 2019)和分组查询注意力(GQA)(Ainslie et al., 2023),这些架构通过在多个查询头之间共享 KV 来显著减少解码过程中的内存访问瓶颈。例如,在 Quest(Tang et al., 2024)等方法中,每个注意力头独立选择其 KV 缓存子集。尽管它在多头注意力(MHA)模型中展示了一致的计算稀疏性和内存访问稀疏性,但在基于 GQA 等架构的模型中,KV 缓存的内存访问量对应于同一 GQA 组内所有查询头的选择的并集。这种架构特性意味着,尽管这些方法可以减少计算操作,但所需的 KV 缓存内存访问量仍然相对较高。这一限制迫使人们在稀疏注意力方法中做出关键选择:虽然一些稀疏注意力方法减少了计算量,但其分散的内存访问模式与先进架构的高效内存访问设计相冲突。
这些局限性源于许多现有的稀疏注意力方法专注于 KV 缓存减少或理论计算减少,但在先进的框架或后端中难以实现显著的延迟减少。这促使我们开发结合先进架构和硬件高效实现的算法,以充分利用稀疏性来提高模型效率。
2.2. 可训练稀疏性的神话
我们对原生可训练稀疏注意力的追求是基于对仅推理方法的两个关键见解:性能退化
:在后处理中应用稀疏性迫使模型偏离其预训练的优化轨迹。正如 Chen et al.(2024)所展示的那样,前 20% 的注意力只能覆盖 70% 的总注意力分数,使得预训练模型中的检索头等结构在推理过程中容易被剪枝。训练效率需求
:高效处理长序列训练对于现代大语言模型开发至关重要。这包括在更长的文档上进行预训练以增强模型容量,以及后续的适应阶段,如长文本微调和强化学习。然而,现有的稀疏注意力方法主要针对推理,对训练中的计算挑战关注较少。这一局限性阻碍了通过高效训练开发更具能力的长文本模型。此外,将现有的稀疏注意力适应训练的尝试也暴露了挑战:
- 不可训练组件
:像 ClusterKV(Liu et al., 2024)(包括 k-means 聚类)和 MagicPIG(Chen et al., 2024)(包括基于 SimHash 的选择)等方法中的离散操作在计算图中造成了不连续性。这些不可训练的组件阻止了梯度通过标记选择过程流动,限制了模型学习最优稀疏模式的能力。
- 低效的反向传播
:一些理论上可训练的稀疏注意力方法在实际训练中存在效率问题。例如,HashAttention(Desai et al., 2024)中使用的基于标记粒度的选择策略导致在注意力计算过程中需要从 KV 缓存中加载大量单独的标记。这种非连续的内存访问阻止了像 FlashAttention 这样的快速注意力技术的高效适应,这些技术依赖于连续的内存访问和块状计算以实现高吞吐量。因此,实现被迫退回到低硬件利用率,显著降低了训练效率。
2.3. 原生稀疏性作为必然选择
推理效率和训练可行性的这些局限性促使我们对稀疏注意力机制进行根本性的重新设计。我们提出了 NSA,这是一个原生稀疏注意力框架,同时解决了计算效率和训练需求。在接下来的部分中,我们将详细介绍 NSA 的算法设计和操作符实现。
3. 方法论
我们的技术方法涵盖了算法设计和内核优化。在接下来的小节中,我们首先介绍方法论的背景。然后,我们介绍 NSA 的整体框架,接着是其关键算法组件。最后,我们详细介绍了我们针对硬件优化的内核设计,以最大化实际效率。
3.1. 背景
注意力机制在语言建模中被广泛使用,其中每个查询标记记 的输入序列,注意力操作定义为:
, (1)
其中 Attn 表示注意力函数:
(2)
这里,
表示
和
之间的注意力权重,
是键的特征维度。随着序列长度的增加,注意力计算在整体计算成本中的占比越来越高,对长文本处理构成了重大挑战。
算术强度
是计算操作与内存访问的比率。它本质上塑造了硬件上的算法优化。每个 GPU 都有一个由其峰值计算能力和内存带宽决定的关键算术强度,计算为这两个硬件限制的比率。对于计算任务,高于此关键阈值的算术强度变为计算受限(受 GPU FLOPS 限制),而低于此阈值则变为内存受限(受内存带宽限制)。
具体来说,对于因果自注意力机制,在训练和预填充阶段,批量矩阵乘法和注意力计算表现出高算术强度,使这些阶段在现代加速器上计算受限。相比之下,自回归解码变得受内存带宽限制,因为它每次前向传递生成一个标记,同时需要加载整个键值缓存,导致低算术强度。这导致了不同的优化目标——在训练和预填充阶段减少计算成本,而在解码阶段减少内存访问。
3.2. 整体框架
为了利用注意力的自然稀疏模式的潜力,我们提出用更紧凑且信息密集的表示键值对
,
替换方程(1)中原始的键值对
,
,针对每个查询
。具体来说,我们正式定义优化后的注意力输出如下:
(3)
(4)
其中,并将它们组合如下:
(5)
如图 2 所示,NSA 有三种映射策略 C = {cmp, slc, win},分别代表压缩、选择和滑动窗口的键和值。
是对应策略
的门控分数,由输入特征通过 MLP 和 Sigmoid 激活函数得出。设
表示重映射键/值的总数:
(6)
我们通过确保
来保持高稀疏比率。
3.3. 算法设计
在本小节中,我们介绍我们的重映射策略
和
的设计:标记压缩、标记选择和滑动窗口。
3.3.1. 标记压缩
通过将连续块的键或值聚合为块级表示,我们获得了捕获整个块信息的压缩键和值。形式上,压缩键表示定义为:
(7)
其中
是块长度,
是相邻块之间的滑动步长,
是一个可学习的 MLP,带有块内位置编码,将块中的键映射为单个压缩键。
是由压缩键组成的张量。通常,我们采用
以减少信息碎片化。压缩值表示
有类似的公式。压缩表示捕获了更粗粒度的高级语义信息,减少了注意力的计算负担。
3.3.2. 标记选择
仅使用压缩键和值可能会丢失重要的细粒度信息,这促使我们选择性地保留个别键和值。下面,我们描述我们高效的标记选择机制,以低计算开销识别和保留最相关的标记。
- 块状选择
:我们的选择策略以空间连续块为单位处理键和值序列,这一策略受到两个关键因素的驱动:硬件效率考虑和注意力分数的固有分布模式。块状选择对于在现代 GPU 上实现高效计算至关重要。这是因为现代 GPU 架构对于连续块访问表现出显著更高的吞吐量,相比之下,基于随机索引的读取则效率低下。此外,块状计算能够实现 Tensor Cores 的最佳利用。这种架构特性已经确立了块状内存访问和计算作为高性能注意力实现的基本原则,正如 FlashAttention 的基于块的设计所展示的那样。块状选择遵循注意力分数的固有分布模式。先前的研究(Jiang et al., 2024)已经表明,注意力分数往往表现出空间连续性,表明相邻的键倾向于共享相似的重要性水平。我们在第 6.2 节中的可视化也展示了这种空间连续模式。
- 重要性分数计算
:计算块重要性分数可能会引入显著的开销。幸运的是,压缩标记的注意力计算产生了中间注意力分数,我们可以利用这些分数来诱导选择块的重要性分数,公式为:
, (8)
其中
是
和压缩键
之间的注意力分数。设 𝑖′ 表示选择块大小。当压缩块和选择块共享相同的阻塞方案时,即
,我们可以通过
直接获得选择块的重要性分数
。对于阻塞方案不同的情况,我们根据它们的空间关系推导出选择块的重要性分数。给定
和
,我们有:
(9)
其中 [·] 表示用于访问向量元素的索引运算符。对于采用 GQA 或 MQA 的模型,其中键值缓存跨查询头共享,在解码过程中需要确保跨这些头的一致块选择,以最小化 KV 缓存加载。同一组内跨头共享的重要性分数正式定义为:
(10)
其中上标 (ℎ) 表示头索引,
是每个组内的查询头数量。这种聚合确保了同一组内跨头的一致块选择。
- Top-n块选择
:获得选择块的重要性分数后,我们保留按块重要性分数排名前n的稀疏块内的标记,公式为:
(11)
(12)
其中 rank(·) 表示按降序排列的排名位置,rank = 1 对应最高分数,
是所选块索引的集合,Cat 表示连接操作。
是由压缩键组成的张量。类似的公式适用于细粒度值
。所选键和值随后与
一起参与注意力计算,如方程(5)所定义。
3.3.3. 滑动窗口
在注意力机制中,局部模式通常适应得更快,并且可能会主导学习过程,从而防止模型有效地从压缩和选择标记中学习。为了解决这一问题,我们引入了一个专门的滑动窗口分支,明确处理局部上下文,允许其他分支(压缩和选择)专注于学习各自的特征,而不会被局部模式所捷径。具体来说,我们在一个大小为
的窗口中维护最近的标记
,并将不同信息源(压缩标记、选择标记和滑动窗口)的注意力计算隔离到单独的分支中。这些分支的输出随后通过学习的门控机制进行聚合。为了进一步防止跨注意力分支的捷径学习,并且只引入最小的计算开销,我们为三个分支提供独立的键和值。这种架构设计通过防止局部和长距离模式识别之间的梯度干扰,实现稳定的学习,同时引入最小的开销。
在获得所有三类键和值()之后,我们按照方程(5)计算最终的注意力输出。结合上述压缩、选择和滑动窗口机制,构成了 NSA 的完整算法框架。
3.4. 内核设计
为了在训练和预填充期间实现 FlashAttention 级别的加速,我们在 Triton 上实现了硬件对齐的稀疏注意力内核。鉴于 MHA 内存密集且解码效率低下,我们专注于采用共享 KV 缓存的架构,如 GQA 和 MQA,这些架构遵循当前最先进的 LLMs。虽然压缩和滑动窗口注意力计算与现有的 FlashAttention-2 内核兼容,但我们为稀疏选择注意力引入了专门的内核设计。如果按照 FlashAttention 的策略将时间上连续的查询块加载到 SRAM 中,由于块内的查询可能需要不连续的 KV 块,这将导致内存访问效率低下。为了解决这一问题,我们的关键优化在于不同的查询分组策略:对于查询序列上的每个位置,我们将同一 GQA 组内的所有查询头(它们共享相同的稀疏 KV 块)加载到 SRAM 中。图 3 展示了我们的前向传递实现。所提出的内核架构具有以下关键特点:
- 以组为中心的数据加载
:对于每个内循环,加载位于位置
的组中的所有头的查询
以及它们共享的稀疏键/值块索引
。
- 共享 KV 获取
:在内循环中,按顺序将由
索引的连续键/值块加载到 SRAM 中,分别为
和
,以最小化内存加载,其中
是满足
的内核块大小。
- 网格上的外循环
:由于内循环长度(与所选块数n成正比)对于不同的查询块几乎相同,我们将查询/输出循环放在 Triton 的网格调度器中,以简化和优化内核。
这种设计通过(1)通过组内共享消除冗余的 KV 传输,以及(2)在 GPU 流处理器之间平衡计算工作负载,实现了接近最优的算术强度。
- 实验
我们通过三个角度评估 NSA:(1)通用基准测试性能,(2)长文本基准测试性能,以及(3)链式推理性能,与全注意力基线和最先进的稀疏注意力方法进行比较。我们将稀疏计算范式的效率分析推迟到第 5 节,在那里我们提供了关于训练和推理速度的详细讨论。
4.1. 预训练设置
遵循最先进的 LLMs 的常见实践,我们的实验采用了一个结合了分组查询注意力(GQA)和专家混合(MoE)的骨干网络,总参数量为 27B,其中 3B 为活跃参数。该模型包含 30 层,隐藏维度为 2560。对于 GQA,我们将组数设置为 4,总共有 64 个注意力头。对于每个头,查询、键和值的隐藏维度分别配置为
和
。对于 MoE,我们使用 DeepSeekMoE(Dai et al., 2024; DeepSeek-AI, 2024)结构,包含 72 个路由专家和 2 个共享专家,并将 top-k 专家设置为 6。为了确保训练稳定,第一层的 MoE 被替换为 SwiGLU 形式的 MLP。
所提出的架构在计算成本和模型性能之间实现了有效的权衡。对于 NSA,我们设置压缩块大小 𝑖 = 32,滑动步长 d = 16,所选块大小 𝑖′ = 64,所选块数量 n = 16(包括固定激活的 1 个初始块和 2 个局部块),以及滑动窗口大小 w = 512。全注意力和稀疏注意力模型都在 8k 长度的文本上进行了 270B 个标记的预训练,随后使用 YaRN(Peng et al., 2024)在 32k 长度的文本上进行持续训练和监督微调,以实现长文本适应。两个模型都训练到完全收敛,以确保公平比较。如图 4 所示,我们的 NSA 和全注意力基线的预训练损失曲线显示出稳定且平滑的下降趋势,NSA 一致优于全注意力模型。
4.2. 基线方法
除了与全注意力进行比较外,我们还评估了几种最先进的推理阶段稀疏注意力方法:H2O(Zhang et al., 2023b)、infLLM(Xiao et al., 2024)、Quest(Tang et al., 2024)和 Exact-Top,后者首先计算完整的注意力分数,然后选择每个查询对应的 top-n 分数的键,接着在这些位置上计算注意力。这些方法涵盖了多种稀疏注意力范式,包括 KV 缓存驱逐、查询感知选择和精确的 top-n稀疏选择。
对于通用评估,由于大多数样本的长度在稀疏注意力基线的局部上下文窗口内,这些方法实际上等同于全注意力。因此,在这种设置中,我们仅展示 NSA 和全注意力基线之间的比较结果。在长文本评估中,我们在所有基线方法之间进行比较,并将所有稀疏注意力方法的稀疏性设置为相同,以确保公平比较。对于链式推理评估,这需要长文本监督微调,我们将比较限制为全注意力,因为稀疏注意力基线不支持训练。
4.3. 性能比较
通用评估
:我们在涵盖知识、推理和编程能力的广泛基准测试套件上评估了预训练的 NSA 和全注意力基线,包括 MMLU(Hendrycks et al., 2020)、MMLU-PRO(Wang et al., 2024)、CMMLU(Li et al., 2023)、BBH(Suzgun et al., 2022)、GSM8K(Cobbe et al., 2021)、MATH(Hendrycks et al., 2020)、DROP(Dua et al., 2019)、MBPP(Austin et al., 2021)和 HumanEval(Chen et al., 2021)。结果如表 1 所示。尽管具有稀疏性,NSA 实现了优于全注意力的总体性能,在 9 项指标中的 7 项上优于所有基线。这表明,尽管 NSA 在较短序列上可能无法充分发挥其效率优势,但其性能依然强劲。值得注意的是,NSA 在与推理相关的基准测试中取得了显著的提升(DROP:+0.042,GSM8K:+0.034),这表明我们的预训练有助于模型开发专门的注意力机制。这种稀疏注意力预训练机制迫使模型专注于最重要的信息,可能通过过滤掉来自不相关注意力路径的噪声来增强性能。在多样化评估中的一致性能也验证了 NSA 作为通用架构的稳健性。
长文本评估
:如图 5 所示,NSA 在 64k 上下文的针 haystack 测试中实现了所有位置的完美检索精度。这一性能归功于我们的层次化稀疏注意力设计,它结合了用于高效全局上下文扫描的压缩标记和用于精确局部信息检索的选择标记。粗粒度压缩以较低的计算成本识别相关的上下文块,而标记级的选择标记注意力则确保了关键细粒度信息的保留。这种设计使 NSA 能够同时保持全局感知和局部精度。
我们还在 LongBench(Bai et al., 2023)上对 NSA 进行了评估,并与最先进的稀疏注意力方法和全注意力基线进行了比较。为了确保一致的稀疏性,我们将所有稀疏注意力基线中每个查询激活的标记数量设置为 2560 个标记,这对应于 NSA 在处理 32k 序列长度时激活的平均标记数量。按照 StreamLLM(Xiao et al., 2023)的方法,这一标记预算包括前 128 个标记和 512 个局部标记。我们排除了 LongBench 中某些子集,因为它们在所有模型上的得分都很低,可能无法提供有意义的比较。如表 2 所示,NSA 实现了最高的平均分数 0.469,优于所有基线(比全注意力高出 +0.032,比 Exact-Top 高出 +0.046)。这一改进源于两个关键创新:(1)我们的原生稀疏注意力设计,它允许在预训练期间对稀疏模式进行端到端优化,促进了稀疏注意力模块与其他模型组件之间的同步适应;(2)层次化稀疏注意力机制在局部和全局信息处理之间实现了平衡。
值得注意的是,NSA 在需要在长文本上进行复杂推理的任务上表现出色,在多跳问答任务(HPQ 和 2Wiki)上比全注意力提高了 +0.087 和 +0.051,在代码理解(LCC:+0.069)上超过了基线,并且在段落检索(PassR-en:+0.075)上优于其他方法。这些结果验证了 NSA 在处理多样化长文本挑战方面的能力,其原生预训练的稀疏注意力在学习任务最优模式方面提供了额外的好处。
链式推理评估
:为了评估 NSA 对先进下游训练范式的兼容性,我们研究了其通过后训练获得链式推理数学推理能力的能力。鉴于强化学习在较小规模模型上的有限有效性,我们采用了 DeepSeek-R1 的知识蒸馏,进行了监督微调(SFT),使用 10B 个标记的 32k 长度数学推理轨迹。这产生了两个可比较的模型:全注意力-R(全注意力基线)和 NSA-R(我们的稀疏变体)。我们在具有挑战性的美国数学邀请赛(AIME 24)基准测试上评估了这两个模型。我们使用 0.7 的采样温度和 0.95 的 top-𝑖? 值,为每个问题生成 16 个回答,并获得平均分数。为了验证推理深度的影响,我们在两个生成上下文限制下进行了实验:8k 和 16k 标记,测量扩展推理链是否提高了准确性。模型预测的示例比较提供在附录 A 中。
如表 3 所示,NSA-R 在 8k 上下文设置下实现了比全注意力-R 更高的准确性(+0.075),这一优势在 16k 上下文中持续存在(+0.054)。这些结果验证了原生稀疏注意力的两个关键好处:(1)预训练的稀疏注意力模式能够高效地捕获对于复杂数学推导至关重要的长距离逻辑依赖;(2)我们架构的硬件对齐设计保持了足够的上下文密度,以支持推理深度的增加而不会发生灾难性遗忘。在上下文长度上的一致优势确认了当原生集成到训练管道中时,稀疏注意力对于先进推理任务的可行性。
5. 效率分析
我们在 8-GPU A100 系统上评估了 NSA 与全注意力的计算效率。在效率分析中,我们还将模型配置为 GQA 组 g = 4,每组头数 ℎ = 16,查询/键维度
= 192,以及值维度
= 128。按照第 4 节中的设置,我们将 NSA 的压缩块大小设置为
= 32,滑动步长设置为 d = 16,所选块大小设置为 𝑖′ = 64,所选块数量设置为 n = 16,以及滑动窗口大小设置为 w = 512。
5.1. 训练速度
我们将基于 Triton 的 NSA 注意力实现与全注意力实现进行了比较,以确保在相同的后端上进行公平的速度比较。如图 6 所示,随着上下文长度的增加,NSA 实现了越来越大的加速,在 64k 上下文长度时,前向传播加速比达到 9.0×,反向传播加速比达到 6.0×。值得注意的是,随着序列长度的增加,速度优势变得更加显著。这种加速源于我们为最大化稀疏注意力架构的效率而设计的硬件对齐算法:(1)块状内存访问模式通过合并加载最大化了 Tensor Core 的利用率;(2)内核中精心设计的循环调度消除了冗余的 KV 传输。
5.2. 解码速度
注意力的解码速度主要由内存访问瓶颈决定,这与 KV 缓存的加载量密切相关。在每个解码步骤中,我们的 NSA 仅需要加载最多
个压缩标记、
个选择标记以及 w 个邻近标记,其中 s是缓存的序列长度。如表 4 所示,随着解码长度的增加,我们的方法显著减少了延迟,在 64k 上下文长度时实现了高达 11.6× 的加速。随着序列长度的增加,这种内存访问效率的优势也变得更加显著。
- 讨论
在本节中,我们回顾了 NSA 的开发过程,并讨论了从探索不同稀疏注意力策略中获得的关键见解。尽管我们的方法展现了有希望的结果,但了解替代策略所面临的挑战以及分析注意力模式为我们提供了宝贵的背景,以便为未来的研究方向提供参考。我们首先考察了促使我们做出设计选择的替代标记选择策略所面临的挑战,随后提供了提供对注意力分布模式见解的可视化。
6.1. 替代标记选择策略面临的挑战
在设计 NSA 之前,我们探索了将现有的稀疏注意力方法适应到训练阶段。然而,这些尝试遇到了各种挑战,促使我们设计了一个不同的稀疏注意力架构:
基于关键聚类的策略
:我们考察了像 ClusterKV(Liu et al., 2024)这样的基于聚类的策略。这些方法将来自同一聚类的键和值存储在连续的内存区域中。尽管从理论上讲,这些方法可以用于训练和推理,但它们面临三个显著的挑战:(1)动态聚类机制引入了非平凡的计算开销;(2)由于聚类间的不平衡,操作符优化困难加剧,特别是在混合专家(MoE)系统中,倾斜的专家并行(EP)组执行时间导致持续的负载不平衡;(3)由于需要定期重新聚类和块序贯训练协议,实施约束随之而来。这些综合因素造成了巨大的瓶颈,显著限制了它们在实际部署中的有效性。
其他块状选择策略
:我们还考虑了与 NSA 不同的块状键、值选择策略,例如 Quest(Tang et al., 2024)和 InfLLM(Xiao et al., 2024)。这些方法依赖于计算每个块的重要性分数,并根据其与 𝑖?𝑖? 的相似性选择 top-𝑖? 块。然而,现有方法面临两个关键问题:(1)由于选择操作是非可微的,基于神经网络的重要性分数计算依赖于辅助损失,这增加了操作符开销,并且通常会降低模型性能;(2)基于启发式的无参数重要性分数计算策略遭受低召回率的困扰,导致次优性能。我们在具有相似架构的 3B 参数模型上评估了这两种方法,并将其损失曲线与 NSA 和全注意力进行了比较。对于基于辅助损失的选择方法,我们引入了每个块的额外查询和代表性键,以估计块的重要性分数。这些分数由每个块内原始查询和键之间的平均注意力分数监督。对于基于启发式的无参数选择方法,我们按照 Quest 的策略,使用查询和键块的逐坐标最小值和最大值的乘积直接进行选择,而不引入额外的参数。我们还探索了一种冷启动训练方法,即在转换为启发式块状选择之前,最初 1000 步应用全注意力。如图 7 所示,这两种方法的损失均劣于 NSA。
6.2. 可视化
为了探索 Transformer 注意力分布的潜在模式,并为我们的设计寻求灵感,我们在图 8 中可视化了我们预训练的 27B 全注意力模型的注意力图。可视化揭示了有趣的模式,其中注意力分数倾向于表现出块状聚类特征,附近的键往往具有相似的注意力分数。这一观察结果启发了我们的 NSA 设计,表明基于空间连续性的选择关键块可能是一个有前途的方法。块状聚类现象表明,序列中相邻的标记可能与查询标记共享某些语义关系,尽管这些关系的确切性质需要进一步调查。这一观察结果促使我们探索一种在连续标记块上运行的稀疏注意力机制,而不是在单个标记上运行,旨在提高计算效率并保留高注意力模式。
7. 相关工作
我们回顾了通过稀疏注意力提高注意力计算效率的现有方法。这些方法可以根据其核心策略大致分为三类:(1)固定稀疏模式,(2)动态标记修剪,(3)查询感知选择。我们从每个类别中介绍了一些代表性工作。
7.1. 固定稀疏模式
滑动窗口是一种常见的方法,允许查询仅在固定窗口内计算注意力。StreamingLLM(Xiao et al., 2023)通过维护上下文的两个关键部分来解决处理长文本流的挑战:一个注意力池(早期标记)和一个局部上下文窗口。尽管这些方法有效地减少了内存和计算成本,但它们固定的忽略上下文的模式限制了它们在需要完整上下文理解的任务上的性能。
7.2. 动态标记修剪
H2O(Zhang et al., 2023b)实现了一种自适应方法,用于在解码过程中减少 KV 缓存的内存使用。这种方法根据标记的近期效用(根据注意力分数)动态驱逐被认为对未来预测不那么重要的标记。SnapKV(Li et al., 2024)也引入了一种标记修剪策略,通过选择性地保留最关键的特征来减少 KV 缓存,从而实现高效的内存使用。SnapKV 通过分析预填充阶段的注意力权重并进行投票来识别重要特征,然后通过结合选定的压缩特征和最近的上下文更新 KV 缓存,以保持提示的一致性。
7.3. 查询感知选择
Quest(Tang et al., 2024)采用了一种块状选择策略,通过计算查询和键块的逐坐标最小值和最大值的乘积来估计每个块的重要性。这些分数用于选择 top-𝑖? 重要的键值块进行注意力计算。InfLLM(Xiao et al., 2024)结合了固定模式和检索,通过维护注意力池、局部上下文和可检索块来实现高效的长文本处理。该方法通过从每个块中选择代表性键来估计块的重要性。HashAttention(Desai et al., 2024)将关键标记识别表述为一个推荐问题,通过使用学习到的函数将查询和键映射到汉明空间来实现。ClusterKV(Liu et al., 2024)通过首先对键进行聚类,然后根据查询与聚类的相似性选择最相关的聚类进行注意力计算,从而实现稀疏性。
- 结论
我们提出了 NSA,这是一种用于高效长文本建模的硬件对齐稀疏注意力架构。通过将层次化标记压缩与块状标记选择相结合,并将其集成到可训练的架构中,我们的架构在保持全注意力性能的同时,实现了加速的训练和推理。NSA 通过在通用基准测试中匹配全注意力基线,在长文本评估中超越建模能力,以及增强推理能力,同时显著减少计算延迟并实现显著加速,从而推动了该领域的进步。