快速Transformer解码:一个写头就足够了

摘要

https://arxiv.org/pdf/1911.02150

Transformer神经序列模型中使用的多头注意力层是一种强大的替代RNN的方法,用于在序列内部和序列之间传递信息。虽然由于序列长度上的并行化,训练这些层通常快速且简单,但增量推理(在这种情况下并行化是不可能的)通常较慢,这是由于反复加载大型"键"和"值"张量所带来的内存带宽成本。我们提出了一种称为多查询注意力的变体,其中键和值在所有不同的注意力"头"之间共享,大大减少了这些张量的大小,从而降低了增量解码的内存带宽需求。我们通过实验验证了由此产生的模型确实可以更快地解码,并且与基线相比仅产生轻微的质量下降。

1 引言

Transformer神经序列模型[Vaswani et al., 2017]已成为循环序列模型的一种流行替代方案。Transformer依赖注意力层在序列之间和跨序列传递信息。Transformer的一个主要挑战是增量推理的速度。正如我们将要讨论的,在现代计算硬件上,增量Transformer推理的速度受限于重新加载编码注意力层状态的大型"键"和"值"张量所需的内存带宽。在接下来的部分中,我们将回顾Transformer使用的多头注意力层,提供性能分析,并提出一种架构变体(多查询注意力),该变体大大提高了推理速度,同时仅产生轻微的质量下降。

2 背景:神经注意力

神经注意力,由[Bahdanau et al., 2014]引入,是操作可变长度表示的有力工具。神经注意力函数接收单个查询向量qqmm个不同的(键向量,值向量)对(由矩阵KKVV表示),并产生输出向量yy。输出yy被计算为不同值向量的加权和,其中权重是通过将查询与键进行比较得出的。

2.1 点积注意力

以下代码描述了一种常见形式,其中权重被计算为查询与不同键的点积的softmax。

def DotProductAttention(q,K,V):
    """ 一个查询上的点积注意力。
    参数:
    q: 一个形状为[$k$]的向量
    K: 一个形状为[$m$, $k$]的矩阵
    V: 一个形状为[$m$, $v$]的矩阵
    返回:
    y: 一个形状为[$v$]的向量
    """
    logits = tf.einsum("k,mk->m",q,K)
    weights = tf.softmax(logits)
    return tf.einsum("m,mv->v", weights,V)

我们的代码示例使用TensorFlow和numpy中定义的einsum表示法,用于任意维度张量之间的广义收缩。在这种表示法中,方程命名输入和输出张量的维度。计算在数值上等同于将每个输入广播为具有所有维度的并集,逐元素相乘,并对所需输出形状中不存在的所有维度求和。

2.2 多头注意力

"Transformer"序列到序列模型[Vaswani et al., 2017]并行使用hh个不同的注意力层(头),作者将其称为"多头注意力"。hh个不同层的查询向量是从输入向量xxhh个不同的学习线性投影PqP_{q}派生的。类似地,键和值是从mm个不同输入向量的集合MMhh个不同的学习线性投影PkP_{k},PvP_{v}派生的。hh层的输出本身通过不同的学习线性投影PoP_{o}传递,然后求和。为简化起见,我们给输入和输出向量相同的维度dd。计算可以表示如下:

def MultiheadAttention(x,M,P_q,P_k,P_v,P_o):
    """ 一个查询上的多头注意力。
    参数:
    X:一个形状为[$d$]的向量
    M:一个形状为[$m$, $d$]的矩阵
    P_q: 一个形状为[$h$, $d$, $k$]的张量
    P_k: 一个形状为[$h$, $d$, $k$]的张量
    P_v: 一个形状为[$h$, $d$, $v$]的张量
    P_o:一个形状为[$h$, $d$, $v$]的张量
    返回:
    y: 一个形状为[$d$]的向量
    """
    q = tf.einsum("d, hdk->hk",x,P_q)
    K = tf.einsum("md, hdk->hmk",M,P_k)
    V = tf.einsum("md, hdv->hmv",M,P_v)
    logits = tf.einsum("hk, hmk->hm",q,K)
    weights = tf.softmax(logits)
    o = tf.einsum("hm, hmv->hv", weights,V)
    y = tf.einsum("hv, hdv->d",o,P_o)
    return y

注:[Vaswani et al., 2017]在logits上包含一个常数缩放因子。我们在代码中省略了这一点,因为它可以折叠到线性投影PqP_{q}PkP_{k}中。

2.3 多头注意力(批处理)

实际上,将多个查询批处理在一起要高效得多。下面的代码添加了两种类型的批处理。首先,我们从序列中nn个不同位置生成查询。这些查询都与相同的键和值交互。此外,我们一次处理bb个不同的非交互序列。按照[Vaswani et al., 2017],在自回归模型中,我们可以通过在logits中添加一个"掩码"来防止向后信息流,该掩码在非法位置包含值-\infty

def MultiheadAttentionBatched(X,M,mask,P_q,P_k,P_v,P_o):
    """ 多头注意力。
    参数:
    X: 一个形状为[$b$, $n$, $d$]的张量
    M: 一个形状为[$b$, $m$, $d$]的张量
    mask: 一个形状为[$b$, $h$, $n$, $m$]的张量
    P_q: 一个形状为[$h$, $d$, $k$]的张量
    P_k: 一个形状为[$h$, $d$, $k$]的张量
    P_v: 一个形状为[$h$, $d$, $v$]的张量
    P_o: 一个形状为[$h$, $d$, $v$]的张量
    返回:
    Y: 一个形状为[$b$, $n$, $d$]的张量
    """
    Q = tf.einsum("bnd, hdk->bhnk",X, P_q)
    K = tf.einsum("bmd, hdk->bhmk",M, P_k)
    V = tf.einsum("bmd, hdv->bhmv",M, P_v)
    logits = tf.einsum("bhnk, bhmk->bhnm",Q,K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm, bhmv->bhnv", weights,V)
    Y = tf.einsum("bhnv, hdv->bnd",O,P_o)
    return Y

2.3.1 批处理多头注意力的性能分析

为了简化性能分析,我们将做出几个简化假设:

m=nm=n

k=v=dhk=v=\frac{d}{h},如[Vaswani et al., 2017]所建议

ndn \leq d

算术运算的总数为Θ(bnd2)\Theta(bnd^{2})。(由于上述每个tf.einsum操作的复杂度在给定简化假设下为O(bnd2)O(bnd^{2})。)

要访问的内存总大小等于所涉及的所有张量大小之和:O(bnd+bhn2+d2)O(bnd+bhn^{2}+d^{2})。第一项是由于XX,MM,QQ,KK,VV,OOYY,第二项是由于logits和weights,第三项是由于投影张量PqP_{q},PkP_{k},PvP_{v}PoP_{o}

将两者相除,我们发现内存访问与算术运算的比率为O(1k+1bn)O(\frac{1}{k}+\frac{1}{bn})。在现代GPU/TPU硬件上,这种低比率对于良好性能是必要的,其中计算能力可能比内存带宽高两个数量级。

2.4 多头注意力(增量式)

在某些设置中,数据依赖性使得无法并行处理来自序列中多个位置的查询。例如,在自回归语言模型(如Transformer[Vaswani et al., 2017])中的自注意力层。在每个位置产生的查询关注到该位置及之前所有位置产生的键-值对。在训练期间,已知真实目标序列,我们可以使用类似于第2.3节中的高效并行实现。然而,当从训练好的模型生成时,特定位置的自注意力层的输出会影响在下一个位置生成的标记,这又影响该层在下一个位置的输入。这阻止了并行计算。下面显示了增量计算此自注意力层的代码。

def MultiheadSelfAttentionIncremental(X,prev_K,prev_V,P_q,P_k,P_v,P_o):
    """ 多头自注意力(一步)。
    参数:
    X: 一个形状为[$b$, $d$]的张量
    prev_K: 一个形状为[$b$, $h$, $m$, $k$]的张量
    prev_V: 一个形状为[$b$, $h$, $m$, $v$]的张量
    P_q: 一个形状为[$h$, $d$, $k$]的张量
    P_k: 一个形状为[$h$, $d$, $k$]的张量
    P_v: 一个形状为[$h$, $d$, $v$]的张量
    P_o: 一个形状为[$h$, $d$, $v$]的张量
    返回:
    y: 一个形状为[$b$, $d$]的张量
    new_K: 一个形状为[$b$, $h$, $m+1$, $k$]的张量
    new_V: 一个形状为[$b$, $h$, $m+1$, $v$]的张量
    """
    q = tf.einsum("bd, hdk->bhk",x,P_q)
    new_K = tf.concat(
        [prev_K,tf.expand_dims(tf.einsum("bd, hdk->bhk",M,P_k), axis=2)],
        axis=2)
    new_V = tf.concat(
        [prev_V,tf.expand_dims(tf.einsum("bd, hdv->bhv",M,P_v), axis=2)],
        axis=2)
    logits = tf.einsum("bhk, bhmk->bhm",q,new_K)
    weights = tf.softmax(logits)
    o = tf.einsum("bhm, bhmv->bhv", weights, new_V)
    y = tf.einsum("bhv, hdv->bd",O,P_o)
    return y, new_K, new_V

2.4.1 性能分析

我们做出与第2.3.1节相同的简化假设。

nn次调用中,算术运算的总数再次为Θ(bnd2)\Theta(bnd^{2})

nn次调用中,内存访问的总量为Θ(bn2d+nd2)\Theta(bn^{2}d+nd^{2}),第一项是由于KKVV,第二项是由于PqP_{q},PkP_{k},PvP_{v}PoP_{o}

将内存除以计算量,我们发现内存访问与算术运算的比率为Θ(nd+1b)\Theta(\frac{n}{d}+\frac{1}{b})。当ndn\approx db1b\approx1时,比率接近1,导致内存带宽成为现代计算硬件上的主要性能瓶颈。为了使增量生成高效,我们必须将这两个项都减少到1\ll11b\frac{1}{b}项更容易解决 - 我们只需使用更大的批量大小(内存大小允许的情况下)。

减少nd\frac{n}{d}项更难。该项与在每一步重新加载表示内存的KKVV张量的费用有关,这些张量的大小为bhmk=bn2bhmk=bn^{2}。一种解决方案是限制序列长度nn。另一种是减少被关注的位置数量,要么关注局部邻域,要么以其他方式压缩内存位置的数量,如[Liu et al., 2018]、[Zhang et al., 2018]、[Povey et al., 2018]中所述。在本文中,我们提出了一种减少KKVV张量大小的正交方法 - 即移除它们的"头"维度,同时保持查询中的"头"维度。

3 多查询注意力

我们将多查询注意力作为[Vaswani et al., 2017]中描述的多头注意力的一种变体引入。多头注意力由多个并行的注意力层(头)组成,对查询、键、值和输出有不同的线性变换。多查询注意力除了不同头共享一组键和值外,其他完全相同。增量多查询(自)注意力的代码与上面列出的多头注意力代码相同,只是我们在tf.einsum方程中移除了字母"h",其中它表示KK,VV,PkP_{k}PvP_{v}的"头"维度。

def MultiqueryAttentionBatched(X,M,mask,P_q,P_k,P_v,P_o):
    """ 多查询注意力。
    参数:
    X: 一个形状为[$b$, $n$, $d$]的张量
    M: 一个形状为[$b$, $m$, $d$]的张量
    mask: 一个形状为[$b$, $h$, $n$, $m$]的张量
    P_q: 一个形状为[$h$, $d$, $k$]的张量
    P_k: 一个形状为[$d$, $k$]的张量
    P_v: 一个形状为[$d$, $v$]的张量
    P_o: 一个形状为[$h$, $d$, $v$]的张量
    返回:
    Y: 一个形状为[$b$, $n$, $d$]的张量
    """
    Q = tf.einsum("bnd, hdk->bhnk",X, P_q)
    K = tf.einsum("bmd, dk->bmk",M,P_k)
    V = tf.einsum("bmd, dv->bmv",M, P_v)
    logits = tf.einsum("bhnk, bmk->bhnm",Q,K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm, bmv->bhnv", weights,V)
    Y = tf.einsum("bhnv, hdv->bnd",O,P_o)
    return Y
def MultiquerySelfAttentionIncremental(X,prev_K,prev_V,P_q,P_k,P_v,P_o):
    """ 多查询自注意力(一步)。
    参数:
    X: 一个形状为[$b$, $d$]的张量
    prev_K: 一个形状为[$b$, $m$, $k$]的张量
    prev_V: 一个形状为[$b$, $m$, $v$]的张量
    P_q: 一个形状为[$h$, $d$, $k$]的张量
    P_k: 一个形状为[$d$, $k$]的张量
    P_v: 一个形状为[$d$, $v$]的张量
    P_o: 一个形状为[$h$, $d$, $v$]的张量
    返回:
    y: 一个形状为[$b$, $d$]的张量
    new_K: 一个形状为[$b$, $m+1$, $k$]的张量
    new_V: 一个形状为[$b$, $m+1$, $v$]的张量
    """
    q = tf.einsum("bd, hdk->bhk",x,P_q)
    K = tf.concat(
        [prev_K,tf.expand_dims(tf.einsum("bd, dk->bk",M,P_k), axis=1)],
        axis=1)
    V = tf.concat(
        [prev_V,tf.expand_dims(tf.einsum("bd, dv->bv",M,P_v), axis=1)],
        axis=1)
    logits = tf.einsum("bhk, bmk->bhm",q,K)
    weights = tf.softmax(logits)
    O = tf.einsum("bhm, bmv->bhv", weights,V)
    y = tf.einsum("bhv, hdv->bd",O,P_o)
    return y,K,V

3.1 增量多查询注意力的性能分析

我们做出与第2.3.1节相同的简化假设。

nn次调用中,算术运算的总数再次为Θ(bnd2)\Theta(bnd^{2})

nn次调用中,内存访问的总量为Θ(bnd+bn2k+nd2)\Theta(bnd+bn^{2}k+nd^{2}),第一项是由于XX,qq,OOyy,第二项是由于KKVV,第三项是由于PqP_{q},PkP_{k},PvP_{v},PoP_{o}

将内存除以计算量,我们发现内存访问与算术运算的比率为Θ(1d+ndh+1b)\Theta(\frac{1}{d}+\frac{n}{dh}+\frac{1}{b})。我们已将有问题的nd\frac{n}{d}项减少了hh倍。理论上,给定大批量大小bb,这应该会显著提高增量生成的性能。在我们的实验部分,我们将展示性能增益是真实的,并且模型质量保持较高。

4 实验和结果

4.1 实验设置

按照[Vaswani et al., 2017],我们在WMT_2014英德翻译任务上进行评估。作为基线,我们使用具有6层的编码器-解码器Transformer模型,使用dmodel=1024d_{model}=1024dff=4096d_{ff}=4096h=8h=8dk=dv=128d_{k}=d_{v}=128,学习的位置嵌入,以及词嵌入和输出层之间的权重共享。基线模型和所有变体都有2.11亿个参数。所有模型都训练了100,000步(20个周期)。每个训练批次由128个示例组成,每个示例由256个标记的输入序列和256个标记的目标序列组成(多个训练句子被连接在一起以达到此长度)。模型在32核TPUv3集群上训练,每个模型大约需要2小时来训练。我们使用了来自tensor2tensor和mesh-tensorflow库的实现。

使用的配置可以在[出版前添加]找到,包括关于学习率、dropout、标签平滑等的详细信息。

在我们的"多查询"模型中,我们将模型中的所有注意力层替换为多查询注意力。这包括编码器自注意力、解码器自注意力和编码器-解码器注意力层。我们将前馈隐藏层从4096加宽到5440,以使总参数计数等于基线。

为了证明局部注意力和多查询注意力是正交的,我们还训练了基线和多查询模型的"局部"版本,其中解码器自注意力层(但不是其他注意力层)将注意力限制在当前位置和前31个位置。

减少KKVV大小的一种更简单的替代方法是减少头数hh和/或减少键和值的维度kkvv。我们训练了几个这样的模型进行比较,同样将前馈隐藏层加宽,以使总参数计数等于基线。

我们还使用"transformer-decoder"语言模型在BillionWord语言建模基准[Chelba et al., 2013]上进行了一组类似的实验。对于基线,我们使用具有6层的模型,dmodel=1024d_{model}=1024dff=8192d_{ff}=8192h=8h=8dk=dv=128d_{k}=d_{v}=128。基线和所有变体的总参数计数为1.92亿。我们在批量大小为64K标记的情况下训练了136K步(10个周期)。同样,我们使用32核TPUv3集群训练每个模型约3小时。

4.2 模型质量

表1显示了机器翻译实验的结果。我们使用贪婪最大似然解码对开发集进行解码,并使用sacrebleu "sacrebleu -t wmt13 -l en-de -tok intl"计算BLEU分数。我们还列出了开发集上每个子词标记的困惑度。根据这两个指标,多查询注意力模型似乎比基线稍差,但比任何涉及减少hhdkd_{k}dvd_{v}的替代方案要接近得多。

我们通过使用贪婪解码和束搜索(束大小4,α=0.6\alpha=0.6)对测试集进行解码,并使用sacrebleu "sacrebleu -t wmt14 -l en-de -tok intl"进行评估来验证结果。同样,多查询模型表现与基线相似,实际上在束搜索-4解码中获得了最高的BLEU分数(28.5)。

表3显示了十亿词语言建模基准的结果。模型通过开发集上每个词(而不是每个子词标记)的困惑度进行评估。结果与翻译结果相似。多查询注意力模型略差于基线,但明显优于任何涉及减少hhdkd_{k}dvd_{v}的替代方案。

4.3 速度

表2显示了各种模型的训练和推理时间。训练和推理速度均在一台TPUv2(8核)上评估。一个训练步骤(包含32,768个输入标记和32,768个目标标记,如上所述)对于基线模型需要433毫秒,对于多查询模型需要425毫秒。除以32,768,我们发现训练时间为每(输入标记+目标标记)13.2微秒,如表2所列。

我们在1024个序列的批次上运行增量贪婪推理(每核128个),使用128个标记的源序列长度和128个标记的目标序列长度。对于基线模型,模型的编码器部分耗时222毫秒,解码器的每个增量步骤耗时47毫秒。除以相应的标记数,我们发现编码器的摊销推理时间为每标记1.7微秒,而解码器则大得多,为每标记46微秒,如表2所列。对于多查询模型,编码器耗时195毫秒,解码器每步耗时3.9毫秒,摊销的每标记成本分别为1.5微秒和3.8微秒。表2显示了这些值以及束搜索的类似结果。

表1: WMT14 EN-DE结果。
注意力类型 $h$ $d_{k},d_{v}$ $d_{ff}$ ln(PPL) (dev) BLEU (test) beam 1 / 4
多头 8 128 4096 1.424 26.7 27.7 28.4
多查询 8 128 5440 1.439 26.5 27.5 28.5
多头局部 8 128 4096 1.427 26.6 27.5 /28.3
多查询局部 8 128 5440 1.437 26.5 27.6 / 28.2
多头 1 128 6784 1.518 25.8
多头 2 64 6784 1.480 26.2 26.8 / 27.9
多头 4 32 6784 1.488 26.1
多头 8 16 6784 1.513 25.8
表2: WMT14 EN-DE翻译任务的摊销训练和推理成本,序列长度为128。列出的值是以TPUv2微秒为单位的每输出标记成本。
注意力类型 训练 推理 enc. + dec. 束搜索-4 enc. + dec.
多头 13.2 1.7 + 46 2.0 + 203
多查询 13.0 1.5 + 3.8 1.6 + 32
多头局部 13.2 1.7 + 23 1.9 + 47
多查询局部 13.0 1.5 + 3.3 1.6 + 16
表3: 十亿词语言模型基准结果。
注意力类型 $h$ $d_{k},d_{v}$ $d_{ff}$ dev-PPL
多头 8 128 8192 29.9
多查询 8 128 9088 30.2
多头 1 128 9984 31.2
多头 2 64 9984 31.1
多头 4 32 9984 31.0
多头 8 16 9984 30.9
结论

我们提出了多查询注意力 - 一种多头注意力的替代方案,在增量设置中具有低得多的内存带宽需求。我们相信这使得基于注意力的序列模型能够在推理性能关键的应用中得到更广泛的采用。

0
0
0
0
评论
未登录
暂无评论