LightGCN模型部分代码解读

技术

picture.image

点击蓝字关注,提升学习效率

代码地址:

tensorflow版:https://github.com/kuandeng/LightGCN

本文对LightGCN模型部分的代码进行了解读,对相应部分进行了简单的注释帮助大家理解。笔者第一次尝试代码阅读分享,有什么不足之处或者建议可以给我留言哦,感谢。

picture.image

picture.image

Dropout

在图上实施dropout,以一定概率忽略一部分边


            
def \_\_dropout\_x(self, x, keep\_prob):  
        # 获取self.Graph中的大小,下标和值,Graph采用稀疏矩阵的表示方法SparseTensor  
        size = x.size()  
        index = x.indices().t()  
        values = x.values()  
        # 通过rand得到len(values)数量的随机数,加上keep\_prob  
        random_index = torch.rand(len(values)) + keep_prob  
        # 通过对这些数字取int使得小于1的为0,在通过bool()将0->false,大于等于1的取True  
        random_index = random_index.int().bool()  
        # 利用上面得到的True,False数组选取下标,从而dropout了为False的下标  
        index = index[random_index]  
        # 由于dropout在训练和测试过程中的不一致,所以需要除以p  
        values = values[random_index]/keep_prob  
        # 得到新的graph  
        g = torch.sparse.FloatTensor(index.t(), values, size)  
        return g  
      
    def \_\_dropout(self, keep\_prob):  
        if self.A_split:  
            graph = []  
            for g in self.Graph:  
                graph.append(self.__dropout_x(g, keep_prob))  
        else:  
            graph = self.__dropout_x(self.Graph, keep_prob)  
        return graph
        

picture.image

picture.image

消息传播

computer函数是LightGCN类中用于进行图信息传播的实现方法,整体上通过在整个图上进行矩阵计算得到所有用户和商品的embedding。


            
def computer(self):  
        """  
        propagate methods for lightGCN  
        """         
        # 得到所有用户和所有商品的embedding  
        users_emb = self.embedding_user.weight  
        items_emb = self.embedding_item.weight  
        all_emb = torch.cat([users_emb, items_emb])  
        # torch.split(all\_emb , [self.num\_users, self.num\_items])  
        embs = [all_emb]  
        # 判断是否需要dropout  
        if self.config['dropout']:  
            if self.training:  
                print("droping")  
                g_droped = self.__dropout(self.keep_prob)  
            else:  
                g_droped = self.Graph   
        else:  
            g_droped = self.Graph   
        # 根据层数对图进行信息传播和聚合考虑n-hop  
        # 通过稀疏矩阵乘法对Graph进行n\_layers次的计算  
        for layer in range(self.n_layers):  
            if self.A_split:  
                temp_emb = []  
                for f in range(len(g_droped)):  
                    temp_emb.append(torch.sparse.mm(g_droped[f], all_emb))  
                side_emb = torch.cat(temp_emb, dim=0)  
                all_emb = side_emb  
            else:  
                all_emb = torch.sparse.mm(g_droped, all_emb)  
            embs.append(all_emb)  
        embs = torch.stack(embs, dim=1)  
        #print(embs.size())  
        # 对每一层得到的输出求均值,以此将不同层的信息进行融合  
        light_out = torch.mean(embs, dim=1)  
        users, items = torch.split(light_out, [self.num_users, self.num_items])  
        return users, items
        

picture.image

picture.image

损失构建

在computer函数计算得到所有用户和商品经过消息传播后的embedding之后,getEmbedding根据当前用户和商品查询出需要用到的embedding以及当前用户和商品的原始embedding,即未经GCN的embedding。

传播后的embedding用于计算bpr损失,原始embedding用于计算L2正则项。


            
def getEmbedding(self, users, pos\_items, neg\_items):  
        # 得到需要计算相似度的用户和商品的embedding  
        all_users, all_items = self.computer()  
        users_emb = all_users[users]  
        pos_emb = all_items[pos_items]  
        neg_emb = all_items[neg_items]  
        # 没经过传播的embedding,用于后续正则项计算  
        users_emb_ego = self.embedding_user(users)  
        pos_emb_ego = self.embedding_item(pos_items)  
        neg_emb_ego = self.embedding_item(neg_items)  
        return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego  
      
    def bpr\_loss(self, users, pos, neg):  
        (users_emb, pos_emb, neg_emb,   
        userEmb0, posEmb0, negEmb0) = self.getEmbedding(users.long(), pos.long(), neg.long())  
        # 这个损失计算的是LightGCN论文中损失函数中的正则项,即做了一个L2正则  
        reg_loss = (1/2)*(userEmb0.norm(2).pow(2) +   
                         posEmb0.norm(2).pow(2) +  
                         negEmb0.norm(2).pow(2))/float(len(users))  
        # 通过乘法计算用户和商品的相似度  
        pos_scores = torch.mul(users_emb, pos_emb)  
        pos_scores = torch.sum(pos_scores, dim=1)  
        neg_scores = torch.mul(users_emb, neg_emb)  
        neg_scores = torch.sum(neg_scores, dim=1)  
        # pair-wise的排序损失  
        loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))
        

picture.image

往期推荐

在线学习方法FTRL原理与实现

SIGIR'21「腾讯」冷启动:元学习+FTRL+动态学习率=FORM模型

SIGIR'21「微信」利用元网络学习冷启动商品ID Embedding

学习交流群:联系作者--备注“研究方向-公司或学校”

picture.image

picture.image

picture.image

长按关注

更多精彩

picture.image

秋枫学习笔记

picture.image

picture.image

点个在看你最好看

picture.image

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论