图神经网络16-DGL实战:构建图神经网络(GNN)模块

技术

1

DGL NN模块的构造函数

构造函数完成以下几个任务:

  1. 设置选项。
  2. 注册可学习的参数或者子模块。
  3. 初始化参数。

              
  
    import torch.nn as nn  
  
    from dgl.utils import expand_as_pair  
  
    class SAGEConv(nn.Module):  
        def \_\_init\_\_(self,  
                     in\_feats,  
                     out\_feats,  
                     aggregator\_type,  
                     bias=True,  
                     norm=None,  
                     activation=None):  
            super(SAGEConv, self).__init__()  
  
            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)  
            self._out_feats = out_feats  
            self._aggre_type = aggregator_type  
            self.norm = norm  
            self.activation = activation  

          

在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。常用的聚合类型包括 meansummaxmin。一些模块可能会使用更加复杂的聚合函数,比如 lstm

上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化:


              
 # 聚合类型:mean、max\_pool、lstm、gcn  
if aggregator_type not in ['mean', 'max\_pool', 'lstm', 'gcn']:  
     raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))  
 if aggregator_type == 'max\_pool':  
     self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)  
if aggregator_type == 'lstm':  
      self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)  
if aggregator_type in ['mean', 'max\_pool', 'lstm']:  
      self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)  
      self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)  
      self.reset_parameters()  

          

注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linearnn.LSTM 等。构造函数的最后调用了 reset_parameters() 进行权重初始化。


              
 def reset\_parameters(self):  
        """重新初始化可学习的参数"""  
        gain = nn.init.calculate_gain('relu')  
        if self._aggre_type == 'max\_pool':  
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)  
        if self._aggre_type == 'lstm':  
            self.lstm.reset_parameters()  
        if self._aggre_type != 'gcn':  
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)  
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)  

          

2

编写DGL NN模块的forward函数

在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比, DGL NN模块额外增加了1个参数 :class:dgl.DGLGraphforward() 函数的内容一般可以分为3项操作:

  • 检测输入图对象是否符合规范。
  • 消息传递和聚合。
  • 聚合后,更新特征作为输出。

下文展示了SAGEConv示例中的 forward() 函数。

输入图对象的规范检测


              
def forward(self, graph, feat):  
        with graph.local_scope():  
         # 指定图类型,然后根据图类型扩展输入特征  
         feat_src, feat_dst = expand_as_pair(feat, graph)  

          

forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。比如在 :class:~dgl.nn.pytorch.conv.GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。当1个节点入度为0时, mailbox 将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是,在 :class:~dgl.nn.pytorch.conv.SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来,forward() 函数的输出不会全为0。在这种情况下,无需进行此类检验。

DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(:ref:guide_cn-graph-heterogeneous)和子图块(:ref:guide_cn-minibatch)。

SAGEConv的数学公式如下:

picture.image

源节点特征 feat_src 和目标节点特征 feat_dst 需要根据图类型被指定。用于指定图类型并将 feat 扩展为 feat_srcfeat_dst 的函数是 :meth:~dgl.utils.expand_as_pair。该函数的细节如下所示。


              
  
    def expand\_as\_pair(input\_, g=None):  
        if isinstance(input_, tuple):  
            # 二分图的情况  
            return input_  
        elif g is not None and g.is_block:  
            # 子图块的情况  
            if isinstance(input_, Mapping):  
                input_dst = {  
                    k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))  
                    for k, v in input_.items()}  
            else:  
                input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())  
            return input_, input_dst  
        else:  
            # 同构图的情况  
            return input_, input_  

          

对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。

在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)。当输入特征 feat 是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。

在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block)。在区块创建的阶段,dst nodes 位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()] 可以找到 feat_dst

确定 feat_srcfeat_dst 之后,以上3种图类型的计算方法是相同的。

消息传递和聚合


              
import dgl.function as fn  
import torch.nn.functional as F  
from dgl.utils import check_eq_shape  
       if self._aggre_type == 'mean':  
            graph.srcdata['h'] = feat_src  
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))  
            h_neigh = graph.dstdata['neigh']  
        elif self._aggre_type == 'gcn':  
                check_eq_shape(feat)  
                graph.srcdata['h'] = feat_src  
                graph.dstdata['h'] = feat_dst  
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))  
                # 除以入度  
                degs = graph.in_degrees().to(feat_dst)  
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)  
            elif self._aggre_type == 'max\_pool':  
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))  
                graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))  
                h_neigh = graph.dstdata['neigh']  
            else:  
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))  
  
            # GraphSAGE中gcn聚合不需要fc\_self  
            if self._aggre_type == 'gcn':  
                rst = self.fc_neigh(h_neigh)  
            else:  
                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)  

          

上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。

聚合后,更新特征作为输出


              
 # 激活函数  
 if self.activation is not None:  
        rst = self.activation(rst)  
       # 归一化  
   if self.norm is not None:  
        rst = self.norm(rst)  
        return rst  

          

forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。

3

简单的图分类任务

在本教程中,我们将学习如何使用 DGL 执行图分类,这个例子的任务目标就是对下面显示的八种拓扑类型Grpah进行分类。picture.image

这里我们直接使用 DGL 中合成数据集 data.MiniGCDataset。数据集有八种不同类型的图,每个类都有相同数量的图样本


              
from dgl.data import MiniGCDataset  
import matplotlib.pyplot as plt  
import networkx as nx  
# A dataset with 80 samples, each graph is  
# of size [10, 20]  
dataset = MiniGCDataset(80, 10, 20)  
graph, label = dataset[0]  
fig, ax = plt.subplots()  
nx.draw(graph.to_networkx(), ax=ax)  
ax.set_title('Class: {:d}'.format(label))  
plt.show()  

          

              
Using backend: pytorch
          

picture.image

创建graph的批数据

picture.image image


              
import dgl  
import torch  
  
def collate(samples):  
    # The input `samples` is a list of pairs  
    #  (graph, label).  
    graphs, labels = map(list, zip(*samples))  
    batched_graph = dgl.batch(graphs)  
    return batched_graph, torch.tensor(labels,dtype=torch.long)  

          

构建Graph分类器

picture.image image


              
from dgl.nn.pytorch import GraphConv  
import torch.nn as nn  
import torch.nn.functional as F  
  
class Classifier(nn.Module):  
    def \_\_init\_\_(self, in\_dim, hidden\_dim, n\_classes):  
        super(Classifier, self).__init__()  
        self.conv1 = GraphConv(in_dim, hidden_dim)  
        self.conv2 = GraphConv(hidden_dim, hidden_dim)  
        self.classify = nn.Linear(hidden_dim, n_classes)  
  
    def forward(self, g):  
        # Use node degree as the initial node feature. For undirected graphs, the in-degree  
        # is the same as the out\_degree.  
        h = g.in_degrees().view(-1, 1).float()  
        # Perform graph convolution and activation function.  
        h = F.relu(self.conv1(g, h))  
        h = F.relu(self.conv2(g, h))  
        g.ndata['h'] = h  
        # Calculate graph representation by averaging all the node representations.  
        hg = dgl.mean_nodes(g, 'h')  
        return self.classify(hg)  

          

              
import torch.optim as optim  
from torch.utils.data import DataLoader  
  
# Create training and test sets.  
trainset = MiniGCDataset(320, 10, 20)  
testset = MiniGCDataset(80, 10, 20)  
# Use PyTorch's DataLoader and the collate function  
# defined before.  
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,  
                         collate_fn=collate)  
  
# Create model  
model = Classifier(1, 256, trainset.num_classes)  
loss_func = nn.CrossEntropyLoss()  
optimizer = optim.Adam(model.parameters(), lr=0.001)  
model.train()  
  
epoch_losses = []  
for epoch in range(80):  
    epoch_loss = 0  
    for iter, (bg, label) in enumerate(data_loader):  
        prediction = model(bg)  
        loss = loss_func(prediction, label)  
        optimizer.zero_grad()  
        loss.backward()  
        optimizer.step()  
        epoch_loss += loss.detach().item()  
    epoch_loss /= (iter + 1)  
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))  
    epoch_losses.append(epoch_loss)  

          

              
Epoch 0, loss 2.0010  
Epoch 1, loss 1.9744  
Epoch 2, loss 1.9551  
Epoch 3, loss 1.9444  
Epoch 4, loss 1.9318  
Epoch 5, loss 1.9170  
Epoch 6, loss 1.8928  
Epoch 7, loss 1.8573  
Epoch 8, loss 1.8212  
Epoch 9, loss 1.7715  
Epoch 10, loss 1.7152  
Epoch 11, loss 1.6570  
Epoch 12, loss 1.5885  
Epoch 13, loss 1.5308  
Epoch 14, loss 1.4719  
Epoch 15, loss 1.4158  
Epoch 16, loss 1.3515  
Epoch 17, loss 1.2963  
Epoch 18, loss 1.2417  
Epoch 19, loss 1.1978  
Epoch 20, loss 1.1698  
Epoch 21, loss 1.1086  
Epoch 22, loss 1.0780  
Epoch 23, loss 1.0459  
Epoch 24, loss 1.0192  
Epoch 25, loss 1.0017  
Epoch 26, loss 1.0297  
Epoch 27, loss 0.9784  
Epoch 28, loss 0.9486  
Epoch 29, loss 0.9327  
Epoch 30, loss 0.9133  
Epoch 31, loss 0.9265  
Epoch 32, loss 0.9177  
Epoch 33, loss 0.9303  
Epoch 34, loss 0.8666  
Epoch 35, loss 0.8639  
Epoch 36, loss 0.8474  
Epoch 37, loss 0.8858  
Epoch 38, loss 0.8393  
Epoch 39, loss 0.8306  
Epoch 40, loss 0.8204  
Epoch 41, loss 0.8057  
Epoch 42, loss 0.7998  
Epoch 43, loss 0.7909  
Epoch 44, loss 0.7840  
Epoch 45, loss 0.7807  
Epoch 46, loss 0.7882  
Epoch 47, loss 0.7701  
Epoch 48, loss 0.7612  
Epoch 49, loss 0.7563  
Epoch 50, loss 0.7430  
Epoch 51, loss 0.7354  
Epoch 52, loss 0.7357  
Epoch 53, loss 0.7326  
Epoch 54, loss 0.7249  
Epoch 55, loss 0.7181  
Epoch 56, loss 0.7146  
Epoch 57, loss 0.7306  
Epoch 58, loss 0.7143  
Epoch 59, loss 0.7018  
Epoch 60, loss 0.7130  
Epoch 61, loss 0.7003  
Epoch 62, loss 0.6977  
Epoch 63, loss 0.7120  
Epoch 64, loss 0.6979  
Epoch 65, loss 0.7370  
Epoch 66, loss 0.7223  
Epoch 67, loss 0.6980  
Epoch 68, loss 0.6891  
Epoch 69, loss 0.6715  
Epoch 70, loss 0.6736  
Epoch 71, loss 0.6709  
Epoch 72, loss 0.6583  
Epoch 73, loss 0.6717  
Epoch 74, loss 0.6683  
Epoch 75, loss 0.6656  
Epoch 76, loss 0.6477  
Epoch 77, loss 0.6414  
Epoch 78, loss 0.6442  
Epoch 79, loss 0.6398
          

              
plt.title('cross entropy averaged over minibatches')  
plt.plot(epoch_losses)  
plt.show()  

          

picture.image


              
model.eval()  
# Convert a list of tuples to two lists  
test_X, test_Y = map(list, zip(*testset))  
test_bg = dgl.batch(test_X)  
test_Y = torch.tensor(test_Y).float().view(-1, 1)  
probs_Y = torch.softmax(model(test_bg), 1)  
sampled_Y = torch.multinomial(probs_Y, 1)  
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)  
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(  
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))  
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(  
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))  

          

              
Accuracy of sampled predictions on the test set: 58.7500%  
Accuracy of argmax predictions on the test set: 62.500000%
          

大家加左图进入微信群;右图有QQ学习交流群

picture.image

picture.image

picture.image

发现“在看”和“赞”了吗,戳我试试吧

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
如何利用云原生构建 AIGC 业务基石
AIGC即AI Generated Content,是指利用人工智能技术来生成内容,AIGC也被认为是继UGC、PGC之后的新型内容生产方式,AI绘画、AI写作等都属于AIGC的分支。而 AIGC 业务的部署也面临着异构资源管理、机器学习流程管理等问题,本次分享将和大家分享如何使用云原生技术构建 AIGC 业务。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论