点击下方卡片,关注「集智书童」公众号
旋转式位置编码(RoPE)最早是论文[1]
提出的一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。而目前很火的 LLaMA 模型也是采用该位置编码方式。
接下来结合代码和论文来解读一下 RoPE。
基本概念
首先论文中定义一个长度为 N
的输入序列为:
其中 wi
表示输入序列中第 i
个 token,而输入序列 SN
对应的 embedding 表示为:
其中 xi
表示第 i
个 token wi
对应的 d
维词嵌入向量。
接着在做 self-attention 之前,会用词嵌入向量计算 q, k, v
向量同时加入位置信息,函数公式表达如下:
其中 qm
表示第 m
个 token 对应的词向量 xm
集成位置信息 m
之后的 query 向量。而 kn
和 vn
则表示第 n
个 token 对应的词向量 xn
集成位置信息 n
之后的 key 和 value 向量。
而基于 transformer 的位置编码方法都是着重于构造一个合适的 f{q,k,v}
函数形式。
而计算第 m 个词嵌入向量 xm
对应的 self-attention 输出结果,就是 qm
和其他 kn
都计算一个 attention score ,然后再将 attention score 乘以对应的 vn
再求和得到输出向量 om
:
绝对位置编码
对于位置编码,常规的做法是在计算 query, key 和 value 向量之前,会计算一个位置编码向量 pi
加到词嵌入 xi
上,位置编码向量 pi
同样也是 d
维向量,然后再乘以对应的变换矩阵 W{q,k,v}
:
而经典的位置编码向量 pi
的计算方式是:
其中 p_{i,2t}
表示位置 d
维度向量 pi
中的第 2t
个元素也就是偶数索引位置的计算公式,而 p_{i,2t+1}
就对应奇数索引位置的计算公式。
python 代码如下:
# position 就对应 token 序列中的位置索引 i
# hidden\_dim 就对应词嵌入维度大小 d
# seq\_len 表示 token 序列长度
def get\_position\_angle\_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / hidden_dim) for hid_j in range(hidden_dim)]
# position\_angle\_vecs.shape = [seq\_len, hidden\_dim]
position_angle_vecs = np.array([get_position_angle_vec(pos_i) for pos_i in range(seq_len)])
# 分别计算奇偶索引位置对应的 sin 和 cos 值
position_angle_vecs[:, 0::2] = np.sin(position_angle_vecs[:, 0::2]) # dim 2t
position_angle_vecs[:, 1::2] = np.cos(position_angle_vecs[:, 1::2]) # dim 2t+1
# positional\_embeddings.shape = [1, seq\_len, hidden\_dim]
positional_embeddings = torch.FloatTensor(position_angle_vecs).unsqueeze(0)
旋转式位置编码
接着论文中提出为了能利用上 token 之间的相对位置信息,假定 query 向量 qm
和 key 向量 kn
之间的内积操作可以被一个函数 g
表示,该函数 g
的输入是词嵌入向量 xm
, xn
和它们之间的相对位置 m - n
:
接下来的目标就是找到一个等价的位置编码方式,从而使得上述关系成立。
假定现在词嵌入向量的维度是两维 d=2
,这样就可以利用上2维度平面上的向量的几何性质,然后论文中提出了一个满足上述关系的 f
和 g
的形式如下:
(x_m,m),f_k(x_n,n)>
上面的公式一眼看过去感觉很复杂,怎么理解呢?
首先我们得先了解一下基本的复数相关知识。
首先看到上述 f
和 g
公式中有个指数函数
(x_m,m),f_k(x_n,n)>
这个其实是欧拉公式 [2]
,其中 x
表示任意实数, e
是自然对数的底数,i
是复数中的虚数单位,则根据欧拉公式有:
(x_m,m),f_k(x_n,n)>
上述指数函数可以表示为实部为 cosx
,虚部为 sinx
的一个复数,欧拉公式 [2]
建立了指数函数、三角函数和复数之间的桥梁。
则上述 f
和 g
公式中的
(x_m,m),f_k(x_n,n)>
然后我们看回公式:
(x_m,m),f_k(x_n,n)>
其中 Wq
是个二维矩阵,xm
是个二维向量,相乘的结果也是一个二维向量,这里用 qm
表示:
(x_m,m),f_k(x_n,n)>
然后首先将 qm
表示成复数形式:
(x_m,m),f_k(x_n,n)>
接着
(x_m,m),f_k(x_n,n)>
其实就是两个复数相乘:
(x_m,m),f_k(x_n,n)>
我们首先来复习一下复数乘法的性质:
(x_m,m),f_k(x_n,n)>
可以看到,复数乘法也是用的分配律,还有用到了复数的一个性质:
(x_m,m),f_k(x_n,n)>
然后就有:
(x_m,m),f_k(x_n,n)>
将结果重新表达成实数向量形式就是:
(x_m,m),f_k(x_n,n)>
相信读者看到这里会发现这不就是 query 向量乘以了一个旋转矩阵[5]
吗?
(x_m,m),f_k(x_n,n)>
这就是为什么叫做旋转式位置编码的原因。
同理可得 key 向量 kn
:
(x_m,m),f_k(x_n,n)>
最后还有个函数 g
:
(x_m,m),f_k(x_n,n)>
其中 Re[x]
表示一个复数 x
的实部部分,而
(x_m,m),f_k(x_n,n)>
则表示复数
(x_m,m),f_k(x_n,n)>
的共轭,复习一下共轭复数的定义:
所以可得:
(x_m,m),f_k(x_n,n)>
继续可得:
(x_m,m),f_k(x_n,n)>
ok,接下来我们就要证明函数 g
的计算公式是成立的。
首先回顾一下 attention 操作, 位置 m 的 query 和位置 n 的 key 会做一个内积操作:
(x_m,m),f_k(x_n,n)>
接着继续之前先复习一下三角函数的一些性质[3]
:
好了回到上面那坨式子,我们整理一下:
(x_m,m),f_k(x_n,n)>(x_m,m),f_k(x_n,n)>
这就证明上述关系是成立的,位置 m 的 query 和位置 n 的 key 的内积就是函数 g
。
然后上面的讲解是假定的词嵌入维度是2维向量,而对于d >= 2
的通用情况,则是将词嵌入向量元素按照两两一组分组,每组应用同样的旋转操作且每组的旋转角度计算方式如下:
(x_m,m),f_k(x_n,n)>(x_m,m),f_k(x_n,n)>(x_m,m),f_k(x_n,n)>
所以简单来说 RoPE 的 self-attention 操作的流程是,对于 token 序列中的每个词嵌入向量,首先计算其对应的 query 和 key 向量,然后对每个 token 位置都计算对应的旋转位置编码,接着对每个 token 位置的 query 和 key 向量的元素按照 两两一组 应用旋转变换,最后再计算 query 和 key 之间的内积得到 self-attention 的计算结果。
论文中有个很直观的图片展示了旋转变换的过程:
LLaMA 官方实现代码
[4]
如下(经过简化):
def precompute\_freqs\_cis(dim: int, seq\_len: int, theta: float = 10000.0):
# 计算词向量元素两两分组之后,每组元素对应的旋转角度
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成 token 序列索引 t = [0, 1,..., seq\_len-1]
t = torch.arange(seq_len, device=freqs.device)
# freqs.shape = [seq\_len, dim // 2]
freqs = torch.outer(t, freqs).float()
# torch.polar 的文档
# https://pytorch.org/docs/stable/generated/torch.polar.html
# 计算结果是个复数向量
# 假设 freqs = [x, y]
# 则 freqs\_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply\_rotary\_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs\_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# xq.shape = [batch\_size, seq\_len, dim]
# xq\_.shape = [batch\_size, seq\_len, dim // 2, 2]
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
# 转为复数域
xq_ = torch.view_as_complex(xq_)
xk_ = torch.view_as_complex(xk_)
# 应用旋转操作,然后将结果转回实数域
# xq\_out.shape = [batch\_size, seq\_len, dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def \_\_init\_\_(self, args: ModelArgs):
super().__init__()
self.wq = Linear(...)
self.wk = Linear(...)
self.wv = Linear(...)
self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)
def forward(self, x: torch.Tensor):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(batch_size, seq_len, dim)
xk = xk.view(batch_size, seq_len, dim)
xv = xv.view(batch_size, seq_len, dim)
# attention 操作之前,应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# scores.shape = (batch\_size, seq\_len, seqlen)
scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, xv) # (batch\_size, seq\_len, dim)
# ......
可以看到 LLaMA 的官方实现代码和论文 [1]
中的描述是一致的。
参考资料
- [1] https://arxiv.org/pdf/2104.09864.pdf
- [2] https://en.wikipedia.org/wiki/Euler's\_formula
- [3] https://en.wikipedia.org/wiki/List\_of\_trigonometric\_identities
- [4] https://github.com/facebookresearch/llama/tree/main
- [5] https://zh.wikipedia.org/wiki/旋转矩阵 (x_m,m),f_k(x_n,n)>(x_m,m),f_k(x_n,n)>(x_m,m),f_k(x_n,n)>
[
SAM增强技术 | SAMAug提出Point Prompt增强,让SAM模型天天向上](https://mp.weixin.qq.com/s?__biz=MzU5OTA2Mjk5Mw==&mid=2247510793&idx=1&sn=c4778dbaa5b57c32cad3999007360ab2&chksm=feb84bb7c9cfc2a14ef2a244a8512140151b57b79fd939cb6f58473df9b0dba0b7465171a92b&scene=21#wechat_redirect)
[
Backbone创新 | 中科大联合百度提出全新Transformer Backbone](https://mp.weixin.qq.com/s?__biz=MzU5OTA2Mjk5Mw==&mid=2247510723&idx=1&sn=b8a0ee23c73497afe1ec36785266a5a2&chksm=feb84a7dc9cfc36b7afd487d9f86d5bdba9a598ddf5235d361f2905671d88c9900a3f5e858a8&scene=21#wechat_redirect)
[
自动驾驶感知多任务框架 | MultiTask V3、HybridNets和YOLOP谁更强呢?](https://mp.weixin.qq.com/s?__biz=MzU5OTA2Mjk5Mw==&mid=2247510671&idx=1&sn=e1f0dcea5a1fce1e8d5798ef7dc27404&chksm=feb84a31c9cfc3270a3b6a3c365e071b868b43a4f3fc59ebdc6a475c5361fc953b85e79e75ae&scene=21#wechat_redirect)
扫码加入👉「集智书童」交流群
(备注: 方向+学校/公司+昵称 )
想要了解更多:
前沿AI视觉感知全栈知识👉「分类、检测、分割、关键点、车道线检测、3D视觉(分割、检测)、多模态、目标跟踪、NerF」
行业技术方案 👉「AI安防、AI医疗、AI自动驾驶」
AI模型部署落地实战 👉「CUDA、TensorRT、NCNN、OpenVINO、MNN、ONNXRuntime以及地平线框架」
欢迎扫描上方二维码,加入「集智书童-知识星球 」,日常分享论文、学习笔记、问题解决方案、部署方案以及全栈式答疑,期待交流!
免责声明
凡本公众号注明“来源:XXX(非集智书童)”的作品,均转载自其它媒体,版权归原作者所有,如有侵权请联系我们删除,谢谢。
点击下方“阅读原文 ”,
了解更多AI学习路上的 「武功秘籍」