© 作者|陈杰
机构|中国人民大学
研究方向|自然语言处理、大语言模型
为什么需要 FlashAttention?
Transformer 模型在自然语言处理(NLP)和大语言模型(LLM)领域取得了巨大成功。然而,传统的 自注意力(Self-Attention) 模块在处理长序列时面临 的时间和空间复杂度,这极大限制了其在长上下文处理上的效率。当输入序列长度 增大时,注意力机制的计算和内存开销急剧上升,导致模型训练速度变慢,显存占用剧增。
FlashAttention 的解决方案
为了解决上述问题,FlashAttention 提出了一种 快速、节省内存、精确的注意力计算方法 ,通过减少内存读写次数,提高 GPU 的内存 IO 效率,加快训练速度,并增加上下文窗口长度。其核心目标是通过 硬件感知(IO-awareness) 的算法优化,将整个注意力计算过程在更高速的 SRAM (静态随机存取存储器)中完成,减少对 HBM(高带宽内存)的依赖。
- 快速:降低 HBM 访问次数,采用 Tiling 的方法,分块从 HBM 加载矩阵到 SRAM 进行融合计算。
- 节省内存:不再对中间矩阵进行存储,反向传播时重新计算来计算梯度。
- 精确:算法中无近似操作,只是进行分块计算。
标准注意力机制的内存访问分析
在传统的注意力机制中,首先需要将输入矩阵 从 HBM 中读取,随后计算注意力矩阵 ,再通过 函数得到概率矩阵 ,最后与 相乘得到输出 。这个过程的内存访问复杂度如下:
- 读取 ,计算 ,再将 存入 HBM:复杂度为 。
- 读取 ,计算 ,再存入 HBM:复杂度为 。
- 读取 ,计算 ,得到最终结果 :复杂度为 。
综上所述,标准注意力机制的内存访问总复杂度为:
由于 矩阵很大,需要在 HBM 中存储,这带来了很多 HBM 的访问次数,导致算法的时间延迟。
FlashAttention 如何优化
FlashAttention 将从输入 到输出 的整个过程进行融合,避免 矩阵的存储开销,实现延迟缩减。其中面临的挑战是,输入长度 通常很大,无法完全将完整的 存储在 SRAM 中,而我们需要让计算过程的结果完全在 SRAM 中,摆脱对 HBM 的依赖。
一种解决方案是采用分片操作,每次进行部分计算,但分片操作面临的挑战是,我们需要确保这些计算结果能在 SRAM 内进行交互,能基于存储的中间值更新得到最终结果。特别地,之前对于 的计算是以行为单位,当我们将输入进行分片后,无法对完整的行数据执行 操作(因为 函数在计算时需要考虑整行数据)。FlashAttention 采用 Tiling 的方式,实现分块 。
总体而言,FlashAttention 通过以下方法优化上述过程:
- 分块计算 :将 切分为小块,每次只加载一部分数据到 SRAM 中,并避免存储中间矩阵 ,从而减少 HBM 访问次数。
- 融合计算 :将从 到 的整个流程进行融合,不再单独存储中间结果。
- 分片 : 通过 Tiling 技术解决分块后 的计算问题,实现更高效的内存管理。
Tiling
我们介绍分片 在 FlashAttention 中的具体实现。
以向量 为例,记
将 分为 块,,其中
则
从而:
记 ,
则:
从而我们有:
因此,在分块计算时,我们可以迭代更新 ,以得到最终的结果。
具体而言,已知前个 分块的结果:
其中:
迭代过程具体如下:
更新:
从而:
这样,我们可基于已有的分块结果和新的分块信息更新得到新的分块结果,实现分片 。
前向传播步骤
- 读取 并计算 :
- 通过将 分块加载到 SRAM 中,减少 HBM 的读取次数。
- 使用分块矩阵乘法进行计算。
- 分块 :
- 由于输入被分片处理,无法对整行数据执行标准的 操作。
- FlashAttention 通过在每个块内进行 计算,并在后续块计算时调整累积值,实现准确的计算。
- 计算输出 :
- 累加各分块的结果,最终得到输出矩阵 。
算法复杂度分析
假设 的分块大小为 , 的分块大小为 ,对于 的每一分块,都需要加载 ,则有 FlashAttention 的内存访问复杂度为 。
在 SRAM 大小为 的条件下,我们有以下约束:
相应的, 有如下限制:
最终,还有一个中间态 需要存储,则有如下限制:
综上,限制如下:
进而推出:
那么在 的前提下,则有 FlashAttention 的 HBM 内存访问复杂度为:
在语言建模中,通常有 ,则有 。
因此,FlashAttention 采用分块计算,避免 矩阵存储开销,降低 HBM 访问次数,从而提升计算速度。
反向传播步骤
前向传播时我们为了减少 HBM 访存次数,并没有对 矩阵进行存储,而在反向传播计算梯度的时候需要这一信息。FlashAttention 采用重新计算对应梯度的方式,利用前向时存储的指数项之和 来进行梯度的计算。
下面我们计算损失函数 对 对应的梯度。
对应的梯度 ,其中 是已知的。 对应的梯度也可计算,由于 ,根据链式求导和矩阵求导法则有 ,具体如下:
对应梯度的计算则比较复杂,我们先计算 。由于 ,则有 如下表示:
由于,有:
接下来我们定义如下表示:
根据上述定义得到如下表示:
相应的 可表示为如下形式:
又因为 ,结合上述推导利用链式求导法则, 对应的梯度如下表示:
从而得到一个完整的包含前向和反向的 FlashAttention 算子。
FlashAttention 中,由于在 GPU 的不同线程块和 warps 上的任务切分比较粗糙,造成低利用率和不必要的共享内存读写。因此 FlashAttention-2 的目标是优化上述任务的切分。
并行度的提升
FlashAttention-2 针对 GPU 的并行计算特点,优化了任务的切分和并行度:
- 细化任务切分:将计算任务分配到更多的线程块和 warp 中,提高了计算资源的利用率。
- 减少非矩阵乘法的 FLOPs:优化了非矩阵乘法部分的计算,进一步提升了速度。
具体优化策略
FlashAttention-2 对 FlashAttention 中的前向传播进行两项修改:
- 移除 ,只在循环的最后相除得到正确值。
- 前向传播时只需要存储 用于反向传播。
FlashAttention-2 对 workers 并行处理,进行线程的分配。FlashAttention 的前向传播中,对于每一个块,是将 切分到 4 个不同的 warps 上,但是将 保持为对所有的 warps 是可见的。FlashAttention-2 则只对 进行切分,这样 后和共享的 相乘即可。
warps 是 NVIDIA GPU 并行计算的基本单元,一个 warp 通常包含 32 个线程,它们同时执行相同的指令,但对不同的数据进行操作。在 GPU 执行指令时,通常以 warps 为单位进行调度,这可以充分利用 GPU 的并行处理能力。
综合以上改进,FlashAttention-2 前向传播步骤如下:
FlashAttention 尚未充分利用硬件功能,FlashAttention-2 在 H100 GPU 上仅实现了 35% 的理论最大 FLOP 利用率。因此,FlashAttention-3 的目标是在 Hopper GPU(比如 H100, H800)上充分利用 WGMMA 和 TMA 的异步性加速 attention。实现更高效的 GPU 利用率、较低精度下更好的性能,并能够在大模型中使用更长的上下文。
硬件新特性
FlashAttention-3 充分利用了 Hopper 架构硬件的新特性:
- WGMMA:新一代张量核心,提供比 Ampere 架构更高的矩阵乘法吞吐量。
- 在相同时间内,Hopper GPU 可以进行更多的矩阵运算,从而提高整体计算性能。
- 用于执行 GEMMs。
- TMA:加速全局内存与共享内存之间的数据传输,处理索引计算和越界预测,从而释放寄存器。
- 优化数据传输的分块大小,提高数据传输的效率。
- 自动计算数据在内存中的位置,简化了编程过程。数据传输在合法范围内,避免错误,从而降低编程复杂性。
- 加载分块后的 。
- 低精度 FP8,使得张量核心吞吐量翻一倍。
具体技术
FlashAttention-3 采用三项主要技术:
- 生产-消费异步:采用 warp 专用流水线,将数据生产和消费分成不同的 warp。利用异步执行可以更好地隐藏内存和指令发出延迟。
- 在异步 block-wise 的 GEMM 下隐藏 ,使得分块 和 运算交错。
- 非矩阵乘法运算比矩阵乘法运算慢。特殊函数如指数运算(如 函数)的吞吐量甚至低于浮点乘加操作,这些运算是由多功能单元处理的,这是一个与浮点乘加或矩阵乘加不同的单元。
- 我们希望矩阵乘法和 能并行操作。当张量核心忙于矩阵乘法时,多功能单元可进行指数运算。
- 通过将低吞吐量的 操作与异步 WGMMA 指令重叠,FlashAttention-3 可以绕过 和 GEMM 之间的顺序依赖。
- 硬件加速的低精度 GEMM:利用硬件支持 FP8 处理。
- 利用 FP8 张量核心进行 GEMM,可使 TFLOPS/s 翻倍。
- 采用非相干处理,将查询和键与一个随机正交矩阵相乘来分散极端值,从而减少量化误差。特别地,使用 Hadamard 变换,它可以在每个注意力头中以 的时间复杂度完成,而不是 。
- 通过量化和非相干处理可以提高计算效率,并减轻精度降低带来的影响。
# https://github.com/Dao-AILab/flash-attention
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
FlashAttention 通过深入挖掘硬件潜力,优化注意力机制的内存和计算效率,为大模型训练带来了显著的性能提升。随着 FlashAttention-2 和 FlashAttention-3 的迭代升级,未来的 Transformer 模型将更高效地处理更长的上下文,从而在更多场景中发挥更大的潜力。
FlashAttention 系列论文:
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
交流群:点击“联系 作者”--备注“研究方向-公司或学校”
欢迎|论文宣传|合作交流
往期推荐
RecSys'24 | 通过额外的注意力来增强自注意力机制用于序列推荐
长按关注,更多精彩
点个在看你最好看