FlashAttention 系列技术详解:加速大模型训练的利器

大模型向量数据库云存储

© 作者|陈杰

机构|中国人民大学

研究方向|自然语言处理、大语言模型

picture.image

背景与动机

为什么需要 FlashAttention?

Transformer 模型在自然语言处理(NLP)和大语言模型(LLM)领域取得了巨大成功。然而,传统的 自注意力(Self-Attention) 模块在处理长序列时面临 的时间和空间复杂度,这极大限制了其在长上下文处理上的效率。当输入序列长度 增大时,注意力机制的计算和内存开销急剧上升,导致模型训练速度变慢,显存占用剧增。

FlashAttention 的解决方案

为了解决上述问题,FlashAttention 提出了一种 快速、节省内存、精确的注意力计算方法 ,通过减少内存读写次数,提高 GPU 的内存 IO 效率,加快训练速度,并增加上下文窗口长度。其核心目标是通过 硬件感知(IO-awareness) 的算法优化,将整个注意力计算过程在更高速的 SRAM (静态随机存取存储器)中完成,减少对 HBM(高带宽内存)的依赖。

  • 快速:降低 HBM 访问次数,采用 Tiling 的方法,分块从 HBM 加载矩阵到 SRAM 进行融合计算。
  • 节省内存:不再对中间矩阵进行存储,反向传播时重新计算来计算梯度。
  • 精确:算法中无近似操作,只是进行分块计算。
FlashAttention 的基本原理

标准注意力机制的内存访问分析

picture.image

在传统的注意力机制中,首先需要将输入矩阵 从 HBM 中读取,随后计算注意力矩阵 ,再通过 函数得到概率矩阵 ,最后与 相乘得到输出 。这个过程的内存访问复杂度如下:

  1. 读取 ,计算 ,再将 存入 HBM:复杂度为 。
  2. 读取 ,计算 ,再存入 HBM:复杂度为 。
  3. 读取 ,计算 ,得到最终结果 :复杂度为 。

综上所述,标准注意力机制的内存访问总复杂度为:

由于 矩阵很大,需要在 HBM 中存储,这带来了很多 HBM 的访问次数,导致算法的时间延迟。

FlashAttention 如何优化

FlashAttention 将从输入 到输出 的整个过程进行融合,避免 矩阵的存储开销,实现延迟缩减。其中面临的挑战是,输入长度 通常很大,无法完全将完整的 存储在 SRAM 中,而我们需要让计算过程的结果完全在 SRAM 中,摆脱对 HBM 的依赖。

一种解决方案是采用分片操作,每次进行部分计算,但分片操作面临的挑战是,我们需要确保这些计算结果能在 SRAM 内进行交互,能基于存储的中间值更新得到最终结果。特别地,之前对于 的计算是以行为单位,当我们将输入进行分片后,无法对完整的行数据执行 操作(因为 函数在计算时需要考虑整行数据)。FlashAttention 采用 Tiling 的方式,实现分块 。

总体而言,FlashAttention 通过以下方法优化上述过程:

  • 分块计算 :将 切分为小块,每次只加载一部分数据到 SRAM 中,并避免存储中间矩阵 ,从而减少 HBM 访问次数。
  • 融合计算 :将从 到 的整个流程进行融合,不再单独存储中间结果。
  • 分片 通过 Tiling 技术解决分块后 的计算问题,实现更高效的内存管理。
FlashAttention 的核心算法

Tiling

我们介绍分片 在 FlashAttention 中的具体实现。

以向量 为例,记

将 分为 块,,其中

从而:

记 ,

则:

从而我们有:

因此,在分块计算时,我们可以迭代更新 ,以得到最终的结果。

具体而言,已知前个 分块的结果:

其中:

迭代过程具体如下:

更新:

从而:

这样,我们可基于已有的分块结果和新的分块信息更新得到新的分块结果,实现分片 。

前向传播步骤

picture.image

  1. 读取 并计算
  • 通过将 分块加载到 SRAM 中,减少 HBM 的读取次数。
  • 使用分块矩阵乘法进行计算。
  • 分块
  • 由于输入被分片处理,无法对整行数据执行标准的 操作。
  • FlashAttention 通过在每个块内进行 计算,并在后续块计算时调整累积值,实现准确的计算。
  • 计算输出
  • 累加各分块的结果,最终得到输出矩阵 。

算法复杂度分析

假设 的分块大小为 , 的分块大小为 ,对于 的每一分块,都需要加载 ,则有 FlashAttention 的内存访问复杂度为 。

在 SRAM 大小为 的条件下,我们有以下约束:

相应的, 有如下限制:

最终,还有一个中间态 需要存储,则有如下限制:

综上,限制如下:

进而推出:

那么在 的前提下,则有 FlashAttention 的 HBM 内存访问复杂度为:

在语言建模中,通常有 ,则有 。

因此,FlashAttention 采用分块计算,避免 矩阵存储开销,降低 HBM 访问次数,从而提升计算速度。

反向传播步骤

前向传播时我们为了减少 HBM 访存次数,并没有对 矩阵进行存储,而在反向传播计算梯度的时候需要这一信息。FlashAttention 采用重新计算对应梯度的方式,利用前向时存储的指数项之和 来进行梯度的计算。

下面我们计算损失函数 对 对应的梯度。

对应的梯度 ,其中 是已知的。 对应的梯度也可计算,由于 ,根据链式求导和矩阵求导法则有 ,具体如下:

对应梯度的计算则比较复杂,我们先计算 。由于 ,则有 如下表示:

由于,有:

接下来我们定义如下表示:

根据上述定义得到如下表示:

相应的 可表示为如下形式:

又因为 ,结合上述推导利用链式求导法则, 对应的梯度如下表示:

从而得到一个完整的包含前向和反向的 FlashAttention 算子。

FlashAttention-2

FlashAttention 中,由于在 GPU 的不同线程块和 warps 上的任务切分比较粗糙,造成低利用率和不必要的共享内存读写。因此 FlashAttention-2 的目标是优化上述任务的切分。

并行度的提升

FlashAttention-2 针对 GPU 的并行计算特点,优化了任务的切分和并行度:

  • 细化任务切分:将计算任务分配到更多的线程块和 warp 中,提高了计算资源的利用率。
  • 减少非矩阵乘法的 FLOPs:优化了非矩阵乘法部分的计算,进一步提升了速度。

具体优化策略

FlashAttention-2 对 FlashAttention 中的前向传播进行两项修改:

  • 移除 ,只在循环的最后相除得到正确值。
  • 前向传播时只需要存储 用于反向传播。

picture.image

FlashAttention-2 对 workers 并行处理,进行线程的分配。FlashAttention 的前向传播中,对于每一个块,是将 切分到 4 个不同的 warps 上,但是将 保持为对所有的 warps 是可见的。FlashAttention-2 则只对 进行切分,这样 后和共享的 相乘即可。

warps 是 NVIDIA GPU 并行计算的基本单元,一个 warp 通常包含 32 个线程,它们同时执行相同的指令,但对不同的数据进行操作。在 GPU 执行指令时,通常以 warps 为单位进行调度,这可以充分利用 GPU 的并行处理能力。

综合以上改进,FlashAttention-2 前向传播步骤如下:

picture.image

FlashAttention-3

FlashAttention 尚未充分利用硬件功能,FlashAttention-2 在 H100 GPU 上仅实现了 35% 的理论最大 FLOP 利用率。因此,FlashAttention-3 的目标是在 Hopper GPU(比如 H100, H800)上充分利用 WGMMA 和 TMA 的异步性加速 attention。实现更高效的 GPU 利用率、较低精度下更好的性能,并能够在大模型中使用更长的上下文。

硬件新特性

FlashAttention-3 充分利用了 Hopper 架构硬件的新特性:

  • WGMMA:新一代张量核心,提供比 Ampere 架构更高的矩阵乘法吞吐量。
  • 在相同时间内,Hopper GPU 可以进行更多的矩阵运算,从而提高整体计算性能。
  • 用于执行 GEMMs。
  • TMA:加速全局内存与共享内存之间的数据传输,处理索引计算和越界预测,从而释放寄存器。
  • 优化数据传输的分块大小,提高数据传输的效率。
  • 自动计算数据在内存中的位置,简化了编程过程。数据传输在合法范围内,避免错误,从而降低编程复杂性。
  • 加载分块后的 。
  • 低精度 FP8,使得张量核心吞吐量翻一倍。

具体技术

FlashAttention-3 采用三项主要技术:

  • 生产-消费异步:采用 warp 专用流水线,将数据生产和消费分成不同的 warp。利用异步执行可以更好地隐藏内存和指令发出延迟。
  • 在异步 block-wise 的 GEMM 下隐藏 ,使得分块 和 运算交错。
  • 非矩阵乘法运算比矩阵乘法运算慢。特殊函数如指数运算(如 函数)的吞吐量甚至低于浮点乘加操作,这些运算是由多功能单元处理的,这是一个与浮点乘加或矩阵乘加不同的单元。
  • 我们希望矩阵乘法和 能并行操作。当张量核心忙于矩阵乘法时,多功能单元可进行指数运算。
  • 通过将低吞吐量的 操作与异步 WGMMA 指令重叠,FlashAttention-3 可以绕过 和 GEMM 之间的顺序依赖。
  • 硬件加速的低精度 GEMM:利用硬件支持 FP8 处理。
  • 利用 FP8 张量核心进行 GEMM,可使 TFLOPS/s 翻倍。
  • 采用非相干处理,将查询和键与一个随机正交矩阵相乘来分散极端值,从而减少量化误差。特别地,使用 Hadamard 变换,它可以在每个注意力头中以 的时间复杂度完成,而不是 。
  • 通过量化和非相干处理可以提高计算效率,并减轻精度降低带来的影响。
FlashAttention 调用方式

        
          
# https://github.com/Dao-AILab/flash-attention  
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func  

      
总结

FlashAttention 通过深入挖掘硬件潜力,优化注意力机制的内存和计算效率,为大模型训练带来了显著的性能提升。随着 FlashAttention-2 和 FlashAttention-3 的迭代升级,未来的 Transformer 模型将更高效地处理更长的上下文,从而在更多场景中发挥更大的潜力。

FlashAttention 系列论文:

  • FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
  • FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
  • FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

交流群:点击“联系 作者”--备注“研究方向-公司或学校”

欢迎|论文宣传|合作交流

往期推荐

阿里 | 多分支协作网络用于淘宝中点击率(CTR)预测

KDD'24 | DESC:在校准中考虑形状校准和值校准

RecSys'24 | 通过额外的注意力来增强自注意力机制用于序列推荐

长按关注,更多精彩

点个在看你最好看

picture.image

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论