diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a7b3ac3..1bae00c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,10 +11,18 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} strategy: matrix: + os: [ubuntu-latest] python-version: [3.6, 3.7, 3.8] + include: + - python-version: 3.8 + push-package: true + - os: windows-latest + python-version: 3.8 + - os: macos-latest + python-version: 3.8 steps: - uses: actions/checkout@v2 @@ -24,7 +32,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - make dev-venv + make dev-venv SYSTEM_PYTHON=python - name: Lint run: | make dev-lint @@ -32,17 +40,17 @@ jobs: run: | make dev-pytest - name: Build dist - if: matrix.python-version == '3.8' + if: matrix.push-package == true run: | make dev-remove-dist dev-build-dist dev-list-dist-contents dev-test-install-dist - name: Publish distribution to Test PyPI - if: matrix.python-version == '3.8' + if: matrix.push-package == true uses: pypa/gh-action-pypi-publish@master with: password: ${{ secrets.test_pypi_password }} repository_url: https://test.pypi.org/legacy/ - name: Publish distribution to PyPI - if: matrix.python-version == '3.8' && startsWith(github.ref, 'refs/tags') + if: matrix.push-package == true && startsWith(github.ref, 'refs/tags') uses: pypa/gh-action-pypi-publish@master with: password: ${{ secrets.pypi_password }} diff --git a/Makefile b/Makefile index e34745a..1ef7559 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,15 @@ VENV = venv -PIP = $(VENV)/bin/pip -PYTHON = $(VENV)/bin/python + +ifeq ($(OS),Windows_NT) + VENV_BIN = $(VENV)/Scripts +else + VENV_BIN = $(VENV)/bin +endif + +PYTHON = $(VENV_BIN)/python +PIP = $(VENV_BIN)/python -m pip + +SYSTEM_PYTHON = python3 VENV_TEMP = venv_temp @@ -30,7 +39,7 @@ venv-clean: venv-create: - python3 -m venv $(VENV) + $(SYSTEM_PYTHON) -m venv $(VENV) dev-install: diff --git a/tests/cli_test.py b/tests/cli_test.py index c5c8971..a02937b 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -1,8 +1,13 @@ +import logging from pathlib import Path +from tf_bodypix.download import BodyPixModelPaths from tf_bodypix.cli import main +LOGGER = logging.getLogger(__name__) + + EXAMPLE_IMAGE_URL = ( r'https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/' r'Person_Of_Interest_-_Panel_%289353656298%29.jpg/' @@ -69,3 +74,16 @@ def test_should_not_fail_to_replace_background(self, temp_dir: Path): '--background=%s' % EXAMPLE_IMAGE_URL, '--output=%s' % output_image_path ]) + + def test_should_list_all_default_model_urls(self, capsys): + expected_urls = [ + value + for key, value in BodyPixModelPaths.__dict__.items() + if not key.startswith('_') + ] + main(['list-models']) + captured = capsys.readouterr() + output_urls = captured.out.splitlines() + LOGGER.debug('output_urls: %s', output_urls) + missing_urls = set(expected_urls) - set(output_urls) + assert not missing_urls diff --git a/tf_bodypix/cli.py b/tf_bodypix/cli.py index 0c31142..c6f4c91 100644 --- a/tf_bodypix/cli.py +++ b/tf_bodypix/cli.py @@ -253,7 +253,7 @@ def add_arguments(self, parser: argparse.ArgumentParser): add_common_arguments(parser) parser.add_argument( "--storage-url", - default="https://storage.googleapis.com/tfjs-models/", + default="https://storage.googleapis.com/tfjs-models", help="The base URL for the storage containing the models" ) diff --git a/tf_bodypix/download.py b/tf_bodypix/download.py index 23c89cb..3c6c8cf 100644 --- a/tf_bodypix/download.py +++ b/tf_bodypix/download.py @@ -78,7 +78,7 @@ def download_model(model_path: str) -> str: for weights_manifest_path in weights_manifest_paths: local_model_json_path = tf.keras.utils.get_file( os.path.basename(weights_manifest_path), - os.path.join(model_base_path, weights_manifest_path), + model_base_path + '/' + weights_manifest_path, cache_subdir=cache_subdir, ) return local_model_path diff --git a/tf_bodypix/utils/s3.py b/tf_bodypix/utils/s3.py index 867e9ed..6fab53a 100644 --- a/tf_bodypix/utils/s3.py +++ b/tf_bodypix/utils/s3.py @@ -1,6 +1,5 @@ import logging -import os import urllib.request from xml.etree import ElementTree from typing import Iterable @@ -17,6 +16,8 @@ def iter_s3_file_urls(base_url: str) -> Iterable[str]: + if not base_url.endswith('/'): + base_url += '/' marker = None while True: current_url = base_url @@ -29,7 +30,7 @@ def iter_s3_file_urls(base_url: str) -> Iterable[str]: for item in root.findall(S3_CONTENTS): key = item.findtext(S3_KEY) LOGGER.debug('key: %s', key) - yield os.path.join(base_url, key) + yield base_url + key next_marker = root.findtext(S3_NEXT_MARKER) if not next_marker or next_marker == marker: break