Google最新开源大语言模型:Gemma 2介绍及其微调(上篇)

存储混合云弹性计算
  • 引言
  • 简介
  • Gemma 2模型介绍
  • 架构设计
  • 训练方法
  • 后训练优化
  • 关键发现:知识蒸馏的影响
  • 性能评估
  • 使用
  • 体验:Hugging Chat
  • 如何提示 Gemma 2
  • 基于Hugging Face Transformers
  • 结论与展望
  • 模型汇总
引言

两岸荔枝红,万家烟雨中。

picture.image

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖荔枝的小男孩。Google 最近谷歌发布了开源大语言模型 Gemma 2,目前可以在 huggingface 上找到 4 个开源模型(2 个基础模型和 2 个微调模型)。今天这篇小作文主要介绍Gemma 2的一些技术特点及其使用初体验,下一篇小作文将介绍如何微调Gemma 2模型。

技术报告原文:https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf

简介

2024年6月27日,Google DeepMind发布了Gemma 2,这是Gemma系列轻量级开放语言模型的最新成员。Gemma 2在架构和训练方法上都有重大创新,在多项基准测试中取得了显著进步,甚至可以与参数规模大2-3倍的模型相媲美。本文将对Gemma 2技术报告的主要内容进行解读,包括模型架构、预训练和后训练方法、性能评估等方面。

Gemma 2 模型的训练数据量约为其第一代的两倍,总计 13 万亿 Tokens(27b 模型)和 8 万亿 Tokens(9b 模型)的网页数据(主要是英语)、代码和数学数据。官方尚未透露训练数据混合的具体细节。

Gemma 2模型介绍

Gemma 2系列包括三种规模的模型:27B(270亿参数)、9B(90亿参数)和2.6B(26亿参数)。其中27B和9B模型已经发布,2.6B模型即将发布。这些模型均采用Transformer解码器架构,具有以下主要特点:

  • 上下文长度为8192个token
  • 使用旋转位置编码(RoPE)
  • 采用近似GeGLU非线性激活函数
  • 交替使用局部滑动窗口注意力和全局注意力
  • 应用logit软上限技术
  • 使用RMSNorm进行后归一化和前归一化
  • 9B和27B模型采用分组查询注意力(GQA)机制

架构设计

相比Gemma 1,Gemma 2在架构上有几项重要改进:

(1) 局部滑动窗口和全局注意力交替

Gemma 2在每隔一层交替使用局部滑动窗口注意力和全局注意力。局部注意力的窗口大小为4096个token,全局注意力的跨度为8192个token。这种设计兼顾了计算效率和长程依赖建模能力。

(2) Logit软上限

在每个注意力层和最终输出层,Gemma 2对logit值进行软上限处理,将其限制在[-soft_cap, +soft_cap]范围内。具体实现为:

logits ← soft_cap * tanh(logits/soft_cap)

9B和27B模型的注意力logit上限为50.0,最终logit上限为30.0。这项技术有助于稳定训练过程和生成质量。

(3) RMSNorm归一化

Gemma 2使用RMSNorm来归一化每个Transformer子层、注意力层和前馈层的输入和输出,以提高训练稳定性。

(4) 分组查询注意力(GQA)

9B和27B模型采用了GQA机制,组数为2。实验表明,GQA可以在保持下游任务性能的同时提高推理速度。

训练方法

Gemma 2的另一大亮点是采用了创新的训练方法,尤其是对较小规模模型(9B和2.6B)使用知识蒸馏技术。

(1) 预训练数据

27B模型在13万亿token的数据上训练,9B模型使用8万亿token,2.6B模型使用2万亿token。训练数据主要为英语,来源包括网页文档、代码和科学文章等。tokenizer采用SentencePiece方法,词表大小为256k。

(2) 知识蒸馏

对于9B和2.6B模型,Gemma 2团队创新性地将知识蒸馏应用于大规模预训练。具体做法是:使用一个大型模型作为教师,最小化学生模型与教师模型在每个token的条件概率分布之间的负对数似然。这种方法可以在有限的计算资源下,让小模型获得类似于训练更多token的效果。

(3) 计算基础设施

Gemma 2使用了TPUv4、TPUv5e和TPUv5p进行训练,采用数据并行和模型并行相结合的策略。例如,27B模型在6144个TPUv5p芯片上训练,采用768路数据复制和8路模型分片。

后训练优化

为了进一步提升模型性能和安全性,Gemma 2在预训练后进行了一系列优化:

(1) 监督微调(SFT)

在英语指令-响应对数据集上进行监督微调,数据包括合成和人工生成的样本。

(2) 人类反馈强化学习(RLHF)

在SFT基础上应用RLHF,奖励模型基于标注的英语偏好数据训练,策略使用与SFT相同的提示。

(3) 模型融合

将不同超参数实验得到的模型进行平均,以提升整体性能。

关键发现:知识蒸馏的影响

技术报告重点介绍了知识蒸馏对小型语言模型性能的显著影响:

  • 在500B token上训练的2.6B模型,使用知识蒸馏比从头训练在3项基准测试的平均分上高出7.4个百分点。
  • 随着模型规模增加(从200M到1B参数),知识蒸馏带来的收益依然存在。
  • 在9B规模上,使用GQA替代多头注意力(MHA)对性能影响不大,但可以减少参数量并提高推理速度。
  • 对于9B模型,更深的网络结构略优于更宽的结构。
  • 可以在推理时调整局部注意力的滑动窗口大小,对困惑度的影响较小,为推理速度优化提供了灵活性。

性能评估

Gemma 2在多项基准测试中展现出优异性能:

  • 在同等规模的开放模型中达到最佳性能
  • 甚至可以与参数量大2-3倍的模型相竞争
  • 在问答、常识推理、数学科学、编程等多个领域的任务上表现出色

具体评估涵盖了以下方面:

  • 自动化基准测试
  • 人工评估
  • 问答能力(如SQuAD、Natural Questions)
  • 常识推理(如WinoGrande、HellaSwag)
  • 数学和科学(如MATH、MMLU)
  • 代码能力(如HumanEval、MBPP)
使用

体验:Hugging Chat

你可以在 Hugging Chat 上与 Gemma 27B 指令模型聊天!查看此链接:https://huggingface.co/chat/models/google/gemma-2-27b-it

如何提示 Gemma 2

基础模型没有提示格式。像其他基础模型一样,它们可以用于继续输入序列的合理延续或零样本/少样本推理。指令版本有一个非常简单的对话结构:


        
          
<start_of_turn>user  
knock knock<end_of_turn>  
<start_of_turn>model  
who is there<end_of_turn>  
<start_of_turn>user  
LaMDA<end_of_turn>  
<start_of_turn>model  
LaMDA who?<end_of_turn><eos>  

      

必须精确地复制此格式才能有效使用。稍后我们将展示如何使用 transformers 中的聊天模板轻松地复制指令提示。

基于Hugging Face Transformers

随着 Transformers 版本 4.42 的发布,你可以使用 Gemma 并利用 Hugging Face 生态系统中的所有工具。要使用 Transformers 使用 Gemma 模型,请确保使用最新的 transformers 版本:


        
          
pip install "transformers>=4.42.3" --upgrade  

      

以下代码片段展示了如何使用 transformers 使用 gemma-2-9b-it。它需要大约 18 GB 的RAM,适用于许多消费者 GPU。相同的代码片段适用于 gemma-2-27b-it,需要 56GB 的 RAM,使其非常适合生产用例。通过加载 8-bit 或 4-bit 模式,可以进一步减少内存消耗。


        
          
from transformers import pipeline  
import torch  
  
pipe = pipeline(  
    "text-generation",  
    model="google/gemma-2-9b-it",  
    model_kwargs={"torch\_dtype": torch.bfloat16},  
    device="cuda",  
)  
  
messages = [  
    {"role": "user", "content": "请模仿蔡坤的方式说两句土味情话"},  
]  
outputs = pipe(  
    messages,  
    max_new_tokens=256,  
    do_sample=False,  
)  
assistant_response = outputs[0]["generated\_text"][-1]["content"]  
print(assistant_response)  

      

输出结果如下:


        
          
1. 你就像一碗热腾腾的番茄鸡蛋面,让我欲罢不能!  
2. 你的眼睛像星星一样闪亮,照亮了我这颗孤独的心!  
  
  
希望你喜欢我的土味情话!😜  

      

这里我们使用 bfloat16 因为这是指令调优模型的参考精度。在你的硬件上运行 float16 可能会更快,90 亿模型的结果应该是相似的。然而,使用 float16 时,270 亿指令调优模型会产生不稳定的输出:对于该模型权重,你必须使用 bfloat16。

你还可以自动量化模型,以 8-bit 甚至 4-bit 模式加载。加载 4-bit 模式的 270 亿版本需要大约 18 GB 的内存,使其兼容许多消费者显卡和 Google Colab 中的 GPU。这是你在 4-bit 模式下加载生成管道的方式:


        
          
import os  
from transformers import pipeline  
import torch  
  
init_model_dir = "/share\_model\_zoo/LLM/"  
# init\_model\_id = "google/gemma-2-9b"  
init_model_id = "google/gemma-2-9b-it"  
init_model_path = os.path.join(init_model_dir, init_model_id)  
  
pipeline = pipeline(  
    "text-generation",  
    model=init_model_path,  
    model_kwargs={  
        "torch\_dtype": torch.bfloat16,  
        "quantization\_config": {"load\_in\_4bit": True}  
    },  
    device_map="auto",  
)  
  
messages = [  
    {"role": "user", "content": "请模仿蔡坤的语气说两句土味情话"},  
]  
outputs = pipeline(  
    messages,  
    max_new_tokens=256,  
    do_sample=False,  
)  
assistant_response = outputs[0]["generated\_text"][-1]["content"]  
print(assistant_response)  

      

输出结果如下:


        
          
1.  宝贝,你就像一朵盛开的玫瑰,香气迷人,让我忍不住想把你捧在手里,永远守护着。  
2.  你是我这辈子最想遇见的人,就像天上的星星,照亮我的整个世界,让我不再迷茫。  
  
  
希望你喜欢这些土味情话!  

      

有关使用 Transformers 模型的更多详细信息,请查看模型卡。

结论与展望

Gemma 2代表了轻量级开放语言模型的重大进展。通过创新的架构设计和训练方法,尤其是知识蒸馏技术的大规模应用,Gemma 2在有限的参数规模下实现了卓越的性能。这不仅推动了语言模型技术的发展,也为更广泛的AI应用提供了高效、实用的解决方案。

未来的研究方向可能包括:

  • 进一步优化知识蒸馏技术
  • 探索更高效的模型架构
  • 增强模型的多语言和跨领域能力
  • 改进模型的安全性和可控性

总的来说,Gemma 2的成功表明,通过精心的设计和创新的训练方法,我们可以在不盲目追求参数规模的情况下,显著提升语言模型的能力。这为构建更高效、更易部署的AI系统开辟了新的可能性,有望推动AI技术在更多实际场景中的应用与落地。

模型汇总
模型百川2千问Qwen1天工Gemma 2
参数量7B,13B7B,14B13B2.6B, 9B,27B
预训练token数2.6万亿3万亿3.2万亿2.6B:2万亿;9B:8万亿;27B:13万亿
tokenizerBPEBPEBPEBPE
词表大小125696152K65536256k
位置编码7b:RoPE ; 13b:ALiBi (影响不大)RoPERoPERoPE
最长上下文4096训练时2048;推理时8K40968192
模型外推--NTK插值、窗口注意力、LogN注意力缩放等技术来提升模型的上下文长度----
激活函数SwiGLUSwiGLUSwiGLUGeGLU
归一化Layer Normalization; RMSNormPre-Norm; RMSNormPre-Norm; RMSNormPre-Norm; RMSNorm
注意力机制xFormers2Flash AttentionFlash Attention V2GQA
优化器AdamW+NormHead+Max-z损失AdamWAdamW--
特色Infrastructure、Scaling Laws--两阶段的预训练滑动窗口注意力,Logit 软上限,知识蒸馏,模型合并
0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论