LLM推理加速-Medusa

技术

写在前面

大家好,我是刘聪NLP。

今天又给大家带来一篇UC伯克利关于推理速度加速的新工作,来自知乎@uuuuu。

自己近期也在做这方面的工作,具体看了下源码,并且这个是可以结合到现有的加速推理框架里面的,感觉还不错。


          
知乎:https://zhuanlan.zhihu.com/p/655809033  
Github: https://github.com/FasterDecoding/Medusa  
Blog: https://sites.google.com/view/medusa-llm  

      

一、前言

从GPT4的一些技术细节泄露后,对于投机采样【Speculative Decoding】策略加速推理的研究比较多,但是投机采样依赖一个小而强的模型来生成对于原始的模型来说比较简单的token,其次在一个系统中维护2个不同的模型,导致架构上的复杂性,最后使用投机采样的时候,会带来额外的解码开销,尤其是当使用一个比较高的采样温度值时。

先看一下加速效果。picture.image

二、Medusa: Marrying Simplicity with Efficiency

picture.image主要思想是在正常的LLM的基础上,增加几个解码头,并且每个头预测的偏移量是不同的,比如原始的头预测第i个token,而新增的medusa heads分别为预测第i+1,i+2...个token。如上图,并且每个头可以指定topk个结果,这样可以将所有的topk组装成一个一个的候选结果,最后选择最优的结果

picture.image

计算每个头组装之后的候选的最优解,其实这时候完全可以每个候选都走一次模型,算出概率,但是很显然不可能这样做,因为本来方案是为了加速,作者设计了一种tree attention的机制,可以做到只走一次模型来达到目的,如示例所示,第一个medusa heads的 top-2 预测和第二个medusa heads的 top-3 预测产生 2*3=6 个候选。假设原始的LLM输出是[0],第一个头是[1,2],第二个头是[3,4,5]。期望直接能把[0,1,2,3,4,5],输入模型就能得到一些概率的信息,但是不同的头对应的token的父节点是不同的,所以需要冗余一些token,方便添加mask变成[0,1,2,3,4,5,3,4,5],对应到右上的mask矩阵,每个节点只有父节点以及当前节点是1,这样其实就能得到一个树的路径关系。

输入模型得到概率之后,可以通过头计算得到树的路径信息,比如示例对应的路径index是[0,1,3] , [0,1,4], [0,1,5], [0,2,6],然后基于后验概率得到最优的候选片段。

三、来看下具体实现

首先需要训练几个新增的头,不同的头预测的label的偏移量不同,所以可以组装每个头的topk作为候选:


          
# Customized for training Medusa heads  
class CustomizedTrainer(Trainer):  
    def compute_loss(self, model, inputs, return_outputs=False):  
        # DDP will give us model.module  
        if hasattr(model, "module"):  
            medusa = model.module.medusa  
        else:  
            medusa = model.medusa  
  
        logits = model(  
            input_ids=inputs["input\_ids"], attention_mask=inputs["attention\_mask"]  
        )  
        labels = inputs["labels"]  
        # Shift so that tokens < n predict n  
        loss = 0  
        loss_fct = CrossEntropyLoss()  
        log = {}  
        for i in range(medusa):  
            medusa_logits = logits[i, :, : -(2 + i)].contiguous()  
            medusa_labels = labels[..., 2 + i :].contiguous()  
            medusa_logits = medusa_logits.view(-1, logits.shape[-1])  
            medusa_labels = medusa_labels.view(-1)  
            medusa_labels = medusa_labels.to(medusa_logits.device)  
            loss_i = loss_fct(medusa_logits, medusa_labels)  
            loss += loss_i  
            not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)  
            medusa_labels = medusa_labels[not_ignore]  
  
            # Add top-k accuracy  
            for k in range(1, 6):  
                _, topk = medusa_logits.topk(k, dim=-1)  
                topk = topk[not_ignore]  
                correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)  
                log[f"medusa{i}\_top{k}"] = correct.float().mean().item()  
  
            log[f"medusa{i}\_loss"] = loss_i.item()  
        self.log(log)  
        return (loss, logits) if return_outputs else loss  

      

forward函数最后需要计算每个头


          
def forward(  
    self,  
    input_ids=None,  
    attention_mask=None,  
    labels=None,  
    past_key_values=None,  
    output_orig=False,  
    position_ids=None,  
):  
    with torch.inference_mode():  
        # Pass input through the base model  
        outputs = self.base_model.model(  
            input_ids=input_ids,  
            attention_mask=attention_mask,  
            past_key_values=past_key_values,  
            position_ids=position_ids,  
        )  
        if output_orig:  
            orig = self.base_model.lm_head(outputs[0])  
    # Clone the output hidden states  
    hidden_states = outputs[0].clone()  
    medusa_logits = []  
    # TODO: Consider parallelizing this loop for efficiency?  
    for i in range(self.medusa):  
        medusa_logits.append(self.medusa_head[i](hidden_states))  
    if output_orig:  
        return torch.stack(medusa_logits, dim=0), outputs, orig  
    return torch.stack(medusa_logits, dim=0)  

      

然后看推理的时候怎么组装,为了精简,摘出来最核心部分,涉及到3个函数,generate_candidates,tree_decoding,evaluate_posterior。分别对应到预测每个头的topk的token然后笛卡尔积组装成候选片段;用基础的LLM模型预测每一条路径的概率;路径选择;


          
def medusa_generate(  
    self,  
    input_ids,  
    attention_mask=None,  
    temperature=0.0,  
    max_steps=512,  
    # The hyperparameters below are for the Medusa  
    # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.  
    medusa_choices=[1, 7, 6],  
    posterior_threshold=0.09,  # threshold validation of Medusa output  
    # another threshold hyperparameter, recommended to be sqrt(posterior\_threshold)  
    posterior_alpha=0.3,  
):  
    ...省略....  
    for idx in range(max_steps):  
        # Generate candidates with topk predictions from Medusa heads  
        candidates, tree_candidates = generate_candidates(  
            medusa_logits,  
            logits,  
            medusa_topk,  
            medusa_buffers["tree\_indices"],  
            temperature,  
        )  
  
        # Use tree attention to verify the candidates and get predictions  
        medusa_logits, logits, outputs = tree_decoding(  
            self,  
            tree_candidates,  
            past_key_values,  
            medusa_buffers["medusa\_position\_ids"],  
            input_ids,  
            medusa_buffers["retrieve\_indices"],  
        )  
  
        # Evaluate the posterior of the candidates to select the accepted candidate prefix  
        best_candidate, accept_length = evaluate_posterior(  
            logits, candidates, temperature, posterior_threshold, posterior_alpha  
        )  
  
        # Update the input\_ids and logits  
        input_ids, logits, medusa_logits, new_token = update_inference_inputs(  
            input_ids,  
            candidates,  
            best_candidate,  
            accept_length,  
            medusa_buffers["retrieve\_indices"],  
            outputs,  
            logits,  
            medusa_logits,  
            new_token,  
            past_key_values_data,  
            current_length_data,  
        )  
  
        yield {  
            "text": self.tokenizer.decode(  
                input_ids[0, input_len:],  
                skip_special_tokens=True,  
                spaces_between_special_tokens=False,  
                clean_up_tokenization_spaces=True,  
            )  
        }  
  
        if self.tokenizer.eos_token_id in input_ids[0, input_len:]:  
            break  

      

最后单独看下generate部分3个函数具体实现

候选片段生成,并且组装成可以解析成tree的序列


          
def generate_candidates(medusa_logits, logits, medusa_topk, tree_indices, temperature):  
  
    # Greedy decoding for original logits  
    candidates = [torch.argmax(logits[:, -1]).unsqueeze(0)]  
    for i in range(medusa_logits.shape[0]):  
        candidate_i = torch.topk(medusa_logits[i, 0, -1], medusa_topk[i]).indices  
        candidates.append(candidate_i)  
    candidates_flat = torch.cat(candidates)  
    candidates = torch.cartesian_prod(*candidates)    
    tree_candidates = candidates_flat[tree_indices].unsqueeze(0)  
    return candidates, tree_candidates  

      

tree decoding,上面的得到的拉平的序列,算一下概率,最后根据retrieve_indices还原到原始的笛卡尔积的路径,可以得到路径上每个位置的概率


          
def tree_decoding(  
    model,  
    tree_candidates,  
    past_key_values,  
    medusa_position_ids,  
    input_ids,  
    retrieve_indices,  
):  
    position_ids = medusa_position_ids + input_ids.shape[1]  
    # Decode the tree candidates using the model  
    tree_medusa_logits, outputs, tree_logits = model(  
        tree_candidates,  
        output_orig=True,  
        past_key_values=past_key_values,  
        position_ids=position_ids,  
    )  
    # Reorder the logits based on the retrieve\_indices for consistency  
    logits = tree_logits[0, retrieve_indices]  
    medusa_logits = tree_medusa_logits[:, 0, retrieve_indices]  
    return medusa_logits, logits, outputs  

      

候选确认,只保留了贪婪解码部分,这里的输入logits就是上面输出的logits,通过后验概率来选取最优的候选,候选长度不一定是头的数量


          
def evaluate_posterior(  
    logits, candidates, temperature, posterior_threshold, posterior_alpha  
):  
     
    # Greedy decoding based on temperature value  
    if temperature == 0:  
        # Find the tokens that match the maximum logits for each position in the sequence  
        posterior_mask = (  
            candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)  
        ).int()  
        candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)  
        accept_length = candidates_accept_length.max()  
        # Choose the best candidate  
        if accept_length == 0:  
            # Default to the first candidate if none are accepted  
            best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)  
        else:  
            best_candidate = torch.argmax(candidates_accept_length).to(torch.long)  
        return best_candidate, accept_length  
    ..省略..  
    return best_candidate, accept_length  

      

四、一些参数配置

picture.image

几个解码头?以及每个解码头解码top k? 选择最多的top k,可以让模型更大概率接受解码的结果,但是会增加解码的时间开销picture.image

在最后计算最优的候选的时候,可以设置posterior_threshold,帮助根据模型预测的结果来判断候选的token是否合理,阈值越高则越严格。

请多多关注知乎「刘聪NLP」,有问题的朋友也欢迎加我微信「logCong」私聊,交个朋友吧,一起学习,一起进步。我们的口号是“生命不止,学习不停”。

PS:新书已出《ChatGPT原理与实战》,欢迎购买~~。

往期推荐:

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
高性能存储虚拟化方案 NVMe over Fabric 在火山引擎的演进
在云计算中,虚拟化存储扮演着重要角色,其中 iSCSI 协议在业界开放、流行多年。近年来,拥有更优性能的 NVMe over Fabrics 协议也得到了发展。本次分享介绍了 NVMe over Fabrics 在云原生和虚拟化方向的演进工作和成果。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论