Skip to content

Commit

Permalink
support muli-gpu train with metric=Ap(#12); support amp(#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangming8 committed Aug 10, 2021
1 parent 9a743eb commit 5d6faac
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 167 deletions.
42 changes: 17 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#### Model Zoo

All weights can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/1qEMLzikH5JwRNRoHpeCa6BJBeSQ6xXCH?usp=sharing) or [BaiduDrive](https://pan.baidu.com/s/1UsbdnyVwRJhr9Vy1tmJLeQ)(code:bc72)
All weights can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/1qEMLzikH5JwRNRoHpeCa6BJBeSQ6xXCH?usp=sharing) or [BaiduDrive](https://pan.baidu.com/s/1UsbdnyVwRJhr9Vy1tmJLeQ) (code:bc72)

|Model |test size |mAP<sup>val<br>0.5:0.95 |mAP<sup>test<br>0.5:0.95 | Params<br>(M) |
| ------ |:---: |:---: | :---: |:---: |
Expand All @@ -28,7 +28,7 @@ All weights can be downloaded from [GoogleDrive](https://drive.google.com/drive/
|yolox-x |640 |50.5 |51.1 |99.1 |
|yolox-x |800 |51.2 |51.9 |99.1 |

The weights were converted from [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX). mAP was reevaluated on COCO val2017 and test2017, and some results are slightly better than the official implement. You can reproduce them by scripts in 'evaluate.sh'
mAP was reevaluated on COCO val2017 and test2017, and some results are slightly better than the official implement [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX). You can reproduce them by scripts in 'evaluate.sh'

#### Dataset
download COCO:
Expand All @@ -45,24 +45,29 @@ The weights were converted from [YOLOX](https://github.com/Megvii-BaseDetection/
change opt.dataset_path = "/path/to/dataset" in 'config.py'

#### Train

See more example in 'train.sh'
a. Train from scratch:(backbone="CSPDarknet-s" means using yolox-s, and you can change it to any other backbone, eg: CSPDarknet-nano, tiny, s, m, l, x)
python train.py gpus='0' backbone="CSPDarknet-s" num_epochs=300 exp_id="coco_CSPDarknet-s_640x640" use_amp=False val_intervals=1 data_num_workers=8
python train.py gpus='0' backbone="CSPDarknet-s" num_epochs=300 exp_id="coco_CSPDarknet-s_640x640" use_amp=True val_intervals=2 data_num_workers=6 metric="ap" batch_size=48

b. Finetune, download pre-trained weight on COCO and finetune on customer dataset:
python train.py gpus='0' backbone="CSPDarknet-s" num_epochs=300 exp_id="coco_CSPDarknet-s_640x640" use_amp=False val_intervals=1 data_num_workers=8 load_model="../weights/yolox-s.pth" resume=False
python train.py gpus='0' backbone="CSPDarknet-s" num_epochs=300 exp_id="coco_CSPDarknet-s_640x640" use_amp=True val_intervals=2 data_num_workers=6 metric="ap" batch_size=48 load_model="../weights/yolox-s.pth" resume=False

c. Resume, you can use 'resume=True' when your training is accidentally stopped:
python train.py gpus='0' backbone="CSPDarknet-s" num_epochs=300 exp_id="coco_CSPDarknet-s_640x640" use_amp=False val_intervals=1 data_num_workers=8 load_model="exp/coco_CSPDarknet-s_640x640/model_last.pth" resume=True
python train.py gpus='0' backbone="CSPDarknet-s" num_epochs=300 exp_id="coco_CSPDarknet-s_640x640" use_amp=True val_intervals=2 data_num_workers=6 metric="ap" batch_size=48 load_model="exp/coco_CSPDarknet-s_640x640/model_last.pth" resume=True

d. Some tips:
Ⅰ You can also change params in 'train.sh'(these params will replace opt.xxx in config.py) and use 'sh train.sh' to train
Ⅱ if you want to close mulit-size training, change opt.random_size = None or (20, 21) in 'config.py')
Ⅲ mulit-gpu train: change opt.gpus = "3,5,6,7"
Ⅰ You can also change params in 'train.sh'(these params will replace opt.xxx in config.py) and use 'nohup sh train.sh &' to train
Ⅱ If you want to close mulit-size training, change opt.random_size = None or (20, 21) in 'config.py' or set random_size=None in 'train.sh'
Ⅲ Mulit-gpu train: change opt.gpus = "3,5,6,7"
Ⅳ Visualized log by tensorboard: tensorboard --logdir exp/your_exp_id/logs_2021-08-xx-xx-xx and visit http://localhost:6006
Your can also use the following shell scripts:
grep 'train epoch' exp/your_exp_id/logs_2021-08-xx-xx-xx/log.txt
grep 'val epoch' exp/your_exp_id/logs_2021-08-xx-xx-xx/log.txt
grep 'AP' exp/your_exp_id/logs_2021-08-xx-xx-xx/log.txt |grep 0.95

#### Evaluate

The trained weights will be saved in './exp/your_exp_id/model_xx.pth'
The weights will be saved in './exp/your_exp_id/model_xx.pth'
change 'load_model'='weight/path/to/evaluate.pth' and backbone='backbone-type' in 'evaluate.sh'
sh evaluate.sh

Expand Down Expand Up @@ -104,21 +109,7 @@ The weights were converted from [YOLOX](https://github.com/Megvii-BaseDetection/
DOING

#### Multi-class MOT Dataset
DOING, not ready
1. download and unzip VisDrone dataset http://aiskyeye.com/download/multi-object-tracking_2021

2. put train and val dataset into:
/path/to/dataset/VisDrone/VisDrone2019-MOT-train # This folder contains two subfolders, 'annotations' and 'sequences'
/path/to/dataset/VisDrone/VisDrone2019-MOT-val # This folder contains two subfolders, 'annotations' and 'sequences'

3. change opt.dataset_path = "/path/to/dataset/VisDrone" in 'config.py'
4. python tools/visdrone_mot_to_coco.py # converted to COCO format
5.(Optional) python tools/show_coco_anns.py # visualized tracking id

6. set class name and tracking id number
change opt.label_name=['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
change opt.tracking_id_nums=[1829, 853, 323, 3017, 295, 159, 215, 79, 55, 749]
change opt.reid_dim=128
DOING

#### Train
DOING
Expand All @@ -133,3 +124,4 @@ The weights were converted from [YOLOX](https://github.com/Megvii-BaseDetection/
https://github.com/Megvii-BaseDetection/YOLOX
https://github.com/PaddlePaddle/PaddleDetection
https://github.com/open-mmlab/mmdetection
https://github.com/xingyizhou/CenterNet
8 changes: 4 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def update_nano_tiny(cfg, inp_params):
# yolo-nano, yolo-tiny config:
cfg.scale = cfg.scale if 'scale' in inp_params else (0.5, 1.5)
cfg.test_size = cfg.test_size if 'test_size' in inp_params else (416, 416)
cfg.enable_mixup = cfg.enable_mixup if 'enable_mixup' in inp_params else False
Expand All @@ -26,7 +27,6 @@ def update_nano_tiny(cfg, inp_params):
opt.exp_id = "coco_CSPDarknet-s_640x640" # experiment name, you can change it to any other name
opt.dataset_path = "/data/dataset/coco_dataset" # COCO detection
# opt.dataset_path = r"D:\work\public_dataset\coco2017" # Windows system
# opt.dataset_path = "/media/ming/DATA1/dataset/VisDrone" # MOT tracking
opt.backbone = "CSPDarknet-s" # CSPDarknet-nano, CSPDarknet-tiny, CSPDarknet-s, CSPDarknet-m, l, x
opt.input_size = (640, 640)
opt.random_size = (14, 26) # None; multi-size train: from 448 to 800, random sample an int value and *32 as input size
Expand Down Expand Up @@ -79,15 +79,15 @@ def update_nano_tiny(cfg, inp_params):
opt.perspective = 0.0
opt.enable_mixup = True
opt.seed = 0
opt.data_num_workers = 0
opt.data_num_workers = 4

opt.momentum = 0.9
opt.vis_thresh = 0.3 # inference confidence, used in 'predict.py'
opt.load_model = ''
opt.ema = True # False, Exponential Moving Average
opt.grad_clip = dict(max_norm=35, norm_type=2) # None, clip gradient makes training more stable
opt.print_iter = 1 # print loss every 1 iteration
opt.metric = "loss" # 'Ap' 'loss', a little slow when set 'Ap'
opt.metric = "loss" # 'Ap' 'loss', used to save 'model_best.pth'
opt.val_intervals = 1 # evaluate(when metric='Ap') and save best ckpt every 1 epoch
opt.save_epoch = 1 # save check point every 1 epoch
opt.resume = False # resume from 'model_last.pth' when set True
Expand Down Expand Up @@ -134,4 +134,4 @@ def update_nano_tiny(cfg, inp_params):
assert opt.tracking_id_nums is not None

os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus_str
print("\n{} final config: {}\n{}".format("-"*20, "-"*20, opt))
print("\n{} final config: {}\n{}".format("-" * 20, "-" * 20, opt))
15 changes: 3 additions & 12 deletions data/coco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys

sys.path.append(".")
from data import (COCODataset, TrainTransform, YoloBatchSampler, DataLoader, InfiniteSampler, MosaicDetection)
from data import COCODataset, TrainTransform, YoloBatchSampler, DataLoader, InfiniteSampler, MosaicDetection


def get_dataloader(opt, no_aug=False):
Expand All @@ -34,7 +34,7 @@ def get_dataloader(opt, no_aug=False):
enable_mixup=opt.enable_mixup,
tracking=do_tracking,
)
train_sampler = InfiniteSampler(len(train_dataset), seed=opt.seed)
train_sampler = InfiniteSampler(len(train_dataset), seed=opt.seed if opt.seed is not None else 0)
batch_sampler = YoloBatchSampler(
sampler=train_sampler,
batch_size=opt.batch_size,
Expand All @@ -56,21 +56,12 @@ def get_dataloader(opt, no_aug=False):
augment=False))
val_sampler = torch.utils.data.SequentialSampler(val_dataset)
val_kwargs = {"num_workers": opt.data_num_workers, "pin_memory": True, "sampler": val_sampler,
"batch_size": opt.batch_size}
"batch_size": opt.batch_size, "drop_last": True}
val_loader = torch.utils.data.DataLoader(val_dataset, **val_kwargs)

return train_loader, val_loader


def memory_info():
import psutil

mem_total = psutil.virtual_memory().total / 1024 / 1024 / 1024
mem_used = psutil.virtual_memory().used / 1024 / 1024 / 1024
mem_percent = psutil.virtual_memory().percent
return mem_percent, mem_used, mem_total


def vis_inputs(inputs, targets, opt):
from utils.util import label_color

Expand Down
60 changes: 37 additions & 23 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import cv2
import tqdm
import json
import numpy as np
import pycocotools.coco as coco_
from pycocotools.cocoeval import COCOeval

Expand All @@ -19,37 +20,50 @@ def evaluate():
detector = Detector(opt)
gt_ann = opt.val_ann if "test_ann" not in opt.keys() else opt.test_ann
img_dir = opt.dataset_path + "/images/" + ("test2017" if "test" in os.path.basename(gt_ann) else "val2017")
batch_size = opt.batch_size

assert os.path.isfile(gt_ann), 'cannot find gt {}'.format(gt_ann)
coco = coco_.COCO(gt_ann)
images = coco.getImgIds()
class_ids = sorted(coco.getCatIds())
num_samples = len(images)

print("==>> evaluating batch_size={}".format(batch_size))
print('find {} samples in {}'.format(num_samples, gt_ann))

result_file = "result_{}_{}.json".format(opt.backbone, opt.test_size[0])
coco_res = []
for index in tqdm.tqdm(range(num_samples)):
img_id = images[index]
file_name = coco.loadImgs(ids=[img_id])[0]['file_name']
image_path = img_dir + "/" + file_name
assert os.path.isfile(image_path), "cannot find img {}".format(image_path)
img = cv2.imread(image_path)
img_h, img_w = img.shape[:2]
results = detector.run(img, vis_thresh=0.001)
for res in results:
cls, conf, bbox = res[0], res[1], res[2]
bbox[0] = max(0, min(img_w, bbox[0]))
bbox[1] = max(0, min(img_h, bbox[1]))
bbox[2] = max(0, min(img_w, bbox[2]))
bbox[3] = max(0, min(img_h, bbox[3]))
if len(res) > 3:
reid_feat = res[4]
cls_index = opt.label_name.index(cls)
coco_res.append(
{'bbox': [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]],
'category_id': class_ids[cls_index],
'image_id': int(img_id),
'score': conf})
samples_idx = list(range(num_samples))
iterations = int(np.ceil(num_samples / float(batch_size)))
for its in tqdm.tqdm(range(iterations)):
batch_index = samples_idx[its * batch_size: (its + 1) * batch_size]
batch_images = []
batch_img_ids = []
for index in batch_index:
img_id = images[index]
file_name = coco.loadImgs(ids=[img_id])[0]['file_name']
image_path = img_dir + "/" + file_name
assert os.path.isfile(image_path), "cannot find img {}".format(image_path)
img = cv2.imread(image_path)

batch_images.append(img)
batch_img_ids.append(img_id)

batch_results = detector.run(batch_images, vis_thresh=0.001)

for index in range(len(batch_images)):
results = batch_results[index]
img_id = batch_img_ids[index]
for res in results:
cls, conf, bbox = res[0], res[1], res[2]
if len(res) > 3:
reid_feat = res[4]
cls_index = opt.label_name.index(cls)
coco_res.append(
{'bbox': [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]],
'category_id': class_ids[cls_index],
'image_id': int(img_id),
'score': conf})

with open(result_file, 'w') as f_dump:
json.dump(coco_res, f_dump, cls=NpEncoder)
Expand All @@ -60,7 +74,7 @@ def evaluate():
try:
zip_file = result_file.replace(".json", ".zip")
os.system("zip {} {}".format(zip_file, result_file))
print("create upload file done: {}".format(zip_file))
print("--> create upload file done: {}".format(zip_file))
except:
print("please zip it before uploading")
return
Expand Down
5 changes: 4 additions & 1 deletion evaluate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@
#python evaluate.py gpus='0' backbone="CSPDarknet-x" load_model="../weights/yolox-x.pth" dataset_path="/data/dataset/coco_dataset" test_ann="/data/dataset/coco_dataset/annotations/image_info_test-dev2017.json" test_size="(800,800)"

# evaluate customer dataset
python evaluate.py gpus='0' backbone="CSPDarknet-s" load_model="exp/coco_CSPDarknet-s_640x640/model_best.pth"
python evaluate.py gpus='0' backbone="CSPDarknet-s" load_model="exp/coco_CSPDarknet-s_640x640/model_best.pth" batch_size=24

# fuse BN into Conv to speed up
#python evaluate.py gpus='0' backbone="CSPDarknet-s" load_model="exp/coco_CSPDarknet-s_640x640/model_best.pth" batch_size=24 fuse=True
8 changes: 5 additions & 3 deletions models/losses/yolox_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,11 @@ def get_assignments(
if mode == "cpu":
cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()

cls_preds_ = (cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid() * obj_preds_.unsqueeze(0).repeat(
num_gt, 1, 1).sigmoid())
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt(), gt_cls_per_image, reduction="none").sum(-1)
with torch.cuda.amp.autocast(enabled=False):
cls_preds_ = (
cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid() * obj_preds_.unsqueeze(0).repeat(
num_gt, 1, 1).sigmoid())
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt(), gt_cls_per_image, reduction="none").sum(-1)
del cls_preds_

cost = (pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center))
Expand Down
44 changes: 44 additions & 0 deletions models/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule, Sequential

from models.backbone.csp_darknet import BaseConv


class ResLayer(Sequential):
"""ResLayer to build ResNet style backbone.
Expand Down Expand Up @@ -217,3 +219,45 @@ def __init__(self,
def forward(self, inputs):
out = self.conv_bn_layer(inputs)
return out


def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = (
nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)

# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

# prepare spatial bias
b_conv = (
torch.zeros(conv.weight.size(0), device=conv.weight.device)
if conv.bias is None
else conv.bias
)
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

return fusedconv


def fuse_model(model):
for m in model.modules():
if type(m) is BaseConv and hasattr(m, "bn"):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, "bn") # remove batchnorm
m.forward = m.fuseforward # update forward
return model
8 changes: 7 additions & 1 deletion models/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchvision


def yolox_post_process(outputs, down_strides, num_classes, conf_thre, nms_thre, label_name, img_ratios):
def yolox_post_process(outputs, down_strides, num_classes, conf_thre, nms_thre, label_name, img_ratios, img_shape):
hw = [i.shape[-2:] for i in outputs]
grids, strides = [], []
for (hsize, wsize), stride in zip(hw, down_strides):
Expand Down Expand Up @@ -63,11 +63,17 @@ def yolox_post_process(outputs, down_strides, num_classes, conf_thre, nms_thre,

detections[:, :4] = detections[:, :4] / img_ratios[i]

img_h, img_w = img_shape[i]
for det in detections:
x1, y1, x2, y2, obj_conf, class_conf, class_pred = det[0:7]
bbox = [float(x1), float(y1), float(x2), float(y2)]
conf = float(obj_conf * class_conf)
label = label_name[int(class_pred)]
# clip bbox
bbox[0] = max(0, min(img_w, bbox[0]))
bbox[1] = max(0, min(img_h, bbox[1]))
bbox[2] = max(0, min(img_w, bbox[2]))
bbox[3] = max(0, min(img_h, bbox[3]))

if reid_dim > 0:
reid_feat = det[7:].cpu().numpy().tolist()
Expand Down
Loading

0 comments on commit 5d6faac

Please sign in to comment.