forked from NVlabs/imaginaire
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
93 lines (77 loc) · 3.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import argparse
from imaginaire.config import Config
from imaginaire.utils.cudnn import init_cudnn
from imaginaire.utils.dataset import get_train_and_val_dataloader
from imaginaire.utils.distributed import init_dist
from imaginaire.utils.distributed import master_only_print as print
from imaginaire.utils.gpu_affinity import set_affinity
from imaginaire.utils.logging import init_logging, make_logging_dir
from imaginaire.utils.trainer import (get_model_optimizer_and_scheduler,
get_trainer, set_random_seed)
def parse_args():
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--config',
help='Path to the training config file.', required=True)
parser.add_argument('--logdir', help='Dir for saving logs and models.')
parser.add_argument('--checkpoint', default='', help='Checkpoint path.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--single_gpu', action='store_true')
parser.add_argument('--num_workers', type=int)
args = parser.parse_args()
return args
def main():
args = parse_args()
set_affinity(args.local_rank)
set_random_seed(args.seed, by_rank=True)
cfg = Config(args.config)
# If args.single_gpu is set to True,
# we will disable distributed data parallel
if not args.single_gpu:
cfg.local_rank = args.local_rank
init_dist(cfg.local_rank)
# Override the number of data loading workers if necessary
if args.num_workers is not None:
cfg.data.num_workers = args.num_workers
# Create log directory for storing training results.
cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)
make_logging_dir(cfg.logdir)
# Initialize cudnn.
init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
# Initialize data loaders and models.
train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg)
net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
get_model_optimizer_and_scheduler(cfg, seed=args.seed)
trainer = get_trainer(cfg, net_G, net_D,
opt_G, opt_D,
sch_G, sch_D,
train_data_loader, val_data_loader)
current_epoch, current_iteration = trainer.load_checkpoint(
cfg, args.checkpoint)
# Start training.
for epoch in range(current_epoch, cfg.max_epoch):
print('Epoch {} ...'.format(epoch))
if not args.single_gpu:
train_data_loader.sampler.set_epoch(current_epoch)
trainer.start_of_epoch(current_epoch)
for it, data in enumerate(train_data_loader):
data = trainer.start_of_iteration(data, current_iteration)
for _ in range(cfg.trainer.dis_step):
trainer.dis_update(data)
for _ in range(cfg.trainer.gen_step):
trainer.gen_update(data)
current_iteration += 1
trainer.end_of_iteration(data, current_epoch, current_iteration)
if current_iteration >= cfg.max_iter:
print('Done with training!!!')
return
current_epoch += 1
trainer.end_of_epoch(data, current_epoch, current_iteration)
print('Done with training!!!')
return
if __name__ == "__main__":
main()