“
NLP前沿交流群成立,详见置顶推文
https://huggingface.co/blog/kv-cache-quantization
1.省流,how to use?
只需要在model.generate时传入cache_implementation="quantized", cache_config={"backend": "quanto", "nbits": 4}
参数,示例如下:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="cuda:0")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)
out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"backend": "quanto", "nbits": 4})
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
out = model.generate(**inputs, do_sample=False, max_new_tokens=20)
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
- kv cache是什么?
目前的自回归模型,解码截断一个 token 一个 token 生成,每个新的 token 预测都依赖于先前的上下文。这意味着,要预测生成中的第1000个 token,需要来自之前 999 个已预测的token + 预填充的prompt token的信息。如果继续预测第1001个token,还需要前 999 个标记中的相同信息 + 预填充的prompt token的信息,以及来自 token 数字 1000 的附加信息。这就是使用kv cache的加速原理,通过存储先前的计算以在后续token中复用,而不需要再次计算它们。
尽管 kv cache 加速了自回归大模型的生成,但它可能成为长上下文长度或高批量大小的内存瓶颈。可以估计一下,对于 7B Llama-2 模型,对于序列长度为 10000 个令牌的输入,需要多少内存来存储 kv 缓存。存储一个 token 的 kv 缓存所需的内存大致为 2 * 2 * num_layers * num_key_value_heads * head_dim ,其中第一个 2 表示 key & value,第二个 2 是需要的字节数(假设模型加载到 float16 中)。因此,如果我们有一个长度为 10000 个标记的上下文,我们需要
2 * 2 * 32 * 32 * 128 * 10000 ≈ 5GB
的内存仅用于存储之前的kv cache,这几乎是半精度存储模型参数所需内存的三分之一。
- 实现细节
主要受到 KIVI: A Tuning-Free Asymmetric 2bit Quantization for kv Cache
论文的启发,在KIVI中,对kvcache的,k沿着channel维度量化,v沿着token维度量化。但是集成到 Transformer 中的方法中,k和v都是按 token 量化的。量化每个token时的主要瓶颈是每次添加新的token(即每个生成步骤)时都需要对k和v进行量化和反量化。这可能会导致生成速度变慢。为了解决这个问题,实现过程中,保留固定大小的剩余缓存,以按原始精度存储k和v。当剩余缓存达到其最大容量时,存储的键和值将被量化,并且缓存内容将被丢弃。这个小技巧还可以保持准确性,因为最新的k和v的某些部分始终以其原始精度存储。设置剩余缓存长度时主要考虑的是内存效率的权衡。虽然残留缓存以其原始精度存储键和值,但这可能会导致总体内存使用量增加。发现使用剩余长度 128 作为baseline效果很好。
- 性能对比
int4 缓存的性能与 fp16 精度几乎相同,而使用 int2 时性能会下降。
将 LongBench 基准测试的性能与 KIVI 论文的结果进行比较时,得出同样的结论。在下表中的所有数据集中, Int4 quanto 精度相当,甚至略优于 fp16 。