https://arxiv.org/abs/2403.17297
https://github.com/InternLM/InternLM
- 多维度并行性 :InternEvo通过结合数据并行、张量并行、序列并行和流水线并行,实现了在数千个GPU上进行模型训练的扩展能力。利用Zero策略,减少训练过程中所需的内存占用。InternEvo整合了FlashAttention技术,提高硬件利用率。:
- 强大的扩展性能 :在保持恒定的全局批次大小的情况下,InternEvo在扩展到更多GPU时仍能保持较高的模型FLOPs利用率(MFU), 扩展GPU数量 / 提高序列长度,Flops利用率。
- 优化了通信开销,提高训练效率,在执行前向和后向传播时,有效地重叠了通信和计算过程,最大化了训练流水线的效率。
- 容错性 :解决了GPU数据中心中常见的硬件故障、复杂的并行化策略和不平衡的资源利用等问题。
- 为了提高RLHF效率,在InternEvo和基础上,开发了一个RLHF框架
整体上基于LLama2的结构。InternLM2的矩阵布局设计考虑到了张量并行(tensor parallelism)的需求,通过交错排列每个头的Wk、Wq和Wv矩阵,使得模型能够更灵活地适应不同的分布式计算环境。不同的矩阵排列,在tp下呈现出不同的复杂度,如下图:
考虑到长上下文推理,使用GQA
中文和英文网页数据占86.46%。。虽然其他来源的数据量相对较小,如书籍和技术文献(简称techlit),但平均文档长度更长,内容质量相对更高,详细数据如下表:
数据流水线:
- 数据格式化:爬的网页提取正文和检测语言
- 规则处理:随机爬的网页很多脏数据,针对标点符号的异常断行、异常字符出现频率、标点符号分布情况等设计了一系列启发式过滤规则
- 重复数据过滤:使用minhash (5-gram) ,0.7阈值
- 安全过滤:采用“域名屏蔽”、“词屏蔽”、“色情分类”和“毒性分类”相结合的综合安全策略对数据进行过滤
- 毒性分类模型过滤 (基于kaggle相关数据集训练的bert)
- 质量过滤:互联网来源的数据包含大量低质量的内容,人工按照一些维度标注,然后训练模型二次过滤
代码数据
从GitHub直接爬取,公共数据集,以及与编码和编程相关的在线资源,如问答论坛,教程网站和API文档,具体如下图:
基于分类器,对数据进行高中低,质量区分;高质量的数据将具有更高的采样权重,并可以在预训练阶段进行多次训练迭代。中等质量的数据具有正常的采样权重,通常只训练一次。低质量的数据被排除在外,尽管低质量的比例相对较小,但删除它们对优化模型性能和确保训练稳定性至关重要
数据处理
- 格式清洗,转成markdown,但是部分转不了,就不管了
- 重复数据过滤,与文本类似
- 质量过滤
- 代码库依赖排序
长上下文数据
它包括三个阶段:
- 长度选择,基于规则的过滤器,选择超过32K字节的数据样本;
- 统计过滤器,利用统计特征识别和删除异常数据,目的是过滤掉无意义的数据,而不是选择高质量的数据,32k上下文下,样本是有清晰的特征分布
- 困惑度过滤器,利用困惑度的差异来评估文本段之间的连贯性,过滤掉具有分散上下文的样本。(估计两个文本段P(S2|S1)之间的条件概率,其中S1在S2之前。当S1和S2强相关时,条件概率应该高于单独估计S2的概率,这也意味着困惑度差为负。相反,如果概率向相反的方向改变,意味着S1是一个分散注意力的上下文,则应该从预训练语料库中删除它。理想情况下,添加更多上下文不应该对后续文本的可预测性产生负面影响)
(用于长上下文训练的所有数据都是标准预训练语料库的一个子集,这意味着在预训练期间将至少学习两次长上下文数据)
包括词典,超参数,训练配置等等。
SFT
使用了一个包含1000万个指令数据实例的数据集,这些数据实例经过了筛选,以确保它们是有用和无害的。该数据集包含了各种各样的主题,包括一般对话、NLP任务、数学问题、代码生成和函数调用等。比例如下图:
7B和20B模型都使用AdamW优化器进行一个epoch的训练,初始学习率为4e-5。
RLHF
实践中rlhf存在的问题
- 偏好冲突,有的时候我们希望模型提供有用的信息(有用的),同时不产生有害或不适当的内容(无害的)。然而,这两种偏好在实践中往往无法同时满足,因为在某些情况下,提供有用的信息可能涉及敏感或高风险内容。
- 计算成本,现有的RLHF系统可能常依赖于多个偏好模型进行评分
- reward hacking,模型可能学会通过捷径“欺骗”奖励系统以获得更高的奖励,而不是真正学习预期的行为
提出并使用了条件奖励模型的方案,与传统方法不同,条件奖励模型将不同类型的偏好的system prompt,合并在一个单一奖励模型中有效地建模各种偏好
由于奖励模型是从SFT模型初始化的,该模型已经学会了遵循不同的人类指令,因此还让奖励模型遵循不同的系统提示,以适应不同场景的不同偏好。数据集,包括对话、文章写作、诗歌、摘要、编码、数学和格式化输出等各个领域,有多达240万对二值化偏好。
损失函数:由2部分组成
- 为了降低数据集中易、难样本不平衡的影响,在排名损失中添加了一个难度衰减系数,类似于focal loss,使困难样本的损失值更大,容易样本的损失值更小,防止对大量容易样本的过拟合
当p > 0.5 时,难度系数才生效,否则为1;
- 为确保奖励模型在不同训练中的输出分数的稳定性和一致性,对奖励分数引入对数障碍惩罚,将分数分布限制在-5到5的范围内
最终损失为:
online rlhf
分为两种不同的路径:快速路径用于立即的、有针对性的改进,缓慢路径用于长期的、全面的改进奖励模型。快路径和慢路径是互补的,为减轻reward hacking 行为和增强人工反馈训练的llm的性能和可靠性提供了一个自适应框架。
- fast path:在PPO训练过程中,模型倾向于朝着高奖励区域移动,这通常会暴露出更多的reward hacking场景,这些场景可以很容易地被检测到。在每轮RLHF之后,通过比较当前轮次中早期和晚期PPO模型生成的回复,构建偏好对来突出这些模式。将20到100个这样的偏好对纳入训练过程足以防止奖励模型出现相应的问题。这个过程允许快速修复奖励模型以应对新出现的hacking行为,从而增强奖励模型的可靠性和对期望结果的遵循。
- slow path: 缓慢路径旨在通过覆盖最新和有能力的模型的LLMs响应,全面提高奖励模型的上限,特别是在高奖励区域的奖励模型的可靠性和鲁棒性;模型在训练的不同阶段(包括SFT模型、早期的PPO模型和后期的PPO模型)产生的回复被用来形成成对的比较。然后,这些对被呈现给专业的人类标注团队,以标记他们的偏好。这样的过程提供了更细致和彻底的改进奖励模型,但需要大量的人工注释时间
在online rlhf的实现过程中,进行了三轮的优化。在这些周期中,在快速路径中收集了数千个偏好补丁和在线偏好数据,以更新奖励模型,并使用以前模型响应的所有现有的人类偏好数据。每一轮在线RLHF都提供了有价值的见解,使我们能够动态调整和完善奖励模型,从而提高通过人工反馈训练的语言模型的整体性能和可靠性。
长上下文微调
为了保持微调后llm的长上下文能力,我们继续使用SFT和RLHF中的长上下文预训练数据,利用了两种类型的数据:一种是来自书籍的长上下文数据,而另一种是来自GitHub存储库(超过10000 star 的repo)的长上下文数据。实验结果表明,长上下文编码数据不仅提高了LLMs的长上下文能力,而且提高了 其编码能力。
工具增强LLM
没有细说,增加了environment的角色,<|interpreter|> <|plugin|> 之类的特殊token,来区分工具和上下文
与专注于解决奖励黑客问题的Fast Path不同,Slow Path旨在通过长期、全面的细化来提高奖励模型的上限,尤其是在高奖励区域的可靠性和鲁棒性方面。这个过程通过涵盖来自最新一代和最有能力模型的响应,使用这些响应形成成对比较。然后,这些成对比较被呈现给专业人类标注员进行偏好标注。这样的过程提供了对奖励模型更细致和全面的改进,但需要大量的人类标注时间。为了提高在线RLHF的效率,在实验启动时,只使用之前所有模型累积的人类偏好。通过持续根据人类反馈更新模型,Slow Path确保奖励模型与人类偏好的复杂性和微妙性同步发展。
简单贴一个
| Dataset | Baichuan2-7B-Chat | Mistral-7B-Instruct-v0.2 | Qwen-7B-Chat | InternLM2-Chat-7B | ChatGLM3-6B | Baichuan2-13B-Chat | Mixtral-8x7B-Instruct-v0.1 | Qwen-14B-Chat | InternLM2-Chat-20B |
|---|---|---|---|---|---|---|---|---|---|
| MMLU | 50.1 | 59.2 | 57.1 | 63.7 | 58.0 | 56.6 | 70.3 | 66.7 | 66.5 |
| CMMLU | 53.4 | 42.0 | 57.9 | 63.0 | 57.8 | 54.8 | 50.6 | 68.1 | 65.1 |
| AGIEval | 35.3 | 34.5 | 39.7 | 47.2 | 44.2 | 40.0 | 41.7 | 46.5 | 50.3 |
| C-Eval | 53.9 | 42.4 | 59.8 | 60.8 | 59.1 | 56.3 | 54.0 | 71.5 | 63.0 |
| TrivialQA | 37.6 | 35.0 | 46.1 | 50.8 | 38.1 | 40.3 | 57.7 | 54.5 | 53.9 |
| NaturalQuestions | 12.8 | 8.1 | 18.6 | 24.1 | 14.0 | 12.7 | 22.5 | 22.9 | 25.9 |
| C3 | 78.5 | 66.9 | 84.4 | 91.5 | 79.3 | 84.4 | 82.1 | 91.5 | 93.5 |
| CMRC | 8.1 | 5.6 | 14.6 | 63.8 | 43.2 | 27.8 | 5.3 | 13.0 | 50.4 |
| WinoGrande | 49.9 | 50.8 | 54.2 | 65.8 | 61.7 | 50.9 | 60.9 | 55.7 | 74.8 |
| BBH | 35.9 | 46.5 | 45.5 | 61.2 | 56.0 | 42.5 | 57.3 | 55.8 | 68.3 |
| GSM-8K | 32.4 | 48.3 | 44.1 | 70.7 | 53.8 | 56.0 | 71.7 | 57.7 | 79.6 |
| Math | 5.7 | 8.6 | 12.0 | 23.0 | 20.4 | 4.3 | 22.5 | 27.6 | 31.9 |
| HumanEval | 17.7 | 35.4 | 36.0 | 59.8 | 52.4 | 19.5 | 37.8 | 40.9 | 67.1 |
| MBPP | 37.7 | 25.7 | 33.9 | 51.4 | 55.6 | 40.9 | 40.9 | 30.0 | 65.8 |
