“ 大白话+代码 理解rlhf
rlhf用到的强化学习算法是ppo,ppo是一种特殊的a2c算法,而a2c算法是actor-critic算法的改良版,
actor-critic算法 中包含actor、critic2个模型,其中演员actor是最终需要的模型,负责选择动作,基于当前环境状态S, 选择动作A,这个P(A|S)被称为策略;而critic模型主要用于评估这个动作的收益,从S开始,选择动作A,直到结束能够得到的奖励和期望Q(S,A),这个也叫状态动作价值。
演员actor模型用 更新参数;critic模型用 更新参数;其中是环境给出的真实奖励反馈
A2C算法 中引入优势,对actor、critic模型改良,表示是否超出预期
前2步不变:演员actor模型基于环境,得到动作;环境基于,给出奖励及新状态
评论家critc模型需要计算优势,先估计价值,优势
actor模型基于 更新,只是把Q换成了adv优势
critic模型基于 更新
ppo 的思路是维持训练稳定性,让更新幅度不要太大,在ppo中actor的更新损失是
p是本次参数更新前的策略,p'是上一次参数更新前的策略,当上一次已经很大了,预测的概率很高了,就没必要使劲更新参数,从而保持训练稳定性,在ppo中设置截断策略,令,当adv > 0 ,r > 1.2不更新, adv < 0, r < 0.8不更新,
上loss,如果取到右边的值,相当于损失中只有常量了,就不产生任何的梯度
rlhf
step1: 采样, 根据prompt,利用policy可以得到 response(答案)、概率((数量词表维度) )、 values (每次解码每个token的价值)
step2:计算reward,reward基于 参考模型 - actor模型 之间概率之间的kl 散度, 最后一个位置为完成位置加上reward模型的打分
#https://github.com/huggingface/trl/blob/main/trl/trainer/ppo\_trainer.py#L1074
def compute\_rewards(
self,
scores: torch.FloatTensor,
logprobs: torch.FloatTensor,
ref\_logprobs: torch.FloatTensor,
masks: torch.LongTensor,
):
rewards, non_score_rewards = [], []
for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
# compute KL penalty (from difference in logprobs)
kl = self._kl_penalty(logprob, ref_logprob)
non_score_reward = -self.kl_ctl.value * kl
non_score_rewards.append(non_score_reward)
reward = non_score_reward.clone()
last_non_masked_index = mask.nonzero()[-1]
# reward is preference model score + KL penalty
reward[last_non_masked_index] += score
rewards.append(reward)
return torch.stack(rewards), torch.stack(non_score_rewards)
step3: 计算优势advantage, 先根据价值 和 rewards值来 估计优势adv
评论家critc模型需要计算优势,先估计价值,优势
#https://github.com/huggingface/trl/blob/830cadfc4c80bdced0d3753de392070e6760d1f5/trl/trainer/ppo\_trainer.py#L1122
def compute\_advantages(
self,
values: torch.FloatTensor,
rewards: torch.FloatTensor,
mask: torch.FloatTensor,
):
lastgaelam = 0
advantages_reversed = []
gen_len = rewards.shape[-1]
values = values * mask
rewards = rewards * mask
if self.config.whiten_rewards:
rewards = masked_whiten(rewards, mask, shift_mean=False)
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
# 这里对应到 根据next\_values以及当前values计算 $Adv(S\_{t},A\_{t})=V(S\_{t+1}) + R\_{t} - V(S\_{t})$
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
returns = advantages + values
advantages = masked_whiten(advantages, mask)
advantages = advantages.detach()
return values, advantages, returns
step4: 计算actor和critic模型的损失
在上面ppo的公式中:写到如下公式
actor模型基于
critic模型为response中的每个token计算一个预期收益,第i个预期收益记为values[i],它预估的是
critic模型基于期望价值和实际价值的 mse:
https://github.com/huggingface/trl/blob/830cadfc4c80bdced0d3753de392070e6760d1f5/trl/trainer/ppo_trainer.py#L1179
vpredclipped = clip_by_value(
vpreds,
values - self.config.cliprange_value,
values + self.config.cliprange_value,
)
# critic loss
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
# actor loss
ratio = torch.exp(logprobs - old_logprobs)
pg_losses = -advantages * ratio
# 还有个截断策略
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
loss = pg_loss + self.config.vf_coef * vf_loss
总结:
ref_model (参数不更新)
actor_model (参数更新,最终使用的模型)
reward_model (参数不更新)
value_model (用来估计期望价值,参数更新)