点击下方 卡片 ,关注“ 慢慢学AIGC ”
摘要
这是一篇来自 Yandex 公司的技术博客文章,介绍了他们开发的一种名为 YaFSDP 的新工具,旨在大幅加速大型语言模型的训练过程,并优化 GPU 显存消耗。文章详细解释了在多 GPU 集群上进行分布式语言模型训练时可能遇到的挑战,并介绍了一些现有的训练方法如 ZeRO 和 FSDP 及其优缺点。随后,作者阐述了 YaFSDP 方法的设计思路和具体实现细节,包括内存管理、流(stream)使用、可靠的 backward_hook 实现等。最后,作者提供了一些测试结果,显示 YaFSDP 在某些场景下可以比 FSDP 快 20% 以上,并在 Yandex 的预训练中带来 45% 的加速。该文介绍了一种新的分布式语言模型训练优化方法,对从事相关研究的人员可能有一定参考价值。
内容由 AI 生成。文章末尾点“阅读原文”可直达 Medium 英文原文。以下为正文。
在上周,我们开源了 YaFSDP 方法——一种用于大幅加速大型语言模型训练的新工具。
YaFSDP 开源链接:https://github.com/yandex/YaFSDP
在本文中,我们将探讨如何在集群上组织语言模型训练,以及可能遇到的问题。我们还将介绍 ZeRO 和 FSDP 等训练方法,并解释 YaFSDP 与它们有何不同。
多 GPU 训练的问题
在多 GPU 集群上进行分布式语言模型训练时,会面临哪些挑战?
首先,让我们考虑单 GPU 训练:
- 对新的数据批次进行前向传播并计算损失;
- 然后运行反向传播;
- 优化器更新优化器状态和模型权重;
那么使用多个 GPU 会有什么变化呢?
让我们看看在 4 个 GPU 上进行分布式数据并行(Distributed Data Parallelism)的最简单实现:
发生了什么变化?现在:
-
每个 GPU 处理较大数据批次的一部分,这使我们可以在相同内存负载下将批量大小(batch size)增加四倍;
-
我们需要同步 GPU。为此,我们使用 all_reduce 在 GPU 之间平均梯度,以确保不同 GPU 上的权重同步更新。all_reduce 操作是实现此目的的最快方式之一,它在 NCCL(NVIDIA 集合通信库)中可用,并在 torch.distributed 包中得到支持。
让我们回顾一下不同的通信操作(在本文中会多次提及):
- broadcast:将数据广播到每个进程;
- scatter:将数据分散到每个进程;
- gather:将数据从不同进程汇集到一个进程;
- all_gather:从所有进程中收集数据并广播到所有进程;
- reduce:将数据从不同进程规约到一个进程;
- all_reduce:在所有进程中规约数据并广播到所有进程;
- all_to_all:所有进程数据按分片汇集到每个进程,类似矩阵转置;
- reduce_scatter:对所有进程数据规约后分散到每个进程;
希望了解 更 多 可 以阅读这篇《 利用多 GPU 加速深度学习模型训练 》。
我们在这些通信中遇到的问题包括:
- 在 all_reduce 操作中,我们发送的梯度是网络参数的两倍。例如,当对 Llama 70B 的 fp16 梯度求和时,我们需要在每次迭代中在 GPU 之间发送 280 GB 的数据。在当今的集群中,这需要相当长的时间。
- 权重、梯度和优化器状态在 GPU 之间被重复。对于 Llama 70B 和 Adam 优化器的混合精度训练,需要超过 1 TB 的显存,而常规 GPU 显存只有 80 GB。
这意味着巨大的冗余显存负载使我们甚至无法将相对较小的模型装入 GPU 显存,而且由于所有这些额外的操作,我们的训练过程严重放缓。
有没有办法解决这些问题?
是的,有一些解决方案。其中,我们区分了一组数据并行方法,这些方法允许完全分片权重、梯度和优化器状态。对于 PyTorch,有三种这样的方法可用:ZeRO、FSDP 和 Yandex 的 YaFSDP。
ZeRO
2019 年,微软 DeepSpeed 开发团队发表了题为《ZeRO:Memory Optimizations Toward Training Trillion Parameter Models》的文章(https://arxiv.org/abs/1910.02054)。研究人员引入了一种新的内存优化解决方案——零冗余优化器(ZeRO),细节如下图所示:
从图中可见,一个 7.5B 参数量的模型使用 FP16 混合精度训练方法和 Adam 优化器,采用原始数据并行方案将对每个 GPU 都复制相同的优化器、权重、梯度等内容,导致显存开销高达 120 GB,超出了一张 A100 的显存最大容量,显然是无法实现的。
而 ZeRO-1(图中 P_os 这一行,os=optimizer states)则首先对显存占用最高的优化器下手做切分,在不同 GPU 间实现零冗余存储,这样每张 GPU 只需要存储 1/N 优化器状态,N 为 GPU 总数量。假设 N = 64 则每张 GPU 显存开销为 31.4 GB,具备了在 A100 或 V100 上的训练条件。
ZeRO-2(图中 P_os+g 这一行,g=gradients)进一步对梯度做切分,每张 GPU 显存开销下降到 16.6 GB,具备了在 3090 或 4090 上运行训练条件。
ZeRO-3(图中 P_os+g+p 这一行,p=parameters)进一步对权重做切分,每张 GPU 显存开销下降到 1.9 GB,具备了在 3060 等更低端 GPU 上运行训练的条件。
切分方式 节省了显存,代价是增加了通信量 ,是一种 时间换空间 的手法。
上面所提出的分区只是虚拟的。在前向和反向传播过程中,模型会像数据没有被分区一样处理所有张量。实现这一点的方法是异步收集张量。
这是 DeepSpeed 库在 N 个 GPU 上训练实现 ZeRO-3 的方式:
- 每个张量被分成 N 份,每一份存储在单独的进程内存中。
- 我们记录张量在第一次迭代中使用的顺序,在优化器步骤之前。
- 我们为收集的张量分配空间。在每个后续的前向和反向传播中,我们通过 all_gather 原语异步加载张量。当某个模块完成工作时,我们释放该模块张量占用的显存,并开始加载下一个张量。所有计算步骤都是并行的。
- 在反向传播过程中,一旦计算出梯度就运行 reduce_scatter。
- 在优化器步骤中,我们只更新属于特定 GPU 的权重和优化器参数。顺便说一下,这使优化器步骤相比之前数据并行方案加速了 N 倍!
如果我们每层只有一个参数张量,前向传播在 ZeRO-3 中的工作方式如下:
对每个 GPU 而言,计算步骤如下所示:
从图中可以看出:
- 通信现在是异步的。如果通信比计算更快,它们不会干扰计算或减慢整个过程;
- 现在有更多的通信;
- 优化器步骤所需时间大大减少;
在 DeepSpeed 中实现的 ZeRO 概念加速了许多语言模型的训练过程,显著优化了内存消耗。然而,也存在一些缺点:
- DeepSpeed 代码中有许多 bug 和瓶颈(新版已经好多了);
- 在大型集群上通信效率低下(千卡以上集群几乎没法直接用,得魔改);
- 所有 NCCL 集合操作都有一个奇特的原则:一次发送的数据越少,通信效率就越低(这很正常)。
假设我们有 N 个 GPU。那么对于 all_gather 操作,我们一次最多只能发送总参数数量的 1/N。当 N 增加时,通信效率会下降。
在 DeepSpeed 中,我们为每个参数张量运行 all_gather 和 reduce_scatter 操作。在 Llama 70B 中,常规参数张量大小为 8192×8192。因此,当在 1024 个 GPU 上训练时,我们一次最多只能发送 128 KB,这意味着网络利用率低下。
DeepSpeed 试图通过同时集成大量张量来解决这个问题。不幸的是,这种方法会导致许多缓慢的 GPU 显存操作,或者需要对所有通信进行定制实现。
结果,情况看起来像这样(流 7 表示计算,流 24 是通信):
显然,在集群规模增加时,DeepSpeed 往往会显著减慢训练过程。那么是否有更好的策略呢?事实上是有的。
FSDP 时代
全分片数据并行(Fully Sharded Data Parallelism,FSDP,论文见:https://arxiv.org/pdf/2304.11277)现在内置于 PyTorch,得到积极支持,并受到开发人员的欢迎。
这种新方法优势:
- FSDP 将多个层参数组合为单个 FlatParameter,在分片时进行拆分。这允许运行高效的集合通信。
- FSDP 有更用户友好的接口:
- DeepSpeed 改变了整个训练流程,更改了模型和优化器。
- FSDP 只改变模型,并且只将由该进程托管的权重和梯度发送给优化器。因此,可以在不额外设置的情况下使用自定义优化器。
-
FSDP 在常见用例中没有产生像 DeepSpeed 那样多的 bug。
-
动态图:ZeRO 要求模块始终按严格定义的顺序调用,否则它将无法理解加载哪个参数和何时加载。在 FSDP 中,你可以使用动态图。
尽管 FSDP 有这些优势,但也存在一些问题:
-
FSDP 动态分配层的显存,有时需要比实际需要的更多显存。
-
在反向传播过程中,我们遇到了一种被我们称为"避让效应"的现象。下面的分析说明了这一点:
这张图中的第一行是计算流,其他线代表通信流。
那么图中发生了什么?在 reduce_scatter 操作(蓝色)之前,有许多准备计算(通信下方的小操作)。这些小计算与主计算流并行运行,严重减慢了通信。这导致通信之间出现了大量空白,因此,在计算流中也出现了相同的空白。
我们试图克服这些问题,我们提出的解决方案就是 YaFSDP 方法。
YaFSDP
在这部分,我们将讨论开发过程,深入探讨如何设计和实现这样的解决方案。前面会有大量代码参考。如果你想了解高级的 PyTorch 使用方式,请继续阅读。
我们设定的目标是确保显存消耗得到优化,没有任何东西减慢通信速度。
为什么要节省显存?
这是一个好问题。让我们看看训练过程中显存的消耗情况:
- 权重、梯度和优化器状态都取决于 GPU 数量,随着 GPU 数量的增加,显存消耗趋近于零;
- 缓冲区只消耗固定的显存;
- 激活值取决于模型大小和每个进程的 token 数量;
事实证明,经过 ZeRO-3 和 FSDP 对优化器、梯度、权重进行切分后,随着 GPU 数量增加,激活值变成主要占用显存的部分。这没有错!对于 Llama 2 70B,当批量为 8192 个 token 且使用 Flash Attention 2 时,激活存储占用超过 110 GB (该数字可以显著减小,但这是另一个故事)。
激活值检查点能够大幅减少显存负载:对于前向传播,我们只存储 Transformers 层之间的激活值,对于反向传播,我们重新计算它们。这节省了大量显存:只需 5 GB 来存储激活值。问题是额外的冗余计算占整个训练时间的 25%。
这就是为什么释放显存以避免尽可能多的层使用激活值检查点是有意义的。
另外,如果你有一些空闲显存,某些通信的效率可以得到提高。
缓冲区
与 FSDP 一样,我们决定分片层而不是单个参数——这样,我们可以保持高效通信并避免重复操作。为了控制显存消耗,我们预先为所有所需数据分配了缓冲区,因为我们不希望 PyTorch 分配器管理该过程。
它是这样工作的: 分配两个缓冲区来存储中间权重和梯度。每个奇数层使用第一个缓冲区,每个偶数层使用第二个缓冲区 。
这样,来自不同层的权重存储在同一内存中。如果层具有相同的结构,它们将始终相同!重要的是要确保在需要层 X 时,缓冲区中有层 X 的权重。所有参数将存储在缓冲区对应内存块中:
除此之外,新方法与 FSDP 类似。我们需要以下内容:
- 缓冲区以 fp32 存储用于优化器的分片和梯度(由于混合精度)。
- 缓冲区以半精度(在我们这里是 bf16)存储权重分片。
现在我们需要设置通信:以便:
- 在层上的前向/反向传播开始之前,该层的权重被收集在其缓冲区中。
- 在层上的前向/反向传播完成之前,我们不会在该层的缓冲区中收集另一层。
- 在前一层完成使用相同梯度缓冲区的 reduce_scatter 操作之前,不会开始该层的反向传播。
- 在相应层的反向传播完成之前,不会在该缓冲区中开始 reduce_scatter 操作。
我们如何实现这种设置?
使用流
你可以使用 CUDA 流来促进并发计算和通信。
在 PyTorch 和其他框架中,CPU 和 GPU 之间的交互是如何组织的?内核(在 GPU 上执行的函数)按执行顺序从 CPU 加载到 GPU。为了避免 CPU 空闲,内核会提前加载并异步执行。在单个流中,内核始终按照加载到 GPU 的顺序执行。如果我们希望它们并行运行,我们需要将它们加载到不同的流中。请注意,如果不同流中的内核使用相同的资源,它们可能无法并行运行(还记得上面提到的"避让效应"吗),或者它们的执行可能会非常缓慢。
为了促进流之间的通信,你可以使用"事件"(event)原语(在 PyTorch 中 event = torch.cuda.Event())。我们可以将事件放入流中(event.record(stream)),然后它将作为微内核附加到流的末尾。我们可以在另一个流中等待该事件(event.wait(another_stream)),然后该流将暂停,直到第一个流到达该事件。
我们只需要两个流来实现这一点:一个计算流和一个通信流。以下是如何设置执行以确保满足上述条件 1 和 2:
在图中,粗线标记 event.record(),虚线用于 event.wait()。如你所见,第三层的前向传播不会开始,直到该层的 all_gather 操作完成(条件 1)。同样,第三层的 all_gather 操作不会开始,直到使用相同缓冲区的第一层的前向传播完成(条件 2)。由于该方案中没有循环,所以不可能发生死锁。
我们如何在 PyTorch 中实现这一点?您可以使用 forward_pre_hook(在前向传播之前在 CPU 上执行的代码)以及 forward_hook(在前向传播后执行):
这样,所有前置操作都在 forward_pre_hook 中执行。 有关 hook 的更多信息,请参阅 PyTorch 官方文档:
对于反向传播有什么不同?在这里,我们需要在进程之间平均梯度:
我们可以尝试像使用 forward_hook 和 forward_pre_hook 一样使用backward_hook 和 backward_pre_hook:
但是有一个陷阱:虽然 backward_pre_hook 的工作方式如预期,但 backward_hook 的行为可能出乎意料:
- 如果模块输入张量至少有一个不传递梯度的张量(例如注意力掩码),则 backward_hook 将在执行反向传播之前运行;
- 即使所有模块输入张量都传递梯度,也不能保证 backward_hook 会在计算所有张量的 .grad 之后运行;
因此,我们对 backward_hook 的最初实现不满意,需要一个更可靠的解决方案。
可靠的 backward_hook
为什么 backward_hook 不合适?让我们看看相对简单操作的梯度计算图:
我们对输入应用两个独立的线性层(Weight 1 和 Weight 2),并将它们的输出相乘。
梯度计算图将如下所示:
我们可以看到,在这个图中所有操作都有自己的*Backward 节点。对于图中的所有权重,有一个 GradAccum 节点,在那里参数的 .grad 被更新。该参数将由 YaFSDP 用于处理梯度。
需要注意的是,GradAccum 位于此图的叶节点。有趣的是 PyTorch 不保证图遍历的顺序。其中一个权重的 GradAccum 可能会在梯度离开该块之后执行。PyTorch 中的图执行是不确定的,并且可能在每次迭代时发生变化。
那么我们如何确保在另一层的反向传播开始之前计算出权重梯度?如果我们在没有确保满足此条件的情况下启动 reduce_scatter,它只会处理已计算梯度的一部分。在尝试寻找解决方案的过程中,我们想到了以下方案:
在每次前向传播之前,执行以下额外步骤:
-
我们将所有输入和权重缓冲区通过 GateGradFlow,这是一个基本的 torch.autograd.Function,只是将未改变的输入和梯度传递;
-
在层中,我们用权重缓冲区内存中存储的伪参数替换参数。为此,我们使用自定义的 Narrow 函数;
在反向传播过程中会发生什么:
参数的梯度可以通过两种方式分配:
- 通常,我们会在 Narrow 的实现过程中分配或添加梯度,这比我们到达缓冲区的 GradAccum 要早得多;
- 我们可以为层编写自定义函数,在其中我们将在不分配额外张量的情况下分配梯度以节省内存。然后 Narrow 将收到"None"而不是梯度,并且不执行任何操作;
通过这种方式,我们可以保证:
- 在执行 backward GateGradFlow 之前,所有梯度都将被写入梯度缓冲区。
- 在执行 backward GateGradFlow 之前,梯度不会流向输入,然后流向"backward"的下一层;
这意味着最合适的位置来调用 backward_hook 是在 backward GateGradFlow 中!在那一步,所有权重梯度已经被计算和写入,而其他层的反向传播还没有开始。现在我们已经具备了在反向传播中实现并发通信和计算所需的一切。
克服"避让效应"
"避让效应"问题在于,在 reduce_scatter 之前,通信流中会发生一些计算操作。这些操作包括将梯度复制到不同的缓冲区、防止 fp16 溢出的梯度"预除"(现在很少使用)等。
这是我们做的:
- 我们为 RMSNorm/LayerNorm 添加了单独的处理。由于这些在优化器中应该被稍微不同地处理,所以将它们分为一个单独的组是有意义的。这样的权重不多,所以我们在迭代开始时一次收集,并在最后平均梯度。这消除了"避让效应"中的重复操作;
- 由于在 bf16 或 fp32 中进行 reduce_scatter 没有溢出风险,我们将"预除"替换为"后除",将该操作移至反向传播的最后;
结果,我们消除了"避让效应",这大大减少了计算中的停顿时间:
限制
YaFSDP 方法优化了显存消耗,并允许显著提高性能。然而,它也有一些限制:
- 只有在以交替方式调用层时,才能达到峰值性能,即它们对应的缓冲区交替使用;
- 我们明确考虑了,从优化器的角度来看,只能有一大组具有大量参数的权重;
测试结果
以下是 YaFSDP 在 Llama 2 和 Llama 3 上相较于 FSDP 所达到的加速:
在小批量场景下,所实现的加速超过 20%,使 YaFSDP 成为微调模型的有用工具。
在 Yandex 的预训练中,YaFSDP 的实现以及其他显存优化策略带来了 45% 的加速。
现在 YaFSDP 已开源,您可以查看并告诉我们您的想法!请分享您的使用体验,如果有可能的 Pull Request,我们也乐意考虑。
作者简介
作者: 米哈伊尔·赫鲁晓夫,Yandex GPT 预训练团队负责人。 Yandex 是一家构建基于机器学习的智能产品和服务的科技公司, 目标是帮助消费者和企业更好地探索在线和离线世界。 自 1997 年以来,一直在提供世界级的、本地化相关的搜索和信息服务。 此外,还为全球数百万消费者开发了领先市场的按需交通服务、导航产品和其他移动应用程序。 在全球拥有 34 个办事处的 Yandex 自 2011 年起就已在纳斯达克上市。