浅谈LLM的长度外推

写在前面

大家好,我是刘聪NLP。

随着大模型应用的不断发展,知识外挂已经成为了重要手段。但只是外挂手段往往受限于模型本身可接受长度,以及模型外推能力。今天给大家带来一篇LLM的长度外推浅谈,来自@uuuuu(知乎)。


          
https://zhuanlan.zhihu.com/p/645770522  

      

涉及到更新至20230724的外推策略,NBCE,线性内插,NTK-Aware Scaled RoPE,Dynamically Scaled RoPE,consistent of Dynamically Scaled RoPE。

从第二个开始,基本上后一个都是基于前一个的基础上优化得到的,适用于所有使用ROPE的语言模型。

一、NBCE


          
NBCE:使用朴素贝叶斯扩展LLM的Context处理长度  
https://kexue.fm/archives/9617  

      

苏神最早提出的扩展LLM的context方法,基于bayes启发得到的公式:picture.image在问答下实测确实不错,在较长context下的阅读理解还算好用。

局限性是,无序性,即无法识别Context的输入顺序,这在续写故事等场景可能表现欠佳,做一些依赖每个context生成答案,比如提取文档摘要,效果较差。


          
outputs = model(input_ids=input_ids,  
                        attention_mask=attention_mask,  
                        return_dict=True,  
                        use_cache=True,  
                        past_key_values=past_key_values  
                       )  
past_key_values = outputs.past_key_values  
          
# ===== 核心代码开始 =====  
beta = 0.25  
probas = torch.nn.functional.softmax(outputs.logits[:, -1], dim=-1)  
logits = probas.log()  
k = (probas * logits).sum(dim=-1)[1:].argmax() + 1  
logits_max = logits[k]  
logits_uncond = logits[0]  
logits = (1 + beta) * logits_max - beta * logits_uncond  
# ===== 核心代码结束 =====  
          
# 构建分布,采样  
tau = 0.01  # tau = 1是标准的随机采样,tau->0则是贪心搜索  
probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)  
next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)      

      

此处代码,图片,文本均选自科学空间。

二、线性内插


          
https://kaiokendev.github.io/context  
https://lmsys.org/blog/2023-06-29-longchat/  
https://arxiv.org/abs/2306.15595  

      

llama基于rotary embedding在2048长度上预训练,该方法通过将position压缩到0~2048之间,从而达到长度外推的目的。

longchat将模型微调为上下文长度外扩为16384,压缩比为 8。例如,position_ids = 10000 的 token 变为position_ids = 10000 / 8 = 1250,相邻 token 10001 变为 10001 / 8 = 1250.125

该方法的缺陷是需要进行一定量的微调,让模型来适应这种改变。


          
import torch  
import transformers  
import transformers.models.llama.modeling_llama  
from einops import rearrange  
  
from functools import partial  
  
class CondenseRotaryEmbedding(torch.nn.Module):  
    def \_\_init\_\_(self, dim, ratio, max\_position\_embeddings=2048, base=10000, device=None):  
        super().__init__()  
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))  
        self.register_buffer("inv\_freq", inv_freq)  
          
        # Build here to make `torch.jit.trace` work.  
        self.ratio = ratio  
        max_position_embeddings *= ratio  
        print(f"Condensing Positional embeddings from {max\_position\_embeddings} to {max\_position\_embeddings // ratio}")  
        self.max_seq_len_cached = max_position_embeddings  
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) / ratio  
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        dtype = torch.get_default_dtype()  
        self.register_buffer("cos\_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)  
        self.register_buffer("sin\_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)  
  
    def forward(self, x, seq\_len=None):  
        # x: [bs, num\_attention\_heads, seq\_len, head\_size]  
        # This `if` block is unlikely to be run after we build sin/cos in `\_\_init\_\_`. Keep the logic here just in case.  
        if seq_len > self.max_seq_len_cached:  
            self.max_seq_len_cached = seq_len  
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) / self.ratio  
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
            # Different from paper, but it uses a different permutation in order to obtain the same calculation  
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)  
            self.register_buffer("cos\_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)  
            self.register_buffer("sin\_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)  
        return (  
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),  
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),  
        )  
  
def replace\_llama\_with\_condense(ratio):  
    transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial(CondenseRotaryEmbedding, ratio=ratio)  

      

三、NTK-Aware Scaled RoPE


          
NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.  
https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/  
  
RoPE是一种β进制编码:https://spaces.ac.cn/archives/9675  

      

picture.image有意思的解释一下,RoPE 的行为就像一个时钟。12小时时钟基本上是一个维度为 3、底数为 60 的 RoPE。因此,每秒钟,分针转动 1/60 分钟,每分钟,时针转动 1/60。现在,如果将时间减慢 4 倍,那就是二使用的线性RoPE 缩放。不幸的是,现在区分每一秒,因为现在秒针几乎每秒都不会移动。因此,如果有人给你两个不同的时间,仅相差一秒,你将无法从远处区分它们。NTK-Aware RoPE 扩展不会减慢时间。一秒仍然是一秒,但它会使分钟减慢 1.5 倍,将小时减慢 2 倍。这样,您可以将 90 分钟容纳在一个小时中,将 24 小时容纳在半天中。所以现在你基本上有了一个可以测量 129.6k 秒而不是 43.2k 秒的时钟。由于在查看时间时不需要精确测量时针,因此与秒相比,更大程度地缩放小时至关重要。不想失去秒针的精度,但可以承受分针甚至时针的精度损失。


          
import transformers  
  
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__  
def ntk\_scaled\_init(self, dim, max\_position\_embeddings=2048, base=10000, device=None):  
  
    #The method is just these three lines  
    max_position_embeddings = 16384  
    a = 8 #Alpha value  
    base = base * a ** (dim / (dim-2)) #Base change formula  
  
    old_init(self, dim, max_position_embeddings, base, device)  
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init  

      

四、Dynamically Scaled RoPE


          
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/  

      

picture.image对于上面的方法二、三,都涉及到一个超参数α,用于调节缩放比例,该方法是通过序列长度动态选择正确的比例参数,效果可以看上图。

对于线性插值,前 2k 上下文的精确位置值,然后在模型逐个生成标记时重新计算每个新序列长度的位置向量。本质上,将比例设置为原始模型上下文长度/当前序列长度。

对于动态 NTK,α 的缩放设置为 (α * 当前序列长度 / 原始模型上下文长度) - (α - 1)。随着序列长度的增加动态缩放超参数。


          
import math  
import torch  
  
class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):  
    def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):  
        super().__init__()  
        self.ntk = ntk  
        self.base = base  
        self.dim = dim  
        self.max_position_embeddings = max_position_embeddings  
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))  
        self.register_buffer("inv\_freq", inv_freq)  
  
        # Build here to make `torch.jit.trace` work.  
        self.max_seq_len_cached = max_position_embeddings  
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)  
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        dtype = torch.get_default_dtype()  
        self.register_buffer("cos\_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)  
        self.register_buffer("sin\_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)  
  
    def forward(self, x, seq_len=None):  
        # x: [bs, num\_attention\_heads, seq\_len, head\_size]  
        # This `if` block is unlikely to be run after we build sin/cos in `\_\_init\_\_`. Keep the logic here just in case.  
        if seq_len > self.max_seq_len_cached:  
            self.max_seq_len_cached = seq_len  
            if self.ntk:  
                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))  
                inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))  
                self.register_buffer("inv\_freq", inv_freq)  
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)  
            if not self.ntk:  
                t *= self.max_position_embeddings / seq_len  
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
            # Different from paper, but it uses a different permutation in order to obtain the same calculation  
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)  
            self.register_buffer("cos\_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)  
            self.register_buffer("sin\_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)  
        return (  
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),  
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),  
        )  

      

五、consistent of Dynamically Scaled RoPE


          
https://github.com/NormXU/Consistent-DynamicNTKRoPE  

      

picture.image


          
import math  
from typing import List, Optional, Tuple, Union  
  
import torch  
import torch.nn.functional as F  
import torch.utils.checkpoint  
from torch import nn  
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb  
from transformers.models.llama.modeling_llama import LlamaAttention  
  
def forward(  
        self,  
        hidden\_states: torch.Tensor,  
        attention\_mask: Optional[torch.Tensor] = None,  
        position\_ids: Optional[torch.LongTensor] = None,  
        past\_key\_value: Optional[Tuple[torch.Tensor]] = None,  
        output\_attentions: bool = False,  
        use\_cache: bool = False,  
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:  
    bsz, q_len, _ = hidden_states.size()  
  
    if self.pretraining_tp > 1:  
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp  
        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)  
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)  
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)  
  
        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]  
        query_states = torch.cat(query_states, dim=-1)  
  
        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]  
        key_states = torch.cat(key_states, dim=-1)  
  
        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]  
        value_states = torch.cat(value_states, dim=-1)  
  
    else:  
        query_states = self.q_proj(hidden_states)  
        key_states = self.k_proj(hidden_states)  
        value_states = self.v_proj(hidden_states)  
  
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)  
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  
  
    kv_seq_len = key_states.shape[-2]  
    if past_key_value is not None:  
        kv_seq_len += past_key_value[0].shape[-2]  
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)  
  
    if past_key_value is not None:  
        # reuse k w/o RoPE  
        key_states = torch.cat([past_key_value[0], key_states], dim=2)  
  
    # apply RoPE after retrieving all keys and queries  
    query_states, rotated_key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)  
  
    if past_key_value is not None:  
        # reuse v, self\_attention  
        value_states = torch.cat([past_key_value[1], value_states], dim=2)  
  
    past_key_value = (key_states, value_states) if use_cache else None # cache the key w/o RoPE  
  
    # repeat k/v heads if n\_kv\_heads < n\_heads  
    rotated_key_states = repeat_kv(rotated_key_states, self.num_key_value_groups)  
    value_states = repeat_kv(value_states, self.num_key_value_groups)  
  
    attn_weights = torch.matmul(query_states, rotated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)  
  
    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):  
        raise ValueError(  
            f"Attention weights should be of size {(bsz, self.num\_heads, q\_len, kv\_seq\_len)}, but is"  
            f" {attn\_weights.size()}"  
        )  
  
    if attention_mask is not None:  
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):  
            raise ValueError(  
                f"Attention mask should be of size {(bsz, 1, q\_len, kv\_seq\_len)}, but is {attention\_mask.size()}"  
            )  
        attn_weights = attn_weights + attention_mask  
  
    # upcast attention to fp32  
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  
    attn_output = torch.matmul(attn_weights, value_states)  
  
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):  
        raise ValueError(  
            f"`attn\_output` should be of size {(bsz, self.num\_heads, q\_len, self.head\_dim)}, but is"  
            f" {attn\_output.size()}"  
        )  
  
    attn_output = attn_output.transpose(1, 2).contiguous()  
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)  
  
    if self.pretraining_tp > 1:  
        attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)  
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)  
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])  
    else:  
        attn_output = self.o_proj(attn_output)  
  
    if not output_attentions:  
  
  
        attn_weights = None  
  
    return attn_output, attn_weights, past_key_value  
  
  
def replace\_llama\_attn\_with\_consistent\_ntk\_rope():  
    LlamaAttention.forward = forward  

      

总结

请多多关注知乎「刘聪NLP」,有问题的朋友也欢迎加我微信「logCong」私聊,交个朋友吧,一起学习,一起进步。我们的口号是“生命不止,学习不停”。

PS:交流群4天就加满了,没有进群的小伙伴不要着急,过几天搞个2群。

往期推荐:

0
0
0
0
评论
未登录
暂无评论