开源复现 DeepSeek-R1-Zero 的工程
在这里插入图片描述
- • GRPO 开源实现
- •
trl grpo trainer
:TRL 的
GRPOTrainer实现,目前尚未发版,需要安装 trl 的 main 分支。 - • veRL :字节开源的 RL 实现,也支持 GRPO reward function。
- • R1 复现项目、数据集
- • open-r1 :包括数据合成、SFT、GRPO RL 的代码。
- • TinyZero :在简单的类24点问题上复刻 R1 RL 范式。
- • SkyT1 :蒸馏的 QwQ 的数据实现的 o1-like 模型。
- • HuatuoGPT-o1 :医学领域复刻 o1(开放代码、数据、论文和模型),但是用的还是 reward model,效果提升很少。可以用 R1 RL 范式看看能否有明显提升。
- • simpleRL-reason :在 8k MATH 数据集上复刻 R1-Zero 的范式
- • open-r1-multimodal :R1 多模态的复刻项目
- • open-thoughts :最成熟的 R1 复刻项目,已经发布了 Bespoke-Stratos-17k dataset 和 OpenThoughts-114k dataset 项目,仅经过 SFT 即可以逼近 R1-distill 模型
- • R1-Distill-SFT :1.68M 条 R1 蒸馏数据集
- • grpo_demo.py :基于 0.5B 模型的 RL demo,可以用来学习怎么训练
CountDown任务
数据集地址:https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4
countdown的任务是给定target和nums,要求基于加减乘除得到target。比如target=24,nums=[2,3,5,6] ,要求用这4个数来计算出24,1种方案比如:(6/2) * (3+5)=24。
在这里插入图片描述
使用TinyZero 复现 DeepSeek-R1-Zero
原作者采用了2张H200实验,A100比H200的显存小不少,容易OOM,需要调整参数配置方可跑起来。建议最少用2张A100。
机器环境
用A100-40G-8GPU的机器复现,用4卡训练。机器内存有1.5T。
$ lsb\_release -a
LSB Version: core-9.20170808ubuntu1-noarch:security-9.20170808ubuntu1-noarch
Distributor ID: Ubuntu
Description: Ubuntu 18.04.6 LTS
Release: 18.04
Codename: bionic
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06 Driver Version: 525.125.06 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:17:00.0 Off | 0 |
| N/A 49C P0 90W / 400W | 32242MiB / 40960MiB | 62% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM... On | 00000000:1A:00.0 Off | 0 |
| N/A 49C P0 127W / 400W | 32658MiB / 40960MiB | 32% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA A100-SXM... On | 00000000:6B:00.0 Off | 0 |
| N/A 47C P0 148W / 400W | 32724MiB / 40960MiB | 62% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA A100-SXM... On | 00000000:6F:00.0 Off | 0 |
| N/A 51C P0 140W / 400W | 32276MiB / 40960MiB | 51% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 4 NVIDIA A100-SXM... On | 00000000:A9:00.0 Off | 0 |
| N/A 54C P0 138W / 400W | 29516MiB / 40960MiB | 67% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 5 NVIDIA A100-SXM... On | 00000000:AD:00.0 Off | 0 |
| N/A 49C P0 125W / 400W | 29946MiB / 40960MiB | 69% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 6 NVIDIA A100-SXM... On | 00000000:DB:00.0 Off | 0 |
| N/A 52C P0 137W / 400W | 29954MiB / 40960MiB | 52% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 7 NVIDIA A100-SXM... On | 00000000:DE:00.0 Off | 0 |
| N/A 53C P0 130W / 400W | 29542MiB / 40960MiB | 53% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 4 N/A N/A 49888 C xxxxxxxxx 29516MiB |
+-----------------------------------------------------------------------------+
$ free -h
total used free shared buff/cache available
Mem: 1.5T 97G 139G 70G 1.2T 1.2T
Swap: 0B 0B 0B
$ lscpu
Architecture: x86\_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 160
On-line CPU(s) list: 0-159
Thread(s) per core: 2
Core(s) per socket: 20
Socket(s): 4
NUMA node(s): 4
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
Stepping: 7
CPU MHz: 3199.999
CPU max MHz: 3900.0000
CPU min MHz: 1000.0000
BogoMIPS: 5000.00
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 28160K
NUMA node0 CPU(s): 0-19,80-99
NUMA node1 CPU(s): 20-39,100-119
NUMA node2 CPU(s): 40-59,120-139
NUMA node3 CPU(s): 60-79,140-159
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant\_tsc art arch\_perfmon pebs bts rep\_good nopl xtopology nonstop\_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds\_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4\_1 sse4\_2 x2apic movbe popcnt tsc\_deadline\_timer aes xsave avx f16c rdrand lahf\_lm abm 3dnowprefetch cpuid\_fault epb cat\_l3 cdp\_l3 invpcid\_single intel\_ppin ssbd mba ibrs ibpb stibp ibrs\_enhanced tpr\_shadow vnmi flexpriority ept vpid ept\_ad fsgsbase tsc\_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt\_a avx512f avx512dq rdseed adx smap clflushopt clwb intel\_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm\_llc cqm\_occup\_llc cqm\_mbm\_total cqm\_mbm\_local dtherm ida arat pln pts pku avx512\_vnni md\_clear flush\_l1d arch\_capabilities
环境准备
TinyZero 是基于 veRL 框架实现的。
conda create -n zero python=3.9
source activate zero
pip3 install --no-cache-dir \
torch==2.4.0 \
accelerate \
codetiming \
dill \
hydra-core \
numpy \
pybind11 \
tensordict \
"transformers <= 4.46.0"
# 在编译flash-attn的时候花了不少时间,记得先安装ninja能加快速度
pip install ninja
pip3 install --no-cache-dir flash-attn==2.7.0.post2 --no-build-isolation
pip install importlib-metadata
# vllm depends on ray, and veRL does not support ray > 2.37
pip3 install --no-cache-dir vllm==0.6.3 ray==2.10
pip install wandb IPython matplotlib
pip install absl-py
pip install astunparse
模型和数据
基于Qwen2.5-3B复现,也可以对比Qwen/Qwen2.5-3B-Instruct的效果。
pip install -U huggingface\_hub
export HF\_ENDPOINT="https://hf-mirror.com"
# model
huggingface-cli download --resume-download Qwen/Qwen2.5-3B
huggingface-cli download --resume-download Qwen/Qwen2.5-3B-Instruct
# dataset
huggingface-cli download --resume-download --repo-type dataset Jiayi-Pan/Countdown-Tasks-3to4
模型默认下载到 HF\_HUB\_CACHE 目录中,其地址为:
$ huggingface-cli env
Copy-and-paste the text below in your GitHub issue.
- huggingface\_hub version: 0.29.1
- Platform: Linux-5.10.0-1.0.0.28-x86\_64-with-glibc2.27
- Python version: 3.9.19
- Running in iPython ?: No
- Running in notebook ?: No
- Running in Google Colab ?: No
- Running in Google Colab Enterprise ?: No
- Token path ?: /root/.cache/huggingface/token
- Has saved token ?: False
- Configured git credential helpers: store
- FastAI: N/A
- Tensorflow: N/A
- Torch: 2.4.0
- Jinja2: 3.1.4
- Graphviz: N/A
- keras: N/A
- Pydot: N/A
- Pillow: 10.4.0
- hf\_transfer: N/A
- gradio: 4.44.1
- tensorboard: N/A
- numpy: 1.26.4
- pydantic: 2.10.6
- aiohttp: 3.10.3
- ENDPOINT: https://huggingface.co
- HF\_HUB\_CACHE: /root/.cache/huggingface/hub
- HF\_ASSETS\_CACHE: /root/.cache/huggingface/assets
- HF\_TOKEN\_PATH: /root/.cache/huggingface/token
- HF\_STORED\_TOKENS\_PATH: /root/.cache/huggingface/stored\_tokens
- HF\_HUB\_OFFLINE: False
- HF\_HUB\_DISABLE\_TELEMETRY: False
- HF\_HUB\_DISABLE\_PROGRESS\_BARS: None
- HF\_HUB\_DISABLE\_SYMLINKS\_WARNING: False
- HF\_HUB\_DISABLE\_EXPERIMENTAL\_WARNING: False
- HF\_HUB\_DISABLE\_IMPLICIT\_TOKEN: False
- HF\_HUB\_ENABLE\_HF\_TRANSFER: False
- HF\_HUB\_ETAG\_TIMEOUT: 10
- HF\_HUB\_DOWNLOAD\_TIMEOUT: 10
模型资源结构为:
$ tree pretrain\_model/models--Qwen--Qwen2.5-3B
pretrain\_model/models--Qwen--Qwen2.5-3B
├── blobs
│ ├── 20024bfe7c83998e9aeaf98a0cd6a2ce6306c2f0
│ ├── 38047c6284a25427043f0ab040f623a2a20dd093
│ ├── 443909a61d429dff23010e5bddd28ff530edda00
│ ├── 4783fe10ac3adce15ac8f358ef5462739852c569
│ ├── 51410930d5cf19a998fdb17ef0c46e4d9ace72c97a975a3331395a8a500f5edb
│ ├── 5a98c2f5cee568453413e49a600be32d3f010eaf
│ ├── a6344aac8c09253b3b630fb776ae94478aa0275b
│ ├── acd0065677f7999d568f327c94188adbbcb6e158
│ ├── ba7e4c5637b9732dadcd66286ce48334e8b31e9e
│ ├── cbbb3133034e192527e5321b4c679154e4819ab8
│ ├── ed317377ac0abff39f17eec693ce664f4b8152af
│ └── f9558df91d3b89b4826e4db37439edb52f1d62a4fd602685013e7ca6b9f60f8f
├── refs
│ └── main
└── snapshots
└── 3aab1f1954e9cc14eb9509a215f9e5ca08227a9b # 模型目录
├── config.json -> ../../blobs/acd0065677f7999d568f327c94188adbbcb6e158
├── generation\_config.json -> ../../blobs/cbbb3133034e192527e5321b4c679154e4819ab8
├── LICENSE -> ../../blobs/ed317377ac0abff39f17eec693ce664f4b8152af
├── merges.txt -> ../../blobs/20024bfe7c83998e9aeaf98a0cd6a2ce6306c2f0
├── model-00001-of-00002.safetensors -> ../../blobs/f9558df91d3b89b4826e4db37439edb52f1d62a4fd602685013e7ca6b9f60f8f
├── model-00002-of-00002.safetensors -> ../../blobs/51410930d5cf19a998fdb17ef0c46e4d9ace72c97a975a3331395a8a500f5edb
├── model.safetensors.index.json -> ../../blobs/38047c6284a25427043f0ab040f623a2a20dd093
├── README.md -> ../../blobs/5a98c2f5cee568453413e49a600be32d3f010eaf
├── tokenizer\_config.json -> ../../blobs/ba7e4c5637b9732dadcd66286ce48334e8b31e9e
├── tokenizer.json -> ../../blobs/443909a61d429dff23010e5bddd28ff530edda00
└── vocab.json -> ../../blobs/4783fe10ac3adce15ac8f358ef5462739852c569
4 directories, 24 files
数据集处理
python ./examples/data\_preprocess/countdown.py --local\_dir data/countdown
生成两个.parquet文件:
$ ls data/countdown/*parquet
data/countdown/test.parquet data/countdown/train.parquet
$ du -sh data/countdown/*parquet
80K data/countdown/test.parquet
22M data/countdown/train.parquet
$ file data/countdown/train.parquet
data/countdown/train.parquet: Apache Parquet
.parquet是二进制的,存储占用小,但不方便查看,可以修改countdown.py代码生成jsonl格式的:
125 train\_dataset.to\_parquet(os.path.join(local\_dir, 'train.parquet'))
126 test\_dataset.to\_parquet(os.path.join(local\_dir, 'test.parquet'))
127 #train\_dataset.to\_json(os.path.join(local\_dir, 'train.jsonl'))
128 #test\_dataset.to\_json(os.path.join(local\_dir, 'test.jsonl'))
$ du -sh data/countdown/*jsonl
816K data/countdown/test.jsonl
257M data/countdown/train.jsonl
$ wc -l data/countdown/*jsonl
1024 data/countdown/test.jsonl
327680 data/countdown/train.jsonl
328704 total
我们查看下数据样例:
$ shuf -n1 data/countdown/train.jsonl | jq
{
"target": 88, # answer
"nums": [ # question
11,
76,
1
],
"data\_source": "countdown", # 选择 rule-based reward function
"prompt": [ # 单轮形式,包含 system+query
{
"content": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [11, 76, 1], create an equation that equals 88. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>",
"role": "user"
}
],
"ability": "math",
"reward\_model": { # reward function, groud-truth
"ground\_truth": {
"numbers": [
11,
76,
1
],
"target": 88
},
"style": "rule" # rule or model
},
"extra\_info": {
"index": 311487, # sample index
"split": "train"# shard name
}
}
数据处理的核心逻辑如下,每条数据最重要的3个字段:prompt(包含question)、reward\_model (包含标准答案),data\_source(用于选择rule-based reward model)。
def make\_map\_fn(split):
defprocess\_fn(example, idx):
question = make\_prefix(example, template\_type=args.template\_type)
solution = {
"target": example['target'],
"numbers": example['nums']
}
data = {
"data\_source": data\_source,
"prompt": [{
"role": "user",
"content": question,
}],
"ability": "math",
"reward\_model": {
"style": "rule",
"ground\_truth": solution
},
"extra\_info": {
'split': split,
'index': idx,
}
}
return data
return process\_fn
正对不同的llm,其prompt格式不同,代码逻辑如下:
def make\_prefix(dp, template\_type):
target = dp['target']
numbers = dp['nums']
# NOTE: also need to change reward\_score/countdown.py
if template\_type == 'base':
"""This works for any base model"""
prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.
Assistant: Let me solve this step by step.
<think>"""
elif template\_type == 'qwen-instruct':
"""This works for Qwen Instruct Models"""
prefix = f"""<|im\_start|>system\nYou are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im\_end|>\n<|im\_start|>user\n Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im\_end|>\n<|im\_start|>assistant\nLet me solve this step by step.\n<think>"""
return prefix
veRL框架中处理步骤包括,
- • 根据
data\_source的取值拿到对应注册的reward\_function。 - • 然后根据prompt中的question进行
rollout生成response。question中已经包含了具体的nums和目标值target。 - • response作为
reward\_function的输入,解析出答案,并和reward\_model中的ground_truth target值进行比对,计算奖励值。
Rule-base Reward Function
(main\_task pid=81572) --------------------------------
(main\_task pid=81572) Target: 96 | Numbers: [84 54 66]
(main\_task pid=81572) Extracted equation: (84 - 54 + 66)
(main\_task pid=81572) Solution string: A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
(main\_task pid=81572) User: Using the numbers [84, 54, 66], create an equation that equals 96. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.
(main\_task pid=81572) Assistant: Let me solve this step by step.
(main\_task pid=81572) <think> In order to create an equation that equals 96, we need to use three numbers and basic arithmetic operations. <answer> (84 - 54 + 66) </answer> </think><|endoftext|>
(main\_task pid=81572) Correct equation: (84 - 54 + 66) = 96
总入口为verl/trainer/main\_ppo.py, 其中\_select\_rm\_score\_fn根据不同的data\_source,选择reward function:
from verl import DataProto
import torch
from verl.utils.reward\_score import gsm8k, math, multiply, countdown
from verl.trainer.ppo.ray\_trainer import RayPPOTrainer
# 根据数据集选择不同的 rule-based reward
def\_select\_rm\_score\_fn(data\_source):
if data\_source == 'openai/gsm8k':
return gsm8k.compute\_score
elif data\_source == 'lighteval/MATH':
return math.compute\_score
elif"multiply"in data\_source or"arithmetic"in data\_source:
return multiply.compute\_score
elif"countdown"in data\_source:
return countdown.compute\_score
else:
raise NotImplementedError
# Reward 封装,调用 \_select\_rm\_score\_fn 的rule-based reward func,计算reward。
classRewardManager():
"""The reward manager.
"""
def\_\_init\_\_(self, tokenizer, num\_examine) -> None:
self.tokenizer = tokenizer
self.num\_examine = num\_examine # the number of batches of decoded responses to print to the console
def\_\_call\_\_(self, data: DataProto):
"""We will expand this function gradually based on the available datasets"""
# If there is rm score, we directly return rm score. Otherwise, we compute via rm\_score\_fn
if'rm\_scores'in data.batch.keys():
return data.batch['rm\_scores']
# reward of responses
reward\_tensor = torch.zeros\_like(data.batch['responses'], dtype=torch.float32)
already\_print\_data\_sources = {}
for i inrange(len(data)):
data\_item = data[i] # DataProtoItem
prompt\_ids = data\_item.batch['prompts']
prompt\_length = prompt\_ids.shape[-1]
## prompt len w/ one's mask
valid\_prompt\_length = data\_item.batch['attention\_mask'][:prompt\_length].sum()
# prompt\_ids is left-padded
valid\_prompt\_ids = prompt\_ids[-valid\_prompt\_length:]
response\_ids = data\_item.batch['responses']
## response len
valid\_response\_length = data\_item.batch['attention\_mask'][prompt\_length:].sum()
valid\_response\_ids = response\_ids[:valid\_response\_length]
# decode
## decode sequnce ids, which contains 1' maksed (prompt, response)
sequences = torch.cat((valid\_prompt\_ids, valid\_response\_ids))
## decoded sequnce
sequences\_str = self.tokenizer.decode(sequences)
ground\_truth = data\_item.non\_tensor\_batch['reward\_model']['ground\_truth']
# select rm\_score
data\_source = data\_item.non\_tensor\_batch['data\_source']
compute\_score\_fn = \_select\_rm\_score\_fn(data\_source)
## comptue reward based on ground-truth and decoded response
score = compute\_score\_fn(solution\_str=sequences\_str, ground\_truth=ground\_truth)
reward\_tensor[i, valid\_response\_length - 1] = score
if data\_source notin already\_print\_data\_sources:
already\_print\_data\_sources[data\_source] = 0
if already\_print\_data\_sources[data\_source] < self.num\_examine:
already\_print\_data\_sources[data\_source] += 1
print(sequences\_str)
return reward\_tensor
countdown的reward function实现在verl/utils/reward\_score/countdown.py,如下:
import re
import random
import ast
import operator
defextract\_solution(solution\_str):
"""Extract the equation from the solution string."""
# Remove everything before the first "Assistant:"
if"Assistant:"in solution\_str:
solution\_str = solution\_str.split("Assistant:", 1)[1]
elif"<|im\_start|>assistant"in solution\_str:
solution\_str = solution\_str.split("<|im\_start|>assistant", 1)[1]
else:
returnNone
solution\_str = solution\_str.split('\n')[-1]
answer\_pattern = r'<answer>(.*?)</answer>'
match = re.finditer(answer\_pattern, solution\_str)
matches = list(match)
if matches:
final\_answer = matches[-1].group(1).strip()
else:
final\_answer = None
return final\_answer
defvalidate\_equation(equation\_str, available\_numbers):
"""Validate that equation only uses available numbers and each number once."""
try:
# Extract all numbers from the equation
numbers\_in\_eq = [int(n) for n in re.findall(r'\d+', equation\_str)]
# Check if all numbers in equation are available
available\_numbers = sorted(available\_numbers)
numbers\_in\_eq = sorted(numbers\_in\_eq)
# Each number should be used exactly once
return numbers\_in\_eq == available\_numbers
except:
returnFalse
defevaluate\_equation(equation\_str):
"""Safely evaluate the arithmetic equation using eval() with precautions."""
try:
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
allowed\_pattern = r'^[\d+\-*/().\s]+$'
ifnot re.match(allowed\_pattern, equation\_str):
raise ValueError("Invalid characters in equation.")
# Evaluate the equation with restricted globals and locals
result = eval(equation\_str, {"\_\_builtins\_\_": None}, {})
return result
except Exception as e:
returnNone
defcompute\_score(solution\_str, ground\_truth, method='strict', format\_score=0.1, score=1.):
"""The scoring function for countdown task.
Args:
solution\_str: the solution text
ground\_truth: dictionary containing target number and available numbers
method: the method to extract the solution
format\_score: the score for correct format but wrong answer
score: the score for the correct answer
"""
target = ground\_truth['target']
numbers = ground\_truth['numbers']
equation = extract\_solution(solution\_str=solution\_str)
do\_print = random.randint(1, 64) == 1
if do\_print:
print(f"--------------------------------")
print(f"Target: {target} | Numbers: {numbers}")
print(f"Extracted equation: {equation}")
print(f"Solution string: {solution\_str}")
if equation isNone:
if do\_print:
print(f"No equation found")
return0 # reward
# Validate equation uses correct numbers
ifnot validate\_equation(equation, numbers):
if do\_print:
print(f"Invalid equation")
return format\_score # reward
# Evaluate equation
try:
result = evaluate\_equation(equation)
if result isNone:
if do\_print:
print(f"Could not evaluate equation")
return format\_score # reward
ifabs(result - target) < 1e-5: # Account for floating point precision
if do\_print:
print(f"Correct equation: {equation} = {result}")
return score # reward
else:
if do\_print:
print(f"Wrong result: equation = {result}, target = {target}")
return format\_score # reward
except:
if do\_print:
print(f"Error evaluating equation")
return format\_score # reward
- • 调用
extract\_solution抽取模型结构化的输出,即包裹在的结果,以上述为例即:86 + 48 / 48 - • 验证表达式是否正确
validate\_equation,包括是否使用了要求的数字、且每个数字都用上并只用了1次、表达式准确可计算。 - • 表达式无误则直接解析计算结果
evaluate\_equation,简单粗暴的直接eval(expr)即可。 - • Reward score是分档的,范围是[0,1]。
- • 没有solution str的reward为0,
- • 格式不对的 reward 为format_score=0.1,
- • 结果计算对的reward是1.0。
这里的reward实现存在缺陷,先已原版的跑实验,后面优化看下效果。
训练脚本
veRL使用wandb来观察训练过程,需要提前wandb login, key 从 https://wandb.ai/settings 的 API Keys中获取 .
- • A100-40G-4GPU的GRPO训练配置:
