-
Notifications
You must be signed in to change notification settings - Fork 186
/
example.py
120 lines (84 loc) · 3.28 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from graphgym.config import cfg
from graphgym.register import register_layer
# Note: A registered GNN layer should take 'batch' as input
# and 'batch' as output
# Example 1: Directly define a GraphGym format Conv
# take 'batch' as input and 'batch' as output
class ExampleConv1(MessagePassing):
r"""Example GNN layer
"""
def __init__(self, in_channels, out_channels, bias=True, **kwargs):
super(ExampleConv1, self).__init__(aggr=cfg.gnn.agg, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def forward(self, batch):
""""""
x, edge_index = batch.node_feature, batch.edge_index
x = torch.matmul(x, self.weight)
batch.node_feature = self.propagate(edge_index, x=x)
return batch
def message(self, x_j):
return x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
# Remember to register your layer!
register_layer('exampleconv1', ExampleConv1)
# Example 2: First define a PyG format Conv layer
# Then wrap it to become GraphGym format
class ExampleConv2Layer(MessagePassing):
r"""Example GNN layer
"""
def __init__(self, in_channels, out_channels, bias=True, **kwargs):
super(ExampleConv2Layer, self).__init__(aggr=cfg.gnn.agg, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def forward(self, x, edge_index):
""""""
x = torch.matmul(x, self.weight)
return self.propagate(edge_index, x=x)
def message(self, x_j):
return x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
class ExampleConv2(nn.Module):
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(ExampleConv2, self).__init__()
self.model = ExampleConv2Layer(dim_in, dim_out, bias=bias)
def forward(self, batch):
batch.node_feature = self.model(batch.node_feature, batch.edge_index)
return batch
# Remember to register your layer!
register_layer('exampleconv2', ExampleConv2)