首页 > 其他 > 详细

DGL学习(六): GCN实现

时间:2020-07-24 12:17:10      阅读:156      评论:0      收藏:0      [点我收藏+]

GCN可以认为由两步组成:

对于每个节点 $u$

1)汇总邻居的表示$h_v$ 产生中间表示 $\hat h_u$

2) 使用$W_u$线性投影 $\hat h_v$, 再经过非线性变换 $f$ , 即 $h_u = f(W_u \hat h_u)$

 

首先定义message函数和reduce函数。

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

## 定义消息函数 和 reduce函数
gcn_msg = fn.copy_src(src=h, out=m)
gcn_reduce = fn.sum(msg=m, out=h)

定义GCN

## 定义GCNLayer
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, feature):
        # Creating a local scope so that all the stored ndata and edata
        # (such as the `‘h‘` ndata below) are automatically popped out
        # when the scope exits.
        with g.local_scope():
            g.ndata[h] = feature
            g.update_all(gcn_msg, gcn_reduce)
            h = g.ndata[h]
            return self.linear(h)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = GCNLayer(1433, 16)
        self.layer2 = GCNLayer(16, 7)

    def forward(self, g, features):
        x = F.relu(self.layer1(g, features))
        x = self.layer2(g, x)
        return x
net = Net()
print(net)

 

DGL学习(六): GCN实现

原文:https://www.cnblogs.com/liyinggang/p/13370943.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!