-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
180 lines (150 loc) · 7.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import sys
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Neural Architecture Search with Ray")
"""
General arguments for NAS
"""
parser.add_argument("--smoke-test", default=False, action="store_true", help="Finish quickly for testing")
parser.add_argument("--ray-address", default=None, help="Address of Ray cluster for seamless distributed execution.")
parser.add_argument("--cuda", default=False, action="store_true", help="Enables GPU training")
parser.add_argument('--batch_size', type=int, default=64, help="batch size of supervised learning")
parser.add_argument('--num_workers', type=int, default=2, help="workers for torch data loaders")
def invalid_use(k):
print("Usage: python main.py [darts|enas|random] [cnn|rnn|viz]")
if k > 0:
print(f"{sys.argv[k]} is not a valid option.")
exit(-1)
if len(sys.argv) < 3:
invalid_use(0)
elif sys.argv[1] == 'darts':
"""
General arguments for DARTS
"""
parser.add_argument("--layers", default=10, type=int, help="Number of layers in model")
if sys.argv[2] == 'cnn':
from darts.run_cnn import run_experiment
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
dataset_options = ["cifar10", "mnist", "imagenet"]
"""
CNN specific arguments for DARTS
"""
parser.add_argument("--dataset", default="cifar10", choices=dataset_options, type=str.lower, help="Name of dataset")
args = parser.parse_args()
run_experiment(args)
elif sys.argv[2] == 'rnn':
raise NotImplementedError
from darts.run_rnn import run_experiment
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
dataset_options = ["ptb", "wikitext"]
"""
RNN specific arguments for DARTS
"""
parser.add_argument("--data_path", default="~/data/ptb", type=str, help="Path to text dataset")
args = parser.parse_args()
# run_experiment(args)
# TODO: Add Tune checkpointing and integrate with this for visualizing
elif sys.argv[2] == 'viz':
from darts.viz import viz_arch
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
"""
Arguments for visualizing searched architectures
"""
parser.add_argument("--load", default=None, type=str, help="Path to dir of a specific tune experiment")
parser.add_argument("--save", default=None, type=str, help="Path to dir for saving the pngs of the model graph to. If unset defaults to load_dir")
parser.add_argument("--viz", default=False, action="store_true", help="Open up vizualize pngs or not")
args = parser.parse_args()
viz_arch(args.load, args.save, viz=args.viz)
else:
invalid_use(2)
elif sys.argv[1] == 'enas':
"""
General arguments for ENAS
"""
parser.add_argument("--num_blocks", default=12, type=int, help="Number of layers in model")
if sys.argv[2] == 'cnn':
raise NotImplementedError
from enas.run_cnn import run_experiment
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
dataset_options = ["cifar10", "mnist", "imagenet"]
"""
CNN specific arguments for ENAS
"""
parser.add_argument("--dataset", default="cifar10", choices=dataset_options, type=str.lower, help="Name of dataset")
args = parser.parse_args()
run_experiment(args)
elif sys.argv[2] == 'rnn':
from enas.run_rnn import run_experiment
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
dataset_options = ["ptb", "wikitext"]
"""
RNN specific arguments for ENAS
"""
parser.add_argument("--data_path", default="~/data/ptb", type=str, help="Path to text dataset")
args = parser.parse_args()
run_experiment(args)
elif sys.argv[2] == 'viz':
raise NotImplementedError
from enas.viz import viz_arch
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
"""
Arguments for visualizing searched architectures
"""
parser.add_argument("--load", default=None, type=str, help="Path to dir of a specific tune experiment")
parser.add_argument("--save", default=None, type=str, help="Path to dir for saving the pngs of the model graph to. If unset defaults to load_dir")
parser.add_argument("--viz", default=False, action="store_true", help="Open up vizualize pngs or not")
args = parser.parse_args()
viz_arch(args.load, args.save, viz=args.viz)
else:
invalid_use(2)
elif sys.argv[1] == 'random':
"""
General arguments for RandomNAS (ENAS without RNN controller)
"""
parser.add_argument("--num_blocks", default=12, type=int, help="Number of layers in model")
if sys.argv[2] == 'cnn':
from random_nas.run_cnn import run_experiment
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
dataset_options = ["cifar10", "mnist", "imagenet"]
"""
CNN specific arguments for ENAS
"""
parser.add_argument("--dataset", default="cifar10", choices=dataset_options, type=str.lower, help="Name of dataset")
args = parser.parse_args()
run_experiment(args)
elif sys.argv[2] == 'rnn':
raise NotImplementedError
from random_nas.run_rnn import run_experiment
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
dataset_options = ["ptb", "wikitext"]
"""
RNN specific arguments for ENAS
"""
parser.add_argument("--data_path", default="~/data/ptb", type=str, help="Path to text dataset")
args = parser.parse_args()
run_experiment(args)
elif sys.argv[2] == 'viz':
raise NotImplementedError
from random_nas.viz import viz_arch
sys.argv.remove(sys.argv[2])
sys.argv.remove(sys.argv[1])
"""
Arguments for visualizing searched architectures
"""
parser.add_argument("--load", default=None, type=str, help="Path to dir of a specific tune experiment")
parser.add_argument("--save", default=None, type=str, help="Path to dir for saving the pngs of the model graph to. If unset defaults to load_dir")
parser.add_argument("--viz", default=False, action="store_true", help="Open up vizualize pngs or not")
args = parser.parse_args()
viz_arch(args.load, args.save, viz=args.viz)
else:
invalid_use(2)
else:
invalid_use(1)