MHA(多头注意力)
Transformer 编码器块内的缩放点积注意力机制和多头注意力机制
MHA计算过程
标准注意力层分为4个核心阶段,输入为序列矩阵
(其中
为序列长度,
为模型隐藏层维度):
1. QKV线性投影
输入
通过可学习的权重矩阵
(
为注意力头的维度),分别映射为查询(Query)、键(Key)、值(Value)矩阵:
其中
,这一步的核心是将原始输入转换为注意力计算所需的表征空间。
2. 缩放点积注意力(SDPA)
计算查询与键的相似度,经缩放和softmax归一化后,对值矩阵进行加权求和,得到单头注意力输出:
- 缩放因子
用于缓解维度增长导致的内积值过大、softmax梯度消失问题;
- softmax确保注意力权重非负且每行和为1,实现对值的加权聚合。
3. 多头拼接
为捕捉多维度的语义信息,将注意力机制并行扩展为
个“头”(每个头有独立的
),并将所有头的输出拼接:
其中
,拼接后的输出维度为
。
4. 最终输出层
拼接后的多头输出通过一个线性变换层
,映射回模型的隐藏层维度,得到注意力层的最终输出:
用门控机制增强注意力层
门控可插入注意力层的5个关键节点(如下图左所示),直接影响其作用对象和效果:
- :SDPA输出后(即单头注意力计算完成后、多头拼接前);
- :分别在V、K、Q投影后(即
、
、
之后);
- :多头拼接后的最终输出层之后(即
之后)。
不同位置对应不同的调制对象(如
调制注意力聚合后的结果,
调制值矩阵本身),实验证明在 SDPA 输出之后加入门控机制(
)效果最好效果最优。如下图右:加上
门控后,Loss 曲线变得极其平滑,且 PPL 显著下降。
形式化表示为:
- :需要被调制(筛选/增强)的输入(如SDPA输出、QKV投影结果等);
- :用于计算门控分数的辅助输入(论文中采用预归一化后的隐藏状态);
- :门控的可学习参数;
- :激活函数(控制门控分数的范围);
- :门控调制后的输出。
门控配置
门控变体性能和结果
- 位置:
(SDPA输出后);
- 粒度:Elementwise:门控分数与
维度完全一致(如
为
,则门控分数也为
),实现逐元素的精细调制,且增加参数极少;相比Headwise,Elementwise门控更灵活,但参数和计算量略高;头级门控简洁高效。
- 结合方式:乘法型(Multiplicative):
,门控分数作为“权重”缩放
的元素(分数为0时完全过滤,1时完全保留);乘法型直接控制信息的“有无”,筛选能力更强;加法型仅微调信息强度,效果较弱。论文中乘法型(尤其是sigmoid激活的乘法型)表现最优。
- 激活函数:Sigmoid(
):输出范围为 [0, 1],天然适合作为“开关”(0=关闭,1=开启),仅用于乘法型门控。
- Head-Specific:每个注意力头有独立的
和门控分数,支持不同头根据自身功能(如有的头捕捉语法、有的捕捉语义)进行差异化调制;头特异性效果显著优于头共享。
分析-Gating机制为什么简单有效?
机制一:引入非线性
Gating引入非线性:传统Transformer的注意力层存在一个隐藏局限:连续线性映射导致的表达能力不足 。
1. 传统注意力的“低秩问题”
在多头注意力中,单个头的输出需经过两个关键线性层:
- **值投影(
)** :将输入映射为值向量(
); 2. **输出投影(
)** :将多头拼接后的结果映射回模型隐藏层维度(
)。
这两个线性层可合并为一个低秩线性映射 ——由于注意力头维度
(例如论文中
,
通常为数千),合并后的矩阵秩被限制在
,导致信息处理能力有限(类似“用直线拟合复杂曲线”)。若使用GQA(分组查询注意力),同一组内的头共享
,低秩问题会更突出。
2. Gating如何引入非线性?
Gating通过在线性映射之间插入非线性操作 ,打破低秩限制:
- 若门控加在
(SDPA输出后):相当于在“值投影(
)→ SDPA聚合”与“输出投影(
)”之间,插入sigmoid门控的非线性调制;
- 若门控加在
(值投影后):相当于在“值投影(
)”与“SDPA聚合”之间插入非线性。
论文通过实验验证了非线性的必要性如上表:
- 仅在SDPA后加SiLU(无门控参数),PPL从6.026降至5.975(虽有提升但有限);
- 移除加法型门控的SiLU(变为恒等映射),PPL回升至5.882,基准分数显著下降;
- 而SDPA元素级门控(sigmoid激活,强非线性)能将PPL降至5.761,MMLU提升2分以上。
这说明:非线性是提升表达能力的关键 ,而Gating的sigmoid激活(输出范围[0,1])能精准控制非线性的强度,比单纯的SiLU或归一化(如RMSNorm)更有效。
机制2:引入“输入依赖的稀疏性”
Gating的另一大优势是动态生成稀疏的门控分数 ——仅保留对当前查询(Query)有用的信息,过滤冗余内容,这是其优于传统注意力的核心差异。
1. 稀疏性的“输入依赖性”
论文4.2节强调:有效的稀疏性必须是查询依赖 的——门控分数由当前查询的隐藏态计算(而非固定值或仅依赖键/值),能针对不同查询动态调整筛选策略。具体为:
- **SDPA输出门控(
)** :门控分数基于“当前查询对应的SDPA聚合结果”计算,直接关联查询的语义需求;
- **值投影门控(
)** :门控分数仅依赖键/值的隐藏态,与查询无关,筛选精度更低。
实验如上表:
- SDPA元素级门控的平均分数仅0.116(大量分数接近0,稀疏性强),而值元素级门控的平均分数为0.221(稀疏性弱);
- 对应的性能:SDPA门控的MMLU达60.82,GSM8k达55.27,均显著高于值门控(MMLU 59.17,GSM8k 53.97)。
若强制门控“输入无关”(如零初始化门控参数,固定分数),即使保留非线性,PPL也仅降至5.917,远差于输入依赖的门控(5.761),证明稀疏性必须与查询绑定才能发挥作用 。
2. 头特异性增强稀疏性
论文还发现:每个注意力头需独立的门控分数 (头特异性),而非所有头共享(头共享)。原因是不同头负责不同任务(如部分头捕捉语法结构,部分头捕捉语义关联),独立门控能针对性筛选各头的有效信息。
- SDPA头级门控(头特异性,仅加1.6M参数)的PPL为5.792,MMLU 60.05;
- 相同位置的头共享门控(加201M参数)的PPL升至5.801,MMLU仅60.06(参数更多但效果更差)。
稀疏性的“粒度”至关重要——头特异性让门控更精准,避免了共享门控的“一刀切”问题。
衍生优势:根除“注意力Sink”,提升长上下文能力
注意力Sink回顾
注意力Sink是传统Transformer的顽疾:模型会过度关注序列早期token(如第一个token),导致后期token的注意力被稀释。
Softmax函数
Softmax 通常用于多类别分类问题中的输出层。在这个公式中,给定一个输入向量 (
),Softmax 函数将其转化为一个概率分布 (
)。每个元素
表示该样本属于第 (
) 类的概率。
表示输入向量中第 (
) 个元素的指数,而分母部分是所有元素的指数之和。这样的设计确保了输出概率分布的归一性,因为指数函数的性质使得所有元素都为正数,而分母的和则确保了概率总和为 1。
当
时,输入向量的第一个元素
往往会远远大于其他元素,这可以帮助模型在分类时更明确地选择一个主要的类别。
attention sink
语言模型(LLMs)在注意力机制中存在过度关注初始token(initial tokens)的现象。从以下两个角度探索下。
softmax角度
在SoftMax函数中,
是指数函数,这意味着即使输入的初始token (
) 在语义上与语言建模不相关,由于指数函数的存在,SoftMax 函数的输出中它仍然会有一个非零的值。因此,模型在进行自注意力机制时,即使当前的嵌入已经包含了足够的自包含信息用于预测,模型仍然需要从其他头和层中的其他token中汇聚一些信息。结果,模型倾向于将不必要的注意力值“倾泻 ”到特定的token上,这就是所谓的attention sink (注意力汇聚)。
- 局部模式在前两层中的呈现: 在第一层和第二层(layers 0 和 1),注意力图呈现出“ 局部 ”模式,即对最近的token给予更多的关注。这表明模型在初步的处理阶段更注重周围的token,强调了局部上下文的重要性。
- 模型跨所有层和头部(heads)都强烈关注初始token: 在底部两层之外,也就是深层网络中,模型在所有层和头部都倾向于强烈关注初始token。这与“ attention sink ”现象相吻合,即模型在处理过程中过度关注初始token,而 SoftMax 函数的特性是导致这种现象的原因之一。
自回归模型角度
由于自回归语言建模的顺序性质 ,初始token对所有后续token都是可见的 ,而后续token只对一组有限的后续token可见 。因此,初始token更容易被训练成注意力的聚焦点,捕捉到不必要的关注(在训练中可能导致模型过度集中注意力,捕捉到一些在语言建模任务中并不重要的信息。)。如何理解呢?
- 自回归性质: 这些模型是自回归的,即它们根据之前生成的token来生成下一个token。在这个过程中,初始token是最早生成的token,因此在生成整个序列的过程中,它对所有后续token都是可见的。
- 可见性的不对称性: 由于模型是按顺序生成token的,初始token在生成序列的整个过程中一直是可见的,而后续token只能被生成的一小部分token所看到。这种不对称性导致了对初始token更强烈的关注。
- 训练中的注意力聚焦点: 由于模型在训练过程中学会了将注意力集中在初始token上,这些token更容易成为“attention sink”,即吸引不必要的注意力。这可能是因为初始token在训练中更频繁地与后续token发生交互,从而更容易捕捉到一些模型认为重要的信息。
为什么Gating能消除注意力Sink?
注意力Sink的本质是softmax归一化导致的冗余注意力积累 ——早期token的键向量与后续查询的相似度被反复放大,形成“注意力垄断”。Gating通过以下方式打破这一循环:
- 门控分数是“查询依赖的稀疏过滤器”:若当前查询与早期token无关,门控分数会趋近于0,直接抑制早期token的注意力贡献;
- 稀疏性减少“ massive activation”:早期层FFN的极端激活值会加剧Sink(而Gating让隐藏态激活值从基线的1053降至94(下表,M-Act列),间接减少Sink的诱因。
M-Act列
- 基线模型Layer 21对第一个token的注意力达83%,而SDPA门控模型仅4%;
- 所有层的平均注意力占比(F-Attn)从0.467降至0.048,Sink现象基本消失。
进一步优势:长上下文扩展
注意力Sink的消除直接提升了模型的长上下文外推能力 。论文通过YaRN扩展上下文至128k,原因是:基线依赖Sink维持注意力分布,而门控模型通过查询依赖的稀疏性动态调整注意力,无需依赖固定token,因此对上下文长度变化更鲁棒。
参考文献
- Gated Attention for Large Language Models: Non-linearity, Sparsity,and Attention-Sink-Free,https://arxiv.org/pdf/2505.06708
- Efficient Streaming Language Models with Attention Sinks,https://arxiv.org/abs/2309.17453
