- 引言
- 简介
- Gemma 2模型介绍
- 架构设计
- 训练方法
- 后训练优化
- 关键发现:知识蒸馏的影响
- 性能评估
- 使用
- 体验:Hugging Chat
- 如何提示 Gemma 2
- 基于Hugging Face Transformers
- 结论与展望
- 模型汇总
两岸荔枝红,万家烟雨中。
小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖荔枝的小男孩。Google 最近谷歌发布了开源大语言模型 Gemma 2,目前可以在 huggingface 上找到 4 个开源模型(2 个基础模型和 2 个微调模型)。今天这篇小作文主要介绍Gemma 2的一些技术特点及其使用初体验,下一篇小作文将介绍如何微调Gemma 2模型。
技术报告原文:https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf
2024年6月27日,Google DeepMind发布了Gemma 2,这是Gemma系列轻量级开放语言模型的最新成员。Gemma 2在架构和训练方法上都有重大创新,在多项基准测试中取得了显著进步,甚至可以与参数规模大2-3倍的模型相媲美。本文将对Gemma 2技术报告的主要内容进行解读,包括模型架构、预训练和后训练方法、性能评估等方面。
Gemma 2 模型的训练数据量约为其第一代的两倍,总计 13 万亿 Tokens(27b 模型)和 8 万亿 Tokens(9b 模型)的网页数据(主要是英语)、代码和数学数据。官方尚未透露训练数据混合的具体细节。
Gemma 2系列包括三种规模的模型:27B(270亿参数)、9B(90亿参数)和2.6B(26亿参数)。其中27B和9B模型已经发布,2.6B模型即将发布。这些模型均采用Transformer解码器架构,具有以下主要特点:
- 上下文长度为8192个token
- 使用旋转位置编码(RoPE)
- 采用近似GeGLU非线性激活函数
- 交替使用局部滑动窗口注意力和全局注意力
- 应用logit软上限技术
- 使用RMSNorm进行后归一化和前归一化
- 9B和27B模型采用分组查询注意力(GQA)机制
架构设计
相比Gemma 1,Gemma 2在架构上有几项重要改进:
(1) 局部滑动窗口和全局注意力交替
Gemma 2在每隔一层交替使用局部滑动窗口注意力和全局注意力。局部注意力的窗口大小为4096个token,全局注意力的跨度为8192个token。这种设计兼顾了计算效率和长程依赖建模能力。
(2) Logit软上限
在每个注意力层和最终输出层,Gemma 2对logit值进行软上限处理,将其限制在[-soft_cap, +soft_cap]范围内。具体实现为:
logits ← soft_cap * tanh(logits/soft_cap)
9B和27B模型的注意力logit上限为50.0,最终logit上限为30.0。这项技术有助于稳定训练过程和生成质量。
(3) RMSNorm归一化
Gemma 2使用RMSNorm来归一化每个Transformer子层、注意力层和前馈层的输入和输出,以提高训练稳定性。
(4) 分组查询注意力(GQA)
9B和27B模型采用了GQA机制,组数为2。实验表明,GQA可以在保持下游任务性能的同时提高推理速度。
训练方法
Gemma 2的另一大亮点是采用了创新的训练方法,尤其是对较小规模模型(9B和2.6B)使用知识蒸馏技术。
(1) 预训练数据
27B模型在13万亿token的数据上训练,9B模型使用8万亿token,2.6B模型使用2万亿token。训练数据主要为英语,来源包括网页文档、代码和科学文章等。tokenizer采用SentencePiece方法,词表大小为256k。
(2) 知识蒸馏
对于9B和2.6B模型,Gemma 2团队创新性地将知识蒸馏应用于大规模预训练。具体做法是:使用一个大型模型作为教师,最小化学生模型与教师模型在每个token的条件概率分布之间的负对数似然。这种方法可以在有限的计算资源下,让小模型获得类似于训练更多token的效果。
(3) 计算基础设施
Gemma 2使用了TPUv4、TPUv5e和TPUv5p进行训练,采用数据并行和模型并行相结合的策略。例如,27B模型在6144个TPUv5p芯片上训练,采用768路数据复制和8路模型分片。
后训练优化
为了进一步提升模型性能和安全性,Gemma 2在预训练后进行了一系列优化:
(1) 监督微调(SFT)
在英语指令-响应对数据集上进行监督微调,数据包括合成和人工生成的样本。
(2) 人类反馈强化学习(RLHF)
在SFT基础上应用RLHF,奖励模型基于标注的英语偏好数据训练,策略使用与SFT相同的提示。
(3) 模型融合
将不同超参数实验得到的模型进行平均,以提升整体性能。
关键发现:知识蒸馏的影响
技术报告重点介绍了知识蒸馏对小型语言模型性能的显著影响:
- 在500B token上训练的2.6B模型,使用知识蒸馏比从头训练在3项基准测试的平均分上高出7.4个百分点。
- 随着模型规模增加(从200M到1B参数),知识蒸馏带来的收益依然存在。
- 在9B规模上,使用GQA替代多头注意力(MHA)对性能影响不大,但可以减少参数量并提高推理速度。
- 对于9B模型,更深的网络结构略优于更宽的结构。
- 可以在推理时调整局部注意力的滑动窗口大小,对困惑度的影响较小,为推理速度优化提供了灵活性。
性能评估
Gemma 2在多项基准测试中展现出优异性能:
- 在同等规模的开放模型中达到最佳性能
- 甚至可以与参数量大2-3倍的模型相竞争
- 在问答、常识推理、数学科学、编程等多个领域的任务上表现出色
具体评估涵盖了以下方面:
- 自动化基准测试
- 人工评估
- 问答能力(如SQuAD、Natural Questions)
- 常识推理(如WinoGrande、HellaSwag)
- 数学和科学(如MATH、MMLU)
- 代码能力(如HumanEval、MBPP)
体验:Hugging Chat
你可以在 Hugging Chat 上与 Gemma 27B 指令模型聊天!查看此链接:https://huggingface.co/chat/models/google/gemma-2-27b-it
如何提示 Gemma 2
基础模型没有提示格式。像其他基础模型一样,它们可以用于继续输入序列的合理延续或零样本/少样本推理。指令版本有一个非常简单的对话结构:
<start_of_turn>user
knock knock<end_of_turn>
<start_of_turn>model
who is there<end_of_turn>
<start_of_turn>user
LaMDA<end_of_turn>
<start_of_turn>model
LaMDA who?<end_of_turn><eos>
必须精确地复制此格式才能有效使用。稍后我们将展示如何使用 transformers
中的聊天模板轻松地复制指令提示。
基于Hugging Face Transformers
随着 Transformers 版本 4.42 的发布,你可以使用 Gemma 并利用 Hugging Face 生态系统中的所有工具。要使用 Transformers 使用 Gemma 模型,请确保使用最新的 transformers
版本:
pip install "transformers>=4.42.3" --upgrade
以下代码片段展示了如何使用 transformers
使用 gemma-2-9b-it
。它需要大约 18 GB 的RAM,适用于许多消费者 GPU。相同的代码片段适用于 gemma-2-27b-it
,需要 56GB 的 RAM,使其非常适合生产用例。通过加载 8-bit 或 4-bit 模式,可以进一步减少内存消耗。
from transformers import pipeline
import torch
pipe = pipeline(
"text-generation",
model="google/gemma-2-9b-it",
model_kwargs={"torch\_dtype": torch.bfloat16},
device="cuda",
)
messages = [
{"role": "user", "content": "请模仿蔡坤的方式说两句土味情话"},
]
outputs = pipe(
messages,
max_new_tokens=256,
do_sample=False,
)
assistant_response = outputs[0]["generated\_text"][-1]["content"]
print(assistant_response)
输出结果如下:
1. 你就像一碗热腾腾的番茄鸡蛋面,让我欲罢不能!
2. 你的眼睛像星星一样闪亮,照亮了我这颗孤独的心!
希望你喜欢我的土味情话!😜
这里我们使用 bfloat16 因为这是指令调优模型的参考精度。在你的硬件上运行 float16 可能会更快,90 亿模型的结果应该是相似的。然而,使用 float16 时,270 亿指令调优模型会产生不稳定的输出:对于该模型权重,你必须使用 bfloat16。
你还可以自动量化模型,以 8-bit 甚至 4-bit 模式加载。加载 4-bit 模式的 270 亿版本需要大约 18 GB 的内存,使其兼容许多消费者显卡和 Google Colab 中的 GPU。这是你在 4-bit 模式下加载生成管道的方式:
import os
from transformers import pipeline
import torch
init_model_dir = "/share\_model\_zoo/LLM/"
# init\_model\_id = "google/gemma-2-9b"
init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
pipeline = pipeline(
"text-generation",
model=init_model_path,
model_kwargs={
"torch\_dtype": torch.bfloat16,
"quantization\_config": {"load\_in\_4bit": True}
},
device_map="auto",
)
messages = [
{"role": "user", "content": "请模仿蔡坤的语气说两句土味情话"},
]
outputs = pipeline(
messages,
max_new_tokens=256,
do_sample=False,
)
assistant_response = outputs[0]["generated\_text"][-1]["content"]
print(assistant_response)
输出结果如下:
1. 宝贝,你就像一朵盛开的玫瑰,香气迷人,让我忍不住想把你捧在手里,永远守护着。
2. 你是我这辈子最想遇见的人,就像天上的星星,照亮我的整个世界,让我不再迷茫。
希望你喜欢这些土味情话!
有关使用 Transformers 模型的更多详细信息,请查看模型卡。
Gemma 2代表了轻量级开放语言模型的重大进展。通过创新的架构设计和训练方法,尤其是知识蒸馏技术的大规模应用,Gemma 2在有限的参数规模下实现了卓越的性能。这不仅推动了语言模型技术的发展,也为更广泛的AI应用提供了高效、实用的解决方案。
未来的研究方向可能包括:
- 进一步优化知识蒸馏技术
- 探索更高效的模型架构
- 增强模型的多语言和跨领域能力
- 改进模型的安全性和可控性
总的来说,Gemma 2的成功表明,通过精心的设计和创新的训练方法,我们可以在不盲目追求参数规模的情况下,显著提升语言模型的能力。这为构建更高效、更易部署的AI系统开辟了新的可能性,有望推动AI技术在更多实际场景中的应用与落地。
模型 | 百川2 | 千问Qwen1 | 天工 | Gemma 2 |
---|---|---|---|---|
参数量 | 7B,13B | 7B,14B | 13B | 2.6B, 9B,27B |
预训练token数 | 2.6万亿 | 3万亿 | 3.2万亿 | 2.6B:2万亿;9B:8万亿;27B:13万亿 |
tokenizer | BPE | BPE | BPE | BPE |
词表大小 | 125696 | 152K | 65536 | 256k |
位置编码 | 7b:RoPE ; 13b:ALiBi (影响不大) | RoPE | RoPE | RoPE |
最长上下文 | 4096 | 训练时2048;推理时8K | 4096 | 8192 |
模型外推 | -- | NTK插值、窗口注意力、LogN注意力缩放等技术来提升模型的上下文长度 | -- | -- |
激活函数 | SwiGLU | SwiGLU | SwiGLU | GeGLU |
归一化 | Layer Normalization; RMSNorm | Pre-Norm; RMSNorm | Pre-Norm; RMSNorm | Pre-Norm; RMSNorm |
注意力机制 | xFormers2 | Flash Attention | Flash Attention V2 | GQA |
优化器 | AdamW+NormHead+Max-z损失 | AdamW | AdamW | -- |
特色 | Infrastructure、Scaling Laws | -- | 两阶段的预训练 | 滑动窗口注意力,Logit 软上限,知识蒸馏,模型合并 |