Introduction
单词的顺序对自然语言理解有很大的价值.
本文引入了一种新颖的方法,即 Rotary Position Embedding(RoPE),将位置信息应用到PLM的学习过程中。
具体来说, RoPE用旋转矩阵对绝对位置进行编码,同时在自注意公式中加入了显式的相对位置依赖 。
RoPE优于现有的方法,包括序列长度的灵活性,随着相对距离的增加而衰减token间依赖关系,以及用相对位置编码装备线性self-attention的能力。
简而言之,我们的贡献有以下三方面:
- 我们研究了现有的相对位置编码方法,发现它们大多是基于将位置编码添加到上下文表示的分解思想构建的。我们提出了一种新颖的方法,即 旋转位置嵌入(RoPE) ,将位置信息应用到PLMS的学习过程中。 关键思想是通过将上下文表示与具有明确理论解释的旋转矩阵相乘来编码相对位置 。
- 我们研究了RoPE的性质,表明 它随着相对距离的增加而衰减,这是自然语言编码所需要的 。我们善意地认为,以前基于相对位置编码的方法与线性自注意不兼容
- 我们在各种长文本基准数据集上评估了所提出的RoFormer。我们的实验表明,与它的替代品相比,它始终能够获得更好的性能。
Background and Related Work
Preliminary
Absolute position embedding
x + p 后仿射变换,q、k、v都有了绝对位置信息。
Relative position embedding
将位置m和位置n关联起来。
q没有位置信息,k、v包含相对位置信息。相对位置信息引入使得模型更关注相邻位置的信息,长距离的信息需要衰减。
公式6是绝对位置编码的
展开形式,7是其相对位置编码的形式。
公式8,9,10是其它相对位置编码的形式。
Proposed approach
Formulation
之前的相对位置编码都是在
展开式上最近似替换。RoPE是直接求包含相对位置m-n的
的理论公式。
Rotary position embedding
A 2D case
具体来说,结合相对位置嵌入很简单: 简单地旋转仿射变换的词嵌入向量,以其位置索引的角度倍数旋转 ,从而解释旋转位置嵌入背后的直觉。
仿射变换的词嵌入向量是
;以其位置索引的角度倍数旋转是
General form
公式16最后一个等式的第一个
应该会
与以往工作中采用的位置嵌入方法(即公式(3)至公式(10)的 加性 不同, 我们的方法是乘法的 。
此外,RoPE在应用自注意时,自然地 通过旋转矩阵积来整合相对位置信息 ,而不是改变可加位置编码扩展公式中的项。
x 经过q、k Linear后,再旋转向量就是引入相对位置编码(RoPE)后的
了。
Properties of RoPE
Long-term decay
设
θ
。
我们可以证明这种设置提供了一个长期的衰减特性,这意味着 当相对位置增加时内积会衰减 。这个性质与直觉一致,一对相对距离较长的token应该少些联系。
RoPE with linear attention
Transformers are RNNs:Fast Autoregressive Transformers with Linear Attention: https://arxiv.org/pdf/2006.16236v3.pdf
Efficient Attention: Attention with Linear Complexities: https://arxiv.org/pdf/1812.01243v9.pdf
Linear Attention 中的 attention score 由
,变成了
。 其中n是序列的长度,可变;
是模型的size,固定。
Theoretical Explanati
Derivation of RoPE under
根据:
证明:
具体证明参看论文。
Computational efficient realization of rotary matrix multiplication
Long-term decay of RoPE
证明略。
Experiments and Evaluatio
Machine Translatio
Pre-training Language Modeling
Performer with RoPE
Evaluation on Chinese Data
Implementation
https://github.com/search?q=repo%3Abojone%2Fbert4keras%20RoPE&type=code
# https://github.com/bojone/bert4keras/blob/a160e18b714c68f57e4d4d47afac2df39b36db50/bert4keras/layers.py#L1428
@recompute\_grad
def call(self, inputs, mask=None):
# 输入变换
inputs = self.dense(inputs)
inputs = tf.split(inputs, self.heads, axis=-1)
inputs = K.stack(inputs, axis=-2)
qw, kw = inputs[..., :self.head_size], inputs[..., self.head_size:]
# RoPE编码
if self.RoPE:
pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
qw, kw = apply_rotary_position_embeddings(pos, qw, kw)
# 计算内积
logits = tf.einsum('bmhd,bnhd->bhmn', qw, kw) / self.head_size**0.5
# 排除下三角
if self.tril_mask:
tril_mask = tf.linalg.band_part(K.ones_like(logits[0, 0]), 0, -1)
tril_mask = K.cast(tril_mask, 'bool')
else:
tril_mask = None
# 返回最终结果
return sequence_masking(logits, mask, -np.inf, [2, 3], tril_mask)
# https://github.com/bojone/bert4keras/blob/a160e18b714c68f57e4d4d47afac2df39b36db50/bert4keras/layers.py#L845
class SinusoidalPositionEmbedding(Layer):
"""定义Sin-Cos位置Embedding
"""
def \_\_init\_\_(
self,
output\_dim,
merge\_mode='add',
custom\_position\_ids=False,
**kwargs
):
super(SinusoidalPositionEmbedding, self).__init__(**kwargs)
self.output_dim = output_dim
self.merge_mode = merge_mode
self.custom_position_ids = custom_position_ids
def call(self, inputs):
"""如果custom\_position\_ids,那么第二个输入为自定义的位置id
"""
if self.custom_position_ids:
inputs, position_ids = inputs
if 'float' not in K.dtype(position_ids):
position_ids = K.cast(position_ids, K.floatx())
else:
input_shape = K.shape(inputs)
batch_size, seq_len = input_shape[0], input_shape[1]
position_ids = K.arange(0, seq_len, dtype=K.floatx())[None]
embeddings = sinusoidal_embeddings(position_ids, self.output_dim)
if self.merge_mode == 'add':
return inputs + embeddings
elif self.merge_mode == 'mul':
return inputs * (embeddings + 1.0)
elif self.merge_mode == 'zero':
return embeddings
else:
if not self.custom_position_ids:
embeddings = K.tile(embeddings, [batch_size, 1, 1])
return K.concatenate([inputs, embeddings])
# https://github.com/bojone/bert4keras/blob/a160e18b714c68f57e4d4d47afac2df39b36db50/bert4keras/backend.py#L336
def sinusoidal\_embeddings(pos, dim, base=10000):
"""计算pos位置的dim维sinusoidal编码
"""
assert dim % 2 == 0
# (d/2,)
indices = K.arange(0, dim // 2, dtype=K.floatx())
indices = K.pow(K.cast(base, K.floatx()), -2 * indices / dim)
# pos (1, T), indices (d/2,) -> (1, T, d/2)
embeddings = tf.einsum('...,d->...d', pos, indices)
# (1, T, d/2, 2)
embeddings = K.stack([K.sin(embeddings), K.cos(embeddings)], axis=-1)
# (1, T, d)
embeddings = K.flatten(embeddings, start_dim=-2)
return embeddings
def align(tensor: tf.Tensor, axes: List[int], ndim=None):
"""重新对齐tensor(批量版expand\_dims)
axes:原来的第i维对齐新tensor的第axes[i]维;
ndim:新tensor的维度。
"""
assert len(axes) == K.ndim(tensor)
assert ndim or min(axes) >= 0
ndim = ndim or max(axes) + 1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices = [None] * ndim
for i in axes:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices[i] = slice(None)
return tensor[indices]
# https://github.com/bojone/bert4keras/blob/a160e18b714c68f57e4d4d47afac2df39b36db50/bert4keras/backend.py#L359
def apply\_rotary\_position\_embeddings(sinusoidal, *tensors):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[b, n, d],tensors为tensor的列表,而
tensor.shape=[b, n, ..., d]。
"""
assert len(tensors) > 0, 'at least one input tensor'
assert all([
K.int_shape(tensor) == K.int_shape(tensors[0]) for tensor in tensors[1:]
]), 'all tensors must have the same shape'
ndim = K.ndim(tensors[0])
# sinusoidal shape same with tensors[0]
# [b,n,d] -> [b,n,...,d]
sinusoidal = align(sinusoidal, [0, 1, -1], ndim)
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api\_docs/python/tf/keras/backend/repeat\_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,n, ..., d/2] -> [b,n, ..., d]
cos_pos = K.repeat_elements(sinusoidal[..., 1::2], rep=2, axis=-1)
sin_pos = K.repeat_elements(sinusoidal[..., ::2], rep=2, axis=-1)
outputs = []
for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x\_d, x\_{d-1}]
tensor2 = K.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
tensor2 = K.reshape(tensor2, K.shape(tensor))
# 公式 34, out = x * cos\_pos + x2 * sin\_pos
outputs.append(tensor * cos_pos + tensor2 * sin_pos)
return outputs[0] if len(outputs) == 1 else outputs