LLM实践系列-细聊LLM的拒绝采样

大模型机器学习数据库

今天给大家带来一篇知乎好友@ybq的文章,《拒绝采样》。


        
          
知乎:https://zhuanlan.zhihu.com/p/3907736367  

      

最近学强化的过程中,总是遇到“拒绝采样”这个概念,我尝试科普一下,争取用最大白话的方式让每个感兴趣的同学都理解其中思想。

拒绝采样是 LLM 从统计学借鉴过来的一个概念。其实大家很早就接触过这个概念,每个刷过 leetcode 的同学大概率都遇到过这样一个问题:“如何用一枚骰子获得 1/7 的概率?”

答案很简单:把骰子扔两次,获得 6 * 6 = 36 种可能的结果,丢弃最后一个结果,剩下的 35 个结果平分成 7 份,对应的概率值便为 1/7 。使用这种思想,我们可以利用一枚骰子获得任意 1/N 的概率。

在这个问题中,我们可以看到拒绝采样的一些关键要素:

  • 采样 :从易于采样的分布(两个骰子的所有可能结果)中生成样本;
  • 缩放 :(扔两次骰子)获得更大的样本分布;
  • 拒绝 :丢弃(拒绝)不符合条件的样本(第36种情况);
  • 接受 :对于剩下的样本,重新调整概率(通过分组),获得目标概率分布。

用大白话来总结就是:我们想获得某个分布(1/7)的样本,但却没有办法。于是我们对另外一个分布(1/6)进行采样,但这个分布不能涵盖原始分布,需要我们缩放这个分布(扔两次)来包裹起来目标分布。然后,我们以某种规则拒绝明显不是目标分布的采样点,剩下的采样点就可以看作是从目标分布采样出来的了。

统计学的拒绝采样

我们再来看下统计学中是如何定义拒绝采样的:拒绝采样(Rejection Sampling)是一种 Monte Carlo 的统计方法,用于从复杂的目标概率分布中生成随机样本 。当直接从目标分布 中采样困难或不可行时,使用一个易于采样的提议分布 ,并根据某种接受概率来决定是否接受采样结果。具体过程如下

1.择提议分布 这个分布易于直接采样,并且覆盖目标分布的支持,即目标分布可能取值的所有区域;

  1. 确定缩放常数 找到一个常数 ,使得对于所有的 ,都有 。这个常数确保了 成为 的上界,这个系数可以让新分布能把旧分布“包裹”起来;
  2. 采样过程:
  • 步骤 a:从提议分布 中生成一个样本;
  • 步骤 b:从均匀分布 中采样一个随机数 ;
  • 步骤 c:计算接受概率 ;
  • 步骤 d:如果 ,则接受样本 ;否则,拒绝样本并返回步骤 a。

LLM 的拒绝采样

LLM 的拒绝采样操作起来非常简单:让自己的模型针对 prompt 生成多个候选 response,然后用 reward_model 筛选出来高质量的 response (也可以是 pair 对),拿来再次进行训练。

解剖这个过程:

  1. 提议分布是我们自己的模型,目标分布是最好的语言模型;
  2. prompt + response = 一个采样结果;
  3. do_sample 多次 = 缩放提议分布(也可以理解为扔多次骰子);
  4. 采样结果得到 reward_model 的认可 = 符合目标分布。

经过这一番操作,我们能获得很多的训练样本,“这些样本既符合最好的语言模型的说话习惯,又不偏离原始语言模型的表达习惯 ”,学习它们就能让我们的模型更接近最好的语言模型。

统计学与 LLM 的映射关系

统计学的拒绝采样有几个关键要素:

  1. 原始分布采样困难,提议分布采样简单;
  2. 提议分布缩放后能涵盖原始分布;
  3. 有办法判断从提议分布获取的样本是否属于原始分布,这需要我们知道原始分布的密度函数。

LLM 的拒绝采样也有几个对应的关键要素:

  1. 我们不知道最好的语言模型怎么说话,但我们知道自己的语言模型如何说话;
  2. 让自己的语言模型反复说话,得到的语料大概率会包括最好的语言模型的说话方式;
  3. reward_model 可以判断某句话是否属于最好的语言模型的说话方式。

目前为止,是不是看上去很有道理,很好理解。但其实这里有一个致命的逻辑漏洞:为什么我们的模型反复 do_sample,就一定能覆盖最好的语言模型呢?这不合逻辑啊,狗嘴里采样多少次也吐不出象牙啊。

紧接着,就需要我们引出另一个概念了:RLHF 的优化目标是什么?

RLHF 与拒绝采样

RLHF 的优化目标,并不是获得说话说的最好的模型,而是获得 reward_model 和 reference_model (被优化的模型)共同认可的模型。这个观点,可以从 RLHF 的最优解看出来:

其中, 是归一化分子,可以无视; 是 reference_model,同时也是被优化的模型的最初状态;

是 reward_model;证明过程可以自己行搜索,不再赘述。

现在,大家应该明白了吧:在 RLHF 的训练框架下,reward_model 认为谁是最好的语言模型,谁就是最好的语言模型,人类的观点并不重要。与此同时,即使 reward_model 告诉了我们最好的语言模型距离当前十公里,但 reference_model 每次只允许我们走两公里,所以 RLHF 需要反复迭代进行。

此时,我们再回过头来看拒绝采样,就能理解它的核心思想了:

  • 通过对原始模型 do_sample,我们获得了很多个样本,每个样本代表一个优化方向;
  • reward_model 知道最优的目标分布在哪个方向,它帮我们选择了一个正确的方向;
  • 沿着这个方向,我们小心翼翼的往前迈了一步(学习模型本就能生成出来的话,大概率不会让模型发生较大的改变);
  • 和 rlhf 一样,拒绝采样需要多次迭代才能到达 reward_model 认为的最好的语言分布。

灵感来源:https://mp.weixin.qq.com/s/2txfqHpyiW-ipKuQSWAsLA

统计学知识来源:ChatGPT_o1

写在最后

总结下来:拒绝采样虽然没有使用 ppo,但满满的都是 rlhf 的思想。

PS:看到这里,如果觉得不错,可以来个点赞在看关注 。给公众号添加【星标⭐️】不迷路!您的支持是我坚持的最大动力!

欢迎多多关注公众号「NLP工作站」,加入交流群,交个朋友吧,一起学习,一起进步!

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论