forked from open-mmlab/OpenUnReID
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
157 lines (131 loc) · 4.52 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import shutil
import sys
import time
from datetime import timedelta
from pathlib import Path
import torch
from openunreid.apis import BaseRunner, test_reid, set_random_seed
from openunreid.core.solvers import build_lr_scheduler, build_optimizer
from openunreid.data import build_test_dataloader, build_train_dataloader
from openunreid.models import build_model
from openunreid.models.losses import build_loss
from openunreid.utils.config import (
cfg,
cfg_from_list,
cfg_from_yaml_file,
log_config_to_file,
)
from openunreid.utils.dist_utils import init_dist, synchronize
from openunreid.utils.file_utils import mkdir_if_missing
from openunreid.utils.logger import Logger
def parge_config():
parser = argparse.ArgumentParser(description="Strong cluster baseline training")
parser.add_argument("config", help="train config file path")
parser.add_argument(
"--work-dir", help="the dir to save logs and models", default=""
)
parser.add_argument("--resume-from", help="the checkpoint file to resume from")
parser.add_argument(
"--launcher",
type=str,
choices=["none", "pytorch", "slurm"],
default="none",
help="job launcher",
)
parser.add_argument("--tcp-port", type=str, default="5017")
parser.add_argument(
"--set",
dest="set_cfgs",
default=None,
nargs=argparse.REMAINDER,
help="set extra config keys if needed",
)
args = parser.parse_args()
cfg_from_yaml_file(args.config, cfg)
cfg.launcher = args.launcher
cfg.tcp_port = args.tcp_port
if not args.work_dir:
args.work_dir = Path(args.config).stem
cfg.work_dir = cfg.LOGS_ROOT / args.work_dir
mkdir_if_missing(cfg.work_dir)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs, cfg)
shutil.copy(args.config, cfg.work_dir / "config.yaml")
return args, cfg
def main():
start_time = time.monotonic()
# init distributed training
args, cfg = parge_config()
dist = init_dist(cfg)
set_random_seed(cfg.TRAIN.seed, cfg.TRAIN.deterministic)
synchronize()
# init logging file
logger = Logger(cfg.work_dir / "log.txt", debug=False)
sys.stdout = logger
print("==========\nArgs:{}\n==========".format(args))
log_config_to_file(cfg)
# build train loader
train_loader, train_sets = build_train_dataloader(cfg)
# the number of classes for the model is tricky,
# you need to make sure that
# it is always larger than the number of clusters
num_classes = 0
for idx, set in enumerate(train_sets):
if idx in cfg.TRAIN.unsup_dataset_indexes:
# number of clusters in an unsupervised dataset
# must not be larger than the number of images
num_classes += len(set)
else:
# ground-truth classes for supervised dataset
num_classes += set.num_pids
# build model
model = build_model(cfg, num_classes, init=cfg.MODEL.source_pretrained)
model.cuda()
if dist:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[cfg.gpu],
output_device=cfg.gpu,
find_unused_parameters=True,
)
elif cfg.total_gpus > 1:
model = torch.nn.DataParallel(model)
# build optimizer
optimizer = build_optimizer([model,], **cfg.TRAIN.OPTIM)
# build lr_scheduler
if cfg.TRAIN.SCHEDULER.lr_scheduler is not None:
lr_scheduler = build_lr_scheduler(optimizer, **cfg.TRAIN.SCHEDULER)
else:
lr_scheduler = None
# build loss functions
criterions = build_loss(cfg.TRAIN.LOSS, num_classes=num_classes, cuda=True)
# build runner
runner = BaseRunner(
cfg,
model,
optimizer,
criterions,
train_loader,
train_sets=train_sets,
lr_scheduler=lr_scheduler,
reset_optim=True,
)
# resume
if args.resume_from:
runner.resume(args.resume_from)
# start training
runner.run()
# load the best model
runner.resume(cfg.work_dir / "model_best.pth")
# final testing
test_loaders, queries, galleries = build_test_dataloader(cfg)
for i, (loader, query, gallery) in enumerate(zip(test_loaders, queries, galleries)):
cmc, mAP = test_reid(
cfg, model, loader, query, gallery, dataset_name=cfg.TEST.datasets[i]
)
# print time
end_time = time.monotonic()
print("Total running time: ", timedelta(seconds=end_time - start_time))
if __name__ == "__main__":
main()