diff --git a/README.md b/README.md index d4ddcf5..f9560f1 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [A PyTorch Implementation of Dual Shot Face Detector](https://arxiv.org/abs/1810.10220?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29) ### Description -I use basenet [vgg](https://pan.baidu.com/s/1Q-YqoxJyqvln6KTcIck1tQ) to train DSFD,the model can be downloaded in [DSFD](https://pan.baidu.com/s/17cpDHEwYVxWmOIPqUy5zCQ).the AP in WIDER FACE as following: +I use basenet [vgg](https://pan.baidu.com/s/1Q-YqoxJyqvln6KTcIck1tQ) to train DSFD, the model can be downloaded in [DSFD with Baidu](https://pan.baidu.com/s/17cpDHEwYVxWmOIPqUy5zCQ) or [DSFD with Google Drive](https://drive.google.com/open?id=11pZy4DhslDP9cEk2uN9wNhTmB1VZ7GIR). The AP in WIDER FACE as following: | Easy MAP | Medium MAP | hard MAP | | ---------|------------| --------- | @@ -68,4 +68,4 @@ python demo.py ### References * [Dual Shot Face Detector](https://arxiv.org/abs/1810.10220?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29) -* [ssd.pytorch](https://github.com/amdegroot/ssd.pytorch) \ No newline at end of file +* [ssd.pytorch](https://github.com/amdegroot/ssd.pytorch) diff --git a/demo.py b/demo.py index 0c5d9c3..6ae2231 100644 --- a/demo.py +++ b/demo.py @@ -33,7 +33,7 @@ help='Directory for detect result') parser.add_argument('--model', type=str, - default='weights/dsfd_face.pth', help='trained model') + default='weights/dsfd_vgg_0.880.pth', help='trained model') parser.add_argument('--thresh', default=0.4, type=float, help='Final confidence threshold') @@ -58,10 +58,8 @@ def detect(net, img_path, thresh): img = np.array(img) height, width, _ = img.shape - max_im_shrink = np.sqrt( - 1500 * 1000 / (img.shape[0] * img.shape[1])) - image = cv2.resize(img, None, None, fx=max_im_shrink, - fy=max_im_shrink, interpolation=cv2.INTER_LINEAR) + max_im_shrink = np.sqrt(1500 * 1000 / (img.shape[0] * img.shape[1])) + image = cv2.resize(img, None, None, fx=max_im_shrink,fy=max_im_shrink, interpolation=cv2.INTER_LINEAR) x = to_chw_bgr(image) x = x.astype('float32') @@ -88,23 +86,25 @@ def detect(net, img_path, thresh): j += 1 cv2.rectangle(img, left_up, right_bottom, (0, 0, 255), 2) conf = "{:.2f}".format(score) - text_size, baseline = cv2.getTextSize( - conf, cv2.FONT_HERSHEY_SIMPLEX, 0.3, 1) + text_size, baseline = cv2.getTextSize(conf, cv2.FONT_HERSHEY_SIMPLEX, 0.3, 1) p1 = (left_up[0], left_up[1] - text_size[1]) - cv2.rectangle(img, (p1[0] - 2 // 2, p1[1] - 2 - baseline), - (p1[0] + text_size[0], p1[1] + text_size[1]),[255,0,0], -1) - cv2.putText(img, conf, (p1[0], p1[ - 1] + baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1, 8) + cv2.rectangle(img, (p1[0] - 2 // 2, p1[1] - 2 - baseline),(p1[0] + text_size[0], p1[1] + text_size[1]),[255,0,0], -1) + cv2.putText(img, conf, (p1[0], p1[1] + baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1, 8) t2 = time.time() print('detect:{} timer:{}'.format(img_path, t2 - t1)) + print('Found:{}'.format(j)) cv2.imwrite(os.path.join(args.save_dir, os.path.basename(img_path)), img) if __name__ == '__main__': net = build_net('test', cfg.NUM_CLASSES, args.network) - net.load_state_dict(torch.load(args.model)) + if use_cuda: + net.load_state_dict(torch.load(args.model)) + else: + net.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage)) + net.eval() if use_cuda: @@ -112,7 +112,6 @@ def detect(net, img_path, thresh): cudnn.benckmark = True img_path = './img' - img_list = [os.path.join(img_path, x) - for x in os.listdir(img_path) if x.endswith('jpg')] + img_list = [os.path.join(img_path, x) for x in os.listdir(img_path) if x.endswith('jpg')] for path in img_list: detect(net, path, args.thresh) diff --git a/layers/functions/__init__.py b/layers/functions/__init__.py index bd1ef0f..5d4c8d3 100644 --- a/layers/functions/__init__.py +++ b/layers/functions/__init__.py @@ -1,5 +1,5 @@ from .prior_box import PriorBox -from detection import Detect +from .detection import Detect __all__=['Detect','PriorBox'] diff --git a/weights/README.md b/weights/README.md new file mode 100644 index 0000000..afcac15 --- /dev/null +++ b/weights/README.md @@ -0,0 +1 @@ +Github does not allow files over 100MB. For this reason, please download the model on [Baidu](https://pan.baidu.com/s/17cpDHEwYVxWmOIPqUy5zCQ) or [Google](https://drive.google.com/open?id=11pZy4DhslDP9cEk2uN9wNhTmB1VZ7GIR) and place it here.