FlashMLA
今天DeepSeek开源周第一天,开放了FlashMLA仓库, 1小时内星标2.7k!
FlashMLA 是一个高效的 MLA 解码内核,专为 Hopper GPU 优化,适用于可变长度序列。该项目目前发布了 BF16 和具有 64 块大小分页 kvcache 的功能。在 H800 SXM5 上,使用 CUDA 12.6,内存受限配置下可达 3000 GB/s,计算受限配置下可达 580 TFLOPS。
Github仓库地址:https://github.com/deepseek-ai/FlashMLA
这里提到两个比较关键的功能就是 BF16精度计算以及Paged kvcache缓存技术
好巧不巧,近期DeepSeek 发布了一篇新论文, 提出了一种改进版的注意力机制 NSA,即Native Sparse Attention,可以直译为「原生稀疏注意力」 ;但其实就在同一天,月之暗面也发布了一篇主题类似的论文, 提出了一种名为 MoBA 的注意力机制,即 Mixture of Block Attention,可以直译为「块注意力混合」 。注意机制最近这么火爆的背景下,不妨我们趁机复习下Kv Cache相关概念以及相关注意力机制模型。
KV Cache简介
这部分主要参考 LLM推理算法简述 ,可以快速回顾下KV Cache概念,关于更多LLM推理算法讲解大家可以阅读 https://zhuanlan.zhihu.com/p/685794495 。
LLM 推理服务的吞吐量指标主要受制于显存限制。研究团队发现现有系统由于缺乏精细的显存管理方法而浪费了 60% 至 80% 的显存,浪费的显存主要来自 KV Cache。因此,有效管理 KV Cache 是一个重大挑战。
什么是KV Cache?
Transformer 模型具有自回归推理的特点,即每次推理只会预测输出一个 token,当前轮输出 token 与历史输入 tokens 拼接,作为下一轮的输入 tokens,反复执行多次。该过程中,前后两轮的输入只相差一个 token,存在重复计算。KV Cache 技术实现了将可复用的键值向量结果保存下来,从而避免了重复计算。如下图所示,展示了有无 kv cache 的流程:
- Without KV Cache 每次需要计算全 Wq(X), Wk(X), Wv(X),每次需要计算全量 Attn
- With KV Cache ,第一步计算完整
,将
保存成
- With KV Cache ,第二步,取第二步的 Next Token 计算
,将
联合,计算出
如何优化KV cahce?
KV cache的峰值显存占用大小计算公式: 2 x Length x batch_size x [d x n_kv_heads] x Layers x k-bits ,由此我们可以看出影响KV cache的具体因素:
- k-bits: 数据类型,FP16 占2个bytes。(量化)
- 2: 代表 Key/Value 两个向量现
- Length: 输入输出序列的长度(循环队列管理窗口KV,减少长度kv)
- Layers:模型层数
- d x n_kv_heads:kv维度(MQA/GQA通过减少KV的头数减少显存占用)
- batch_size : KV Cache 与 batchsize 度呈线性关系,随着 batch size 的增大,KV cache 占用的显存开销快速增大,甚至会超过模型本身。
- 操作系统管理:现GPU的KV Cache的有效存储率较低低 (page-attention)
在bf16格式下的13B模型中,我们只有大约10G的空间来存储kv cache。
KV Cache 的引入也使得推理过程分为如下两个不同阶段,进而影响到后续的其他 优化方法 。
- 预填充阶段 (Prefill):发生在计算第一个输出 token 过程中,计算时需要为每个 Transformer layer 计算并保存 key cache 和 value cache;FLOPs 同 KV Cache 一致,存在大量 GEMM (GEneral Matrix-Matrix multiply) 操作,属于 Compute-bound 类型计算。
- 解码阶段 (Decoder):发生在计算第二个输出 token 至最后一个 token 过程中,这时 KV Cache 已存有历史键值结果,每轮推理只需读取 KV Cache,同时将当前轮计算出的新 Key、Value 追加写入至 Cache;GEMM 变为 GEMV (GEneral Matrix-Vector multiply) 操作,FLOPs 降低,推理速度相对预填充阶段变快,这时属于 Memory-bound 类型计算。
解码中的KV Cache
我们下面用一个例子更加详细的解释什么是KV Cache,了解一些背景的计算问题,以及KV Cache的概念。
无论是encoder-decoder结构,还是现在我们最接近AGI的decoder-only的LLM,解码生成时都是自回归auto-regressive的方式。也就是说,解码的时候,先根据当前输入
,生成下一个token,然后把生成的token拼接在
后面,获得新的输入
,再用
生成
,依此选择,直到生成结果。
比如我们输入“窗前明月光下一句是”,那么模型每次生成一个token,输入输出会是这样(方便起见,默认每个token都是一个字符)
step0: 输入=[BOS]窗前明月光下一句是;输出=疑
step1: 输入=[BOS]窗前明月光下一句是疑;输出=是
step2: 输入=[BOS]窗前明月光下一句是疑是;输出=地
step3: 输入=[BOS]窗前明月光下一句是疑是地;输出=上
step4: 输入=[BOS]窗前明月光下一句是疑是地上;输出=霜
step5: 输入=[BOS]窗前明月光下一句是疑是地上霜;输出=[EOS]
(其中[BOS]和[EOS]分别是开始和结束的标记字符)
我们看一下在计算的过程中,如何输入的token “是” 的最后是hidden state如何传递到后面的类Token预测模型,以及后面每一个token,使用新的输入列中最后一个时刻的输出。
我们可以看到,在每一个step的计算中,主要包含了上一轮step的内容,而且只在最后一步使用(一个token)。那么每一个计算也就包含了上一轮step的计算内容。
从公式来看是这样的,回想一下我们attention的计算:
注意对于decoder的时候,由于mask attention的存在,每个输入只能看到自己和前面的内容,而看不到后面的内容。
假设我们当前输入的长度是3,预测第4个字,那么每层attention所做的计算有:
预测完第4个字,放到输入里,继续预测第5个字,每层attention所做的计算有:
可以看到,在预测第5个字时,只有最后一步引入了新的计算,而
到
的计算部分是完全重复的。
但是模型在推理的时候可不管这些,无论你是否只是要最后一个字的输出,它都会把所有输入计算一遍,给出所有输出结果。
也就是说中间有很多我们不需要的计算,这样就造成了浪费。
而且随着生成的结果越来越多,输入的长度也越来越长,上面这个例子里,输入长度是step0的10个, 每步骤,直接step5到15个。如果输入的instruction是规范型任务,那么可能有800个step。这个情况下,step0就变得有800次,step1被重复了799次——这样浪费的计算资源显然不可忍受。
有没有什么方法可以重利用上一个step里已经计算过的结果,减少浪费呢?
答案就是KV Cache,利用一个缓存,把需要重复利用的时序计算结果保存下来,减少重复计算。
而
和
就是需要保存的对象。
想一想,下图就是缓存的过程,假设我们第一次输入的输入长度是3个,我们第一次预测输出预测第4个字,那么由于下图给你看的是每个输入步骤的缓存,每个时序步骤都需要存储一次,而我们依旧会有些重复计算的情况。则有:
kv_cache下标l表示模型层数。在进行第二次预测时,也就是预测第5个字的时候,在第l层的时候,由于前面我们缓存了每层的
,
值,那层就不需要算新的
,而不再算
,
。因为第l层的
,
本来经过FFN层之后进到
层,再经过新的投影变换,成为
层的
,
值,但是是
层的
,
值就已经保留了!
然后我们把本次新算出来的
,
值也存储起来。
然后我们再做下一次计算出的结果:
这样就节省了attention和FFN的很多重复计算。
transformers中,生成的时候传入use_cache=True就会开启KV Cache。
也可以简单看下GPT2中的实现,中文注释的部分就是使用缓存结果和更新缓存结果
Class GPT2Attention(nn.Module):
...
...
def
forward
(
self,
hidden\_states: Optional[Tuple[torch.FloatTensor]],
layer\_past: Optional[Tuple[torch.Tensor]] = None,
attention\_mask: Optional[torch.FloatTensor] = None,
head\_mask: Optional[torch.FloatTensor] = None,
encoder\_hidden\_states: Optional[torch.Tensor] = None,
encoder\_attention\_mask: Optional[torch.FloatTensor] = None,
use\_cache: Optional[bool] = False,
output\_attentions: Optional[bool] = False,
)
-> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if
encoder\_hidden\_states
is
not
None
:
if
not
hasattr(self,
"q\_attn"
):
raise
ValueError(
"If class is used as cross attention, the weights `q\_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is\_cross\_attention=True)`."
)
query = self.q\_attn(hidden\_states)
key, value = self.c\_attn(encoder\_hidden\_states).split(self.split\_size, dim=
2
)
attention\_mask = encoder\_attention\_mask
else
:
query, key, value = self.c\_attn(hidden\_states).split(self.split\_size, dim=
2
)
query = self.\_split\_heads(query, self.num\_heads, self.head\_dim)
key = self.\_split\_heads(key, self.num\_heads, self.head\_dim)
value = self.\_split\_heads(value, self.num\_heads, self.head\_dim)
# 过去所存的值
if
layer\_past
is
not
None
:
past\_key, past\_value = layer\_past
key = torch.cat((past\_key, key), dim=
-2
)
# 把当前新的key加入
value = torch.cat((past\_value, value), dim=
-2
)
# 把当前新的value加入
if
use\_cache
is
True
:
present = (key, value)
# 输出用于保存
else
:
present =
None
if
self.reorder\_and\_upcast\_attn:
attn\_output, attn\_weights = self.\_upcast\_and\_reordered\_attn(query, key, value, attention\_mask, head\_mask)
else
:
attn\_output, attn\_weights = self.\_attn(query, key, value, attention\_mask, head\_mask)
attn\_output = self.\_merge\_heads(attn\_output, self.num\_heads, self.head\_dim)
attn\_output = self.c\_proj(attn\_output)
attn\_output = self.resid\_dropout(attn\_output)
outputs = (attn\_output, present)
if
output\_attentions:
outputs += (attn\_weights,)
return
outputs
# a, present, (attentions)
总的来说,KV Cache是以空间换时间的做法,通过使用快速的缓存存储,减少了重复计算。(注意,只能在decoder结构的模型可用,因为有mask attention的存在,使得前面的token可以不用关照后面的token)
但是,用了KV Cache之后也不是立刻万事大吉。
我们简单计算一下,对于输入长度为
,层数为
,hidden size为
的模型,需要缓存的参数量为
如果使用的是半精度浮点数,那么每个值所需要的空间就是
以Llama2 7B为例,有
,
,那么每个token所需的缓存空间就是524,288 bytes,约524k,假设
,则需要占用536,870,912 bytes,超过500M的空间。
这些参数的大小是batch size=1的情况,如果batch size增大,这个值是很容易就超过1G。
注意力相关模型
笔者之前回顾了一些注意力机制相关模型: 从MHA、MQA、GQA、MLA到NSA、MoBA ,写了下面一篇文章来解读这些注意力模型,
注意力机制进化史:从MHA到MoBA,新一代注意力机制的极限突破!
这里为了节约篇幅,只提一下模型简介,具体原理可以阅读上面综述。
MHA:Multi-Head Attention
论文标题:Attention Is All You Need 论文链接: https://arxiv.org/pdf/1706.03762
MHA在2017年就随着《Attention Is All You Need》一起提出,主要干的就是一个事:把原来一个attention计算,拆成多个小份的attention,并行计算,分别得出结果,最后再合回原来的维度。
MQA:Multi-Query Attention
论文标题:Fast Transformer Decoding: One Write-Head is All You Need 论文链接: https://arxiv.org/pdf/1911.02150
MQA就是减少所有所需要的键值缓存内存消耗的。
Google在2019年就提出了《Fast Transformer Decoding: One Write-Head is All You Need》提出了MQA,不过那时候主要是针对的人不多,那是大家主要还是关注在用Bert也开始创新上。
MQA的做法其实很简单。在MHA中,输入分别经过
的变换之后,都切成
份(
=头数),维度也从
降到
,分别进行attention计算再拼接。而MQA这一步,在运算过程中,首先对
进行切分(和MHA一样),而
则直接在在线变换的时候把维度压到
(而不是切分开),然后返回每个Query头分别和一份
进行attention计算,之后最终结果拼接起来。
简而言之,就是MHA中,每个注意力头的
是不一样的,而MQA这里,每个注意力头的
是一样的,值是共享的。而性别效果和MHA一样。
GQA:Grouped Query Attention
论文标题:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints 论文链接: https://arxiv.org/pdf/2305.13245
既然MQA对效果有点影响,MHA存储又有不下,那2023年GQA(Grouped-Query Attention)就提出了一个折中的办法,既能减少MQA效果的损失,又相比MHA需要更少的存储。
GQA是,
还是按原来MHA/MQA的做法不变。只使用一套共享的
就能效果不好吗,那就还是多个头。但是要不要太多,数量还是比
的头数少一些,这样相当于把多个头分成group,同一个group内的
共享,同不group的
所用的
不同。
MHA可以认为是
头数最大时的GQA(有几个Q就有几个K,V),而MQA可以认为是
头数最少时的GQA(K的头数只有1个)。
MLA:Multi-head Latent Attention
论文标题:DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model 论文链接: https://arxiv.org/abs/2405.04434
MLA(Memory-efficient Latent Attention) 的核心思想是将注意力输入
压缩到一个低维的潜在向量,记作
,其维度
远小于原始的
维度。这样,在计算注意力时,我们可以通过映射将该潜在向量恢复到高维空间,以重构键(keys)和值(values)。这种方法的优势在于,只需存储低维的潜在向量,从而大幅减少内存占用。
这一过程可以用以下公式描述:
- 是低维的潜在向量。
- 是一个压缩矩阵(down-projection matrix),用于将
的维度从
降维到
(其中 D 代表“降维”)。
- 和
是两个向上投影矩阵(up-projection matrices),分别用于将共享的潜在向量映射回高维空间,以恢复键(K)和值(V)。
多头注意力(MHA)、分组查询注意力(GQA)、多查询注意力(MQA)和多头潜在注意力(MLA)的简化示意图。通过将键(keys)和值(values)联合压缩到一个潜在向量中,MLA在推理过程中显著减少了KV缓存的大小。
NSA:Native Sparse Attention
Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
NSA的技术方法涵盖算法设计与内核优化。其整体框架基于对注意力机制的重新定义,通过设计不同的映射策略构建更紧凑、信息更密集的键值对表示,以减少计算量。同时,针对硬件特性进行内核优化,提升实际运行效率。
MoBA: Mixture of Block Attention for Long-Context LLMs
论文标题:Mixture of Block Attention for Long-Context LLMs 论文地址: https://github.com/MoonshotAI/MoBA/blob/master/MoBA\_Tech\_Report.pdf
它基于MoE原理,应用于Transformer模型的注意力机制,通过将上下文划分为块,并采用门控机制选择性地将查询令牌路由到最相关的块,提高LLMs效率,使模型能处理更长更复杂的提示,同时降低资源消耗。
Page Attention
论文标题:Efficient Memory Management for Large Language Model Serving with PagedAttention 论文地址: https://arxiv.org/pdf/2309.06180
在 vLLM 库中 LLM 服务的性能受到内存瓶颈的影响。在自回归 decoder 中,所有输入到 LLM 的 token 会产生注意力 key 和 value 的张量,这些张量保存在 GPU 显存中以生成下一个 token。这些缓存 key 和 value 的张量通常被称为 KV cache,其具有以下特点:
- 显存占用大:在 LLaMA-13B 中,缓存单个序列最多需要 1.7GB 显存;
- 动态变化:KV 缓存的大小取决于序列长度,这是高度可变和不可预测的。因此,这对有效管理 KV cache 挑战较大。该研究发现,由于碎片化和过度保留,现有系统浪费了 60% - 80% 的显存。
为了解决这个问题,该研究引入了 PagedAttention,这是一种受操作系统中虚拟内存和分页经典思想启发的注意力算法。与传统的注意力算法不同,PagedAttention 允许在非连续的内存空间中存储连续的 key 和 value 。具体来说,PagedAttention 将每个序列的 KV cache 划分为块,每个块包含固定数量 token 的键和值。在注意力计算期间,PagedAttention 内核可以有效地识别和获取这些块。
具体来讲,Paged Attention 将每个序列的 KV Cache 分成若干块,每个块包含固定数量token 的键和值。在注意力计算期间,PagedAttention 内核可以有效地识别和获取这些块。因为块在内存中不需要连续,因而可以用一种更加灵活的方式管理 key 和 value ,就像在操作系统的虚拟内存中一样:可以将块视为页面,将 token 视为字节,将序列视为进程。序列的连续逻辑块通过块表映射到非连续物理块中。物理块在生成新 token 时按需分配。在 PagedAttention 中,内存浪费只会发生在序列的最后一个块中。这使得在实践中可以实现接近最佳的内存使用,仅浪费不到 4%。
FlashAttention
论文标题:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 论文地址: https://arxiv.org/pdf/2205.14135
此前针对Transformer类模型的优化工作致力于减少 FLOP 从而优化计算速度,忽略了显存IO访问上的优化。
attention计算所带来的O(n^2)复杂度的矩阵对HBM的重复读写是制约模型推理的一个主要瓶颈。
如下图所示GPU SRAM的读写(I/O)的速度为19TB/s 和GPU HBM 的读写(I/O)速度 1.5TB/s 相差十几倍,而对比存储容量也相差了好几个数量级。从图中可以看出Flash Attention的计算是从HBM中读取块,在SRAM中计算之后再写到HBM中。
要解决这些问题,需要做两件主要的事情:
- 在不访问整个输入的情况下计算 softmax
- 不为反向传播存储大的中间 attention 矩阵
为此 FlashAttention 提出了两种方法来分布解决上述问题:tiling 和 recomputation。
- tiling - 注意力计算被重新构造,将输入分割成块,并通过在输入块上进行多次传递来递增地执行softmax操作。
- recomputation - 存储来自前向的 softmax 归一化因子,以便在反向中快速重新计算芯片上的 attention,这比从HBM读取中间矩阵的标准注意力方法更快。
由于重新计算,这虽然导致FLOPS增加,但是由于大量减少HBM访问,FlashAttention运行速度更快。 该算法背后的主要思想是分割输入,将它们从慢速HBM加载到快速SRAM,然后计算这些块的 attention 输出。在将每个块的输出相加之前,将其按正确的归一化因子进行缩放,从而得到正确的结果。
参考资料
-
LLM(十七):从 FlashAttention 到 PagedAttention, 如何进一步优化 Attention 性能
-
LLM推理算法简述