背景:传统PPO算法的局限性
在大语言模型的强化学习训练中,PPO(Proximal Policy Optimization)一直是主流方法。对于数据集
中的问题
,PPO通过策略模型
(参数为
)生成回答
,其目标函数如下:
这里的
是重要性采样权重,用来校正离线策略更新时的分布偏差。
GRPO的改进与不足
GRPO(Group Relative Policy Optimization)在PPO基础上做了简化,去掉了价值模型,将优势函数定义为相对于组内其他回答的奖励:
其中
是回答的奖励,
个回答
从每个问题中采样得到。
发现的核心问题:token裁剪的危害
研究团队在实验中发现了一个严重问题:传统的裁剪操作严重影响了长链思维推理的训练效果 。
问题的具体表现
- 关键token被误伤 :那些表示反思行为的token(比如"However"、"Recheck"、"Wait"、"Aha"等)在基础模型中概率很低
- 推理路径中断 :这些token往往是推理路径的"分叉点",但在策略更新时会产生很高的
值 3. 梯度贡献丢失 :经过第一次策略更新后,这些token就被裁剪掉了,无法在后续的离线策略梯度更新中发挥作用
这个问题在混合架构模型中尤其严重,进一步阻碍了强化学习的可扩展性。虽然DAPO试图通过提高裁剪上界来缓解这个问题,但在16轮离线策略更新的设置下效果并不理想。
CISPO:一种全新的解决方案
核心思想
CISPO(Clipped Importance Sampling Policy Optimization)的核心理念是:不再裁剪token更新,而是裁剪重要性采样权重 。
算法推导
首先回顾标准的REINFORCE目标函数:
其中
表示停止梯度操作。
CISPO在此基础上引入了裁剪的重要性采样权重 :
关键创新:裁剪重要性采样权重
实际应用中,研究团队只调整
,而将
设为很大的值,相当于不施加下界约束。
我们可以看到关键区别:
- 传统方法:先计算
,然后裁剪整个乘积
- CISPO:先裁剪
,再乘以
算法优势
- 保留所有token的梯度贡献 :特别是在长回答中,每个token都能参与梯度更新
- 减少方差 :通过权重裁剪而非token裁剪来稳定训练
- 无需KL惩罚项 :类似其他最新工作的简化设计
统一框架:更灵活的表述
研究团队还提出了一个更通用的表述,通过引入token级别的掩码来控制梯度:
掩码
的定义等价于PPO信任域中的隐式掩码:
掩码的规则:
- 如果优势>0且重要性权重>1+ε_high,则M=0(不更新)
- 如果优势<0且重要性权重<1-ε_low,则M=0(不更新)
- 其他情况M=1(正常更新)
这个统一框架可以灵活表示不同的裁剪策略。
下面是我用大模型写了一个伪代码,可以帮助我们理解这个算法:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from typing import List, Tuple, Dict
import math
class CISPOTrainer:
"""
CISPO (Clipped Importance Sampling Policy Optimization) Implementation
Key difference from PPO/GRPO: Instead of clipping token updates,
we clip the importance sampling weights to preserve gradient contributions
from all tokens, especially those critical for long-chain reasoning.
"""
def \_\_init\_\_(self,
policy\_model: nn.Module,
old\_policy\_model: nn.Module,
epsilon\_high\_is: float = 2.0,
epsilon\_low\_is: float = 1000.0, # Large value = no lower bound
group\_size: int = 8,
device: str = 'cuda'):
"""
Args:
policy\_model: Current policy model being optimized
old\_policy\_model: Old policy model for importance sampling
epsilon\_high\_is: Upper bound for IS weight clipping
epsilon\_low\_is: Lower bound for IS weight clipping (set large to disable)
group\_size: Number of responses per question for group advantage
device: Training device
"""
self.policy\_model = policy\_model
self.old\_policy\_model = old\_policy\_model
self.epsilon\_high\_is = epsilon\_high\_is
self.epsilon\_low\_is = epsilon\_low\_is
self.group\_size = group\_size
self.device = device
# Freeze old policy model
for param in self.old\_policy\_model.parameters():
param.requires\_grad = False
def compute\_log\_probs(self, model: nn.Module, input\_ids: torch.Tensor,
attention\_mask: torch.Tensor) -> torch.Tensor:
"""
Compute log probabilities for each token in the sequences
Args:
model: Language model
input\_ids: Token IDs [batch\_size, seq\_len]
attention\_mask: Attention mask [batch\_size, seq\_len]
Returns:
log\_probs: Log probabilities [batch\_size, seq\_len-1]
"""
with torch.set\_grad\_enabled(model.training):
outputs = model(input\_ids=input\_ids, attention\_mask=attention\_mask)
logits = outputs.logits # [batch\_size, seq\_len, vocab\_size]
# Shift for next token prediction
shift\_logits = logits[:, :-1, :].contiguous() # [batch\_size, seq\_len-1, vocab\_size]
shift\_labels = input\_ids[:, 1:].contiguous() # [batch\_size, seq\_len-1]
# Compute log probabilities
log\_probs = F.log\_softmax(shift\_logits, dim=-1)
# Gather log probs for actual tokens
log\_probs = log\_probs.gather(dim=-1, index=shift\_labels.unsqueeze(-1)).squeeze(-1)
return log\_probs
def compute\_importance\_sampling\_weights(self,
current\_log\_probs: torch.Tensor,
old\_log\_probs: torch.Tensor,
attention\_mask: torch.Tensor) -> torch.Tensor:
"""
Compute importance sampling weights with clipping
Args:
current\_log\_probs: Log probs from current policy [batch\_size, seq\_len-1]
old\_log\_probs: Log probs from old policy [batch\_size, seq\_len-1]
attention\_mask: Attention mask [batch\_size, seq\_len-1]
Returns:
clipped\_is\_weights: Clipped IS weights [batch\_size, seq\_len-1]
"""
# Compute raw importance sampling weights
log\_ratio = current\_log\_probs - old\_log\_probs
is\_weights = torch.exp(log\_ratio)
# Apply clipping to IS weights (key difference from PPO/GRPO)
clipped\_is\_weights = torch.clamp(
is\_weights,
min=1.0 - self.epsilon\_low\_is, # Effectively no lower bound
max=1.0 + self.epsilon\_high\_is
)
# Mask out padding tokens
mask = attention\_mask[:, 1:] # Align with shifted tokens
clipped\_is\_weights = clipped\_is\_weights * mask
return clipped\_is\_weights
def compute\_group\_advantage(self, rewards: torch.Tensor,
group\_indices: torch.Tensor) -> torch.Tensor:
"""
Compute group relative advantage following GRPO approach
Args:
rewards: Reward for each response [batch\_size]
group\_indices: Group index for each response [batch\_size]
Returns:
advantages: Normalized advantages [batch\_size]
"""
advantages = torch.zeros\_like(rewards)
for group\_id in torch.unique(group\_indices):
group\_mask = (group\_indices == group\_id)
group\_rewards = rewards[group\_mask]
if len(group\_rewards) > 1:
group\_mean = group\_rewards.mean()
group\_std = group\_rewards.std() + 1e-8# Avoid division by zero
group\_advantages = (group\_rewards - group\_mean) / group\_std
else:
group\_advantages = torch.zeros\_like(group\_rewards)
advantages[group\_mask] = group\_advantages
return advantages
def compute\_cispo\_loss(self,
input\_ids: torch.Tensor,
attention\_mask: torch.Tensor,
rewards: torch.Tensor,
group\_indices: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Compute CISPO loss
Args:
input\_ids: Token IDs [batch\_size, seq\_len]
attention\_mask: Attention mask [batch\_size, seq\_len]
rewards: Rewards for each response [batch\_size]
group\_indices: Group membership [batch\_size]
Returns:
loss\_dict: Dictionary containing loss and metrics
"""
batch\_size, seq\_len = input\_ids.shape
# Compute log probabilities from both models
current\_log\_probs = self.compute\_log\_probs(self.policy\_model, input\_ids, attention\_mask)
with torch.no\_grad():
old\_log\_probs = self.compute\_log\_probs(self.old\_policy\_model, input\_ids, attention\_mask)
# Compute clipped importance sampling weights
clipped\_is\_weights = self.compute\_importance\_sampling\_weights(
current\_log\_probs, old\_log\_probs, attention\_mask
)
# Compute group relative advantages
advantages = self.compute\_group\_advantage(rewards, group\_indices)
# Expand advantages to token level
token\_advantages = advantages.unsqueeze(1).expand(-1, seq\_len - 1) # [batch\_size, seq\_len-1]
# Apply attention mask
mask = attention\_mask[:, 1:] # Align with shifted tokens
# CISPO objective: no gradient on IS weights, preserve all token gradients
policy\_loss = -torch.sum(
clipped\_is\_weights.detach() * token\_advantages * current\_log\_probs * mask
) / torch.sum(mask)
# Compute metrics for monitoring
with torch.no\_grad():
raw\_is\_weights = torch.exp(current\_log\_probs - old\_log\_probs)
is\_ratio\_mean = (raw\_is\_weights * mask).sum() / mask.sum()
is\_ratio\_max = (raw\_is\_weights * mask).max()
# Fraction of tokens that would be clipped in traditional PPO
clipped\_fraction = ((raw\_is\_weights > 1.0 + self.epsilon\_high\_is) * mask).sum() / mask.sum()
# Entropy for monitoring exploration
entropy = -(torch.exp(current\_log\_probs) * current\_log\_probs * mask).sum() / mask.sum()
loss\_dict = {
'policy\_loss': policy\_loss,
'is\_ratio\_mean': is\_ratio\_mean,
'is\_ratio\_max': is\_ratio\_max,
'clipped\_fraction': clipped\_fraction,
'entropy': entropy,
'advantage\_mean': advantages.mean(),
'advantage\_std': advantages.std(),
'reward\_mean': rewards.mean()
}
return loss\_dict
def train\_step(self, batch\_data: Dict[str, torch.Tensor],
optimizer: torch.optim.Optimizer) -> Dict[str, float]:
"""
Single training step
Args:
batch\_data: Dictionary containing training data
optimizer: PyTorch optimizer
Returns:
metrics: Training metrics
"""
self.policy\_model.train()
# Extract batch data
input\_ids = batch\_data['input\_ids'].to(self.device)
attention\_mask = batch\_data['attention\_mask'].to(self.device)
rewards = batch\_data['rewards'].to(self.device)
group\_indices = batch\_data['group\_indices'].to(self.device)
# Forward pass
loss\_dict = self.compute\_cispo\_loss(input\_ids, attention\_mask, rewards, group\_indices)
# Backward pass
optimizer.zero\_grad()
loss\_dict['policy\_loss'].backward()
# Gradient clipping (optional but recommended)
torch.nn.utils.clip\_grad\_norm\_(self.policy\_model.parameters(), max\_norm=1.0)
optimizer.step()
# Convert to float for logging
metrics = {k: v.item() if torch.is\_tensor(v) else v for k, v in loss\_dict.items()}
return metrics
def update\_old\_policy(self):
"""Update old policy model with current policy weights"""
self.old\_policy\_model.load\_state\_dict(self.policy\_model.state\_dict())
# Example usage and training loop
def example\_training\_loop():
"""
Example of how to use CISPO trainer
"""
# Initialize models (placeholder - replace with actual model initialization)
policy\_model = None# Your language model
old\_policy\_model = None# Copy of the language model
# Initialize trainer
trainer = CISPOTrainer(
policy\_model=policy\_model,
old\_policy\_model=old\_policy\_model,
epsilon\_high\_is=2.0, # Allow larger IS weights than traditional PPO
group\_size=8
)
# Initialize optimizer
optimizer = torch.optim.AdamW(policy\_model.parameters(), lr=1e-5)
# Training loop
for epoch in range(num\_epochs):
for batch\_idx, batch\_data in enumerate(dataloader):
# batch\_data should contain:
# - input\_ids: [batch\_size, seq\_len]
# - attention\_mask: [batch\_size, seq\_len]
# - rewards: [batch\_size]
# - group\_indices: [batch\_size]
metrics = trainer.train\_step(batch\_data, optimizer)
# Log metrics
if batch\_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch\_idx}")
print(f"Policy Loss: {metrics['policy\_loss']:.4f}")
print(f"IS Ratio Mean: {metrics['is\_ratio\_mean']:.4f}")
print(f"Clipped Fraction: {metrics['clipped\_fraction']:.4f}")
print(f"Entropy: {metrics['entropy']:.4f}")
# Update old policy periodically (e.g., every epoch or few batches)
trainer.update\_old\_policy()
# Utility function for data preparation
def prepare\_cispo\_batch(questions: List[str],
responses: List[List[str]],
rewards: List[List[float]],
tokenizer,
max\_length: int = 512) -> Dict[str, torch.Tensor]:
"""
Prepare batch data for CISPO training
Args:
questions: List of questions
responses: List of response groups (each question has multiple responses)
rewards: List of reward groups (corresponding to responses)
tokenizer: Tokenizer for encoding text
max\_length: Maximum sequence length
Returns:
batch\_data: Dictionary ready for training
"""
all\_texts = []
all\_rewards = []
all\_group\_indices = []
for q\_idx, (question, response\_group, reward\_group) in enumerate(zip(questions, responses, rewards)):
for response, reward in zip(response\_group, reward\_group):
# Format: question + response (adjust based on your prompt format)
text = f"{question}\n{response}"
all\_texts.append(text)
all\_rewards.append(reward)
all\_group\_indices.append(q\_idx)
# Tokenize
encodings = tokenizer(
all\_texts,
truncation=True,
padding=True,
max\_length=max\_length,
return\_tensors='pt'
)
batch\_data = {
'input\_ids': encodings['input\_ids'],
'attention\_mask': encodings['attention\_mask'],
'rewards': torch.tensor(all\_rewards, dtype=torch.float32),
'group\_indices': torch.tensor(all\_group\_indices, dtype=torch.long)
}
return batch\_data
实验验证:显著的性能提升
实验设置
研究团队在零强化学习训练设置下对比了CISPO、DAPO和GRPO三种算法:
- 基础模型:Qwen2.5-32B-base
- 训练数据:数学推理数据集
- 评测基准:AIME 2024
关键发现
从图2的结果可以看出:
- 性能优势明显 :在相同训练步数下,CISPO显著优于DAPO和GRPO
- 训练效率大幅提升 :CISPO只用50%的训练步数就能达到DAPO的性能水平
- 稳定性更好 :避免了token裁剪带来的不稳定因素
理论贡献
CISPO的提出不仅解决了实际问题,更重要的是揭示了一个被忽视的技术细节:在长序列推理任务中,传统的token级裁剪会阻碍关键推理token的学习 。
实用价值
- 提高训练效率 :同样的性能用更少的计算资源
- 增强推理能力 :特别是对于需要长链思维的复杂任务
- 更好的可扩展性 :为大规模强化学习训练提供了新思路
CISPO的成功表明,在大语言模型的强化学习训练中,还有很多技术细节值得深入研究。特别是在处理长序列、复杂推理任务时,传统算法的一些假设可能需要重新审视。
总结
CISPO通过一个看似简单的改动——将裁剪从token级别移动到重要性采样权重级别——解决了强化学习训练中的一个关键问题。这个工作提醒我们,在复杂的机器学习系统中,看似微小的技术细节往往可能产生重大影响。对于从事大语言模型训练的研究者和工程师来说,CISPO提供了一个值得借鉴的优化思路。