Skip to content

Commit

Permalink
clean up and ready
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Mar 28, 2017
1 parent 27cf0b5 commit 6f517aa
Show file tree
Hide file tree
Showing 29 changed files with 139 additions and 2,027 deletions.
41 changes: 30 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@ The arXiv paper is available [here](http://arxiv.org/abs/1512.02325).
This example is intended for reproducing the nice detector while fully utilize the
remarkable traits of MXNet.
* The model is fully compatible with caffe version.
* Model converter from caffe is available, I'll release it once I can convert any symbol other than VGG16.
* Model [converter](#convert-caffemodel) from caffe is available now!
* The result is almost identical to the original version. However, due to different implementation details, the results might differ slightly.

### What's new
* Update to the latest version according to caffe version, with 5% mAP increase.
* Use C++ record iterator based on back-end multi-thread engine to achieve huge speed up on multi-gpu environments.
* Add symbol for 512x512 input.
* More network symbols under development and test.
* Extra operators are now in `mxnet/src/operator/contrib`, symbols are modified. Please use [Release-v0.2-beta](https://github.com/zhreshold/mxnet-ssd/releases/tag/v0.2-beta) for old models.

### Demo results
![demo1](https://cloud.githubusercontent.com/assets/3307514/19171057/8e1a0cc4-8be0-11e6-9d8f-088c25353b40.png)
![demo2](https://cloud.githubusercontent.com/assets/3307514/19171063/91ec2792-8be0-11e6-983c-773bd6868fa8.png)
Expand All @@ -23,7 +30,8 @@ remarkable traits of MXNet.
### mAP
| Model | Training data | Test data | mAP |
|:-----------------:|:----------------:|:---------:|:----:|
| VGG16_reduced 300x300 | VOC07+12 trainval| VOC07 test| 71.57|
| VGG16_reduced 300x300 | VOC07+12 trainval| VOC07 test| 77.4|
| VGG16_reduced 512x512 | VOC07+12 trainval | VOC07 test| 79.9|

### Speed
| Model | GPU | CUDNN | Batch-size | FPS* |
Expand All @@ -36,11 +44,11 @@ remarkable traits of MXNet.
- *Forward time only, data loading and drawing excluded.*

### Getting started
* You will need python modules: `easydict`, `cv2`, `matplotlib` and `numpy`.
* You will need python modules: `cv2`, `matplotlib` and `numpy`.
If you use mxnet-python api, you probably have already got them.
You can install them via pip or package manegers, such as `apt-get`:
```
sudo apt-get install python-opencv python-matplotlib python-numpy
sudo pip install easydict
```
* Clone this repo:
```
Expand All @@ -54,23 +62,24 @@ git clone --recursive https://github.com/zhreshold/mxnet-ssd.git
# git submodule update --recursive --init
cd mxnet-ssd/mxnet
```
* Build MXNet: `cd $REPO_ROOT/mxnet`. Follow the official instructions [here](http://mxnet.io/get_started/setup.html).
* Build MXNet: `cd /path/to/mxnet-ssd/mxnet`. Follow the official instructions [here](http://mxnet.io/get_started/setup.html).
```
# for Ubuntu/Debian
cp make/config.mk ./config.mk
# modify it if necessary
```
Remember to enable CUDA if you want to be able to train, since CPU training is
insanely slow. Using CUDNN is optional, it's not fully tested but should be fine.
insanely slow. Using CUDNN is optional, but highly recommanded.

### Try the demo
* Download the pretrained model: [`ssd_300_voc_0712.zip`](https://dl.dropboxusercontent.com/u/39265872/ssd_300_voc0712.zip), and extract to `model/` directory. (This model is converted from VGG_VOC0712_SSD_300x300_iter_60000.caffemodel provided by paper author).
* Download the pretrained model: [`ssd_300_voc_0712.zip`](https://dl.dropboxusercontent.com/u/39265872/ssd_300_voc0712.zip), and extract to `model/` directory.
* Run
```
# cd /path/to/mxnet-ssd
python demo.py
# play with examples:
python demo.py --epoch 0 --images ./data/demo/dog.jpg --thresh 0.5
# wait for library to load for the first time
```
* Check `python demo.py --help` for more options.

Expand Down Expand Up @@ -99,18 +108,26 @@ in the same `VOCdevkit` folder.
ln -s /path/to/VOCdevkit /path/to/this_example/data/VOCdevkit
```
Use hard link instead of copy could save us a bit disk space.
* Create packed binary file for faster training:
```
# cd /path/to/mxnet-ssd
bash tools/prepare_pascal.sh
# or if you are using windows
python tools/prepare_dataset.py --dataset pascal --year 2007,2012 --set trainval --target ./data/train.lst
python $tools/prepare_dataset.py --dataset pascal --year 2007 --set test --target ./data/val.lst --shuffle False
```
* Start training:
```
python train.py
```
* By default, this example will use `batch-size=32` and `learning_rate=0.002`.
* By default, this example will use `batch-size=32` and `learning_rate=0.004`.
You might need to change the parameters a bit if you have different configurations.
Check `python train.py --help` for more training options. For example, if you have 4 GPUs, use:
```
# note that a perfect training parameter set is yet to be discovered for multi-gpu
python train.py --gpus 0,1,2,3 --batch-size 128 --lr 0.0005
python train.py --gpus 0,1,2,3 --batch-size 128 --lr 0.001
```
* Memory usage: MXNet is very memory efficient, training on `VGG16_reduced` model with `batch-size` 32 takes around 4684MB without CUDNN.
* Memory usage: MXNet is very memory efficient, training on `VGG16_reduced` model with `batch-size` 32 takes around 4684MB without CUDNN(conv1_x and conv2_x fixed).

### Evalute trained model
Again, currently we only support evaluation on PASCAL VOC
Expand All @@ -125,9 +142,11 @@ Useful when loading python symbol is not available.
```
# cd /path/to/mxnet-ssd
python deploy.py --num-class 20
# then you can run demo with new model without loading python symbol
python demo.py --prefix model/ssd_300_deploy --epoch 0 --deploy
```

### Convert model from caffe
### Convert caffemodel
Converter from caffe is available at `/path/to/mxnet-ssd/tools/caffe_converter`

This is specifically modified to handle custom layer in caffe-ssd. Usage:
Expand Down
49 changes: 23 additions & 26 deletions dataset/iterator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import mxnet as mx
import numpy as np
import cv2
from tools.image_processing import resize, transform
from tools.rand_sampler import RandSampler

class DetRecordIter(mx.io.DataIter):
Expand Down Expand Up @@ -39,7 +38,7 @@ class DetRecordIter(mx.io.DataIter):
Returns:
----------
"""
def __init__(self, path_imgrec, batch_size, data_shape, path_imglist="",
label_width=-1, label_pad_width=-1, label_pad_value=-1,
Expand Down Expand Up @@ -149,7 +148,7 @@ def __init__(self, imdb, batch_size, data_shape, \
if isinstance(data_shape, int):
data_shape = (data_shape, data_shape)
self._data_shape = data_shape
self._mean_pixels = mean_pixels
self._mean_pixels = mx.nd.array(mean_pixels).reshape((3,1,1))
if not rand_samplers:
self._rand_samplers = []
else:
Expand Down Expand Up @@ -203,7 +202,7 @@ def next(self):
raise StopIteration

def getindex(self):
return self._current / self.batch_size
return self._current // self.batch_size

def getpad(self):
pad = self._current + self.batch_size - self._size
Expand All @@ -213,30 +212,28 @@ def _get_batch(self):
"""
Load data/label from dataset
"""
batch_data = []
batch_data = mx.nd.zeros((self.batch_size, 3, self._data_shape[0], self._data_shape[1]))
batch_label = []
for i in range(self.batch_size):
if (self._current + i) >= self._size:
if not self.is_train:
continue
# use padding from middle in each epoch
idx = (self._current + i + self._size / 2) % self._size
idx = (self._current + i + self._size // 2) % self._size
index = self._index[idx]
else:
index = self._index[self._current + i]
# index = self.debug_index
im_path = self._imdb.image_path_from_index(index)
img = cv2.imread(im_path)
with open(im_path, 'rb') as fp:
img_content = fp.read()
img = mx.img.imdecode(img_content)
gt = self._imdb.label_from_index(index).copy() if self.is_train else None
data, label = self._data_augmentation(img, gt)
batch_data.append(data)
batch_data[i] = data
if self.is_train:
batch_label.append(label)
# pad data if not fully occupied
for i in range(self.batch_size - len(batch_data)):
assert len(batch_data) > 0
batch_data.append(batch_data[0] * 0)
self._data = {'data': mx.nd.array(np.array(batch_data))}
self._data = {'data': batch_data}
if self.is_train:
self._label = {'label': mx.nd.array(np.array(batch_label))}
else:
Expand All @@ -262,32 +259,32 @@ def _data_augmentation(self, data, label):
xmax = int(crop[2] * width)
ymax = int(crop[3] * height)
if xmin >= 0 and ymin >= 0 and xmax <= width and ymax <= height:
data = data[ymin:ymax, xmin:xmax, :]
data = mx.img.fixed_crop(data, xmin, ymin, xmax-xmin, ymax-ymin)
else:
# padding mode
new_width = xmax - xmin
new_height = ymax - ymin
offset_x = 0 - xmin
offset_y = 0 - ymin
data_bak = data
data = np.full((new_height, new_width, 3), 128.)
data = mx.nd.full((new_height, new_width, 3), 128, dtype='uint8')
data[offset_y:offset_y+height, offset_x:offset_x + width, :] = data_bak
label = rand_crops[index][1]

if self.is_train and self._rand_mirror:
if np.random.uniform(0, 1) > 0.5:
data = cv2.flip(data, 1)
valid_mask = np.where(label[:, 0] > -1)[0]
tmp = 1.0 - label[valid_mask, 1]
label[valid_mask, 1] = 1.0 - label[valid_mask, 3]
label[valid_mask, 3] = tmp

if self.is_train:
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, \
cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
else:
interp_methods = [cv2.INTER_LINEAR]
interp_method = interp_methods[int(np.random.uniform(0, 1) * len(interp_methods))]
data = resize(data, self._data_shape, interp_method)
data = transform(data, self._mean_pixels)
data = mx.img.imresize(data, self._data_shape[1], self._data_shape[0], interp_method)
if self.is_train and self._rand_mirror:
if np.random.uniform(0, 1) > 0.5:
data = mx.nd.flip(data, axis=1)
valid_mask = np.where(label[:, 0] > -1)[0]
tmp = 1.0 - label[valid_mask, 1]
label[valid_mask, 1] = 1.0 - label[valid_mask, 3]
label[valid_mask, 3] = tmp
data = mx.nd.transpose(data, (2,0,1))
data = data.astype('float32')
data = data - self._mean_pixels
return data, label
7 changes: 4 additions & 3 deletions dataset/pascal_voc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import os
import numpy as np
from imdb import Imdb
Expand Down Expand Up @@ -128,7 +129,7 @@ def _label_path_from_index(self, index):
full path of annotation file
"""
label_file = os.path.join(self.data_path, 'Annotations', index + '.xml')
assert os.path.exists(label_file), 'Path does not exist: {}'.format(image_file)
assert os.path.exists(label_file), 'Path does not exist: {}'.format(label_file)
return label_file

def _load_image_labels(self):
Expand Down Expand Up @@ -220,7 +221,7 @@ def write_pascal_results(self, all_boxes):
None
"""
for cls_ind, cls in enumerate(self.classes):
print 'Writing {} VOC results file'.format(cls)
print('Writing {} VOC results file'.format(cls))
filename = self.get_result_file_template().format(cls)
with open(filename, 'wt') as f:
for im_ind, index in enumerate(self.image_set_index):
Expand Down Expand Up @@ -250,7 +251,7 @@ def do_python_eval(self):
aps = []
# The PASCAL VOC metric changed in 2010
use_07_metric = True if int(self.year) < 2010 else False
print 'VOC07 metric? ' + ('Y' if use_07_metric else 'No')
print('VOC07 metric? ' + ('Y' if use_07_metric else 'No'))
for cls_ind, cls in enumerate(self.classes):
filename = self.get_result_file_template().format(cls)
rec, prec, ap = voc_eval(filename, annopath, imageset_file, cls, cache_dir,
Expand Down
2 changes: 1 addition & 1 deletion dataset/yolo_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, name, classes, list_file, image_dir, label_dir, \
classes = [l.strip() for l in f.readlines()]
num_classes = len(classes)
else:
raise ValueError, "classes should be list/tuple or text file"
raise ValueError("classes should be list/tuple or text file")
assert num_classes > 0, "number of classes must > 0"
super(YoloFormat, self).__init__(name + '_' + str(num_classes))
self.classes = classes
Expand Down
4 changes: 2 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx,

def parse_args():
parser = argparse.ArgumentParser(description='Single-shot detection network demo')
parser.add_argument('--network', dest='network', type=str, default='ssd_300',
choices=['ssd_300'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--images', dest='images', type=str, default='./data/demo/dog.jpg',
help='run demo with images, use comma(without space) to seperate multiple images')
parser.add_argument('--dir', dest='dir', nargs='?',
Expand Down
9 changes: 5 additions & 4 deletions deploy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import argparse
import tools.find_mxnet
import mxnet as mx
Expand All @@ -7,8 +8,8 @@

def parse_args():
parser = argparse.ArgumentParser(description='Convert a trained model to deploy model')
parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
choices=['vgg16_reduced'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--epoch', dest='epoch', help='epoch of trained model',
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='trained model prefix',
Expand All @@ -32,5 +33,5 @@ def parse_args():
tmp = args.prefix.rsplit('/', 1)
save_prefix = '/deploy_'.join(tmp)
mx.model.save_checkpoint(save_prefix, args.epoch, net, arg_params, aux_params)
print "Saved model: {}-{:04d}.param".format(save_prefix, args.epoch)
print "Saved symbol: {}-symbol.json".format(save_prefix)
print("Saved model: {}-{:04d}.param".format(save_prefix, args.epoch))
print("Saved symbol: {}-symbol.json".format(save_prefix))
7 changes: 4 additions & 3 deletions detect/detector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import mxnet as mx
import numpy as np
from timeit import default_timer as timer
Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, context=ctx)
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.data_shape = data_shape
self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))])
self.mod.set_params(args, auxs)
Expand Down Expand Up @@ -62,8 +63,8 @@ def detect(self, det_iter, show_timer=False):
detections = self.mod.predict(det_iter).asnumpy()
time_elapsed = timer() - start
if show_timer:
print "Detection time for {} images: {:.4f} sec".format(
num_images, time_elapsed)
print("Detection time for {} images: {:.4f} sec".format(
num_images, time_elapsed))
result = []
for i in range(detections.shape[0]):
det = detections[i, :, :]
Expand Down
11 changes: 3 additions & 8 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def parse_args():
default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str)
parser.add_argument('--list-path', dest='list_path', help='which list file to use',
default="", type=str)
parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
choices=['vgg16_reduced', 'ssd_300'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
help='evaluation batch size')
parser.add_argument('--num-class', dest='num_class', type=int, default=20,
Expand All @@ -31,7 +31,7 @@ def parse_args():
default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str)
parser.add_argument('--gpus', dest='gpu_id', help='GPU devices to evaluate with',
default='0', type=str)
parser.add_argument('--cpu', dest='cpu', help='use cpu to evaluate',
parser.add_argument('--cpu', dest='cpu', help='use cpu to evaluate, this can be slow',
action='store_true')
parser.add_argument('--data-shape', dest='data_shape', type=int, default=300,
help='set image shape')
Expand Down Expand Up @@ -78,11 +78,6 @@ def parse_args():
else:
class_names = None

# evaluate_net(args.network, args.dataset, args.devkit_path,
# (args.mean_r, args.mean_g, args.mean_b), args.data_shape,
# args.prefix, args.epoch, ctx, year=args.year,
# sets=args.eval_set, batch_size=args.batch_size,
# nms_thresh=args.nms_thresh, force_nms=args.force_nms)
network = None if args.deploy_net else args.network
evaluate_net(network, args.rec_path, num_class,
(args.mean_r, args.mean_g, args.mean_b), args.data_shape,
Expand Down
15 changes: 9 additions & 6 deletions evaluate/eval_voc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""
given a pascal voc imdb, compute mAP
"""

from __future__ import print_function
import numpy as np
import os
import cPickle
try:
import cPickle as pickle
except ImportError:
import pickle


def parse_voc_rec(filename):
Expand Down Expand Up @@ -88,13 +91,13 @@ def voc_eval(detpath, annopath, imageset_file, classname, cache_dir, ovthresh=0.
for ind, image_filename in enumerate(image_filenames):
recs[image_filename] = parse_voc_rec(annopath.format(image_filename))
if ind % 100 == 0:
print 'reading annotations for {:d}/{:d}'.format(ind + 1, len(image_filenames))
print 'saving annotations cache to {:s}'.format(cache_file)
print('reading annotations for {:d}/{:d}'.format(ind + 1, len(image_filenames)))
print('saving annotations cache to {:s}'.format(cache_file))
with open(cache_file, 'w') as f:
cPickle.dump(recs, f)
pickle.dump(recs, f)
else:
with open(cache_file, 'r') as f:
recs = cPickle.load(f)
recs = pickle.load(f)

# extract objects in :param classname:
class_recs = {}
Expand Down
Loading

0 comments on commit 6f517aa

Please sign in to comment.