【LLM】MInference:浅尝通过动态稀疏注意力加速长上下文LLM推理框架

前言

长上下文 LLM 推理面临两大挑战:1)预填充阶段注意力延迟较长,2) KV 缓存存储和传输成本较高。之前有效的长上下文 LLM 方法主要集中在 KV 缓存压缩、静态稀疏注意力(例如模型压缩、SSM、线性注意力)或分布式服务。然而,这些方法很难以低成本和单个 A100 GPU 实现百万级token提示的可接受延迟。

为了解决这些问题,提出了 MInference,在单台 A100 提速10倍。MInference 是一种基于动态稀疏注意的长上下文 LLM 预填充阶段的无需训练的有效方法。

picture.image

MInference 1.0

1、Problem Formulation

在利用稀疏注意力计算来加速长上下文LLMs的预填充阶段时,首先需要对注意力矩阵进行定义。如下公式:

picture.image

  • :注意力矩阵,其中 是一个动态稀疏掩码(mask),其元素 是二进制值(0或1),用于指示注意力矩阵中哪些项是重要的。
  • 和 :分别是查询(query)和键(key)矩阵,它们的维度是 ,其中 是序列长度, 是头(head)的维度。
  • :是缩放因子,通常取 的平方根,用于平衡注意力分数。
  • :是一个常数(例如 ),用于确保当 时,通过函数处理后的 接近于零。

动态稀疏注意力 的目标是在保持尽可能多的注意力权重的同时,实现更大的加速并最小化开销。这可以形式化为以下两个最小化问题:

  1. 尽量减少稀疏注意力矩阵 和密集注意力矩阵 之间的差异,以保持模型性能。
  2. 尽量减少稀疏注意力计算的时间 以及估计近似动态稀疏模式的开销 。

用公式表示为:

其中, 和 分别代表动态稀疏注意力计算和估计近似动态稀疏模式所需的时间。

2、通过动态稀疏注意力加速长上下文LLM推理

2.1 Kernel-Aware Optimal Sparse Pattern Search

为了在有限的浮点运算次数(FLOPs)预算下实现最佳的精度,作者提出了一种Kernel-Aware Optimal Sparse Pattern Search(A型) 方法。该方法确定每个注意力头(attention head)使用的稀疏模式,以及实际计算中模式的最优设置。

picture.image Kernel-Aware(A型)稀疏模式搜索

首先,基于每个模式的目标FLOPs创建搜索空间,确保所有潜在候选者具有相似的计算成本。"Kernel-aware" 表示计算成本反映了GPU内核中的实际FLOPs,这对于实现最佳加速至关重要。

接下来,使用一个参考示例遍历搜索空间,以确定最优模式和设置。使用注意力输出的召回率 作为搜索最佳模式时的目标标准。

2.2 Sparsity Indices Approximation and Dynamic Sparse Attention Calculation

在推理阶段,对注意力矩阵进行在线估计,以动态确定稀疏索引的空间分布,然后使用优化的GPU内核进行稀疏注意力计算。

picture.image 如何分别针对垂直-斜杠头和块稀疏头构建稀疏索引并计算最终的动态稀疏注意力得分

  1. Vertical-Slash Head(垂直斜线)
  • 利用垂直和斜杠线的连续性,通过矩阵乘法生成估计的注意力矩阵 。
  • 确定垂直线 和斜杠线 的索引,并将它们转换为稀疏格式 。
  • 使用稀疏索引执行注意力权重和输出的块稀疏计算。
  • Block-Sparse Head(块稀疏)
  • 对 和 应用均值池化,获得 和 。
  • 计算估计的块级注意力权重 。
  • 构建稀疏索引 ,并用它来计算稀疏注意力权重和输出。

小结:通过上述方法,MInference 技术能够在不牺牲准确性的前提下,能够显著减少长上下文LLMs推理的延迟,特别是在预填充 阶段。

插拔使用

目前支持的中文模型有GLM-4和Qwen2 ,下面看看插拔使用MInference。

安装:


        
          
pip install minference  

      
  • for HF:

        
          
from transformers import pipeline  
+from minference import MInference  
  
pipe = pipeline("text-generation", model=model_name, torch_dtype="auto", device_map="auto")  
  
# Patch MInference Module  
+minference_patch = MInference("minference", model_name)  
+pipe.model = minference_patch(pipe.model)  
  
pipe(prompt, max_length=10)  

      
  • for vLLM(支持vllm==0.4.x):

        
          
from vllm import LLM, SamplingParams  
+ from minference import MInference  
  
llm = LLM(model_name, max_num_seqs=1, enforce_eager=True, max_model_len=128000)  
  
# Patch MInference Module  
+minference_patch = MInference("vllm", model_name)  
+llm = minference_patch(llm)  
  
outputs = llm.generate(prompts, sampling_params)  

      
  • using only the kernel:

        
          
from minference import vertical_slash_sparse_attention, block_sparse_attention, streaming_forward  
  
attn_output = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)  
attn_output = block_sparse_attention(q, k, v, topk)  
attn_output = streaming_forward(q, k, v, init_num, local_window_num)  

      

参考文献

0
0
0
0
评论
未登录
暂无评论