https://arxiv.org/pdf/2312.04985.pdf
SparQ Attention通过选择性地获取缓存历史来减少注意力块内的内存带宽需求。这种方法可以直接应用于现有的LLM推理过程,无需修改预训练设置或进行额外的微调。作者通过评估Llama 2和Pythia模型在多种下游任务上的表现,展示了SparQ Attention如何在不损失准确性的情况下,将注意力内存带宽需求减少多达8倍。
算法原理:SparQ Attention基于以下两个观察结果:
- 注意力机制中的softmax函数输出主要由少数分量主导,大部分分量接近于0。因此,只需从内存中获取具有最高注意力分数的tokens的key和value对,就可以在不影响任务性能的情况下显著减少内存传输。
- 通过稀疏化查询向量q,仅保留r个最大幅度分量,可以有效预测具有最大注意力分数的索引,而无需获取完整的K矩阵。
基于这两个观察结果,SparQ Attention算法包括以下三个步骤:
- 找到输入查询向量q中r个最大分量的索引,并仅沿着这些索引获取key缓存K。使用切片查询和键来计算近似注意力分数。
- 在近似注意力分数中找到前k个位置,并获取相应的完整key和value向量。使用前k个键和值计算注意力块的输出。
- 使用近似注意力分数估计分配给前k个位置的总分数α。根据近似分数权重,使用该总分数在前k个位置的注意力输出和平均值向量之间进行插值。
from torch import abs, softmax, sqrt, tensor, topk
def gather(t, dim, i):
dim += (dim < 0) * t.ndim
return t.gather(dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1 :]))
def attn(Q, K, V, M):
s = (Q @ K.transpose(-1, -2)) / sqrt(tensor(Q.shape[-1])) + M
return softmax(s, dim=-1) @ V
def sparq\_attn(Q, K, V, V\_mean, M, r, k):
# 1. Approximate attention scores using r largest components of Q
i1 = topk(abs(Q), r, -1).indices
Q_hat, K_hat = gather(Q, -1, i1), gather(K, -1, i1)
scale = sqrt(
Q.shape[-1]
* abs(Q_hat).sum(dim=-1, keepdim=True)
/ abs(Q).sum(dim=-1, keepdim=True)
)
s_hat = softmax(Q_hat @ K_hat.transpose(-1, -2) / scale + M, dim=-1)
# 2. Gather top k positions based on approximate attention scores & run attention
i2 = topk(s_hat, k, -1).indices
iKV = i2[..., 0, :, None]
K, V, M = gather(K, -2, iKV), gather(V, -2, iKV), gather(M, -1, i2)
y_ = attn(Q, K, V, M)
# 3. Estimate the total score of the top k, and interpolate with V\_mean
alpha = gather(s_hat, -1, i2).sum(-1, keepdim=True)
return alpha * y_ + (1 - alpha) * V_mean
结论:
SparQ Attention在各种任务和模型大小上表现出色,可实现2倍至8倍的压缩比,同时几乎不损失任务性能。此外,通过调整参数r和k,可以在近似精度和推理速度之间进行权衡。总体而言,SparQ Attention为提高LLM推理效率提供了一种有效且实用的方法。