【LLM】LongRoPE:LLM上下文窗口扩展方法及非官方实现

技术

前言

目前,大多数LLMs的上下文窗口限制在4k个标记左右,这意味着模型在处理超过这个长度的文本时性能会下降。这种限制对于需要大量上下文信息的场景,虽然可以通过在更长的文本上进行微调来将预训练LLM的上下文窗口扩展上下文窗口,但要进一步扩展上下文窗口面临着三个主要挑战:

  1. 新位置索引的未训练引入了许多灾难性值,导致分布外问题,使得微调难以收敛。
  2. 微调通常需要相应长度的文本。然而,当前数据集中特别是超过1000k的长文本非常有限。此外,对超长文本进行训练计算成本高昂,需要大量的训练时间和GPU资源。
  3. 当扩展到极长的上下文窗口时,注意力会变得分散,因为它需要在大量的标记位置上进行分配,这会降低模型在原始短上下文上的性能。picture.image

paper:LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens

link:https://arxiv.org/abs/2402.13753

更多细节请参考原文

LongRoPE

创新点

  1. 通过有效搜索识别并利用了位置插值中的两种非均匀性,为微调提供了更好的初始化,并在非微调情况下实现了8倍的扩展。
  2. 引入了一种渐进式扩展策略,首先对长度为256k的LLM进行微调,然后在微调后的扩展LLM上进行第二次位置插值,以实现2048k的上下文窗口。
  3. 在8k长度上重新调整LongRoPE,以恢复短上下文窗口的性能。

位置插值中的非均匀性问题

位置插值中的非均匀性问题是指在扩展大型语言模型(LLMs)的上下文窗口时,如何有效地为新增的token位置分配位置嵌入(positional embeddings),以便模型能够在更长的序列上保持或提升性能。在LongRoPE这篇文章中,作者们发现并利用了两种主要的非均匀性,以改进位置插值方法:

  1. RoPE维度的非均匀性
  • RoPE(Rotary Positional Embedding)是一种在Transformer架构中广泛使用的位置嵌入方法,它通过旋转角度来表示token的位置。
  • 不同的RoPE维度具有不同的旋转频率,这意味着低维度(高频率)和高维度(低频率)在表示位置信息时的重要性和敏感性不同。
  • 低维度对于位置信息的变化更敏感,因此在插值时应使用较小的缩放因子,以保持相邻位置token的区分度。
  • 高维度可以承受更大的插值,因为它们对于位置信息的变化不那么敏感。
  • Token位置的非均匀性
  • 在输入序列的开始部分,token接收到的注意力分数较高,这些位置的token对于模型理解上下文尤为重要。
  • 因此,序列初始的token位置应该使用较小的插值,或者不进行插值,以保留这些关键位置的原始RoPE信息。
  • 随着序列位置的增加,可以应用更大的插值因子,因为远离序列开始的token对于模型理解上下文的重要性逐渐降低。

LongRoPE采用了以下方法解决这些非均匀性问题:

  • 有效的位置插值 :通过进化搜索算法(evolutionary search)来寻找每个RoPE维度的最佳缩放因子(rescale factors),这些因子基于token位置进行调整。
  • 渐进式扩展策略 :首先对长度为256k的LLM进行微调,然后在微调后的模型上进行第二次位置插值,以实现2048k的上下文窗口,而无需直接在极长文本上进行微调。
  • 短上下文窗口性能恢复 :通过额外的进化搜索来调整RoPE缩放因子,以便在扩展到极长上下文窗口后,仍能保持在原始短上下文窗口内的高性能。

搜索算法

picture.image

LongRoPE非官方实现


        
          
import torch  
import torch.nn as nn  
import torch.optim as optim  
import random  
import numpy as np  
import gzip  
import io  
  
  
class RoPEPositionalEncoding(nn.Module):  
    """  
    Rotary Position Encoding (RoPE) module.  
    """  
  
    def __init__(self, d_model, max_len=5000, base=10000):  
        super().__init__()  
        self.d_model = d_model  
        self.max_len = max_len  
        self.base = base  
        self.theta = torch.tensor(  
            [base ** (-2 * (i // 2) / d_model) for i in range(d_model)]  
        )  
  
    def forward(self, positions):  
        angles = positions.unsqueeze(-1) * self.theta  
        return torch.stack([angles.cos(), angles.sin()], dim=-1).flatten(-2)  
  
  
def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat):  
    """  
    Perform non-uniform interpolation on position embeddings.  
  
    Args:  
        pos\_embed (torch.Tensor): Position embeddings.  
        extension\_ratio (float): Extension ratio for context window.  
        lambda\_factors (list): Lambda factors for interpolation.  
        n\_hat (int): Threshold for applying interpolation.  
  
    Returns:  
        torch.Tensor: Interpolated position embeddings.  
    """  
    d_model = pos_embed.shape[-1]  
    interpolated_pos = pos_embed.clone()  
  
    for i in range(d_model // 2):  
        mask = torch.arange(pos_embed.shape[-2]) < n_hat  
        scale = torch.where(  
            mask, torch.ones_like(pos_embed[..., 0]), 1 / lambda_factors[i]  
        )  
        interpolated_pos[..., i * 2] *= scale  
        interpolated_pos[..., i * 2 + 1] *= scale  
  
    return interpolated_pos  
  
  
def search_lambda_factors(  
    model,  
    data,  
    extension_ratio,  
    population_size,  
    num_mutations,  
    num_crossovers,  
    max_iterations,  
):  
    """  
    Search for optimal lambda factors using evolutionary search.  
  
    Args:  
        model (nn.Module): LongRoPE model.  
        data (list): List of input sequences.  
        extension\_ratio (float): Extension ratio for context window.  
        population\_size (int): Size of the population for evolutionary search.  
        num\_mutations (int): Number of mutations per iteration.  
        num\_crossovers (int): Number of crossovers per iteration.  
        max\_iterations (int): Maximum number of iterations for evolutionary search.  
  
    Returns:  
        list: Optimal lambda factors found by the search.  
    """  
    population = initialize_population(population_size, extension_ratio)  
  
    for i in range(max_iterations):  
        perplexities = evaluate_population(model, data, population)  
        parents = select_topk(population, perplexities, k=population_size // 2)  
        population = mutate(parents, num_mutations) + crossover(parents, num_crossovers)  
  
    return min(population, key=lambda x: evaluate_individual(model, data, x))  
  
  
def progressive_extension(model, data, base_length, target_length):  
    """  
    Progressively extend the context window of the model.  
  
    Args:  
        model (nn.Module): LongRoPE model.  
        data (list): List of input sequences.  
        base\_length (int): Base context window length.  
        target\_length (int): Target context window length.  
  
    Returns:  
        tuple: (Extended model, lambda factors, base lambda factors)  
    """  
    curr_model = model  
    curr_length = base_length  
  
    while curr_length < target_length:  
        lambda_factors, n_hat = search_lambda_factors(  
            curr_model, data, curr_length / base_length  
        )  
        curr_model = fine_tune(curr_model, data, curr_length, lambda_factors, n_hat)  
        curr_length *= 2  
  
    lambda_factors_base, _ = search_lambda_factors(  
        curr_model, data, curr_length / base_length, max_length=base_length  
    )  
  
    return curr_model, lambda_factors, lambda_factors_base  
  
  
class LongRoPEModel(nn.Module):  
    """  
    Long Range Rotary Position Encoding (LongRoPE) model.  
  
    This model extends the context window of transformer-based models beyond the  
    typical limit by using non-uniform interpolation of rotary position embeddings.  
    It enables the model to handle longer input sequences while maintaining the  
    ability to capture long-range dependencies.  
  
    Attributes:  
        d\_model (int): Dimension of the model.  
        n\_heads (int): Number of attention heads.  
        num\_layers (int): Number of transformer layers.  
        max\_len (int): Maximum sequence length.  
        rope (RoPEPositionalEncoding): Rotary Position Encoding (RoPE) module.  
        transformers (nn.ModuleList): List of transformer encoder layers.  
        lambda\_factors (list): Lambda factors for non-uniform interpolation.  
        lambda\_factors\_base (list): Lambda factors for the base model.  
        extension\_ratio (float): Extension ratio for the context window.  
        n\_hat (int): Threshold for applying interpolation.  
  
    Methods:  
        forward(input\_ids):  
            Perform forward pass on the input sequence.  
  
            Args:  
                input\_ids (torch.Tensor): Input sequence tensor.  
  
            Returns:  
                torch.Tensor: Output embeddings from the model.  
  
        extend\_context(data\_path, target\_length, max\_sequence\_length, tokenizer):  
            Extend the context window of the model.  
  
            Args:  
                data\_path (str): Path to the input data file.  
                target\_length (int): Target context window length.  
                max\_sequence\_length (int): Maximum sequence length for input data.  
                tokenizer: Tokenizer object for encoding input data.  
  
            Returns:  
                LongRoPEModel: Extended LongRoPE model.  
    """  
  
    def __init__(self, d_model, n_heads, num_layers, max_len=5000):  
        super().__init__()  
        self.d_model = d_model  
        self.num_layers = num_layers  
        self.rope = RoPEPositionalEncoding(d_model, max_len)  
        self.transformers = nn.ModuleList(  
            [  
                nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads)  
                for _ in range(num_layers)  
            ]  
        )  
        self.lambda_factors = None  
        self.lambda_factors_base = None  
  
    def forward(self, input_ids):  
        positions = torch.arange(input_ids.size(1), device=input_ids.device)  
        pos_embeddings = self.rope(positions)  
  
        if self.lambda_factors is not None:  
            pos_embeddings = non_uniform_interpolation(  
                pos_embeddings, self.extension_ratio, self.lambda_factors, self.n_hat  
            )  
  
        input_embeddings = input_ids + pos_embeddings  
  
        for transformer in self.transformers:  
            input_embeddings = transformer(input_embeddings)  
  
        return input_embeddings  
  
    def extend_context(self, data_path, target_length, max_sequence_length, tokenizer):  
        """  
        Extend the context window of the model.  
  
        Args:  
            data\_path (str): Path to the input data file.  
            target\_length (int): Target context window length.  
            max\_sequence\_length (int): Maximum sequence length for input data.  
            tokenizer: Tokenizer object for encoding input data.  
  
        Returns:  
            LongRoPEModel: Extended LongRoPE model.  
        """  
        if tokenizer is None:  
            raise ValueError("Tokenizer is required for extending context.")  
  
        self.extension_ratio = target_length / self.rope.max_len  
  
        data = load_data(data_path, tokenizer, max_sequence_length)  
        model, lambda_factors, lambda_factors_base = progressive_extension(  
            self, data, self.rope.max_len, target_length  
        )  
  
        self.lambda_factors = lambda_factors  
        self.lambda_factors_base = lambda_factors_base  
        self.n_hat = self.rope.max_len // 2  
  
        return model  
  
  
def load_data(data_path, tokenizer, max_sequence_length):  
    """  
    Load and preprocess the input data.  
  
    Args:  
        data\_path (str): Path to the input data file.  
        tokenizer: Tokenizer object for encoding input data.  
        max\_sequence\_length (int): Maximum sequence length for input data.  
  
    Returns:  
        list: List of preprocessed input sequences.  
    """  
    if data_path is None or tokenizer is None:  
        raise ValueError("Data path and tokenizer are required for loading data.")  
  
    if data_path.endswith(".gz"):  
        with gzip.open(data_path, "rt", encoding="utf-8") as file:  
            text_data = file.read()  
    else:  
        with open(data_path, "r", encoding="utf-8") as file:  
            text_data = file.read()  
  
    tokenized_data = tokenizer.encode(text_data)  
  
    sequences = [  
        tokenized_data[i : i + max_sequence_length]  
        for i in range(0, len(tokenized_data), max_sequence_length)  
    ]  
  
    tensor_data = [torch.tensor(seq, dtype=torch.long) for seq in sequences]  
  
    return tensor_data  
  
  
def initialize_population(population_size, extension_ratio):  
    """  
    Initialize the population for evolutionary search.  
  
    Args:  
        population\_size (int): Size of the population.  
        extension\_ratio (float): Extension ratio for context window.  
  
    Returns:  
        list: Initialized population.  
    """  
    population = []  
  
    population.append(torch.ones(512) * extension_ratio)  
  
    ntk_factors = torch.tensor([extension_ratio ** (2 * i / 512) for i in range(512)])  
    population.append(ntk_factors)  
  
    yarn_factors = torch.ones(512)  
    yarn_factors[:128] = 1.0  
    yarn_factors[128:256] = extension_ratio ** (1 / 3)  
    yarn_factors[256:] = extension_ratio  
    population.append(yarn_factors)  
  
    for _ in range(population_size - 3):  
        factors = torch.ones(512)  
        for i in range(512):  
            if random.random() < 0.1:  
                factors[i] = random.uniform(1, extension_ratio)  
        population.append(factors)  
  
    return population  
  
  
def evaluate_individual(model, data, individual):  
    """  
    Evaluate an individual lambda factor configuration.  
  
    Args:  
        model (nn.Module): LongRoPE model.  
        data (list): List of input sequences.  
        individual (list): Lambda factor configuration.  
  
    Returns:  
        float: Perplexity score for the individual.  
    """  
    model.lambda_factors = individual  
    perplexities = []  
  
    for seq in data:  
        input_ids = seq.unsqueeze(0)  
        output = model(input_ids)  
        perplexity = torch.exp(torch.mean(output))  
        perplexities.append(perplexity.item())  
  
    return np.mean(perplexities)  
  
  
def evaluate_population(model, data, population):  
    """  
    Evaluate the population of lambda factor configurations.  
  
    Args:  
        model (nn.Module): LongRoPE model.  
        data (list): List of input sequences.  
        population (list): Population of lambda factor configurations.  
  
    Returns:  
        list: Perplexity scores for each individual in the population.  
    """  
    perplexities = []  
    for individual in population:  
        perplexity = evaluate_individual(model, data, individual)  
        perplexities.append(perplexity)  
    return perplexities  
  
  
def select_topk(population, perplexities, k):  
    """  
    Select the top-k individuals from the population based on perplexity scores.  
  
    Args:  
        population (list): Population of lambda factor configurations.  
        perplexities (list): Perplexity scores for each individual in the population.  
        k (int): Number of top individuals to select.  
  
    Returns:  
        list: Top-k individuals from the population.  
    """  
    indices = np.argsort(perplexities)[:k]  
    return [population[i] for i in indices]  
  
  
def mutate(parents, num_mutations):  
    """  
    Perform mutation on the parent population.  
  
    Args:  
        parents (list): Parent population.  
        num\_mutations (int): Number of mutations to perform.  
  
    Returns:  
        list: Mutated population.  
    """  
    mutated_population = []  
    for _ in range(num_mutations):  
        parent = random.choice(parents)  
        child = parent.clone()  
        for i in range(512):  
            if random.random() < 0.1:  
                child[i] *= random.uniform(0.8, 1.2)  
        mutated_population.append(child)  
    return mutated_population  
  
  
def crossover(parents, num_crossovers):  
    """  
    Perform crossover on the parent population.  
  
    Args:  
        parents (list): Parent population.  
        num\_crossovers (int): Number of crossovers to perform.  
  
    Returns:  
        list: Crossover population.  
    """  
    crossover_population = []  
    for _ in range(num_crossovers):  
        parent1, parent2 = random.sample(parents, 2)  
        child = parent1.clone()  
        for i in range(512):  
            if random.random() < 0.5:  
                child[i] = parent2[i]  
        crossover_population.append(child)  
    return crossover_population  
  
  
def fine_tune(model, data, target_length, lambda_factors, n_hat, num_epochs=3):  
    """  
    Fine-tune the LongRoPE model.  
  
    Args:  
        model (nn.Module): LongRoPE model.  
        data (list): List of input sequences.  
        target\_length (int): Target context window length.  
        lambda\_factors (list): Lambda factors for interpolation.  
        n\_hat (int): Threshold for applying interpolation.  
        num\_epochs (int, optional): Number of fine-tuning epochs. Defaults to 3.  
  
    Returns:  
        nn.Module: Fine-tuned LongRoPE model.  
    """  
    model.lambda_factors = lambda_factors  
    model.n_hat = n_hat  
    optimizer = optim.Adam(model.parameters(), lr=1e-4)  
  
    for epoch in range(num_epochs):  
        for seq in data:  
            optimizer.zero_grad()  
  
            seq_len = seq.size(0)  
            if seq_len <= target_length:  
                input_ids = seq.unsqueeze(0)  
            else:  
                start_idx = random.randint(0, seq_len - target_length)  
                input_ids = seq[start_idx : start_idx + target_length].unsqueeze(0)  
  
            output = model(input_ids)  
            loss = torch.mean(output)  
  
            loss.backward()  
            optimizer.step()  
  
    return model  
  
  
# Example usage  
data_path = "path/to/your/dataset"  
d_model = 512  
n_heads = 8  
num_layers = 6  
base_length = 4096  
target_length = 2048 * 1024  
  
data = load_data(data_path)  
model = LongRoPEModel(d_model, n_heads, num_layers, base_length)  
model = model.extend_context(data, target_length)  
  
input_ids = torch.randn(2, target_length, d_model)  
output = model(input_ids)  
print(output.shape)  # Expected shape: (batch\_size, target\_length, d\_model)  

      
0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
亿万用户下高可用融合直播的应用实践
直播融合 CDN 调度系统承担了公司内所有直播流量的接入工作,对高并发高带宽场景支持友好,有完善的体系进行容灾降级、质量优化、成本优化。本次演讲将带大家了解直播融合 CDN 调度系统的整体架构及在抖音上的应用。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论