动手复现 DeepSeek-R1-Zero 顿悟时刻

动手复现 DeepSeek-R1-Zero 顿悟时刻

开源复现 DeepSeek-R1-Zero 的工程

picture.image

在这里插入图片描述

  • • GRPO 开源实现
  • • trl grpo trainer :TRL 的 GRPOTrainer 实现,目前尚未发版,需要安装 trl 的 main 分支。
  • • veRL :字节开源的 RL 实现,也支持 GRPO reward function。
  • • R1 复现项目、数据集
  • • open-r1 :包括数据合成、SFT、GRPO RL 的代码。
  • • TinyZero :在简单的类24点问题上复刻 R1 RL 范式。
  • • SkyT1 :蒸馏的 QwQ 的数据实现的 o1-like 模型。
  • • HuatuoGPT-o1 :医学领域复刻 o1(开放代码、数据、论文和模型),但是用的还是 reward model,效果提升很少。可以用 R1 RL 范式看看能否有明显提升。
  • • simpleRL-reason :在 8k MATH 数据集上复刻 R1-Zero 的范式
  • • open-r1-multimodal :R1 多模态的复刻项目
  • • open-thoughts :最成熟的 R1 复刻项目,已经发布了 Bespoke-Stratos-17k dataset 和 OpenThoughts-114k dataset 项目,仅经过 SFT 即可以逼近 R1-distill 模型
  • • R1-Distill-SFT :1.68M 条 R1 蒸馏数据集
  • • grpo_demo.py :基于 0.5B 模型的 RL demo,可以用来学习怎么训练

CountDown任务

数据集地址:https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4

countdown的任务是给定target和nums,要求基于加减乘除得到target。比如target=24,nums=[2,3,5,6] ,要求用这4个数来计算出24,1种方案比如:(6/2) * (3+5)=24。

picture.image

在这里插入图片描述

使用TinyZero 复现 DeepSeek-R1-Zero

原作者采用了2张H200实验,A100比H200的显存小不少,容易OOM,需要调整参数配置方可跑起来。建议最少用2张A100。

机器环境

用A100-40G-8GPU的机器复现,用4卡训练。机器内存有1.5T。


 
 
 
 
   
$ lsb\_release -a  
LSB Version:    core-9.20170808ubuntu1-noarch:security-9.20170808ubuntu1-noarch  
Distributor ID: Ubuntu  
Description:    Ubuntu 18.04.6 LTS  
Release:        18.04  
Codename:       bionic

 
 
 
 
   
$ nvidia-smi   
+-----------------------------------------------------------------------------+  
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |  
|-------------------------------+----------------------+----------------------+  
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |  
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |  
|                               |                      |               MIG M. |  
|===============================+======================+======================|  
|   0  NVIDIA A100-SXM...  On   | 00000000:17:00.0 Off |                    0 |  
| N/A   49C    P0    90W / 400W |  32242MiB / 40960MiB |     62%      Default |  
|                               |                      |             Disabled |  
+-------------------------------+----------------------+----------------------+  
|   1  NVIDIA A100-SXM...  On   | 00000000:1A:00.0 Off |                    0 |  
| N/A   49C    P0   127W / 400W |  32658MiB / 40960MiB |     32%      Default |  
|                               |                      |             Disabled |  
+-------------------------------+----------------------+----------------------+  
|   2  NVIDIA A100-SXM...  On   | 00000000:6B:00.0 Off |                    0 |  
| N/A   47C    P0   148W / 400W |  32724MiB / 40960MiB |     62%      Default |  
|                               |                      |             Disabled |  
+-------------------------------+----------------------+----------------------+  
|   3  NVIDIA A100-SXM...  On   | 00000000:6F:00.0 Off |                    0 |  
| N/A   51C    P0   140W / 400W |  32276MiB / 40960MiB |     51%      Default |  
|                               |                      |             Disabled |  
+-------------------------------+----------------------+----------------------+  
|   4  NVIDIA A100-SXM...  On   | 00000000:A9:00.0 Off |                    0 |  
| N/A   54C    P0   138W / 400W |  29516MiB / 40960MiB |     67%      Default |  
|                               |                      |             Disabled |  
+-------------------------------+----------------------+----------------------+  
|   5  NVIDIA A100-SXM...  On   | 00000000:AD:00.0 Off |                    0 |  
| N/A   49C    P0   125W / 400W |  29946MiB / 40960MiB |     69%      Default |  
|                               |                      |             Disabled |  
+-------------------------------+----------------------+----------------------+  
|   6  NVIDIA A100-SXM...  On   | 00000000:DB:00.0 Off |                    0 |  
| N/A   52C    P0   137W / 400W |  29954MiB / 40960MiB |     52%      Default |  
|                               |                      |             Disabled |  
+-------------------------------+----------------------+----------------------+  
|   7  NVIDIA A100-SXM...  On   | 00000000:DE:00.0 Off |                    0 |  
| N/A   53C    P0   130W / 400W |  29542MiB / 40960MiB |     53%      Default |  
|                               |                      |             Disabled |  
+-------------------------------+----------------------+----------------------+  
                                                                                 
+-----------------------------------------------------------------------------+  
| Processes:                                                                  |  
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |  
|        ID   ID                                                   Usage      |  
|=============================================================================|  
|    4   N/A  N/A     49888      C   xxxxxxxxx                       29516MiB |  
+-----------------------------------------------------------------------------+

 
 
 
 
   
$ free -h  
              total        used        free      shared  buff/cache   available  
Mem:           1.5T         97G        139G         70G        1.2T        1.2T  
Swap:            0B          0B          0B

 
 
 
 
   
$ lscpu  
Architecture:        x86\_64  
CPU op-mode(s):      32-bit, 64-bit  
Byte Order:          Little Endian  
CPU(s):              160  
On-line CPU(s) list: 0-159  
Thread(s) per core:  2  
Core(s) per socket:  20  
Socket(s):           4  
NUMA node(s):        4  
Vendor ID:           GenuineIntel  
CPU family:          6  
Model:               85  
Model name:          Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz  
Stepping:            7  
CPU MHz:             3199.999  
CPU max MHz:         3900.0000  
CPU min MHz:         1000.0000  
BogoMIPS:            5000.00  
Virtualization:      VT-x  
L1d cache:           32K  
L1i cache:           32K  
L2 cache:            1024K  
L3 cache:            28160K  
NUMA node0 CPU(s):   0-19,80-99  
NUMA node1 CPU(s):   20-39,100-119  
NUMA node2 CPU(s):   40-59,120-139  
NUMA node3 CPU(s):   60-79,140-159  
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant\_tsc art arch\_perfmon pebs bts rep\_good nopl xtopology nonstop\_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds\_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4\_1 sse4\_2 x2apic movbe popcnt tsc\_deadline\_timer aes xsave avx f16c rdrand lahf\_lm abm 3dnowprefetch cpuid\_fault epb cat\_l3 cdp\_l3 invpcid\_single intel\_ppin ssbd mba ibrs ibpb stibp ibrs\_enhanced tpr\_shadow vnmi flexpriority ept vpid ept\_ad fsgsbase tsc\_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt\_a avx512f avx512dq rdseed adx smap clflushopt clwb intel\_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm\_llc cqm\_occup\_llc cqm\_mbm\_total cqm\_mbm\_local dtherm ida arat pln pts pku avx512\_vnni md\_clear flush\_l1d arch\_capabilities

环境准备

TinyZero 是基于 veRL 框架实现的。


 
 
 
 
   
conda create -n zero python=3.9  
source activate zero  
  
pip3 install --no-cache-dir \  
    torch==2.4.0 \  
    accelerate \  
    codetiming \  
    dill \  
    hydra-core \  
    numpy \  
    pybind11 \  
    tensordict \  
    "transformers <= 4.46.0"  
  
# 在编译flash-attn的时候花了不少时间,记得先安装ninja能加快速度  
pip install ninja  
pip3 install --no-cache-dir flash-attn==2.7.0.post2 --no-build-isolation  
pip install importlib-metadata  
  
# vllm depends on ray, and veRL does not support ray > 2.37  
pip3 install --no-cache-dir vllm==0.6.3 ray==2.10  
  
pip install wandb IPython matplotlib  
pip install absl-py  
pip install astunparse

模型和数据

基于Qwen2.5-3B复现,也可以对比Qwen/Qwen2.5-3B-Instruct的效果。


 
 
 
 
   
pip install -U huggingface\_hub  
export HF\_ENDPOINT="https://hf-mirror.com"  
  
# model  
huggingface-cli download --resume-download Qwen/Qwen2.5-3B  
huggingface-cli download --resume-download Qwen/Qwen2.5-3B-Instruct  
  
# dataset  
huggingface-cli download --resume-download --repo-type dataset Jiayi-Pan/Countdown-Tasks-3to4

模型默认下载到 HF\_HUB\_CACHE 目录中,其地址为:


 
 
 
 
   
$ huggingface-cli  env  
  
Copy-and-paste the text below in your GitHub issue.  
  
- huggingface\_hub version: 0.29.1  
- Platform: Linux-5.10.0-1.0.0.28-x86\_64-with-glibc2.27  
- Python version: 3.9.19  
- Running in iPython ?: No  
- Running in notebook ?: No  
- Running in Google Colab ?: No  
- Running in Google Colab Enterprise ?: No  
- Token path ?: /root/.cache/huggingface/token  
- Has saved token ?: False  
- Configured git credential helpers: store  
- FastAI: N/A  
- Tensorflow: N/A  
- Torch: 2.4.0  
- Jinja2: 3.1.4  
- Graphviz: N/A  
- keras: N/A  
- Pydot: N/A  
- Pillow: 10.4.0  
- hf\_transfer: N/A  
- gradio: 4.44.1  
- tensorboard: N/A  
- numpy: 1.26.4  
- pydantic: 2.10.6  
- aiohttp: 3.10.3  
- ENDPOINT: https://huggingface.co  
- HF\_HUB\_CACHE: /root/.cache/huggingface/hub  
- HF\_ASSETS\_CACHE: /root/.cache/huggingface/assets  
- HF\_TOKEN\_PATH: /root/.cache/huggingface/token  
- HF\_STORED\_TOKENS\_PATH: /root/.cache/huggingface/stored\_tokens  
- HF\_HUB\_OFFLINE: False  
- HF\_HUB\_DISABLE\_TELEMETRY: False  
- HF\_HUB\_DISABLE\_PROGRESS\_BARS: None  
- HF\_HUB\_DISABLE\_SYMLINKS\_WARNING: False  
- HF\_HUB\_DISABLE\_EXPERIMENTAL\_WARNING: False  
- HF\_HUB\_DISABLE\_IMPLICIT\_TOKEN: False  
- HF\_HUB\_ENABLE\_HF\_TRANSFER: False  
- HF\_HUB\_ETAG\_TIMEOUT: 10  
- HF\_HUB\_DOWNLOAD\_TIMEOUT: 10

模型资源结构为:


 
 
 
 
   
$ tree pretrain\_model/models--Qwen--Qwen2.5-3B  
pretrain\_model/models--Qwen--Qwen2.5-3B  
├── blobs  
│   ├── 20024bfe7c83998e9aeaf98a0cd6a2ce6306c2f0  
│   ├── 38047c6284a25427043f0ab040f623a2a20dd093  
│   ├── 443909a61d429dff23010e5bddd28ff530edda00  
│   ├── 4783fe10ac3adce15ac8f358ef5462739852c569  
│   ├── 51410930d5cf19a998fdb17ef0c46e4d9ace72c97a975a3331395a8a500f5edb  
│   ├── 5a98c2f5cee568453413e49a600be32d3f010eaf  
│   ├── a6344aac8c09253b3b630fb776ae94478aa0275b  
│   ├── acd0065677f7999d568f327c94188adbbcb6e158  
│   ├── ba7e4c5637b9732dadcd66286ce48334e8b31e9e  
│   ├── cbbb3133034e192527e5321b4c679154e4819ab8  
│   ├── ed317377ac0abff39f17eec693ce664f4b8152af  
│   └── f9558df91d3b89b4826e4db37439edb52f1d62a4fd602685013e7ca6b9f60f8f  
├── refs  
│   └── main  
└── snapshots  
    └── 3aab1f1954e9cc14eb9509a215f9e5ca08227a9b # 模型目录  
        ├── config.json -> ../../blobs/acd0065677f7999d568f327c94188adbbcb6e158  
        ├── generation\_config.json -> ../../blobs/cbbb3133034e192527e5321b4c679154e4819ab8  
        ├── LICENSE -> ../../blobs/ed317377ac0abff39f17eec693ce664f4b8152af  
        ├── merges.txt -> ../../blobs/20024bfe7c83998e9aeaf98a0cd6a2ce6306c2f0  
        ├── model-00001-of-00002.safetensors -> ../../blobs/f9558df91d3b89b4826e4db37439edb52f1d62a4fd602685013e7ca6b9f60f8f  
        ├── model-00002-of-00002.safetensors -> ../../blobs/51410930d5cf19a998fdb17ef0c46e4d9ace72c97a975a3331395a8a500f5edb  
        ├── model.safetensors.index.json -> ../../blobs/38047c6284a25427043f0ab040f623a2a20dd093  
        ├── README.md -> ../../blobs/5a98c2f5cee568453413e49a600be32d3f010eaf  
        ├── tokenizer\_config.json -> ../../blobs/ba7e4c5637b9732dadcd66286ce48334e8b31e9e  
        ├── tokenizer.json -> ../../blobs/443909a61d429dff23010e5bddd28ff530edda00  
        └── vocab.json -> ../../blobs/4783fe10ac3adce15ac8f358ef5462739852c569  
  
4 directories, 24 files

数据集处理


 
 
 
 
   
python ./examples/data\_preprocess/countdown.py --local\_dir data/countdown

生成两个.parquet文件:


 
 
 
 
   
$ ls data/countdown/*parquet  
data/countdown/test.parquet  data/countdown/train.parquet  
  
$ du -sh  data/countdown/*parquet  
80K     data/countdown/test.parquet  
22M     data/countdown/train.parquet  
  
$ file data/countdown/train.parquet  
data/countdown/train.parquet: Apache Parquet

.parquet是二进制的,存储占用小,但不方便查看,可以修改countdown.py代码生成jsonl格式的:


 
 
 
 
   
125     train\_dataset.to\_parquet(os.path.join(local\_dir, 'train.parquet'))  
126     test\_dataset.to\_parquet(os.path.join(local\_dir, 'test.parquet'))  
127     #train\_dataset.to\_json(os.path.join(local\_dir, 'train.jsonl'))  
128     #test\_dataset.to\_json(os.path.join(local\_dir, 'test.jsonl'))

 
 
 
 
   
$ du -sh data/countdown/*jsonl  
816K    data/countdown/test.jsonl  
257M    data/countdown/train.jsonl

 
 
 
 
   
$ wc -l data/countdown/*jsonl  
     1024 data/countdown/test.jsonl  
   327680 data/countdown/train.jsonl  
   328704 total

我们查看下数据样例:


 
 
 
 
   
$ shuf -n1 data/countdown/train.jsonl | jq   
{  
"target": 88,  # answer  
"nums": [      # question  
    11,  
    76,  
    1  
  ],  
"data\_source": "countdown",  # 选择 rule-based reward function  
"prompt": [  # 单轮形式,包含 system+query  
    {  
      "content": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [11, 76, 1], create an equation that equals 88. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>",  
      "role": "user"  
    }  
  ],  
"ability": "math",  
"reward\_model": {    # reward function, groud-truth  
    "ground\_truth": {  
      "numbers": [  
        11,  
        76,  
        1  
      ],  
      "target": 88  
    },  
    "style": "rule"   # rule or model  
  },  
"extra\_info": {  
    "index": 311487,  # sample index  
    "split": "train"# shard name  
  }  
}

数据处理的核心逻辑如下,每条数据最重要的3个字段:prompt(包含question)、reward\_model (包含标准答案),data\_source(用于选择rule-based reward model)。


 
 
 
 
   
    def make\_map\_fn(split):  
        defprocess\_fn(example, idx):  
            question = make\_prefix(example, template\_type=args.template\_type)  
            solution = {  
                "target": example['target'],  
                "numbers": example['nums']  
            }  
            data = {  
                "data\_source": data\_source,  
                "prompt": [{  
                    "role": "user",  
                    "content": question,  
                }],  
                "ability": "math",  
                "reward\_model": {  
                    "style": "rule",  
                    "ground\_truth": solution  
                },  
                "extra\_info": {  
                    'split': split,  
                    'index': idx,  
                }  
            }  
            return data  
        return process\_fn

正对不同的llm,其prompt格式不同,代码逻辑如下:


 
 
 
 
   
def make\_prefix(dp, template\_type):  
    target = dp['target']  
    numbers = dp['nums']  
    # NOTE: also need to change reward\_score/countdown.py  
    if template\_type == 'base':  
        """This works for any base model"""  
        prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.  
User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.  
Assistant: Let me solve this step by step.  
<think>"""  
    elif template\_type == 'qwen-instruct':  
        """This works for Qwen Instruct Models"""  
        prefix = f"""<|im\_start|>system\nYou are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im\_end|>\n<|im\_start|>user\n Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im\_end|>\n<|im\_start|>assistant\nLet me solve this step by step.\n<think>"""  
    return prefix

veRL框架中处理步骤包括,

  • • 根据 data\_source 的取值拿到对应注册的 reward\_function
  • • 然后根据prompt中的question进行 rollout 生成response。question中已经包含了具体的nums和目标值target。
  • • response作为 reward\_function 的输入,解析出答案,并和 reward\_model 中的ground_truth target值进行比对,计算奖励值。

Rule-base Reward Function


 
 
 
 
   
(main\_task pid=81572) --------------------------------  
(main\_task pid=81572) Target: 96 | Numbers: [84 54 66]  
(main\_task pid=81572) Extracted equation: (84 - 54 + 66)  
(main\_task pid=81572) Solution string: A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.  
(main\_task pid=81572) User: Using the numbers [84, 54, 66], create an equation that equals 96. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.  
(main\_task pid=81572) Assistant: Let me solve this step by step.  
(main\_task pid=81572) <think> In order to create an equation that equals 96, we need to use three numbers and basic arithmetic operations. <answer> (84 - 54 + 66) </answer> </think><|endoftext|>  
(main\_task pid=81572) Correct equation: (84 - 54 + 66) = 96

总入口为verl/trainer/main\_ppo.py, 其中\_select\_rm\_score\_fn根据不同的data\_source,选择reward function:


 
 
 
 
   
from verl import DataProto  
import torch  
from verl.utils.reward\_score import gsm8k, math, multiply, countdown  
from verl.trainer.ppo.ray\_trainer import RayPPOTrainer  
  
  
# 根据数据集选择不同的 rule-based reward  
def\_select\_rm\_score\_fn(data\_source):  
    if data\_source == 'openai/gsm8k':  
        return gsm8k.compute\_score  
    elif data\_source == 'lighteval/MATH':  
        return math.compute\_score  
    elif"multiply"in data\_source or"arithmetic"in data\_source:  
        return multiply.compute\_score  
    elif"countdown"in data\_source:  
        return countdown.compute\_score  
    else:  
        raise NotImplementedError  
  
          
# Reward 封装,调用 \_select\_rm\_score\_fn 的rule-based reward func,计算reward。  
classRewardManager():  
    """The reward manager.  
    """  
  
    def\_\_init\_\_(self, tokenizer, num\_examine) -> None:  
        self.tokenizer = tokenizer  
        self.num\_examine = num\_examine  # the number of batches of decoded responses to print to the console  
  
    def\_\_call\_\_(self, data: DataProto):  
        """We will expand this function gradually based on the available datasets"""  
  
        # If there is rm score, we directly return rm score. Otherwise, we compute via rm\_score\_fn  
        if'rm\_scores'in data.batch.keys():  
            return data.batch['rm\_scores']  
  
        # reward of responses  
        reward\_tensor = torch.zeros\_like(data.batch['responses'], dtype=torch.float32)  
  
        already\_print\_data\_sources = {}  
  
        for i inrange(len(data)):  
            data\_item = data[i]  # DataProtoItem  
  
            prompt\_ids = data\_item.batch['prompts']  
            prompt\_length = prompt\_ids.shape[-1]  
  
            ## prompt len w/ one's mask  
            valid\_prompt\_length = data\_item.batch['attention\_mask'][:prompt\_length].sum()  
            # prompt\_ids is left-padded  
            valid\_prompt\_ids = prompt\_ids[-valid\_prompt\_length:]  
  
            response\_ids = data\_item.batch['responses']  
            ## response len  
            valid\_response\_length = data\_item.batch['attention\_mask'][prompt\_length:].sum()  
            valid\_response\_ids = response\_ids[:valid\_response\_length]  
  
            # decode  
            ## decode sequnce ids, which contains 1' maksed (prompt, response)  
            sequences = torch.cat((valid\_prompt\_ids, valid\_response\_ids))  
            ## decoded sequnce  
            sequences\_str = self.tokenizer.decode(sequences)  
  
            ground\_truth = data\_item.non\_tensor\_batch['reward\_model']['ground\_truth']  
  
            # select rm\_score  
            data\_source = data\_item.non\_tensor\_batch['data\_source']  
            compute\_score\_fn = \_select\_rm\_score\_fn(data\_source)  
                        ## comptue reward based on ground-truth and decoded response  
            score = compute\_score\_fn(solution\_str=sequences\_str, ground\_truth=ground\_truth)  
            reward\_tensor[i, valid\_response\_length - 1] = score  
  
            if data\_source notin already\_print\_data\_sources:  
                already\_print\_data\_sources[data\_source] = 0  
  
            if already\_print\_data\_sources[data\_source] < self.num\_examine:  
                already\_print\_data\_sources[data\_source] += 1  
                print(sequences\_str)  
  
        return reward\_tensor

countdown的reward function实现在verl/utils/reward\_score/countdown.py,如下:


 
 
 
 
   
import re  
import random  
import ast  
import operator  
  
  
defextract\_solution(solution\_str):  
    """Extract the equation from the solution string."""  
    # Remove everything before the first "Assistant:"  
    if"Assistant:"in solution\_str:  
        solution\_str = solution\_str.split("Assistant:", 1)[1]  
    elif"<|im\_start|>assistant"in solution\_str:  
        solution\_str = solution\_str.split("<|im\_start|>assistant", 1)[1]  
    else:  
        returnNone  
    solution\_str = solution\_str.split('\n')[-1]  
  
    answer\_pattern = r'<answer>(.*?)</answer>'  
    match = re.finditer(answer\_pattern, solution\_str)  
    matches = list(match)  
    if matches:  
        final\_answer = matches[-1].group(1).strip()  
    else:  
        final\_answer = None  
    return final\_answer  
  
  
defvalidate\_equation(equation\_str, available\_numbers):  
    """Validate that equation only uses available numbers and each number once."""  
    try:  
        # Extract all numbers from the equation  
        numbers\_in\_eq = [int(n) for n in re.findall(r'\d+', equation\_str)]  
          
        # Check if all numbers in equation are available  
        available\_numbers = sorted(available\_numbers)  
        numbers\_in\_eq = sorted(numbers\_in\_eq)  
          
        # Each number should be used exactly once  
        return numbers\_in\_eq == available\_numbers  
    except:  
        returnFalse  
  
  
defevaluate\_equation(equation\_str):  
    """Safely evaluate the arithmetic equation using eval() with precautions."""  
    try:  
        # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace  
        allowed\_pattern = r'^[\d+\-*/().\s]+$'  
        ifnot re.match(allowed\_pattern, equation\_str):  
            raise ValueError("Invalid characters in equation.")  
  
        # Evaluate the equation with restricted globals and locals  
        result = eval(equation\_str, {"\_\_builtins\_\_": None}, {})  
        return result  
    except Exception as e:  
        returnNone  
  
  
defcompute\_score(solution\_str, ground\_truth, method='strict', format\_score=0.1, score=1.):  
    """The scoring function for countdown task.  
      
    Args:  
        solution\_str: the solution text  
        ground\_truth: dictionary containing target number and available numbers  
        method: the method to extract the solution  
        format\_score: the score for correct format but wrong answer  
        score: the score for the correct answer  
    """  
    target = ground\_truth['target']  
    numbers = ground\_truth['numbers']  
      
    equation = extract\_solution(solution\_str=solution\_str)  
    do\_print = random.randint(1, 64) == 1  
      
    if do\_print:  
        print(f"--------------------------------")  
        print(f"Target: {target} | Numbers: {numbers}")  
        print(f"Extracted equation: {equation}")  
        print(f"Solution string: {solution\_str}")  
  
    if equation isNone:  
        if do\_print:  
            print(f"No equation found")  
        return0   # reward  
      
    # Validate equation uses correct numbers  
    ifnot validate\_equation(equation, numbers):  
        if do\_print:  
            print(f"Invalid equation")  
        return format\_score # reward  
          
    # Evaluate equation  
    try:  
        result = evaluate\_equation(equation)  
        if result isNone:  
            if do\_print:  
                print(f"Could not evaluate equation")  
            return format\_score # reward  
              
        ifabs(result - target) < 1e-5:  # Account for floating point precision  
            if do\_print:  
                print(f"Correct equation: {equation} = {result}")  
            return score # reward  
        else:  
            if do\_print:  
                print(f"Wrong result: equation = {result}, target = {target}")  
            return format\_score # reward  
    except:  
        if do\_print:  
            print(f"Error evaluating equation")  
        return format\_score # reward
  • • 调用 extract\_solution 抽取模型结构化的输出,即包裹在的结果,以上述为例即: 86 + 48 / 48
  • • 验证表达式是否正确 validate\_equation ,包括是否使用了要求的数字、且每个数字都用上并只用了1次、表达式准确可计算。
  • • 表达式无误则直接解析计算结果 evaluate\_equation ,简单粗暴的直接 eval(expr) 即可。
  • • Reward score是分档的,范围是[0,1]。
  • • 没有solution str的reward为0,
  • • 格式不对的 reward 为format_score=0.1,
  • • 结果计算对的reward是1.0。

这里的reward实现存在缺陷,先已原版的跑实验,后面优化看下效果。

训练脚本

veRL使用wandb来观察训练过程,需要提前wandb login, key 从 https://wandb.ai/settings 的 API Keys中获取 .

  • • A100-40G-4GPU的GRPO训练配置:
0
0
0
0
评论
未登录
暂无评论