本文将介绍一下如何使用 Unsloth 框架和 TRL (Transformer Reinforcement Learning) 库,通过 GRPO (Group Relative Policy Optimization) 强化学习算法对 Qwen 2.5 (3B) 大模型进行微调(Fine-tuning)。不讲原理,直接上代码:
整个代码实现了以下主要功能:
-
环境配置:安装并配置 Unsloth、vLLM 和 TRL 等高性能训练库。
-
模型加载与量化:加载 Qwen 2.5-3B-Instruct 模型,并使用 4-bit 量化(QLoRA)以降低显存占用,同时启用 vLLM 进行快速推理。
-
数据集准备:加载 GSM8K(小学数学)数据集,并将其格式化为包含 和 XML 标签的 Prompt,旨在训练模型具备“思维链”(Chain-of-Thought)推理能力。
-
定义奖励函数 (Reward Functions):定义了一组规则来评价模型的输出,包括:答案准确性、是否为整数、XML 格式是否规范等。这是强化学习(RL)的核心部分。
-
GRPO 训练:使用 GRPO 算法进行训练。GRPO 会让模型针对同一个问题生成多个回答,通过对比这些回答的奖励分数来优化策略,而不需要额外的价值模型(Value Model)。
-
保存与推理:保存微调后的 LoRA 权重,并演示如何加载权重进行推理测试。
一、环境安装与设置
import os, numpy
# 设置环境变量,让 Unsloth 在 vLLM 中预留更多显存用于上下文
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
# 获取当前 numpy 版本以防止依赖冲突
numpy_version = f"numpy=={numpy.__version__}"
# Install dependencies with numpy version preservation
# 安装 Unsloth 及其依赖(Unsloth 用于加速训练,vLLM 用于加速推理)
!uv pip install unsloth_zoo
!uv pip install --upgrade unsloth vllm==0.9.2 {numpy_version} torchvision bitsandbytes xformers
!uv pip install triton==3.2.0
!uv pip install transformers==4.55.4
!uv pip install --no-deps trl==0.22.2
二、加载Model和Tokenizer
from unsloth import FastLanguageModel
import torch
# 设置最大上下文长度
max_seq_length = 1024
# 加载预训练模型和分词器
# 功能:加载 Qwen2.5-3B-Instruct 模型,使用 4-bit 量化加载以节省显存,开启 fast_inference (vLLM)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen2.5-3B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True, # 4bit 量化加载
fast_inference = True, # 启用 vLLM 快速推理引擎
max_lora_rank = 8, # LoRA 秩
gpu_memory_utilization = 0.9, # 显存利用率上限
)
三、配置 LoRA (低秩适应)
# 配置 PEFT (Parameter-Efficient Fine-Tuning)
# 功能:将模型转换为 LoRA 模式,只训练新增的少量参数,冻结原模型参数
model = FastLanguageModel.get_peft_model(
model,
r = 8, # LoRA 的秩
# 指定需要应用 LoRA 的模块(注意力层和前馈网络层)
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = 8,
use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存
random_state = 1234,
)
四、数据集处理与格式化
import re
from datasets import load_dataset, Dataset
# 系统提示词,强制模型使用特定的 XML 格式输出推理过程和答案
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# 定义 XML 格式模板
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
# 函数:从模型输出中提取 XML 标签内的答案
def extract_xml_answer(text):
if "" not in text or "" not in text:
return ""
return text.split(" ", 1)[-1].split(" ", 1)[0].strip()
# 函数:从 GSM8K 数据集的原始答案字段中提取最终数值(通常在 #### 之后)
def extract_hash_answer(text):
return text.split("####")[-1].strip() if "####" in text else None
# 函数:加载并预处理 GSM8K 数据集
# 功能:加载 OpenAI 的 GSM8K 数据集,并将每个样本转化为包含 system prompt 和 user prompt 的对话格式
def get_gsm8k_dataset(split = "train"):
data = load_dataset("openai/gsm8k", "main")[split]
return data.map(
lambda x: {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": x["question"]},
],
"answer": extract_hash_answer(x["answer"]), # 提取标准答案用于后续奖励计算
}
)
# 加载处理好的数据集
dataset = get_gsm8k_dataset()
五、定义奖励函数 (Reward Functions)
这是 GRPO 的核心,模型生成的每个结果都会经过这些函数打分。
# 奖励函数 1:正确性奖励
# 功能:检查模型生成的答案(从 XML 中提取)是否与标准答案完全一致。正确得 2.0 分,否则 0 分。
def correctness_reward_func(prompts, completions, answer, **kwargs):
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
# 打印日志方便调试
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
# 奖励函数 2:整数奖励
# 功能:检查提取出的答案是否为数字。是则得 0.5 分。
def int_reward_func(completions, **kwargs):
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
# 奖励函数 3:严格格式奖励
# 功能:使用正则检查输出是否严格符合 <reasoning>...\n<answer>... 的格式结构。
def strict_format_reward_func(completions, **kwargs):
pattern = r"^\n.*?\n\n\n.*?\n\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# 奖励函数 4:宽松格式奖励
# 功能:检查输出是否至少包含了 XML 标签,允许格式上有少量空白差异。
def soft_format_reward_func(completions, **kwargs):
pattern = r".*?\s*.*?"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# 辅助函数:计算 XML 标签的完整性
def count_xml(text):
count = 0.0
# 检查各个标签是否存在,每存在一个加分,如果格式混乱(如多余换行)则扣分
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
count -= len(text.split("\n\n")[-1])*0.001 # 惩罚项
if text.count("\n") == 1:
count += 0.125
count -= (len(text.split("\n")[-1]) - 1)*0.001 # 惩罚项
return count
# 奖励函数 5:XML 计数奖励
# 功能:基于 XML 标签的完整性和位置给予分数。
def xmlcount_reward_func(completions, **kwargs):
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
六、 配置与启动 GRPO 训练
from trl import GRPOConfig, GRPOTrainer
# 配置训练参数
training_args = GRPOConfig(
use_vllm = True, # 使用 vLLM 生成样本(极快)
learning_rate = 5e-6, # 学习率
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_8bit", # 使用 8-bit 优化器节省显存
logging_steps = 1,
per_device_train_batch_size = 4,
gradient_accumulation_steps = 1,
num_generations = 4, # GRPO 核心:每个 prompt 生成 4 个回答进行对比
max_prompt_length = 256,
max_completion_length = 200,
max_steps = 250, # 训练总步数
save_steps = 250,
max_grad_norm = 0.1,
report_to = "none",
output_dir = "outputs",
)
# 初始化 GRPO 训练器
# 功能:将模型、奖励函数列表和训练配置结合
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = dataset,
)
# 开始训练
# 功能:模型开始根据 prompt 生成多个回答,根据奖励函数的反馈更新 LoRA 权重,使模型更倾向于生成高分回答(格式正确且答案正确)。
trainer.train()
七、保存与推理测试
# 保存训练好的 LoRA 适配器
model.save_lora("grpo_saved_lora")
# --- 推理部分 ---
from vllm import SamplingParams
# 测试用的查询
query = "How many r's are in strawberry?"
# 构建聊天模板
text = tokenizer.apply_chat_template([
{"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)
# 设置采样参数
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
# 生成回答(不加载 LoRA 或 加载 LoRA)
# 这里演示了如何使用 model.fast_generate 进行快速推理
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = None, # 这里设为 None 表示用基础模型,若要用训练后的模型需加载 LoRA
)[0].outputs[0].text
print(output)
# 构建聊天模板
text = tokenizer.apply_chat_template([
{"role" : "system", "content" : SYSTEM_PROMPT},
{"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
# 再次生成,这次加载刚才保存的 LoRA 权重
# 功能:验证经过 GRPO 训练后的模型表现
output = model.fast_generate(
text,
sampling_params = sampling_params,
lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
print(output)
至此,完整的微调代码就介绍完了。
