DeepSeek开源周 Day01:从FlashMLA背后原理回顾KV Cache

大模型向量数据库机器学习

FlashMLA

今天DeepSeek开源周第一天,开放了FlashMLA仓库, 1小时内星标2.7k!

picture.image

FlashMLA 是一个高效的 MLA 解码内核,专为 Hopper GPU 优化,适用于可变长度序列。该项目目前发布了 BF16 和具有 64 块大小分页 kvcache 的功能。在 H800 SXM5 上,使用 CUDA 12.6,内存受限配置下可达 3000 GB/s,计算受限配置下可达 580 TFLOPS。

Github仓库地址:https://github.com/deepseek-ai/FlashMLA

这里提到两个比较关键的功能就是 BF16精度计算以及Paged kvcache缓存技术

好巧不巧,近期DeepSeek 发布了一篇新论文, 提出了一种改进版的注意力机制 NSA,即Native Sparse Attention,可以直译为「原生稀疏注意力」 ;但其实就在同一天,月之暗面也发布了一篇主题类似的论文, 提出了一种名为 MoBA 的注意力机制,即 Mixture of Block Attention,可以直译为「块注意力混合」 。注意机制最近这么火爆的背景下,不妨我们趁机复习下Kv Cache相关概念以及相关注意力机制模型。

KV Cache简介

这部分主要参考 LLM推理算法简述 ,可以快速回顾下KV Cache概念,关于更多LLM推理算法讲解大家可以阅读 https://zhuanlan.zhihu.com/p/685794495

LLM 推理服务的吞吐量指标主要受制于显存限制。研究团队发现现有系统由于缺乏精细的显存管理方法而浪费了 60% 至 80% 的显存,浪费的显存主要来自 KV Cache。因此,有效管理 KV Cache 是一个重大挑战。

什么是KV Cache?

Transformer 模型具有自回归推理的特点,即每次推理只会预测输出一个 token,当前轮输出 token 与历史输入 tokens 拼接,作为下一轮的输入 tokens,反复执行多次。该过程中,前后两轮的输入只相差一个 token,存在重复计算。KV Cache 技术实现了将可复用的键值向量结果保存下来,从而避免了重复计算。如下图所示,展示了有无 kv cache 的流程:

  • Without KV Cache 每次需要计算全 Wq(X), Wk(X), Wv(X),每次需要计算全量 Attn
  • With KV Cache ,第一步计算完整

,将

保存成

  • With KV Cache ,第二步,取第二步的 Next Token 计算

,将

联合,计算出

picture.image

图片来源:陆淳,https://zhuanlan.zhihu.com/p/685794495

如何优化KV cahce?

KV cache的峰值显存占用大小计算公式: 2 x Length x batch_size x [d x n_kv_heads] x Layers x k-bits ,由此我们可以看出影响KV cache的具体因素:

  • k-bits: 数据类型,FP16 占2个bytes。(量化)
  • 2: 代表 Key/Value 两个向量现
  • Length: 输入输出序列的长度(循环队列管理窗口KV,减少长度kv)
  • Layers:模型层数
  • d x n_kv_heads:kv维度(MQA/GQA通过减少KV的头数减少显存占用)
  • batch_size : KV Cache 与 batchsize 度呈线性关系,随着 batch size 的增大,KV cache 占用的显存开销快速增大,甚至会超过模型本身。
  • 操作系统管理:现GPU的KV Cache的有效存储率较低低 (page-attention)

picture.image

在bf16格式下的13B模型中,我们只有大约10G的空间来存储kv cache。

KV Cache 的引入也使得推理过程分为如下两个不同阶段,进而影响到后续的其他 优化方法

  • 预填充阶段 (Prefill):发生在计算第一个输出 token 过程中,计算时需要为每个 Transformer layer 计算并保存 key cache 和 value cache;FLOPs 同 KV Cache 一致,存在大量 GEMM (GEneral Matrix-Matrix multiply) 操作,属于 Compute-bound 类型计算。
  • 解码阶段 (Decoder):发生在计算第二个输出 token 至最后一个 token 过程中,这时 KV Cache 已存有历史键值结果,每轮推理只需读取 KV Cache,同时将当前轮计算出的新 Key、Value 追加写入至 Cache;GEMM 变为 GEMV (GEneral Matrix-Vector multiply) 操作,FLOPs 降低,推理速度相对预填充阶段变快,这时属于 Memory-bound 类型计算。

解码中的KV Cache

我们下面用一个例子更加详细的解释什么是KV Cache,了解一些背景的计算问题,以及KV Cache的概念。

无论是encoder-decoder结构,还是现在我们最接近AGI的decoder-only的LLM,解码生成时都是自回归auto-regressive的方式。也就是说,解码的时候,先根据当前输入

,生成下一个token,然后把生成的token拼接在

后面,获得新的输入

,再用

生成

,依此选择,直到生成结果。

比如我们输入“窗前明月光下一句是”,那么模型每次生成一个token,输入输出会是这样(方便起见,默认每个token都是一个字符)


            
 
   

 
 
            
                

              step0: 输入=[BOS]窗前明月光下一句是;输出=疑
              
   

 
              step1: 输入=[BOS]窗前明月光下一句是疑;输出=是
              
   

 
              step2: 输入=[BOS]窗前明月光下一句是疑是;输出=地
              
   

 
              step3: 输入=[BOS]窗前明月光下一句是疑是地;输出=上
              
   

 
              step4: 输入=[BOS]窗前明月光下一句是疑是地上;输出=霜
              
   

 
              step5: 输入=[BOS]窗前明月光下一句是疑是地上霜;输出=[EOS]
              
   

 
            
          

(其中[BOS]和[EOS]分别是开始和结束的标记字符)

我们看一下在计算的过程中,如何输入的token “是” 的最后是hidden state如何传递到后面的类Token预测模型,以及后面每一个token,使用新的输入列中最后一个时刻的输出。

我们可以看到,在每一个step的计算中,主要包含了上一轮step的内容,而且只在最后一步使用(一个token)。那么每一个计算也就包含了上一轮step的计算内容。

从公式来看是这样的,回想一下我们attention的计算:

注意对于decoder的时候,由于mask attention的存在,每个输入只能看到自己和前面的内容,而看不到后面的内容。

假设我们当前输入的长度是3,预测第4个字,那么每层attention所做的计算有:

预测完第4个字,放到输入里,继续预测第5个字,每层attention所做的计算有:

可以看到,在预测第5个字时,只有最后一步引入了新的计算,而

的计算部分是完全重复的。

但是模型在推理的时候可不管这些,无论你是否只是要最后一个字的输出,它都会把所有输入计算一遍,给出所有输出结果。

也就是说中间有很多我们不需要的计算,这样就造成了浪费。

而且随着生成的结果越来越多,输入的长度也越来越长,上面这个例子里,输入长度是step0的10个, 每步骤,直接step5到15个。如果输入的instruction是规范型任务,那么可能有800个step。这个情况下,step0就变得有800次,step1被重复了799次——这样浪费的计算资源显然不可忍受。

有没有什么方法可以重利用上一个step里已经计算过的结果,减少浪费呢?

答案就是KV Cache,利用一个缓存,把需要重复利用的时序计算结果保存下来,减少重复计算。

就是需要保存的对象。

想一想,下图就是缓存的过程,假设我们第一次输入的输入长度是3个,我们第一次预测输出预测第4个字,那么由于下图给你看的是每个输入步骤的缓存,每个时序步骤都需要存储一次,而我们依旧会有些重复计算的情况。则有:

picture.image

kv_cache下标l表示模型层数。在进行第二次预测时,也就是预测第5个字的时候,在第l层的时候,由于前面我们缓存了每层的

,

值,那层就不需要算新的

,而不再算

,

。因为第l层的

,

本来经过FFN层之后进到

层,再经过新的投影变换,成为

层的

,

值,但是是

层的

,

值就已经保留了!

然后我们把本次新算出来的

,

值也存储起来。

然后我们再做下一次计算出的结果:

这样就节省了attention和FFN的很多重复计算。

transformers中,生成的时候传入use_cache=True就会开启KV Cache。

也可以简单看下GPT2中的实现,中文注释的部分就是使用缓存结果和更新缓存结果


            
 
   

 
 
            
                

              Class GPT2Attention(nn.Module):
              
   

 
                  ...
              
   

 
                  ...
              
   

 
                  
              
 
 def
 
  
 
 forward
 
 
 (
 
   

 
         self,
 
   

 
         hidden\_states: Optional[Tuple[torch.FloatTensor]],
 
   

 
         layer\_past: Optional[Tuple[torch.Tensor]] = None,
 
   

 
         attention\_mask: Optional[torch.FloatTensor] = None,
 
   

 
         head\_mask: Optional[torch.FloatTensor] = None,
 
   

 
         encoder\_hidden\_states: Optional[torch.Tensor] = None,
 
   

 
         encoder\_attention\_mask: Optional[torch.FloatTensor] = None,
 
   

 
         use\_cache: Optional[bool] = False,
 
   

 
         output\_attentions: Optional[bool] = False,
 
   

 
     )
 
  -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
 
              
   

 
                      
              
 if
 
               encoder\_hidden\_states 
              
 is
 
              
 not
 
              
 None
 
              :
              
   

 
                          
              
 if
 
              
 not
 
               hasattr(self, 
              
 "q\_attn"
 
              ):
              
   

 
                              
              
 raise
 
               ValueError(
              
   

 
                                  
              
 "If class is used as cross attention, the weights `q\_attn` have to be defined. "
 
              
   

 
                                  
              
 "Please make sure to instantiate class with `GPT2Attention(..., is\_cross\_attention=True)`."
 
              
   

 
                              )
              
   

 
              
   

 
                          query = self.q\_attn(hidden\_states)
              
   

 
                          key, value = self.c\_attn(encoder\_hidden\_states).split(self.split\_size, dim=
              
 2
 
              )
              
   

 
                          attention\_mask = encoder\_attention\_mask
              
   

 
                      
              
 else
 
              :
              
   

 
                          query, key, value = self.c\_attn(hidden\_states).split(self.split\_size, dim=
              
 2
 
              )
              
   

 
              
   

 
                      query = self.\_split\_heads(query, self.num\_heads, self.head\_dim)
              
   

 
                      key = self.\_split\_heads(key, self.num\_heads, self.head\_dim)
              
   

 
                      value = self.\_split\_heads(value, self.num\_heads, self.head\_dim)
              
   

 
              
   

 
                      
              
 # 过去所存的值
 
              
   

 
                      
              
 if
 
               layer\_past 
              
 is
 
              
 not
 
              
 None
 
              :
              
   

 
                          past\_key, past\_value = layer\_past
              
   

 
                          key = torch.cat((past\_key, key), dim=
              
 -2
 
              )  
              
 # 把当前新的key加入
 
              
   

 
                          value = torch.cat((past\_value, value), dim=
              
 -2
 
              )  
              
 # 把当前新的value加入
 
              
   

 
              
   

 
                      
              
 if
 
               use\_cache 
              
 is
 
              
 True
 
              :
              
   

 
                          present = (key, value)  
              
 # 输出用于保存
 
              
   

 
                      
              
 else
 
              :
              
   

 
                          present = 
              
 None
 
              
   

 
              
   

 
                      
              
 if
 
               self.reorder\_and\_upcast\_attn:
              
   

 
                          attn\_output, attn\_weights = self.\_upcast\_and\_reordered\_attn(query, key, value, attention\_mask, head\_mask)
              
   

 
                      
              
 else
 
              :
              
   

 
                          attn\_output, attn\_weights = self.\_attn(query, key, value, attention\_mask, head\_mask)
              
   

 
              
   

 
                      attn\_output = self.\_merge\_heads(attn\_output, self.num\_heads, self.head\_dim)
              
   

 
                      attn\_output = self.c\_proj(attn\_output)
              
   

 
                      attn\_output = self.resid\_dropout(attn\_output)
              
   

 
              
   

 
                      outputs = (attn\_output, present)
              
   

 
                      
              
 if
 
               output\_attentions:
              
   

 
                          outputs += (attn\_weights,)
              
   

 
              
   

 
                      
              
 return
 
               outputs  
              
 # a, present, (attentions)
 
              
   

 
            
          

总的来说,KV Cache是以空间换时间的做法,通过使用快速的缓存存储,减少了重复计算。(注意,只能在decoder结构的模型可用,因为有mask attention的存在,使得前面的token可以不用关照后面的token)

但是,用了KV Cache之后也不是立刻万事大吉。

我们简单计算一下,对于输入长度为

,层数为

,hidden size为

的模型,需要缓存的参数量为

如果使用的是半精度浮点数,那么每个值所需要的空间就是

以Llama2 7B为例,有

,那么每个token所需的缓存空间就是524,288 bytes,约524k,假设

,则需要占用536,870,912 bytes,超过500M的空间。

这些参数的大小是batch size=1的情况,如果batch size增大,这个值是很容易就超过1G。

注意力相关模型

笔者之前回顾了一些注意力机制相关模型: 从MHA、MQA、GQA、MLA到NSA、MoBA ,写了下面一篇文章来解读这些注意力模型,

注意力机制进化史:从MHA到MoBA,新一代注意力机制的极限突破!

这里为了节约篇幅,只提一下模型简介,具体原理可以阅读上面综述。

MHA:Multi-Head Attention

论文标题:Attention Is All You Need 论文链接: https://arxiv.org/pdf/1706.03762

picture.image

MHA在2017年就随着《Attention Is All You Need》一起提出,主要干的就是一个事:把原来一个attention计算,拆成多个小份的attention,并行计算,分别得出结果,最后再合回原来的维度。picture.image

MQA:Multi-Query Attention

论文标题:Fast Transformer Decoding: One Write-Head is All You Need 论文链接: https://arxiv.org/pdf/1911.02150

MQA就是减少所有所需要的键值缓存内存消耗的。

Google在2019年就提出了《Fast Transformer Decoding: One Write-Head is All You Need》提出了MQA,不过那时候主要是针对的人不多,那是大家主要还是关注在用Bert也开始创新上。

picture.image

MQA的做法其实很简单。在MHA中,输入分别经过

的变换之后,都切成

份(

=头数),维度也从

降到

,分别进行attention计算再拼接。而MQA这一步,在运算过程中,首先对

进行切分(和MHA一样),而

则直接在在线变换的时候把维度压到

(而不是切分开),然后返回每个Query头分别和一份

进行attention计算,之后最终结果拼接起来。

简而言之,就是MHA中,每个注意力头的

是不一样的,而MQA这里,每个注意力头的

是一样的,值是共享的。而性别效果和MHA一样。

GQA:Grouped Query Attention

论文标题:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints 论文链接: https://arxiv.org/pdf/2305.13245

既然MQA对效果有点影响,MHA存储又有不下,那2023年GQA(Grouped-Query Attention)就提出了一个折中的办法,既能减少MQA效果的损失,又相比MHA需要更少的存储。

picture.image

GQA是,

还是按原来MHA/MQA的做法不变。只使用一套共享的

就能效果不好吗,那就还是多个头。但是要不要太多,数量还是比

的头数少一些,这样相当于把多个头分成group,同一个group内的

共享,同不group的

所用的

不同。

MHA可以认为是

头数最大时的GQA(有几个Q就有几个K,V),而MQA可以认为是

头数最少时的GQA(K的头数只有1个)。

MLA:Multi-head Latent Attention

论文标题:DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model 论文链接: https://arxiv.org/abs/2405.04434

MLA(Memory-efficient Latent Attention) 的核心思想是将注意力输入

压缩到一个低维的潜在向量,记作

,其维度

远小于原始的

维度。这样,在计算注意力时,我们可以通过映射将该潜在向量恢复到高维空间,以重构键(keys)和值(values)。这种方法的优势在于,只需存储低维的潜在向量,从而大幅减少内存占用。picture.image

这一过程可以用以下公式描述:

  • 是低维的潜在向量。
  • 是一个压缩矩阵(down-projection matrix),用于将

的维度从

降维到

(其中 D 代表“降维”)。

是两个向上投影矩阵(up-projection matrices),分别用于将共享的潜在向量映射回高维空间,以恢复键(K)和值(V)。

picture.image

多头注意力(MHA)、分组查询注意力(GQA)、多查询注意力(MQA)和多头潜在注意力(MLA)的简化示意图。通过将键(keys)和值(values)联合压缩到一个潜在向量中,MLA在推理过程中显著减少了KV缓存的大小。

NSA:Native Sparse Attention

Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention

论文地址: https://arxiv.org/abs/2502.11089

picture.image

NSA的技术方法涵盖算法设计与内核优化。其整体框架基于对注意力机制的重新定义,通过设计不同的映射策略构建更紧凑、信息更密集的键值对表示,以减少计算量。同时,针对硬件特性进行内核优化,提升实际运行效率。

MoBA: Mixture of Block Attention for Long-Context LLMs

论文标题:Mixture of Block Attention for Long-Context LLMs 论文地址: https://github.com/MoonshotAI/MoBA/blob/master/MoBA\_Tech\_Report.pdf

它基于MoE原理,应用于Transformer模型的注意力机制,通过将上下文划分为块,并采用门控机制选择性地将查询令牌路由到最相关的块,提高LLMs效率,使模型能处理更长更复杂的提示,同时降低资源消耗。

picture.image

Page Attention

论文标题:Efficient Memory Management for Large Language Model Serving with PagedAttention 论文地址: https://arxiv.org/pdf/2309.06180

在 vLLM 库中 LLM 服务的性能受到内存瓶颈的影响。在自回归 decoder 中,所有输入到 LLM 的 token 会产生注意力 key 和 value 的张量,这些张量保存在 GPU 显存中以生成下一个 token。这些缓存 key 和 value 的张量通常被称为 KV cache,其具有以下特点:

  • 显存占用大:在 LLaMA-13B 中,缓存单个序列最多需要 1.7GB 显存;
  • 动态变化:KV 缓存的大小取决于序列长度,这是高度可变和不可预测的。因此,这对有效管理 KV cache 挑战较大。该研究发现,由于碎片化和过度保留,现有系统浪费了 60% - 80% 的显存。

为了解决这个问题,该研究引入了 PagedAttention,这是一种受操作系统中虚拟内存和分页经典思想启发的注意力算法。与传统的注意力算法不同,PagedAttention 允许在非连续的内存空间中存储连续的 key 和 value 。具体来说,PagedAttention 将每个序列的 KV cache 划分为块,每个块包含固定数量 token 的键和值。在注意力计算期间,PagedAttention 内核可以有效地识别和获取这些块。

picture.image

https://arxiv.org/pdf/2309.06180

具体来讲,Paged Attention 将每个序列的 KV Cache 分成若干块,每个块包含固定数量token 的键和值。在注意力计算期间,PagedAttention 内核可以有效地识别和获取这些块。因为块在内存中不需要连续,因而可以用一种更加灵活的方式管理 key 和 value ,就像在操作系统的虚拟内存中一样:可以将块视为页面,将 token 视为字节,将序列视为进程。序列的连续逻辑块通过块表映射到非连续物理块中。物理块在生成新 token 时按需分配。在 PagedAttention 中,内存浪费只会发生在序列的最后一个块中。这使得在实践中可以实现接近最佳的内存使用,仅浪费不到 4%。

FlashAttention

论文标题:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 论文地址: https://arxiv.org/pdf/2205.14135

此前针对Transformer类模型的优化工作致力于减少 FLOP 从而优化计算速度,忽略了显存IO访问上的优化。

attention计算所带来的O(n^2)复杂度的矩阵对HBM的重复读写是制约模型推理的一个主要瓶颈。

如下图所示GPU SRAM的读写(I/O)的速度为19TB/s 和GPU HBM 的读写(I/O)速度 1.5TB/s 相差十几倍,而对比存储容量也相差了好几个数量级。从图中可以看出Flash Attention的计算是从HBM中读取块,在SRAM中计算之后再写到HBM中。

picture.image

要解决这些问题,需要做两件主要的事情:

  • 在不访问整个输入的情况下计算 softmax
  • 不为反向传播存储大的中间 attention 矩阵

为此 FlashAttention 提出了两种方法来分布解决上述问题:tiling 和 recomputation。

  • tiling - 注意力计算被重新构造,将输入分割成块,并通过在输入块上进行多次传递来递增地执行softmax操作。
  • recomputation - 存储来自前向的 softmax 归一化因子,以便在反向中快速重新计算芯片上的 attention,这比从HBM读取中间矩阵的标准注意力方法更快。

由于重新计算,这虽然导致FLOPS增加,但是由于大量减少HBM访问,FlashAttention运行速度更快。 该算法背后的主要思想是分割输入,将它们从慢速HBM加载到快速SRAM,然后计算这些块的 attention 输出。在将每个块的输出相加之前,将其按正确的归一化因子进行缩放,从而得到正确的结果。

参考资料

  • LLM(十七):从 FlashAttention 到 PagedAttention, 如何进一步优化 Attention 性能

  • LLM推理算法简述

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
大规模高性能计算集群优化实践
随着机器学习的发展,数据量和训练模型都有越来越大的趋势,这对基础设施有了更高的要求,包括硬件、网络架构等。本次分享主要介绍火山引擎支撑大规模高性能计算集群的架构和优化实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论