本文将对比使用MHA (Multi-Head Attention)、MQA (Multi-Query Attention)、GQA (Grouped-Query Attention)和MLA (Multi-Head Latent Attention)这4种注意力机制时,在decoder阶段使用KV cache生成单个token所需的额外缓存空间。#AI入门 #从0到1 #FromScratch #注意力机制
更多AI相关欢迎关注微信公众号"小窗幽记机器学习":
在具有L层、
个注意头和key维度
(即单个head的维度)的Transformer模型中,decoder阶段使用KV cache的话,生成一个token需要多少KV cache空间?
模型中有如下关系: d_model=hidden_size= embed_dim=d_h * n_h
MHA、MQA、GQA和MLA这几种都是Transformer架构中注意力机制的不同变体,主要区别在于如何处理键值对。
MHA (Multi-Head Attention):
标准的多头注意力机制,每个注意力头都有独立的查询(Q)、键(K)、值(V)矩阵。计算复杂度高但表达能力强,是原始Transformer使用的方法。
MQA (Multi-Query Attention):
多个查询头共享同一组键值对,即只有一个K和V矩阵,但有多个Q矩阵。这大幅减少了KV缓存的内存占用,提高了推理速度,但可能会损失一些表达能力。
GQA (Grouped-Query Attention):
MHA和MQA的折中方案,将查询头分成若干组,每组内的头共享同一组键值对。比如8个查询头可以分成2组,每组4个头共享KV。在保持较好性能的同时减少内存使用。
MLA (Multi-Head Latent Attention):
DeepSeek V2中引入的注意力机制,通过引入潜在空间来进一步优化计算效率。将高维的键值投影到低维潜在空间进行计算,然后再投影回原空间,在保持性能的同时显著降低计算和存储开销。
这些变体的发展趋势是在保持模型性能的前提下,不断优化计算效率和内存使用,特别是在大模型推理场景中越来越重要。
KV cache的缓存空间计算如下:
对于每个注意力头,需要存储:
- Key向量:
维度
- Value向量:
维度
因此每个头需要存储
个数值。
对于整个模型:
- 每层有
个注意力头
- 总共有L层
- 每个位置需要存储:L ×
× 2 ×
个数值
如果序列长度为n,那么KV cache的总存储空间为:
以浮点数精度计算存储大小:
- 如果使用FP32(4字节):
字节
- 如果使用FP16(2字节):
字节
这个缓存空间会随着序列长度n线性增长,这也是为什么长序列推理时内存消耗会快速增加的原因。在实际部署中,KV cache往往是推理时的内存瓶颈之一。
至于,MQA、GQA和MLA以下直接进行换算。
以DeepSeek R1/V3 为例,L=61,
=128,
=128,fp16,隐层维度
。那么以DeepSeek V3的模型规模为例,生成1个token额外需要多少KV cache空间?
MHA:
2*61*128*2*128=3997696 字节(Bytes)
1 MB=1024*1024字节=1048576 字节
3997696 字节 / 1048576 字节=3.8125(MB)
因此,对于DeepSeek V3模型,如果使用MHA,则每生成1个token需要3.8MB (约4MB)的存储空间。
小伙伴们可以自行估算,如果n=1000,那么此时再生成1个token,需要额外多少存储空间?
MQA:
KV cache数据量=
; 存储空间(fp16)=
=2261*128/1024=30.5KB
GQA:
假设head分成
组(也可以记为
),那么每组:
个head共享一个KV。当
时候,GQA等价于MQA;当
时,GQA等于MHA。
因此:
KV cache数据量=
;
假设每组8个head,那么
个head。
存储空间(fp16)=
MLA
KV cache数据量:
存储空间(fp16)=
| 注意力机制 | 每个token需要的cache的数据量 | 每个token需要的KV cache对应的存储空间 | 补充说明 | | --- | --- | --- | --- | | Multi-Head Attention(MHA) |
| 3.8MB |
| | Multi-Query Attention(MQA) |
| 30.5 KB | 变为MHA的
| | Grouped-Query Attention(GQA) |
| 488 KB | 相对比MHA,降低了8倍 | | Multi-Head Latent Attention(MLA) |
| 67.55KB | 相比于MHA,降低了57.6 倍 |
更多AI相关欢迎关注微信公众号"小窗幽记机器学习"