-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgnn.py
58 lines (49 loc) · 2.67 KB
/
gnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GINConv, BatchNorm, global_max_pool as gmp, global_add_pool as gap,global_mean_pool as gep, global_sort_pool
# GCN based model
class GNNNet(torch.nn.Module):
def __init__(self, n_output=1, num_features_pro=32, num_features_mol=32, output_dim=128, dropout=0.2):
super(GNNNet, self).__init__()
print('GNNNet Loaded')
self.n_output = n_output
self.num_feature_pro = num_features_pro
self.mol_conv1 = GATConv(78, num_features_mol, heads=4)
self.mol_lin1 = torch.nn.Linear(78, 4 * num_features_mol)
self.mol_conv2 = GATConv(num_features_mol * 4, num_features_mol, heads=4)
self.mol_lin2 = torch.nn.Linear(4 * num_features_mol, 4 * num_features_mol)
self.mol_conv3 = GATConv(
4 * num_features_mol, output_dim, concat=False)
self.mol_lin3 = torch.nn.Linear(4 * num_features_mol, output_dim)
self.pro_conv1 = GATConv(70, num_features_pro, heads=4)
self.lin1 = torch.nn.Linear(70, 4 * num_features_pro)
self.pro_conv2 = GATConv(4 * num_features_pro, num_features_pro, heads=4)
self.lin2 = torch.nn.Linear(4 * num_features_pro, 4 * num_features_pro)
self.pro_conv3 = GATConv(
4 * num_features_pro, output_dim, concat=False)
self.lin3 = torch.nn.Linear(4 * num_features_pro, output_dim)
self.relu = nn.ELU()
self.bn1 = BatchNorm(128)
self.bn2 = BatchNorm(128)
self.bn3 = BatchNorm(128)
self.bn1_drug = BatchNorm(128)
self.bn2_drug = BatchNorm(128)
self.bn3_drug = BatchNorm(128)
def forward(self, data_drug, data_pro):
bs = data_pro.batch[-1]+1
# get graph input
mol_x, mol_edge_index, mol_batch = data_drug.x, data_drug.edge_index, data_drug.batch
# get protein input
target_x, target_edge_index, target_batch = data_pro.x, data_pro.edge_index, data_pro.batch
#drug
x = self.bn1_drug(self.relu(self.mol_conv1(mol_x, mol_edge_index) + self.mol_lin1(mol_x)))
x = self.bn2_drug(self.relu(self.mol_conv2(x, mol_edge_index) + self.mol_lin2(x)))
x = self.bn3_drug(self.relu(self.mol_conv3(x, mol_edge_index) + self.mol_lin3(x)))
x = gep(x, mol_batch) # global pooling
#protein
xt = self.bn1(self.relu(self.pro_conv1(target_x, target_edge_index) + self.lin1(target_x)))
xt = self.bn2(self.relu(self.pro_conv2(xt, target_edge_index) + self.lin2(xt)))
xt = self.bn3(self.relu(self.pro_conv3(xt, target_edge_index) + self.lin3(xt)))
xt = gep(xt, target_batch) # global pooling
return xt, x