unsloth杀疯了,3090即可训练超长上下文grpo!

大模型算法机器学习

unsloth 凌晨又放大招了,发布长上下文grpo算法。不仅显存消耗减少,而且 可以用之前10%的显存消耗,实现比其他所有的lora/qlora/flash attention实现的grpo,长10倍的上下文长度。

先说个题外话,unsloth 上周在招聘高级工程师,在x上发布了4道题。每道题有不同的得分,如果得分超过47分,只需要一次面试,就可以获得年薪50W刀的工作!(先到先得,一次面试,是避免题目不是你自己做的,后做出来的也有美刀补贴)

picture.image

回到这个最新的工作。对比,使用 TRL + FA2 设置的 GRPO,Llama 3.1(8B)在 20K 上下文长度下的训练需要 510.8GB 显存。然而,而 Unsloth 仅需 54.3GB。

仍然是可以在colab体验:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1\_(8B)-GRPO.ipynb

不仅有工程优化,还有算法优化!

算法优化:反向KL散度优化

相比较而言,优化后的反向KL。策略更新更保守,避免过度偏离参考模型,而且对长尾分布的建模更稳定。跟优势函数结合时梯度更平滑。

大佬们,球球了。帮帮小弟薅点deepseek 满血token咯。走我的火山引擎地址体验体验,咱俩都可以白嫖几百万token。火山部署的r1体验很好~

https://www.volcengine.com/experience/ark?utm\_term=202502dsinvite∾=DSASUQY5&rc=M4ZCMBFE

picture.image

工程优化:

  • 通过分块计算和内存复用技术,将长上下文(20K tokens)的VRAM需求降低8倍

        
 
   

 
 
        
            

          
 # 传统实现(需78.3GB)
 
          
   

 
          logits = model(input\_ids)  
          
 # 完整计算所有token的logits
 
          
   

 
          
   

 
          
 # Unsloth优化实现(仅需9.8GB)
 
          
   

 
          logits = optimized\_linear\_attention(
          
   

 
              query, key, value, 
          
   

 
              use\_memory\_efficient=True  
          
 # 启用内存优化
 
          
   

 
          )
          
   

 
        
      
  • 通过异步数据转移技术,仅增加1%训练时间 (支持中间梯度累积进一步优化)

        
 
   

 
 
        
            

          
 # 传统梯度检查点(需372GB)
 
          
   

 
          torch.utils.checkpoint.checkpoint(forward\_fn, inputs)
          
   

 
          
   

 
          
 # Unsloth异步卸载(节省372GB)
 
          
   

 
          async\_offload\_to\_cpu(activations)  
          
 # 异步转移中间激活到CPU
 
          
   

 
          retain\_partial\_graph()  
          
 # 保留部分计算图用于反向传播
 
          
   

 
        
      
  • 内存共享优化

        
 
   

 
 
        
            

          
 # 传统实现(额外16GB开销)
 
          
   

 
          inference\_engine = vLLMEngine()  
          
 # 独立内存空间
 
          
   

 
          training\_engine = PyTorchEngine()
          
   

 
          
   

 
          
 # Unsloth集成方案(0额外开销)
 
          
   

 
          unified\_memory\_pool = create\_shared\_memory\_space()
          
   

 
          vLLMEngine.attach(unified\_memory\_pool)
          
   

 
          PyTorchEngine.attach(unified\_memory\_pool)
          
   

 
        
      

其他的一些优化,比如:跟vLLM深度集成,动态的4bit量化(相比标准4bit量化提升3.2%准确率)

这个trick之前介绍open-r1的代码时,给大家说过,他们测试发现还是非常有必要的。picture.image

简单写写,学不完,根本学不完。

建议大家学习博客和代码: https://unsloth.ai/blog/grpo

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
大规模高性能计算集群优化实践
随着机器学习的发展,数据量和训练模型都有越来越大的趋势,这对基础设施有了更高的要求,包括硬件、网络架构等。本次分享主要介绍火山引擎支撑大规模高性能计算集群的架构和优化实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论