SFT 泛化新解读:强化学习 + 奖励修正,一文读懂
1. 研究背景与问题
- SFT的局限性:传统监督微调(SFT)在LLM任务适配中简单高效,但泛化能力弱于强化学习(RL)。RL依赖奖励信号探索策略,但计算成本高且需人工设计奖励函数。
- 核心问题:能否在不引入额外反馈(如负样本或奖励模型)的情况下,提升SFT的泛化性能?
2. 理论分析:SFT的本质缺陷
- 数学等价性(公式5-6):
通过重要性采样证明SFT梯度等价于策略梯度(Policy Gradient),但隐含一个病态奖励结构:
- 奖励信号极度稀疏(仅当生成结果与专家完全一致时非零)。
- 权重项 在模型置信度低时()导致梯度方差无限大,迫使模型过度拟合训练数据。
3. 方法创新:Dynamic Fine-Tuning (DFT)
-
核心思想:通过动态重加权抵消病态奖励结构。
-
实现方式(公式7-9):
- 单行代码改动:在SFT损失函数中乘以当前token概率 (停止梯度)。
- 作用机制:消除原始SFT中 的逆概率权重,将梯度更新稳定为均匀加权。
-
与Focal Loss对比:
DFT损失为 ,而Focal Loss为 。前者抑制低置信样本的权重以缓解过拟合,后者侧重难样本以解决欠拟合,反映LLM时代的关键矛盾转变(从欠拟合变为过拟合)。
核心实现:DFT的代码修改
DFT的核心是对标准SFT损失函数的一个简单但关键的修改:缩放每个 token 的损失值基于其预测概率(detached以防止梯度流动)。这只需要在训练代码中添加一行代码,具体如下:
- 修改位置:在模型训练脚本中,处理损失函数的部分。
- 代码实现:
loss = loss * torch.softmax(shift_logits, dim=-1).gather(1, shift_labels.unsqueeze(-1)).squeeze(-1).detach()- 解释:
shift_logits:模型的输出logits(预测分数)。shift_labels:真实的token标签。torch.softmax:将logits转换为概率分布。.gather():选择对应标签的概率。.detach():阻止梯度回传到概率计算,避免循环依赖。- 效果:这个缩放操作稳定了梯度更新,提高了模型在泛化任务(如数学推理)上的性能。
- 解释:
4. 实验结果
4.1 SFT场景(仅专家数据)
-
泛化能力提升:
在5个数学推理基准(Math500、Olympiad Bench等)上,DFT显著超越SFT:模型 SFT平均提升 DFT平均提升 增益倍数 Qwen2.5-Math-1.5B +2.09 +15.66 5.9× DeepSeekMath-7B +7.18 +15.51 1.58× Qwen2.5-Math-7B +2.37 +15.90 3.8× -
挑战性任务表现:
- Olympiad Bench任务:SFT使Qwen1.5B性能从15.88%→12.63%,而DFT提升至27.08%(+11.2点)。
- AIME 2024任务:SFT使Qwen7B性能从6.68%→2.48%,而DFT提升至8.56%(逆转负迁移)。
-
收敛效率:
DFT在早期训练步骤(10-20步)即超越SFT最终性能,且收敛速度更快:
4.2 离线RL场景(含奖励信号)
- 超越专业RL方法:
DFT在离线RL设定下优于DPO、RFT等算法,甚至超越在线RL(PPO、GRPO):方法 平均性能(Qwen1.5B) DPO (离线) 23.20% GRPO (在线) 32.00% DFT (离线) 35.43%
5. 机制分析
- 概率分布变化(图2):
- SFT:均匀提升所有token概率,导致过拟合。
- DFT:极化效应——显著提升高信息量token概率,抑制低信息量token(如连词、标点)。
- 超参数鲁棒性:
DFT在不同学习率和批量大小下均稳定优于SFT,且最优学习率区间(1e-4~5e-5)与SFT一致。
6. 贡献总结
- 理论层面:首次严格证明SFT梯度与策略梯度的等价性,揭示其泛化瓶颈的数学根源。
- 实践层面:提出DFT——单行代码修改显著提升SFT泛化能力,在数学推理任务中实现:
- 平均性能提升 2–6倍于SFT
- 离线RL场景超越专业算法
局限:当前实验限于数学任务和≤7B模型,未来需验证多模态与更大规模模型。 DFT在非确定性任务(如数学CoT推理)上表现优秀,但在确定性任务(单一正确答案、低熵)可能不如标准SFT。 DFT在文学或金融任务上可能失败,建议先在数学或代码任务上测试。
