【Trick】NEFTune:一种用于LLM的对抗训练和BERT适配

小程序计算MySQL

前言

对抗训练是一种训练阶段的trick,本文记录了一种用于LLM的训练阶段对embedding嵌入扰动的对抗训练算法,并进行小修改,以适配BERT的embedding扰动,旨在增加模型的泛化性和鲁棒性。

算法

picture.image

该算法的核心思想和一些常见的对抗训练算法一样,只不过原文章应用在生成式大模型中,都是在训练过程中,通过向嵌入层(embedding)的输出添加一定的噪声来增加模型的鲁棒性,从而提高模型的泛化能力。

原始代码


        
          
from torch.nn import functional as F  
  
def NEFTune(model, noise\_alpha=5)  
    def noised\_embed(orig\_embed, noise\_alpha):  
        def new\_func(x):  
            # during training, we add noise to the embedding  
            # during generation, we don't add noise to the embedding  
            if model.training:  
                embed_init = orig_embed(x)  
                dims = torch.tensor(embed_init.size(1) * embed_init.size(2))  
                mag_norm = noise_alpha/torch.sqrt(dims)  
                return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)  
            else:  
                return orig_embed(x)  
        return new_func  
    ##### NOTE: this is for a LLaMA model #####   
    ##### For a different model, you need to change the attribute path to the embedding #####  
    model.base_model.model.model.embed_tokens.forward = noised_embed(model.base_model.model.model.embed_tokens, noise_alpha)  
    return model  

      
  • orig_embed:原始的embedding层
  • noise_alpha:噪声尺度
  • mag_norm:噪声缩放因子

该代码应用于LLama架构的生成式大模型进行扰动,文章中提到产生了一定的效果,代码也很简短,算是一种训练trick。

picture.image

改动适配BERT架构


        
          
#!/usr/bin/env python  
# \_*\_coding:utf-8\_*\_  
# Author   :    Junhui Yu  
  
import torch  
  
import torch.nn.functional as F  
  
def NEFTune(model, noise\_alpha=5):  
    def noised\_embed(orig\_embed, noise\_alpha):  
        def new\_func(x):  
            if model.training:  
                embed_init = F.embedding(x, orig_embed.weight)  
                dims = torch.tensor(embed_init.size(1) * embed_init.size(2))  
                mag_norm = noise_alpha / torch.sqrt(dims)  
                return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)  
            else:  
                return F.embedding(x, orig_embed.weight)  
  
        return new_func  
  
    model.bert.embeddings.word_embeddings.forward = noised_embed(model.bert.embeddings.word_embeddings, noise_alpha)  
    return model  
  

      

上述代码对原始代码进行一点小改动,扰动BERT的embedding层,这一点与其他对抗训练如FGM、PGD等对抗训练思想一致。

可用于一些算法竞赛的提分小trick,具体还需要根据不同的数据进行尝试。

参考文献

【1】paper:https://arxiv.org/pdf/2310.05914.pdf

【2】code:https://github.com/neelsjain/NEFTune/tree/main

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
VikingDB:大规模云原生向量数据库的前沿实践与应用
本次演讲将重点介绍 VikingDB 解决各类应用中极限性能、规模、精度问题上的探索实践,并通过落地的案例向听众介绍如何在多模态信息检索、RAG 与知识库等领域进行合理的技术选型和规划。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论