-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
156 lines (133 loc) · 4.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import numpy as np
import os
import random
import torch
from tensorboardX import SummaryWriter
import argparse
from omegaconf import OmegaConf
from datetime import datetime
from utils.utils import save_yaml
from models import get_model
from trainers import get_trainer, get_logger
from loader import get_dataloader
from optimizers import get_optimizer
import wandb
def parse_arg_type(val):
if val.isnumeric():
return int(val)
if (val == 'True') or (val == 'true'):
return True
if (val == 'False') or (val == 'false'):
return False
try:
return float(val)
except:
return str(val)
def parse_unknown_args(l_args):
"""convert the list of unknown args into dict
this does similar stuff to OmegaConf.from_cli()
I may have invented the wheel again..."""
n_args = len(l_args) // 2
kwargs = {}
for i_args in range(n_args):
key = l_args[i_args*2]
val = l_args[i_args*2 + 1]
assert '=' not in key, 'optional arguments should be separated by space'
kwargs[key.strip('-')] = parse_arg_type(val)
return kwargs
def parse_nested_args(d_cmd_cfg):
"""produce a nested dictionary by parsing dot-separated keys
e.g. {key1.key2 : 1} --> {key1: {key2: 1}}"""
d_new_cfg = {}
for key, val in d_cmd_cfg.items():
l_key = key.split('.')
d = d_new_cfg
for i_key, each_key in enumerate(l_key):
if i_key == len(l_key) - 1:
d[each_key] = val
else:
if each_key not in d:
d[each_key] = {}
d = d[each_key]
return d_new_cfg
def run(cfg, writer):
# Setup seeds
seed = cfg.training.get('seed', 1)
print(f"running with random seed : {seed}")
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.set_num_threads(8)
# Setup device
device = cfg.device
# Setup Dataloader
d_dataloaders = {}
for key, dataloader_cfg in cfg.data.items():
d_dataloaders[key] = get_dataloader(dataloader_cfg)
model_kwargs = {}
model_kwargs['dl'] = d_dataloaders['validation']
model = get_model(cfg, **model_kwargs).to(device)
logger = get_logger(cfg, writer)
# Setup optimizer
optimizer = get_optimizer(
cfg.training.optimizer,
filter(lambda p: p.requires_grad, model.parameters())
)
# Setup Trainer
trainer = get_trainer(optimizer, cfg)
model, train_result = trainer.train(
model,
d_dataloaders,
logger=logger,
logdir=writer.file_writer.get_logdir(),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base_config", type=str)
parser.add_argument("--config", type=str)
parser.add_argument("--device", default='any')
parser.add_argument("--wandb", default='off')
parser.add_argument("--run", default=None)
args, unknown = parser.parse_known_args()
d_cmd_cfg = parse_unknown_args(unknown)
d_cmd_cfg = parse_nested_args(d_cmd_cfg)
print(d_cmd_cfg)
base_cfg = OmegaConf.load(args.base_config)
cfg = OmegaConf.load(args.config)
cfg = OmegaConf.merge(base_cfg, cfg)
cfg = OmegaConf.merge(cfg, d_cmd_cfg)
print(OmegaConf.to_yaml(cfg))
if args.device == "cpu":
cfg["device"] = f"cpu"
if args.device == "any":
cfg["device"] = f"cuda"
else:
cfg["device"] = f"cuda:{args.device}"
if args.run is None:
run_id = datetime.now().strftime("%Y%m%d-%H%M")
else:
run_id = args.run
config_basename = os.path.basename(args.config).split(".")[0]
if hasattr(cfg, "logdir"):
logdir = cfg["logdir"]
else:
logdir = args.logdir
logdir = os.path.join(logdir, run_id)
writer = SummaryWriter(logdir=logdir)
print("Result directory: {}".format(logdir))
# copy config file
copied_yml = os.path.join(logdir, os.path.basename(args.config))
save_yaml(copied_yml, OmegaConf.to_yaml(cfg))
print(f"config saved as {copied_yml}")
cfg['wandb'] = args.wandb
if args.wandb == 'on':
# if wandb set on, you shuld add your wandb entity:
# train.py ... --wandb_entity <your_entity>
wandb.init(
entity=args.wandb_entity,
project=cfg['wandb_project_name'],
config=OmegaConf.to_container(cfg),
name=logdir
)
run(cfg, writer)