Deep Learning (DL) frameworks (e.g. TensorFlow) provide a fast way to develop robust image classifiers. The desing process leverages the collected data to train personalized detection pipelines, which can be embedded into portable low-power devices, such as GAP8. In this context, the GAP flow is the missing piece between high-level DL frameworks and the processing engine GAP8.
To showcase the customization and embedding of a deep learning pipeline for visual object spotting into a GAP8-based smart sensor, this project demonstrates the design steps in the figure below in case of image vehicle spotting. Nevertheless, you can follow the same steps to easily and fastly build your own smart camera system!
The repository includes a trained TF model for image vehicle spotting. To covert the model into C code with the GAP flow and run the application-code on a GAP8-based camera system:
make clean all run [RGB=1 or 0]
The application works either with greyscale (default, Himax HM01B0) or RGB camera sensors (GalaxyCore GC0308); TF models are trained accrodingly on greyscale or RGB augumented data.
The application code runs also on the GAP8 software simulator GVSOC:
make clean all run platform=gvsoc [RGB=1 or 0]
where the visual pipeline is fed with a sample image loaded through JTAG (images/COCO_val2014_000000000641.ppm).
In the following, we detail the design steps to train and deploy a custom visual spotting model for vehicle detectecion.
Tensorflow 1.13
Tensorflow Slim
Concerning the TF Slim scripts, we refer to the committed version added as a submodule to the repository.
GapSDK 3.5+
- Dataset Preparation and DL Model training
- Deep Model Deployment on GAP8 with the GAP flow
- Accuracy Validation with Platform Emulator
We refer to the open-source [Tensorflow Slim] framework (TF1.x) to build a custom image dataset and to train the DL model for visual image spotting. With this tutorial, we will guide you through the usage of the TF framework to reproduce our results for vehicles spotting. Specifically we:
- Build a dataset
- Train a DL model, leveraging on quantization-aware training
- Freeze the graph and convert to TFlite format
Afterwars, the GAP flow is fed with the tflite model.
The custom dataset for image vehicle spotting is distilled from the COCO dataset: the presence of the target object - vehiles - determines the new label of every image of the source dataset. In our case, the target object consists of a list of classes: 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat'. The TF framework provides a script to automatically download (if not already done) and convert the COCO dataset into the new dataset for image visual spotting, according to the TF-Record format:
python3 slim/download_and_convert.py --dataset_name=visualwakewords \
--dataset_dir=visualwakewords_vehicle \
--foreground_class_of_interest='bicycle','car','motorcycle','airplane','bus','train','truck','boat'\
--small_object_area_threshold=0.05 \
--download \
--coco_dir=coco \
Arguments details:
- dataset_name: name of the dataset to download (one of "flowers", "cifar10", "mnist", "visualwakewords")
- dataset_dir: where to store the dataset in case of "visualwakewords":
- foreground_class_of_interest: list of COCO object which you want to detect in images
- small_object_area_threshold: minimum percentage of area of the interested object to promote the image to the visualwakewords label of true
- download: whether to download the entire coco dataset or not if already did it
- coco_dir: if download=False where to store the coco dataset, if download=True where it is stored
Now, the dataset is ready to feed the training process.
We select a MobileNetV2 to solve the image classification problem, with input resolution 224x224 input dimensions and width multiplier of 1 in case of grayscale images and 0.35 in case of RGB images. However, different models included in the TF SLim framework can be used as well.
A quantization-aware finetuning process takes places during the training phase. The model training is launched with:
python3 train_image_classifier.py \
--train_dir='vww_vehicle_train_grayscale' \
--dataset_name='visualwakewords' \
--dataset_split_name=train \
--dataset_dir='./visualwakewords_vehicle/' \
--log_every_n_steps=100 \
--model_name='mobilenet_v2' \
--checkpoint_path='./vww_vehicle_train_grayscale/' \
--max_number_of_steps=100000 \
--num_clones=1 \
--quantize_delay=90000 \
--use_grayscale
- use_grayscale: (optional): to apply grayscale conversion to the dataset samples
- train_dir: path to store checkpoints and training info
- dataset_name: training dataset name
- dataset_split: dataset partition to use for training
- dataset_dir: path of dataset (TF-Records)
- model_name: netowork architecture to train (slim/nets/nets_factory.py for a complete list of supported networks)
- checkpoint_path: path of the folder contrianing the checkpoint files, if not defined the network is trained from scratch
- max_number_of_steps: number of training steps
- num_clones: number of GPU to use for training
- quantize_delay: after how many steps the model is quantized
Since the target platform features optimized INT8 computational kernels, the quantization API of TensorFlow contrib has been replaced with its experimental version (contrib_quantize.experimental_create_training_graph_(symmetric=True)). On the contrary, you can force: --quantize_sym=False .
The model can now be evaluated the validation dataset:
python3 eval_image_classifier.py \
--checkpoint_path='vww_train_vehicle_grayscale/' \
--eval_dir='vww_eval_vehicle_grayscale/' \
--dataset_split_name=val \
--dataset_dir='visualwakewords_vehicle/' \
--dataset_name='visualwakewords' \
--model_name='mobilenet_v2' \
--quantize \ #if the model has been trained with quantization
--use_grayscale
The script computes the ratio of correct predictions, the number of false positives and false negatives. All these metrics can be inspected through tensorboard:
tensorboard --logdir='vww_eval_vehicle_grayscale'
To investigate the validation score at training time, we have added a bash script which evaluates the model validation accuracy at every epoch:
./train_eval_loop.sh -m ${MODEL_NAME} -b ${BATCH_SIZE} -e ${NUM_EPOCHS} -l ${LEARNING_RATE} -q ${QUANT_DELAY} -i ${IMAGE_SIZE} -g [to use grayscale]
NOTE: dataset name and directory are hardcoded in the script. Change them accordingly in your project.
To export the inference graph, i.e. the tensorflow graphdef file for the inference:
python3 slim/export_inference_graph.py
--model_name=mobilenet_v2 \
--image_size=224 \
--output_file=./mobilenet_v2_224_grayscale.pb \
--quantize \
--use_grayscale
To obtain a frozen graph:
freeze_graph \
--input_graph=./mobilenet_v2_224_grayscale.pb \
--output_graph=./frozen_mbv2_224_grayscale.pb \
--input_checkpoint=./vww_train_vehicle_grayscale/model.ckpt-100000 \
--input_binary=true \
--output_node_names=MobilenetV2/Predictions/Reshape_1
To inspect the graph and get the output_node_names you can use Netron Lastly, the TFLite is generated by means of:
tflite_converter --graph_def=./frozen_mbv2_224_grayscale.pb \
--output_file=mbv2_grayscale.tflite \
--input_arrays=input \
--output_arrays=MobilenetV2/Predictions/Reshape_1 \
# if the model has been trained with quantization (suggested)
--inference_type=QUANTIZED_UINT8 \
--std_dev_val=128 \
--mean_val=128
The GAP flow converts a frozen TFlite model into a GAP-optimized C code.
The GAP flow consists of a two-steps procedure. Firstly, we use the nntool to convert the high level tflite model description into an special graph description, which will be used by the Autotiler tool. This latter leverages on the predictable memory access pattern of convolutional neural network to generate optimized C-code for the GAP platforms.
This tool, besides producing the graph description needed by the Autotiler to optimize the memory accesses, includes debug features to inspect the target DL model.
Initially, the tflite model is opened with nntool:
nntool mbv2_grayscale.tflite [-q to load also the quantization information]
show
show will display the list of layers and the network topology. The model graph is now modified to match the Autotiler execution model:
adjust
fusions --scale8
This two commands tranlsate the graph tensors from a HxWxC data layout (TF) to CxHxW data layout (GAP8 computational kernels). Moreover, fusions finds and replaces all the subgraph portions that the Autotiler can handle with a single optimized layer (e.g. Conv+Pool+Relu into one single layer).
Only if the model was not quantized:
The model runs inferences on several images, i.e. the calibration dataset, to estimate the activation quantization ranges (post-training quantization). Weights and bias quantization ranges are computed statically.
aquant -f 8 -s calibatrion_images/*
The nntool can also validate the accuracy of the model over a validation dataset:
validation dataset/* [-q to run with quantized kernels]
the default behaviour (see --help for more information) interpretes the last character of the filename as the label (e.g. COCO_0000000_1.png means that the true value of the prediction is 1). If the validation scores are poor, you can inspect the quantization log and see if the quantization error measured in terms of layer-wise QSNR (quantization signal-to-noise-ratio) wrt float values is too high:
qerror image.png
NOTE: wnen running inference in nntool, the input must be preprocessed accordingly to what it is done in the training process. For example, in our case the model expects a [-1:1] input but input data belongs to the range [0:255], a normalization input function can be configured or selected among the ones available:
imageformat input_1 bw8 offset_int8
Lastly, the nntool state can be saved:
save_state
This will save the adjusted+fused+quantized nntool graph in a .json file beside all its parameters .nnparam. At this point you can generate the Autotiler Model code from the nntool saved state and the constants tensors files for weights and biases:
nntool -g model.json -m ATModel.c -T path/to/tensors/folder
All the above steps are automatized in the common/model_rules.mk procedure, which also includes the Autotiler model compilation and code generation. The Autotiler will generate 3 functions which you can place into your application code:
- Graph Constructor: when it is called allocates all the tensors needed by the graph in memory
- Graph Run: run the series of layers with optimal memory movements
- Graph Destructor: deallocate all the structures allocated by the Constructor
Now it is time to run visual infecences on a GAP8-based camera platform. Don't worry if you do not have one, Greenwaves provides a platform simulator gvsoc. You can take a loog to the application code in main.c and run it
make clean all run [RGB=0 or 1]
The Autotiler features an optional, yet very useful, run mode: the __EMUL__ mode.
If the flag is enabled, the Autotiler replaces all the GAP parallel code and built-in
functions with x86 instructions, which can be executed by the host x86 PC.
This speed-up the simulation time for functional testing on the host PC, with respect then using the GVSOC platform.
main_accuracy.c and emul.mk leverages on this feature.
main_accuracy.c gets a folder path and run infernce over the data samples (images) to return the total accuracy:
make -f emul.mk clean all TEST_ACC=1 RGB=0 or 1
./mobv2_vwwvehicle_quant_asym_emul /path/to/dataset/
NOTE: only non-compressed .ppm images format is supported.