-
Notifications
You must be signed in to change notification settings - Fork 0
/
configure.py
40 lines (36 loc) · 915 Bytes
/
configure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Best hyperparameters found."""
import torch
MWE_GCN_proteins = {
'num_ew_channels': 8,
'num_epochs': 2000,
'in_feats': 1,
'hidden_feats': 10,
'out_feats': 112,
'n_layers': 3,
'lr': 2e-2,
'weight_decay': 0,
'patience': 1000,
'dropout': 0.2,
'aggr_mode': 'sum', ## 'sum' or 'concat' for the aggregation across channels
'ewnorm': 'both'
}
MWE_DGCN_proteins = {
'num_ew_channels': 8,
'num_epochs': 2000,
'in_feats': 1,
'hidden_feats': 10,
'out_feats': 112,
'n_layers': 2,
'lr': 1e-2,
'weight_decay': 0,
'patience': 300,
'dropout': 0.5,
'aggr_mode': 'sum',
'residual': True,
'ewnorm': 'none'
}
def get_exp_configure(args):
if (args['model'] == 'MWE-GCN'):
return MWE_GCN_proteins
elif (args['model'] == 'MWE-DGCN'):
return MWE_DGCN_proteins