LLM微调(六)| 使用Unsloth框架和TRL库通过 GRPO强化学习算法对 Qwen 2.5 (3B) 进行对齐

大模型机器学习算法
本文将介绍一下如何使用 Unsloth 框架和 TRL (Transformer Reinforcement Learning) 库,通过 GRPO (Group Relative Policy Optimization) 强化学习算法对 Qwen 2.5 (3B) 大模型进行微调(Fine-tuning)。不讲原理,直接上代码:

源代码已经放在:https://github.com/ArronAI007/Awesome-AGI/blob/main/LLM%20Pipeline/Fine-Tune/trl/01\_Train\_Qwen\_2\_5(3B)\_To\_Reason\_With\_GRPO.ipynb

整个代码实现了以下主要功能:

  • 环境配置:安装并配置 Unsloth、vLLM 和 TRL 等高性能训练库。

  • 模型加载与量化:加载 Qwen 2.5-3B-Instruct 模型,并使用 4-bit 量化(QLoRA)以降低显存占用,同时启用 vLLM 进行快速推理。

  • 数据集准备:加载 GSM8K(小学数学)数据集,并将其格式化为包含 和 XML 标签的 Prompt,旨在训练模型具备“思维链”(Chain-of-Thought)推理能力。

  • 定义奖励函数 (Reward Functions):定义了一组规则来评价模型的输出,包括:答案准确性、是否为整数、XML 格式是否规范等。这是强化学习(RL)的核心部分。

  • GRPO 训练:使用 GRPO 算法进行训练。GRPO 会让模型针对同一个问题生成多个回答,通过对比这些回答的奖励分数来优化策略,而不需要额外的价值模型(Value Model)。

  • 保存与推理:保存微调后的 LoRA 权重,并演示如何加载权重进行推理测试。

一、环境安装与设置

  
import os, numpy  
# 设置环境变量,让 Unsloth 在 vLLM 中预留更多显存用于上下文  
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"  
# 获取当前 numpy 版本以防止依赖冲突  
numpy_version = f"numpy=={numpy.__version__}"  
# Install dependencies with numpy version preservation  
# 安装 Unsloth 及其依赖(Unsloth 用于加速训练,vLLM 用于加速推理)  
!uv pip install unsloth_zoo  
!uv pip install --upgrade unsloth vllm==0.9.2 {numpy_version} torchvision bitsandbytes xformers  
!uv pip install triton==3.2.0  
!uv pip install transformers==4.55.4  
!uv pip install --no-deps trl==0.22.2

二、加载Model和Tokenizer

  
from unsloth import FastLanguageModel  
import torch  
# 设置最大上下文长度  
max_seq_length = 1024  
# 加载预训练模型和分词器  
# 功能:加载 Qwen2.5-3B-Instruct 模型,使用 4-bit 量化加载以节省显存,开启 fast_inference (vLLM)  
model, tokenizer = FastLanguageModel.from_pretrained(  
    model_name = "unsloth/Qwen2.5-3B-Instruct",  
    max_seq_length = max_seq_length,  
    load_in_4bit = True,               # 4bit 量化加载  
    fast_inference = True,             # 启用 vLLM 快速推理引擎  
    max_lora_rank = 8,                 # LoRA 秩  
    gpu_memory_utilization = 0.9,      # 显存利用率上限  
)

三、配置 LoRA (低秩适应)

  
# 配置 PEFT (Parameter-Efficient Fine-Tuning)  
# 功能:将模型转换为 LoRA 模式,只训练新增的少量参数,冻结原模型参数  
model = FastLanguageModel.get_peft_model(  
    model,  
    r = 8,  # LoRA 的秩  
    # 指定需要应用 LoRA 的模块(注意力层和前馈网络层)  
    target_modules = [  
        "q_proj", "k_proj", "v_proj", "o_proj",  
        "gate_proj", "up_proj", "down_proj",  
    ],  
    lora_alpha = 8,  
    use_gradient_checkpointing = "unsloth",       # 使用梯度检查点节省显存  
    random_state = 1234,  
)

四、数据集处理与格式化

  
import re  
from datasets import load_dataset, Dataset  
# 系统提示词,强制模型使用特定的 XML 格式输出推理过程和答案  
SYSTEM_PROMPT = """  
Respond in the following format:  
<reasoning>  
...  
</reasoning>  
<answer>  
...  
</answer>  
"""  
# 定义 XML 格式模板  
XML_COT_FORMAT = """\  
<reasoning>  
{reasoning}  
</reasoning>  
<answer>  
{answer}  
</answer>  
"""  
# 函数:从模型输出中提取 XML 标签内的答案  
def extract_xml_answer(text):  
    if "" not in text or "" not in text:  
        return ""  
    return text.split(" ", 1)[-1].split(" ", 1)[0].strip()  
# 函数:从 GSM8K 数据集的原始答案字段中提取最终数值(通常在 #### 之后)  
def extract_hash_answer(text):  
    return text.split("####")[-1].strip() if "####" in text else None  
# 函数:加载并预处理 GSM8K 数据集  
# 功能:加载 OpenAI 的 GSM8K 数据集,并将每个样本转化为包含 system prompt 和 user prompt 的对话格式  
def get_gsm8k_dataset(split = "train"):  
    data = load_dataset("openai/gsm8k", "main")[split]  
    return data.map(  
        lambda x: {  
            "prompt": [  
                {"role": "system", "content": SYSTEM_PROMPT},  
                {"role": "user", "content": x["question"]},  
            ],  
            "answer": extract_hash_answer(x["answer"]),           # 提取标准答案用于后续奖励计算  
        }  
    )  
# 加载处理好的数据集  
dataset = get_gsm8k_dataset()

五、定义奖励函数 (Reward Functions)

这是 GRPO 的核心,模型生成的每个结果都会经过这些函数打分。

  
# 奖励函数 1:正确性奖励  
# 功能:检查模型生成的答案(从 XML 中提取)是否与标准答案完全一致。正确得 2.0 分,否则 0 分。  
def correctness_reward_func(prompts, completions, answer, **kwargs):  
    responses = [completion[0]['content'] for completion in completions]  
    q = prompts[0][-1]['content']  
    extracted_responses = [extract_xml_answer(r) for r in responses]  
    # 打印日志方便调试  
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")  
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]  
# 奖励函数 2:整数奖励  
# 功能:检查提取出的答案是否为数字。是则得 0.5 分。  
def int_reward_func(completions, **kwargs):  
    responses = [completion[0]['content'] for completion in completions]  
    extracted_responses = [extract_xml_answer(r) for r in responses]  
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]  
# 奖励函数 3:严格格式奖励  
# 功能:使用正则检查输出是否严格符合 <reasoning>...\n<answer>... 的格式结构。  
def strict_format_reward_func(completions, **kwargs):  
    pattern = r"^\n.*?\n\n\n.*?\n\n$"  
    responses = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, r) for r in responses]  
    return [0.5 if match else 0.0 for match in matches]  
# 奖励函数 4:宽松格式奖励  
# 功能:检查输出是否至少包含了 XML 标签,允许格式上有少量空白差异。  
def soft_format_reward_func(completions, **kwargs):  
    pattern = r".*?\s*.*?"  
    responses = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, r) for r in responses]  
    return [0.5 if match else 0.0 for match in matches]  
# 辅助函数:计算 XML 标签的完整性  
def count_xml(text):  
    count = 0.0  
    # 检查各个标签是否存在,每存在一个加分,如果格式混乱(如多余换行)则扣分  
    if text.count("\n") == 1:  
        count += 0.125  
    if text.count("\n\n") == 1:  
        count += 0.125  
    if text.count("\n\n") == 1:  
        count += 0.125  
        count -= len(text.split("\n\n")[-1])*0.001       # 惩罚项  
    if text.count("\n") == 1:  
        count += 0.125  
        count -= (len(text.split("\n")[-1]) - 1)*0.001   # 惩罚项  
    return count  
# 奖励函数 5:XML 计数奖励  
# 功能:基于 XML 标签的完整性和位置给予分数。  
def xmlcount_reward_func(completions, **kwargs):  
    contents = [completion[0]["content"] for completion in completions]  
    return [count_xml(c) for c in contents]

六、 配置与启动 GRPO 训练

  
from trl import GRPOConfig, GRPOTrainer  
# 配置训练参数  
training_args = GRPOConfig(  
    use_vllm = True,                  # 使用 vLLM 生成样本(极快)  
    learning_rate = 5e-6,             # 学习率  
    adam_beta1 = 0.9,  
    adam_beta2 = 0.99,  
    weight_decay = 0.1,  
    warmup_ratio = 0.1,  
    lr_scheduler_type = "cosine",  
    optim = "adamw_8bit",             # 使用 8-bit 优化器节省显存  
    logging_steps = 1,  
    per_device_train_batch_size = 4,  
    gradient_accumulation_steps = 1,  
    num_generations = 4,              # GRPO 核心:每个 prompt 生成 4 个回答进行对比  
    max_prompt_length = 256,  
    max_completion_length = 200,  
    max_steps = 250,                  # 训练总步数  
    save_steps = 250,  
    max_grad_norm = 0.1,  
    report_to = "none",  
    output_dir = "outputs",  
)
  
# 初始化 GRPO 训练器  
# 功能:将模型、奖励函数列表和训练配置结合  
trainer = GRPOTrainer(  
    model = model,  
    processing_class = tokenizer,  
    reward_funcs = [  
        xmlcount_reward_func,  
        soft_format_reward_func,  
        strict_format_reward_func,  
        int_reward_func,  
        correctness_reward_func,  
    ],  
    args = training_args,  
    train_dataset = dataset,  
)
  
# 开始训练  
# 功能:模型开始根据 prompt 生成多个回答,根据奖励函数的反馈更新 LoRA 权重,使模型更倾向于生成高分回答(格式正确且答案正确)。  
trainer.train()

七、保存与推理测试

  
# 保存训练好的 LoRA 适配器  
model.save_lora("grpo_saved_lora")
  
# --- 推理部分 ---  
from vllm import SamplingParams  
# 测试用的查询  
query = "How many r's are in strawberry?"  
# 构建聊天模板  
text = tokenizer.apply_chat_template([  
    {"role" : "user", "content" : query},  
], tokenize = False, add_generation_prompt = True)  
# 设置采样参数  
sampling_params = SamplingParams(  
    temperature = 0.8,  
    top_p = 0.95,  
    max_tokens = 1024,  
)  
# 生成回答(不加载 LoRA 或 加载 LoRA)  
# 这里演示了如何使用 model.fast_generate 进行快速推理  
output = model.fast_generate(  
    [text],  
    sampling_params = sampling_params,  
    lora_request = None,                    # 这里设为 None 表示用基础模型,若要用训练后的模型需加载 LoRA  
)[0].outputs[0].text  
print(output)
  
# 构建聊天模板  
text = tokenizer.apply_chat_template([  
    {"role" : "system", "content" : SYSTEM_PROMPT},  
    {"role" : "user", "content" : query},  
], tokenize = False, add_generation_prompt = True)  
sampling_params = SamplingParams(  
    temperature = 0.8,  
    top_p = 0.95,  
    max_tokens = 1024,  
)  
# 再次生成,这次加载刚才保存的 LoRA 权重  
# 功能:验证经过 GRPO 训练后的模型表现  
output = model.fast_generate(  
    text,  
    sampling_params = sampling_params,  
    lora_request = model.load_lora("grpo_saved_lora"),  
)[0].outputs[0].text  
print(output)

至此,完整的微调代码就介绍完了。

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
在火山引擎云搜索服务上构建混合搜索的设计与实现
本次演讲将重点介绍字节跳动在混合搜索领域的探索,并探讨如何在多模态数据场景下进行海量数据搜索。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论