hf transformers重磅更新:KV Cache量化技术!精度不减,显存大降2.5倍,单卡轻松驾驭100K大模型!

人工智能与算法视频云MySQL

NLP前沿交流群成立,详见置顶推文


        
          
https://huggingface.co/blog/kv-cache-quantization  

      

1.省流,how to use?

只需要在model.generate时传入cache_implementation="quantized", cache_config={"backend": "quanto", "nbits": 4}参数,示例如下:


        
          
import torch  
from transformers import AutoTokenizer, AutoModelForCausalLM  
  
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")  
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="cuda:0")  
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)  
  
out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"backend": "quanto", "nbits": 4})  
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])  
  
out = model.generate(**inputs, do_sample=False, max_new_tokens=20)  
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])  
  

      
  1. kv cache是什么?

目前的自回归模型,解码截断一个 token 一个 token 生成,每个新的 token 预测都依赖于先前的上下文。这意味着,要预测生成中的第1000个 token,需要来自之前 999 个已预测的token + 预填充的prompt token的信息。如果继续预测第1001个token,还需要前 999 个标记中的相同信息 + 预填充的prompt token的信息,以及来自 token 数字 1000 的附加信息。这就是使用kv cache的加速原理,通过存储先前的计算以在后续token中复用,而不需要再次计算它们。

尽管 kv cache 加速了自回归大模型的生成,但它可能成为长上下文长度或高批量大小的内存瓶颈。可以估计一下,对于 7B Llama-2 模型,对于序列长度为 10000 个令牌的输入,需要多少内存来存储 kv 缓存。存储一个 token 的 kv 缓存所需的内存大致为 2 * 2 * num_layers * num_key_value_heads * head_dim ,其中第一个 2 表示 key & value,第二个 2 是需要的字节数(假设模型加载到 float16 中)。因此,如果我们有一个长度为 10000 个标记的上下文,我们需要

2 * 2 * 32 * 32 * 128 * 10000 ≈ 5GB

的内存仅用于存储之前的kv cache,这几乎是半精度存储模型参数所需内存的三分之一。

  1. 实现细节

主要受到 KIVI: A Tuning-Free Asymmetric 2bit Quantization for kv Cache 论文的启发,在KIVI中,对kvcache的,k沿着channel维度量化,v沿着token维度量化。但是集成到 Transformer 中的方法中,k和v都是按 token 量化的。量化每个token时的主要瓶颈是每次添加新的token(即每个生成步骤)时都需要对k和v进行量化和反量化。这可能会导致生成速度变慢。为了解决这个问题,实现过程中,保留固定大小的剩余缓存,以按原始精度存储k和v。当剩余缓存达到其最大容量时,存储的键和值将被量化,并且缓存内容将被丢弃。这个小技巧还可以保持准确性,因为最新的k和v的某些部分始终以其原始精度存储。设置剩余缓存长度时主要考虑的是内存效率的权衡。虽然残留缓存以其原始精度存储键和值,但这可能会导致总体内存使用量增加。发现使用剩余长度 128 作为baseline效果很好。

picture.image

  1. 性能对比

int4 缓存的性能与 fp16 精度几乎相同,而使用 int2 时性能会下降。

picture.image

将 LongBench 基准测试的性能与 KIVI 论文的结果进行比较时,得出同样的结论。在下表中的所有数据集中, Int4 quanto 精度相当,甚至略优于 fp16 。

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动客户端性能优化最佳实践
在用户日益增长、需求不断迭代的背景下,如何保证 APP 发布的稳定性和用户良好的使用体验?本次分享将结合字节跳动内部应用的实践案例,介绍应用性能优化的更多方向,以及 APM 团队对应用性能监控建设的探索和思考。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论