-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
37 lines (29 loc) · 1.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from options.atme_options import AtmeOptions
from options.simple_options import SimpleOptions
import sys
import atme
import simple
import argparse
if __name__ == '__main__':
main_parser = argparse.ArgumentParser(description="Main parser with subparsers")
subparsers = main_parser.add_subparsers(dest="model")
parser1 = subparsers.add_parser(name="atme")
parser2 = subparsers.add_parser(name="simple")
atme_opt = simple_opt = None
if str(sys.argv[1]) == 'atme':
atme_opt = AtmeOptions().parse(parser1)
elif str(sys.argv[1]) == 'simple':
simple_opt = SimpleOptions().parse(parser2)
else:
print(f'model {str(sys.argv[1])} is not exist!')
opt = main_parser.parse_args()
if opt.model == "atme":
AtmeOptions().print_options(atme_opt)
if atme_opt.isTrain:
atme.train(atme_opt)
if opt.TestAfterTrain: atme.test(atme_opt)
else:
atme.test(atme_opt)
elif opt.model == "simple":
SimpleOptions().print_options(simple_opt)
simple.train(simple_opt)