Skip to content

Commit

Permalink
added tflite support (experimental) (#40)
Browse files Browse the repository at this point in the history
* tflite support (wip)

* linting

* just check for "resnet" in model path

* fallback to single tensor input

* adjust image_data shape to expected input shape

also optional image_size input

* removed commented out code

* only resize tensor if input image doesn't match

* added command to convert to tflite

* added --model-architecture argument

* added quantization arguments

* updated readme

* linting

* typo in readme

* improved wording in readme
  • Loading branch information
de-code authored Nov 19, 2020
1 parent ebc5a4a commit 2fbefbb
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 4 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ disable=
too-few-public-methods,
too-many-arguments,
too-many-instance-attributes,
duplicate-code,
invalid-name
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,31 @@ python -m tf_bodypix \
--threshold=0.75
```

## TensorFlow Lite support (experimental)

The model path may also point to a TensorFlow Lite model (`.tflite` extension). Whether that actually improves performance may depend on the platform and available hardware.

You could convert one of the available TensorFlow JS models to TensorFlow Lite using the following command:

```bash
python -m tf_bodypix \
convert-to-tflite \
--model-path \
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \
--optimize \
--quantization-type=float16 \
--output-model-file "./mobilenet-float16-stride16.tflite"
```

The above command is provided for convenience.
You may use alternative methods depending on your preference and requirements.

Relevant links:

* [TensorFlow Lite converter](https://www.tensorflow.org/lite/convert/)
* [TF Lite post_training_quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)
* [TF GitHub #40183](https://github.com/tensorflow/tensorflow/issues/40183).

## Acknowledgements

* [Original TensorFlow JS Implementation of BodyPix](https://github.com/tensorflow/tfjs-models/tree/body-pix-v2.0.4/body-pix)
Expand Down
20 changes: 20 additions & 0 deletions tests/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

from tf_bodypix.download import BodyPixModelPaths
from tf_bodypix.model import ModelArchitectureNames
from tf_bodypix.cli import main


Expand Down Expand Up @@ -87,3 +88,22 @@ def test_should_list_all_default_model_urls(self, capsys):
LOGGER.debug('output_urls: %s', output_urls)
missing_urls = set(expected_urls) - set(output_urls)
assert not missing_urls

def test_should_be_able_to_convert_to_tflite_and_use_model(self, temp_dir: Path):
output_model_file = temp_dir / 'model.tflite'
main([
'convert-to-tflite',
'--model-path=%s' % BodyPixModelPaths.MOBILENET_FLOAT_75_STRIDE_16,
'--optimize',
'--quantization-type=int8',
'--output-model-file=%s' % output_model_file
])
output_image_path = temp_dir / 'mask.jpg'
main([
'draw-mask',
'--model-path=%s' % output_model_file,
'--model-architecture=%s' % ModelArchitectureNames.MOBILENET_V1,
'--output-stride=16',
'--source=%s' % EXAMPLE_IMAGE_URL,
'--output=%s' % output_image_path
])
61 changes: 60 additions & 1 deletion tf_bodypix/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from contextlib import ExitStack
from itertools import cycle
from pathlib import Path
from time import time
from typing import Dict, List

Expand All @@ -25,8 +26,10 @@
)
from tf_bodypix.utils.s3 import iter_s3_file_urls
from tf_bodypix.download import download_model
from tf_bodypix.tflite import get_tflite_converter_for_model_path
from tf_bodypix.model import (
load_model,
VALID_MODEL_ARCHITECTURE_NAMES,
PART_CHANNELS,
DEFAULT_RESIZE_METHOD,
BodyPixModelWrapper,
Expand Down Expand Up @@ -77,6 +80,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=DEFAULT_MODEL_PATH,
help="The path or URL to the bodypix model."
)
parser.add_argument(
"--model-architecture",
choices=VALID_MODEL_ARCHITECTURE_NAMES,
help=(
"The model architecture."
" It will be guessed from the model path if not specified."
)
)
parser.add_argument(
"--output-stride",
type=int,
Expand Down Expand Up @@ -219,7 +230,8 @@ def load_bodypix_model(args: argparse.Namespace) -> BodyPixModelWrapper:
return load_model(
local_model_path,
internal_resolution=args.internal_resolution,
output_stride=args.output_stride
output_stride=args.output_stride,
architecture_name=args.model_architecture
)


Expand Down Expand Up @@ -266,6 +278,52 @@ def run(self, args: argparse.Namespace): # pylint: disable=unused-argument
print('\n'.join(bodypix_model_json_files))


class ConvertToTFLiteSubCommand(SubCommand):
def __init__(self):
super().__init__("convert-to-tflite", "Converts the model to a tflite model")

def add_arguments(self, parser: argparse.ArgumentParser):
add_common_arguments(parser)
parser.add_argument(
"--model-path",
default=DEFAULT_MODEL_PATH,
help="The path or URL to the bodypix model."
)
parser.add_argument(
"--output-model-file",
required=True,
help="The path to the output file (tflite model)."
)
parser.add_argument(
"--optimize",
action='store_true',
help="Enable optimization (quantization)."
)
parser.add_argument(
"--quantization-type",
choices=['float16', 'float32', 'int8'],
help="The quantization type to use."
)

def run(self, args: argparse.Namespace): # pylint: disable=unused-argument
LOGGER.info('converting model: %s', args.model_path)
converter = get_tflite_converter_for_model_path(download_model(
args.model_path
))
tflite_model = converter.convert()
if args.optimize:
LOGGER.info('enabled optimization')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if args.quantization_type:
LOGGER.info('quanization type: %s', args.quantization_type)
quantization_type = getattr(tf, args.quantization_type)
converter.target_spec.supported_types = [quantization_type]
converter.inference_input_type = quantization_type
converter.inference_output_type = quantization_type
LOGGER.info('saving tflite model to: %s', args.output_model_file)
Path(args.output_model_file).write_bytes(tflite_model)


class AbstractWebcamFilterApp(ABC):
def __init__(self, args: argparse.Namespace):
self.args = args
Expand Down Expand Up @@ -497,6 +555,7 @@ def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp:

SUB_COMMANDS: List[SubCommand] = [
ListModelsSubCommand(),
ConvertToTFLiteSubCommand(),
DrawMaskSubCommand(),
BlurBackgroundSubCommand(),
ReplaceBackgroundSubCommand()
Expand Down
68 changes: 65 additions & 3 deletions tf_bodypix/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,66 @@ def get_structured_output_names(structured_outputs: List[tf.Tensor]) -> List[str
]


def to_number_of_dimensions(data: np.ndarray, dimension_count: int) -> np.ndarray:
while len(data.shape) > dimension_count:
data = data[0]
while len(data.shape) < dimension_count:
data = np.expand_dims(data, axis=0)
return data


def load_tflite_model(model_path: str):
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
LOGGER.debug('input_details: %s', input_details)
input_names = [item['name'] for item in input_details]
LOGGER.debug('input_names: %s', input_names)
input_details_map = dict(zip(input_names, input_details))

output_details = interpreter.get_output_details()
LOGGER.debug('output_details: %s', output_details)
output_names = [item['name'] for item in output_details]
LOGGER.debug('output_names: %s', output_names)

try:
image_input = input_details_map['image']
except KeyError:
assert len(input_details_map) == 1
image_input = list(input_details_map.values())[0]
input_shape = image_input['shape']
LOGGER.debug('input_shape: %s', input_shape)

def predict(image_data: np.ndarray):
nonlocal input_shape
image_data = to_number_of_dimensions(image_data, len(input_shape))
LOGGER.debug('tflite predict, image_data.shape=%s (%s)', image_data.shape, image_data.dtype)
height, width, *_ = image_data.shape
if tuple(image_data.shape) != tuple(input_shape):
LOGGER.info('resizing input tensor: %s -> %s', tuple(input_shape), image_data.shape)
interpreter.resize_tensor_input(image_input['index'], list(image_data.shape))
interpreter.allocate_tensors()
input_shape = image_data.shape
interpreter.set_tensor(image_input['index'], image_data)
if 'image_size' in input_details_map:
interpreter.set_tensor(
input_details_map['image_size']['index'],
np.array([height, width], dtype=np.float)
)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
return {
item['name']: interpreter.get_tensor(item['index'])
for item in output_details
}
return predict


def load_using_saved_model_and_get_predict_function(model_path):
loaded = tf.saved_model.load(model_path)
LOGGER.debug('loaded: %s', loaded)
Expand All @@ -366,24 +426,26 @@ def load_using_tfjs_graph_converter_and_get_predict_function(
def load_model_and_get_predict_function(
model_path: str
) -> Callable[[np.ndarray], dict]:
if model_path.endswith('.tflite'):
return load_tflite_model(model_path)
try:
return load_using_saved_model_and_get_predict_function(model_path)
except OSError:
return load_using_tfjs_graph_converter_and_get_predict_function(model_path)


def get_output_stride_from_model_path(model_path: str) -> int:
match = re.search(r'stride(\d+)', model_path)
match = re.search(r'stride(\d+)|_(\d+)_quant', model_path)
if not match:
raise ValueError('cannot extract output stride from model path: %r' % model_path)
return int(match.group(1))
return int(match.group(1) or match.group(2))


def get_architecture_from_model_path(model_path: str) -> int:
model_path_lower = model_path.lower()
if 'mobilenet' in model_path_lower:
return ModelArchitectureNames.MOBILENET_V1
if 'resnet50' in model_path_lower:
if 'resnet' in model_path_lower:
return ModelArchitectureNames.RESNET_50
raise ValueError('cannot extract model architecture from model path: %r' % model_path)

Expand Down
25 changes: 25 additions & 0 deletions tf_bodypix/tflite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging

import tensorflow as tf

try:
import tfjs_graph_converter
except ImportError:
tfjs_graph_converter = None


LOGGER = logging.getLogger(__name__)


def get_tflite_converter_for_tfjs_model_path(model_path: str) -> tf.lite.TFLiteConverter:
if tfjs_graph_converter is None:
raise ImportError('tfjs_graph_converter required')
graph = tfjs_graph_converter.api.load_graph_model(model_path)
tf_fn = tfjs_graph_converter.api.graph_to_function_v2(graph)
return tf.lite.TFLiteConverter.from_concrete_functions([tf_fn])


def get_tflite_converter_for_model_path(model_path: str) -> tf.lite.TFLiteConverter:
LOGGER.debug('converting model_path: %s', model_path)
# if model_path.endswith('.json'):
return get_tflite_converter_for_tfjs_model_path(model_path)

0 comments on commit 2fbefbb

Please sign in to comment.