forked from RexYing/gnn-model-explainer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfigs.py
103 lines (96 loc) · 4.78 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import utils.parser_utils as parser_utils
def arg_parse():
parser = argparse.ArgumentParser(description='GraphPool arguments.')
io_parser = parser.add_mutually_exclusive_group(required=False)
io_parser.add_argument('--dataset', dest='dataset',
help='Input dataset.')
benchmark_parser = io_parser.add_argument_group()
benchmark_parser.add_argument('--bmname', dest='bmname',
help='Name of the benchmark dataset')
io_parser.add_argument('--pkl', dest='pkl_fname',
help='Name of the pkl data file')
softpool_parser = parser.add_argument_group()
softpool_parser.add_argument('--assign-ratio', dest='assign_ratio', type=float,
help='ratio of number of nodes in consecutive layers')
softpool_parser.add_argument('--num-pool', dest='num_pool', type=int,
help='number of pooling layers')
parser.add_argument('--linkpred', dest='linkpred', action='store_const',
const=True, default=False,
help='Whether link prediction side objective is used')
parser_utils.parse_optimizer(parser)
parser.add_argument('--datadir', dest='datadir',
help='Directory where benchmark is located')
parser.add_argument('--logdir', dest='logdir',
help='Tensorboard log directory')
parser.add_argument('--ckptdir', dest='ckptdir',
help='Model checkpoint directory')
parser.add_argument('--cuda', dest='cuda',
help='CUDA.')
parser.add_argument('--gpu', dest='gpu', action='store_const',
const=True, default=False,
help='whether to use GPU.')
parser.add_argument('--max_nodes', dest='max_nodes', type=int,
help='Maximum number of nodes (ignore graghs with nodes exceeding the number.')
parser.add_argument('--batch_size', dest='batch_size', type=int,
help='Batch size.')
parser.add_argument('--epochs', dest='num_epochs', type=int,
help='Number of epochs to train.')
parser.add_argument('--train_ratio', dest='train_ratio', type=float,
help='Ratio of number of graphs training set to all graphs.')
parser.add_argument('--num_workers', dest='num_workers', type=int,
help='Number of workers to load data.')
parser.add_argument('--feature', dest='feature_type',
help='Feature used for encoder. Can be: id, deg')
parser.add_argument('--input_dim', dest='input_dim', type=int,
help='Input feature dimension')
parser.add_argument('--hidden_dim', dest='hidden_dim', type=int,
help='Hidden dimension')
parser.add_argument('--output_dim', dest='output_dim', type=int,
help='Output dimension')
parser.add_argument('--num_classes', dest='num_classes', type=int,
help='Number of label classes')
parser.add_argument('--num_gc_layers', dest='num_gc_layers', type=int,
help='Number of graph convolution layers before each pooling')
parser.add_argument('--bn', dest='bn', action='store_const',
const=True, default=False,
help='Whether batch normalization is used')
parser.add_argument('--dropout', dest='dropout', type=float,
help='Dropout rate.')
parser.add_argument('--nobias', dest='bias', action='store_const',
const=False, default=True,
help='Whether to add bias. Default to True.')
parser.add_argument('--weight_decay', dest='weight_decay', type=float,
help='Weight decay regularization constant.')
parser.add_argument('--method', dest='method',
help='Method. Possible values: base, ')
parser.add_argument('--name-suffix', dest='name_suffix',
help='suffix added to the output filename')
parser.set_defaults(datadir='data', # io_parser
logdir='log',
ckptdir='ckpt',
dataset='syn1',
opt='adam', # opt_parser
opt_scheduler='none',
max_nodes=100,
cuda='1',
feature_type='default',
lr=0.001,
clip=2.0,
batch_size=20,
num_epochs=1000,
train_ratio=0.8,
test_ratio=0.1,
num_workers=1,
input_dim=10,
hidden_dim=20,
output_dim=20,
num_classes=2,
num_gc_layers=3,
dropout=0.0,
weight_decay=0.005,
method='base',
name_suffix='',
assign_ratio=0.1,
)
return parser.parse_args()