ViT: Vision Transformer

用于大规模图像识别的Transformers。

““ATTENTION IS ALL YOU NEED””论文中介绍的 transformer 架构在NLP领域产生了巨大的影响。但是,它在计算机视觉领域的应用还很有限。2021年,谷歌的一个研究团队推出了一篇论文“AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE (2021)”,该论文将Transformer编码器架构应用于图像识别(分类)任务。

Idea of the paper

本文的思想是使用 Transformer encoder 架构创建一个 Vision Transformer,以最少的修改,并将其应用于图像分类任务。

当 Vision Transformers(ViT) 在足够大的数据量(>100M)上进行训练时,其计算资源比最先进的CNN (ResNet)少得多(少四倍),并转移到多个中型或小型图像识别基准时,它获得了出色的结果。本文的最后几节将进一步详细讨论这些结果。

Goal

Image Classification

图像分类处理的是为输入图像分配一个类标签。例如,如下图所示,我们预测输入图像的类为Dog,因为它在应用softmax后具有最高的置信度得分。

picture.image

在这里插入图片描述

The Vision Transformer

下图显示了 Vi(vision) T(ransformer) 体系结构。

picture.image

Vision Transformer Architecture Overview

为了更好地理解体系结构,让我们将其分为3个组件。

  1. Embedding
  2. Transformer Encoder
  3. MLP Head

Step 1: Embedding

在这一步中,我们将输入图像划分为[P, P]维的固定大小的小块,并通过拼接channels(如果存在)将它们线性平坦化 。例如,将大小为[P, P, C]的patch转换为[P*P*C, 1]。将这个线性平坦的patch进一步通过一个具有线性激活函数的 Feed-Forward,得到一个维数为[D, 1]的线性patch投影。D是在整个 transformer 中使用的称为嵌入维数的超参数。

通过保持 stride 等于 patch size,可以使用卷积层对图像进行 patched。这将把输入图像转换成所需大小的patches,然后将其 flattened 并传递到下一层

picture.image

Embedding Step

出于分类的目的,我们从原始BERT论文中获得灵感,将一个可学习的 class 嵌入与其他 patches 投影连接起来,这些patches投影在输出处的state作为class information。这个额外的class token被添加到负责 aggregating global image information 和 final classification 的image tokens集合中 。当它通过并学习attention层时,它能够学习这种全局聚合。我们还在线性patch中加入一维位置嵌入,以在输入patch中建立一定的顺序。

Why is positional encoding necessary?

Transformers不能记住输入的顺序。如果对图像patches进行重新排序,则失去了原始图像的意义。因此,我们在线性嵌入的图像patches中添加位置嵌入来跟踪序列。picture.image为了更好地理解嵌入步骤,让我们看看维度。

假设,我们有一个大小为224x224x1的输入图像,我们将其分成大小为16x16的固定大小的小块。设 patch大小为P,图像通道为 C,得到的patch总数N为196个。

将所有的patch线性flattening后得到维数为[N, P²C]的向量X。我们通过一个 Dense Layer 将其转换为D维向量,称为嵌入E [N, D]。然后我们添加一个可学习的class嵌入[1,D]来将E向量转换为维度[N+1, D]。最后一步是添加位置编码以获得最终向量Z。class嵌入和位置嵌入都是随机初始化的向量,在网络训练期间学习。

picture.image一旦我们有了向量Z,我们把它传递给一个transformer encoder层。

Step 2: Transformer Encoder

Transformer Encoder架构类似于““ATTENTION IS ALL YOU NEED””论文中提到的架构。它由多个相同的块堆栈组成。每个块都有一个多头注意层,然后是前馈层。两个子层周围都有一个残差连接,然后是层归一化。模型中的所有子层和嵌入层产生嵌入维D的输出。上一步的Z向量通过 Transformer Encoder架构得到上下文向量C

Transformer Encoder 架构由多个编码器块组成,其中每个块都有一个多头注意单元和一个前馈网络。每一层之后还有一个规范化层。

picture.image假设我们已经了解了前馈层的机制,让我们看看多头注意。

Multi-Head Attention:

picture.image

Multi-Head Attention unit 的主要组成部分是 Scaled Dot-Product Attention。首先,输入向量Z被复制3次,并乘以权重WqWkWv,分别得到 Queries、Keys和Values。然后将Queries乘以Keys,并将结果除以维度D的平方根,以避免梯度消失问题 。这个矩阵经过一个Softmax层,并乘以Values,得到最终的输出,称为Head H

picture.image如上所述的 Scaled Dot-Product Attention 被应用 h 次(h=8),得到h个注意头。这些注意力头被拼接起来并通过一个dense层来得到嵌入维数D的最终向量。

picture.image

Transformer Encoder Block

回到我们的 transformer encoder 架构,Z向量通过多个 Encoder 块来给我们最终的上下文向量C

这个MultiHead self-attention 可以在Pytorch中实现如下。

  
class MultiHeadSelfAttention(nn.Module):  
    def \_\_init\_\_(self, hidden\_dim, num\_heads):  
        super().__init__()  
          
        self.hidden_dim = hidden_dim  
        self.num_heads = num_heads  
          
        self.q_weights = [nn.Linear(hidden_dim, hidden_dim) for _ in range(self.num_heads)]  
        self.k_weights = [nn.Linear(hidden_dim, hidden_dim) for _ in range(self.num_heads)]  
        self.v_weights = [nn.Linear(hidden_dim, hidden_dim) for _ in range(self.num_heads)]  
        self.softmax = nn.Softmax(dim=-1)  
        self.linear = nn.Linear(num_heads*hidden_dim, hidden_dim)  
          
    def forward(self, X):  
        #B, N, D = X.shape   
        result = []  
        for x in X:  
            x_result = [] # H, N, D  
            for head in range(self.num_heads):  
                q = self.q_weights[head](x)  
                k = self.k_weights[head](x)  
                v = self.v_weights[head](x)  
                h = self.softmax(q @ k.T / self.hidden_dim**2) @ v # N, D  
                x_result.append(h)  
            result.append(torch.hstack(x_result)) # B, H, N, D  
        H = torch.cat([torch.unsqueeze(r, dim=0) for r in result])   
        out = self.linear(H)  
        return out # N, D  

Step 3: MLP head

一旦我们有了上下文向量C,我们只对上下文令牌C0感兴趣,以便进行分类。这个上下文令牌C0通过一个MLP头传递给我们最终的概率向量,以帮助预测class。MLP头部在预训练阶段由一个隐藏层和tanh作为非线性实现,在微调阶段由一个线性层实现。

picture.image最终的体系结构如上图所示。线性图像patches由一个[CLS]标记附加,并通过一个Dense层来获得最终的编码向量Z以及位置嵌入。然后将其通过Transformer Encoder体系结构传递以获得上下文向量C。将上下文令牌c0的值通过MLP头传递以获得最终预测。

以下是Vision Transformer的PyTorch实现以供参考。它使用如上所述的MultiHeadSelfAttention类。

  
lass VisionTransformer(nn.Module):  
    def \_\_init\_\_(self, img\_shape, patch\_size, hidden\_dim, num\_heads, out\_dim, num\_encoder\_blocks=6):  
        super().__init__()  
          
        self.img_shape = img_shape  
        self.patch_size = img_shape[0]*patch_size[0]*patch_size[1]  
        self.num_patches = int(img_shape[0]*img_shape[1]/patch_size[0]) ** 2  
        self.hidden_dim = hidden_dim  
        self.num_heads = num_heads  
        self.out_dim = out_dim  
        self.num_encoder_blocks = num_encoder_blocks  
          
        # Linear patching  
        self.linear_patching = nn.Linear(self.patch_size, self.hidden_dim)  
          
        # CLS embedding  
        self.cls_embedding = nn.Parameter(torch.rand(1, self.hidden_dim))  
          
        # Positional embedding  
        self.pos_embedding = nn.Parameter(torch.rand(1+self.num_patches, self.hidden_dim))  
          
        # Transformer  
        self.transformer_1 = nn.Sequential(  
                                nn.LayerNorm((1+self.num_patches, self.hidden_dim)),  
                                MultiHeadSelfAttention(self.hidden_dim, self.num_heads)  
                            )  
        self.transformer_2 = nn.Sequential(  
                                nn.LayerNorm((1+self.num_patches, self.hidden_dim)),  
                                nn.Linear(self.hidden_dim, self.hidden_dim),  
                            )  
          
        # MLP head  
        self.mlp_head = nn.Sequential(  
                            nn.Linear(self.hidden_dim, self.out_dim),  
                            nn.Tanh(),  
                        )  
      
    def forward(self, X):  
        N, C, H, W = X.shape  
        patches = X.reshape(N, self.num_patches, self.patch_size)  
        E = self.linear_patching(patches)  
        cls_embedding = nn.Parameter(self.cls_embedding.repeat(N, 1, 1))  
        E = torch.cat([cls_embedding, E], dim=1)  
        Epos = nn.Parameter(self.pos_embedding.repeat(N, 1, 1))  
        Z = E + Epos  
        for _ in range(self.num_encoder_blocks):  
            res1 = self.transformer_1(Z)  
            Z = self.transformer_2(res1 + Z)  
        C = self.mlp_head(Z[:, 0])  
        return C  

为了验证我们的代码,我们在CPU机器上使用MNIST数据集(60k)训练了我们的模型10次,得到了90.85%的准确率。对于更大的数据集,您可能需要GPU机器进行训练。

  
Test set: Average loss: 0.9649, Accuracy: 9198/10000 (90.85%)  

Training:

本文中提到的训练过程分为预训练和微调步骤,如下图所示。例如,ViT H/16模型首先在JFT 300M数据集上进行训练,然后在ImageNet或CIFAR数据集上进行微调。

picture.image

在这里插入图片描述

Pre-training:

该网络首先通过随机初始化权重在一个大型数据集上进行预训练。本文使用了3个预训练数据集。所有3个模型都使用Adam优化器进行预训练,批大小为4096,weight decay 为0.1。

picture.image

在这里插入图片描述

使用的主要架构如下所示。

picture.image

Source: https://arxiv.org/pdf/2010.11929v2.pdf

Fine-tuning:

一旦模型在大型数据集上进行了预训练,我们现在使用SGD在较小的数据集上微调ViT模型,ViT-L/16 和 ViT-H/14 模型的批大小分别为512和518。

Experiments

Scaling Up

为了了解预训练数据集的大小对模型性能的影响,作者在大型数据集上训练Vision transformer,并将结果与在相同数据集上训练的BiT进行比较。

picture.image

Source: https://arxiv.org/pdf/2010.11929v2.pdf

当在ImageNet (1M图像)上训练时,ViT的表现明显不如CNN (BiT)。然而,在ImageNet-21k (14M张图像)上的性能是相当的,在JFT (300M张图像)上,ViT优于BiT。

Comparison to the state of the art:

我们首先将我们最大的型号ViT-H/14 和 ViT-L/16 与最先进的 CNNs —— 1. BiT (Big Transfer) 2. Noisy student 进行比较。

picture.image正如我们所看到的,对于ImageNet,之前最先进的 Noisy Student 模型的准确率为88.4%。ViT-H/14 优于该基准,而在JFT和ImageNet21k上训练的 ViT-L/16 的性能略低于ImageNet数据集的基准模型。

但对于其他基准数据集,如CIFAR、oxford和VTAB, ViT-L/16 和 ViT-H/14 模型比最先进的ResNet模型表现更好,同时花费更少的资源进行训练。

我们可以看到,ViT模型所需的TPUv3核心天数大大少于BiT和Noisy Student模型,这证明ViT模型不仅性能良好,而且训练速度比现有的最先进模型快得多。

ViT还能很好地处理各种任务,例如VTAB-1k套件(19个任务,每个任务有1000个数据点)。

picture.image

Left: Image classification tasks. Right: Average performance across 19 tasks in the VTAB classification suite. Source: https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html

如所见,在这些流行的基准测试中,Vision Transformer的表现也与最先进的cnn相当或优于cnn。

But, why does Vision Transformer perform better?

为了理解这一点,让我们来看看Vision Transformer中的attention maps。

Vision Transformers 中的 Multi-Head Attention 有助于它只注意图像的相关部分。如果我们取所有注意力头输出的平均值,我们可以看到这种机制在起作用。该模型获得了语义上相关的图像区域进行分类。

picture.image

在这里插入图片描述

Transformers are yet not mainstream

通过这种简单但可扩展的ViT策略,当与大型数据集的预训练相结合时,它在许多图像分类数据集上达到或超过了目前的水平,同时预训练相对便宜,本文为分析ViT在其他计算机视觉任务(如检测和分割)上的性能设定了未来的范围。

参考文献

0
0
0
0
评论
未登录
暂无评论