MiniMax-M1强化学习算法CISPO解读:解决强化学习中的token裁剪问题

大模型向量数据库机器学习

背景:传统PPO算法的局限性

在大语言模型的强化学习训练中,PPO(Proximal Policy Optimization)一直是主流方法。对于数据集

中的问题

,PPO通过策略模型

(参数为

)生成回答

,其目标函数如下:

这里的

是重要性采样权重,用来校正离线策略更新时的分布偏差。

GRPO的改进与不足

GRPO(Group Relative Policy Optimization)在PPO基础上做了简化,去掉了价值模型,将优势函数定义为相对于组内其他回答的奖励:

其中

是回答的奖励,

个回答

从每个问题中采样得到。picture.image

发现的核心问题:token裁剪的危害

研究团队在实验中发现了一个严重问题:传统的裁剪操作严重影响了长链思维推理的训练效果

问题的具体表现

  1. 关键token被误伤 :那些表示反思行为的token(比如"However"、"Recheck"、"Wait"、"Aha"等)在基础模型中概率很低
  2. 推理路径中断 :这些token往往是推理路径的"分叉点",但在策略更新时会产生很高的

值 3. 梯度贡献丢失 :经过第一次策略更新后,这些token就被裁剪掉了,无法在后续的离线策略梯度更新中发挥作用

这个问题在混合架构模型中尤其严重,进一步阻碍了强化学习的可扩展性。虽然DAPO试图通过提高裁剪上界来缓解这个问题,但在16轮离线策略更新的设置下效果并不理想。

CISPO:一种全新的解决方案

核心思想

CISPO(Clipped Importance Sampling Policy Optimization)的核心理念是:不再裁剪token更新,而是裁剪重要性采样权重

算法推导

首先回顾标准的REINFORCE目标函数:

picture.image

其中

表示停止梯度操作。

CISPO在此基础上引入了裁剪的重要性采样权重

picture.image

关键创新:裁剪重要性采样权重

实际应用中,研究团队只调整

,而将

设为很大的值,相当于不施加下界约束。

我们可以看到关键区别:

  • 传统方法:先计算

,然后裁剪整个乘积

  • CISPO:先裁剪

,再乘以

算法优势

  1. 保留所有token的梯度贡献 :特别是在长回答中,每个token都能参与梯度更新
  2. 减少方差 :通过权重裁剪而非token裁剪来稳定训练
  3. 无需KL惩罚项 :类似其他最新工作的简化设计

统一框架:更灵活的表述

研究团队还提出了一个更通用的表述,通过引入token级别的掩码来控制梯度:

picture.image

掩码

的定义等价于PPO信任域中的隐式掩码:

掩码的规则:

  • 如果优势>0且重要性权重>1+ε_high,则M=0(不更新)
  • 如果优势<0且重要性权重<1-ε_low,则M=0(不更新)
  • 其他情况M=1(正常更新)

这个统一框架可以灵活表示不同的裁剪策略。

下面是我用大模型写了一个伪代码,可以帮助我们理解这个算法:

  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torch.distributions import Categorical  
from typing import List, Tuple, Dict  
import math  
  
class CISPOTrainer:  
    """  
    CISPO (Clipped Importance Sampling Policy Optimization) Implementation  
      
    Key difference from PPO/GRPO: Instead of clipping token updates,   
    we clip the importance sampling weights to preserve gradient contributions  
    from all tokens, especially those critical for long-chain reasoning.  
    """  
      
    def \_\_init\_\_(self,   
                 policy\_model: nn.Module,  
                 old\_policy\_model: nn.Module,  
                 epsilon\_high\_is: float = 2.0,  
                 epsilon\_low\_is: float = 1000.0,  # Large value = no lower bound  
                 group\_size: int = 8,  
                 device: str = 'cuda'):  
        """  
        Args:  
            policy\_model: Current policy model being optimized  
            old\_policy\_model: Old policy model for importance sampling  
            epsilon\_high\_is: Upper bound for IS weight clipping  
            epsilon\_low\_is: Lower bound for IS weight clipping (set large to disable)  
            group\_size: Number of responses per question for group advantage  
            device: Training device  
        """  
        self.policy\_model = policy\_model  
        self.old\_policy\_model = old\_policy\_model  
        self.epsilon\_high\_is = epsilon\_high\_is  
        self.epsilon\_low\_is = epsilon\_low\_is  
        self.group\_size = group\_size  
        self.device = device  
          
        # Freeze old policy model  
        for param in self.old\_policy\_model.parameters():  
            param.requires\_grad = False  
      
    def compute\_log\_probs(self, model: nn.Module, input\_ids: torch.Tensor,   
                         attention\_mask: torch.Tensor) -> torch.Tensor:  
        """  
        Compute log probabilities for each token in the sequences  
          
        Args:  
            model: Language model  
            input\_ids: Token IDs [batch\_size, seq\_len]  
            attention\_mask: Attention mask [batch\_size, seq\_len]  
              
        Returns:  
            log\_probs: Log probabilities [batch\_size, seq\_len-1]  
        """  
        with torch.set\_grad\_enabled(model.training):  
            outputs = model(input\_ids=input\_ids, attention\_mask=attention\_mask)  
            logits = outputs.logits  # [batch\_size, seq\_len, vocab\_size]  
              
            # Shift for next token prediction  
            shift\_logits = logits[:, :-1, :].contiguous()  # [batch\_size, seq\_len-1, vocab\_size]  
            shift\_labels = input\_ids[:, 1:].contiguous()   # [batch\_size, seq\_len-1]  
              
            # Compute log probabilities  
            log\_probs = F.log\_softmax(shift\_logits, dim=-1)  
            # Gather log probs for actual tokens  
            log\_probs = log\_probs.gather(dim=-1, index=shift\_labels.unsqueeze(-1)).squeeze(-1)  
              
            return log\_probs  
      
    def compute\_importance\_sampling\_weights(self,   
                                          current\_log\_probs: torch.Tensor,  
                                          old\_log\_probs: torch.Tensor,  
                                          attention\_mask: torch.Tensor) -> torch.Tensor:  
        """  
        Compute importance sampling weights with clipping  
          
        Args:  
            current\_log\_probs: Log probs from current policy [batch\_size, seq\_len-1]  
            old\_log\_probs: Log probs from old policy [batch\_size, seq\_len-1]  
            attention\_mask: Attention mask [batch\_size, seq\_len-1]  
              
        Returns:  
            clipped\_is\_weights: Clipped IS weights [batch\_size, seq\_len-1]  
        """  
        # Compute raw importance sampling weights  
        log\_ratio = current\_log\_probs - old\_log\_probs  
        is\_weights = torch.exp(log\_ratio)  
          
        # Apply clipping to IS weights (key difference from PPO/GRPO)  
        clipped\_is\_weights = torch.clamp(  
            is\_weights,   
            min=1.0 - self.epsilon\_low\_is,  # Effectively no lower bound  
            max=1.0 + self.epsilon\_high\_is  
        )  
          
        # Mask out padding tokens  
        mask = attention\_mask[:, 1:]  # Align with shifted tokens  
        clipped\_is\_weights = clipped\_is\_weights * mask  
          
        return clipped\_is\_weights  
      
    def compute\_group\_advantage(self, rewards: torch.Tensor,   
                               group\_indices: torch.Tensor) -> torch.Tensor:  
        """  
        Compute group relative advantage following GRPO approach  
          
        Args:  
            rewards: Reward for each response [batch\_size]  
            group\_indices: Group index for each response [batch\_size]  
              
        Returns:  
            advantages: Normalized advantages [batch\_size]  
        """  
        advantages = torch.zeros\_like(rewards)  
          
        for group\_id in torch.unique(group\_indices):  
            group\_mask = (group\_indices == group\_id)  
            group\_rewards = rewards[group\_mask]  
              
            if len(group\_rewards) > 1:  
                group\_mean = group\_rewards.mean()  
                group\_std = group\_rewards.std() + 1e-8# Avoid division by zero  
                group\_advantages = (group\_rewards - group\_mean) / group\_std  
            else:  
                group\_advantages = torch.zeros\_like(group\_rewards)  
              
            advantages[group\_mask] = group\_advantages  
          
        return advantages  
      
    def compute\_cispo\_loss(self,   
                          input\_ids: torch.Tensor,  
                          attention\_mask: torch.Tensor,  
                          rewards: torch.Tensor,  
                          group\_indices: torch.Tensor) -> Dict[str, torch.Tensor]:  
        """  
        Compute CISPO loss  
          
        Args:  
            input\_ids: Token IDs [batch\_size, seq\_len]  
            attention\_mask: Attention mask [batch\_size, seq\_len]  
            rewards: Rewards for each response [batch\_size]  
            group\_indices: Group membership [batch\_size]  
              
        Returns:  
            loss\_dict: Dictionary containing loss and metrics  
        """  
        batch\_size, seq\_len = input\_ids.shape  
          
        # Compute log probabilities from both models  
        current\_log\_probs = self.compute\_log\_probs(self.policy\_model, input\_ids, attention\_mask)  
          
        with torch.no\_grad():  
            old\_log\_probs = self.compute\_log\_probs(self.old\_policy\_model, input\_ids, attention\_mask)  
          
        # Compute clipped importance sampling weights  
        clipped\_is\_weights = self.compute\_importance\_sampling\_weights(  
            current\_log\_probs, old\_log\_probs, attention\_mask  
        )  
          
        # Compute group relative advantages  
        advantages = self.compute\_group\_advantage(rewards, group\_indices)  
          
        # Expand advantages to token level  
        token\_advantages = advantages.unsqueeze(1).expand(-1, seq\_len - 1)  # [batch\_size, seq\_len-1]  
          
        # Apply attention mask  
        mask = attention\_mask[:, 1:]  # Align with shifted tokens  
          
        # CISPO objective: no gradient on IS weights, preserve all token gradients  
        policy\_loss = -torch.sum(  
            clipped\_is\_weights.detach() * token\_advantages * current\_log\_probs * mask  
        ) / torch.sum(mask)  
          
        # Compute metrics for monitoring  
        with torch.no\_grad():  
            raw\_is\_weights = torch.exp(current\_log\_probs - old\_log\_probs)  
            is\_ratio\_mean = (raw\_is\_weights * mask).sum() / mask.sum()  
            is\_ratio\_max = (raw\_is\_weights * mask).max()  
              
            # Fraction of tokens that would be clipped in traditional PPO  
            clipped\_fraction = ((raw\_is\_weights > 1.0 + self.epsilon\_high\_is) * mask).sum() / mask.sum()  
              
            # Entropy for monitoring exploration  
            entropy = -(torch.exp(current\_log\_probs) * current\_log\_probs * mask).sum() / mask.sum()  
          
        loss\_dict = {  
            'policy\_loss': policy\_loss,  
            'is\_ratio\_mean': is\_ratio\_mean,  
            'is\_ratio\_max': is\_ratio\_max,  
            'clipped\_fraction': clipped\_fraction,  
            'entropy': entropy,  
            'advantage\_mean': advantages.mean(),  
            'advantage\_std': advantages.std(),  
            'reward\_mean': rewards.mean()  
        }  
          
        return loss\_dict  
      
    def train\_step(self, batch\_data: Dict[str, torch.Tensor],   
                   optimizer: torch.optim.Optimizer) -> Dict[str, float]:  
        """  
        Single training step  
          
        Args:  
            batch\_data: Dictionary containing training data  
            optimizer: PyTorch optimizer  
              
        Returns:  
            metrics: Training metrics  
        """  
        self.policy\_model.train()  
          
        # Extract batch data  
        input\_ids = batch\_data['input\_ids'].to(self.device)  
        attention\_mask = batch\_data['attention\_mask'].to(self.device)  
        rewards = batch\_data['rewards'].to(self.device)  
        group\_indices = batch\_data['group\_indices'].to(self.device)  
          
        # Forward pass  
        loss\_dict = self.compute\_cispo\_loss(input\_ids, attention\_mask, rewards, group\_indices)  
          
        # Backward pass  
        optimizer.zero\_grad()  
        loss\_dict['policy\_loss'].backward()  
          
        # Gradient clipping (optional but recommended)  
        torch.nn.utils.clip\_grad\_norm\_(self.policy\_model.parameters(), max\_norm=1.0)  
          
        optimizer.step()  
          
        # Convert to float for logging  
        metrics = {k: v.item() if torch.is\_tensor(v) else v for k, v in loss\_dict.items()}  
          
        return metrics  
      
    def update\_old\_policy(self):  
        """Update old policy model with current policy weights"""  
        self.old\_policy\_model.load\_state\_dict(self.policy\_model.state\_dict())  
  
  
# Example usage and training loop  
def example\_training\_loop():  
    """  
    Example of how to use CISPO trainer  
    """  
    # Initialize models (placeholder - replace with actual model initialization)  
    policy\_model = None# Your language model  
    old\_policy\_model = None# Copy of the language model  
      
    # Initialize trainer  
    trainer = CISPOTrainer(  
        policy\_model=policy\_model,  
        old\_policy\_model=old\_policy\_model,  
        epsilon\_high\_is=2.0,  # Allow larger IS weights than traditional PPO  
        group\_size=8  
    )  
      
    # Initialize optimizer  
    optimizer = torch.optim.AdamW(policy\_model.parameters(), lr=1e-5)  
      
    # Training loop  
    for epoch in range(num\_epochs):  
        for batch\_idx, batch\_data in enumerate(dataloader):  
            # batch\_data should contain:  
            # - input\_ids: [batch\_size, seq\_len]  
            # - attention\_mask: [batch\_size, seq\_len]    
            # - rewards: [batch\_size]  
            # - group\_indices: [batch\_size]  
              
            metrics = trainer.train\_step(batch\_data, optimizer)  
              
            # Log metrics  
            if batch\_idx % 100 == 0:  
                print(f"Epoch {epoch}, Batch {batch\_idx}")  
                print(f"Policy Loss: {metrics['policy\_loss']:.4f}")  
                print(f"IS Ratio Mean: {metrics['is\_ratio\_mean']:.4f}")  
                print(f"Clipped Fraction: {metrics['clipped\_fraction']:.4f}")  
                print(f"Entropy: {metrics['entropy']:.4f}")  
          
        # Update old policy periodically (e.g., every epoch or few batches)  
        trainer.update\_old\_policy()  
  
  
# Utility function for data preparation  
def prepare\_cispo\_batch(questions: List[str],   
                       responses: List[List[str]],   
                       rewards: List[List[float]],  
                       tokenizer,  
                       max\_length: int = 512) -> Dict[str, torch.Tensor]:  
    """  
    Prepare batch data for CISPO training  
      
    Args:  
        questions: List of questions  
        responses: List of response groups (each question has multiple responses)  
        rewards: List of reward groups (corresponding to responses)  
        tokenizer: Tokenizer for encoding text  
        max\_length: Maximum sequence length  
          
    Returns:  
        batch\_data: Dictionary ready for training  
    """  
    all\_texts = []  
    all\_rewards = []  
    all\_group\_indices = []  
      
    for q\_idx, (question, response\_group, reward\_group) in enumerate(zip(questions, responses, rewards)):  
        for response, reward in zip(response\_group, reward\_group):  
            # Format: question + response (adjust based on your prompt format)  
            text = f"{question}\n{response}"  
            all\_texts.append(text)  
            all\_rewards.append(reward)  
            all\_group\_indices.append(q\_idx)  
      
    # Tokenize  
    encodings = tokenizer(  
        all\_texts,  
        truncation=True,  
        padding=True,  
        max\_length=max\_length,  
        return\_tensors='pt'  
    )  
      
    batch\_data = {  
        'input\_ids': encodings['input\_ids'],  
        'attention\_mask': encodings['attention\_mask'],  
        'rewards': torch.tensor(all\_rewards, dtype=torch.float32),  
        'group\_indices': torch.tensor(all\_group\_indices, dtype=torch.long)  
    }  
      
    return batch\_data  

实验验证:显著的性能提升

实验设置

研究团队在零强化学习训练设置下对比了CISPO、DAPO和GRPO三种算法:

  • 基础模型:Qwen2.5-32B-base
  • 训练数据:数学推理数据集
  • 评测基准:AIME 2024

关键发现

从图2的结果可以看出:

  1. 性能优势明显 :在相同训练步数下,CISPO显著优于DAPO和GRPO
  2. 训练效率大幅提升 :CISPO只用50%的训练步数就能达到DAPO的性能水平
  3. 稳定性更好 :避免了token裁剪带来的不稳定因素

理论贡献

CISPO的提出不仅解决了实际问题,更重要的是揭示了一个被忽视的技术细节:在长序列推理任务中,传统的token级裁剪会阻碍关键推理token的学习

实用价值

  1. 提高训练效率 :同样的性能用更少的计算资源
  2. 增强推理能力 :特别是对于需要长链思维的复杂任务
  3. 更好的可扩展性 :为大规模强化学习训练提供了新思路

CISPO的成功表明,在大语言模型的强化学习训练中,还有很多技术细节值得深入研究。特别是在处理长序列、复杂推理任务时,传统算法的一些假设可能需要重新审视。

总结

CISPO通过一个看似简单的改动——将裁剪从token级别移动到重要性采样权重级别——解决了强化学习训练中的一个关键问题。这个工作提醒我们,在复杂的机器学习系统中,看似微小的技术细节往往可能产生重大影响。对于从事大语言模型训练的研究者和工程师来说,CISPO提供了一个值得借鉴的优化思路。

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
IDC 大模型应用落地白皮书
大模型技术已深度融入业务实践,各企业期望其释放更大商业价值。 但大模型落地之路面临许多挑战和顾虑。 如何精准对接业务需求与发展蓝图,制定切实可行的大模型落地策略? IDC发布首个大模型应用策略与行动指南 一为您揭晓一
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论