-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
34 lines (29 loc) · 1.22 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.nn as nn
import torch.nn.functional as F
from layers import GCNConv_dense
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, dropout_adj, Adj, sparse):
super(GCN, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(GCNConv_dense(in_channels, hidden_channels))
for i in range(num_layers - 2):
self.layers.append(GCNConv_dense(hidden_channels, hidden_channels))
self.layers.append(GCNConv_dense(hidden_channels, out_channels))
self.dropout = dropout
self.dropout_adj = nn.Dropout(p=dropout_adj)
self.dropout_adj_p = dropout_adj
self.Adj = Adj
self.Adj.requires_grad = False
self.sparse = sparse
def forward(self, x):
if self.sparse:
Adj = self.Adj
Adj.edata['w'] = F.dropout(Adj.edata['w'], p=self.dropout_adj_p, training=self.training)
else:
Adj = self.dropout_adj(self.Adj)
for i, conv in enumerate(self.layers[:-1]):
x = conv(x, Adj)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.layers[-1](x, Adj)
return x