SFT 泛化新解读:强化学习 + 奖励修正,一文读懂

SFT 泛化新解读:强化学习 + 奖励修正,一文读懂

1. 研究背景与问题

  • SFT的局限性:传统监督微调(SFT)在LLM任务适配中简单高效,但泛化能力弱于强化学习(RL)。RL依赖奖励信号探索策略,但计算成本高且需人工设计奖励函数。
  • 核心问题:能否在不引入额外反馈(如负样本或奖励模型)的情况下,提升SFT的泛化性能?

2. 理论分析:SFT的本质缺陷

  • 数学等价性(公式5-6):
    通过重要性采样证明SFT梯度等价于策略梯度(Policy Gradient),但隐含一个病态奖励结构:
    r(x,y)=I[y=y],w(yx)=1/πθ(yx)r(x,y) = \mathbb{I}[y=y^*], \quad w(y|x) = 1/\pi_\theta(y|x)
    • 奖励信号极度稀疏(仅当生成结果与专家完全一致时非零)。
    • 权重项 w(yx)w(y|x) 在模型置信度低时(πθ0\pi_\theta \to 0)导致梯度方差无限大,迫使模型过度拟合训练数据。

3. 方法创新:Dynamic Fine-Tuning (DFT)

  • 核心思想:通过动态重加权抵消病态奖励结构。

  • 实现方式(公式7-9):
    LDFT=tsg(πθ(yty<t,x))logπθ(yty<t,x)\mathcal{L}_{\text{DFT}} = -\sum_t \text{sg}(\pi_\theta(y_t^* | y_{<t}^*, x)) \log \pi_\theta(y_t^* | y_{<t}^*, x)

    • 单行代码改动:在SFT损失函数中乘以当前token概率 πθ\pi_\theta(停止梯度)。
    • 作用机制:消除原始SFT中 1/πθ1/\pi_\theta 的逆概率权重,将梯度更新稳定为均匀加权。
  • 与Focal Loss对比
    DFT损失为 plogp-p \log p,而Focal Loss为 (1p)γlogp-(1-p)^\gamma \log p。前者抑制低置信样本的权重以缓解过拟合,后者侧重难样本以解决欠拟合,反映LLM时代的关键矛盾转变(从欠拟合变为过拟合)。

核心实现:DFT的代码修改

DFT的核心是对标准SFT损失函数的一个简单但关键的修改:缩放每个 token 的损失值基于其预测概率(detached以防止梯度流动)。这只需要在训练代码中添加一行代码,具体如下:

  • 修改位置:在模型训练脚本中,处理损失函数的部分。
  • 代码实现
    loss = loss * torch.softmax(shift_logits, dim=-1).gather(1, shift_labels.unsqueeze(-1)).squeeze(-1).detach()
    
    • 解释
      • shift_logits:模型的输出logits(预测分数)。
      • shift_labels:真实的token标签。
      • torch.softmax:将logits转换为概率分布。
      • .gather():选择对应标签的概率。
      • .detach():阻止梯度回传到概率计算,避免循环依赖。
      • 效果:这个缩放操作稳定了梯度更新,提高了模型在泛化任务(如数学推理)上的性能。

4. 实验结果

4.1 SFT场景(仅专家数据)

  • 泛化能力提升
    在5个数学推理基准(Math500、Olympiad Bench等)上,DFT显著超越SFT:

    模型SFT平均提升DFT平均提升增益倍数
    Qwen2.5-Math-1.5B+2.09+15.665.9×
    DeepSeekMath-7B+7.18+15.511.58×
    Qwen2.5-Math-7B+2.37+15.903.8×
  • 挑战性任务表现

    • Olympiad Bench任务:SFT使Qwen1.5B性能从15.88%→12.63%,而DFT提升至27.08%(+11.2点)。
    • AIME 2024任务:SFT使Qwen7B性能从6.68%→2.48%,而DFT提升至8.56%(逆转负迁移)。
  • 收敛效率
    DFT在早期训练步骤(10-20步)即超越SFT最终性能,且收敛速度更快:

picture.image

4.2 离线RL场景(含奖励信号)

  • 超越专业RL方法
    DFT在离线RL设定下优于DPO、RFT等算法,甚至超越在线RL(PPO、GRPO):
    方法平均性能(Qwen1.5B)
    DPO (离线)23.20%
    GRPO (在线)32.00%
    DFT (离线)35.43%

5. 机制分析

  • 概率分布变化(图2):
    • SFT:均匀提升所有token概率,导致过拟合。
    • DFT:极化效应——显著提升高信息量token概率,抑制低信息量token(如连词、标点)。

picture.image

  • 超参数鲁棒性
    DFT在不同学习率和批量大小下均稳定优于SFT,且最优学习率区间(1e-4~5e-5)与SFT一致。

6. 贡献总结

  • 理论层面:首次严格证明SFT梯度与策略梯度的等价性,揭示其泛化瓶颈的数学根源。
  • 实践层面:提出DFT——单行代码修改显著提升SFT泛化能力,在数学推理任务中实现:
    • 平均性能提升 2–6倍于SFT
    • 离线RL场景超越专业算法

局限:当前实验限于数学任务和≤7B模型,未来需验证多模态与更大规模模型。 DFT在非确定性任务(如数学CoT推理)上表现优秀,但在确定性任务(单一正确答案、低熵)可能不如标准SFT。 DFT在文学或金融任务上可能失败,建议先在数学或代码任务上测试。

参考文献

0
0
0
0
评论
未登录
暂无评论