TRL 关于 GRPO Trainer 的实现

大模型向量数据库机器学习
TRL 关于 GRPO Trainer的实现

Overview

TRL支持GRPO Trainer来训练语言模型,如论文 DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open language models 中所述。

Quick start

这个例子演示了如何使用GRPO方法训练模型。我们使用来自TLDR数据集的提示(忽略completion列!)训练Qwen 0.5B指令模型。您可以在这里查看数据集中的数据:

picture.image

在这里插入图片描述

下面是训练模型的脚本。


 
 
 
 
   
# train\_grpo.py  
from datasets import load\_dataset  
from trl import GRPOConfig, GRPOTrainer  
  
dataset = load\_dataset("trl-lib/tldr", split="train")  
  
# Define the reward function, which rewards completions that are close to 20 characters  
defreward\_len(completions, **kwargs):  
    return [-abs(20 - len(completion)) for completion in completions]  
  
training\_args = GRPOConfig(output\_dir="Qwen2-0.5B-GRPO", logging\_steps=10)  
trainer = GRPOTrainer(  
    model="Qwen/Qwen2-0.5B-Instruct",  
    reward\_funcs=reward\_len,  
    args=training\_args,  
    train\_dataset=dataset,  
)  
trainer.train()

使用如下命令执行脚本:


 
 
 
 
   
accelerate launch train\_grpo.py

分布在8个gpu,训练需要大约1天。

picture.image

在这里插入图片描述

How does it work?

通过PPO对语言模型进行微调大致包括三个步骤:

Rollout : The language model generates a response or continuation based on a query which could be the start of a sentence.

Evaluation : The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair . The optimization will aim at maximizing this value .

Optimization : This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences . This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don’t deviate too far from the reference language model . The active language model is then trained with PPO.

整个过程如下图所示:

picture.image

在这里插入图片描述

Looking deeper into the GRPO method

GRPO是一种 online 学习算法,即在训练过程中利用被训练模型本身生成的数据进行迭代改进。The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy 。为了理解GRPO的工作原理,可以将其分为四个主要步骤:Generating completionscomputing the advantageestimating the KL divergencecomputing the loss

picture.image

在这里插入图片描述

Generating completions

在每个训练步骤中,我们对一批prompt进行采样,并为每个prompt生成一组

个completions(记为

)。

Computing the advantage

对于

中每个序列,我们使用奖励模型计算奖励。为了与奖励模型的比较性质保持一致——通常是在同一问题的输出之间的比较数据集上训练的——计算advantage以反映这些相对比较。归一化如下:

这种方法为该方法命名为:Group Relative Policy Optimization (GRPO) .

Estimating the KL divergence

使用Schulman等人(2020)引入的近似器估计KL散度。近似器定义如下:

Computing the loss

目标是在确保模型与参考策略保持接近的同时最大化 advantage。因此,损失定义如下:

其中第一项表示 scaled advantage,第二项通过KL散度惩罚与reference policy的偏差。

在最初的论文中,这个公式被泛化到考虑到每个generation之后的多次更新,利用了 clipped surrogate objective

其中

通过限定

之间的 policy ratio,确保更新不会过度偏离 reference policy。在TRL中,正如在原始论文中一样,我们每个 generation 只进行一次更新,因此我们可以将损失简化为第一种形式。

Logged metrics

GRPO Trainer 记录以下指标:

  • completion\_length : The average completion length.
  • reward/{reward\_func\_name} : The reward computed by each reward function.
  • reward : The average reward.
  • reward\_std : The average standard deviation within reward groups.
  • kl : The average KL divergence between the model and the reference model calculated on completions.

Customization

Speed up training with vLLM-powered generation

Generation 通常是导致 online 训练缓慢的主要瓶颈。为了加速生成,可以使用vLLM,这是一个支持快速生成的库。要启用它,请在训练参数中传递use\_vllm=True


 
 
 
 
   
from trl import GRPOConfig  
  
training\_args = GRPOConfig(..., use\_vllm=True)

有关更多信息,请参见加快使用vLLM进行训练。

当使用vLLM时,需要额外的GPU专门用于生成。这意味着您需要至少两个可用的gpu,并且必须确保其中一个未被训练器使用。要实现这一点,使用--num\_processes <NUMBER\_OF\_GPUs - 1>运行训练。

例如,如果您有4个gpu,设置--num\_processes 3来分配3个gpu用于训练,同时保留一个用于生成。


 
 
 
 
   
accelerate launch --multi\_gpu --num\_processes 3 train\_grpo.py

picture.image

在这里插入图片描述

Using a custom reward function

GRPOTrainer支持使用自定义奖励函数,而不是密集的奖励模型。为了确保兼容性,您的奖励功能必须满足以下要求:

    1. Input arguments:
  • • The function must accept the following as keyword arguments: The easiest way to comply with this requirement is to use **kwargs in the function signature.
  • prompts (contains the prompts),
  • completions (contains the generated completions),
  • • All columns names (but prompt ) that the dataset may have. For example, if the dataset contains a column named ground\_truth , the function will be called with ground\_truth as a keyword argument.
  • • Depending on the dataset format, the input will vary:
  • • For standard format , prompts and completions will be lists of strings.
  • • For conversational format , prompts and completions will be lists of message dictionaries.
    1. Return value: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
Example 1: Reward longer completions

以下是 standard format 奖励 longer completions 的奖励函数示例:


 
 
 
 
   
def reward\_func(completions, **kwargs):  
    """Reward function that gives higher scores to longer completions."""  
    return [float(len(completion)) for completion in completions]

你可以这样测试它:


 
 
 
 
   
prompts = ["The sky is", "The sun is"]  
completions = [" blue.", " in the sky."]  
print(reward\_func(prompts=prompts, completions=completions))  
# [6.0, 12.0]
Example 2: Reward completions with specific format

下面是一个奖励函数的例子,它检查completion是否具有特定的格式。这个例子的灵感来自论文DeepSeek-R1中使用的格式奖励函数:通过强化学习激励llm中的推理能力。它是为会话格式设计的,其中prompts和completions由结构化消息组成。


 
 
 
 
   
import re  
  
def format\_reward\_func(completions, **kwargs):  
    """Reward function that checks if the completion has a specific format."""  
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"  
    completion\_contents = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, content) for content in completion\_contents]  
    return [1.0 if match else 0.0 for match in matches]

你可以这样测试这个函数:


 
 
 
 
   
prompts = [  
    [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],  
    [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],  
]  
completions = [  
    [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],  
    [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],  
]  
format\_reward\_func(prompts=prompts, completions=completions)
Example 3: Reward completions based on a reference

下面是一个奖励函数的例子,它检查completion是否正确。这个例子的灵感来自论文DeepSeek-R1中使用的准确性奖励函数。这个示例是为standard format设计的,其中数据集包含一个名为ground\_truth的列。


 
 
 
 
   
import re  
  
defreward\_func(completions, ground\_truth, **kwargs):  
    # Regular expression to capture content inside \boxed{}  
    matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]  
    contents = [match.group(1) ifmatchelse""formatchin matches]  
    # Reward 1 if the content is the same as the ground truth, 0 otherwise  
    return [1.0if c == gt else0.0for c, gt inzip(contents, ground\_truth)]

你可以这样测试这个函数:


 
 
 
 
   
prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]  
completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]  
ground\_truth = ["2", "5"]  
reward\_func(prompts=prompts, completions=completions, ground\_truth=ground\_truth)
Passing the reward function to the trainer

要使用自定义奖励功能,请将其传递给GRPOTrainer,如下所示:


 
 
 
 
   
from trl import GRPOTrainer  
  
trainer = GRPOTrainer(  
    reward\_funcs=reward\_func,  
    ...,  
)

如果你有多个奖励函数,你可以将它们作为列表传递:


 
 
 
 
   
from trl import GRPOTrainer  
  
trainer = GRPOTrainer(  
    reward\_funcs=[reward\_func1, reward\_func2],  
    ...,  
)

并且奖励将被计算为来自每个函数的奖励的总和,或者如果在配置中提供了reward\_weights则计算为加权和。

参考文献

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
高性能存储虚拟化方案 NVMe over Fabric 在火山引擎的演进
在云计算中,虚拟化存储扮演着重要角色,其中 iSCSI 协议在业界开放、流行多年。近年来,拥有更优性能的 NVMe over Fabrics 协议也得到了发展。本次分享介绍了 NVMe over Fabrics 在云原生和虚拟化方向的演进工作和成果。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论