清华再放大招 Stuffed Mamba | 基于RNN的长上下文建模中的状态崩溃与状态容量,实现近完美的 Key 检索 !

大模型机器学习数据库

点击下方卡片,关注 「AI视界引擎」 公众号

( 添加时备注:方向+学校/公司+昵称/姓名 )

picture.image

picture.image

循环神经网络(RNNs)相对于基于 Transformer 的语言模型,在序列长度方面具有线性的计算复杂度,这使得它们在推理过程中处理长序列的速度更快。

然而,大多数公开可用的RNNs(如Mamba和RWKV)都是在小于10K Token 的序列上进行训练,其在更长上下文中的有效性至今仍然不令人满意。

在本文中,作者研究了RNNs无法处理长上下文的原因,并提出了一些关键的缓解措施。

当将最先进的RNNs应用于长上下文时,有两个实际问题需要关注:

(1)无法将模型扩展到训练长度之外的输入;

(2)记忆容量上限。为了解决第一个问题,作者首先研究了状态崩溃(SC),这是一种在训练期间未遇到过的序列长度上的性能降低现象。通过受控实验,作者将其归因于由于循环状态对训练长度过于复杂的参数化导致的过拟合。

对于第二个问题,作者在长文档上训练一系列Mamba-2模型,以实证估计语言模型和 Key 检索的循环状态容量。

然后,作者提出了三种SC缓解方法,以提高Mamba-2的长度泛化能力,使模型在不发生SC的情况下处理超过1M Token 。

作者还发现, Key 检索中的循环状态容量呈指数增长到状态大小,作者在256K上下文长度上使用近完美的 Key 检索准确率实证训练了一个Mamba-2 370M模型。

1 Introduction

最近基于 Transformer 的语言模型在处理具有数千甚至数百万 Token 的长期序列时,展示了令人印象深刻的推理能力。

然而,它们依赖于注意力机制,该机制在序列长度方面扩展为平方,使得在长期序列上的推理极其昂贵。相比之下,循环神经网络(RNNs)(Bengio等人,1994年)具有恒定的状态大小。因此,在推理过程中,它们的每个 Token 的计算和内存复杂度都线性扩展为序列长度,使得它们在处理长期序列方面更加高效。

尽管RNN在处理长上下文方面具有有前途的未来,但它们的长期上下文性能远不令人满意。目前最先进(SOTA)的基于RNN的语言模型(以下简称为 RNNs以简化表达),如Mamba-1 ,Mamba-2 ,RWKV系列 和GLA 都是在小于10K Token 的序列上进行训练的。

现有研究表明,Mamba-1和RWKV-4在上下文长度超过它们的训练长度时,性能会出现严重的下降。

在本文中,作者研究了导致当前RNNs无法处理长序列的问题以及支持长序列的可能解决方案。在应用RNNs处理更长序列时,作者观察到两个关键问题。

(1)RNNs在序列长度上无法泛化。当上下文长度超过训练长度时,它们表现出异常行为,导致长序列性能较差。

(2)由于其内存大小为常数,尽管它们可以处理无限长的输入,但状态所能表示的信息量存在上限。因此,上下文记忆容量(可以记住的最大 Token 数)存在上限,超过该限制的 Token 将被遗忘。

然后作者更深入地研究了上述问题的形成。首先,作者将SOTA RNNs的长度泛化失败归因于作者称之为状态崩溃(SC)的现象。作者检查了随时间变化的记忆状态分布,并发现其崩溃是由一些主导的异常通道(爆炸值)引起的。当输出隐层表示规范化时,这些异常通道会导致其他通道的消失值。通过分析状态更新规则的各种组件,作者表明SC是由在可以记住的信息比它多时无法忘记最早 Token (通过用较小的乘数衰减状态)引起的。

在这些分析的基础上,作者提出三种无训练技术来减轻塌陷,以及一种基于在更长序列上持续训练的缓解方法。作者的方法依赖于通过减少记忆保留和插入强度、规范化重复状态或将重复性重新格式化为等效的滑动窗口状态来迫使模型忘记上下文信息。

实证结果显示,在Mamba-2上,作者无需训练的SC缓解方法使得模型在没有SC的情况下可以消耗超过1M tokens。通过在超过模型状态容量的长序列上进一步训练,作者实证验证了对于给定的状态大小,存在一个训练长度阈值,超过该阈值后,模型将不再表现出SC。

这一洞察让作者建立了状态容量和状态大小之间的关系。然后,通过训练不同大小的Mamba-2模型,作者发现状态容量是状态大小的线性函数。此外,作者在广泛使用的 Key 检索任务上进行相同的实验,并发现Mamba-2在接近完美的 Key 检索准确性下的长度是状态大小的指数函数。

实验结果得到了一个Mamba-2 370M模型,该模型在256K上下文长度上可以实现接近完美的 Key 检索准确性,明显优于相同大小的基于 Transformer 的模型在检索准确性和长度泛化方面的表现。作者的结果表明,基于RNN的模型通常使用的训练长度可能是次优的,基于RNN的长上下文建模具有巨大的潜力。

作者的贡献可以总结如下。

  1. 作者进行了首次系统性的研究,关于状态崩溃现象,该现象导致RNNs的长度泛化失败。
  2. 作者提出了三种无训练的缓解方法以及一种基于持续训练的方法,以提高RNNs的长度泛化能力。
  3. 通过分析隐表示,作者将状态崩溃归因于状态过拟合,从而建立了与状态容量之间的联系。基于这一分析,作者在状态大小为函数的情况下,实证估计Mamba-2的状态容量。
  4. 作者训练并发布了一个基于RNN的语言模型,在密码检索任务上的准确率达到近完美。该模型有3.7亿参数,是当时长度为256K的模型中,在密码检索任务上准确率最高的模型。
  5. 模型预训练权重和源代码已发布至https://www.github.com/thunlp/stuffed-mamba。

2 Related Works

基于RNN的语言模型最近受到了广泛关注,因为与基于 Transformer 的模型不同,它们的每个 Token 推理成本不会随着序列长度的增加而增加。线性注意力用核函数近似替换了 Transformer 模型中的softmax注意力,这些近似具有等价的循环形式。一些最近出现的RNN包括RWKV系列 ,Mamba系列,Gated Linear Attention等。这些模型在许多语言处理任务上表现出了强大的能力,有时甚至超过了基于 Transformer 的模型。然而,正如作者将通过实证方式展示的那样,其中一些模型在训练长度之外很难进行有效的外推。

一些基于 Transformer 的模型采用了滑动窗口注意力(Beltagy等人,2020;Jiang等人,2023),这实际上使它们变成了RNNs。然而,这些模型在长上下文任务上的表现不佳,无法扩展到非常长的上下文(Zhang等人,2024)。

近年来,大多数基于 Transformer 架构的AI语言模型都采用了某些位置编码的变体。这些模型在使用某些位置编码变体时,可以处理任意长的序列。然而,在训练序列之外的位置,它们表现出严重的性能下降。

为了解决这个问题,许多研究都关注于修改位置编码,有些甚至实现了训练无关的长度泛化。类似地,本研究也探讨了增强基于RNN的模型长度泛化能力的一些后训练修改方法。

一些并行工作探讨了通过控制离散化项(方程3中的)来扩展Mamba的上下文长度(Ben-Kish等人,2024年)。

例如,将其除以一个常数以使其变小1。这实际上使记忆衰减因子(方程4中的)更接近1,使状态保留更多的上下文信息。然而,这也无必要地减小了所有 Token 插入的信息。

3 Preliminary

该研究中大多数实验关注Mamba-2(Dao和Gu,2024年),因为它在多个任务上显示出强大的能力,并且有多大小不一的公开预训练权重,使作者能够探索状态大小和长度限制之间的关系。

此外,与其他RNNs相比,它受到的广泛关注更多,因此使用现有工作作为参考更为容易。

Mamba-2架构由L个Mamba-2层堆叠而成,每个Mamba-2层包含H个并行计算的head,该层的输出为每个head输出的和。每个层中的每个head可以表示为以下形式。

picture.image

在此处, 表示当前时间步,, 分别表示第 个 Token 的输入和输出隐层表示, 表示 RMS 归一化(Zhang 和 Sennrich,2019),,, 是可训练参数, 表示逐元素乘积, 是超参数,分别表示隐层维度、状态维度和头维度。其余变量按以下方式参数化:

(5) (6) (7) (8) (9)picture.image

其中 是可训练模型参数。CNN(卷积神经网络)表示一维的卷积层。σ 表示SiLU函数(Elfwing等人,2017)。

4 State Collapse

作者首先研究了状态崩溃(SC)现象——该现象导致RNN模型在输入长度超过训练过程中见过的长度时表现出异常行为。作者分析了SC对语言建模和 Key 检索任务的影响。

然后作者追踪SC到状态更新规则的组成部分,并提供了关于训练长度过拟合的解释。

最后,作者提出三种训练免费的缓解方法,通过修改更新规则和一种基于在更长序列上持续预训练的方法来避免过拟合。

Length Generalization Failure

语言模型损失图如图1所示,展示了Mamba-2和RWKV-6在训练长度之后的语言模型损失。为了实现可控制性和合成任意长度的 Prompt ,该损失在仅包含"r"字符(作者称之为"换行" Prompt )的 Prompt 上计算。然而,作者强调,在处理预训练语料库中的文本时,同样存在相同的观察结果。

结果表明,当上下文长度远大于它们的训练长度时,两种RNN都遭受了巨大的性能退化,收敛到随机猜测的损失附近。

picture.imageKey 检索评估语言模型可能不能反映下游能力,因此,作者在 Key 检索任务上评估了几种强大的RNN ,这是一个简单的合成任务,其中模型被 Prompt 从一个较长的上下文中回忆一个5位数的 Key 。其他RNN的超参数和结果可以在附录B和D中找到。Mamba-2的结果如图2所示。作者发现Mamba-2模型在训练长度之后的序列上无法泛化。

例如,Mamba-2在8K上下文窗口上训练,在8K上下文内的检索精度接近完美(除了较小的130M预训练权重),但在16K上下文之后,甚至无法获得良好的或零的检索精度,而模型大小无关。

picture.image这一行为是意外的,因为更新规则(方程3)具有稳定的指数衰减(如果变量固定,它将收敛到一个常数值)。因此,作者预计这种形式的RNN在最后个 Token 上的检索精度应该很好,而早于此的 Token 会被遗忘。

这一意外发现还暗示着,在处理的长度超过训练长度的上下文时,最好只保留最后个 Token ,并丢弃其余内容。然而,在在线推理场景中,这并不容易实现,因为所有 Token 信息都压缩到一个状态中。

What is the Cause of State Collapse?

由于递归状态的维数在时间上不会改变,状态塌陷期间行为的剧烈变化必须是由状态值的变化所导致。作者在Mamba-2 370M中的每个层的递归状态统计数据中发现,当上下文长度超过训练长度时,某些 Head 的均值和方差会急剧变化,如图4所示。

一个具有爆炸方差的 Head 在t=20K时的状态如图5所示。从图中,作者发现这种方差爆炸可以很大程度上归因于少数异常通道,而大多数通道相对稳定。

picture.image作者强调SC(状态崩溃)的发生很大程度上与 Prompt 无关,它在预训练数据样本和无意义的文本中都会发生,即使 Prompt 仅包含空白字符(图1(a)和(b)显示了SC在两个不同 Prompt 上的情况)。这意味着插入到状态中的信息不会导致其崩溃。

为了进一步归因于导致SC的特定变量,作者检查了在各种塌陷状态下的,和的值。图6报告了检查的一个示例,作者可以看到,与和相比,相对稳定,尽管它们都是的函数。作者还发现,比更早爆炸。因此,作者得出结论,塌陷主要归因于。进一步的检查发现,生成和(式8和10)的卷积权重在方差上明显大于(式9)的权重。作者留待将来进行更深入的归因研究。

picture.image#### 4.2.1 State Collapse as a Result of State Overparameterization

在这里,作者为SC提出一个高层次的解释。作者主张,SC源于状态过拟合相对于训练长度。换句话说,训练长度的状态容量过大,使得模型在没有学习如何忘记状态即将溢出时,能够实现强大的语言建模性能。为了支持这一论点,作者将隐藏状态表示为以前插入信息的加权和:

picture.image

因此,描述了在时刻第个 Token 的记忆强度。图7显示了不同时间步长下第一个 Token 的记忆强度,作者发现爆炸头(第38层中的2、4和7个头)有强烈的倾向于保留训练长度内的所有信息,在=8K时的记忆强度超过0.8。这意味着模型没有学会通过产生较小的衰减来忘记信息(以避免状态过载)。

此外,作者从零开始预训练Mamba-2,训练周期为,并在通过 Key 检索的中间预训练权重上评估,如图8所示。它表明,SC仅在训练达到一定量后出现,这与过拟合的行为一致——这是由于超参数化导致的。

人们还可以注意到,过拟合的预训练权重在较短序列上的性能优于早期预训练权重,这进一步增强了模型收敛到更少遗忘的假设。最后,如作者在第4.3.2节中将要展示的,对于给定的训练长度,存在一个状态大小,在该状态下,SC将出现,而且只有当模型的状态大小大于该值时,才会出现。

picture.image### How to Mitigate State Collapse?

根据前文的分析,作者提出了一些序列长度的SC缓解方法,以使模型在序列长度上更好地泛化。简而言之,作者通过修改更新规则避免了状态溢出,提出了三种无训练方法。

此外,作者还直接在更长的序列上进行训练,以鼓励模型在上下文过长时平稳地忘记最早的信息。

4.3.1 Training-Free Mitigation Methods

方法1:忘记更多,记住更少在SC(短期记忆)期间,状态的方差会爆炸。作者可以通过增加状态衰减(即忘记更多)或减少插入的信息(即记住更少)来降低这一现象。

根据第4.2节中的分析,作者选择干预组件和,分别控制插入强度和记忆衰减强度。现有研究已经尝试过修改,但它同时控制插入和衰减强度,使得分析和控制变得困难。

方法2:状态规范化 主要思想是,在每个更新后对状态进行规范化,以确保状态的范数始终小于一个阈值。具体而言,作者在每个时间步将状态衰减,以确保。因此,作者得到以下更新规则。

picture.image

值得提及的是,这使得模型变成了非线性RNN,并且无法像原始模型那样以相同的方式并行化,这使得它在预填充阶段变得非常慢。

方法3:状态差滑动窗口

作者可以利用状态 可以表示为加权求和(等式11)的事实,来实现一个滑动窗口机制,而无需在每个步骤从头开始重新处理。用 表示窗口大小, 表示在时间步 上应用模型在最后 个 Token 处的隐藏状态。然后作者可以像计算两个状态的差一样精确地计算 :

在 Stream 生成过程中,作者只需要维护3,并逐个并行推进它们。然而,直接计算可能会因为浮点数精度问题而出现不稳定性。因此,作者维护,并在每一步重新计算,这样既保证了计算效率,又使得翻译后的内容忠实于原文。

此方法适用于所有可以表示为加权和的形式的RNNs,包括RWKV 5和6、RetNet、GLA等。在生成过程中,它将计算和内存成本加倍,但作者认为这是可以接受的权衡,因为与基于 Transformer 的模型相比,RNNs的生成成本非常低,且上下文处理成本不变。

4.3.2 Training on Longer Sequences

根据假设,SC(状态超参数化)是由状态超参数化(见第4.2.1节)引起的,作者可以简单地在超过状态容量的长度上进行训练,这在本文中作者将进行。

为了确保数据包含尽可能多的长期结构,作者过滤掉了少于4K个 Token 的序列。Buckman和Gelada(2024)指出,这对于训练有效的长语义模型至关重要。尽管作者在训练的序列中使用的长度大于4K个 Token ,但作者没有使用更高的长度阈值,因为上述阈值已经删除了原始语料库中大约97.6%的数据。为了在更长序列上进行训练,作者只需将序列连接起来,并用特殊的EOS(结束 Token ) Token 分隔。

截断反向传播通过时间 在原始的Mamba-2中,每个数据样本的状态都初始化为零。而作者现在将状态设置为前一个序列的最终状态。这相当于将多个序列连接在一起,但只在某些间隔处停止梯度反向传播。这种技术已被证明可以帮助延长RNNs(Yang等人,2023年)的上下文长度,并减轻计算梯度所需的缓存激活的记忆成本。根据Yang等人(2023年)和作者的初步测试,作者默认使用这种技术将12个序列连接在一起。

5 State Capacity

根据第4.2.1节讨论的结果,SC仅当训练长度包含的信息少于状态容量时发生。因此,作者可以通过研究不同状态大小下不同训练长度的关系,间接估计状态容量。在本节中,作者实证研究了状态容量与状态大小之间的关系。

具体而言,作者在第4.3.2节中进行的训练相同。为了确定一个状态是否已崩溃,作者将"newlines" Prompt 喂给模型1M个 Token ,并定义崩溃为在训练长度内对数失真超过2倍的最大对数失真。作者训练了多个具有不同状态大小和训练长度的Mamba-2,并将SC不发生时的最小训练长度视为状态容量。

State Capacity in Passkey Retrieval

语言建模性能可能无法很好地反映下游能力(Fu等人,2024年)。因此,作者也研究了在 Key 检索任务上的状态容量。与前面的部分类似,作者使用不同长度的状态大小进行训练,并确定模型在95%以上准确率的最大上下文长度,作者认为这是在 Key 检索中的_状态容量。在这个任务中,噪声上下文是重复的,因此上下文信息量很大程度上独立于上下文长度,因此,容量应大致呈指数增长与状态大小。

值得强调的是,如果作者训练Mamba-2在 Key 检索数据上,理论上模型可以通过忽略所有无关 Token 来处理无穷长的上下文。在这里,模型只通过预测下一个 Token 进行训练,这意味着模型不会忽略无关上下文,而长时间保持信息的能力来源于语言建模。

6 Experiments

作者简要描述了长度外推实验的实验细节。

作者从RedPajama-V2(Computer,2023)开始,这是一个来自CommonCrawl4的30T Token 的开放数据集,作者进行去重以保证数据质量。在评估过程中,作者采样长度超过16K Token 的文档并连接它们,如果不够长。

作者尝试了七个具有不同状态大小的模型配置,以找到状态容量和大小之间的关系。对于其中的每一个,作者都进行了256K个 Token 长度的广泛搜索。为了节省成本,作者从Mamba-2的三个官方预训练权重(大小分别为130M、370M和780M)中继续预训练。这些模型都在8K个序列上进行训练。其余三个模型配置(36M、47M和85M)则从零开始训练。详细的配置见附录F.1。

作者使用 WSDAE(Hu等人,2024年)的 10% 衰减步骤。这种调度器被选择,因为它在竞争中具有与常见余弦调度器相媲美的性能,同时允许从中间预训练权重简单恢复,从而节省大量计算资源。作者在验证通过 Key 检索的过拟合测试中报告了最佳预训练权重选择的验证结果。更多信息请参阅附录F中的更多超参数。

7 Results

Training-Free Length Generalization

图9报告了在Mamba-2 780M上的训练无长度泛化方法的结果。作者可以看到,尽管LongMamba5可以通过比增加3倍以上的长度泛化能力来显著改进模型的长度泛化,但在较短的序列上会导致明显的更大对数熵,并且仍然不可避免地出现SC。作者的所有方法都成功地抑制了SC,使得模型可以泛化到超过64K个 Token ,尽管状态归一化在较短序列上的性能大大低于其他方法。这种低性能的一个解释是,归一化塌陷状态改变了头之间的规范比例,这破坏了学习的机制。

picture.image### Length Generalization by Training on Longer Sequences

在图10中,作者绘制了Mamba-2 130M和370M在不同训练长度下的语言建模困惑度。作者可以看到,对于每个模型大小,都有一个训练长度阈值,超过该阈值后,模型的长度外推效果显著改善,这支持了作者在第4.3.2节中讨论的论点。

picture.image### State Capacity as a Function of State Size

图12显示了Mamba-2在语言建模和 Key 检索任务上的状态容量。两幅图中最右侧的数据点对应于Mamba-2的370M。作者已经确认,780M在训练长度低于128K时也表现出SC,但由于资源有限,作者无法将其训练超过这个长度。结果建立了训练停止时SC不出现的长度Ttrain与状态大小S之间的线性关系。

picture.image图12的第二个图显示,Mamba-2在 Key 检索任务上的容量是关于状态大小呈指数增长的。这是因为上下文中的信息量并不随其长度增加。换句话说,作者存储的信息量是恒定的,而状态的组合数量则随着元素数量的指数增长。图11显示了Mamba-2在370M参数下的最佳预训练权重在 Key 检索任务上的表现。

结果非常令人鼓舞,因为作者所知,没有任何具有少于10亿模型参数的先前模型在128K Token 的任务上具有近完美的准确度。

picture.image8 Conclusion

这篇论文首次系统地研究了状态崩溃(SC),这是一种在RNN中导致长度泛化失败的的现象。

通过检查激活和Mamba-2的受控实验,作者得出结论,这种现象是由过度参数化的状态和过度的状态容量引起的。基于分析,作者提出了三种不需要训练的方法来减少SC至1M Token 。

然后,作者证明,通过在超过状态容量的上下文长度上进行训练,可以减轻SC。

有了这个洞见,作者在语言建模和 Key 检索任务上实证地估计了Mamba-2的状态容量。通过一些简单的数据工程和状态初始化技巧,作者在 Key 检索任务上比现有模型实现了更好的性能。

作者的结果表明,Mamba-2不仅在处理长序列方面非常高效,而且具有巨大的性能潜力。

参考文献

[0]. Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling.

点击上方卡片,关注 「AI视界引擎」 公众号

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论