SparQ Attention:高效大模型推理的新方法 (无须重新训练)

智能内容混合云计算

        
          
https://arxiv.org/pdf/2312.04985.pdf  

      

picture.image

SparQ Attention通过选择性地获取缓存历史来减少注意力块内的内存带宽需求。这种方法可以直接应用于现有的LLM推理过程,无需修改预训练设置或进行额外的微调。作者通过评估Llama 2和Pythia模型在多种下游任务上的表现,展示了SparQ Attention如何在不损失准确性的情况下,将注意力内存带宽需求减少多达8倍。

算法原理:SparQ Attention基于以下两个观察结果:

  1. 注意力机制中的softmax函数输出主要由少数分量主导,大部分分量接近于0。因此,只需从内存中获取具有最高注意力分数的tokens的key和value对,就可以在不影响任务性能的情况下显著减少内存传输。
  2. 通过稀疏化查询向量q,仅保留r个最大幅度分量,可以有效预测具有最大注意力分数的索引,而无需获取完整的K矩阵。

基于这两个观察结果,SparQ Attention算法包括以下三个步骤:

  1. 找到输入查询向量q中r个最大分量的索引,并仅沿着这些索引获取key缓存K。使用切片查询和键来计算近似注意力分数。
  2. 在近似注意力分数中找到前k个位置,并获取相应的完整key和value向量。使用前k个键和值计算注意力块的输出。
  3. 使用近似注意力分数估计分配给前k个位置的总分数α。根据近似分数权重,使用该总分数在前k个位置的注意力输出和平均值向量之间进行插值。

        
          
from torch import abs, softmax, sqrt, tensor, topk  
def gather(t, dim, i):  
  dim += (dim < 0) * t.ndim  
  return t.gather(dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1 :]))  
def attn(Q, K, V, M):  
  s = (Q @ K.transpose(-1, -2)) / sqrt(tensor(Q.shape[-1])) + M  
  return softmax(s, dim=-1) @ V  
def sparq\_attn(Q, K, V, V\_mean, M, r, k):  
  # 1. Approximate attention scores using r largest components of Q  
  i1 = topk(abs(Q), r, -1).indices  
  Q_hat, K_hat = gather(Q, -1, i1), gather(K, -1, i1)  
  scale = sqrt(  
  Q.shape[-1]  
  * abs(Q_hat).sum(dim=-1, keepdim=True)  
  / abs(Q).sum(dim=-1, keepdim=True)  
  )  
  s_hat = softmax(Q_hat @ K_hat.transpose(-1, -2) / scale + M, dim=-1)  
  # 2. Gather top k positions based on approximate attention scores & run attention  
  i2 = topk(s_hat, k, -1).indices  
  iKV = i2[..., 0, :, None]  
  K, V, M = gather(K, -2, iKV), gather(V, -2, iKV), gather(M, -1, i2)  
  y_ = attn(Q, K, V, M)  
  # 3. Estimate the total score of the top k, and interpolate with V\_mean  
  alpha = gather(s_hat, -1, i2).sum(-1, keepdim=True)  
  return alpha * y_ + (1 - alpha) * V_mean  

      

结论:

SparQ Attention在各种任务和模型大小上表现出色,可实现2倍至8倍的压缩比,同时几乎不损失任务性能。此外,通过调整参数r和k,可以在近似精度和推理速度之间进行权衡。总体而言,SparQ Attention为提高LLM推理效率提供了一种有效且实用的方法。

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
高性能存储虚拟化方案 NVMe over Fabric 在火山引擎的演进
在云计算中,虚拟化存储扮演着重要角色,其中 iSCSI 协议在业界开放、流行多年。近年来,拥有更优性能的 NVMe over Fabrics 协议也得到了发展。本次分享介绍了 NVMe over Fabrics 在云原生和虚拟化方向的演进工作和成果。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论