Transformer模型结构解析与Python代码实现

技术
  1. 前言

2017年,谷歌研究人员在《Attention Is All You Need》这篇论文中提出了Transformer模型,该模型最初是被用于机器翻译任务中。由于其良好的可并行性和强大的特征提取能力,Transformer模型在随后的几年中被用到自然语言处理、语音识别、计算机视觉等各个领域中,并表现出优异的性能。

本文基于论文的内容解读Transformer模型的各个组成部分,然后用Python实现一个完整的Transformer模型。

  1. Transformer模型结构解析

1.1 模型总体架构

Transformer的总体架构如下图所示,模型包含一个编码器和解码器(分别对应下图中的左侧和右侧部分),编码器和解码器都是由一系列堆叠的注意力结构和全连接层组成。

picture.image

编码器

编码器由个相同的层组成,每个层又包含两个子层:第一个子层为多头自注意力机制,第二个子层为一个简单的全连接前馈网络。这两个子层都采用了残差连接结构,后面接一个LayerNorm层,也就是说,每个子层的输出为。因为使用了残差连接结构,模型中所有子层,包括输入的Embedding层,它们的输出维度都等于512

解码器

解码器也是由个相同的层组成,除了使用了与编码器相同的子层外,解码器还在其中插入了第三个子层,这个子层对编码器的输出memory执行多头注意力机制。与编码器类似的,解码器的子层也采用残差连接结构,后面再接一个LayerNorm层。需要注意的是,解码器在多头自注意力子层中添加了一个掩码,这种机制可以确保对位置的预测只能依赖于小于位置的已知输出。

解码器的输出通过可学习的线性变换层和SoftMax函数转换为预测下一个Token的概率。

1.2 模型结构详解

1.2.1 注意力机制

注意力函数可以描述为将查询(query )和一组键(key)- 值(value)对映射到输出,其中querykey、和value都是向量。注意力函数的功能就是计算value的加权和,其中分配给每个value的权重由与querykey相关的特定函数计算得出。

缩放点积注意力

作者提出的注意力称为 「缩放点积注意力」 ,它的输入是维度为的querykey,以及维度为的value。对于输入的query,首先计算它与key的点积并除以缩放系数,然后用一个SoftMax函数来计算应用到value上的权重,这个权重再与value做点积运算得到最终结果。

picture.image

在实际应用中,会把一组query向量打包在一起组成矩阵Q,相应的keyvalue也分别打包为矩阵KV,然后同时用注意力函数进行计算:

为什么QK的点积结果要除以系数?因为作者发现如果的值比较大,那么QK点积的结果会产生很大的值,这样经过SoftMax函数后会产生非常小的梯度而不利于模型训练。为了消除这种影响,作者把点积结果除以一个系数,这也是为什么作者把这种注意力称为缩放注意力的原因。

多头注意力

把输入的querykeyvalue用不同的、可学习的线性映射操作分别映射h次,映射后的维度分别为、和,然后每个映射的版本再并行地进行注意力计算,产生维度的输出结果。把这h个输出的结果拼接到一起然后再做一次映射,使得最后输出结果的维度与原始输入相同。作者把这种多次映射再分别进行注意力计算的结构称为 「多头注意力」 ,它比只使用一个维度为的querykeyvalue来计算注意力的效果要好很多。

picture.image

与单头注意力结构相比,多头注意力使得模型具备关注来自不同表示子空间信息的能力,模型的学习能力更强大。多头注意力机制其实就是将输入序列进行多组自注意力处理的过程,可以用公式表示为:

picture.image

对于每个注意力头

,输入矩阵 QKV 分别通过参数可学习的矩阵

进行映射,然后计算注意力。所有注意力头的输出结果会被拼接到一起,再用一个矩阵

把结果映射回维度

。作者在实际应用中采用了

个头,对于每个头,

通俗地讲,多头注意力就是将一个维度为的输入张量平均拆分成份,每一份都单独进行自注意力计算,然后把这个自注意力的结果进行汇总,最后把汇总的结果映射回原来的维度。

注意力的使用细节

Transformer模型中,对多头注意力的使用方式有以下3种方式:

  • 在编码器中的注意力层中,querykeyvalue都来自同一个输入,这种注意力叫做自注意力。自注意力被用来获取同一序列不同位置的依赖关系。
  • 解码器的第一个多头注意力子层,querykey做点积的结果会再添加一个掩码,这个掩码的作用是防止解码器在对位置进行预测的时候提前看到了位置及以后的信息。
  • 解码器中间的那个子层,这个多头注意力的query来自解码器第一个多头自注意力的输出,但是它的keyvalue来自解码器的输出memory。这种querykeyvalue不同源的注意力叫做交叉注意力。解码器使用交叉注意力来处理输入序列和输出序列之间的依赖关系。

1.2.2 前馈网络

除了注意力子层之外,编码器和解码器中的每一层都包含一个全连接的前馈网络,该网络包括两个线性变换层,它们中间有一个ReLU激活函数:

线性变换层在不同位置的参数是相同的,但它们在层与层之间是使用不同的参数,另外一种实现方式是采用两个卷积核大小为1的卷积层。前馈网络的输入、输出维度均为,中间隐藏层的维度则为。

1.2.3 Embedding

与其他序列转导模型类似,Transformer使用可学习的Embedding将输入Token和输出Token转换为维度为的向量。此外,在Embedding层中,权重会乘以系数。

1.2.4 位置编码

由于Transformer模型中没有循环和卷积结构,为了使模型能利用序列的顺序信息,就必须为一个序列中的各个Token注入相对或绝对位置的信息。为此,作者为编码器和解码器的输入Embedding中各添加了一个位置编码信息,编码信息与输入Embedding的维度都是,以方便二者相加。位置编码可以选择用可学习的和固定值的,作者试验了两种类型的位置编码方式,发现两种方式产生的结果几乎一致。

Transformer模型中,作者使用了不同频率的正弦和余弦函数来生成位置编码:

picture.image

其中

是位置,

是维度。之所以选择这样的位置编码函数是因为它可以让模型轻松地学习相对位置:对于任意的偏移

都可以表示为

的线性变换 。此外,它允许模型在推理阶段处理比训练期间遇到过的最长序列长度更长的序列。

位置编码的函数还可以进一步化简,其中

picture.image

  1. 用Python实现Transformer模型

本文代码来源于该博客:https://www.datacamp.com/tutorial/building-a-transformer-with-py-torch,略有改动。

2.1 多头注意力

从前文对Transformer模型的结构解析可以知道,多头注意力和前馈网络是Transformer模型的基本组成单元,我们首先来实现多头注意力单元。

picture.image

多头注意力首先通过Linear层分别对QKV进行降维,然后进行缩放点积注意力计算。这样的操作会进行h次,这h次的结果汇总后再通过Linear层进行维度变换得到最终的结果。


            
class MultiHeadAttention(nn.Module):
            
    def __init__(self, d_model, num_heads):
            
        super(MultiHeadAttention, self).__init__()
            
        # Ensure that the model dimension (d_model) is divisible by the number of heads
            
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
            

            
        # Initialize dimensions
            
        self.d_model = d_model # Model's dimension
            
        self.num_heads = num_heads # Number of attention heads
            
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
            

            
        # Linear layers for transforming inputs
            
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
            
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
            
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
            
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
            

            
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
            
        # Calculate attention scores
            
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
            

            
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
            
        if mask is not None:
            
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
            

            
        # Softmax is applied to obtain attention probabilities
            
        attn_probs = torch.softmax(attn_scores, dim=-1)
            

            
        # Multiply by values to obtain the final output
            
        output = torch.matmul(attn_probs, V)
            
        return output
            

            
    def split_heads(self, x):
            
        # Reshape the input to have num_heads for multi-head attention
            
        batch_size, seq_length, d_model = x.size()
            
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
            

            
    def combine_heads(self, x):
            
        # Combine the multiple heads back to original shape
            
        batch_size, _, seq_length, d_k = x.size()
            
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
            

            
    def forward(self, Q, K, V, mask=None):
            
        # Apply linear transformations and split heads
            
        Q = self.split_heads(self.W_q(Q))
            
        K = self.split_heads(self.W_k(K))
            
        V = self.split_heads(self.W_v(V))
            

            
        # Perform scaled dot-product attention
            
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
            

            
        # Combine heads and apply output transformation
            
        output = self.W_o(self.combine_heads(attn_output))
            
        return output
        

上面的代码中,MultiHeadAttention类的几个主要函数说明如下:

  • scaled_dot_product_attention :实现缩放点积注意力,函数里的每一步实现应该都比较好懂。
  • split_heads :对维度为(batch_size, seq_length, d_model)的输入张量x进行拆分,返回的张量维度为(batch_size, num_heads, seq_length, d_k)。这个函数实际上只对x做维度变换,这样做的好处是方便并行地实现多头注意力计算。
  • 「combine_heads」 :对维度为(batch_size, num_heads, seq_length, d_k)的张量x进行维度变换,返回维度为(batch_size, seq_length, d_model)的张量。这个函数其实就是实现了对多个注意力头的输出结果进行Concat操作。

2.2 前馈网络

前馈网络比较简单,就是两个全连接层,中间有个激活函数:


            
class PositionWiseFeedForward(nn.Module):
            
    def __init__(self, d_model, d_ff):
            
        super(PositionWiseFeedForward, self).__init__()
            
        self.fc1 = nn.Linear(d_model, d_ff)
            
        self.fc2 = nn.Linear(d_ff, d_model)
            
        self.relu = nn.ReLU()
            

            
    def forward(self, x):
            
        return self.fc2(self.relu(self.fc1(x)))
        

2.3 位置编码

位置编码按照前面的公式实现即可,该类的两个初始化参数说明如下:

  • d_model :模型的输入维度
  • max_seq_length :预设的最大序列长度

            
class PositionalEncoding(nn.Module):
            
    def __init__(self, d_model, max_seq_length):
            
        super(PositionalEncoding, self).__init__()
            

            
        pe = torch.zeros(max_seq_length, d_model)
            
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
            
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
            

            
        pe[:, 0::2] = torch.sin(position * div_term)
            
        pe[:, 1::2] = torch.cos(position * div_term)
            

            
        self.register_buffer('pe', pe.unsqueeze(0))
            

            
    def forward(self, x):
            
        return x + self.pe[:, :x.size(1)]
        

由于位置编码数据是固定参数,在训练过程中不需要更新,所以调用register_buffer函数向模型注册一个永久性缓冲区。

2.4 Encoder层

一个Encoder层主要包含一个多头注意力模块和一个前馈网络模块,它们的输出都会接一个LayerNorm层,都采用残差连接结构。

picture.image


            
class EncoderLayer(nn.Module):
            
    def __init__(self, d_model, num_heads, d_ff, dropout):
            
        super(EncoderLayer, self).__init__()
            
        self.self_attn = MultiHeadAttention(d_model, num_heads)
            
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
            
        self.norm1 = nn.LayerNorm(d_model)
            
        self.norm2 = nn.LayerNorm(d_model)
            
        self.dropout = nn.Dropout(dropout)
            

            
    def forward(self, x, mask):
            
        attn_output = self.self_attn(x, x, x, mask)
            
        x = self.norm1(x + self.dropout(attn_output))
            
        ff_output = self.feed_forward(x)
            
        x = self.norm2(x + self.dropout(ff_output))
            
        return x
        

2.5 Decoder层

一个Decoder层主要包含一个多头自注意力模块、一个多头交叉注意力模块和一个前馈网络模块,它们的输出都会接一个LayerNorm层,都采用残差连接结构,其中交叉注意力模块的KV来自编码器的输出。

picture.image


            
class DecoderLayer(nn.Module):
            
    def __init__(self, d_model, num_heads, d_ff, dropout):
            
        super(DecoderLayer, self).__init__()
            
        self.self_attn = MultiHeadAttention(d_model, num_heads)
            
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
            
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
            
        self.norm1 = nn.LayerNorm(d_model)
            
        self.norm2 = nn.LayerNorm(d_model)
            
        self.norm3 = nn.LayerNorm(d_model)
            
        self.dropout = nn.Dropout(dropout)
            

            
    def forward(self, x, enc_output, src_mask, tgt_mask):
            
        attn_output = self.self_attn(x, x, x, tgt_mask)
            
        x = self.norm1(x + self.dropout(attn_output))
            
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
            
        x = self.norm2(x + self.dropout(attn_output))
            
        ff_output = self.feed_forward(x)
            
        x = self.norm3(x + self.dropout(ff_output))
            
        return x
        

2.6 完整的Transformer模型

一个完整的Transformer模型包含一个编码器和一个解码器,编码器和解码器分别包含NEncoder层和Decoder层。源序列和目标序列经过Embedding层映射到向量空间并添加位置编码信息,然后分别送入编码器和解码器进行处理。


            
class Transformer(nn.Module):
            
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
            
        super(Transformer, self).__init__()
            
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
            
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
            
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
            

            
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
            
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
            

            
        self.fc = nn.Linear(d_model, tgt_vocab_size)
            
        self.dropout = nn.Dropout(dropout)
            

            
    def generate_mask(self, src, tgt):
            
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
            
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
            
        seq_length = tgt.size(1)
            
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
            
        tgt_mask = tgt_mask & nopeak_mask
            
        return src_mask, tgt_mask
            

            
    def forward(self, src, tgt):
            
        src_mask, tgt_mask = self.generate_mask(src, tgt)
            
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)*math.sqrt(self.d_model)))
            
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)*math.sqrt(self.d_model)))
            

            
        enc_output = src_embedded
            
        for enc_layer in self.encoder_layers:
            
            enc_output = enc_layer(enc_output, src_mask)
            

            
        dec_output = tgt_embedded
            
        for dec_layer in self.decoder_layers:
            
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)
            

            
        output = self.fc(dec_output)
            
        return output
        

2.7 训练一个简单的Transformer模型

通过上面的代码,我们就可以构建一个完整的Transformer模型了。接下来可以试着训练一个简单的模型,源序列和目标序列都是一些随机整数,损失函数采用交叉熵损失。


            
src_vocab_size = 50
            
tgt_vocab_size = 100
            
d_model = 512
            
num_heads = 8
            
num_layers = 6
            
d_ff = 2048
            
max_seq_length = 10
            
dropout = 0.1
            
batch_size = 64
            

            
# Generate random sample data
            
src_data = torch.randint(1, src_vocab_size, (batch_size, max_seq_length)) 
            
tgt_data = torch.randint(1, tgt_vocab_size, (batch_size, max_seq_length))
            

            
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
            

            
criterion = nn.CrossEntropyLoss(ignore_index=0)
            
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
            

            
transformer.train()
            

            
for epoch in range(10):
            
    optimizer.zero_grad()
            
    output = transformer(src_data, tgt_data)
            
    loss = criterion(output.contiguous().view(-1,tgt_vocab_size), tgt_data.contiguous().view(-1))
            
    loss.backward()
            
    optimizer.step()
            
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
        

这是训练10Epoch的结果,可以看到,Loss是逐步下降的。


            
Epoch: 1, Loss: 4.66384220123291
            
Epoch: 2, Loss: 4.026397228240967
            
Epoch: 3, Loss: 3.4089579582214355
            
Epoch: 4, Loss: 2.810636281967163
            
Epoch: 5, Loss: 2.290071964263916
            
Epoch: 6, Loss: 1.8364582061767578
            
Epoch: 7, Loss: 1.4482641220092773
            
Epoch: 8, Loss: 1.1404035091400146
            
Epoch: 9, Loss: 0.8801152110099792
            
Epoch: 10, Loss: 0.6701229810714722
        
  1. 总结

本文主要从《Attention Is All You Need》这篇论文的内容来解读Transformer模型的结构,初学者看到可能还是不太能理解里面的细节。网上关于Transformer模型解读的资料非常多,本文参考资料里列举的几篇博客个人认为写得非常好,推荐大家都读一读。

本文代码来源于参考资料[5],每个子模块的代码可以对照该模块的图来进行理解,比较适合初学者入门。

  1. 参考资料

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
火山引擎HTTPDNS边缘云原生技术实践
《火山引擎HTTPDNS边缘云原生技术实践》 赵彦奇 | 火山引擎边缘云网络研发工程师
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论