“ 今天还是分享一个训练框架,它的优势是简单、而且性能很高,可以在单卡A100上微调34B的RLHF,4块4090做到7B模型的全参微调。训练RLHF比deepspeedchat快3倍
https://github.com/OpenLLMAI/OpenRLHF
OpenRLHF旨在基于Ray和DeepSpeed开发一个高性能的RLHF训练框架。OpenRLHF是一个最简单的高性能RLHF库,支持使用单个DGXA100(脚本)进行34B模型的RLHF训练。
OpenRLHF的关键是使用Ray将Actor模型、Reward模型、Reference模型和Critic模型分布到不同的GPU上,同时将Adam优化器放置在CPU上。这使得可以在多个24GB RTX 4090 GPU(或者多个A100 80G)上进行7B模型的全面微调,通过使用Adam Offload和Ray的能力以及大批量生成批处理大小,实现高效的训练。我们使用13B llama2模型的PPO性能是DeepSpeedChat的4倍。
features
- 一个基于DeepSpeed的快速LLaMA2 SFT/PPO训练框架。
- 适用于Slurm的多节点训练脚本。
- 支持DPO(直接偏好优化)。
- 基于Ray的分布式PPO,适用于34B和7B模型跑在RTX4090上。
- 支持决策Transformer(DT)对齐(https://arxiv.org/abs/2308.12050)。
- 支持大多的中文模型
- 支持Wandb日志(--wandb)。
- 支持conda环境/nvidia docker。
- 支持FlashAttention2(--flash_attn)。
- 预训练的7B/13B llama2检查点
- 支持GPT4评估和PPO vs SFT示例
- 支持多个奖励模型。
- 支持拒绝抽样。
性能
7B llama2 RLHF | 13B llama2 RLHF (50k samples) | |
---|---|---|
OpenRLHF | - | 22 hours with 8 A100 |
DeepSpeedChat | - | 48 hours with 16 A100 |