前言
长上下文 LLM 推理面临两大挑战:1)预填充阶段注意力延迟较长,2) KV 缓存存储和传输成本较高。之前有效的长上下文 LLM 方法主要集中在 KV 缓存压缩、静态稀疏注意力(例如模型压缩、SSM、线性注意力)或分布式服务。然而,这些方法很难以低成本和单个 A100 GPU 实现百万级token提示的可接受延迟。
为了解决这些问题,提出了 MInference,在单台 A100 提速10倍。MInference 是一种基于动态稀疏注意的长上下文 LLM 预填充阶段的无需训练的有效方法。
MInference 1.0
1、Problem Formulation
在利用稀疏注意力计算来加速长上下文LLMs的预填充阶段时,首先需要对注意力矩阵进行定义。如下公式:
- :注意力矩阵,其中 是一个动态稀疏掩码(mask),其元素 是二进制值(0或1),用于指示注意力矩阵中哪些项是重要的。
- 和 :分别是查询(query)和键(key)矩阵,它们的维度是 ,其中 是序列长度, 是头(head)的维度。
- :是缩放因子,通常取 的平方根,用于平衡注意力分数。
- :是一个常数(例如 ),用于确保当 时,通过函数处理后的 接近于零。
动态稀疏注意力 的目标是在保持尽可能多的注意力权重的同时,实现更大的加速并最小化开销。这可以形式化为以下两个最小化问题:
- 尽量减少稀疏注意力矩阵 和密集注意力矩阵 之间的差异,以保持模型性能。
- 尽量减少稀疏注意力计算的时间 以及估计近似动态稀疏模式的开销 。
用公式表示为:
其中, 和 分别代表动态稀疏注意力计算和估计近似动态稀疏模式所需的时间。
2、通过动态稀疏注意力加速长上下文LLM推理
2.1 Kernel-Aware Optimal Sparse Pattern Search
为了在有限的浮点运算次数(FLOPs)预算下实现最佳的精度,作者提出了一种Kernel-Aware Optimal Sparse Pattern Search(A型) 方法。该方法确定每个注意力头(attention head)使用的稀疏模式,以及实际计算中模式的最优设置。
Kernel-Aware(A型)稀疏模式搜索
首先,基于每个模式的目标FLOPs创建搜索空间,确保所有潜在候选者具有相似的计算成本。"Kernel-aware" 表示计算成本反映了GPU内核中的实际FLOPs,这对于实现最佳加速至关重要。
接下来,使用一个参考示例遍历搜索空间,以确定最优模式和设置。使用注意力输出的召回率 作为搜索最佳模式时的目标标准。
2.2 Sparsity Indices Approximation and Dynamic Sparse Attention Calculation
在推理阶段,对注意力矩阵进行在线估计,以动态确定稀疏索引的空间分布,然后使用优化的GPU内核进行稀疏注意力计算。
如何分别针对垂直-斜杠头和块稀疏头构建稀疏索引并计算最终的动态稀疏注意力得分
- Vertical-Slash Head(垂直斜线)
- 利用垂直和斜杠线的连续性,通过矩阵乘法生成估计的注意力矩阵 。
- 确定垂直线 和斜杠线 的索引,并将它们转换为稀疏格式 。
- 使用稀疏索引执行注意力权重和输出的块稀疏计算。
- Block-Sparse Head(块稀疏)
- 对 和 应用均值池化,获得 和 。
- 计算估计的块级注意力权重 。
- 构建稀疏索引 ,并用它来计算稀疏注意力权重和输出。
小结:通过上述方法,MInference 技术能够在不牺牲准确性的前提下,能够显著减少长上下文LLMs推理的延迟,特别是在预填充 阶段。
插拔使用
目前支持的中文模型有GLM-4和Qwen2 ,下面看看插拔使用MInference。
安装:
pip install minference
- for HF:
from transformers import pipeline
+from minference import MInference
pipe = pipeline("text-generation", model=model_name, torch_dtype="auto", device_map="auto")
# Patch MInference Module
+minference_patch = MInference("minference", model_name)
+pipe.model = minference_patch(pipe.model)
pipe(prompt, max_length=10)
- for vLLM(支持vllm==0.4.x):
from vllm import LLM, SamplingParams
+ from minference import MInference
llm = LLM(model_name, max_num_seqs=1, enforce_eager=True, max_model_len=128000)
# Patch MInference Module
+minference_patch = MInference("vllm", model_name)
+llm = minference_patch(llm)
outputs = llm.generate(prompts, sampling_params)
- using only the kernel:
from minference import vertical_slash_sparse_attention, block_sparse_attention, streaming_forward
attn_output = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
attn_output = block_sparse_attention(q, k, v, topk)
attn_output = streaming_forward(q, k, v, init_num, local_window_num)
参考文献
-
paper:MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention,https://arxiv.org/pdf/2407.02490
