大家好,今天给大家带来一篇好友知乎@ybq的文章,《如何理解 LLM 中的 RL 算法》。
知乎:https://zhuanlan.zhihu.com/p/22331625359
随着最近 R1 爆火,我经常刷到一些有意思的话题,例如:
- SFT 无用,RL 才是通往智能化的正解;
- R1 并不像是传统的强化学习,更像是监督学习;
- ……
这些话题,或多或少有我曾经的疑惑在里面,所以今天写出来和大家分享一下,讨论下我们到底该如何理解 LLM 中所涉及到的 RL 算法。
事先声明:我的一切视角,是从如何训好一个 NLP 模型的角度出发的,所以我不在乎算法是 sft 或 rlhf,也不纠结监督学习和强化学习在理论上有何本质区别。我只关心,哪种 loss 能让模型达到更好的效果。
什么是 LLM 中的 RL
如果我们从 loss 函数的角度来看 sft 和 rlhf,会发现二者在本质上没有差别:无非都是一个条件概率公式嘛,围绕着 next_token 的 probability 做文章。只不过在实现细节上,sft 的 next_token 有一个明确的 target,距离这个 target 远 loss 就大,否则 loss 就小;rlhf 的 next_token 则是有一个 reward,如果这个 reward 高就鼓励它,reward 低就打压它。
至于其他区别,那仅仅是两种算法的习惯性用法不同而已。比如 reference_model,有人规定 sft 的时候不能加 reference_model 了吗?这里明确给出个结论:不仅能加,而且有效。我和 知乎@真中合欢 做过很鲁棒的实验,无论是 pretrain 或者 sft,只要让模型在不想学习的数据(有点脏但不得不用)上加 reference_model,就能有效果。
那么,既然两种算法在 loss 函数上没有本质区别,他们的区别又体现在哪里呢?我个人的观点是:explore 。这也是我对强化学习的理解:“自己玩,旁人来纠正”。具体来说,下列七个算法,除了算法 1 和算法 2,我认为均属于强化学习范畴。
除了特别严谨的强化学习论文,目前基本上都不区分 online / offline 和 on policy / off policy 这两个概念了,本文暂且视为是同一个概念。此外,我会用 ppo 作为默认强化算法,不再和 grpo 等进行区分。
- 指定 response 的 sft
- 指定 response 的 dpo (在算法 1 的基础上引入负例)
- offline reject sampling + sft
- offline reject sampling + dpo
- online reject sampling + sft (在算法 3 的基础上,把 explore 粒度从 epoch 变成 batch)
- online reject sampling + dpo
- ppo(兜兜转转一大圈,算法 6 不过是算法 7 的下位代替者罢了)
post-training 阶段的所有算法都在做一件事:输出当前文本下的 next_token,然后纠错 。只不过 sft 在强制学,rlhf 在 explore 学,强制学进步快,explore 学根基稳 。
因此,“直接对模型上 ppo 算法就能起效果”这一结论对算法从业者来说完全不吃惊。sft 本就不是训 LLM 的必备环节,只不过是能让模型提点最快的一种方案而已。但如果说 sft 完全无用也属实是过激了,毕竟只看 loss 函数的话完全可以这么理解:sft 就是在每个 token 粒度都有一个 reward_model 的 ppo 算法。
“explore 的学习方式”是否在理论上具有优越性,我没有充分的证据,我只是在实验阶段中有些经验而已:“如果不让模型用 explore 的方式进行训练,3 个 epoch 的 sft 真的背不下来一些知识,更多 epoch 则会过拟合十分严重,这在 math 集合上的实验结论十分明显。”(知乎@真中合欢 曾经和我分享过一些实验现象,说是他观察到 on-policy 得到的数据,在训 sft 的时候梯度噪声会更少,梯度噪声指梯度大但对模型更新无帮助)
如果用人的思维方式来分析,就很好理解了:一字不落的背下来一篇文章很难,但如果只背个大概,用自己的理解去复述这篇文章的内容,无关痛痒的说错几个字不去管,关健结论说错了就纠偏,自然背的会更快一些。
post training 算法的统一建模
deepseek 在去年的时候,就已经在技术报告里指出过,sft 和 rlhf 算法在 loss 函数的设计上没有本质区别。具体来说,deepseek 认为 post training 算法包括三要素:启动数据,reward function,token 粒度的 gradient coefficient。sft 的 Gradient Coefficient 是 1,ppo 的 Gradient Coefficient 是 Advantage。
具体内容如下图所示,大家也可以找原论文重新拜读一下,这里就不逐一分析了。
统一建模
sft
reject sampling sft
online reject sampling sft
dpo
ppo
RL 为什么难训
有了前面这些铺垫,我也可以说一下我对 rl 训练容易崩溃的一些理解了。我觉着 rl 不如 sft 稳定,问题出就出在 token 粒度的 reward 是否准确这一点上。
前面说了,sft 的训练过程,是每个 token 都有一个明确的 target 存在的,其优化目标很纯粹,增大这个 target 的概率。我很难想出这种训练方式会存在标签不合理的地方,即使是你正走在一条正确的道路上,却被强制拉到另一条正确的道路上,好像也没啥太大影响吧。
但 rl 不同,每个 token 的 reward 是由整个句子的 reward 回传回来的(带上 value function 的预测),试想一个句子“中国的首都不是南京,是北京 ”,因为太过啰嗦被打上了一个较低的 reward,那问题是“是南京 ”这三个 token 做错了什么,在上个 token 的回答是“不 ”的情况下,这三个 token 已经是当下最优的 token 了。此时,如果 value function 能救回来还好,但显然不太容易。这里注意,传统的 rl,每一个 action 是能有一个及时回报的,但 rlhf 算法中是没有的,它只有折扣累积回报(rlhf 中,每个 action 的及时回报,要么被设置成 0,要么被设置成 kl_penalty),这也进一步导致了 token 级别 reward 的不准确。
就这,还都是建立在整个 response 的 reward 打分准确的基础上,打不准就更头大了。如何给每个 token 一个正确的打分,那就是 ppo / grpo / rloo 等算法各自的努力方向了,它们的出发点和实现方式各不相同,甚至对 KL_penalty 施加的位置都不同,有的放进 reward 中,有的放进 advantage 中。熟优熟劣,就要靠各位的实验结论和理论推导了,我暂时没有结论。
啰哩啰嗦那么多,其实就是想说因为 label 不准, rl 天生比 sft 不太好训练,因此才需要那么多的调参工作。也正是因为 token 粒度的 reward 不准, rl 后的模型出现一些诡异回复也就不那么难理解了。再次提醒,不管什么算法,你只要把 reference_model 的 KL_penalty 开得足够大,都会稳如泰山。
更多理论知识,推荐阅读:https://zhuanlan.zhihu.com/p/19223907990
Reward hacking
非强化出身的我,早期常被 reward hacking 这个概念给唬到,总觉着背后有什么高大上的理论。其实,所谓的 reward hacking,归根结底就是训练者考虑不充分,既要又要导致的。
我很早做过一些和 R1 思路类似的 rule-based rl 实验,得到的实验现象别说 aha-moment 了,直接就是模型越训越短。这是 reward-hacking 吗?当然是,是训练原因导致的吗?不是,完全是因为 prompt 太简单了或是模型背过这道题,模型根本不需要 cot 过程就能直接说出来答案,说的越多就错的越狠。这一点,kimi 的技术报告提到过,如果模型不 cot 就能直接说出答案,需要删掉这些 prompt。
我还有过一版 rule-based rl 实验,reward 是通过模型来判别 ground_truth 是否出现在 response 里来确定的。训练过程中 reward 确实嘎嘎上涨,模型的 response 却全都是“ …… 这个题选<im_start><im_start> A <im_start>” 这种。这能怪模型 reward-hacking 了吗?怪不了一点,但凡多说一句“如果格式不符合标准,就打 0 分”,也就不会出现这种现象。
所以,reward hacking 其实就是模型以训练者不期望的方式找到了提高 reward 的方法 。训练者期待的是模型有条不紊的进行分析,模型找到的法子是“直接说答案吧,要不蒙一个选项吧,输出点乱码扰乱下 attention 吧,多复述一下 prompt 吧 ……” 我们想要的是模型按照某种方法提高 reward,但我们设计的 reward 函数却只在乎 reward,而不在乎“按照某种方法”,那么自然而然的就会不符合预期。
万变不离其宗,有多少人工就有多少智能。sft 要时刻留意数据质量,rlhf 则是要时刻留意 reward 的打分是否准确或者说是 reward 的设计是否合理,后者一点都不比洗数据轻松。
写在最后
写这篇文章的目的是想说,我们做技术的人目标只有一个:训出更好的模型。所有的算法应该都只是我们手上的工具而已,谁用的顺手就用谁。rl 并非高不可攀,sft 也永不过时。
PS:看到这里,如果觉得不错,可以来个点赞 、在看 、关注 。给公众号添加【星标⭐️】不迷路!您的支持是我坚持的最大动力!
欢迎多多关注公众号「NLP工作站」,加入交流群(3群也满了,等开4群吧),交个朋友吧,一起学习,一起进步!