用Grad-CAM可视化ViT输出,揭示视觉Transformer工作原理

容器MySQL微服务

picture.image

本文入选【技术写作训练营】优秀结营作品,作者:王悦天

Vision Transformer (ViT) 作为现在 CV 中的主流 backbone,它可以在图像分类任务上达到与卷积神经网络(CNN)相媲美甚至超越的性能。ViT 的核心思想是将输入图像划分为多个小块,然后将每个小块作为一个 token 输入到 Transformer 的编码器中,最终得到一个全局的类别 token 作为分类结果。

ViT 的优势在于它可以更好地捕捉图像中的长距离依赖关系,而不需要使用复杂的卷积操作。然而,这也带来了一个挑战,那就是如何解释 ViT 的决策过程,以及它是如何关注图像中的不同区域的。想要弄清楚这个问题,我们可以使用一种叫做 Grad-CAM 的技术,它可以根据 ViT 的输出和梯度,生成一张热力图,显示 ViT 在做出分类时最关注的图像区域。

原理

Grad-CAM 对 ViT 的输出进行可视化的原理是利用 ViT 的最后一个注意力块的输出和梯度,计算出每个 token 对分类结果的贡献度,然后将这些贡献度映射回原始图像的空间位置,形成一张热力图。具体来说,Grad-CAM+ViT 的步骤如下:

  1. 给定一个输入图像和一个目标类别,将图像划分为 14x14 个小块,并将每个小块转换为一个 768 维的向量。在这些向量之前,还要加上一个特殊的类别 token ,用于表示全局的分类信息。这样就得到了一个 197x768 的矩阵,作为 ViT 的输入。

  2. 将 ViT 的输入通过 Transformer 的编码器,得到一个 197x768 的输出矩阵。其中第一个向量就是类别 token ,它包含了 ViT 对整个图像的理解。我们将这个向量通过一个线性层和一个 softmax 层,得到最终的分类概率。

  3. 计算类别 token 对目标类别的梯度,即 ,其中 是目标类别的概率, 是 ViT 的输出矩阵。这个梯度表示了每个 token 对分类结果的重要性。

  4. 对每个 token 的梯度求平均值,得到一个 197 维的向量 ,其中 , 是梯度的维度,即 768 。这个向量 可以看作是每个 token 的权重。

  5. 将 ViT 的输出矩阵和权重向量相乘,得到一个 197 维的向量 ,其中 。这个向量 可以看作是每个 token 对分类结果的贡献度。

  6. 将贡献度向量 除去第一个元素(类别 token ),并重塑为一个 14x14 的矩阵 ,其中 。这个矩阵 可以看作是每个小块对分类结果的贡献度。

  7. 将贡献度矩阵 进行归一化和上采样,得到一个与原始图像大小相同的矩阵 ,其中 。这个矩阵 就是我们要求的热力图,它显示了 ViT 在做出分类时最关注的图像区域。

  8. 将热力图 和原始图像进行叠加,得到一张可视化的图像,可以直观地看到 ViT 的注意力分布。

使用代码

首先,import 进来 pytorch_grad_cam 工具和一些必要的包,再 load 进来我们要分析的 ViT 模型,这里使用 DeiT_Tiny 作为示例:


              
import cv2
              
import numpy as np
              
import torch
              

              
from pytorch_grad_cam import GradCAM, \
              
                            ScoreCAM, \
              
                            GradCAMPlusPlus, \
              
                            AblationCAM, \
              
                            XGradCAM, \
              
                            EigenCAM, \
              
                            EigenGradCAM, \
              
                            LayerCAM, \
              
                            FullGrad
              

              
from pytorch_grad_cam import GuidedBackpropReLUModel
              
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
              

              
# 加载预训练的 ViT 模型
              
model = torch.hub.load('facebookresearch/deit:main','deit_tiny_patch16_224', pretrained=True)
              
model.eval()
              

              
# 判断是否使用 GPU 加速
              
use_cuda = torch.cuda.is_available()
              
if use_cuda:
              
    model = model.cuda()
          

接下来,我们需要定义一个函数来将 ViT 的输出层从三维张量转换为二维张量,以便 Grad-CAM 能够处理:


              
def reshape_transform(tensor, height=14, width=14):
              
    # 去掉cls token
              
    result = tensor[:, 1:, :].reshape(tensor.size(0),
              
    height, width, tensor.size(2))
              

              
    # 将通道维度放到第一个位置
              
    result = result.transpose(2, 3).transpose(1, 2)
              
    return result
          

然后,我们需要选择一个目标层来计算 Grad-CAM。由于 ViT 的最后一层只有类别标记对预测类别有影响,所以我们不能选择最后一层。我们可以选择倒数第二层中的任意一个 Transformer 编码器作为目标层。在这里,我们选择第 11 层作为示例:


              
# 创建 GradCAM 对象
              
cam = GradCAM(model=model,
              
            target_layers=[model.blocks[-1].norm1],
              
            # 这里的target_layer要看模型情况,
              
            # 比如还有可能是:target_layers = [model.blocks[-1].ffn.norm]
              
            use_cuda=use_cuda,
              
            reshape_transform=reshape_transform)
          

接下来,我们需要准备一张输入图像,并将其转换为适合 ViT 的格式:


              
# 读取输入图像
              
image_path = "xxx.jpg"
              
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
              
rgb_img = cv2.resize(rgb_img, (224, 224))
              

              
# 预处理图像
              
input_tensor = preprocess_image(rgb_img,
              
mean=[0.485, 0.456, 0.406],
              
std=[0.229, 0.224, 0.225])
              

              
# 看情况将图像转换为批量形式
              
# input_tensor = input_tensor.unsqueeze(0)
              
if use_cuda:
              
    input_tensor = input_tensor.cuda()
          

最后,我们可以调用 cam 对象的 forward 方法,传入输入张量和预测类别(如果不指定,则默认为最高概率的类别),得到 Grad-CAM 的输出:


              
# 计算 grad-cam
              
target_category = None # 可以指定一个类别,或者使用 None 表示最高概率的类别
              
grayscale_cam = cam(input_tensor=input_tensor, targets=target_category)
              
grayscale_cam = grayscale_cam[0, :]
              

              
# 将 grad-cam 的输出叠加到原始图像上
              
visualization = show_cam_on_image(rgb_img, grayscale_cam)
              

              
# 保存可视化结果
              
cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR, visualization)
              
cv2.imwrite('cam.jpg', visualization)
          

这样,我们就完成了使用 Grad-CAM 对 ViT 的输出进行可视化的过程。我们可以看到,ViT 主要关注了图像中的猫的头部和身体区域,这与我们的直觉相符。通过使用 Grad-CAM,我们可以更好地理解 ViT 的工作原理,以及它对不同图像区域的重要性。

PyTorch-Grad-CAM 库的更多方法

除了经典的 Grad-CAM,库里目前支持的方法还有:

picture.image

这里给出 MMPretrain 提供的对比示例:

picture.image

在 MMPretrain 中使用

如果你刚好在用 MMPretrain,那么有着方便的脚本文件来帮助你更加方便的进行上面的工作,具体可见: https://mmpretrain.readthedocs.io/zh\_CN/latest/useful\_tools/cam\_visualization.html

picture.image

示例

这里也放一些我自己试过的例子:

以这张可爱的猫猫作为输入:

picture.image

我们选择 DeiT_tiny 模型,并使用最经典的 Grad-CAM,设置 target_category = None ,即使用输出最高概率的类别,选择最后一层的第一个 Layer Norm 作为 target layer 得到结果如下所示:

picture.image

可以看出,heatmap 的高亮区域,似乎只出现在猫猫头上的部分区域,有聪明的同学知道这是为什么吗?(提示:ImageNet-1k 数据集中,猫的种类有 12 种;判别性区域)

再来看看换用更大一点的 DeiT-base 会怎么样呢?

picture.image

关注的区域变了,甚至一些似乎不在猫猫身上了,是为什么呢(想想 token mixer,或者有没有可能是分类错误呢),这里,我们不妨换为前面的层(e.g. 第四层)来看看:

picture.image

似乎更多的关注点出现了,再结合最后一层的结果想一想(ViT 有时会有这样的“散焦”)。

这里只是一个最基本的尝试,初步给大家展示了一下

ViT+Grad-CAM 的使用。后面,关于各种不同的预训练方法(MAE、SimMIM、DeiT、BeiT 等等)、各种 backbone 使用方法(linear prob、fine-tuning 与 layer-wise learning rate decay 的 ft)、去不去掉 cls token、甚至用别的 token 去接 fc 等等等...的各种 Vision Transformer 的 Grad-CAM 的可视化结果,就由大家来自由探索吧~,说不定会有新的、不一样的发现哦😊

总结

通过使用 Grad-CAM,我们可以更好地理解 ViT 的工作原理,以及它是如何从图像中提取有用的特征的。Grad-CAM 也可以用于其他基于 Transformer 的模型,例如 DeiT、Swin Transformer 等,只需要根据不同的模型结构和输出,调整相应的计算步骤即可。

扫码加入👉「集智书童」交流群

(备注: 方向+学校/公司+昵称 )

picture.image

picture.image

picture.image

picture.image

picture.image

picture.image

想要了解更多:

前沿AI视觉感知全栈知识👉「分类、检测、分割、关键点、车道线检测、3D视觉(分割、检测)、多模态、目标跟踪、NerF」

行业技术方案 👉「AI安防、AI医疗、AI自动驾驶」

AI模型部署落地实战 👉「CUDA、TensorRT、NCNN、OpenVINO、MNN、ONNXRuntime以及地平线框架」

欢迎扫描上方二维码,加入「集智书童-知识星球 」,日常分享论文、学习笔记、问题解决方案、部署方案以及全栈式答疑,期待交流!

免责声明

凡本公众号注明“来源:XXX(非集智书童)”的作品,均转载自其它媒体,版权归原作者所有,如有侵权请联系我们删除,谢谢。

点击下方“阅读原文 ”,

了解更多AI学习路上的 「武功秘籍」

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
CloudWeGo白皮书:字节跳动云原生微服务架构原理与开源实践
本书总结了字节跳动自2018年以来的微服务架构演进之路
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论