unsloth 凌晨又放大招了,发布长上下文grpo算法。不仅显存消耗减少,而且 可以用之前10%的显存消耗,实现比其他所有的lora/qlora/flash attention实现的grpo,长10倍的上下文长度。
先说个题外话,unsloth 上周在招聘高级工程师,在x上发布了4道题。每道题有不同的得分,如果得分超过47分,只需要一次面试,就可以获得年薪50W刀的工作!(先到先得,一次面试,是避免题目不是你自己做的,后做出来的也有美刀补贴)
回到这个最新的工作。对比,使用 TRL + FA2 设置的 GRPO,Llama 3.1(8B)在 20K 上下文长度下的训练需要 510.8GB 显存。然而,而 Unsloth 仅需 54.3GB。
仍然是可以在colab体验:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1\_(8B)-GRPO.ipynb
不仅有工程优化,还有算法优化!
算法优化:反向KL散度优化
传
统
前
向
散
度
:
强
调
覆
盖
所
有
的
模
式
反
向
散
度
:
强
调
避
免
的
过
度
泛
化
相比较而言,优化后的反向KL。策略更新更保守,避免过度偏离参考模型,而且对长尾分布的建模更稳定。跟优势函数结合时梯度更平滑。
大佬们,球球了。帮帮小弟薅点deepseek 满血token咯。走我的火山引擎地址体验体验,咱俩都可以白嫖几百万token。火山部署的r1体验很好~
https://www.volcengine.com/experience/ark?utm\_term=202502dsinvite∾=DSASUQY5&rc=M4ZCMBFE
工程优化:
- 通过分块计算和内存复用技术,将长上下文(20K tokens)的VRAM需求降低8倍
# 传统实现(需78.3GB)
logits = model(input\_ids)
# 完整计算所有token的logits
# Unsloth优化实现(仅需9.8GB)
logits = optimized\_linear\_attention(
query, key, value,
use\_memory\_efficient=True
# 启用内存优化
)
- 通过异步数据转移技术,仅增加1%训练时间 (支持中间梯度累积进一步优化)
# 传统梯度检查点(需372GB)
torch.utils.checkpoint.checkpoint(forward\_fn, inputs)
# Unsloth异步卸载(节省372GB)
async\_offload\_to\_cpu(activations)
# 异步转移中间激活到CPU
retain\_partial\_graph()
# 保留部分计算图用于反向传播
- 内存共享优化
# 传统实现(额外16GB开销)
inference\_engine = vLLMEngine()
# 独立内存空间
training\_engine = PyTorchEngine()
# Unsloth集成方案(0额外开销)
unified\_memory\_pool = create\_shared\_memory\_space()
vLLMEngine.attach(unified\_memory\_pool)
PyTorchEngine.attach(unified\_memory\_pool)
其他的一些优化,比如:跟vLLM深度集成,动态的4bit量化(相比标准4bit量化提升3.2%准确率)
这个trick之前介绍open-r1的代码时,给大家说过,他们测试发现还是非常有必要的。
简单写写,学不完,根本学不完。
建议大家学习博客和代码: https://unsloth.ai/blog/grpo