Proxy-Tuning:大模型无须调整权重的大幅提高效果的调优方法

容器Service Mesh

“ 这里只放了原理,学习一下方法,还有2个示例code没放过来


        
          
https://lightning.ai/lightning-ai/studios/improve-llms-with-proxy-tuning  
https://arxiv.org/abs/2401.08565  

      

Proxy-tuning 是一种在不改变模型权重的情况下调整LLM的方法。如果给定的LLM训练资源过于昂贵,或者用户无法访问LLM的权重,这种方法尤其具有吸引力。举几个具体的示例:

  • 假设目前还不存在Llama 2 70B Chat模型。相反,我们只有Llama 2 70B基础模型。Proxy-tuning使我们能够使这个基础模型表现得和聊天模型一样好,而无需改变基础模型的权重。
  • 在比如,我们有2个7B模型,通过proxy-tunning可以达到13B模型的效果。
  • 又或者我们把13B模型的code能力注入到7B的模型里边
Understanding Proxy-Tuning

Proxy-tuning 提供了一个目标 LLM,具有经过调优的版本的能力,而实际上并没有对其进行调优。如下图

picture.image

下面是步骤:

  • 选择一个比目标LLM(例如,未调整的70B Llama 2模型)更小更便宜的基础LLM(例如,未调整的7B Llama 2模型)
  • 微调这个较小的基础LLM模型,以获得一个小型的微调LLM模型(例如,对一个7B Llama 2模型进行指令微调,以获得一个微调后的7B模型)。
  • 计算基本模型(步骤1)和调整模型(步骤2)之间的输出差异。
  • 将这个差异加到目标LLM的输出上
  • 将第4步中修改后的输出进行规范化处理,然后生成答案。

如果上面的步骤过于抽象,可以考虑下面这个PyTorch伪代码的具体示例:


        
          
generated_tokens = []  
  
input_txt = (  
  "If I have 5 apples and eat 2, but then find 3 more"  
  " on my way home, how many do I have?"  
)  
input_ids = tokenizer.encode(input_text)  
  
for _ in range(max_length):  
    # Obtain logits  
    logits_base = model_base(input_ids).logits # Llama 7B Base  
    logits_tuned = model_tuned(input_ids).logits # Llama 7B Chat  
    logits_target = model_target(input_ids).logits # Llama 70B Base  
                                 
    # Apply proxy-tuning                              
    logits = (  
        logits_target + (logits_tuned - logits_base)  
    )  
  
    # Normalize logits and obtain token  
    predictions = torch.softmax(logits[:, -1, :], dim=-1)  
    next_token_id = torch.argmax(predictions).unsqueeze(0)  
    generated_tokens.append(next_token_id.item())  
      
generated_text = tokenizer.decode(generated_tokens)  
print(generated_text)  
  
# Output:   
# You start with 5 apples and eat 2,  
# so you have 5 - 2 = 3 apples left.  
# Then, you find 3 more apples on your way home,   
# so you have 3 + 3 = 6 apples in total.  

      

上面的代码片段获取了编码输入文本,并使用三种模型分别获得了每个输出token的logits。然后,使用之前描述的Proxy-tuning逻辑,计算了调整模型和基础模型之间logits的差异。然后,将这种差异应用到目标模型的logits上。随后,我们像往常一样对logits进行归一化,并对下一个标记进行采样。(这里,我们使用PyTorch的贪婪采样,选择具有最高概率的标记,但我们也可以使用其他采样技术,比如top-k采样或nucleus采样。)logits_tuned - logits_base argmax

picture.image根据原论文,Proxy-tuning效果非常好。

例如,在AlpacaFarm和GSM数据集上,Proxy-tuning 70B Llama 2基础模型导致了显著的性能提升(AlpacaFarm为88.0%,GSM为32.0%)。此外,在TruthfulQA数据集上,代理调整的模型比直接调整的模型更真实。这种方法在领域特定任务中也非常有效,比如编码和任务特定调整。

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
DevOps 在字节移动研发中的探索和实践
在日益复杂的APP工程架构下,如何保证APP能高效开发,保障团队效能和工程质量?本次将结合字节内部应用的事件案例,介绍DevOps团队对移动研发效能建设的探索和思考。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论