运用知识蒸馏(KD)构建小语言模型

火山方舟向量数据库大模型

点击下方 卡片 ,关注“ 慢慢学AIGC ”

picture.image

语言模型知识蒸馏的技术与实践

近期大型语言模型在语言建模和生成任务上都展现出了令人瞩目的成果。值得注意的是,这些结果与模型的规模成正比 - 包括参数数量、训练数据规模和GPU计算时间。这些模型的输出可以通过两种最常见的方法进行定制 - RAG和使用自定义数据集进行微调。本文将探讨RAG的一些常见缺陷,语言模型微调的挑战,并概述知识蒸馏(KD)的基本概念,同时提供一个实际示例来说明其应用。

RAG代表 检索增强生成 。在RAG中,应用程序逻辑负责检索与用户查询在语义上相关的内容。这些内容与提示一起发送给语言模型。简而言之,这种方法依赖于检索策略、输入数据的质量和提示。它还依赖于语言模型的单次学习能力,因此能接受更多输入的较大语言模型可以生成更好的响应。然而,除非正确实施,否则这种解决方案容易产生有偏见的响应、重复和幻觉。随着文本语料库的增长,为用户查询识别语义正确的输入变得越来越困难。

定制语言模型响应的另一种方法是 特定任务的微调 。在这种方法中,模型使用标记数据集进行微调。特定任务的微调可以提高响应质量,克服单次学习的局限性。在针对特定任务进行微调时,模型会使用不同的超参数和大型标记数据集进行多次迭代训练。在此过程中,最佳模型的快照、激活和梯度会存储在GPU内存中。默认的微调方法存在以下问题:

  • 微调成本 : 特定任务的微调需要几百小时(取决于数据大小)的多核GPU和GB级的RAM。生产系统应该能够快速适应数据的变化,这使得微调成为一项持续性活动。数据科学家可能还想尝试不同的模型配置,这使得迭代方法更加昂贵。粗略估计,使用5K个token微调一个模型需要约6小时。下表显示了在Azure Open AI上使用不同模型进行500小时训练(或450K个token)的成本。

picture.image

  • 在资源受限环境中的部署 : 在推理过程中,这些模型需要大量内存才能获得更好的性能。 模型中的参数数量决定了所需内存的大小。 例如,一个70亿参数的模型需要超过14GB的内存。 这仅仅是为了以半精度浮点格式加载参数。 显然,这超出了大多数边缘设备的能力。 有几种方法可以在不过多妥协质量的情况下压缩模型大小,如PEFT方法 - LoRA、QLoRA、量化、知识蒸馏。 本文主要关注大型语言模型中的知识蒸馏。

如果你已经了解这些概念并想查看代码,可以直接访问我的notebook。在这个notebook中,我使用前向KL散度将从微调的T5-small教师模型中蒸馏知识到更小的T5-small学生模型。下表中可以看到,在不损害性能的情况下,模型大小的减少是显著的。通过更多的训练时间,模型大小可以进一步减小。

picture.image

如果你想了解知识蒸馏的基础知识,请继续阅读。

什么是知识蒸馏?

知识蒸馏是一种旨在将大型复杂模型(教师模型)压缩成更小更简单的模型(学生模型)的技术,同时在一定程度上保持教师的性能 。知识蒸馏并不是一种新方法,它最初由Critstian Bucilua等人在2006年的这篇论文(https://dl.acm.org/doi/10.1145/1150402.1150464)中提出。以前,KD被应用于任何具有有效大型架构的神经网络模型。

可以使用教师模型的响应(或logits)(也称为基于响应的知识蒸馏)、教师模型的权重和激活(也称为基于特征的知识)以及模型参数之间的关系(也称为基于关系的知识)来从教师模型进行知识蒸馏。本文主要关注在大型语言模型中使用基于响应的知识。下图展示了使用来自较大教师模型的基于响应的知识进行知识蒸馏。

picture.image

大型语言模型中的基于响应的知识

在基于响应的知识蒸馏中,核心思想是使用教师的输出作为学生的软标签。学生被训练以预测教师的软标签而不是实际标签。通过这种方式,学生可以从教师的知识中学习,而无需访问教师的参数或架构。使用这种方法,知识蒸馏可以以两种方式应用 - 白盒KD和黑盒KD。

  • 黑盒KD : 在黑盒KD中,只有教师模型的提示和响应对可用。这种方法适用于不预测logits的模型。

  • 白盒KD : 在白盒KD中,使用教师模型的对数概率。白盒KD仅适用于产生logits的开源模型。

如何有效地使用logits来蒸馏知识是一个活跃的研究领域。在附带的notebook中,我们看到了一种非常基本的KD形式。

损失函数

为什么我们认为这种方法会有效?理解这个问题最简单的方法是学习损失函数。损失函数包含3个关键组成部分 - 教师的logits、学生的logits和温度。

教师模型的logits代表在应用任何非线性激活函数之前的最纯粹的预测形式。同样,学生模型也产生logits。任何两个类别的logits都不能直接比较,因此我们需要对logits进行归一化。我们应用softmax等非线性激活函数来归一化logits。归一化后的logits代表N个类别上的概率分布,也称为软标签。我们的目标是减少教师模型和学生模型的概率分布之间的差异,这样学生模型的行为就更像教师模型。

Kullback-Leibler散度损失(或KL散度损失)是计算任何两个概率分布之间差异的一种方法。在KD中,我们使用KL散度损失来通过教师模型分布改进学生模型的学习。以下等式描述了KD损失。

picture.image

在上述等式中,温度是一个整数类型的超参数,用于控制软标签的重要性。 温度使学生模型能够从软标签的微小差异中更好地学习。 Softmax有助于放大logits之间的微小差异,对数softmax软化梯度,防止梯度爆炸/消失。 因此,对学生logits应用logsoftmax。 除了KD损失外,还使用表示分类误差的交叉熵损失来改善学生模型的错误分类。 交叉熵损失定义为:

picture.image

总KD损失定义为 - KL损失 + 交叉熵损失,然后将最终损失反向传播以微调学生模型。可以使用离线微调的教师模型来训练学生模型,如notebook中所述。另一种替代方法是与学生模型一起微调教师模型。

总结

知识蒸馏相比基于RAG的实现有几个优势,特别是对于大型语言模型。一些好处包括:

  • 降低模型的计算成本和内存占用,使其更容易在不同设备上部署和运行。

  • 提高模型的泛化能力和鲁棒性,因为学生可以从教师的隐式正则化和噪声平滑中学习。

  • 增强模型的可解释性和可解释性,因为学生可以具有比教师更简单和更透明的结构。

微调大型模型的响应可以以不同方式用于增强学生模型的学习能力 - 前向KL散度就是这样一种方法。这是一个活跃的研究领域,你可以在MiniLLM中找到一些KD的替代方法(https://arxiv.org/abs/2306.08543)。

本文由 AI 生成。点击左下角“阅读原文”直达英文原文,发表时间为 2024 年 2 月 24 日,作者: Srikanth Machiraju, 微软云解决方案架构师 | AI & ML专业人士 | 出版作家 | 研究生。


扫描下方 二维码 ,关注“ 慢慢学AIGC ”

picture.image

0
0
0
0
关于作者
相关资源
在火山引擎云搜索服务上构建混合搜索的设计与实现
本次演讲将重点介绍字节跳动在混合搜索领域的探索,并探讨如何在多模态数据场景下进行海量数据搜索。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论