图神经网络13-图注意力模型GAT网络详解

技术

1

论文摘要

针对图结构数据,本文提出了一种GAT(graph attention networks)网络。该网络使用masked self-attention层解决了之前基于图卷积(或其近似)的模型所存在的问题。在GAT中,图中的每个节点可以根据邻节点的特征,为其分配不同的权值。GAT的另一个优点在于,无需使用预先构建好的图。因此,GAT可以解决一些基于谱的图神经网络中所具有的问题。实验证明,GAT模型可以有效地适用于(基于图的)归纳学习问题与转导学习问题。

创新点:

  • 引入masked self-attentional layers 来改进前面图卷积graph convolution的缺点
  • 对不同的相邻节点分配相应的权重,既不需要矩阵运算,也不需要事先知道图结构

attention 引入目的

  • 为每个节点分配不同权重
  • 关注那些作用比较大的节点,而忽视一些作用较小的节点
  • 在处理局部信息的时候同时能够关注整体的信息,不是用来给参与计算的各个节点进行加权的,而是表示一个全局的信息并参与计算

框架特点

  • attention 计算机制高效,为每个节点和其每个邻近节点计算attention 可以并行进行
  • 能够按照规则指定neighbor 不同的权重,不受邻居数目的影响
  • 可直接应用到归纳推理问题中

2

GAT模型结构

假设一个图有N个节点,节点的F维特征集合可以表示为

picture.image

注意力层的目的是输出新的节点特征集合,
picture.image

在这个过程中特征向量的维度可能会改变,即picture.image
picture.image

上式在将输入特征运用线性变换转化为高阶特征后,使用self-attention为每个节点分配注意力(权重)。其中picture.image的影响力系数(标量)。

上面的注意力计算考虑了图中任意两个节点,也就是说,图中每个节点对目标节点的影响都被考虑在内,这样就损失了图结构信息。论文中使用了masked attention,对于目标节点picture.image(包括自身的影响)。

为了更好的在不同节点之间分配权重,我们需要将目标节点与所有邻居计算出来的相关度进行统一的归一化处理,这里用softmax归一化:

picture.image

关于picture.image,使用负半轴斜率为0.2的LeakyReLU作为非线性激活函数:

picture.image

其中picture.image表示拼接操作。完整的权重系数计算公式为:

picture.image

得到归一化注意系数后,计算其对应特征的线性组合,通过非线性激活函数后,每个节点的最终输出特征向量为:

picture.image

3

MultiHead Attention

另外,本文使用多头注意力机制(multi-head attention)来稳定self-attention的学习过程,即对上式调用picture.image组相互独立的注意力机制,然后将输出结果拼接起来:

picture.image

其中picture.image个特征。为了减少输出的特征向量的维度,也可以将拼接操作替换为平均操作。

picture.image

下面是picture.image

4

不同模型比较

  • GAT计算高效。self-attetion层可以在所有边上并行计算,输出特征可以在所有节点上并行计算;不需要特征分解或者其他内存耗费大的矩阵操作。单个head的GAT的时间复杂度为picture.image
  • 与GCN不同的是,GAT为同一邻域中的节点分配不同的重要性,提升了模型的性能。
  • 注意力机制以共享的方式应用于图中的所有边,因此它不依赖于对全局图结构的预先访问,也不依赖于对所有节点(特征)的预先访问(这是许多先前技术的限制)。
  • 不必要无向图。如果边picture.image
  • 可以用于归纳学习;

5

评估

数据集

picture.image

其中前三个引文网络用于直推学习,第四个蛋白质交互网络PPI用于归纳学习。

6

实验设置

  • 直推学习
  • 两层GAT模型,第一层多头注意力picture.image(共64个特征),激活函数为指数线性单元(ELU);
  • 第二层单头注意力,计算picture.image为分类数),接softmax激活函数;
  • 为了处理小的训练集,模型中大量采用正则化方法,具体为L2正则化;
  • dropout;
  • 归纳学习:
  • 三层GAT模型,前两层多头注意力picture.image(共1024个特征),激活函数为指数非线性单元(ELU);
  • 最后一层用于多标签分类,picture.image,每个头计算121个特征,后接logistic sigmoid激活函数;
  • 不使用正则化和dropout;
  • 使用了跨越中间注意力层的跳跃连接。
  • batch_size = 2 graph

7

实验结果

  • 不同数据集的分类准确率效果对比(Transductive)

picture.image

  • 数据集PPI上的F1效果(归纳学习)

picture.image

  • 可视化

picture.image

8

核心代码

GAT层代码:


            
import numpy as np  
import torch  
import torch.nn as nn  
import torch.nn.functional as Fclass GraphAttentionLayer(nn.Module):  
    """  
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903  
    """  
    def \_\_init\_\_(self, in_features, out_features, dropout, alpha, concat=True):        super(GraphAttentionLayer, self).__init_\_()        self.dropout = dropout        self.in_features = in_features        self.out_features = out_features        self.alpha = alpha        self.concat = concat        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))  
        nn.init.xavier_uniform\_(self.W.data, gain=1.414)        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))  
        nn.init.xavier_uniform\_(self.a.data, gain=1.414)        self.leakyrelu = nn.LeakyReLU(self.alpha)    def forward(self, h, adj):  
        Wh = torch.mm(h, self.W) # h.shape: (N, in\_features), Wh.shape: (N, out\_features)  
        a_input = self._prepare_attentional_mechanism_input(Wh)  
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))  
  
        zero_vec = -9e15*torch.ones_like(e)  
        attention = torch.where(adj > 0, e, zero_vec)  
        attention = F.softmax(attention, dim=1)  
        attention = F.dropout(attention, self.dropout, training=self.training)  
        h_prime = torch.matmul(attention, Wh)        if self.concat:  
            return F.elu(h_prime)        else:  
            return h_prime    def \_prepare\_attentional\_mechanism\_input(self, Wh):  
        N = Wh.size()[0] # number of nodes  
  
        # Below, two matrices are created that contain embeddings in their rows in different orders.  
        # (e stands for embedding)  
        # These are the rows of the first matrix (Wh\_repeated\_in\_chunks):   
        # e1, e1, ..., e1,            e2, e2, ..., e2,            ..., eN, eN, ..., eN  
        # '-------------' -> N times  '-------------' -> N times       '-------------' -> N times  
        #   
        # These are the rows of the second matrix (Wh\_repeated\_alternating):   
        # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN   
        # '----------------------------------------------------' -> N times  
        #   
          
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)  
        Wh_repeated_alternating = Wh.repeat(N, 1)        # Wh\_repeated\_in\_chunks.shape == Wh\_repeated\_alternating.shape == (N * N, out\_features)  
  
        # The all\_combination\_matrix, created below, will look like this (|| denotes concatenation):  
        # e1 || e1  
        # e1 || e2  
        # e1 || e3  
        # ...  
        # e1 || eN  
        # e2 || e1  
        # e2 || e2  
        # e2 || e3  
        # ...  
        # e2 || eN  
        # ...  
        # eN || e1  
        # eN || e2  
        # eN || e3  
        # ...  
        # eN || eN  
  
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)        # all\_combinations\_matrix.shape == (N * N, 2 * out\_features)  
  
        return all_combinations_matrix.view(N, N, 2 * self.out_features)    def \_\_repr\_\_(self):        return self.__class_\_.__name_\_ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
        

GAT模型


            
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom layers import GraphAttentionLayer, SpGraphAttentionLayerclass GAT(nn.Module):  
    def \_\_init\_\_(self, nfeat, nhid, nclass, dropout, alpha, nheads):  
        """Dense version of GAT."""  
        super(GAT, self).__init__()  
        self.dropout = dropout  
  
        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]        for i, attention in enumerate(self.attentions):  
            self.add_module('attention\_{}'.format(i), attention)  
  
        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)    def forward(self, x, adj):  
        x = F.dropout(x, self.dropout, training=self.training)  
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)  
        x = F.dropout(x, self.dropout, training=self.training)  
        x = F.elu(self.out_att(x, adj))        return F.log_softmax(x, dim=1)
        

8

参考资料

论文笔记图神经网络:图注意力网络(GAT):

https://jjzhou012.github.io/blog/2020/01/28/Graph-Attention-Networks.html

Graph Attention Networks

https://www.cnblogs.com/c-w-k/p/13488820.html

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
火山引擎音视频体验白皮书
火山引擎联合AMD发布了音视频体验白皮书,以抖音亿级日活用户实践和大规模场景落地经验,详细解读音视频体验评估指标和模型,分享火山引擎音视频实验室的评测方案和抖音在音视频体验优化上的典型策略、案例,助力企业优化用户体验,促进业务增长。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论