Overview
TRL支持GRPO Trainer来训练语言模型,如论文 DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open language models 中所述。
Quick start
这个例子演示了如何使用GRPO方法训练模型。我们使用来自TLDR数据集的提示(忽略completion列!)训练Qwen 0.5B指令模型。您可以在这里查看数据集中的数据:
在这里插入图片描述
下面是训练模型的脚本。
# train\_grpo.py
from datasets import load\_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load\_dataset("trl-lib/tldr", split="train")
# Define the reward function, which rewards completions that are close to 20 characters
defreward\_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training\_args = GRPOConfig(output\_dir="Qwen2-0.5B-GRPO", logging\_steps=10)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward\_funcs=reward\_len,
args=training\_args,
train\_dataset=dataset,
)
trainer.train()
使用如下命令执行脚本:
accelerate launch train\_grpo.py
分布在8个gpu,训练需要大约1天。
在这里插入图片描述
How does it work?
通过PPO对语言模型进行微调大致包括三个步骤:
Rollout : The language model generates a response or continuation based on a query which could be the start of a sentence.
Evaluation : The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair . The optimization will aim at maximizing this value .
Optimization : This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences . This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don’t deviate too far from the reference language model . The active language model is then trained with PPO.
整个过程如下图所示:
在这里插入图片描述
Looking deeper into the GRPO method
GRPO是一种 online 学习算法,即在训练过程中利用被训练模型本身生成的数据进行迭代改进。The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy 。为了理解GRPO的工作原理,可以将其分为四个主要步骤:Generating completions 、computing the advantage 、estimating the KL divergence 和computing the loss 。
在这里插入图片描述
Generating completions
在每个训练步骤中,我们对一批prompt进行采样,并为每个prompt生成一组
个completions(记为
)。
Computing the advantage
对于
中每个序列,我们使用奖励模型计算奖励。为了与奖励模型的比较性质保持一致——通常是在同一问题的输出之间的比较数据集上训练的——计算advantage以反映这些相对比较。归一化如下:
这种方法为该方法命名为:Group Relative Policy Optimization (GRPO) .
Estimating the KL divergence
使用Schulman等人(2020)引入的近似器估计KL散度。近似器定义如下:
Computing the loss
目标是在确保模型与参考策略保持接近的同时最大化 advantage。因此,损失定义如下:
其中第一项表示 scaled advantage,第二项通过KL散度惩罚与reference policy的偏差。
在最初的论文中,这个公式被泛化到考虑到每个generation之后的多次更新,利用了 clipped surrogate objective :
其中
通过限定
和
之间的 policy ratio,确保更新不会过度偏离 reference policy。在TRL中,正如在原始论文中一样,我们每个 generation 只进行一次更新,因此我们可以将损失简化为第一种形式。
Logged metrics
GRPO Trainer 记录以下指标:
- •
completion\_length
: The average completion length. - •
reward/{reward\_func\_name}
: The reward computed by each reward function. - •
reward
: The average reward. - •
reward\_std
: The average standard deviation within reward groups. - •
kl
: The average KL divergence between the model and the reference model calculated on completions.
Customization
Speed up training with vLLM-powered generation
Generation 通常是导致 online 训练缓慢的主要瓶颈。为了加速生成,可以使用vLLM,这是一个支持快速生成的库。要启用它,请在训练参数中传递use\_vllm=True
。
from trl import GRPOConfig
training\_args = GRPOConfig(..., use\_vllm=True)
有关更多信息,请参见加快使用vLLM进行训练。
当使用vLLM时,需要额外的GPU专门用于生成。这意味着您需要至少两个可用的gpu,并且必须确保其中一个未被训练器使用。要实现这一点,使用--num\_processes <NUMBER\_OF\_GPUs - 1>
运行训练。
例如,如果您有4个gpu,设置--num\_processes 3
来分配3个gpu用于训练,同时保留一个用于生成。
accelerate launch --multi\_gpu --num\_processes 3 train\_grpo.py
在这里插入图片描述
Using a custom reward function
GRPOTrainer
支持使用自定义奖励函数,而不是密集的奖励模型。为了确保兼容性,您的奖励功能必须满足以下要求:
-
- Input arguments:
- • The function must accept the following as keyword arguments: The easiest way to comply with this requirement is to use **kwargs in the function signature.
- •
prompts
(contains the prompts), - •
completions
(contains the generated completions), - • All columns names (but
prompt
) that the dataset may have. For example, if the dataset contains a column namedground\_truth
, the function will be called withground\_truth
as a keyword argument.
- • Depending on the dataset format, the input will vary:
- • For
standard format
,
prompts
andcompletions
will be lists of strings. - • For
conversational format
,
prompts
andcompletions
will be lists of message dictionaries.
-
- Return value: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
Example 1: Reward longer completions
以下是 standard format 奖励 longer completions 的奖励函数示例:
def reward\_func(completions, **kwargs):
"""Reward function that gives higher scores to longer completions."""
return [float(len(completion)) for completion in completions]
你可以这样测试它:
prompts = ["The sky is", "The sun is"]
completions = [" blue.", " in the sky."]
print(reward\_func(prompts=prompts, completions=completions))
# [6.0, 12.0]
Example 2: Reward completions with specific format
下面是一个奖励函数的例子,它检查completion是否具有特定的格式。这个例子的灵感来自论文DeepSeek-R1中使用的格式奖励函数:通过强化学习激励llm中的推理能力。它是为会话格式设计的,其中prompts和completions由结构化消息组成。
import re
def format\_reward\_func(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion\_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion\_contents]
return [1.0 if match else 0.0 for match in matches]
你可以这样测试这个函数:
prompts = [
[{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
[{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
]
completions = [
[{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
[{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
]
format\_reward\_func(prompts=prompts, completions=completions)
Example 3: Reward completions based on a reference
下面是一个奖励函数的例子,它检查completion是否正确。这个例子的灵感来自论文DeepSeek-R1中使用的准确性奖励函数。这个示例是为standard format设计的,其中数据集包含一个名为ground\_truth
的列。
import re
defreward\_func(completions, ground\_truth, **kwargs):
# Regular expression to capture content inside \boxed{}
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
contents = [match.group(1) ifmatchelse""formatchin matches]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
return [1.0if c == gt else0.0for c, gt inzip(contents, ground\_truth)]
你可以这样测试这个函数:
prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
ground\_truth = ["2", "5"]
reward\_func(prompts=prompts, completions=completions, ground\_truth=ground\_truth)
Passing the reward function to the trainer
要使用自定义奖励功能,请将其传递给GRPOTrainer,如下所示:
from trl import GRPOTrainer
trainer = GRPOTrainer(
reward\_funcs=reward\_func,
...,
)
如果你有多个奖励函数,你可以将它们作为列表传递:
from trl import GRPOTrainer
trainer = GRPOTrainer(
reward\_funcs=[reward\_func1, reward\_func2],
...,
)
并且奖励将被计算为来自每个函数的奖励的总和,或者如果在配置中提供了reward\_weights
则计算为加权和。