推理模型实战 | 如何训练自己的R1模型(下篇):GRPO训练

大模型向量数据库机器学习
  1. 引言 =====

今天继续从实战角度 介绍如何基于Unsloth训练私有、定制化的推理模型R1。具体到模型训练目标: 利用OpenR1的Math数据集,通过GRPO将Qwen3-4B-Base训练成推理模型

上篇中将聚焦于如何通过监督微调(SFT)技术,训练出一个能严格遵循自定义GRPO格式的基座模型。这一步至关重要,它将为后续训练提供一个高质量的初始化模型。本文作为下篇 ,我们将以此模型为起点,深入探讨如何应用GRPO算法进行进一步优化,最终打造出我们自己的目标R1模型。

上篇+下篇的完整代码 可以前往微信公众号"小窗幽记机器学习",回复"推理模型实战1 "获取。如想进一步与小编进一步交流也可以在公众号"小窗幽记机器学习"上添加小编好友。

往期推理模型相关文章回顾:

推理模型专题|一文纵览DeepSeek模型家族:从LLM到R1

推理模型专题|深度揭秘DeepSeek R1 背后的强化学习

推理模型专题|DeepSeek-R1如何用强化学习、冷启动和蒸馏,开启大模型训练新思路?

推理模型专题 | DeepSeek-R1比肩OpenAI满血版o1(技术报告解读)

推理模型专题|LLM推理中的强化学习及其实战:以GRPO为例

推理模型实战 | 如何训练自己的R1模型(上篇):GRPO前奏预微调SFT

更多大模型实战相关欢迎关注公众号"小窗幽记机器学习":

  1. 简介 =====

承接上文,我们将使用上一次预微调得到的模型,作为GRPO方法训练的基础模型。对于RL(Reforce Learning)强化学习,我们需要引入一些相应的背景知识。

强化学习的目标是:

  • 增加获得「好」结果的几率。
  • 降低出现「坏」结果的几率。

「好」结果与「坏」结果可以由人为去定义,没有绝对意义的「坏」结果。我们将通过不断的训练及奖励函数的约束,使得模型的输出离「坏」结果越来越远,离「好」结果越来越近。

本文将重点介绍如何构建Reward函数,设计实验方案,并开展具体的强化学习实践。

  1. 准备工作 =======

2.1 训练数据集

以下将用HuggingFace上 open-r1 数学数据集用于本次实验。

  
from datasets import load\_dataset  
dataset = load\_dataset("open-r1/DAPO-Math-17k-Processed", "en", split = "train")  

先简要分析一下数据格式,其中Open-r1的主要由数学算法思考题组成,其数据概览如下:

  
dataset  
# 数据集概览   
# Dataset({  
#     features: ['prompt', 'solution', 'data\_source', 'source\_prompt', 'ability', 'reward\_model', 'extra\_info'],  
#     num\_rows: 14116  
# })  
  
dataset[0]["prompt"]  
"""  
In triangle $ABC$, $\sin \angle A = \frac{4}{5}$ and $\angle A < 90^\circ$. Let $D$ be a point outside triangle $ABC$ such that $\angle BAD = \angle DAC$ and $\angle BDC = 90^\circ$. Suppose that $AD = 1$ and that $\frac{BD}{CD} = \frac{3}{2}$. If $AB + AC$ can be expressed in the form $\frac{a\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.  
中文解释:根据题目给定的三角形ABC的基本条件,计算互为质数的abc三数之和。  
"""   
  
dataset[0]["solution"]  
"""  
34  
"""  

2.2 构造训练Prompt数据

根据Prompt格式模版,构造训练用的数据格式。

  
dataset = dataset.map(lambda x: {  
    "prompt" : [  
        {"role": "system", "content": system\_prompt},  
        {"role": "user",   "content": x["prompt"]},  
    ],  
    "answer": extract\_hash\_answer(x["solution"]),  
})  

比如上述第一条示例数据,按照模板构造完的数据格式如下所示:

  
{'prompt': [{'content': 'You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start\_working\_out> and <end\_working\_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>',  
   'role': 'system'},  
  {'content': 'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$ be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$ and $\\angle BDC = 90^\\circ$. Suppose that $AD = 1$ and that $\\frac{BD}{CD} = \\frac{3}{2}$. If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.',  
   'role': 'user'}],  
 'solution': '34',  
 'data\_source': 'math\_dapo',  
 'source\_prompt': [{'content': 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.\n\nIn triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$ be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$ and $\\angle BDC = 90^\\circ$. Suppose that $AD = 1$ and that $\\frac{BD}{CD} = \\frac{3}{2}$. If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.\n\nRemember to put your answer on its own line after "Answer:".',  
   'role': 'user'}],  
 'ability': 'MATH',  
 'reward\_model': {'ground\_truth': '34', 'style': 'rule-lighteval/MATH\_v2'},  
 'extra\_info': {'index': '9a9b6eb4-a1cb-49d1-8c1e-62eaf2f74079'},  
 'answer': '34'}  

2.3 制定验证器

为了判断模型响应后的结果是否正确,我们需要准确提取出答案,让其能顺利通过后续的奖励函数。通常来讲,这两个模块是结合在同一个function内使用的,互相配合返回Reward分数。

首先,我们通过制定正则脚本来匹配训练完数据生成的reasoning模块以及答案answer模块的内容。

  
import re  
  
# 可以选择是否要添加EOS\_token匹配模块。  
solution\_end\_regex = r"</SOLUTION>[\s]{0,}" + \  
    "(?:" + re.escape(tokenizer.eos\_token) + ")?"  
  
match\_format = re.compile(  
    rf"{reasoning\_end}.*?"\  
    rf"{solution\_start}(.+?){solution\_end\_regex}"\  
    rf"[\s]{{0,}}$",  
    flags = re.MULTILINE | re.DOTALL  
)  

上述的正则匹配主要用于获取以下部分的结构,匹配时允许 <SOLUTION> 内部跨行,允许结尾是 <SOLUTION> + 空格 + <eos>,并严格要求匹配到文本结尾。

  
...<REASONING>...</REASONING>  
<SOLUTION> 这部分内容将被提取 </SOLUTION> [空格] [可选的eos\_token]  

  1. 设定奖励函数 =========

如何设定合适的奖励函数?

举个简单的例子。将你的生成结果输入到 ChatGPT 4o 或 Llama 3.1 (8B) 等 LLM 中,并设计一个奖励函数和验证器来评估它。例如,将你的生成结果输入到你选择的 LLM 中,并设置一条规则:「如果答案听起来太机械化,则扣 3 分。」这有助于根据质量标准优化输出。

我们也可以参考Unsloth官方于2025 AI Engineer大会上PPT的趣味讲解来讲述如何设定奖励函数。

想必我们大多数人童年都玩过一个叫吃豆人的游戏,游戏的规则是这样:在一个迷宫中,吃尽可能多的豆子,同时会有捕捉玩家的“警官”,如果玩家的角色正面遭遇“警官”则会导致游戏结束。

picture.image

那么,根据吃豆人游戏的经验,我们根据游戏规则拆解,根据角色当前位置、最优路线的抉择来设定奖励函数。在这里我们制定以下奖励规则:

  1. 吃到豆且远离“警察” 奖励 1分。
  2. 没吃到豆但远离“警察”,无奖励。
  3. 往“警察”的方向走,则扣10分。

最后归一化这些操作的权重后,得到对应的奖励分数,如下图所示:

picture.image

那么,回到本次实践的内容,我们现在创建一个奖励函数来完全匹配格式——如果它成功了,我们奖励它3分;如果它失败了,但是通过计算每个符号,发现它至少部分遵循了格式,我们也想要适当奖励模型;最后,我们想要提取生成的答案,奖励或惩罚它!我们还会根据答案与真实答案的接近程度进行奖励,以下是这部分的代码逻辑实现:

  
def match\_format\_exactly(completions, **kwargs):  
   "奖励答案是否正确"  
    scores = []  
    for completion in completions:  
        score = 0  
        response = completion[0]["content"]  
        # 如果匹配的格式正确  
        if match\_format.search(response) isnotNone: score += 3.0  
        scores.append(score)  
    return scores  
  
def match\_format\_approximately(completions, **kwargs):  
"计算匹配格式正确的数量"  
    scores = []  
    for completion in completions:  
        score = 0  
        response = completion[0]["content"]  
        # 计算有多少关键词被看到-如果出现过多<start><\end>的标签字段,那么就给予惩罚  
    # 如果看到只出现一次,则+0.5分,否则扣1分  
        # 不需要奖励<start\_working\_out>,因为我们总是把它加在前面  
        # score += 0.5 if response.count(reasoning\_start) == 1 else -1.0  
        score += 0.5if response.count(reasoning\_end)   == 1else-1.0  
        score += 0.5if response.count(solution\_start)  == 1else-1.0  
        score += 0.5if response.count(solution\_end)    == 1else-1.0  
        scores.append(score)  
    return scores  
  
def check\_answer(prompts, completions, answer, **kwargs):  
   "校验答案的正确性"  
    question = prompts[0][-1]["content"]  
    responses = [completion[0]["content"] for completion in completions]  
  
    extracted\_responses = [  
        guess.group(1)  
        if (guess := match\_format.search(r)) isnotNoneelseNone \  
        for r in responses  
    ]  
  
    scores = []  
    for guess, true\_answer in zip(extracted\_responses, answer):  
        score = 0  
        if guess isNone:  
            scores.append(-2.0)  
            continue  
        # 完全正确的答案获得5分  
        if guess == true\_answer:  
            score += 5.0  
        # 答案正确但结果带有空格分数减少  
        elif guess.strip() == true\_answer.strip():  
            score += 3.5  
        else:  
            # 当响应答案和真实答案相近的时候,也给予奖励分数  
            try:  
                ratio = float(guess) / float(true\_answer)  
                if   ratio >= 0.9and ratio <= 1.1: score += 2.0  
                elif ratio >= 0.8and ratio <= 1.2: score += 1.5  
                else: score -= 2.5# 惩罚接近正确但还是错误的答案  
            except:  
                score -= 4.5# 惩罚完全偏离的答案  
        scores.append(score)  
    return scores  

有时候答案可能不会是完美的数字,我们需要从答案中去匹配到准确的数字结果。我们通过正则匹配来实现这个操作:

  
match\_numbers = re.compile(  
    solution\_start + r".*?[\s]{0,}([-]?[\d\.\,]{1,})",  
    flags = re.MULTILINE | re.DOTALL  
)  
global PRINTED\_TIMES  
PRINTED\_TIMES = 0  
global PRINT\_EVERY\_STEPS  
PRINT\_EVERY\_STEPS = 5  
  
def check\_numbers(prompts, completions, answer, **kwargs):  
    question = prompts[0][-1]["content"]  
    responses = [completion[0]["content"] for completion in completions]  
  
    extracted\_responses = [  
        guess.group(1)  
        if (guess := match\_numbers.search(r)) isnotNoneelseNone \  
        for r in respons es  
    ]  
  
    scores = []  
    global PRINTED\_TIMES  
    global PRINT\_EVERY\_STEPS  
    # 打印生成的响应和真实的答案,校验是否有不同  
    if PRINTED\_TIMES % PRINT\_EVERY\_STEPS == 0:  
        print(  
            '*'*20 + f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted\_responses[0]}"  
        )  
    PRINTED\_TIMES += 1  
  
    for guess, true\_answer in zip(extracted\_responses, answer):  
        if guess isNone:  
            scores.append(-2.5)  
            continue  
        # 转换成数字  
        try:  
            true\_answer = float(true\_answer.strip())  
            # 去除像123,456这样的结果出现  
            guess       = float(guess.strip().replace(",", ""))  
            scores.append(3.5if guess == true\_answer else-1.5)  
        except:  
            scores.append(0)  
            continue  
    return scores  

修饰过长的Prompt,获得仅有最大长度90%的Prompt数据,这样我们就不会不小心截断它们,我们将删除前10%的过长Prompt。

  
tokenized = dataset.map(  
    lambda x: {"tokens" : tokenizer.apply\_chat\_template(x["prompt"], add\_generation\_prompt = True, tokenize = True)},  
    batched = True,  
)  
print(tokenizer.decode(tokenized[0]["tokens"]))  
tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])})  
  
import numpy as np  
maximum\_length = int(np.quantile(tokenized["L"], 0.9))  
print("Max Length = ", maximum\_length)  
  
# 过滤得到小于最大长度90%的Prompt数据  
dataset = dataset.select(np.where(np.array(tokenized["L"]) <= maximum\_length)[0])  
del tokenized  

  1. 模型训练 =======

现在我们开始设置GRPO的训练器和配置,这一步的操作跟我们之前的预微调类似。

  
max\_prompt\_length = maximum\_length + 1 # + 1 just in case!  
max\_completion\_length = max\_seq\_length - max\_prompt\_length  
  
from vllm import SamplingParams  
vllm\_sampling\_params = SamplingParams(  
    min\_p = 0.1,  
    top\_p = 1.0,  
    top\_k = -1,  
    seed = 3407,  
    stop = [tokenizer.eos\_token],  
    include\_stop\_str\_in\_output = True,  
)  
  
from trl import GRPOConfig, GRPOTrainer  
training\_args = GRPOConfig(  
    vllm\_sampling\_params = vllm\_sampling\_params,  
    temperature = 1.0,  
    learning\_rate = 5e-6,  
    weight\_decay = 0.01,  
    warmup\_ratio = 0.1,  
    lr\_scheduler\_type = "linear",  
    optim = "adamw\_8bit",  
    logging\_steps = 1,  
    per\_device\_train\_batch\_size = 1,  
    gradient\_accumulation\_steps = 1, # 调节至4,梯度下降更平稳  
    num\_generations = 4, # 如果OOM就降低生成数量  
    max\_prompt\_length = max\_prompt\_length,  
    max\_completion\_length = max\_completion\_length,  
    # num\_train\_epochs = 1, # 设置为1完整跑完一轮epoch  
    max\_steps = 100,  
    save\_steps = 100,  
    report\_to = "none",   
    output\_dir = "outputs",  
    # 如果需要导入额外的验证集验证实验,加入以下参数  
    # fp16\_full\_eval = True,  
    # per\_device\_eval\_batch\_size = 4,  
    # eval\_accumulation\_steps = 1,  
    # eval\_strategy = "steps",  
    # eval\_steps = 1,  
)  
  
# Training+evaluation 过程  
# new\_dataset = dataset.train\_test\_split(test\_size = 0.01)  
  
trainer = GRPOTrainer(  
    model = model,  
    processing\_class = tokenizer,  
    reward\_funcs = [  
        match\_format\_exactly,  
        match\_format\_approximately,  
        check\_answer,  
        check\_numbers,  
    ],  
    args = training\_args,  
    train\_dataset = dataset,  
    # 可选是否加入 Training+evaluation 过程  
    # train\_dataset = new\_dataset["train"],  
    # eval\_dataset = new\_dataset["test"],  
)  
trainer.train()  

完成上述所有步骤任务后,我们可以开始观察训练的效果。根据过往的经验来看,前100轮step的训练我们往往得不到任何提升,但是请保持耐心,当训练任务流程执行到150~200步的时候,我们会慢慢看到RL带来的收益。Unsloth的框架也会输出每轮训练的log,如下方表格所示,图示为前10轮的训练log:

| Step | Training Loss | reward | reward_std | completion_length | kl | rewards / match_format_exactly | rewards / match_format_approximately | rewards / check_answer | rewards / check_numbers | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | 1 | 0.006200 | -7.500000 | 0.000000 | 1846.000000 | 0.155724 | 0.000000 | -3.000000 | -2.000000 | -2.500000 | | 2 | 0.005200 | -5.500000 | 4.000000 | 1754.000000 | 0.130613 | 0.750000 | -1.875000 | -2.125000 | -2.250000 | | 3 | 0.006300 | -5.500000 | 4.000000 | 1826.000000 | 0.156329 | 0.750000 | -1.875000 | -2.125000 | -2.250000 | | 4 | 0.007100 | -7.500000 | 0.000000 | 1846.000000 | 0.176596 | 0.000000 | -3.000000 | -2.000000 | -2.500000 | | 5 | 0.007500 | 13.000000 | 0.000000 | 1297.500000 | 0.188479 | 3.000000 | 1.500000 | 5.000000 | 3.500000 | | 6 | 0.004800 | -7.500000 | 0.000000 | 1846.000000 | 0.119617 | 0.000000 | -3.000000 | -2.000000 | -2.500000 | | 7 | 0.006200 | -5.500000 | 4.000000 | 1679.000000 | 0.154963 | 0.750000 | -1.875000 | -2.125000 | -2.250000 | | 8 | 0.004200 | -7.500000 | 0.000000 | 1846.000000 | 0.105323 | 0.000000 | -3.000000 | -2.000000 | -2.500000 | | 9 | 0.006100 | -7.500000 | 0.000000 | 1846.000000 | 0.152696 | 0.000000 | -3.000000 | -2.000000 | -2.500000 | | 10 | 0.004900 | -0.875000 | 9.672771 | 1784.750000 | 0.123577 | 1.500000 | -0.750000 | -0.875000 | -0.750000 |

  1. 模型测试 =======

5.1 未做GRPO的模型

我们先测试一下没经过GRPO训练的模型,即Qwen3-4B-Base模型的效果

  
text = "What is the sqrt of 101?"  
  
from vllm import SamplingParams  
sampling\_params = SamplingParams(  
    temperature = 1.0,  
    top\_k = 50,  
    max\_tokens = 1024,  
)  
output = model.fast\_generate(  
    [text],  
    sampling\_params = sampling\_params,  
    lora\_request = None,  
)[0].outputs[0].text  
  

结果:

picture.image

可以看出,未经过推理训练的模型输出根号101的结果无法得出正确的结果。

5.2 做GRPO后的模型

我们导出上述已经训练完成模型的LoRA权重,并验证LoRA模型是否成功导出我们GRPO训练的结果。

  
model.save\_lora("grpo\_saved\_lora")  
  
# 验证模型  
from safetensors import safe\_open  
  
tensors = {}  
with safe\_open("grpo\_saved\_lora/adapter\_model.safetensors", framework = "pt") as f:  
    # Verify both A and B are non zero  
    for key in f.keys():  
        tensor = f.get\_tensor(key)  
        n\_zeros = (tensor == 0).sum() / tensor.numel()  
        assert(n\_zeros.item() != tensor.numel())  
  
messages = [  
    {"role": "system", "content": system\_prompt},  
    {"role": "user",   "content": "What is the sqrt of 101?"},  
]  
  
text = tokenizer.apply\_chat\_template(  
    messages,  
    add\_generation\_prompt = True, # Must add for generation  
    tokenize = False,  
)  
from vllm import SamplingParams  
sampling\_params = SamplingParams(  
    temperature = 1.0,  
    top\_k = 50,  
    max\_tokens = 2048,  
)  
output = model.fast\_generate(  
    text,  
    sampling\_params = sampling\_params,  
    lora\_request = model.load\_lora("grpo\_saved\_lora"),  
)[0].outputs[0].text  
  

最后验证该数学题的结果如下图所示,根号101的结果也按照我们预期的方式推理计算得到。

picture.image

这个结果也证明了我们的推理模型计算数学问题的结果相比以往base模型要好得多,虽然它并不总是正确的,因为我们只训练了一个小时左右。如果我们延长序列长度并训练更长时间,它的结果会更好!

  1. 总结 =====

本文详细介绍了如何基于 Unsloth 框架,从零开始构建属于自己的推理 R1 模型的完整流程。通过结合 Qwen3-4B-Base 模型、NVIDIA开放数学推理数据集,以及自定义的推理格式(GRPO),我们演示了一个高效、低内存消耗、支持大模型的微调范式。希望能为小伙伴们在学习和工作提供一些参考或借鉴。

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
大规模高性能计算集群优化实践
随着机器学习的发展,数据量和训练模型都有越来越大的趋势,这对基础设施有了更高的要求,包括硬件、网络架构等。本次分享主要介绍火山引擎支撑大规模高性能计算集群的架构和优化实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论