-
Notifications
You must be signed in to change notification settings - Fork 73
/
model.py
104 lines (87 loc) · 3.69 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch.nn as nn
import torch
import math
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class GraphConvolution(nn.Module):
def __init__(self, in_features, out_features, residual=False, variant=False):
super(GraphConvolution, self).__init__()
self.variant = variant
if self.variant:
self.in_features = 2*in_features
else:
self.in_features = in_features
self.out_features = out_features
self.residual = residual
self.weight = Parameter(torch.FloatTensor(self.in_features,self.out_features))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.out_features)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, adj , h0 , lamda, alpha, l):
theta = math.log(lamda/l+1)
hi = torch.spmm(adj, input)
if self.variant:
support = torch.cat([hi,h0],1)
r = (1-alpha)*hi+alpha*h0
else:
support = (1-alpha)*hi+alpha*h0
r = support
output = theta*torch.mm(support, self.weight)+(1-theta)*r
if self.residual:
output = output+input
return output
class GCNII(nn.Module):
def __init__(self, nfeat, nlayers,nhidden, nclass, dropout, lamda, alpha, variant):
super(GCNII, self).__init__()
self.convs = nn.ModuleList()
for _ in range(nlayers):
self.convs.append(GraphConvolution(nhidden, nhidden,variant=variant))
self.fcs = nn.ModuleList()
self.fcs.append(nn.Linear(nfeat, nhidden))
self.fcs.append(nn.Linear(nhidden, nclass))
self.params1 = list(self.convs.parameters())
self.params2 = list(self.fcs.parameters())
self.act_fn = nn.ReLU()
self.dropout = dropout
self.alpha = alpha
self.lamda = lamda
def forward(self, x, adj):
_layers = []
x = F.dropout(x, self.dropout, training=self.training)
layer_inner = self.act_fn(self.fcs[0](x))
_layers.append(layer_inner)
for i,con in enumerate(self.convs):
layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
layer_inner = self.act_fn(con(layer_inner,adj,_layers[0],self.lamda,self.alpha,i+1))
layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
layer_inner = self.fcs[-1](layer_inner)
return F.log_softmax(layer_inner, dim=1)
class GCNIIppi(nn.Module):
def __init__(self, nfeat, nlayers,nhidden, nclass, dropout, lamda, alpha,variant):
super(GCNIIppi, self).__init__()
self.convs = nn.ModuleList()
for _ in range(nlayers):
self.convs.append(GraphConvolution(nhidden, nhidden,variant=variant,residual=True))
self.fcs = nn.ModuleList()
self.fcs.append(nn.Linear(nfeat, nhidden))
self.fcs.append(nn.Linear(nhidden, nclass))
self.act_fn = nn.ReLU()
self.sig = nn.Sigmoid()
self.dropout = dropout
self.alpha = alpha
self.lamda = lamda
def forward(self, x, adj):
_layers = []
x = F.dropout(x, self.dropout, training=self.training)
layer_inner = self.act_fn(self.fcs[0](x))
_layers.append(layer_inner)
for i,con in enumerate(self.convs):
layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
layer_inner = self.act_fn(con(layer_inner,adj,_layers[0],self.lamda,self.alpha,i+1))
layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
layer_inner = self.sig(self.fcs[-1](layer_inner))
return layer_inner
if __name__ == '__main__':
pass