- 前言
2017
年,谷歌研究人员在《Attention Is All You Need
》这篇论文中提出了Transformer
模型,该模型最初是被用于机器翻译任务中。由于其良好的可并行性和强大的特征提取能力,Transformer
模型在随后的几年中被用到自然语言处理、语音识别、计算机视觉等各个领域中,并表现出优异的性能。
本文基于论文的内容解读Transformer
模型的各个组成部分,然后用Python
实现一个完整的Transformer
模型。
- Transformer模型结构解析
1.1 模型总体架构
Transformer
的总体架构如下图所示,模型包含一个编码器和解码器(分别对应下图中的左侧和右侧部分),编码器和解码器都是由一系列堆叠的注意力结构和全连接层组成。
编码器
编码器由个相同的层组成,每个层又包含两个子层:第一个子层为多头自注意力机制,第二个子层为一个简单的全连接前馈网络。这两个子层都采用了残差连接结构,后面接一个LayerNorm
层,也就是说,每个子层的输出为。因为使用了残差连接结构,模型中所有子层,包括输入的Embedding
层,它们的输出维度都等于512
。
解码器
解码器也是由个相同的层组成,除了使用了与编码器相同的子层外,解码器还在其中插入了第三个子层,这个子层对编码器的输出memory
执行多头注意力机制。与编码器类似的,解码器的子层也采用残差连接结构,后面再接一个LayerNorm
层。需要注意的是,解码器在多头自注意力子层中添加了一个掩码,这种机制可以确保对位置的预测只能依赖于小于位置的已知输出。
解码器的输出通过可学习的线性变换层和SoftMax
函数转换为预测下一个Token
的概率。
1.2 模型结构详解
1.2.1 注意力机制
注意力函数可以描述为将查询(query
)和一组键(key
)- 值(value
)对映射到输出,其中query
、key
、和value
都是向量。注意力函数的功能就是计算value
的加权和,其中分配给每个value
的权重由与query
和key
相关的特定函数计算得出。
缩放点积注意力
作者提出的注意力称为 「缩放点积注意力」 ,它的输入是维度为的query
和key
,以及维度为的value
。对于输入的query
,首先计算它与key
的点积并除以缩放系数,然后用一个SoftMax
函数来计算应用到value
上的权重,这个权重再与value
做点积运算得到最终结果。
在实际应用中,会把一组query
向量打包在一起组成矩阵Q
,相应的key
和value
也分别打包为矩阵K
和V
,然后同时用注意力函数进行计算:
为什么Q
和K
的点积结果要除以系数?因为作者发现如果的值比较大,那么Q
和K
点积的结果会产生很大的值,这样经过SoftMax
函数后会产生非常小的梯度而不利于模型训练。为了消除这种影响,作者把点积结果除以一个系数,这也是为什么作者把这种注意力称为缩放注意力的原因。
多头注意力
把输入的query
、key
和value
用不同的、可学习的线性映射操作分别映射h
次,映射后的维度分别为、和,然后每个映射的版本再并行地进行注意力计算,产生维度的输出结果。把这h
个输出的结果拼接到一起然后再做一次映射,使得最后输出结果的维度与原始输入相同。作者把这种多次映射再分别进行注意力计算的结构称为 「多头注意力」 ,它比只使用一个维度为的query
、key
和value
来计算注意力的效果要好很多。
与单头注意力结构相比,多头注意力使得模型具备关注来自不同表示子空间信息的能力,模型的学习能力更强大。多头注意力机制其实就是将输入序列进行多组自注意力处理的过程,可以用公式表示为:
对于每个注意力头
,输入矩阵
Q
、
K
、
V
分别通过参数可学习的矩阵
、
、
进行映射,然后计算注意力。所有注意力头的输出结果会被拼接到一起,再用一个矩阵
把结果映射回维度
。作者在实际应用中采用了
个头,对于每个头,
。
通俗地讲,多头注意力就是将一个维度为的输入张量平均拆分成份,每一份都单独进行自注意力计算,然后把这个自注意力的结果进行汇总,最后把汇总的结果映射回原来的维度。
注意力的使用细节
在Transformer
模型中,对多头注意力的使用方式有以下3
种方式:
- 在编码器中的注意力层中,
query
、key
和value
都来自同一个输入,这种注意力叫做自注意力。自注意力被用来获取同一序列不同位置的依赖关系。 - 解码器的第一个多头注意力子层,
query
和key
做点积的结果会再添加一个掩码,这个掩码的作用是防止解码器在对位置进行预测的时候提前看到了位置及以后的信息。 - 解码器中间的那个子层,这个多头注意力的
query
来自解码器第一个多头自注意力的输出,但是它的key
和value
来自解码器的输出memory
。这种query
、key
和value
不同源的注意力叫做交叉注意力。解码器使用交叉注意力来处理输入序列和输出序列之间的依赖关系。
1.2.2 前馈网络
除了注意力子层之外,编码器和解码器中的每一层都包含一个全连接的前馈网络,该网络包括两个线性变换层,它们中间有一个ReLU
激活函数:
线性变换层在不同位置的参数是相同的,但它们在层与层之间是使用不同的参数,另外一种实现方式是采用两个卷积核大小为1
的卷积层。前馈网络的输入、输出维度均为,中间隐藏层的维度则为。
1.2.3 Embedding
与其他序列转导模型类似,Transformer
使用可学习的Embedding
将输入Token
和输出Token
转换为维度为的向量。此外,在Embedding
层中,权重会乘以系数。
1.2.4 位置编码
由于Transformer
模型中没有循环和卷积结构,为了使模型能利用序列的顺序信息,就必须为一个序列中的各个Token
注入相对或绝对位置的信息。为此,作者为编码器和解码器的输入Embedding
中各添加了一个位置编码信息,编码信息与输入Embedding
的维度都是,以方便二者相加。位置编码可以选择用可学习的和固定值的,作者试验了两种类型的位置编码方式,发现两种方式产生的结果几乎一致。
在Transformer
模型中,作者使用了不同频率的正弦和余弦函数来生成位置编码:
其中
是位置,
是维度。之所以选择这样的位置编码函数是因为它可以让模型轻松地学习相对位置:对于任意的偏移
,
都可以表示为
的线性变换 。此外,它允许模型在推理阶段处理比训练期间遇到过的最长序列长度更长的序列。
位置编码的函数还可以进一步化简,其中
- 用Python实现Transformer模型
本文代码来源于该博客:https://www.datacamp.com/tutorial/building-a-transformer-with-py-torch,略有改动。
2.1 多头注意力
从前文对Transformer
模型的结构解析可以知道,多头注意力和前馈网络是Transformer
模型的基本组成单元,我们首先来实现多头注意力单元。
多头注意力首先通过Linear
层分别对Q
,K
,V
进行降维,然后进行缩放点积注意力计算。这样的操作会进行h
次,这h
次的结果汇总后再通过Linear
层进行维度变换得到最终的结果。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
# Ensure that the model dimension (d_model) is divisible by the number of heads
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# Initialize dimensions
self.d_model = d_model # Model's dimension
self.num_heads = num_heads # Number of attention heads
self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
# Linear layers for transforming inputs
self.W_q = nn.Linear(d_model, d_model) # Query transformation
self.W_k = nn.Linear(d_model, d_model) # Key transformation
self.W_v = nn.Linear(d_model, d_model) # Value transformation
self.W_o = nn.Linear(d_model, d_model) # Output transformation
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# Calculate attention scores
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply mask if provided (useful for preventing attention to certain parts like padding)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# Softmax is applied to obtain attention probabilities
attn_probs = torch.softmax(attn_scores, dim=-1)
# Multiply by values to obtain the final output
output = torch.matmul(attn_probs, V)
return output
def split_heads(self, x):
# Reshape the input to have num_heads for multi-head attention
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
# Combine the multiple heads back to original shape
batch_size, _, seq_length, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
def forward(self, Q, K, V, mask=None):
# Apply linear transformations and split heads
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
# Perform scaled dot-product attention
attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
# Combine heads and apply output transformation
output = self.W_o(self.combine_heads(attn_output))
return output
上面的代码中,MultiHeadAttention
类的几个主要函数说明如下:
- 「
scaled_dot_product_attention
」 :实现缩放点积注意力,函数里的每一步实现应该都比较好懂。 - 「
split_heads
」 :对维度为(batch_size, seq_length, d_model
)的输入张量x
进行拆分,返回的张量维度为(batch_size, num_heads, seq_length, d_k
)。这个函数实际上只对x
做维度变换,这样做的好处是方便并行地实现多头注意力计算。 - 「combine_heads」 :对维度为(
batch_size, num_heads, seq_length, d_k
)的张量x
进行维度变换,返回维度为(batch_size, seq_length, d_model
)的张量。这个函数其实就是实现了对多个注意力头的输出结果进行Concat
操作。
2.2 前馈网络
前馈网络比较简单,就是两个全连接层,中间有个激活函数:
class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(PositionWiseFeedForward, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
2.3 位置编码
位置编码按照前面的公式实现即可,该类的两个初始化参数说明如下:
- 「
d_model
」 :模型的输入维度 - 「
max_seq_length
」 :预设的最大序列长度
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
由于位置编码数据是固定参数,在训练过程中不需要更新,所以调用register_buffer
函数向模型注册一个永久性缓冲区。
2.4 Encoder层
一个Encoder
层主要包含一个多头注意力模块和一个前馈网络模块,它们的输出都会接一个LayerNorm
层,都采用残差连接结构。
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
2.5 Decoder层
一个Decoder
层主要包含一个多头自注意力模块、一个多头交叉注意力模块和一个前馈网络模块,它们的输出都会接一个LayerNorm
层,都采用残差连接结构,其中交叉注意力模块的K
和V
来自编码器的输出。
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
2.6 完整的Transformer模型
一个完整的Transformer
模型包含一个编码器和一个解码器,编码器和解码器分别包含N
个Encoder
层和Decoder
层。源序列和目标序列经过Embedding
层映射到向量空间并添加位置编码信息,然后分别送入编码器和解码器进行处理。
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
super(Transformer, self).__init__()
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.fc = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def generate_mask(self, src, tgt):
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
seq_length = tgt.size(1)
nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
tgt_mask = tgt_mask & nopeak_mask
return src_mask, tgt_mask
def forward(self, src, tgt):
src_mask, tgt_mask = self.generate_mask(src, tgt)
src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)*math.sqrt(self.d_model)))
tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)*math.sqrt(self.d_model)))
enc_output = src_embedded
for enc_layer in self.encoder_layers:
enc_output = enc_layer(enc_output, src_mask)
dec_output = tgt_embedded
for dec_layer in self.decoder_layers:
dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)
output = self.fc(dec_output)
return output
2.7 训练一个简单的Transformer模型
通过上面的代码,我们就可以构建一个完整的Transformer
模型了。接下来可以试着训练一个简单的模型,源序列和目标序列都是一些随机整数,损失函数采用交叉熵损失。
src_vocab_size = 50
tgt_vocab_size = 100
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 10
dropout = 0.1
batch_size = 64
# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (batch_size, max_seq_length))
tgt_data = torch.randint(1, tgt_vocab_size, (batch_size, max_seq_length))
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
transformer.train()
for epoch in range(10):
optimizer.zero_grad()
output = transformer(src_data, tgt_data)
loss = criterion(output.contiguous().view(-1,tgt_vocab_size), tgt_data.contiguous().view(-1))
loss.backward()
optimizer.step()
print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
这是训练10
个Epoch
的结果,可以看到,Loss
是逐步下降的。
Epoch: 1, Loss: 4.66384220123291
Epoch: 2, Loss: 4.026397228240967
Epoch: 3, Loss: 3.4089579582214355
Epoch: 4, Loss: 2.810636281967163
Epoch: 5, Loss: 2.290071964263916
Epoch: 6, Loss: 1.8364582061767578
Epoch: 7, Loss: 1.4482641220092773
Epoch: 8, Loss: 1.1404035091400146
Epoch: 9, Loss: 0.8801152110099792
Epoch: 10, Loss: 0.6701229810714722
- 总结
本文主要从《Attention Is All You Need
》这篇论文的内容来解读Transformer
模型的结构,初学者看到可能还是不太能理解里面的细节。网上关于Transformer
模型解读的资料非常多,本文参考资料里列举的几篇博客个人认为写得非常好,推荐大家都读一读。
本文代码来源于参考资料[5
],每个子模块的代码可以对照该模块的图来进行理解,比较适合初学者入门。
- 参考资料