forked from hustvl/SparseTrack
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
194 lines (169 loc) · 7.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
"""
Training script using the new "LazyConfig" python config files.
This scripts reads a given python config file and runs the training or evaluation.
It can be used to train any models or dataset as long as they can be
instantiated by the recursive construction defined in the given config file.
Besides lazy construction of models, dataloader, etc., this scripts expects a
few common configuration parameters currently defined in "configs/common/train.py".
To add more complicated training logic, you can easily add other configs
in the config file and implement a new train_net.py to handle them.
"""
import logging
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
AMPTrainer,
SimpleTrainer,
default_argument_parser,
default_setup,
default_writers,
hooks,
launch,
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.evaluation import inference_on_dataset, print_csv_format
from detectron2.utils import comm
from utils.multiscale import MultiScale
from utils.lr_scheduler import LRHook
from utils.mosaic_close import MosaicClose
from utils import ema
from register_data import *
logger = logging.getLogger("detectron2")
def do_test(cfg, model, eval_only=False):
logger = logging.getLogger("detectron2")
if eval_only:
logger.info("Run evaluation under eval-only mode")
if cfg.train.model_ema.enabled and cfg.train.model_ema.use_ema_weights_for_eval_only:
logger.info("Run evaluation with EMA.")
else:
logger.info("Run evaluation without EMA.")
if "evaluator" in cfg.dataloader:
ret = inference_on_dataset(
model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
)
print_csv_format(ret)
return ret
logger.info("Run evaluation without EMA.")
if "evaluator" in cfg.dataloader:
ret = inference_on_dataset(
model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
)
print_csv_format(ret)
if cfg.train.model_ema.enabled:
logger.info("Run evaluation with EMA.")
with ema.apply_model_ema_and_restore(model):
if "evaluator" in cfg.dataloader:
ema_ret = inference_on_dataset(
model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
)
print_csv_format(ema_ret)
ret.update(ema_ret)
return ret
def do_train(args, cfg):
"""
Args:
cfg: an object with the following attributes:
model: instantiate to a module
dataloader.{train,test}: instantiate to dataloaders
dataloader.evaluator: instantiate to evaluator for test set
optimizer: instantaite to an optimizer
lr_multiplier: instantiate to a fvcore scheduler
train: other misc config defined in `configs/common/train.py`, including:
output_dir (str)
init_checkpoint (str)
amp.enabled (bool)
max_iter (int)
eval_period, log_period (int)
device (str)
checkpointer (dict)
ddp (dict)
"""
model = instantiate(cfg.model)
logger = logging.getLogger("detectron2")
logger.info("Model:\n{}".format(model))
model.to(cfg.train.device)
model.device = torch.device(cfg.train.device)
cfg.optimizer.model = model
optim = instantiate(cfg.optimizer)
train_loader = instantiate(cfg.dataloader.train)
model = create_ddp_model(model, **cfg.train.ddp)
# build model ema
ema.may_build_model_ema(cfg, model)
trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim)
checkpointer = DetectionCheckpointer(
model,
cfg.train.output_dir,
trainer=trainer,
# save model ema
**ema.may_get_ema_checkpointer(cfg, model)
)
trainer.register_hooks(
[
hooks.IterationTimer(),
ema.EMAHook(cfg, model) if cfg.train.model_ema.enabled else None,
LRHook(trainer,
cfg.lr_cfg.train_batch_size, cfg.lr_cfg.basic_lr_per_img, cfg.lr_cfg.scheduler_name,
cfg.lr_cfg.iters_per_epoch, cfg.lr_cfg.max_eps, cfg.lr_cfg.num_warmup_eps,
cfg.lr_cfg.warmup_lr_start, cfg.lr_cfg.no_aug_eps, cfg.lr_cfg.min_lr_ratio
),
MosaicClose(
trainer, cfg.lr_cfg.iters_per_epoch, cfg.lr_cfg.no_aug_eps, cfg.dataloader.train.is_distributed
),
MultiScale(
trainer, cfg.dataloader.train.input_size, cfg.train.random_size, cfg.train.log_period, cfg.dataloader.train.is_distributed
),
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
if comm.is_main_process()
else None,
hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
hooks.PeriodicWriter(
default_writers(cfg.train.output_dir, cfg.train.max_iter),
period=cfg.train.log_period,
)
if comm.is_main_process()
else None,
]
)
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
if args.resume and checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
start_iter = trainer.iter + 1
else:
start_iter = 0
trainer._hooks[1].updates_iter = cfg.train.start_iter = start_iter
trainer.train(start_iter, cfg.train.max_iter)
def main(args):
# import pdb;pdb.set_trace()
cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.apply_overrides(cfg, args.opts)
default_setup(cfg, args)
if args.eval_only:
model = instantiate(cfg.model)
model.to(cfg.train.device)
model.device = torch.device(cfg.train.device)
model = create_ddp_model(model)
# using ema for evaluation
ema.may_build_model_ema(cfg, model)
DetectionCheckpointer(model, **ema.may_get_ema_checkpointer(cfg, model)).load(cfg.train.init_checkpoint)
# Apply ema state for evaluation
if cfg.train.model_ema.enabled and cfg.train.model_ema.use_ema_weights_for_eval_only:
ema.apply_model_ema(model)
print(do_test(cfg, model, eval_only=True))
else:
do_train(args, cfg)
if __name__ == "__main__":
args = default_argument_parser().parse_args()
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
''' CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --num-gpus 4 --config-file mot17_train_config.py '''
''' CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --num-gpus 4 --config-file mot20_train_config.py '''