Abstract
大型语言模型(llm)采用需要顺序计算的自回归解码,每一步都依赖于前一步的输出。这造成了瓶颈,因为每一步都需要将完整的模型参数从高带宽内存(HBM)移动到加速器的缓存中。虽然speculative 解码等方法已被建议用于解决此问题,但其实现受到与获取和维护单独的 draft 模型相关的挑战的阻碍。
在本文中,我们提出了Medusa,一种 通过增加额外的解码头来并行预测多个后续令牌 来增强LLM推理的有效方法。利用 基于树的注意机制 ,Medusa构建了多个候选延续,并在每个解码步骤中同时验证它们。通过利用并行处理,Medusa大大减少了所需解码步骤的数量。
我们为Medusa提供了两个级别的微调过程,以满足不同用例的需求: Medusa-1 : Medusa直接在冻结的主干LLM之上进行微调,从而实现无损推理加速。 Medusa-2 : Medusa与主干LLM一起微调,可以更好地预测Medusa头部的准确性和更高的加速,但需要特殊的训练配方来保持模型的能力。此外,我们提出了几个扩展来改进或扩展Medusa的效用,包括一个 self-distillation 来处理没有可用训练数据的情况,以及一个 typical acceptance scheme 来提高接受率,同时保持生成质量。
我们在不同大小的模型和训练程序上评估Medusa。我们的实验表明,Medusa-1可以在不影响生成质量的情况下实现超过2.2倍的加速,而Medusa-2进一步提高了速度,达到2.3-2.8倍。
Introduction
大型语言模型(llm)的最新进展表明,随着模型规模的增加,语言生成的质量显著提高,达到数十亿个参数。然而,这种增长导致了推理延迟的增加,这在实际应用中提出了重大挑战。从系统的角度来看,LLM推理主要是 内存带宽限制 ,主要的延迟瓶颈来自加速器的内存带宽,而不是算术计算。 这个瓶颈是自回归解码的顺序特性所固有的,其中每个前向传递都需要将完整的模型参数从高带宽内存(HBM)传输到加速器的缓存 。 这个过程只生成一个token,没有充分利用现代加速器的算术计算潜力,导致效率低下 。
为了解决这个问题,加速LLM推理的一种方法包括 增加解码过程的 arithmetic intensity ( 总浮点操作(FLOPs)与总数据移动的比率)并减少解码步骤的数量 。根据这一想法,已经提出了speculative decoding。该方法使用较小的draft模型来生成token序列,然后由原始的较大模型对其进行细化,以获得可接受的延续。然而,获得合适的draft模型仍然具有挑战性,将draft模型集成到分布式系统中更加困难。
在本文中,我们没有使用单独的draft模型来顺序生成候选输出,而是重新审视并完善了在骨干模型之上使用 多个解码头 来加速推理的概念。我们发现,当有效应用时,该技术可以克服推测解码的挑战,允许无缝集成到现有的LLM系统中。具体来说,我们介绍了Medusa,一种通过集成额外的解码头来并发预测多个令牌来增强LLM推理的方法。这些头在 parameter-efficient 的方式微调,可以添加到任何现有的模型。由于不需要draft模型,Medusa可以轻松集成到当前的LLM系统中,包括分布式环境中的LLM系统,从而确保用户友好的体验。
我们进一步加强Medusa与两个关键的见解。 首先 ,目前的方法是在每个解码步骤生成一个候选延续导致计算资源的低效使用。为了解决这个问题,我们建议 使用美Medusa生成多个候选延续,并通过对注意力mask的简单调整同时验证它们 。其次,我们可以复用在speculative decoding 中使用的 rejection sampling 方案,以产生与原始模型相同分布的一致响应。然而,它不能进一步提高加速度。或者,我们 引入一个typical acceptance方案,从Medusa头输出中选择合理的候选者 。我们使用 temperature 作为阈值来管理与原始模型预测的偏差,为 rejection sampling 方法提供了一种有效的替代方法。我们的结果表明,所提出的typical acceptance 方案可以在保持相似的生成质量的同时进一步加快解码速度。
为了给llm配备可预测的Medusa头,我们提出了针对不同场景量身定制的两种不同的微调过程。 对于计算资源有限的情况,或者当目标是将Medusa合并到现有模型中而不影响其性能时,我们推荐使用Medusa-1 。该方法需要最小的内存,并且可以使用类似于QLoRA中的量化技术进一步优化,而不会因固定骨干模型而影响生成质量。然而,在Medusa-1中,没有充分利用骨干模型的潜力。我们可以进一步对其进行微调,以提高Medusa头的预测精度,这可以直接导致更大的加速。因此,我们引入Medusa-2,它适用于具有充足计算资源的场景或从基本模型进行直接监督微调(SFT)。 Medusa-2的关键是一个训练协议,该协议可以在不影响模型下一个令牌预测能力和输出质量的情况下,联合训练Medusa头和骨干模型 。我们根据模型的训练配方和数据集的可用性提出了不同的获取训练数据集的策略。当模型在公共数据集上进行微调时,它可以直接用于Medusa。如果数据集不可用或模型经历了人类反馈强化学习(RLHF) 过程,我们建议使用 self-distillation 方法来 生成Medusa头的训练数据集 。
我们的实验主要集中在 批处理大小为1的场景上 ,这代表了llm在本地托管供个人使用的用例。我们在不同大小和训练设置的模型上测试Medusa,包括Vicuna-7B, 13B(用公共数据集训练),Vicuna-33B (用私有数据集训练)和Zephyr-7B(用监督微调和对齐训练)。在不影响生成质量的情况下,Medusa可以在不同的提示类型中实现2.3到2.8倍的加速。
Methodology
Medusa遵循与推测解码相同的框架,其中每个解码步骤主要包括三个子步骤: (1)生成候选项,(2)处理候选项,(3)接受候选项。
对于Medusa来说,(1)是通过Medusa头实现的,(2)是通过 tree attention 实现的,由于Medusa头位于原始模型之上,因此(2)中计算的logits可以用于下一个解码步骤的子步骤(1)。最后一步(3)可以通过 rejection sampling 或 typical acceptance (章节2.3.1)。整个管道如图1所示。
Figure 1: Medusa introduces multiple heads on top of the last hidden states of the LLM, enabling the prediction of several subsequent tokens in parallel (Section 2.1.1). During inference, each head generates multiple top predictions for its designated position. These predictions are assembled into candidates, which are processed in parallel using a tree-based attention mechanism (Section 2.1.2). The final step is to verify the candidates and accept a continuation. Besides the standard rejection sampling scheme, a typical acceptance scheme (Section 2.3.1) can also be used here to select reasonable continuations, and the longest accepted candidate prefix will be used for the next decoding phase.
在本节中,我们首先介绍Medusa的关键组成部分,包括Medusa头和 tree attention。然后,我们为Medusa提供了两个级别的微调过程,以满足不同用例的需求。最后,我们对Medusa提出了两个扩展,包括self-distillation和typical acceptance,分别用于处理Medusa没有训练数据的情况和提高解码过程的效率。
Key Components
Medusa Heads
在speculative decoding中,后续令牌由辅助draft模型预测。这个draft模型必须小而有效,以产生原始模型将接受的延续 。满足这些需求是一项具有挑战性的任务,现有的方法经常采用单独预训练较小模型的方法。这种预训练过程需要大量额外的计算资源。例如,在(Miao et al., 2023)中,据报道使用了275个NVIDIA A100 GPU小时。此外, 单独的预训练可能会在draft模型和原始模型之间产生分布偏移,从而生成原始模型可能不喜欢的延续 。Chen等人也强调了在分布式环境中服务多个模型的复杂性。
为了简化和民主化LLM推理的加速,我们从Stern等人那里获得灵感,他们利用并行解码来完成机器翻译和图像超分辨率等任务。 Medusa头是附加到原始模型的最后隐藏状态的额外解码头 。具体来说,给定原始模型在𝑡位置的最后一个隐藏状态
,我们将𝐾个解码头添加到
。𝑘-th头用于预测下一个标记的第(𝑡+𝑘+1)个位置的标记(原始语言模型头用于预测第(𝑡+1)个位置)。将𝑘-th头部的预测表示为
,表示词汇表上的分布,而原始模型的预测表示为
。根据Stern等人的方法,我们使用单层前馈网络,每个头部都有residual连接。我们发现这种简单的设计足以达到令人满意的性能。𝑘-th头的定义概述如下:
𝑑是LLM最后一个隐藏层的输出维数,而V是词汇表大小。我们 **将
初始化为原始语言模型头,并将
初始化为零。这使得Medusa头的最初预测与原始模型的预测一致** 。SiLU激活函数遵循Llama模型。
与draft模型不同,Medusa头与原始骨干模型一起训练,可以在训练期间保持冻结(Medusa-1)或一起训练(Medusa-2)。这种方法允许在单个GPU上微调大型模型,利用强大的基础模型学习表征。并且保证了Medusa头的分布与原模型的分布一致,从而 缓解了分布移位问题 。此外,由于新的头只包含一个类似于原始语言模型头的单层,所以Medusa不会增加服务系统设计的复杂性,并且对分布式设置很友好。我们将在2.2节讨论Medusa头的训练配方。
Tree Attention
通过Medusa头,我们获得后续𝐾+1令牌的概率预测。这些预测使我们能够创建长度为𝐾+1的延续作为候选。而 speculative decoding研究建议采样单个延续作为候选,在解码过程中利用多个候选可以提高解码步骤内的预期接受长度。然而,更多的候选者也会提高计算需求。为了达到平衡,我们 采用了一种树状结构的注意力机制来同时处理多个候选对象 。
这种注意机制偏离了传统的因果注意范式。 在此框架中,只有来自相同延续的标记才被视为历史数据 。从图神经网络领域提出的将图结构嵌入注意力的概念中获得灵感,我们将树结构纳入到我们的注意力mask中,如图2所示。
值得注意的是,在Miao等人等独立作品中也探索了类似的想法;Spector & Re,他们遵循自下而上的方法,通过合并由draft模型生成的多个候选对象来构建树。在我们的方法中,我们采用自顶向下的方法来构建树,这要归功于由Medusa头生成的候选结构。
**对于给定的𝑘-th头部,其top-
预测作为候选阵型的基础,其中
是指定的超参数** 。这些候选者是通过确定每个头部的top-
预测的笛卡尔积来确定的。例如,在图2中,当
和
时,每个第一个头的预测都可以由第二个头的任何预测来跟随。 **这导致了一个树形结构,其中
分支存在于𝑘-th级别(将虚拟root视为0-level,在实践中,该0-level是原始模型的语言模型头的预测,可以独立采样)** 。 在这个树中,只有令牌的前身被视为历史上下文,我们的注意力掩码确保注意力只应用于令牌的前身 。通过使用此掩码并适当设置位置编码的位置索引,我们可以同时处理多个候选数据,而无需扩展batch size。新令牌的累积数量计算为
在本节中,我们将演示通过笛卡尔积构造树形结构的最简单和规则的方法。然而,有可能以更复杂的方式构建树结构,并利用不同头的不同top预测的不平衡准确性。我们将在2.3.3节中讨论这一点。
Training Strategies
在最基本的层面上,我们可以通过冻结骨干模型和微调Medusa头来训练Medusa头。然而,与Medusa头一起训练backbone可以显著提高Medusa头的准确性。根据计算资源和用例的具体要求,我们提出了两个层次的Medusa头训练策略。
在本节中,我们 假设训练数据集的可用性与目标模型的输出分布一致 。这可能是用于目标模型的监督微调(SFT)的数据集。我们将在第2.3.2节讨论 使用self-distillation方法消除对这样一个数据集的需求 。
Medusa-1: Frozen Backbone
为了使用frozen backbone模型训练Medusa头,我们可以利用Medusa头预测与groud truth之间的交叉熵损失。具体地说,鉴于位置
的 groud truth token 是
, 𝑘-th头的损失是
, 其中
表示𝑘-th头预测的令牌
的概率。我们还观察到,当𝑘较大时,
也较大,这是合理的,因为当𝑘较大时,对𝑘-th头的预测更不确定。因此,我们可以在
上加上一个权重值
来平衡不同头的损失。Medusa的总损失是:
在实践中,我们将
设置为一个常数的,比如0.8, 𝑘-th次方。由于我们只使用骨干模型来提供隐藏状态,因此我们可以使用骨干模型的量化版本来减少内存消耗。这引入了一种更民主化的方式来加速LLM推理,就像量化一样,Medusa可以在类似于QLoRA的单个消费者GPU上训练大型模型。训练只需要几个小时(例如,在Vicuna 7B模型上使用单个NVIDIA A100 PCIE GPU对60k ShareGPT样本进行训练的Medusa-1需要5个小时)。
Medusa-2: Joint Training
为了进一步提Medusa头的准确性,我们可以将Medusa头与骨干模型一起训练。然而,这需要一个特殊的训练配方来保持骨干模型的下一个令牌预测能力和输出质量。为此,我们提出三项策略:
- Combined loss
为了保持骨干模型的下一个token预测能力,我们需要在Medusa损失中加入骨干模型的交叉熵损失
。我们还添加了一个权重
来平衡骨干模型和Medusa头的损失。因此,总损失为:
在这里插入图片描述
- Differential learning rates 由于骨干模型已经训练得很好,而Medusa头需要更多的训练,我们可以为它们使用单独的学习率,以使Medusa头更快的收敛,同时保持骨干模型的能力。
- Heads warmup 注意,在训练开始时,Medusa头有很大的损失,这导致梯度很大,可能会扭曲骨干模型的参数。按照Kumar等人的想法,我们可以采用两阶段的训练过程。在第一阶段,我们只使用Medusa-1模式训练Medusa的头。在第二阶段,我们用warmup策略一起训练骨干模型和Medusa头。具体来说,我们首先对骨干模型进行几个epoch的训练,然后与骨干模型一起训练Medusa头。除了这个简单的策略,我们还可以使用更复杂的warmup策略,即逐渐增加骨干模型损失的权重
。我们发现这两种策略在实践中都很有效。
将这些策略结合在一起,我们可以在不损害骨干模型能力的情况下,与骨干模型一起训练Medusa头。此外,该配方可以与监督微调(SFT)一起应用,使我们能够获得具有Medusa支持的模型。
How to Select the Number of Heads
根据经验,我们发现 最多5个头 就足够了。因此,我们建议使用五个头进行训练,并参考2.3.3节中描述的策略来确定树注意力的最佳配置。 有了优化的树注意力,有时三个或四个头可能就足够进行推理了。在这种情况下,我们可以忽略冗余的头,而不会产生开销 。
Extensions
Typical Acceptance
在 speculative decoding 论文中,作者采用 rejection sampling 来产生与原始模型分布一致的不同输出。然而,随后的实现表明, 随着采样 temperature 的升高,这种采样策略会导致效率降低 。 直观地说,这可以在draft模型与原始模型相同的极端情况下理解: 使用贪婪解码,draft模型的所有输出都将被接受,从而使效率最大化。相反,rejection sampling 引入了额外的开销,因为draft模型和原始模型是独立采样的。即使它们的分布完全一致,draft模型的输出仍然可能被拒绝 。
然而,在现实世界的场景中,从语言模型中抽样经常被用来产生不同的响应,temperature参数仅仅被用来调节响应的“创造性”。因此,较高的temperature应该会使原始模型有更多的机会接受draft模型的输出。我们确定通常不需要匹配原始模型的分布。因此,我们建议采用typical acceptance方案来选择合理的候选人,而不是使用rejection sampling。这种方法从 truncation sampling 研究(Hewitt et al., 2022)中获得灵感(参见附录A以获得深入解释)。 我们的目标是选择typical的候选者 ,这意味着它们不是极不可能由原始模型产生的。 我们使用原始模型的预测概率作为对此的自然衡量标准,并根据预测分布建立阈值来确定接受度 。具体地说,给定
作为上下文,在评估候选序列
(由原始语言模型头和Medusa头的top预测组成)时,我们考虑条件其中
𝐻(⋅)为熵函数,
为hard阈值,
为熵相关阈值。该标准改编自Hewitt等人,基于两个观察结果: (1)概率相对较高的令牌是有意义的,(2)当分布的熵很高时,各种延续可能被认为是合理的 。 在解码过程中,使用该准则对每个候选项进行评估,如果满足条件,则接受候选项的prefix。为了保证每一步至少生成一个令牌,我们对第一个令牌采用greedy解码并无条件接受,而对后续令牌采用 typical acceptance 。当前步骤的最终预测由所有候选中 longest accepted prefix 确定 。
研究这个方案可以得出几点见解。首先, 当temperature设置为 0 ,它恢复到贪婪解码,因为只有最可能的令牌具有非零概率 。 当temperature超过 0 ,贪婪解码的结果在适当地 𝜖 , 𝛿 参数下将始终被接受,因为这些令牌有最大的概率,产生最大的加速 。同样, 在一般情况下,temperature升高将相应地导致更长的可接受序列 ,正如我们的实验结果所证实的那样。
根据经验,我们验证了typical acceptance 可以实现更好的加速,同时保持类似的生成质量,如图5所示。
Self-Distillation
在2.2节中,我们假设存在一个与目标模型的输出分布匹配的训练数据集。然而,情况并非总是如此。例如,模型所有者可能只发布没有训练数据的模型,或者模型可能已经经过了人类反馈强化学习(RLHF)过程,这使得模型的输出分布与训练数据集不同。为了解决这个问题,我们提出了一个自动self-distillation管道,使用模型本身来生成与模型输出分布匹配的Medusa heads训练数据集。
数据集生成过程很简单。我们首先从与目标模型相似的领域中获取公共种子数据集;例如,使用聊天模型的ShareGPT 数据集。然后,我们简单地从数据集中获取提示,并要求模型对提示进行回复。为了获得多回合对话样本,我们可以将种子数据集中的提示顺序馈送到模型中。或者,对于像Zephyr 7B 这样的模型,它接受了对话两种角色的训练,它们具有自言自语的能力,我们可以简单地输入第一个提示,让模型生成多轮对话。
对于Medusa-1,这个数据集足以训练Medusa头。然而,对于Medusa-2,我们观察到单独使用该数据集来训练主干和Medusa头通常会导致较低的生成质量。事实上,即使不训练Medusa头,用这个数据集训练骨干模型也会导致性能下降。这表明 我们还需要使用原始模型的概率预测,而不是使用ground truth 令牌作为主干模型的标签,类似于经典的知识蒸馏工作 。具体来说,骨干模型的损失为:
式中
为原模型预测在𝑡位置的概率分布。
然而,为了获得原始模型的概率预测,我们需要在训练过程中维护两个模型,增加了对内存的要求。为了进一步缓解这一问题,我们提出了一种简单而有效的方法来利用self-distillation装置。 我们可以使用像LoRA这样的参数高效适配器来微调骨干模型。这样,原始模型就是关闭适配器的模型。因此,蒸馏不需要额外的内存消耗 。总之,这个自蒸馏管道可以用来训练Medusa-2,而不会损害骨干模型的能力,并且几乎不会引入额外的内存消耗。 最后,关于使用自蒸馏的一个提示是,在这种情况下最好使用没有量化的LoRA,否则,教师模型将是量化模型,这可能导致较低的生成质量 。
Searching for the Optimized Tree Construction
在2.1.2节中,我们给出了用笛卡尔积构造树形结构的最简单方法。然而,对于树中总节点数量的固定预算,常规树结构可能不是最佳选择。直观地说,由不同头像的预测结果组成的候选结果可能有不同的准确性。因此,我们可以利用对精度的估计来构建树形结构。
具体来说,我们可以使用校准数据集并计算不同头的最top预测的准确性。设
表示𝑘-th头的𝑖-th top预测的精度。假设精度是独立的,我们可以估计由不同头的 top
预测组成的候选序列的精度为
。设𝐼表示所有可能组合的集合,并且𝐼的每个元素都可以映射到树的一个节点(不仅包括叶节点,还包括所有节点)。则候选序列接受长度的期望为:
在这里插入图片描述
考虑通过一个接一个地添加节点来构建树,新节点对期望的贡献正是与该节点相关的精度 。因此,我们可以通过选择与当前树连接且精度最高的节点来贪婪地向树中添加节点。 这个过程可以重复,直到节点总数达到所需的数量 。通过这种方式, 我们可以构造一个使接受长度期望最大化的树 。详情见附录C。
Experiments
在本节中,我们将通过实验来证明Medusa在不同设置下的有效性。首先,我们在Vicuna-7B和13B模型上对Medusa进行了评估,以展示Medusa-1和Medusa-2的性能。然后,我们使用Vicuna-33B和Zephyr-7B模型来评估我们的方法,以证明自蒸馏在无法直接访问微调配方的情况下的可行性,如Vicuna-33B,以及像Zephyr-7B这样采用人类反馈强化学习(RLHF)的模型。评估是在MT-Bench 上进行的,这是一个多回合的会话格式基准。详细的设置可以在附录B中找到。
Case Study: Medusa-1 v.s. Medusa-2 on Vicuna 7B and 13B
Experimental Setup. 我们使用Vicuna模型类,它包含从Llama模型微调的不同大小的聊天模型(7B, 13B, 33B)。其中,7B和13B模型是在ShareGPT 数据集上训练的,而33B模型是实验模型,是在私有数据集上训练的。在本节中,我们使用ShareGPT数据集在7B和13B模型上训练2个epoch的Medusa头。我们使用v1.5版本的Vicuna模型,这是微调从Llama-2模型序列长度4096。
Results . 我们收集结果并在图3中显示。基线是默认的Huggingface实现。在图3(a)中,我们可以看到,对于7B模型,Medusa-1和Medusa-2配置导致速度显着提高,以每秒处理的令牌为单位。Medusa-1显示2.18倍的加速,而Medusa-2进一步提高到2.83倍。当应用于更大的13B模型时,Medusa-1的速度提高了2.33倍,而Medusa-2在基线上保持了类似的2.83倍的性能增益。我们还绘制了Medusa-2 Vicuna-7b型号的每个类别的加速图。我们观察到编码类别受益于3.29倍的加速,这表明Medusa对该领域的任务特别有效。这指出了优化编码llm的巨大潜力,llm广泛用于软件开发和其他与编程相关的任务。“Extraction”类别显示出最高的加速,为3.62 x,表明该任务被Medusa高度优化。总体而言,结果表明,Medusa显著提高了不同模型大小和任务的推理速度。
Figure 3:Left: Speed comparison of baseline, Medusa-1 and Medusa-2 on Vicuna-7B/13B. Medusa-1 achieves more than 2× wall-time speedup compared to the baseline implementation while Medusa-2 further improves the speedup by a significant margin. Right: Detailed speedup performance of Vicuna-7B with Medusa-2 on 8 categories from MT-Bench.
Case Study: Training with Self-Distillation on Vicuna-33B and Zephyr-7B
Experimental Setup . 在本案例研究中,我们将重点关注需要自蒸馏的案例。我们以Vicuna-33B模型和Zephyr-7B模型为例。按照第2.3.2节中描述的过程,我们首先生成带有一些种子提示符的数据集。我们使用ShareGPT 和UltraChat作为种子数据集,并为这两种情况收集了大约100𝑘样本的数据集。有趣的是,我们发现Zephyr模型可以通过一个提示继续生成多轮对话,这使得收集大型数据集变得容易。对于Vicuna-33B,我们使用温度为0.3的随机抽样,通过迭代地提供来自每个多回合种子对话的提示来生成多回合对话。两个模型都以序列长度2048和批大小128进行训练.
在这里插入图片描述
Results 。表1通过比较不同的Medusa-2模型在MT-Bench上的加速速率、开销和质量来补充这些发现,其中GPT-4作为评估器来分配从0到10的性能分数。我们报告了Medusa与原始模型的质量差异。值得注意的是,虽然Medusa-2 Vicuna-33B 模型显示出较低的加速度,但它保持了相当的质量。我们假设这是由于隐藏的训练数据集和我们用于自蒸馏的数据集之间的不匹配。因此,自蒸馏可以很好地对齐模型的生成质量,而Medusa头从可能从训练集偏移的自蒸馏中学习分布。在我们的研究中,我们还应用了speculative解码,使用开源draft模型(详细信息可在附录D中找到)的Vicuna阵容。
这些结果强调了在扩大模型尺寸和应用自蒸馏技术时速度和性能之间复杂的相互作用。研究结果还强调了Medusa-2配置在提高处理效率的同时保持模型输出质量的潜力,这为与Medusa头共同优化llm提供了一个有希望的方向。
Ablation Study
Configuration of Tree Attention
树注意力的研究是使用Medusa-2 Vicuna-7B在MT-Bench数据集中的写作和角色扮演类别上进行的。我们的目标是描述树的注意力的动机和它的性能。
图4(a)比较了随机采样的密集树配置(第2.1.2节,用蓝点表示)与优化的稀疏树设置(第2.3.3节,用红星表示)的加速速率。64个节点的稀疏树配置比256个节点的密集树配置显示出更好的加速速率。图4(b)中速度的下降是由于计算边界带来的开销增加。虽然更复杂的树可以提高加速,但它是以速度为代价的,因为线性层的密集矩阵乘法和自关注。加速速率的增长遵循对数趋势,随着树的增大,加速速率的增长速度逐渐放缓,如图4(a)所示。然而,最初的增益是可观的,允许Medusa实现显着的加速。如果加速度增加小于开销,则会降低整体性能。有关详细研究,请参阅附录G。
Thresholds of Typical Acceptance
使用Medusa-2 Vicuna 7B对MT-Bench数据集的写作和角色扮演类别的典型接受阈值进行了研究。利用Vicuna 7B模型,我们将我们的方法与所描述的方法保持一致,并设置了
。图5展示了我们的模型在不同采样设置下的性能对比分析。这些设置的阈值从0.01开始,以0.01为步长逐渐增加到0.25。我们的观察表明了一种明显的权衡: 随着加速度的增加,质量的提高是以加速度的降低为代价的 。此外,对于需要创造力的任务,我们注意到默认随机抽样在性能上优于贪婪抽样,并且当自由度增大时,所提出的典型抽样与随机抽样具有可比性。
Effectiveness of Two-stage Fine-tuning
在这里插入图片描述
表2显示了Vicuna-7B模型的各种微调策略之间的性能差异。Medusa-1,微调只有Medusa头,实现2.18倍的加速,而不影响生成质量。Medusa-2采用两阶段微调(章节2.2.2),保持了生成质量,并提供了比Medusa-1更大的加速(2.83倍)。相反,使用Medusa头直接对模型进行微调会导致生成质量下降。研究结果表明,实现我们的Medusa-2进行微调保持了模型的质量,同时提高了与Medusa-1相比的加速。
Discussion
综上所述,Medusa通过为模型配备额外的预测解码头,允许同时生成多个令牌并绕过顺序解码限制,将LLM推理速度提高了2.3-2.8倍。Medusa的主要优点包括其简单性、参数效率和易于集成到现有系统中。Medusa避免了需要专门的draft模型。典型的验收方案在提供合理输出的同时消除了拒绝抽样的复杂性。我们的方法包括两个有效的训练程序,确保各种模型和提示类型的高质量输出。我们在表3中总结了每种技术的发展及其对加速的影响。
在这里插入图片描述
在本文中,为了简单起见,我们将重点放在批大小为1的设置上。然而,我们想强调的是,我们论文中提出的想法可以推广到更大的批处理大小设置,这些设置现在由TensorRT和Huggingface TGI等库支持。
Implementation
kv_cache
github: https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/kv_cache.py
import torch
class KVCache:
"""
A key-value cache for the model.
This class provides a mechanism to maintain a growing cache of keys and values,
particularly useful for models that benefit from caching previous states,
like transformers during autoregressive decoding.
Attributes:
data (torch.Tensor): The tensor storing keys and values.
current_length (int): Current length of the data being stored.
"""
def __init__(self, data, current_length):
"""
Initialize the KVCache.
Args:
data (torch.Tensor): Initial tensor to store the keys and values.
current_length (int): Initial length of the data.
"""
self.data = data
self.current_length = current_length
@property
def shape(self):
"""Return the shape of the data tensor with updated length."""
return (
self.data.shape[0], # B
self.data.shape[1], # H
self.current_length.item(), # T
self.data.shape[3], # D
)
def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2):
"""
Copy values from the current data at specified indices to a new location.
Args:
indices (torch.Tensor): Indices of the data tensor to be copied.
prev_length (int): Previous length before adding new data.
dim (int, optional): Dimension along which copying should be performed. Default is 2.
"""
tgt = self.data.index_select(dim, indices)
dst = self.data.narrow(dim, prev_length, tgt.shape[dim])
dst.copy_(tgt, non_blocking=True)
self.current_length.fill_(prev_length + tgt.shape[dim])
def cat(self, tensor: torch.Tensor, dim: int = 2):
"""
Concatenate the given tensor with the current data.
Args:
tensor (torch.Tensor): The tensor to be concatenated.
dim (int, optional): The dimension along which concatenation should be done. Default is 2.
Returns:
torch.Tensor: The data tensor after concatenation up to the current length.
"""
dst = self.data.narrow(dim, self.current_length, tensor.shape[dim])
dst.copy_(tensor)
self.current_length.add_(tensor.shape[dim])
return torch.narrow(self.data, 2, 0, self.current_length)
def initialize_past_key_values(model):
"""
Initialize past key and value states for a given transformer model.
This function prepares key-value cache structures for the model, allowing it to store and reuse
past key and value states during autoregressive decoding, which can improve efficiency.
Args:
model (nn.Module): The transformer model for which past key-value states need to be initialized.
Returns:
tuple:
- past_key_values (list): A list of KVCache objects for each layer in the model.
- past_key_values_data (torch.Tensor): The tensor that will store all keys and values.
- current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache.
"""
# Extracting configuration from the model
config = model.config
# Initializing the batch size to 1, this can be modified if different batch sizes are required
batch_size = 1
# Initializing a tensor to store past keys and values for all layers
past_key_values_data = torch.zeros(
config.num_hidden_layers * 2, # Layers
batch_size, # B
config.num_key_value_heads, # Heads
config.max_position_embeddings, # T
config.hidden_size // config.num_attention_heads, # D
device=model.device,
dtype=model.dtype,
)
# Initialize tensor to store the current length of the cached data for all layers.
# [IMPORTANT] It needs to be kept on CPU for quick access and updates.
current_length_data = torch.zeros(
config.num_hidden_layers * 2, dtype=torch.long, device="cpu"
)
# Creating a KVCache for each pair of key and value in all layers
past_key_values = [] * config.num_hidden_layers
for i in range(config.num_hidden_layers):
past_key_values.append(
[
KVCache(past_key_values_data[i * 2 + j], current_length_data[i * 2 + j])
for j in range(2) # key, val
]
)
return past_key_values, past_key_values_data, current_length_data
Config
https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/medusa\_model.py#L20
class MedusaConfig(PretrainedConfig):
"""
Configuration class for Medusa model.
Args:
medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
**kwargs: Additional keyword arguments to be passed to the parent class constructor.
"""
def __init__(
self,
medusa_num_heads=5,
medusa_num_layers=1,
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
**kwargs,
):
super().__init__(**kwargs)
self.medusa_num_heads = medusa_num_heads
self.medusa_num_layers = medusa_num_layers
self.base_model_name_or_path = base_model_name_or_path
ResBlock
class ResBlock(nn.Module):
"""
A Residual Block module.
This module performs a linear transformation followed by a SiLU activation,
and then adds the result to the original input, creating a residual connection.
Args:
hidden_size (int): The size of the hidden layers in the block.
"""
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
# Initialize as an identity mapping
torch.nn.init.zeros_(self.linear.weight)
# Use SiLU activation to keep consistent with the Llama model
self.act = nn.SiLU()
def forward(self, x):
"""
Forward pass of the ResBlock.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output after the residual connection and activation.
"""
return x + self.act(self.linear(x))
MedusaModelABC
class MedusaModelABC(nn.Module):
"""The Medusa Language Model Head.
This module creates a series of prediction heads (based on the 'medusa' parameter)
on top of a given base model. Each head is composed of a sequence of residual blocks
followed by a linear layer.
"""
def __init__(
self,
config,
):
"""
Args:
config (PretrainedConfig): The configuration of the MedusaModel.
"""
super().__init__(config)
# For compatibility with the old APIs
medusa_num_heads = config.medusa_num_heads
medusa_num_layers = config.medusa_num_layers
base_model_name_or_path = config._name_or_path
self.hidden_size = config.hidden_size
self.vocab_size = config.vocab_size
self.medusa = medusa_num_heads
self.medusa_num_layers = medusa_num_layers
self.base_model_name_or_path = base_model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
# Create a list of Medusa heads
self.medusa_head = nn.ModuleList(
[
nn.Sequential(
*([ResBlock(self.hidden_size)] * medusa_num_layers), # Layers
nn.Linear(self.hidden_size, self.vocab_size, bias=False),
)
for _ in range(medusa_num_heads) # Heads
]
)
# Add a link named base_model to self
@property
def base_model(self):
return self
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*args,
**kwargs,
):
# Manually load config to ensure that the medusa_num_heads parameter is loaded
try:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
return super().from_pretrained(
pretrained_model_name_or_path,
*args,
**kwargs,
config=config,
)
except:
config = MedusaConfig.from_pretrained(pretrained_model_name_or_path)
base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path)
base_model_config.medusa_num_heads = 5 # TODO: fix the uploaded config (only include 2 heads)
base_model_config.medusa_num_layers = config.medusa_num_layers
model = super().from_pretrained(
config.base_model_name_or_path,
*args,
**kwargs,
config=base_model_config,
)
# load medusa head checkpoints
medusa_head_path = os.path.join(pretrained_model_name_or_path, "medusa_lm_head.pt")
if os.path.exists(medusa_head_path):
filename = medusa_head_path
else:
filename = hf_hub_download(pretrained_model_name_or_path, "medusa_lm_head.pt")
medusa_head_state_dict = torch.load(filename, map_location=model.device)
model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False)
return model
def get_tokenizer(self):
"""Get the tokenizer of the base model.
Returns:
Tokenizer: The tokenizer of the base model.
"""
return self.tokenizer
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
output_orig=False,
position_ids=None,
medusa_forward=False,
**kwargs,
):
"""Forward pass of the MedusaModel.
Args:
input_ids (torch.Tensor, optional): Input token IDs.
attention_mask (torch.Tensor, optional): Attention mask.
labels (torch.Tensor, optional): Ground truth labels for loss computation.
past_key_values (tuple, optional): Tuple containing past key and value states for attention.
output_orig (bool, optional): Whether to also output predictions from the original LM head.
position_ids (torch.Tensor, optional): Position IDs.
Returns:
torch.Tensor: A tensor containing predictions from all Medusa heads.
(Optional) Original predictions from the base model's LM head.
"""
if not medusa_forward:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
**kwargs,
)
with torch.inference_mode():
# Pass input through the base model
outputs = self.base_model.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
**kwargs,
)
if output_orig:
# LM head, [B,T,D]
orig = self.base_model.lm_head(outputs[0])
# Clone the output hidden states
hidden_states = outputs[0].clone()
medusa_logits = []
# TODO: Consider parallelizing this loop for efficiency?
for i in range(self.medusa):
medusa_logits.append(self.medusa_head[i](hidden_states))
if output_orig:
# [H,B,T,D]
return torch.stack(medusa_logits, dim=0), outputs, orig
return torch.stack(medusa_logits, dim=0)
Medusa Decoding
def medusa_generate(
self,
input_ids,
attention_mask=None,
temperature=0.0,
max_steps=512,
# The hyperparameters below are for the Medusa
# top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
medusa_choices=None,
posterior_threshold=0.09, # threshold validation of Medusa output
# another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
posterior_alpha=0.3,
top_p=0.8,
sampling = 'typical',
fast = True
):
"""
Args:
input_ids (torch.Tensor, optional): Input token IDs.
attention_mask (torch.Tensor, optional): Attention mask.
temperature (float, optional): Temperature for typical acceptance.
medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head.
posterior_threshold (float, optional): Threshold for posterior validation.
posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold).
top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
Returns:
torch.Tensor: Output token IDs.
Warning: Only support batch size 1 for now!!
"""
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
input_ids = input_ids.clone()
# Cache medusa buffers (the fixed patterns for tree attention)
if medusa_choices is None:
medusa_choices = self.get_medusa_choice(self.base_model_name_or_path)
if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices:
# Load the cached medusa buffer
medusa_buffers = self.medusa_buffers
else:
# Initialize the medusa buffer
medusa_buffers = generate_medusa_buffers(
medusa_choices, device=self.base_model.device
)
self.medusa_buffers = medusa_buffers
self.medusa_choices = medusa_choices
# Initialize the past key and value states
if hasattr(self, "past_key_values"):
past_key_values = self.past_key_values
past_key_values_data = self.past_key_values_data
current_length_data = self.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(self.base_model)
self.past_key_values = past_key_values
self.past_key_values_data = past_key_values_data
self.current_length_data = current_length_data
input_len = input_ids.shape[1]
reset_medusa_mode(self)
# Initialize tree attention mask and process prefill tokens
medusa_logits, logits = initialize_medusa(
input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values
)
new_token = 0
last_round_token = 0
for idx in range(max_steps):
# Generate candidates with topk predictions from Medusa heads
candidates, tree_candidates = generate_candidates(
medusa_logits,
logits,
medusa_buffers["tree_indices"],
medusa_buffers["retrieve_indices"],
temperature=temperature,
posterior_alpha=posterior_alpha,
posterior_threshold=posterior_threshold,
top_p=top_p,
sampling=sampling,
fast=fast,
)
# Use tree attention to verify the candidates and get predictions
medusa_logits, logits, outputs = tree_decoding(
self,
tree_candidates,
past_key_values,
medusa_buffers["medusa_position_ids"],
input_ids,
medusa_buffers["retrieve_indices"],
)
# Evaluate the posterior of the candidates to select the accepted candidate prefix
best_candidate, accept_length = evaluate_posterior(
logits, candidates, temperature, posterior_threshold, posterior_alpha, top_p=top_p, sampling=sampling, fast=fast
)
# Update the input_ids and logits
input_ids, logits, medusa_logits, new_token = update_inference_inputs(
input_ids,
candidates,
best_candidate,
accept_length,
medusa_buffers["retrieve_indices"],
outputs,
logits,
medusa_logits,
new_token,
past_key_values_data,
current_length_data,
)
yield {
"text": self.tokenizer.decode(
input_ids[0, input_len:],
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
}
if self.tokenizer.eos_token_id in input_ids[0, input_len:]:
break
def get_medusa_choice(self, model_name):
if 'vicuna' in model_name:
if '7b' in model_name:
return vicuna_7b_stage2
elif '13b' in model_name:
return vicuna_13b_stage2
elif '33b' in model_name:
return vicuna_33b_stage2
elif 'zephyr' in model_name:
return zephyr_stage2
warnings.warn('Please specify medusa choice configuration!')
return mc_sim_7b_63
mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
vicuna_7b_stage2 = [(0,), (0, 0), (1,), (0, 1), (0, 0, 0), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 3), (3,), (0, 1, 0), (2, 0), (4,), (0, 0, 2), (0, 4), (1, 1), (1, 0, 0), (0, 0, 0, 0), (5,), (0, 0, 3), (0, 5), (0, 2, 0), (3, 0), (0, 1, 1), (0, 6), (6,), (0, 7), (0, 0, 4), (4, 0), (1, 2), (0, 8), (7,), (0, 3, 0), (0, 0, 0, 1), (0, 0, 5), (2, 1), (0, 0, 6), (1, 0, 1), (0, 0, 1, 0), (2, 0, 0), (5, 0), (0, 9), (0, 1, 2), (8,), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7), (0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0), (9,), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0), (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)]
vicuna_7b_stage1_ablation = [(0,), (0, 0), (1,), (0, 0, 0), (0, 1), (1, 0), (2,), (0, 2), (0, 0, 1), (3,), (0, 3), (0, 1, 0), (2, 0), (0, 0, 2), (0, 4), (4,), (0, 0, 0, 0), (1, 0, 0), (1, 1), (0, 0, 3), (0, 2, 0), (0, 5), (5,), (3, 0), (0, 1, 1), (0, 6), (6,), (0, 0, 4), (1, 2), (0, 0, 0, 1), (4, 0), (0, 0, 5), (0, 7), (0, 8), (0, 3, 0), (0, 0, 1, 0), (1, 0, 1), (7,), (2, 0, 0), (0, 0, 6), (2, 1), (0, 1, 2), (5, 0), (0, 2, 1), (0, 9), (0, 0, 0, 2), (0, 4, 0), (8,), (1, 3), (0, 0, 7), (0, 1, 0, 0), (1, 1, 0), (6, 0), (9,), (0, 0, 8), (0, 0, 9), (0, 5, 0), (0, 0, 2, 0), (1, 0, 2), (0, 1, 3), (0, 0, 0, 3), (3, 0, 0), (3, 1)]
vicuna_7b_stage1 = [(0,), (0, 0), (1,), (2,), (0, 1), (1, 0), (3,), (0, 2), (4,), (0, 0, 0), (0, 3), (5,), (2, 0), (0, 4), (6,), (0, 5), (1, 1), (0, 0, 1), (7,), (3, 0), (0, 6), (8,), (9,), (0, 1, 0), (0, 7), (0, 8), (4, 0), (0, 0, 2), (1, 2), (0, 9), (2, 1), (5, 0), (1, 0, 0), (0, 0, 3), (1, 3), (0, 2, 0), (0, 1, 1), (0, 0, 4), (6, 0), (1, 4), (0, 0, 5), (2, 2), (0, 3, 0), (3, 1), (0, 0, 6), (7, 0), (1, 5), (1, 0, 1), (2, 0, 0), (0, 0, 7), (8, 0), (0, 0, 0, 0), (4, 1), (0, 1, 2), (0, 4, 0), (9, 0), (0, 2, 1), (2, 3), (1, 6), (0, 0, 8), (0, 5, 0), (3, 2), (5, 1)]
vicuna_13b_stage2 = [(0,), (0, 0), (1,), (0, 0, 0), (0, 1), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 1, 0), (3,), (0, 3), (2, 0), (0, 0, 2), (0, 0, 0, 0), (0, 4), (1, 0, 0), (1, 1), (4,), (0, 0, 3), (0, 5), (0, 2, 0), (5,), (3, 0), (0, 1, 1), (0, 6), (0, 0, 4), (0, 0, 0, 1), (0, 7), (0, 0, 5), (1, 2), (0, 0, 1, 0), (0, 3, 0), (1, 0, 1), (4, 0), (0, 0, 6), (0, 8), (2, 0, 0), (0, 9), (6,), (7,), (2, 1), (5, 0), (0, 1, 2), (0, 0, 0, 2), (8,), (0, 4, 0), (0, 1, 0, 0), (0, 2, 1), (0, 0, 7), (1, 1, 0), (1, 3), (0, 0, 2, 0), (9,), (0, 0, 8), (0, 5, 0), (0, 0, 0, 3), (0, 0, 9), (0, 1, 3), (1, 0, 2), (0, 0, 1, 1), (3, 0, 0), (1, 0, 0, 0)]
vicuna_13b_stage1 = [(0,), (0, 0), (1,), (0, 1), (2,), (1, 0), (0, 0, 0), (0, 2), (3,), (0, 3), (4,), (2, 0), (0, 4), (0, 0, 1), (0, 5), (5,), (1, 1), (0, 1, 0), (6,), (0, 6), (0, 0, 2), (7,), (3, 0), (8,), (0, 7), (0, 8), (1, 0, 0), (0, 0, 3), (4, 0), (1, 2), (9,), (0, 9), (2, 1), (0, 2, 0), (0, 0, 4), (1, 3), (0, 1, 1), (0, 0, 5), (5, 0), (0, 3, 0), (0, 0, 0, 0), (0, 0, 6), (6, 0), (1, 4), (2, 0, 0), (0, 1, 2), (3, 1), (0, 4, 0), (1, 0, 1), (2, 2), (0, 0, 7), (1, 5), (7, 0), (0, 0, 8), (8, 0), (0, 5, 0), (0, 0, 9), (0, 2, 1), (1, 1, 0), (0, 1, 3), (4, 1), (2, 3), (1, 6)]
vicuna_33b_stage2 = [(0,), (0, 0), (1,), (0, 1), (0, 0, 0), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 3), (3,), (0, 1, 0), (2, 0), (0, 4), (4,), (0, 0, 2), (1, 1), (1, 0, 0), (0, 5), (5,), (0, 0, 0, 0), (0, 0, 3), (3, 0), (0, 2, 0), (0, 6), (0, 1, 1), (6,), (0, 0, 4), (0, 7), (7,), (1, 2), (4, 0), (8,), (0, 3, 0), (0, 0, 5), (0, 0, 0, 1), (0, 8), (2, 1), (0, 9), (1, 0, 1), (2, 0, 0), (0, 0, 6), (5, 0), (0, 0, 1, 0), (1, 3), (0, 1, 2), (0, 4, 0), (0, 0, 7), (0, 2, 1), (9,), (1, 1, 0), (0, 0, 0, 2), (6, 0), (0, 0, 8), (0, 1, 0, 0), (7, 0), (0, 1, 3), (0, 5, 0), (1, 4), (0, 0, 9), (3, 1), (1, 0, 2), (2, 2)]
vicuna_33b_stage1 = [(0,), (1,), (0, 0), (2,), (0, 1), (3,), (1, 0), (4,), (0, 2), (5,), (0, 3), (0, 0, 0), (6,), (0, 4), (2, 0), (7,), (1, 1), (0, 5), (3, 0), (8,), (9,), (0, 6), (0, 7), (0, 0, 1), (1, 2), (4, 0), (0, 1, 0), (0, 8), (0, 9), (2, 1), (0, 0, 2), (5, 0), (1, 3), (0, 0, 3), (1, 0, 0), (1, 4), (6, 0), (0, 2, 0), (3, 1), (2, 2), (0, 0, 4), (7, 0), (0, 1, 1), (1, 5), (4, 1), (0, 0, 5), (0, 3, 0), (9, 0), (8, 0), (1, 6), (0, 0, 6), (2, 3), (0, 1, 2), (3, 2), (0, 4, 0), (2, 0, 0), (1, 7), (1, 0, 1), (0, 0, 7), (5, 1), (2, 4), (0, 0, 8), (0, 2, 1)]
zephyr_stage2 = [(0,), (0, 0), (1,), (0, 1), (2,), (0, 0, 0), (1, 0), (0, 2), (3,), (0, 3), (4,), (2, 0), (0, 0, 1), (0, 4), (5,), (0, 5), (0, 1, 0), (1, 1), (6,), (0, 0, 2), (3, 0), (0, 6), (7,), (0, 7), (0, 8), (0, 0, 3), (1, 0, 0), (0, 9), (0, 2, 0), (1, 2), (4, 0), (8,), (9,), (2, 1), (0, 1, 1), (0, 0, 4), (0, 0, 0, 0), (5, 0), (0, 3, 0), (1, 3), (0, 0, 5), (0, 0, 6), (6, 0), (2, 0, 0), (1, 0, 1), (0, 1, 2), (0, 4, 0), (1, 4), (3, 1), (2, 2), (0, 0, 7), (7, 0), (0, 2, 1), (0, 0, 8), (0, 1, 3), (0, 5, 0), (1, 5), (0, 0, 9), (1, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0), (4, 1), (2, 3)]
def reset_medusa_mode(
model,
):
"""
Resets the Medusa settings and the past key-values to their initial state.
This function ensures that after any operations involving Medusa,
the base model and its settings return to their default state.
Specifically, it performs the following tasks:
1. Clears the Medusa attention mask in the base model.
2. Resets the Medusa mode in the base model.
3. Resets the current lengths in the past key-values to zero for all layers.
Args:
- model (MedusaLMHead): The model containing the Medusa layers and base model.
- past_key_values (list of torch.Tensor): Contains past hidden states and past attention values.
Returns:
- None
"""
model.base_model.model.medusa_mask = None
model.base_model.model.medusa_mode = None
def initialize_medusa(input_ids, model, medusa_attn_mask, past_key_values):
"""
Initializes the Medusa structure for a given model.
This function performs the following operations:
1. Forward pass through the model to obtain the Medusa logits, original model outputs, and logits.
2. Sets the Medusa attention mask within the base model.
Args:
- input_ids (torch.Tensor): The input tensor containing token ids.
- model (MedusaLMHead): The model containing the Medusa layers and base model.
- medusa_attn_mask (torch.Tensor): The attention mask designed specifically for the Medusa structure.
- past_key_values (list of torch.Tensor): Contains past hidden states and past attention values.
Returns:
- medusa_logits (torch.Tensor): Logits from the Medusa heads.
- logits (torch.Tensor): Original logits from the base model.
"""
medusa_logits, outputs, logits = model(
input_ids, past_key_values=past_key_values, output_orig=True, medusa_forward=True
)
model.base_model.model.medusa_mask = medusa_attn_mask
return medusa_logits, logits
def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, temperature = 0, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = False):
"""
Generate candidates based on provided logits and indices.
Parameters:
- medusa_logits (torch.Tensor): Logits from a specialized Medusa structure, aiding in candidate selection.
- logits (torch.Tensor): Standard logits from a language model.
- tree_indices (list or torch.Tensor): Indices representing a tree structure, used for mapping candidates.
- retrieve_indices (list or torch.Tensor): Indices for extracting specific candidate tokens.
- temperature (float, optional): Controls the diversity of the sampling process. Defaults to 0.
- posterior_threshold (float, optional): Threshold for typical sampling. Defaults to 0.3.
- posterior_alpha (float, optional): Scaling factor for the entropy-based threshold in typical sampling. Defaults to 0.09.
- top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
- sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
- fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
Returns:
- tuple (torch.Tensor, torch.Tensor): A tuple containing two sets of candidates:
1. Cartesian candidates derived from the combined original and Medusa logits.
2. Tree candidates mapped from the Cartesian candidates using tree indices.
"""
# Greedy decoding: Select the most probable candidate from the original logits.
if temperature == 0 or fast:
candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0)
else:
if sampling == 'typical':
candidates_logit = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
elif sampling == 'nucleus':
candidates_logit = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
else:
raise NotImplementedError
# Extract the TOPK candidates from the medusa logits.
candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim = -1).indices
# Combine the selected candidate from the original logits with the topk medusa logits.
candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(-1)], dim=-1)
# Map the combined candidates to the tree indices to get tree candidates.
tree_candidates = candidates[tree_indices]
# Extend the tree candidates by appending a zero.
tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device)], dim=0)
# Retrieve the cartesian candidates using the retrieve indices.
cart_candidates = tree_candidates_ext[retrieve_indices]
# Unsqueeze the tree candidates for dimension consistency.
tree_candidates = tree_candidates.unsqueeze(0)
return cart_candidates, tree_candidates
def get_nucleus_one_token(logit, temperature, top_p):
"""
Performs token sampling based on the nucleus (top-p) sampling method.
This function selects a token from a given logit distribution using the nucleus sampling strategy.
It allows for more controlled and diverse generation compared to traditional top-k sampling.
Args:
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
temperature (float): A temperature parameter to control the randomness in sampling.
Higher values increase diversity, lower values make selections more deterministic.
top_p (float): The cumulative probability threshold for nucleus sampling.
It controls the size of the set of high-probability tokens to consider for sampling.
Returns:
torch.Tensor: A tensor containing the indices of the sampled tokens.
"""
if top_p >= 1:
return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
logit = logit / temperature
probs = torch.softmax(logit, dim=-1)
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
cum_probs = torch.cumsum(sorted_logits, dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logit[indices_to_remove] = float('-inf')
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
return sampled_tokens
def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
"""
Implements token sampling based on the typical sampling method.
This function selects a token from a given logit distribution using the typical sampling strategy,
aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
Args:
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
temperature (float): A parameter to control the randomness in sampling.
Higher values increase diversity, lower values make selections more deterministic.
posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
Returns:
torch.Tensor: A tensor containing the indices of the sampled tokens.
"""
logit = logit / temperature
probs = torch.softmax(logit, dim=-1)
entropy = -torch.sum(
probs * torch.log(probs + 1e-5), dim=-1
)
threshold = torch.minimum(
torch.ones_like(entropy) * posterior_threshold,
torch.exp(-entropy) * posterior_alpha,
)
indices_to_remove = probs < threshold.unsqueeze(-1)
logit[indices_to_remove] = float('-inf')
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
return sampled_tokens
def tree_decoding(
model,
tree_candidates,
past_key_values,
medusa_position_ids,
input_ids,
retrieve_indices,
):
"""
Decode the tree candidates using the provided model and reorganize the logits.
Parameters:
- model (nn.Module): Model to be used for decoding the tree candidates.
- tree_candidates (torch.Tensor): Input candidates based on a tree structure.
- past_key_values (torch.Tensor): Past states, such as key and value pairs, used in attention layers.
- medusa_position_ids (torch.Tensor): Positional IDs associated with the Medusa structure.
- input_ids (torch.Tensor): Input sequence IDs.
- retrieve_indices (list or torch.Tensor): Indices for reordering the logits.
Returns:
- tuple: Returns medusa logits, regular logits, and other outputs from the model.
"""
# Compute new position IDs by adding the Medusa position IDs to the length of the input sequence.
position_ids = medusa_position_ids + input_ids.shape[1]
# Use the model to decode the tree candidates.
# The model is expected to return logits for the Medusa structure, original logits, and possibly other outputs.
tree_medusa_logits, outputs, tree_logits = model(
tree_candidates,
output_orig=True,
past_key_values=past_key_values,
position_ids=position_ids,
medusa_forward=True,
)
# Reorder the obtained logits based on the retrieve_indices to ensure consistency with some reference ordering.
logits = tree_logits[0, retrieve_indices]
medusa_logits = tree_medusa_logits[:, 0, retrieve_indices]
return medusa_logits, logits, outputs
def evaluate_posterior(
logits, candidates, temperature, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = True
):
"""
Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.
Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
probabilities to select the best candidate.
Args:
- logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
- candidates (torch.Tensor): Candidate token sequences.
- temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
- posterior_threshold (float): Threshold for posterior probability.
- posterior_alpha (float): Scaling factor for the threshold.
- top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
- sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
- fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
Returns:
- best_candidate (torch.Tensor): Index of the chosen best candidate.
- accept_length (int): Length of the accepted candidate sequence.
"""
# Greedy decoding based on temperature value
if temperature == 0:
# Find the tokens that match the maximum logits for each position in the sequence
posterior_mask = (
candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
).int()
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
accept_length = candidates_accept_length.max()
# Choose the best candidate
if accept_length == 0:
# Default to the first candidate if none are accepted
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
return best_candidate, accept_length
if sampling == 'typical':
if fast:
posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
candidates_prob = torch.gather(
posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
).squeeze(-1)
posterior_entropy = -torch.sum(
posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
) # torch.sum(torch.log(*)) is faster than torch.prod
threshold = torch.minimum(
torch.ones_like(posterior_entropy) * posterior_threshold,
torch.exp(-posterior_entropy) * posterior_alpha,
)
posterior_mask = candidates_prob > threshold
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
# Choose the best candidate based on the evaluated posterior probabilities
accept_length = candidates_accept_length.max()
if accept_length == 0:
# If no candidates are accepted, just choose the first one
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
best_candidates = torch.where(candidates_accept_length == accept_length)[0]
# Accept the best one according to likelihood
likelihood = torch.sum(
torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
)
best_candidate = best_candidates[torch.argmax(likelihood)]
return best_candidate, accept_length
# Calculate posterior probabilities and thresholds for candidate selection
posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha, fast)
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
# Choose the best candidate based on the evaluated posterior probabilities
accept_length = candidates_accept_length.max()
if accept_length == 0:
# If no candidates are accepted, just choose the first one
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
# Accept the best one according to likelihood
return best_candidate, accept_length
if sampling == 'nucleus':
assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
accept_length = candidates_accept_length.max()
# Choose the best candidate
if accept_length == 0:
# Default to the first candidate if none are accepted
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
return best_candidate, accept_length
else:
raise NotImplementedError
def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
"""
Generates a posterior mask for token candidates using nucleus (top-p) sampling.
This function applies nucleus sampling to a set of logits, and then generates a mask indicating
which candidate tokens are selected. It adapts the sampling strategy to accommodate for
temperature scaling and cumulative probability thresholding.
Args:
logits (torch.Tensor): A tensor of logits from a language model output.
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
top_p (float): The cumulative probability threshold for nucleus sampling.
Returns:
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
"""
# adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
# Apply temperature
logits = logits[:, :-1] / temperature
n_samples, n_tokens = logits.shape[0], logits.shape[1]
logits = logits.view(n_samples*n_tokens, -1)
if top_p >= 1:
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
return posterior_mask
# Convert to probabilities (softmax)
probs = F.softmax(logits, dim=-1)
# Sort the probabilities
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
# Compute cumulative probabilities
cum_probs = torch.cumsum(sorted_logits, dim=-1)
# Create mask for the top-p nucleus
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
# Remove low-probability tokens
logits[indices_to_remove] = float('-inf')
# Sample from the remaining tokens
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
# Create a mask for selected tokens
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
return posterior_mask
def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
"""
Args:
logits (torch.Tensor): A tensor of logits from a language model output.
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
Returns:
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
"""
logits = logits[:, :-1] / temperature
n_samples, n_tokens = logits.shape[0], logits.shape[1]
logits = logits.view(n_samples*n_tokens, -1)
probs = F.softmax(logits, dim=-1)
entropy = -torch.sum(
probs * torch.log(probs + 1e-5), dim=-1
)
threshold = torch.minimum(
torch.ones_like(entropy) * posterior_threshold,
torch.exp(-entropy) * posterior_alpha,
)
indices_to_remove = probs < threshold.unsqueeze(-1)
logits[indices_to_remove] = float('-inf')
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
return posterior_mask
def update_inference_inputs(
input_ids,
candidates,
best_candidate,
accept_length,
retrieve_indices,
outputs,
logits,
medusa_logits,
new_token,
past_key_values_data,
current_length_data,
):
"""
Update the input sequences and relevant tensors based on the selected best candidate from the inference results.
Args:
- input_ids (torch.Tensor): Current input token sequences.
- candidates (torch.Tensor): Candidate token sequences generated in the current step.
- best_candidate (int): Index of the chosen best candidate.
- accept_length (int): Length of the accepted candidate sequence.
- retrieve_indices (torch.Tensor): Indices to map tree to a cartesian product.
- outputs, logits, medusa_logits (torch.Tensor): Model's outputs from the previous inference step.
- new_token (int): Counter for the new tokens added during inference.
- past_key_values_data (torch.Tensor): Tensor containing past hidden states for the transformer model.
- current_length_data (torch.Tensor): Tensor containing the current length of sequences in the batch.
Returns:
- input_ids (torch.Tensor): Updated input token sequences.
- logits (torch.Tensor): Updated logits.
- medusa_logits (torch.Tensor): Updated medusa logits.
- new_token (int): Updated counter for the new tokens added.
"""
# Calculate the starting position for new tokens based on the previous input length
prev_input_len = input_ids.shape[1]
# Map the best candidate indices to the original indices in the sequence
select_indices = (
retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
)
# Append the tokens from the best candidate to the input sequence
input_ids = torch.cat(
[input_ids, candidates[None, best_candidate, : accept_length + 1]], dim=-1
)
# Update the past key values based on the selected tokens
# Source tensor that contains relevant past information based on the selected candidate
tgt = past_key_values_data[..., select_indices, :]
# Destination tensor where the relevant past information will be stored
dst = past_key_values_data[..., prev_input_len : prev_input_len + tgt.shape[-2], :]
# Copy relevant past information from the source to the destination
dst.copy_(tgt, non_blocking=True)
# Update the current length tensor (currently only support batch size is 1)
current_length_data.fill_(prev_input_len + tgt.shape[-2])
# Extract logits and medusa logits for the accepted tokens
logits = logits[None, best_candidate, accept_length : accept_length + 1]
medusa_logits = medusa_logits[
:, None, best_candidate, accept_length : accept_length + 1
]
# Update the new token counter
new_token += accept_length + 1
return input_ids, logits, medusa_logits, new_token
