用于大规模图像识别的Transformers。
““ATTENTION IS ALL YOU NEED””论文中介绍的 transformer 架构在NLP领域产生了巨大的影响。但是,它在计算机视觉领域的应用还很有限。2021年,谷歌的一个研究团队推出了一篇论文“AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE (2021)”,该论文将Transformer编码器架构应用于图像识别(分类)任务。
Idea of the paper
本文的思想是使用 Transformer encoder 架构创建一个 Vision Transformer,以最少的修改,并将其应用于图像分类任务。
当 Vision Transformers(ViT) 在足够大的数据量(>100M)上进行训练时,其计算资源比最先进的CNN (ResNet)少得多(少四倍),并转移到多个中型或小型图像识别基准时,它获得了出色的结果。本文的最后几节将进一步详细讨论这些结果。
Goal
Image Classification
图像分类处理的是为输入图像分配一个类标签。例如,如下图所示,我们预测输入图像的类为Dog,因为它在应用softmax后具有最高的置信度得分。
在这里插入图片描述
The Vision Transformer
下图显示了 Vi(vision) T(ransformer) 体系结构。
Vision Transformer Architecture Overview
为了更好地理解体系结构,让我们将其分为3个组件。
- Embedding
- Transformer Encoder
- MLP Head
Step 1: Embedding
在这一步中,我们将输入图像划分为[P, P]维的固定大小的小块,并通过拼接channels(如果存在)将它们线性平坦化 。例如,将大小为[P, P, C]的patch转换为[P*P*C, 1]。将这个线性平坦的patch进一步通过一个具有线性激活函数的 Feed-Forward,得到一个维数为[D, 1]的线性patch投影。D是在整个 transformer 中使用的称为嵌入维数的超参数。
通过保持 stride 等于 patch size,可以使用卷积层对图像进行 patched。这将把输入图像转换成所需大小的patches,然后将其 flattened 并传递到下一层 。
Embedding Step
出于分类的目的,我们从原始BERT论文中获得灵感,将一个可学习的 class 嵌入与其他 patches 投影连接起来,这些patches投影在输出处的state作为class information。这个额外的class token被添加到负责 aggregating global image information 和 final classification 的image tokens集合中 。当它通过并学习attention层时,它能够学习这种全局聚合。我们还在线性patch中加入一维位置嵌入,以在输入patch中建立一定的顺序。
Why is positional encoding necessary?
Transformers不能记住输入的顺序。如果对图像patches进行重新排序,则失去了原始图像的意义。因此,我们在线性嵌入的图像patches中添加位置嵌入来跟踪序列。为了更好地理解嵌入步骤,让我们看看维度。
假设,我们有一个大小为224x224x1的输入图像,我们将其分成大小为16x16的固定大小的小块。设 patch大小为P,图像通道为 C,得到的patch总数N为196个。
将所有的patch线性flattening后得到维数为[N, P²C]的向量X。我们通过一个 Dense Layer 将其转换为D维向量,称为嵌入E [N, D]。然后我们添加一个可学习的class嵌入[1,D]来将E向量转换为维度[N+1, D]。最后一步是添加位置编码以获得最终向量Z。class嵌入和位置嵌入都是随机初始化的向量,在网络训练期间学习。
一旦我们有了向量
Z,我们把它传递给一个transformer encoder层。
Step 2: Transformer Encoder
Transformer Encoder架构类似于““ATTENTION IS ALL YOU NEED””论文中提到的架构。它由多个相同的块堆栈组成。每个块都有一个多头注意层,然后是前馈层。两个子层周围都有一个残差连接,然后是层归一化。模型中的所有子层和嵌入层产生嵌入维D的输出。上一步的Z向量通过 Transformer Encoder架构得到上下文向量C。
Transformer Encoder 架构由多个编码器块组成,其中每个块都有一个多头注意单元和一个前馈网络。每一层之后还有一个规范化层。
假设我们已经了解了前馈层的机制,让我们看看多头注意。
Multi-Head Attention:
Multi-Head Attention unit 的主要组成部分是 Scaled Dot-Product Attention。首先,输入向量Z被复制3次,并乘以权重Wq、Wk和Wv,分别得到 Queries、Keys和Values。然后将Queries乘以Keys,并将结果除以维度D的平方根,以避免梯度消失问题 。这个矩阵经过一个Softmax层,并乘以Values,得到最终的输出,称为Head H。
如上所述的 Scaled Dot-Product Attention 被应用
h 次(h=8),得到h个注意头。这些注意力头被拼接起来并通过一个dense层来得到嵌入维数D的最终向量。
Transformer Encoder Block
回到我们的 transformer encoder 架构,Z向量通过多个 Encoder 块来给我们最终的上下文向量C。
这个MultiHead self-attention 可以在Pytorch中实现如下。
class MultiHeadSelfAttention(nn.Module):
def \_\_init\_\_(self, hidden\_dim, num\_heads):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.q_weights = [nn.Linear(hidden_dim, hidden_dim) for _ in range(self.num_heads)]
self.k_weights = [nn.Linear(hidden_dim, hidden_dim) for _ in range(self.num_heads)]
self.v_weights = [nn.Linear(hidden_dim, hidden_dim) for _ in range(self.num_heads)]
self.softmax = nn.Softmax(dim=-1)
self.linear = nn.Linear(num_heads*hidden_dim, hidden_dim)
def forward(self, X):
#B, N, D = X.shape
result = []
for x in X:
x_result = [] # H, N, D
for head in range(self.num_heads):
q = self.q_weights[head](x)
k = self.k_weights[head](x)
v = self.v_weights[head](x)
h = self.softmax(q @ k.T / self.hidden_dim**2) @ v # N, D
x_result.append(h)
result.append(torch.hstack(x_result)) # B, H, N, D
H = torch.cat([torch.unsqueeze(r, dim=0) for r in result])
out = self.linear(H)
return out # N, D
Step 3: MLP head
一旦我们有了上下文向量C,我们只对上下文令牌C0感兴趣,以便进行分类。这个上下文令牌C0通过一个MLP头传递给我们最终的概率向量,以帮助预测class。MLP头部在预训练阶段由一个隐藏层和tanh作为非线性实现,在微调阶段由一个线性层实现。
最终的体系结构如上图所示。线性图像patches由一个
[CLS]标记附加,并通过一个Dense层来获得最终的编码向量Z以及位置嵌入。然后将其通过Transformer Encoder体系结构传递以获得上下文向量C。将上下文令牌c0的值通过MLP头传递以获得最终预测。
以下是Vision Transformer的PyTorch实现以供参考。它使用如上所述的MultiHeadSelfAttention类。
lass VisionTransformer(nn.Module):
def \_\_init\_\_(self, img\_shape, patch\_size, hidden\_dim, num\_heads, out\_dim, num\_encoder\_blocks=6):
super().__init__()
self.img_shape = img_shape
self.patch_size = img_shape[0]*patch_size[0]*patch_size[1]
self.num_patches = int(img_shape[0]*img_shape[1]/patch_size[0]) ** 2
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.out_dim = out_dim
self.num_encoder_blocks = num_encoder_blocks
# Linear patching
self.linear_patching = nn.Linear(self.patch_size, self.hidden_dim)
# CLS embedding
self.cls_embedding = nn.Parameter(torch.rand(1, self.hidden_dim))
# Positional embedding
self.pos_embedding = nn.Parameter(torch.rand(1+self.num_patches, self.hidden_dim))
# Transformer
self.transformer_1 = nn.Sequential(
nn.LayerNorm((1+self.num_patches, self.hidden_dim)),
MultiHeadSelfAttention(self.hidden_dim, self.num_heads)
)
self.transformer_2 = nn.Sequential(
nn.LayerNorm((1+self.num_patches, self.hidden_dim)),
nn.Linear(self.hidden_dim, self.hidden_dim),
)
# MLP head
self.mlp_head = nn.Sequential(
nn.Linear(self.hidden_dim, self.out_dim),
nn.Tanh(),
)
def forward(self, X):
N, C, H, W = X.shape
patches = X.reshape(N, self.num_patches, self.patch_size)
E = self.linear_patching(patches)
cls_embedding = nn.Parameter(self.cls_embedding.repeat(N, 1, 1))
E = torch.cat([cls_embedding, E], dim=1)
Epos = nn.Parameter(self.pos_embedding.repeat(N, 1, 1))
Z = E + Epos
for _ in range(self.num_encoder_blocks):
res1 = self.transformer_1(Z)
Z = self.transformer_2(res1 + Z)
C = self.mlp_head(Z[:, 0])
return C
为了验证我们的代码,我们在CPU机器上使用MNIST数据集(60k)训练了我们的模型10次,得到了90.85%的准确率。对于更大的数据集,您可能需要GPU机器进行训练。
Test set: Average loss: 0.9649, Accuracy: 9198/10000 (90.85%)
Training:
本文中提到的训练过程分为预训练和微调步骤,如下图所示。例如,ViT H/16模型首先在JFT 300M数据集上进行训练,然后在ImageNet或CIFAR数据集上进行微调。
在这里插入图片描述
Pre-training:
该网络首先通过随机初始化权重在一个大型数据集上进行预训练。本文使用了3个预训练数据集。所有3个模型都使用Adam优化器进行预训练,批大小为4096,weight decay 为0.1。
在这里插入图片描述
使用的主要架构如下所示。
Source: https://arxiv.org/pdf/2010.11929v2.pdf
Fine-tuning:
一旦模型在大型数据集上进行了预训练,我们现在使用SGD在较小的数据集上微调ViT模型,ViT-L/16 和 ViT-H/14 模型的批大小分别为512和518。
Experiments
Scaling Up
为了了解预训练数据集的大小对模型性能的影响,作者在大型数据集上训练Vision transformer,并将结果与在相同数据集上训练的BiT进行比较。
Source: https://arxiv.org/pdf/2010.11929v2.pdf
当在ImageNet (1M图像)上训练时,ViT的表现明显不如CNN (BiT)。然而,在ImageNet-21k (14M张图像)上的性能是相当的,在JFT (300M张图像)上,ViT优于BiT。
Comparison to the state of the art:
我们首先将我们最大的型号ViT-H/14 和 ViT-L/16 与最先进的 CNNs —— 1. BiT (Big Transfer) 2. Noisy student 进行比较。
正如我们所看到的,对于ImageNet,之前最先进的 Noisy Student 模型的准确率为88.4%。ViT-H/14 优于该基准,而在JFT和ImageNet21k上训练的 ViT-L/16 的性能略低于ImageNet数据集的基准模型。
但对于其他基准数据集,如CIFAR、oxford和VTAB, ViT-L/16 和 ViT-H/14 模型比最先进的ResNet模型表现更好,同时花费更少的资源进行训练。
我们可以看到,ViT模型所需的TPUv3核心天数大大少于BiT和Noisy Student模型,这证明ViT模型不仅性能良好,而且训练速度比现有的最先进模型快得多。
ViT还能很好地处理各种任务,例如VTAB-1k套件(19个任务,每个任务有1000个数据点)。
Left: Image classification tasks. Right: Average performance across 19 tasks in the VTAB classification suite. Source: https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html
如所见,在这些流行的基准测试中,Vision Transformer的表现也与最先进的cnn相当或优于cnn。
But, why does Vision Transformer perform better?
为了理解这一点,让我们来看看Vision Transformer中的attention maps。
Vision Transformers 中的 Multi-Head Attention 有助于它只注意图像的相关部分。如果我们取所有注意力头输出的平均值,我们可以看到这种机制在起作用。该模型获得了语义上相关的图像区域进行分类。
在这里插入图片描述
Transformers are yet not mainstream
通过这种简单但可扩展的ViT策略,当与大型数据集的预训练相结合时,它在许多图像分类数据集上达到或超过了目前的水平,同时预训练相对便宜,本文为分析ViT在其他计算机视觉任务(如检测和分割)上的性能设定了未来的范围。
