From 09ee4d198defabafd1d2db2a5d19c4b52743d0e0 Mon Sep 17 00:00:00 2001 From: Daniel Ecer Date: Tue, 22 Feb 2022 21:08:15 +0000 Subject: [PATCH] added tflite runtime support (#167) * Release the tflite inference from tensorflow * Update README * Remove the bodypix_tflite from develop branch * Add tflite inference to the tflite_inference branch * added initial build_tflite workflow job * added --use-feature=in-tree-build * don't install tflite by default * moved build_tflite up * added tflite extra * using dev-install-tflite * make dev-install-tflite install build and dev depenencies * run pytest for tflite * using tflite extra when installing tflite * added make dev-pytest-tflite * linting: addressed markdown linting * made tensorflow import optional * added test_should_be_able_to_use_existing_tflite_model * import tflite_runtime.interpreter * extracted load_image * load image using pillow * adapted pad_and_resize_to using _pad_image_like_tensorflow * implemented _resize_image_to_using_pillow * added make dev-watch-tflite * fallback to np expand_dims without tf * extracted _get_mobilenet_preprocessed_image with np fallback * reuse resize_image_to for scale_and_crop_to_input_tensor_shape * automatically reduce dimension if needed * added support for single channel in _resize_image_to_using_pillow * extracted get_sigmoid and implemented np version * reuse resize_image_to * fixed failing test * removed trailing space from requirements.txt * cli: automatically select tflite model if full tf is not available * don't fail with missing tf when adding alpha mask * added tflite support to draw mask cli * added support for remote tflite models; defined model tflite paths * use model path constants for cli * added TensorFlow Lite Runtime support section to readme * ignore tflite models * removed obsolete bodypix_tflite diectory * fixed draw mask * replaced pillow resize with numpy handling floats * retain original dtype when padding * use float32 for imagenet preprocessing * debug logging of input image * added list-tflite-models sub command * fixed resnet tflite support * added more tflite models Co-authored-by: MrRiahi --- .github/workflows/ci.yml | 24 ++ .gitignore | 1 + Makefile | 89 ++++++- README.md | 19 +- requirements.tflite.txt | 1 + requirements.txt | 2 +- setup.py | 10 +- tests/cli_test.py | 24 +- .../bodypix_js_utils/decode_part_map.py | 14 +- tf_bodypix/bodypix_js_utils/util.py | 125 ++++++++-- tf_bodypix/cli.py | 76 ++++-- tf_bodypix/download.py | 67 ++++- tf_bodypix/model.py | 82 +++++-- tf_bodypix/sink.py | 10 +- tf_bodypix/source.py | 11 +- tf_bodypix/tflite.py | 9 +- tf_bodypix/utils/image.py | 231 +++++++++++++++++- 17 files changed, 698 insertions(+), 97 deletions(-) create mode 100644 requirements.tflite.txt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f739e88..3d08ef6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,30 @@ jobs: env: TEST_PYPI_PASSWORD: ${{ secrets.test_pypi_password }} + build_tflite: + needs: [] + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + include: + - python-version: 3.8 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + make venv-create SYSTEM_PYTHON=python + make dev-install-tflite + - name: Test with pytest + run: | + make dev-pytest-tflite + build: needs: ["check_secrets"] runs-on: ${{ matrix.os }} diff --git a/.gitignore b/.gitignore index fbcb3d5..830f34e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ build *.egg-info *.pyc +*.tflite diff --git a/Makefile b/Makefile index 7fa4cd9..dbdf8d3 100644 --- a/Makefile +++ b/Makefile @@ -45,13 +45,25 @@ venv-create: $(SYSTEM_PYTHON) -m venv $(VENV) -dev-install: +dev-install-build-dependencies: $(PIP) install -r requirements.build.txt + + +dev-install: dev-install-build-dependencies $(PIP) install \ -r requirements.dev.txt \ -r requirements.txt +dev-install-tflite: dev-install-build-dependencies + $(PIP) install -r requirements.dev.txt + $(PIP) install --use-feature=in-tree-build .[tflite,image] + + +dev-run-pip: + $(PIP) $(ARGS) + + dev-venv: venv-create dev-install @@ -75,10 +87,20 @@ dev-pytest: $(PYTHON) -m pytest -p no:cacheprovider $(ARGS) +dev-pytest-tflite: + $(MAKE) dev-pytest \ + ARGS='tests/cli_test.py -k test_should_be_able_to_use_existing_tflite_model' + + dev-watch: $(PYTHON) -m pytest_watch -- -p no:cacheprovider -p no:warnings $(ARGS) +dev-watch-tflite: + $(MAKE) dev-watch \ + ARGS='tests/cli_test.py -k test_should_be_able_to_use_existing_tflite_model' + + dev-test: dev-lint dev-pytest @@ -114,6 +136,11 @@ list-models: list-models +list-tflite-models: + $(PYTHON) -m tf_bodypix \ + list-tflite-models + + convert-example-draw-mask: $(PYTHON) -m tf_bodypix \ draw-mask \ @@ -240,6 +267,66 @@ webcam-v4l2-replace-background: $(ARGS) +convert-tfjs-models-to-tflite: + mkdir -p "./data/tflite-models" + $(PYTHON) -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/050/model-stride8.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./data/tflite-models/mobilenet-float-multiplier-050-stride8-float16.tflite" + $(PYTHON) -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/050/model-stride16.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./data/tflite-models/mobilenet-float-multiplier-050-stride16-float16.tflite" + $(PYTHON) -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride8.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./data/tflite-models/mobilenet-float-multiplier-075-stride8-float16.tflite" + $(PYTHON) -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./data/tflite-models/mobilenet-float-multiplier-075-stride16-float16.tflite" + $(PYTHON) -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/100/model-stride8.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./data/tflite-models/mobilenet-float-multiplier-100-stride8-float16.tflite" + $(PYTHON) -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/100/model-stride16.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./data/tflite-models/mobilenet-float-multiplier-100-stride16-float16.tflite" + $(PYTHON) -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/resnet50/float/model-stride16.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./data/tflite-models/resnet50-float-stride16-float16.tflite" + $(PYTHON) -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/resnet50/float/model-stride32.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./data/tflite-models/resnet50-float-stride32-float16.tflite" + + docker-build: docker build . -t $(IMAGE_NAME):$(IMAGE_TAG) diff --git a/README.md b/README.md index f922a56..f1ea7f5 100644 --- a/README.md +++ b/README.md @@ -31,10 +31,11 @@ when using this project as a library: | ---------- | ----------- | tf | [TensorFlow](https://pypi.org/project/tensorflow/) (required). But you may use your own build. | tfjs | TensorFlow JS Model support, using [tfjs-graph-converter](https://pypi.org/project/tfjs-graph-converter/) +| tflite | [tflite-runtime](https://pypi.org/project/tflite-runtime/) | image | Image loading via [Pillow](https://pypi.org/project/Pillow/), required by the CLI. | video | Video support via [OpenCV](https://pypi.org/project/opencv-python/) | webcam | Webcam support via [OpenCV](https://pypi.org/project/opencv-python/) and [pyfakewebcam](https://pypi.org/project/pyfakewebcam/) -| all | All of the libraries +| all | All of the libraries (except `tflite-runtime`) ## Python API @@ -117,6 +118,12 @@ Those URLs can be passed as the `--model-path` arguments below, or to the `downl The CLI will download and cache the model from the provided path. If no `--model-path` is provided, it will use a default model (mobilenet). +To list TensorFlow Lite models instead: + +```bash +python -m tf_bodypix list-tflite-models +``` + ### Inputs and Outputs Most commands will work with inputs (source) and outputs. @@ -317,7 +324,7 @@ python -m tf_bodypix \ Background: [Brown Landscape Under Grey Sky](https://www.pexels.com/photo/brown-landscape-under-grey-sky-3244513/) -## TensorFlow Lite support (experimental) +## TensorFlow Lite Model support (experimental) The model path may also point to a TensorFlow Lite model (`.tflite` extension). Whether that actually improves performance may depend on the platform and available hardware. @@ -330,7 +337,7 @@ python -m tf_bodypix \ "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \ --optimize \ --quantization-type=float16 \ - --output-model-file "./mobilenet-float16-stride16.tflite" + --output-model-file "./mobilenet-float-multiplier-075-stride16-float16.tflite" ``` The above command is provided for convenience. @@ -342,6 +349,12 @@ Relevant links: * [TF Lite post_training_quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) * [TF GitHub #40183](https://github.com/tensorflow/tensorflow/issues/40183). +## TensorFlow Lite Runtime support (experimental) + +This project can also be used with [tflite-runtime](https://pypi.org/project/tflite-runtime/) instead of full TensorFlow (e.g. by using the `tflite` extra). +However, [TensorFlow Lite converter](https://www.tensorflow.org/lite/convert/) would require full TensorFlow. +In order to avoid it, one needs to use a TensorFlow Lite model (see previous section). + ## Docker Usage You could also use the Docker image if you prefer. diff --git a/requirements.tflite.txt b/requirements.tflite.txt new file mode 100644 index 0000000..faad658 --- /dev/null +++ b/requirements.tflite.txt @@ -0,0 +1 @@ +tflite-runtime==2.7.0 diff --git a/requirements.txt b/requirements.txt index 596486a..fefbf9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -opencv-python==4.5.5.62 +opencv-python==4.5.5.62 Pillow==8.4.0; python_version < "3.7" Pillow==9.0.1; python_version >= "3.7" pyfakewebcam==0.1.0 diff --git a/setup.py b/setup.py index d3fa798..d227f67 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,10 @@ REQUIRED_PACKAGES = f.readlines() +with open('requirements.tflite.txt', 'r', encoding='utf-8') as f: + TFLITE_REQUIRED_PACKAGES = f.readlines() + + with open('README.md', 'r', encoding='utf-8') as f: LONG_DESCRIPTION = '\n'.join([ line.rstrip() @@ -30,6 +34,10 @@ def local_scheme(version): get_requirements_with_groups(REQUIRED_PACKAGES) ) +ALL_EXTRAS = { + **EXTRAS, + 'tflite': TFLITE_REQUIRED_PACKAGES +} packages = find_packages(exclude=["tests", "tests.*"]) @@ -42,7 +50,7 @@ def local_scheme(version): author="Daniel Ecer", url="https://github.com/de-code/python-tf-bodypix", install_requires=DEFAULT_REQUIRED_PACKAGES, - extras_require=EXTRAS, + extras_require=ALL_EXTRAS, packages=packages, include_package_data=True, description='Python implemention of the TensorFlow BodyPix model.', diff --git a/tests/cli_test.py b/tests/cli_test.py index a6f3513..9c7319f 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -1,9 +1,9 @@ import logging from pathlib import Path -from tf_bodypix.download import BodyPixModelPaths +from tf_bodypix.download import ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS, BodyPixModelPaths from tf_bodypix.model import ModelArchitectureNames -from tf_bodypix.cli import main +from tf_bodypix.cli import DEFAULT_MODEL_TFLITE_PATH, main LOGGER = logging.getLogger(__name__) @@ -97,6 +97,15 @@ def test_should_list_all_default_model_urls(self, capsys): missing_urls = set(expected_urls) - set(output_urls) assert not missing_urls + def test_should_list_all_default_tflite_models(self, capsys): + expected_urls = ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS + main(['list-tflite-models']) + captured = capsys.readouterr() + output_urls = captured.out.splitlines() + LOGGER.debug('output_urls: %s', output_urls) + missing_urls = set(expected_urls) - set(output_urls) + assert not missing_urls + def test_should_be_able_to_convert_to_tflite_and_use_model(self, temp_dir: Path): output_model_file = temp_dir / 'model.tflite' main([ @@ -115,3 +124,14 @@ def test_should_be_able_to_convert_to_tflite_and_use_model(self, temp_dir: Path) '--source=%s' % EXAMPLE_IMAGE_URL, '--output=%s' % output_image_path ]) + + def test_should_be_able_to_use_existing_tflite_model(self, temp_dir: Path): + output_image_path = temp_dir / 'mask.jpg' + main([ + 'draw-mask', + '--model-path=%s' % DEFAULT_MODEL_TFLITE_PATH, + '--model-architecture=%s' % ModelArchitectureNames.MOBILENET_V1, + '--output-stride=16', + '--source=%s' % EXAMPLE_IMAGE_URL, + '--output=%s' % output_image_path + ]) diff --git a/tf_bodypix/bodypix_js_utils/decode_part_map.py b/tf_bodypix/bodypix_js_utils/decode_part_map.py index 845c4b6..0642e11 100644 --- a/tf_bodypix/bodypix_js_utils/decode_part_map.py +++ b/tf_bodypix/bodypix_js_utils/decode_part_map.py @@ -1,16 +1,26 @@ # based on: # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.4/body-pix/src/decode_part_map.ts -import tensorflow as tf +try: + import tensorflow as tf +except ImportError: + tf = None import numpy as np +DEFAULT_DTYPE = ( + tf.int32 if tf is not None else np.int32 +) + + def to_mask_tensor( segment_scores: np.ndarray, threshold: float, - dtype: type = tf.int32 + dtype: type = DEFAULT_DTYPE ) -> np.ndarray: + if tf is None: + return (segment_scores > threshold).astype(dtype) return tf.cast( tf.greater(segment_scores, threshold), dtype diff --git a/tf_bodypix/bodypix_js_utils/util.py b/tf_bodypix/bodypix_js_utils/util.py index 00d06af..12167d5 100644 --- a/tf_bodypix/bodypix_js_utils/util.py +++ b/tf_bodypix/bodypix_js_utils/util.py @@ -4,11 +4,22 @@ import logging import math from collections import namedtuple -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union + +try: + import tensorflow as tf +except ImportError: + tf = None -import tensorflow as tf import numpy as np +from tf_bodypix.utils.image import ( + ResizeMethod, + crop_and_resize_batch, + resize_image_to, + ImageSize +) + from .types import Keypoint, Pose, Vector2D @@ -50,6 +61,41 @@ def get_bodypix_input_resolution_height_and_width( ) +def _pad_image_like_tensorflow( + image: np.ndarray, + padding: Padding +) -> np.ndarray: + """ + This is my padding function to replace with tf.image.pad_to_bounding_box + :param image: + :param padding: + :return: + """ + + padded = np.copy(image) + dims = padded.shape + dtype = image.dtype + + if padding.top != 0: + top_zero_row = np.zeros(shape=(padding.top, dims[1], dims[2]), dtype=dtype) + padded = np.vstack([top_zero_row, padded]) + + if padding.bottom != 0: + bottom_zero_row = np.zeros(shape=(padding.top, dims[1], dims[2]), dtype=dtype) + padded = np.vstack([padded, bottom_zero_row]) + + dims = padded.shape + if padding.left != 0: + left_zero_column = np.zeros(shape=(dims[0], padding.left, dims[2]), dtype=dtype) + padded = np.hstack([left_zero_column, padded]) + + if padding.right != 0: + right_zero_column = np.zeros(shape=(dims[0], padding.right, dims[2]), dtype=dtype) + padded = np.hstack([padded, right_zero_column]) + + return padded + + # see padAndResizeTo def pad_and_resize_to( image: np.ndarray, @@ -75,14 +121,28 @@ def pad_and_resize_to( right=0 ) - padded = tf.image.pad_to_bounding_box( - image, - offset_height=padding.top, - offset_width=padding.left, - target_height=padding.top + input_height + padding.bottom, - target_width=padding.left + input_width + padding.right - ) - resized = tf.image.resize([padded], [target_height, target_width])[0] + if tf is not None: + padded = tf.image.pad_to_bounding_box( + image, + offset_height=padding.top, + offset_width=padding.left, + target_height=padding.top + input_height + padding.bottom, + target_width=padding.left + input_width + padding.right + ) + resized = tf.image.resize([padded], [target_height, target_width])[0] + else: + padded = _pad_image_like_tensorflow(image, padding) + LOGGER.debug( + 'padded: %r (%r) -> %r (%r)', + image.shape, image.dtype, padded.shape, padded.dtype + ) + resized = resize_image_to( + padded, ImageSize(width=target_width, height=target_height) + ) + LOGGER.debug( + 'resized: %r (%r) -> %r (%r)', + padded.shape, padded.dtype, resized.shape, resized.dtype + ) return resized, padding @@ -90,8 +150,10 @@ def get_images_batch(image: np.ndarray) -> np.ndarray: if len(image.shape) == 4: return image if len(image.shape) == 3: - return image[tf.newaxis, ...] - raise ValueError('invalid dimension, shape=%s' % image.shape) + if tf is not None: + return image[tf.newaxis, ...] + return np.expand_dims(image, axis=0) + raise ValueError('invalid dimension, shape=%s' % str(image.shape)) # reverse of pad_and_resize_to @@ -100,8 +162,10 @@ def remove_padding_and_resize_back( original_height: int, original_width: int, padding: Padding, - resize_method: str = tf.image.ResizeMethod.BILINEAR + resize_method: Optional[str] = None ) -> np.ndarray: + if not resize_method: + resize_method = ResizeMethod.BILINEAR boxes = [[ padding.top / (original_height + padding.top + padding.bottom - 1.0), padding.left / (original_width + padding.left + padding.right - 1.0), @@ -114,7 +178,7 @@ def remove_padding_and_resize_back( / (original_width + padding.left + padding.right - 1.0) ) ]] - return tf.image.crop_and_resize( + return crop_and_resize_batch( get_images_batch(resized_and_padded), boxes=boxes, box_indices=[0], @@ -128,12 +192,14 @@ def remove_padding_and_resize_back_simple( original_height: int, original_width: int, padding: Padding, - resize_method: str = tf.image.ResizeMethod.BILINEAR + resize_method: Optional[str] = None ) -> np.ndarray: padded_height = padding.top + original_height + padding.bottom padded_width = padding.left + original_width + padding.right - padded = tf.image.resize( - resized_and_padded, [padded_height, padded_width], method=resize_method + padded = resize_image_to( + resized_and_padded, + ImageSize(height=padded_height, width=padded_width), + resize_method=resize_method ) cropped = tf.image.crop_to_bounding_box( padded, @@ -145,6 +211,20 @@ def remove_padding_and_resize_back_simple( return cropped[0] +def _get_sigmoid_using_tf(x: np.ndarray): + return tf.math.sigmoid(x) + + +def _get_sigmoid_using_numpy(x: np.ndarray): + return 1/(1 + np.exp(-x)) + + +def get_sigmoid(x: np.ndarray): + if tf is not None: + return _get_sigmoid_using_tf(x) + return _get_sigmoid_using_numpy(x) + + # see scaleAndCropToInputTensorShape def scale_and_crop_to_input_tensor_shape( image: np.ndarray, @@ -154,13 +234,16 @@ def scale_and_crop_to_input_tensor_shape( resized_width: int, padding: Padding, apply_sigmoid_activation: bool = False, - resize_method: str = tf.image.ResizeMethod.BILINEAR + resize_method: Optional[str] = None ) -> np.ndarray: - resized_and_padded = tf.image.resize( - image, [resized_height, resized_width], method=resize_method + resized_and_padded = resize_image_to( + image, + ImageSize(height=resized_height, width=resized_width), + resize_method=resize_method ) if apply_sigmoid_activation: - resized_and_padded = tf.math.sigmoid(resized_and_padded) + resized_and_padded = get_sigmoid(resized_and_padded) + LOGGER.debug('after sigmoid: %r', resized_and_padded.shape) return remove_padding_and_resize_back( resized_and_padded, input_height, input_width, diff --git a/tf_bodypix/cli.py b/tf_bodypix/cli.py index 0c5c494..1431ce1 100644 --- a/tf_bodypix/cli.py +++ b/tf_bodypix/cli.py @@ -7,14 +7,17 @@ from itertools import cycle from pathlib import Path from time import time, sleep -from typing import ContextManager, Dict, List +from typing import ContextManager, Dict, List, Optional, Sequence os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" # pylint: disable=wrong-import-position # flake8: noqa: E402 -import tensorflow as tf +try: + import tensorflow as tf +except ImportError: + tf = None import numpy as np from tf_bodypix.utils.timer import LoggingTimer @@ -25,13 +28,17 @@ box_blur_image ) from tf_bodypix.utils.s3 import iter_s3_file_urls -from tf_bodypix.download import download_model +from tf_bodypix.download import ( + ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS, + BodyPixModelPaths, + TensorFlowLiteBodyPixModelPaths, + download_model +) from tf_bodypix.tflite import get_tflite_converter_for_model_path from tf_bodypix.model import ( load_model, VALID_MODEL_ARCHITECTURE_NAMES, PART_CHANNELS, - DEFAULT_RESIZE_METHOD, BodyPixModelWrapper, BodyPixResultWrapper ) @@ -52,9 +59,15 @@ def draw_poses(*_, **__): # type: ignore LOGGER = logging.getLogger(__name__) +DEFAULT_MODEL_TF_PATH = BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16 + + +DEFAULT_MODEL_TFLITE_PATH = TensorFlowLiteBodyPixModelPaths.MOBILENET_FLOAT_75_STRIDE_16_FLOAT16 + + DEFAULT_MODEL_PATH = ( - r'https://storage.googleapis.com/tfjs-models/savedmodel/' - r'bodypix/mobilenet/float/050/model-stride16.json' + DEFAULT_MODEL_TF_PATH if tf is not None + else DEFAULT_MODEL_TFLITE_PATH ) @@ -251,7 +264,7 @@ def get_mask( masks: List[np.ndarray], timer: LoggingTimer, args: argparse.Namespace, - resize_method: str = DEFAULT_RESIZE_METHOD + resize_method: Optional[str] = None ) -> np.ndarray: mask = bodypix_result.get_mask(args.threshold, dtype=np.float32, resize_method=resize_method) if args.mask_blur: @@ -280,13 +293,29 @@ def add_arguments(self, parser: argparse.ArgumentParser): help="The base URL for the storage containing the models" ) - def run(self, args: argparse.Namespace): # pylint: disable=unused-argument - bodypix_model_json_files = [ + def get_model_paths(self, storage_url: str) -> Sequence[str]: + return [ file_url - for file_url in iter_s3_file_urls(args.storage_url) + for file_url in iter_s3_file_urls(storage_url) if re.match(r'.*/bodypix/.*/model.*\.json', file_url) ] - print('\n'.join(bodypix_model_json_files)) + + def run(self, args: argparse.Namespace): # pylint: disable=unused-argument + print('\n'.join(self.get_model_paths(storage_url=args.storage_url))) + + +class ListTensorFlowLiteModelsSubCommand(SubCommand): + def __init__(self): + super().__init__("list-tflite-models", "Lists available tflite bodypix models") + + def add_arguments(self, parser: argparse.ArgumentParser): + add_common_arguments(parser) + + def get_model_paths(self) -> Sequence[str]: + return ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS + + def run(self, args: argparse.Namespace): # pylint: disable=unused-argument + print('\n'.join(self.get_model_paths())) class ConvertToTFLiteSubCommand(SubCommand): @@ -297,7 +326,7 @@ def add_arguments(self, parser: argparse.ArgumentParser): add_common_arguments(parser) parser.add_argument( "--model-path", - default=DEFAULT_MODEL_PATH, + default=DEFAULT_MODEL_TF_PATH, help="The path or URL to the bodypix model." ) parser.add_argument( @@ -386,6 +415,7 @@ def next_frame(self): image_array = next(self.image_iterator) except StopIteration: return False + LOGGER.debug('image_array: %r (%r)', image_array.shape, image_array.dtype) self.timer.on_step_start('model') output_image = self.get_output_image(image_array) self.timer.on_step_start('out') @@ -424,7 +454,7 @@ def run(self, args: argparse.Namespace): class DrawMaskApp(AbstractWebcamFilterApp): def get_output_image(self, image_array: np.ndarray) -> np.ndarray: - resize_method = DEFAULT_RESIZE_METHOD + resize_method = None result = self.get_bodypix_result(image_array) self.timer.on_step_start('get_mask') mask = self.get_mask(result, resize_method=resize_method) @@ -439,14 +469,27 @@ def get_output_image(self, image_array: np.ndarray) -> np.ndarray: mask, part_names=self.args.parts, resize_method=resize_method ) * 255 else: - mask_image = mask * 255 + if LOGGER.isEnabledFor(logging.DEBUG): + LOGGER.debug( + 'mask: %r (%r, %r) (%s)', + mask.shape, np.min(mask), np.max(mask), mask.dtype + ) + mask_image = mask * 255.0 if self.args.mask_alpha is not None: self.timer.on_step_start('overlay') LOGGER.debug('mask.shape: %s (%s)', mask.shape, mask.dtype) alpha = self.args.mask_alpha try: - if mask_image.dtype == tf.int32: - mask_image = tf.cast(mask, tf.float32) + if tf is not None: + if mask_image.dtype == tf.int32: + mask_image = tf.cast(mask_image, tf.float32) + else: + image_array = np.asarray(image_array).astype(np.float32) + if LOGGER.isEnabledFor(logging.DEBUG): + LOGGER.debug( + 'mask_image: %r (%r, %r) (%s)', + mask_image.shape, np.min(mask_image), np.max(mask_image), mask_image.dtype + ) except TypeError: pass output = np.clip( @@ -599,6 +642,7 @@ def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp: SUB_COMMANDS: List[SubCommand] = [ ListModelsSubCommand(), + ListTensorFlowLiteModelsSubCommand(), ConvertToTFLiteSubCommand(), DrawMaskSubCommand(), DrawPoseSubCommand(), diff --git a/tf_bodypix/download.py b/tf_bodypix/download.py index 388fb36..9d4fad2 100644 --- a/tf_bodypix/download.py +++ b/tf_bodypix/download.py @@ -2,6 +2,7 @@ import json import os import re +from urllib.parse import urlparse from hashlib import md5 @@ -50,6 +51,54 @@ class BodyPixModelPaths: ) +_TFLITE_DOWNLOAD_URL_PREFIX = r'https://www.dropbox.com/sh/d6tqb3gfrugs7ne/' + + +class TensorFlowLiteBodyPixModelPaths: + MOBILENET_FLOAT_50_STRIDE_8_FLOAT16 = ( + _TFLITE_DOWNLOAD_URL_PREFIX + + 'AADUtMGoDO6vzOfRLP0Dg7ira/mobilenet-float-multiplier-050-stride8-float16.tflite?dl=1' + ) + MOBILENET_FLOAT_50_STRIDE_16_FLOAT16 = ( + _TFLITE_DOWNLOAD_URL_PREFIX + + 'AAAhnozSEO07xzgL495dW3h8a/mobilenet-float-multiplier-050-stride16-float16.tflite?dl=1' + ) + + MOBILENET_FLOAT_75_STRIDE_8_FLOAT16 = ( + _TFLITE_DOWNLOAD_URL_PREFIX + + 'AADBYGO2xj2v9Few4qBq62wZa/mobilenet-float-multiplier-075-stride8-float16.tflite?dl=1' + ) + MOBILENET_FLOAT_75_STRIDE_16_FLOAT16 = ( + _TFLITE_DOWNLOAD_URL_PREFIX + + 'AAAGYNAOTTWBl9ZDhALv7rEOa/mobilenet-float-multiplier-075-stride16-float16.tflite?dl=1' + ) + + MOBILENET_FLOAT_100_STRIDE_8_FLOAT16 = ( + _TFLITE_DOWNLOAD_URL_PREFIX + + 'AADr8zOtPZz2cWlQEvKgIbdTa/mobilenet-float-multiplier-100-stride8-float16.tflite?dl=1' + ) + MOBILENET_FLOAT_100_STRIDE_16_FLOAT16 = ( + _TFLITE_DOWNLOAD_URL_PREFIX + + 'AAAo-hkaCqx2pN99cCvDPcosa/mobilenet-float-multiplier-100-stride16-float16.tflite?dl=1' + ) + + RESNET50_FLOAT_STRIDE_16 = ( + _TFLITE_DOWNLOAD_URL_PREFIX + + 'AADvvgLyPXMPOeRyRY9WQ9Mva/resnet50-float-stride16-float16.tflite?dl=1' + ) + MOBILENET_RESNET50_FLOAT_STRIDE_32 = ( + _TFLITE_DOWNLOAD_URL_PREFIX + + 'AADGlTuMQQeL8vm6BuOwObKTa/resnet50-float-stride32-float16.tflite?dl=1' + ) + + +ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS = [ + value + for key, value in TensorFlowLiteBodyPixModelPaths.__dict__.items() + if key.isupper() and isinstance(value, str) +] + + class DownloadError(RuntimeError): pass @@ -57,13 +106,11 @@ class DownloadError(RuntimeError): def download_model(model_path: str) -> str: if os.path.exists(model_path): return model_path - if not model_path.endswith('.json'): - raise ValueError('remote model path needs to end with .json') - model_base_path = os.path.dirname(model_path) + parsed_model_path = urlparse(model_path) local_name_part = re.sub( r'[^a-zA-Z0-9]+', r'-', - os.path.splitext(model_path)[0] + os.path.splitext(parsed_model_path.path)[0] ) local_name = ( md5(model_path.encode('utf-8')).hexdigest() + '-' @@ -73,6 +120,18 @@ def download_model(model_path: str) -> str: cache_dir = get_default_cache_dir( cache_subdir=os.path.join('tf-bodypix', local_name) ) + if parsed_model_path.path.endswith('.tflite'): + return download_file_to( + source_url=model_path, + local_path=os.path.join( + cache_dir, + os.path.basename(parsed_model_path.path) + ), + skip_if_exists=True + ) + if not parsed_model_path.path.endswith('.json'): + raise ValueError('remote model path needs to end with .json') + model_base_path = os.path.dirname(model_path) local_model_json_path = download_file_to( source_url=model_path, local_path=os.path.join(cache_dir, 'model.json'), diff --git a/tf_bodypix/model.py b/tf_bodypix/model.py index d456ed1..5acf008 100644 --- a/tf_bodypix/model.py +++ b/tf_bodypix/model.py @@ -5,7 +5,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import tensorflow as tf + +try: + import tensorflow as tf + tflite = tf.lite +except ImportError: + tf = None + import tflite_runtime.interpreter as tflite # type: ignore try: import tfjs_graph_converter @@ -47,9 +53,6 @@ } -DEFAULT_RESIZE_METHOD = tf.image.ResizeMethod.BILINEAR - - ImageSize = namedtuple('ImageSize', ('height', 'width')) @@ -96,6 +99,27 @@ def __call__(self, image: np.ndarray) -> dict: pass +def _get_imagenet_preprocessed_image_using_numpy( + image_array: np.ndarray +) -> np.ndarray: + result = np.divide(image_array, 127.5, dtype=np.float32) + result = np.subtract(result, 1, out=result) + LOGGER.debug( + 'imagenet preprocessed: %r (%r) -> %r (%r)', + image_array.shape, image_array.dtype, + result.shape, result.dtype + ) + return result + + +def _get_mobilenet_preprocessed_image( + image_array: np.ndarray +) -> np.ndarray: + if tf is not None: + return tf.keras.applications.mobilenet.preprocess_input(image_array) + return _get_imagenet_preprocessed_image_using_numpy(image_array) + + class MobileNetBodyPixPredictWrapper(BodyPixArchitecture): def __init__(self, predict_fn: Callable[[np.ndarray], dict]): super().__init__(ModelArchitectureNames.MOBILENET_V1) @@ -103,9 +127,12 @@ def __init__(self, predict_fn: Callable[[np.ndarray], dict]): def __call__(self, image: np.ndarray) -> dict: if len(image.shape) == 3: - image = image[tf.newaxis, ...] + if tf is not None: + image = image[tf.newaxis, ...] + else: + image = np.expand_dims(image, axis=0) return self.predict_fn( - tf.keras.applications.mobilenet.preprocess_input(image) + _get_mobilenet_preprocessed_image(image) ) @@ -118,12 +145,16 @@ def __call__(self, image: np.ndarray) -> dict: image = np.add(image, np.array(IMAGE_NET_MEAN)) # Note: tf.keras.applications.resnet50.preprocess_input is rotating the image as well? if len(image.shape) == 3: - image = image[tf.newaxis, ...] - image = tf.cast(image, tf.float32) + if tf is not None: + image = image[tf.newaxis, ...] + else: + image = np.expand_dims(image, axis=0) + if tf is not None: + image = tf.constant(tf.cast(image, tf.float32)) + else: + image = np.asarray(image).astype(np.float32) LOGGER.debug('image.shape: %s (%s)', image.shape, image.dtype) - predictions = self.predict_fn( - tf.constant(image) - ) + predictions = self.predict_fn(image) return predictions @@ -216,8 +247,9 @@ def __init__( def _get_scaled_scores( self, logits: np.ndarray, - resize_method: str = DEFAULT_RESIZE_METHOD + resize_method: Optional[str] = None ) -> np.ndarray: + LOGGER.debug('logits: %r', logits.shape) return scale_and_crop_to_input_tensor_shape( logits, self.original_size.height, @@ -240,7 +272,7 @@ def get_scaled_part_segmentation( mask: np.ndarray = None, part_names: List[str] = None, outside_mask_value: int = -1, - resize_method: str = DEFAULT_RESIZE_METHOD + resize_method: Optional[str] = None ) -> np.ndarray: scaled_part_heatmap_argmax = np.argmax( self.get_scaled_part_heatmap_scores(resize_method=resize_method), @@ -268,7 +300,7 @@ def get_scaled_part_segmentation( def get_mask( self, threshold: float, - resize_method: str = DEFAULT_RESIZE_METHOD, + resize_method: Optional[str] = None, **kwargs ) -> np.ndarray: return to_mask_tensor( @@ -281,7 +313,7 @@ def get_part_mask( self, mask: np.ndarray, part_names: List[str] = None, - resize_method: str = DEFAULT_RESIZE_METHOD + resize_method: Optional[str] = None ) -> np.ndarray: if is_all_part_names(part_names): return mask @@ -301,7 +333,7 @@ def get_colored_part_mask( mask: np.ndarray, part_colors: List[T_Color] = None, part_names: List[str] = None, - resize_method: str = DEFAULT_RESIZE_METHOD + resize_method: Optional[str] = None ) -> np.ndarray: part_segmentation = self.get_scaled_part_segmentation( mask, part_names=part_names, resize_method=resize_method @@ -358,8 +390,8 @@ def get_padded_and_resized( self, image: np.ndarray, model_input_size: ImageSize ) -> Tuple[np.ndarray, Padding]: LOGGER.debug( - 'pad_and_resize_to: image.shape=%s, model_input_size=%s', - image.shape, model_input_size + 'pad_and_resize_to: image.shape=%s (%r), model_input_size=%s', + image.shape, image.dtype, model_input_size ) return pad_and_resize_to( image, @@ -393,8 +425,14 @@ def find_required_tensor_in_map( def predict_single(self, image: np.ndarray) -> BodyPixResultWrapper: original_size = ImageSize(*image.shape[:2]) + LOGGER.debug('original_size: %r (%r)', original_size, image.dtype) model_input_size = self.get_bodypix_input_size(original_size) + LOGGER.debug('model_input_size: %r', model_input_size) model_input_image, padding = self.get_padded_and_resized(image, model_input_size) + LOGGER.debug( + 'model_input_image: %r (%r)', model_input_image.shape, model_input_image.dtype + ) + LOGGER.debug('predict_fn: %r', self.predict_fn) tensor_map = self.predict_fn(model_input_image) @@ -437,7 +475,7 @@ def predict_single(self, image: np.ndarray) -> BodyPixResultWrapper: ) -def get_structured_output_names(structured_outputs: List[tf.Tensor]) -> List[str]: +def get_structured_output_names(structured_outputs: List['tf.Tensor']) -> List[str]: return [ tensor.name.replace(':0', '') for tensor in structured_outputs @@ -454,7 +492,7 @@ def to_number_of_dimensions(data: np.ndarray, dimension_count: int) -> np.ndarra def load_tflite_model(model_path: str): # Load TFLite model and allocate tensors. - interpreter = tf.lite.Interpreter(model_path=model_path) + interpreter = tflite.Interpreter(model_path=model_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() @@ -478,6 +516,10 @@ def load_tflite_model(model_path: str): def predict(image_data: np.ndarray): nonlocal input_shape + LOGGER.debug( + 'tflite predict, original image_data.shape=%s (%s)', + image_data.shape, image_data.dtype + ) image_data = to_number_of_dimensions(image_data, len(input_shape)) LOGGER.debug('tflite predict, image_data.shape=%s (%s)', image_data.shape, image_data.dtype) height, width, *_ = image_data.shape diff --git a/tf_bodypix/sink.py b/tf_bodypix/sink.py index 44de8ef..70ef229 100644 --- a/tf_bodypix/sink.py +++ b/tf_bodypix/sink.py @@ -1,11 +1,11 @@ -import os import logging from contextlib import contextmanager from functools import partial from typing import Callable, ContextManager, Iterator import numpy as np -import tensorflow as tf + +from tf_bodypix.utils.image import write_image_to # pylint: disable=import-outside-toplevel @@ -16,12 +16,6 @@ T_OutputSink = Callable[[np.ndarray], None] -def write_image_to(image_array: np.ndarray, path: str): - LOGGER.info('writing image to: %r', path) - os.makedirs(os.path.dirname(path), exist_ok=True) - tf.keras.preprocessing.image.save_img(path, image_array) - - def get_v4l2_output_sink(device_name: str) -> ContextManager[T_OutputSink]: from tf_bodypix.utils.v4l2 import VideoLoopbackImageSink return VideoLoopbackImageSink(device_name) diff --git a/tf_bodypix/source.py b/tf_bodypix/source.py index 49b4ea4..2294970 100644 --- a/tf_bodypix/source.py +++ b/tf_bodypix/source.py @@ -6,9 +6,7 @@ from threading import Thread from typing import ContextManager, Iterable, Iterator, Optional -import tensorflow as tf - -from tf_bodypix.utils.image import resize_image_to, ImageSize, ImageArray +from tf_bodypix.utils.image import load_image, ImageSize, ImageArray from tf_bodypix.utils.io import get_file, strip_url_suffix @@ -52,12 +50,7 @@ def get_simple_image_source( ) -> Iterator[Iterable[ImageArray]]: local_image_path = get_file(path) LOGGER.debug('local_image_path: %r', local_image_path) - image = tf.keras.preprocessing.image.load_img( - local_image_path - ) - image_array = tf.keras.preprocessing.image.img_to_array(image) - if image_size is not None: - image_array = resize_image_to(image_array, image_size) + image_array = load_image(local_image_path, image_size=image_size) yield [image_array] diff --git a/tf_bodypix/tflite.py b/tf_bodypix/tflite.py index 6429f48..e51b45a 100644 --- a/tf_bodypix/tflite.py +++ b/tf_bodypix/tflite.py @@ -1,6 +1,9 @@ import logging -import tensorflow as tf +try: + import tensorflow as tf +except ImportError: + tf = None try: import tfjs_graph_converter @@ -11,7 +14,7 @@ LOGGER = logging.getLogger(__name__) -def get_tflite_converter_for_tfjs_model_path(model_path: str) -> tf.lite.TFLiteConverter: +def get_tflite_converter_for_tfjs_model_path(model_path: str) -> 'tf.lite.TFLiteConverter': if tfjs_graph_converter is None: raise ImportError('tfjs_graph_converter required') graph = tfjs_graph_converter.api.load_graph_model(model_path) @@ -19,7 +22,7 @@ def get_tflite_converter_for_tfjs_model_path(model_path: str) -> tf.lite.TFLiteC return tf.lite.TFLiteConverter.from_concrete_functions([tf_fn]) -def get_tflite_converter_for_model_path(model_path: str) -> tf.lite.TFLiteConverter: +def get_tflite_converter_for_model_path(model_path: str) -> 'tf.lite.TFLiteConverter': LOGGER.debug('converting model_path: %s', model_path) # if model_path.endswith('.json'): return get_tflite_converter_for_tfjs_model_path(model_path) diff --git a/tf_bodypix/utils/image.py b/tf_bodypix/utils/image.py index 18d7b65..1ef042f 100644 --- a/tf_bodypix/utils/image.py +++ b/tf_bodypix/utils/image.py @@ -1,14 +1,25 @@ import logging +import os from collections import namedtuple +from typing import Optional, Sequence import numpy as np -import tensorflow as tf + +try: + import tensorflow as tf +except ImportError: + tf = None try: from cv2 import cv2 except ImportError: cv2 = None +try: + import PIL.Image +except ImportError: + PIL = None + LOGGER = logging.getLogger(__name__) @@ -19,6 +30,10 @@ ImageArray = np.ndarray +class ResizeMethod: + BILINEAR = 'bilinear' + + def require_opencv(): if cv2 is None: raise ImportError('OpenCV is required') @@ -42,12 +57,165 @@ def get_image_size(image: np.ndarray): return ImageSize(height=height, width=width) -def resize_image_to(image: np.ndarray, size: ImageSize) -> np.ndarray: - if get_image_size(image) == size: - LOGGER.debug('image has already desired size: %s', size) - return image +def _resize_image_to_using_tf( + image_array: np.ndarray, + image_size: ImageSize, + resize_method: Optional[str] = None +) -> np.ndarray: + if not resize_method: + resize_method = tf.image.ResizeMethod.BILINEAR + LOGGER.debug('resizing image: %r -> %r', image_array.shape, image_size) + return tf.image.resize( + image_array, + (image_size.height, image_size.width), + method=resize_method + ) + + +def _get_pil_image(image_array: np.ndarray) -> 'PIL.Image': + if image_array.shape[-1] == 1: + pil_mode = 'L' + image_array = np.reshape(image_array, image_array.shape[:2]) + else: + pil_mode = 'RGB' + image_array = image_array.astype(np.uint8) + pil_image = PIL.Image.fromarray(image_array, mode=pil_mode) + return pil_image + + +# copied from: +# https://chao-ji.github.io/jekyll/update/2018/07/19/BilinearResize.html +def _numpy_bilinear_resize_2d( # pylint: disable=too-many-locals + image: np.ndarray, + height: int, + width: int +) -> np.ndarray: + """ + `image` is a 2-D numpy array + `height` and `width` are the desired spatial dimension of the new 2-D array. + """ + img_height, img_width = image.shape + + image = image.ravel() + + x_ratio = float(img_width - 1) / (width - 1) if width > 1 else 0 + y_ratio = float(img_height - 1) / (height - 1) if height > 1 else 0 + + y, x = np.divmod(np.arange(height * width), width) + + x_l = np.floor(x_ratio * x).astype('int32') + y_l = np.floor(y_ratio * y).astype('int32') + + x_h = np.ceil(x_ratio * x).astype('int32') + y_h = np.ceil(y_ratio * y).astype('int32') + + x_weight = (x_ratio * x) - x_l + y_weight = (y_ratio * y) - y_l + + a = image[y_l * img_width + x_l] + b = image[y_l * img_width + x_h] + c = image[y_h * img_width + x_l] + d = image[y_h * img_width + x_h] + + resized = ( + a * (1 - x_weight) * (1 - y_weight) + + b * x_weight * (1 - y_weight) + + c * y_weight * (1 - x_weight) + + d * x_weight * y_weight + ) - return tf.image.resize([image], (size.height, size.width))[0] + return resized.reshape(height, width) + + +def _numpy_bilinear_resize_3d(image: np.ndarray, height: int, width: int) -> np.ndarray: + _, _, dimensions = image.shape + return np.stack( + [ + _numpy_bilinear_resize_2d( + image[:, :, dimension], height, width + ) + for dimension in range(dimensions) + ], + axis=-1 + ) + + +def _resize_image_to_using_numpy( + image_array: np.ndarray, + image_size: ImageSize, + resize_method: Optional[str] = None +) -> np.ndarray: + assert not resize_method or resize_method == 'bilinear' + if len(image_array.shape) == 4: + assert image_array.shape[0] == 1 + image_array = image_array[0] + LOGGER.debug( + 'resizing image: %r (%r) -> %r', image_array.shape, image_array.dtype, image_size + ) + resize_image_array = ( + _numpy_bilinear_resize_3d( + np.asarray(image_array), image_size.height, image_size.width + ).astype(image_array.dtype) + ) + LOGGER.debug( + 'resize_image_array image: %r (%r)', image_array.shape, resize_image_array.dtype + ) + return resize_image_array + + +def resize_image_to( + image_array: np.ndarray, + image_size: ImageSize, + resize_method: Optional[str] = None +) -> np.ndarray: + if get_image_size(image_array) == image_size: + LOGGER.debug('image has already desired size: %s', image_size) + return image_array + + if tf is not None: + return _resize_image_to_using_tf(image_array, image_size, resize_method) + return _resize_image_to_using_numpy(image_array, image_size, resize_method) + + +def crop_and_resize_batch( # pylint: disable=too-many-locals + image_array_batch: np.ndarray, + boxes: Sequence[Sequence[float]], + box_indices: Sequence[int], + crop_size: Sequence[int], + method='bilinear', +) -> np.ndarray: + if tf is not None: + return tf.image.crop_and_resize( + image_array_batch, + boxes=boxes, + box_indices=box_indices, + crop_size=crop_size, + method=method + ) + assert list(box_indices) == [0] + assert len(boxes) == 1 + assert len(crop_size) == 2 + box = np.array(boxes[0]) + assert np.min(box) >= 0 + assert np.max(box) <= 1 + y1, x1, y2, x2 = list(box) + assert y1 <= y2 + assert x1 <= x2 + assert len(image_array_batch) == 1 + image_size = get_image_size(image_array_batch[0]) + image_y1 = int(y1 * (image_size.height - 1)) + image_y2 = int(y2 * (image_size.height - 1)) + image_x1 = int(x1 * (image_size.width - 1)) + image_x2 = int(x2 * (image_size.width - 1)) + LOGGER.debug('image y1, x1, y2, x2: %r', (image_y1, image_x1, image_y2, image_x2)) + cropped_image_array = image_array_batch[0][ + image_y1:(1 + image_y2), image_x1: (1 + image_x2), : + ] + LOGGER.debug('cropped_image_array: %r', cropped_image_array.shape) + resized_cropped_image_array = resize_image_to( + cropped_image_array, ImageSize(height=crop_size[0], width=crop_size[1]) + ) + return np.expand_dims(resized_cropped_image_array, 0) def bgr_to_rgb(image: np.ndarray) -> np.ndarray: @@ -57,3 +225,54 @@ def bgr_to_rgb(image: np.ndarray) -> np.ndarray: def rgb_to_bgr(image: np.ndarray) -> np.ndarray: return bgr_to_rgb(image) + + +def _load_image_using_tf( + local_image_path: str, + image_size: ImageSize = None +) -> np.ndarray: + image = tf.keras.preprocessing.image.load_img( + local_image_path + ) + image_array = tf.keras.preprocessing.image.img_to_array(image) + if image_size is not None: + image_array = resize_image_to(image_array, image_size) + return image_array + + +def _load_image_using_pillow( + local_image_path: str, + image_size: ImageSize = None +) -> np.ndarray: + with PIL.Image.open(local_image_path) as image: + image_array = np.asarray(image) + if image_size is not None: + image_array = resize_image_to(image_array, image_size) + return image_array + + +def load_image( + local_image_path: str, + image_size: ImageSize = None +) -> np.ndarray: + if tf is not None: + return _load_image_using_tf(local_image_path, image_size=image_size) + return _load_image_using_pillow(local_image_path, image_size=image_size) + + +def save_image_using_tf(image_array: np.ndarray, path: str): + tf.keras.preprocessing.image.save_img(path, image_array) + + +def save_image_using_pillow(image_array: np.ndarray, path: str): + pil_image = _get_pil_image(image_array) + pil_image.save(path) + + +def write_image_to(image_array: np.ndarray, path: str): + LOGGER.info('writing image to: %r', path) + os.makedirs(os.path.dirname(path), exist_ok=True) + if tf is not None: + save_image_using_tf(image_array, path) + else: + save_image_using_pillow(image_array, path)