diff --git a/.conda/README.md b/.conda/README.md index 71a49d7f1..65fadd36e 100644 --- a/.conda/README.md +++ b/.conda/README.md @@ -3,7 +3,7 @@ This folder defines the conda package build for Linux and Windows. There are run To build, first go to the base repo directory and install the build environment: ``` -mamba env create -f environment_build.yml -n sleap_build && conda activate sleap_build +conda env create -f environment_build.yml -n sleap_build && conda activate sleap_build ``` And finally, run the build command pointing to this directory: @@ -15,7 +15,7 @@ conda build .conda --output-folder build -c conda-forge -c nvidia -c https://con To install the local package: ``` -mamba create -n sleap_0 -c conda-forge -c nvidia -c ./build -c https://conda.anaconda.org/sleap/ -c anaconda sleap=x.x.x +conda create -n sleap_0 -c conda-forge -c nvidia -c ./build -c https://conda.anaconda.org/sleap/ -c anaconda sleap=x.x.x ``` replacing x.x.x with the version of SLEAP that you just built. diff --git a/.conda/condarc.yaml b/.conda/condarc.yaml index f9ac6efbe..c5fbc2d96 100644 --- a/.conda/condarc.yaml +++ b/.conda/condarc.yaml @@ -1,5 +1,6 @@ channels: - conda-forge - nvidia + - https://conda.anaconda.org/sleap/label/dev - sleap - anaconda diff --git a/.conda/meta.yaml b/.conda/meta.yaml index caffe9fcb..c1781a3ee 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -16,7 +16,7 @@ source: path: ../ build: - number: 1 + number: 0 requirements: host: @@ -32,12 +32,12 @@ requirements: # unnecessary pypi packages are installed via the build script (bld.bat, build.sh) - conda-forge::attrs ==21.4.0 - conda-forge::cattrs ==1.1.1 - - conda-forge::h5py ==3.1 # [not win] - - conda-forge::imgaug ==0.4.0 + - conda-forge::h5py ==3.7.0 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::networkx - - conda-forge::opencv + - conda-forge::opencv <4.9.0 - conda-forge::pandas - conda-forge::pillow >=8.3.2 - conda-forge::psutil @@ -53,20 +53,24 @@ requirements: - conda-forge::scikit-learn ==1.0 - conda-forge::scikit-video - conda-forge::seaborn + - conda-forge::qudida + - conda-forge::albumentations + - conda-forge::ndx-pose <0.2.0 + - conda-forge::importlib-metadata ==4.11.4 run: - conda-forge::python ==3.7.12 # Run into _MAX_WINDOWS_WORKERS not found if < - conda-forge::attrs ==21.4.0 - conda-forge::cattrs ==1.1.1 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::cudatoolkit ==11.3.1 - conda-forge::cudnn=8.2.1 - nvidia::cuda-nvcc=11.3 - - conda-forge::h5py ==3.1 # [not win] - - conda-forge::imgaug ==0.4.0 + - conda-forge::h5py ==3.7.0 - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::networkx - numpy >=1.19.5,<1.23.0 # Linux likes anaconda, windows likes conda-forge - - conda-forge::opencv + - conda-forge::opencv <4.9.0 - conda-forge::pandas - conda-forge::pillow >=8.3.2 - conda-forge::psutil @@ -82,9 +86,15 @@ requirements: - conda-forge::scikit-learn ==1.0 - conda-forge::scikit-video - conda-forge::seaborn - - sleap::tensorflow >=2.6.3,<2.11 # No windows GPU support for >2.10, sleap channel has 2.6.3 + - sleap/label/dev::tensorflow ==2.7.0 # TODO: Switch to main label when updated - conda-forge::tensorflow-hub <0.14.0 # Causes pynwb conflicts on linux GH-1446 + - conda-forge::qudida + - conda-forge::albumentations + - conda-forge::ndx-pose <0.2.0 + - conda-forge::importlib-metadata ==4.11.4 -test: - imports: - - sleap \ No newline at end of file +# This no longer works so we have moved it to the build workflow +# https://github.com/talmolab/sleap/pull/1744 +# test: +# imports: +# - sleap \ No newline at end of file diff --git a/.conda_mac/build.sh b/.conda_mac/build.sh index 2036035f6..a68193560 100644 --- a/.conda_mac/build.sh +++ b/.conda_mac/build.sh @@ -2,7 +2,6 @@ # Install anything that didn't get conda installed via pip. # We need to turn pip index back on because Anaconda turns it off for some reason. - export PIP_NO_INDEX=False export PIP_NO_DEPENDENCIES=False export PIP_IGNORE_INSTALLED=False diff --git a/.conda_mac/condarc.yaml b/.conda_mac/condarc.yaml index df2727c74..c1be41bf1 100644 --- a/.conda_mac/condarc.yaml +++ b/.conda_mac/condarc.yaml @@ -1,4 +1,3 @@ -# This file is not used at the moment, but when github actions can be used to build the package, it needs to be listed. # https://github.com/github/roadmap/issues/528 channels: diff --git a/.conda_mac/meta.yaml b/.conda_mac/meta.yaml index 7496f2057..8f773badf 100644 --- a/.conda_mac/meta.yaml +++ b/.conda_mac/meta.yaml @@ -16,14 +16,14 @@ about: summary: {{ data.get('description') }} build: - number: 1 + number: 0 source: path: ../ requirements: host: - - conda-forge::python ~=3.9 + - conda-forge::python >=3.9.0, <3.10.0 - anaconda::numpy >=1.19.5,<1.23.0 - conda-forge::setuptools - conda-forge::packaging @@ -34,11 +34,11 @@ requirements: - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - conda-forge::h5py - - conda-forge::imgaug ==0.4.0 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::keras <2.10.0,>=2.9.0rc0 # Required by tensorflow-macos - - conda-forge::networkx + - conda-forge::networkx <3.3 - conda-forge::opencv - conda-forge::pandas - conda-forge::pillow @@ -55,17 +55,20 @@ requirements: - conda-forge::scikit-learn ==1.0 - conda-forge::scikit-video - conda-forge::seaborn + - conda-forge::qudida + - conda-forge::albumentations + - conda-forge::ndx-pose <0.2.0 run: - - conda-forge::python ~=3.9 + - conda-forge::python >=3.9.0, <3.10.0 - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - conda-forge::h5py - - conda-forge::imgaug ==0.4.0 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::keras <2.10.0,>=2.9.0rc0 # Required by tensorflow-macos - - conda-forge::networkx + - conda-forge::networkx <3.3 - anaconda::numpy >=1.19.5,<1.23.0 - conda-forge::opencv - conda-forge::pandas @@ -83,8 +86,11 @@ requirements: - conda-forge::scikit-learn ==1.0 - conda-forge::scikit-video - conda-forge::seaborn - - conda-forge::tensorflow-hub + # - conda-forge::tensorflow-hub # pulls in tensorflow cpu from conda-forge + - conda-forge::qudida + - conda-forge::albumentations + - conda-forge::ndx-pose <0.2.0 -test: - imports: - - sleap +# test: +# imports: +# - sleap diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 8c95f28dc..6a92c2e3b 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -25,15 +25,15 @@ Tell us a little about the system you're using. Please include information about how you installed. --> -- OS: +- OS: -- Version(s): - -- SLEAP installation method (listed [here](https://sleap.ai/installation.html#)): - - [ ] [Conda from package](https://sleap.ai/installation.html#conda-package) - - [ ] [Conda from source](https://sleap.ai/installation.html#conda-from-source) - - [ ] [pip package](https://sleap.ai/installation.html#pip-package) - - [ ] [Apple Silicon Macs](https://sleap.ai/installation.html#apple-silicon-macs) +- Version(s): + +- SLEAP installation method (listed [here](https://sleap.ai/installation.html#)): + - [ ] [Conda from package](https://sleap.ai/installation.html#conda-package) + - [ ] [Conda from source](https://sleap.ai/installation.html#conda-from-source) + - [ ] [pip package](https://sleap.ai/installation.html#pip-package) + - [ ] [Apple Silicon Macs](https://sleap.ai/installation.html#apple-silicon-macs)
Environment packages diff --git a/.github/workflows/archive/comment-template.yml b/.github/workflows/archive/comment-template.yml new file mode 100644 index 000000000..3bef84531 --- /dev/null +++ b/.github/workflows/archive/comment-template.yml @@ -0,0 +1,71 @@ +name: Reusable Comment Workflow + +on: + workflow_call: + inputs: + subject_id: + required: true + type: string + body_prefix: + required: true + type: string + comment_type: + required: true + type: string + +jobs: + comment: + runs-on: ubuntu-latest + steps: + - name: Post a comment + uses: actions/github-script@v6 + with: + script: | + const { owner, repo } = context.repo; + const subject_id = '${{ inputs.subject_id }}'; + const comment_type = '${{ inputs.comment_type }}'; + const baseBody = ` + We appreciate your input and will review it soon. + + > [!WARNING] + > A friendly reminder that this is a public forum. Please be cautious when clicking links, downloading files, or running scripts posted by others. + > + > - Always verify the credibility of links and code. + > - Avoid running scripts or installing files from untrusted sources. + > - If you're unsure, ask for clarification before proceeding. + + Stay safe and happy SLEAPing! + + Best regards, + The Team + `; + const body = `${{ inputs.body_prefix }}\n\n${baseBody}`; + + const mutation = comment_type === 'discussion' + ? ` + mutation($discussionId: ID!, $body: String!) { + addDiscussionComment(input: {discussionId: $discussionId, body: $body}) { + comment { + id + } + } + } + ` + : ` + mutation($issueId: ID!, $body: String!) { + addComment(input: {subjectId: $issueId, body: $body}) { + commentEdge { + node { + id + body + } + } + } + } + `; + + const variables = comment_type === 'discussion' + ? { discussionId: subject_id, body: body.trim() } + : { issueId: subject_id, body: body.trim() }; + + await github.graphql(mutation, variables); diff --git a/.github/workflows/archive/comment.yml b/.github/workflows/archive/comment.yml new file mode 100644 index 000000000..a24df018f --- /dev/null +++ b/.github/workflows/archive/comment.yml @@ -0,0 +1,24 @@ +name: Comment on New Discussions and Issues + +on: + discussion: + types: [created] + issues: + types: [opened] + +jobs: + comment_on_discussion: + if: github.event_name == 'discussion' + uses: ./.github/workflows/comment-template.yml + with: + subject_id: ${{ github.event.discussion.node_id }} + body_prefix: "Thank you for starting a new discussion!" + comment_type: "discussion" + + comment_on_issue: + if: github.event_name == 'issues' + uses: ./.github/workflows/comment-template.yml + with: + subject_id: ${{ github.event.issue.node_id }} + body_prefix: "Thank you for opening a new issue!" + comment_type: "issue" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 02bc8798b..74203245c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,39 +13,37 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-22.04", "windows-2022", "macos-latest"] + os: ["ubuntu-22.04", "windows-2022", "macos-14"] # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstrategymatrixinclude include: # Use this condarc as default - condarc: .conda/condarc.yaml + - pyver: "3.7" # Use special condarc if macos - - os: "macos-latest" + - os: "macos-14" condarc: .conda_mac/condarc.yaml + pyver: "3.9" steps: # Setup - - uses: actions/checkout@v2 - - name: Cache conda - uses: actions/cache@v1 - env: - # Increase this value to reset cache if environment_build.yml has not changed - CACHE_NUMBER: 0 - with: - path: ~/conda_pkgs_dir - key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('environment_build.yml', 'requirements.txt') }} + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Miniconda - # https://github.com/conda-incubator/setup-miniconda - uses: conda-incubator/setup-miniconda@v2.0.1 + uses: conda-incubator/setup-miniconda@v3.0.3 with: - python-version: 3.7 - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! - environment-file: environment_build.yml + miniforge-version: latest condarc-file: ${{ matrix.condarc }} + python-version: ${{ matrix.pyver }} + environment-file: environment_build.yml activate-environment: sleap_ci + conda-solver: "libmamba" + - name: Print environment info shell: bash -l {0} run: | which python conda info + conda list # Build pip wheel (Ubuntu) - name: Build pip wheel (Ubuntu) @@ -69,21 +67,92 @@ jobs: shell: bash -l {0} run: | conda build .conda --output-folder build + echo "BUILD_PATH=$(pwd)/build" >> "$GITHUB_ENV" - # Build conda package (Windows) + # Build conda package (Windows) - name: Build conda package (Windows) if: matrix.os == 'windows-2022' shell: powershell run: | conda build .conda --output-folder build + echo "BUILD_PATH=\$(pwd)\build" >> "$env:GITHUB_ENV" # Build conda package (Mac) - name: Build conda package (Mac) - if: matrix.os == 'macos-latest' + if: matrix.os == 'macos-14' shell: bash -l {0} run: | conda build .conda_mac --output-folder build + echo "BUILD_PATH=$(pwd)/build" >> "$GITHUB_ENV" + + # Test built conda package (Ubuntu and Windows) + - name: Test built conda package (Ubuntu and Windows) + if: matrix.os != 'macos-14' + shell: bash -l {0} + run: | + echo "Current build path: $BUILD_PATH" + conda deactivate + + echo "Python executable before activating environment:" + which python + echo "Python version before activating environment:" + python --version + echo "Conda info before activating environment:" + conda info + + echo "Creating and testing conda environment with sleap package..." + conda create -y -n sleap_test -c file://$BUILD_PATH -c sleap/label/dev -c conda-forge -c nvidia -c anaconda sleap + conda activate sleap_test + echo "Python executable after activating sleap_test environment:" + which python + echo "Python version after activating sleap_test environment:" + python --version + echo "Conda info after activating sleap_test environment:" + conda info + echo "List of installed conda packages in the sleap_test environment:" + conda list + echo "List of installed pip packages in the sleap_test environment:" + pip list + + echo "Testing sleap package installation..." + sleap_version=$(python -c "import sleap; print(sleap.__version__)") + echo "Test completed using sleap version: $sleap_version" + + # Test built conda package (Mac) + - name: Test built conda package (Mac) + if: matrix.os == 'macos-14' + shell: bash -l {0} + run: | + echo "Current build path: $BUILD_PATH" + conda deactivate + + echo "Python executable before activating environment:" + which python + echo "Python version before activating environment:" + python --version + echo "Conda info before activating environment:" + conda info + + echo "Creating and testing conda environment with sleap package..." + conda create -y -n sleap_test -c file://$BUILD_PATH -c conda-forge -c anaconda sleap + conda activate sleap_test + + echo "Python executable after activating sleap_test environment:" + which python + echo "Python version after activating sleap_test environment:" + python --version + echo "Conda info after activating sleap_test environment:" + conda info + echo "List of installed conda packages in the sleap_test environment:" + conda list + echo "List of installed pip packages in the sleap_test environment:" + pip list + + echo "Testing sleap package installation..." + sleap_version=$(python -c "import sleap; print(sleap.__version__)") + echo "Test completed using sleap version: $sleap_version" + # Login to conda (Ubuntu) - name: Login to Anaconda (Ubuntu) if: matrix.os == 'ubuntu-22.04' @@ -95,7 +164,7 @@ jobs: # Login to conda (Windows) - name: Login to Anaconda (Windows) - if: matrix.os == 'windows-2019' + if: matrix.os == 'windows-2022' env: ANACONDA_LOGIN: ${{ secrets.ANACONDA_LOGIN }} shell: powershell @@ -104,7 +173,7 @@ jobs: # Login to conda (Mac) - name: Login to Anaconda (Mac) - if: matrix.os == 'macos-latest' + if: matrix.os == 'macos-14' env: ANACONDA_LOGIN: ${{ secrets.ANACONDA_LOGIN }} shell: bash -l {0} @@ -113,12 +182,12 @@ jobs: # Upload conda package (Windows) - name: Upload conda package (Windows/main) - if: matrix.os == 'windows-2019' && !github.event.release.prerelease + if: matrix.os == 'windows-2022' && !github.event.release.prerelease shell: powershell run: | anaconda -v upload "build\win-64\*.tar.bz2" - name: Upload conda package (Windows/dev) - if: matrix.os == 'windows-2019' && github.event.release.prerelease + if: matrix.os == 'windows-2022' && github.event.release.prerelease shell: powershell run: | anaconda -v upload "build\win-64\*.tar.bz2" --label dev @@ -137,15 +206,15 @@ jobs: # Upload conda package (Mac) - name: Upload conda package (Mac/main) - if: matrix.os == 'macos-latest' && !github.event.release.prerelease + if: matrix.os == 'macos-14' && !github.event.release.prerelease shell: bash -l {0} run: | - anaconda -v upload build/osx-64/*.tar.bz2 --label dev + anaconda -v upload build/osx-arm64/*.tar.bz2 --label dev - name: Upload conda package (Mac/dev) - if: matrix.os == 'macos-latest' && github.event.release.prerelease + if: matrix.os == 'macos-14' && github.event.release.prerelease shell: bash -l {0} run: | - anaconda -v upload build/osx-64/*.tar.bz2 --label dev + anaconda -v upload build/osx-arm64/*.tar.bz2 --label dev # Logout - name: Logout from Anaconda diff --git a/.github/workflows/build_conda_ci.yml b/.github/workflows/build_conda_ci.yml new file mode 100644 index 000000000..3fd3d2b92 --- /dev/null +++ b/.github/workflows/build_conda_ci.yml @@ -0,0 +1,179 @@ +# Run tests using built conda packages. +name: Build Conda CI (no upload) + +# Run when changes to pip wheel +on: + push: + paths: + - ".conda/meta.yaml" + - ".conda_mac/meta.yaml" + - "setup.py" + - "requirements.txt" + - "dev_requirements.txt" + - "environment_build.yml" + - ".github/workflows/build_conda_ci.yml" # Run! + +# If RUN_BUILD_JOB is set to true, then RUN_ID will be overwritten to the current run id +env: + RUN_BUILD_JOB: true + RUN_ID: 10713717594 # Only used if RUN_BUILD_JOB is false (to dowload build artifact) + +jobs: + build: + name: Build package from push (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ["windows-2022", "ubuntu-22.04", "macos-14"] + # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstrategymatrixinclude + include: + # Use these variables as defaults + - condarc: .conda/condarc.yaml + - conda-folder: .conda + - pyver: "3.10" + - build-prefix: win + - os: "ubuntu-22.04" + build-prefix: linux + # Use special condarc if macos + - os: "macos-14" + condarc: .conda_mac/condarc.yaml + conda-folder: .conda_mac + build-prefix: osx + + steps: + # Setup + - name: Checkout + if: env.RUN_BUILD_JOB == 'true' + uses: actions/checkout@v4 + + - name: Setup Miniconda + if: env.RUN_BUILD_JOB == 'true' + uses: conda-incubator/setup-miniconda@v3.0.4 + with: + miniforge-version: latest + condarc-file: ${{ matrix.condarc }} + python-version: ${{ matrix.pyver }} + environment-file: environment_build.yml + activate-environment: sleap_ci + conda-solver: "libmamba" + + - name: Print build environment info + if: env.RUN_BUILD_JOB == 'true' + shell: bash -l {0} + run: | + which python + conda list + pip freeze + + # Build conda package + - name: Build conda package + if: env.RUN_BUILD_JOB == 'true' + shell: bash -l {0} + run: | + conda build ${{ matrix.conda-folder }} --output-folder build + + # Upload artifact "tests" can use it + - name: Upload conda package artifact + if: env.RUN_BUILD_JOB == 'true' + uses: actions/upload-artifact@v4 + with: + name: sleap-build-${{ matrix.build-prefix }} + path: build # Upload entire build directory + retention-days: 1 + overwrite: true + + tests: + name: Run tests using package (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + needs: build # Ensure the build job has completed before starting this job. + strategy: + fail-fast: false + matrix: + os: ["windows-2022", "ubuntu-22.04", "macos-14"] + # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstrategymatrixinclude + include: + # Default values + - build-prefix: win + - build-suffix: 64 + - test_args: pytest --durations=-1 tests/ + - condarc: .conda/condarc.yaml + - pyver: "3.10" + - conda-channels: -c conda-forge -c nvidia -c anaconda + # Ubuntu specific values + - os: ubuntu-22.04 + build-prefix: linux + # Otherwise core dumped in github actions + test_args: | + sudo apt install xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 + sudo Xvfb :1 -screen 0 1024x768x24 > $GITHUB_ENV + + # https://github.com/actions/download-artifact?tab=readme-ov-file#usage + - name: Download conda package artifact + uses: actions/download-artifact@v4 + id: download + with: + name: sleap-build-${{ matrix.build-prefix }} + path: build + run-id: ${{ env.RUN_ID }} + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: List items in current directory + run: | + ls . + ls -R build + + - name: Setup Miniconda + uses: conda-incubator/setup-miniconda@v3.0.4 + with: + miniforge-version: latest + condarc-file: ${{ matrix.condarc }} + python-version: ${{ matrix.pyver }} + conda-solver: "libmamba" + + - name: Create conda environment + shell: bash -l {0} + run: conda create sleap -y -n sleap_ci -c ./build ${{ matrix.conda-channels }} + + - name: Install packages for testing + shell: bash -l {0} + run: | + conda activate sleap_ci + pip install -r "dev_requirements.txt" + + # Note: "conda activate" does not persist across steps + - name: Print environment info + shell: bash -l {0} + run: | + conda activate sleap_ci + conda info + conda list + pip freeze + + - name: Test package + shell: bash -l {0} + run: | + conda activate sleap_ci + ${{ matrix.test_args}} diff --git a/.github/workflows/build_manual.yml b/.github/workflows/build_manual.yml index ab689342d..7cba65d67 100644 --- a/.github/workflows/build_manual.yml +++ b/.github/workflows/build_manual.yml @@ -8,8 +8,10 @@ on: paths: - '.conda/meta.yaml' - '.conda_mac/meta.yaml' + - '.github/workflows/build_manual.yml' branches: - - develop + # - develop + - fakebranch jobs: build: @@ -18,39 +20,37 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-22.04", "windows-2022", "macos-latest"] + os: ["ubuntu-22.04", "windows-2022", "macos-14"] # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstrategymatrixinclude include: # Use this condarc as default - condarc: .conda/condarc.yaml + - pyver: "3.7" # Use special condarc if macos - - os: "macos-latest" + - os: "macos-14" condarc: .conda_mac/condarc.yaml + pyver: "3.9" steps: # Setup - - uses: actions/checkout@v2 - - name: Cache conda - uses: actions/cache@v1 - env: - # Increase this value to reset cache if environment_build.yml has not changed - CACHE_NUMBER: 0 - with: - path: ~/conda_pkgs_dir - key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('environment_build.yml', 'requirements.txt') }} + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Miniconda - # https://github.com/conda-incubator/setup-miniconda - uses: conda-incubator/setup-miniconda@v2.0.1 + uses: conda-incubator/setup-miniconda@v3.0.3 with: - python-version: 3.7 - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! - environment-file: environment_build.yml + miniforge-version: latest condarc-file: ${{ matrix.condarc }} + python-version: ${{ matrix.pyver }} + environment-file: environment_build.yml activate-environment: sleap_ci + conda-solver: "libmamba" + - name: Print environment info shell: bash -l {0} run: | which python conda info + conda list # Build pip wheel (Not Windows) - name: Build pip wheel (Not Windows) @@ -59,14 +59,14 @@ jobs: run: | python setup.py bdist_wheel - # Upload pip wheel (Ubuntu) - - name: Upload pip wheel (Ubuntu) - if: matrix.os == 'ubuntu-22.04' - env: - PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} - shell: bash -l {0} - run: | - twine upload -u __token__ -p "$PYPI_TOKEN" dist/* --non-interactive --skip-existing --disable-progress-bar + # # Upload pip wheel (Ubuntu) + # - name: Upload pip wheel (Ubuntu) + # if: matrix.os == 'ubuntu-22.04' + # env: + # PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} + # shell: bash -l {0} + # run: | + # twine upload -u __token__ -p "$PYPI_TOKEN" dist/* --non-interactive --skip-existing --disable-progress-bar # Build conda package (Ubuntu) - name: Build conda package (Ubuntu) @@ -74,70 +74,141 @@ jobs: shell: bash -l {0} run: | conda build .conda --output-folder build + echo "BUILD_PATH=$(pwd)/build" >> "$GITHUB_ENV" - # Build conda package (Windows) + # Build conda package (Windows) - name: Build conda package (Windows) if: matrix.os == 'windows-2022' shell: powershell run: | conda build .conda --output-folder build + echo "BUILD_PATH=\$(pwd)\build" >> "$env:GITHUB_ENV" # Build conda package (Mac) - name: Build conda package (Mac) - if: matrix.os == 'macos-latest' + if: matrix.os == 'macos-14' shell: bash -l {0} run: | conda build .conda_mac --output-folder build + echo "BUILD_PATH=$(pwd)/build" >> "$GITHUB_ENV" - # Login to conda (Ubuntu) - - name: Login to Anaconda (Ubuntu) - if: matrix.os == 'ubuntu-22.04' - env: - ANACONDA_LOGIN: ${{ secrets.ANACONDA_LOGIN }} + # Test built conda package (Ubuntu and Windows) + - name: Test built conda package (Ubuntu and Windows) + if: matrix.os != 'macos-14' shell: bash -l {0} run: | - yes 2>/dev/null | anaconda login --username sleap --password "$ANACONDA_LOGIN" || true - - # Login to conda (Windows) - - name: Login to Anaconda (Windows) - if: matrix.os == 'windows-2022' - env: - ANACONDA_LOGIN: ${{ secrets.ANACONDA_LOGIN }} - shell: powershell - run: | - echo "yes" | anaconda login --username sleap --password "$env:ANACONDA_LOGIN" - - # Login to conda (Mac) - - name: Login to Anaconda (Mac) - if: matrix.os == 'macos-latest' - env: - ANACONDA_LOGIN: ${{ secrets.ANACONDA_LOGIN }} + echo "Current build path: $BUILD_PATH" + conda deactivate + + echo "Python executable before activating environment:" + which python + echo "Python version before activating environment:" + python --version + echo "Conda info before activating environment:" + conda info + + echo "Creating and testing conda environment with sleap package..." + conda create -y -n sleap_test -c file://$BUILD_PATH -c sleap/label/dev -c conda-forge -c nvidia -c anaconda sleap + conda activate sleap_test + + echo "Python executable after activating sleap_test environment:" + which python + echo "Python version after activating sleap_test environment:" + python --version + echo "Conda info after activating sleap_test environment:" + conda info + echo "List of installed conda packages in the sleap_test environment:" + conda list + echo "List of installed pip packages in the sleap_test environment:" + pip list + + echo "Testing sleap package installation..." + sleap_version=$(python -c "import sleap; print(sleap.__version__)") + echo "Test completed using sleap version: $sleap_version" + + # Test built conda package (Mac) + - name: Test built conda package (Mac) + if: matrix.os == 'macos-14' shell: bash -l {0} run: | - yes 2>/dev/null | anaconda login --username sleap --password "$ANACONDA_LOGIN" || true + echo "Current build path: $BUILD_PATH" + conda deactivate + + echo "Python executable before activating environment:" + which python + echo "Python version before activating environment:" + python --version + echo "Conda info before activating environment:" + conda info + + echo "Creating and testing conda environment with sleap package..." + conda create -y -n sleap_test -c file://$BUILD_PATH -c conda-forge -c anaconda sleap + conda activate sleap_test + + echo "Python executable after activating sleap_test environment:" + which python + echo "Python version after activating sleap_test environment:" + python --version + echo "Conda info after activating sleap_test environment:" + conda info + echo "List of installed conda packages in the sleap_test environment:" + conda list + echo "List of installed pip packages in the sleap_test environment:" + pip list + + echo "Testing sleap package installation..." + sleap_version=$(python -c "import sleap; print(sleap.__version__)") + echo "Test completed using sleap version: $sleap_version" + + # # Login to conda (Ubuntu) + # - name: Login to Anaconda (Ubuntu) + # if: matrix.os == 'ubuntu-22.04' + # env: + # ANACONDA_LOGIN: ${{ secrets.ANACONDA_LOGIN }} + # shell: bash -l {0} + # run: | + # yes 2>/dev/null | anaconda login --username sleap --password "$ANACONDA_LOGIN" || true - # Upload conda package (Windows) - - name: Upload conda package (Windows/dev) - if: matrix.os == 'windows-2022' - shell: powershell - run: | - anaconda -v upload "build\win-64\*.tar.bz2" --label dev + # # Login to conda (Windows) + # - name: Login to Anaconda (Windows) + # if: matrix.os == 'windows-2022' + # env: + # ANACONDA_LOGIN: ${{ secrets.ANACONDA_LOGIN }} + # shell: powershell + # run: | + # echo "yes" | anaconda login --username sleap --password "$env:ANACONDA_LOGIN" - # Upload conda package (Ubuntu) - - name: Upload conda package (Ubuntu/dev) - if: matrix.os == 'ubuntu-22.04' - shell: bash -l {0} - run: | - anaconda -v upload build/linux-64/*.tar.bz2 --label dev + # # Login to conda (Mac) + # - name: Login to Anaconda (Mac) + # if: matrix.os == 'macos-14' + # env: + # ANACONDA_LOGIN: ${{ secrets.ANACONDA_LOGIN }} + # shell: bash -l {0} + # run: | + # yes 2>/dev/null | anaconda login --username sleap --password "$ANACONDA_LOGIN" || true - # Upload conda package (Mac) - - name: Upload conda package (Mac/dev) - if: matrix.os == 'macos-latest' - shell: bash -l {0} - run: | - anaconda -v upload build/osx-64/*.tar.bz2 --label dev + # # Upload conda package (Windows) + # - name: Upload conda package (Windows/dev) + # if: matrix.os == 'windows-2022' + # shell: powershell + # run: | + # anaconda -v upload "build\win-64\*.tar.bz2" --label dev - - name: Logout from Anaconda - shell: bash -l {0} - run: | - anaconda logout + # # Upload conda package (Ubuntu) + # - name: Upload conda package (Ubuntu/dev) + # if: matrix.os == 'ubuntu-22.04' + # shell: bash -l {0} + # run: | + # anaconda -v upload build/linux-64/*.tar.bz2 --label dev + + # # Upload conda package (Mac) + # - name: Upload conda package (Mac/dev) + # if: matrix.os == 'macos-14' + # shell: bash -l {0} + # run: | + # anaconda -v upload build/osx-arm64/*.tar.bz2 --label dev + + # - name: Logout from Anaconda + # shell: bash -l {0} + # run: | + # anaconda logout diff --git a/.github/workflows/build_ci.yml b/.github/workflows/build_pypi_ci.yml similarity index 71% rename from .github/workflows/build_ci.yml rename to .github/workflows/build_pypi_ci.yml index baf046295..68142b288 100644 --- a/.github/workflows/build_ci.yml +++ b/.github/workflows/build_pypi_ci.yml @@ -1,17 +1,17 @@ -# Run tests using built conda packages and wheels. -name: Build CI (no upload) +# Run tests using built wheels. +name: Build PyPI CI (no upload) # Run when changes to pip wheel on: push: paths: - - 'setup.py' - - 'requirements.txt' - - 'dev_requirements.txt' - - 'jupyter_requirements.txt' - - 'pypi_requirements.txt' - - 'environment_build.yml' - - '.github/workflows/build_ci.yml' + - "setup.py" + - "requirements.txt" + - "dev_requirements.txt" + - "jupyter_requirements.txt" + - "pypi_requirements.txt" + - "environment_build.yml" + - ".github/workflows/build_pypi_ci.yml" # Run! jobs: build: @@ -26,28 +26,21 @@ jobs: # Use this condarc as default - condarc: .conda/condarc.yaml - wheel_name: sleap-wheel-linux + - pyver: "3.7" steps: # Setup - - uses: actions/checkout@v2 + - name: Checkout + uses: actions/checkout@v4 - - name: Cache conda - uses: actions/cache@v1 - env: - # Increase this value to reset cache if environment_build.yml has not changed - CACHE_NUMBER: 0 + - name: Setup Miniconda + uses: conda-incubator/setup-miniconda@v3.0.3 with: - path: ~/conda_pkgs_dir - key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('environment_build.yml', 'pyproject.toml') }} - - - name: Setup Miniconda for Build - # https://github.com/conda-incubator/setup-miniconda - uses: conda-incubator/setup-miniconda@v2.0.1 - with: - python-version: 3.7 - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! - environment-file: environment_build.yml + miniforge-version: latest condarc-file: ${{ matrix.condarc }} + python-version: ${{ matrix.pyver }} + environment-file: environment_build.yml activate-environment: sleap_ci + conda-solver: "libmamba" - name: Print build environment info shell: bash -l {0} @@ -61,10 +54,10 @@ jobs: shell: bash -l {0} run: | python setup.py bdist_wheel - + # Upload artifact "tests" can use it - name: Upload wheel artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ matrix.wheel_name }} path: dist/*.whl @@ -73,11 +66,12 @@ jobs: tests: name: Run tests using wheel (${{ matrix.os }}) runs-on: ${{ matrix.os }} - needs: build # Ensure the build job has completed before starting this job. + needs: build # Ensure the build job has completed before starting this job. strategy: fail-fast: false matrix: - os: ["ubuntu-22.04", "windows-2022", "macos-latest"] + os: ["ubuntu-22.04", "windows-2022"] + # os: ["ubuntu-22.04", "windows-2022", "macos-14"] # removing macos-14 for now since the setup-python action only support py>=3.10, which is breaking this CI. # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstrategymatrixinclude include: # Default values @@ -89,9 +83,11 @@ jobs: pip install '$wheel_path'[dev] - test_args: pytest --durations=-1 tests/ - condarc: .conda/condarc.yaml + - pyver: "3.7" # Use special condarc if macos - - os: "macos-latest" + - os: "macos-14" condarc: .conda_mac/condarc.yaml + pyver: "3.10" # Ubuntu specific values - os: ubuntu-22.04 # Otherwise core dumped in github actions @@ -106,16 +102,16 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v3 - - - name: Set up Python 3.7 - uses: actions/setup-python@v4 + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 with: - python-version: 3.7 - + python-version: ${{ matrix.pyver }} + # Download wheel - name: Download wheel artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 id: download with: name: ${{ matrix.wheel_name }} @@ -148,8 +144,8 @@ jobs: run: | which python pip freeze - + # Install and test the wheel - name: Test the built wheel run: | - ${{ matrix.test_args}} \ No newline at end of file + ${{ matrix.test_args}} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e1d193724..84b028fc3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,45 +11,28 @@ on: - "environment_no_cuda.yml" - "requirements.txt" - "dev_requirements.txt" - push: - branches: - - master - - develop - paths: - - "sleap/**" - - "tests/**" - - ".github/workflows/ci.yml" - - "environment_no_cuda.yml" - - "requirements.txt" - - "dev_requirements.txt" + # push: + # branches: + # - main + # - develop + # paths: + # - "sleap/**" + # - "tests/**" + # - ".github/workflows/ci.yml" + # - "environment_no_cuda.yml" + # - "requirements.txt" + # - "dev_requirements.txt" jobs: - type_check: - name: Type Check - runs-on: "ubuntu-22.04" - steps: - - name: Checkout repo - uses: actions/checkout@v3 - - name: Set up Python 3.7 - uses: actions/setup-python@v4 - with: - python-version: 3.7 - - name: Install Dependencies - run: | - pip install mypy - - name: Run MyPy - # TODO: remove this once all MyPy errors get fixed - continue-on-error: true - run: | - mypy --follow-imports=skip --ignore-missing-imports sleap tests + # Lint lint: name: Lint runs-on: "ubuntu-22.04" steps: - name: Checkout repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python 3.7 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.7 - name: Install Dependencies @@ -59,58 +42,52 @@ jobs: - name: Run Black run: | black --check sleap tests + + # Tests tests: name: Tests (${{ matrix.os }}) runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - os: ["ubuntu-22.04", "windows-2022", "macos-latest"] + os: ["ubuntu-22.04", "windows-2022", "macos-14"] include: # Default values - env_file: environment_no_cuda.yml - - test_args: --durations=-1 tests/ # Mac specific values - - os: macos-latest + - os: macos-14 env_file: environment_mac.yml - # Ubuntu specific values - - os: ubuntu-22.04 - test_args: --cov=sleap --cov-report=xml --durations=-1 tests/ + steps: - name: Checkout repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - name: Setup Micromamba - # https://github.com/mamba-org/setup-micromamba - uses: mamba-org/setup-micromamba@v1 + - name: Setup Conda + uses: conda-incubator/setup-miniconda@v3.0.3 with: - micromamba-version: '1.4.6-0' + miniforge-version: latest + conda-solver: "libmamba" environment-file: ${{ matrix.env_file }} - environment-name: sleap_ci - init-shell: >- - bash - powershell - post-cleanup: all + activate-environment: sleap_ci # Print environment info - name: Print environment info shell: bash -l {0} run: | which python - micromamba info - micromamba list + conda info + conda list pip freeze # Test environment - name: Test with pytest shell: bash -l {0} run: | - pytest ${{ matrix.test_args }} + pytest --cov=sleap --cov-report=xml --durations=-1 tests/ # Upload coverage - name: Upload coverage - uses: codecov/codecov-action@v1 - if: matrix.os == 'ubuntu-22.04' + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true diff --git a/.github/workflows/website.yml b/.github/workflows/website.yml index 7db6b4d74..36c1d6ad7 100644 --- a/.github/workflows/website.yml +++ b/.github/workflows/website.yml @@ -7,8 +7,8 @@ on: branches: # 'main' triggers updates to 'sleap.ai', all others to 'sleap.ai/develop' - main - - develop - - liezl/add-pip-extras + - develop # Run + - liezl/bump-to-1.4.1 paths: - "docs/**" - "README.rst" @@ -20,21 +20,13 @@ jobs: steps: # Setup - name: Checkout - uses: actions/checkout@v2 - - name: Cache conda - uses: actions/cache@v1 - env: - # Increase this value to reset cache if environment_build.yml has not changed - CACHE_NUMBER: 0 - with: - path: ~/conda_pkgs_dir - key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('environment_build.yml', 'requirements.txt') }} + uses: actions/checkout@v4 + - name: Setup Miniconda # https://github.com/conda-incubator/setup-miniconda - uses: conda-incubator/setup-miniconda@v2.0.1 + uses: conda-incubator/setup-miniconda@v3.0.3 with: python-version: 3.7 - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! environment-file: environment_no_cuda.yml activate-environment: sleap_ci - name: Print environment info @@ -51,7 +43,7 @@ jobs: make html - name: Deploy (sleap.ai) - uses: peaceiris/actions-gh-pages@v3 + uses: peaceiris/actions-gh-pages@v4 if: github.ref == 'refs/heads/main' with: github_token: ${{ secrets.GITHUB_TOKEN }} @@ -60,10 +52,10 @@ jobs: keep_files: true - name: Deploy (test) - uses: peaceiris/actions-gh-pages@v3 + uses: peaceiris/actions-gh-pages@v4 if: github.ref != 'refs/heads/main' with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_branch: gh-pages publish_dir: docs/build/html - destination_dir: develop \ No newline at end of file + destination_dir: develop diff --git a/AUTHORS b/AUTHORS index e6a78d2ba..11e40e839 100644 --- a/AUTHORS +++ b/AUTHORS @@ -11,3 +11,5 @@ John Smith Example Inc. Jeremy Delahanty The Salk Institute for Biological Studies + +Lili Karashchuk Allen Institute of Neural Dynamics diff --git a/README.rst b/README.rst index dbc5a7cac..f7a5acd6c 100644 --- a/README.rst +++ b/README.rst @@ -69,7 +69,7 @@ Quick install .. code-block:: bash - conda create -y -n sleap -c conda-forge -c nvidia -c sleap -c anaconda sleap + conda create -y -n sleap -c conda-forge -c nvidia -c sleap/label/dev -c sleap -c anaconda sleap `pip` **(any OS except Apple silicon)**: @@ -84,7 +84,7 @@ Learn to SLEAP -------------- - **Learn step-by-step**: `Tutorial `_ - **Learn more advanced usage**: `Guides `__ and `Notebooks `__ -- **Learn by watching**: `MIT CBMM Tutorial `_ +- **Learn by watching**: `ABL:AOC 2023 Workshop `_ and `MIT CBMM Tutorial `_ - **Learn by reading**: `Paper (Pereira et al., Nature Methods, 2022) `__ and `Review on behavioral quantification (Pereira et al., Nature Neuroscience, 2020) `_ - **Learn from others**: `Discussions on Github `_ diff --git a/dev_requirements.txt b/dev_requirements.txt index f7bb23643..709fb48fd 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -5,13 +5,14 @@ pytest-qt>=4.0.0 pytest-cov<=3.0.0 pytest-xvfb ipython -sphinx +sphinx>=5.0 # sphinxcontrib.applehelp extension needs at least Sphinx v5.0 # furo sphinx-book-theme sphinx-copybutton +sphinx-tabs nbformat==5.1.3 -myst-nb==0.13.2 -myst-parser==0.15.2 +myst-nb>=0.16.0 # sphinx>=5.0 needs myst-nb>=0.16.0 +myst-parser linkify-it-py sphinx-autobuild black==21.6b0 diff --git a/docs/_static/bonsai-connection.jpg b/docs/_static/bonsai-connection.jpg new file mode 100644 index 000000000..32b725416 Binary files /dev/null and b/docs/_static/bonsai-connection.jpg differ diff --git a/docs/_static/bonsai-filecapture.jpg b/docs/_static/bonsai-filecapture.jpg new file mode 100644 index 000000000..7a809d67a Binary files /dev/null and b/docs/_static/bonsai-filecapture.jpg differ diff --git a/docs/_static/bonsai-predictcentroids.jpg b/docs/_static/bonsai-predictcentroids.jpg new file mode 100644 index 000000000..e284f2338 Binary files /dev/null and b/docs/_static/bonsai-predictcentroids.jpg differ diff --git a/docs/_static/bonsai-predictposeidentities.jpg b/docs/_static/bonsai-predictposeidentities.jpg new file mode 100644 index 000000000..8582fd707 Binary files /dev/null and b/docs/_static/bonsai-predictposeidentities.jpg differ diff --git a/docs/_static/bonsai-predictposes.jpg b/docs/_static/bonsai-predictposes.jpg new file mode 100644 index 000000000..2e4f04a22 Binary files /dev/null and b/docs/_static/bonsai-predictposes.jpg differ diff --git a/docs/_static/bonsai-workflow.jpg b/docs/_static/bonsai-workflow.jpg new file mode 100644 index 000000000..0481c3dcf Binary files /dev/null and b/docs/_static/bonsai-workflow.jpg differ diff --git a/docs/_static/css/tabs.css b/docs/_static/css/tabs.css new file mode 100644 index 000000000..95765dff6 --- /dev/null +++ b/docs/_static/css/tabs.css @@ -0,0 +1,91 @@ +.sphinx-tabs { + margin-bottom: 1rem; +} + +[role="tablist"] { + border-bottom: 1px solid #a0b3bf; +} + +.sphinx-tabs-tab { + position: relative; + font-family: Lato,'Helvetica Neue',Arial,Helvetica,sans-serif; + color: var(--pst-color-link); + line-height: 24px; + margin: 3px; + font-size: 16px; + font-weight: 400; + background-color: rgb(241 244 249); + border-radius: 5px 5px 0 0; + border: 0; + padding: 1rem 1.5rem; + margin-bottom: 0; +} + +.sphinx-tabs-tab[aria-selected="true"] { + font-weight: 700; + border: 1px solid #a0b3bf; + border-bottom: 1px solid rgb(241 244 249); + margin: -1px; + background-color: rgb(242 247 255); +} + +.admonition .sphinx-tabs-tab[aria-selected="true"]:last-child { + margin-bottom: -1px; +} + +.sphinx-tabs-tab:focus { + z-index: 1; + outline-offset: 1px; +} + +.sphinx-tabs-panel { + position: relative; + padding: 1rem; + border: 1px solid #a0b3bf; + margin: 0px -1px -1px -1px; + border-radius: 0 0 5px 5px; + border-top: 0; + background: rgb(242 247 255); +} + +.sphinx-tabs-panel.code-tab { + padding: 0.4rem; +} + +.sphinx-tab img { + margin-bottom: 24px; +} + +/* Dark theme preference styling */ + +html[data-theme="dark"] .sphinx-tabs-panel { + color: white; + background-color: rgb(50, 50, 50); +} + +html[data-theme="dark"] .sphinx-tabs-tab { + color: var(--pst-color-link); + background-color: rgba(255, 255, 255, 0.05); +} + +html[data-theme="dark"] .sphinx-tabs-tab[aria-selected="true"] { + border-bottom: 2px solid rgb(50, 50, 50); + background-color: rgb(50, 50, 50); +} + +/* Light theme preference styling */ + +html[data-theme="light"] .sphinx-tabs-panel { + color: black; + background-color: white; +} + +html[data-theme="light"] .sphinx-tabs-tab { + color: var(--pst-color-link); + background-color: rgba(0, 0, 0, 0.05); +} + +html[data-theme="light"] .sphinx-tabs-tab[aria-selected="true"] { + border-bottom: 2px solid white; + background-color: white; +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 572e73ea0..074869903 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,7 +28,7 @@ copyright = f"2019–{date.today().year}, Talmo Lab" # The short X.Y version -version = "1.3.3" +version = "1.4.1" # Get the sleap version # with open("../sleap/version.py") as f: @@ -36,7 +36,7 @@ # version = re.search("\d.+(?=['\"])", version_file).group(0) # Release should be the full branch name -release = "v1.3.3" +release = "v1.4.1" html_title = f"SLEAP ({release})" html_short_title = "SLEAP" @@ -59,6 +59,7 @@ "sphinx.ext.linkcode", "sphinx.ext.napoleon", "sphinx_copybutton", + "sphinx_tabs.tabs", # For tabs inside docs # https://myst-nb.readthedocs.io/en/latest/ "myst_nb", ] @@ -85,6 +86,7 @@ pygments_style = "sphinx" pygments_dark_style = "monokai" + # Autosummary linkcode resolution # https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html def linkcode_resolve(domain, info): @@ -173,6 +175,12 @@ def linkcode_resolve(domain, info): # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] +# These paths are either relative to html_static_path +# or fully qualified paths (eg. https://...) +html_css_files = [ + "css/tabs.css", +] + # Custom sidebar templates, must be a dictionary that maps document names # to template names. # @@ -219,3 +227,7 @@ def linkcode_resolve(domain, info): # https://myst-nb.readthedocs.io/en/latest/use/config-reference.html jupyter_execute_notebooks = "off" + +# Sphinx-tabs settings +# https://sphinx-tabs.readthedocs.io/en/latest/ +sphinx_tabs_disable_css_loading = True # Use the theme's CSS diff --git a/docs/guides/bonsai.md b/docs/guides/bonsai.md new file mode 100644 index 000000000..d262873b6 --- /dev/null +++ b/docs/guides/bonsai.md @@ -0,0 +1,75 @@ +(bonsai)= + +# Using Bonsai with SLEAP + +Bonsai is a visual language for reactive programming and currently supports SLEAP models. + +:::{note} +Currently Bonsai supports only single instance, top-down and top-down-id SLEAP models. +::: + +### Exporting a SLEAP trained model + +Before we can import a trained model into Bonsai, we need to use the {code}`sleap-export` command to convert the model to a format supported by Bonsai. For example, to export a top-down-id model, the command is as follows: + +```bash +sleap-export -m centroid/model/folder/path -m top_down_id/model/folder/path -e exported/model/path +``` + +Please refer to the {ref}`sleap-export` docs for more details on using the command. + +This will generate the necessary `.pb` file and other information files required by Bonsai. In this example, these files were saved to the specified `exported/model/path` folder. + +The `exported/model/path` folder will have a structure like the following: + +```plaintext +exported/model/path +├── centroid_config.json +├── confmap_config.json +├── frozen_graph.pb +└── info.json +``` + +### Installing Bonsai and necessary packages + +1. Install Bonsai. See the [Bonsai installation instructions](https://bonsai-rx.org/docs/articles/installation.html). + +2. Download and add the necessary packages for Bonsai to run with SLEAP. See the official [Bonsai SLEAP documentation](https://github.com/bonsai-rx/sleap?tab=readme-ov-file#bonsai---sleap) for more information. + +### Using Bonsai SLEAP modules + +Once you have Bonsai installed with the required packages, you should be able to open the Bonsai application. The workflow must have a source module `FileCapture` which can be found in the toolbox search in the workflow editor. Provide the path to the video that was used to train the SLEAP model in the `FileName` field of the module. + +![Bonsai FileCapture module](../_static/bonsai-filecapture.jpg) + +#### Top-down model +The top-down model requires both the `PredictCentroids` and the `PredictPoses` modules. + +The `PredictCentroids` module will predict the centroids of detections. There are two fields inside the `PredictCentroids` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centroid model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictCentroids module](../_static/bonsai-predictcentroids.jpg) + +The `PredictPoses` module will predict the instances of detections. Similar to the `PredictCentroid` module, there are two fields inside the `PredictPoses` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centered instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictPoses module](../_static/bonsai-predictposes.jpg) + +#### Top-Down-ID model +The `PredictPoseIdentities` module will predict the instances with identities. This module has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the top-down-id model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictPoseIdentities module](../_static/bonsai-predictposeidentities.jpg) + +#### Single instance model +The `PredictSinglePose` module will predict the poses for single instance models. This module also has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the single instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +### Connecting the modules +Right-click on the `FileCapture` module and select **Create Connection**. Now click on the required SLEAP module to complete the connection. + +![Bonsai module connection ](../_static/bonsai-connection.jpg) + +Once it is done, the workflow in Bonsai will look something like the following: + +![Bonsai.SLEAP workflow](../_static/bonsai-workflow.jpg) + +Now you can click the green start button to run the workflow and you can add more modules to analyze and visualize the results in Bonsai. + +For more documentation on various modules and workflows, please refer to the [official Bonsai docs](https://bonsai-rx.org/docs/articles/editor.html). diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 35ea52171..134461c60 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -36,8 +36,8 @@ optional arguments: ```none usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS] - [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] - [--zmq] [--run_name RUN_NAME] [--prefix PREFIX] + [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] + [--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX] [--suffix SUFFIX] training_job_path [labels_path] @@ -68,6 +68,8 @@ optional arguments: --save_viz Enable saving of prediction visualizations to the run folder if not already specified in the training job config. + --keep_viz Keep prediction visualization images in the run + folder after training if --save_viz is enabled. --zmq Enable ZMQ logging (for GUI) if not already specified in the training job config. --run_name RUN_NAME Run name to use when saving file, overrides other run @@ -99,9 +101,9 @@ optional arguments: -e [EXPORT_PATH], --export_path [EXPORT_PATH] Path to output directory where the frozen model will be exported to. Defaults to a folder named 'exported_model'. - -u, --unrag UNRAG - Convert ragged tensors into regular tensors with NaN padding. - Defaults to True. + -r, --ragged RAGGED + Keep tensors ragged if present. If ommited, convert + ragged tensors into regular tensors with NaN padding. -n, --max_instances MAX_INSTANCES Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None. @@ -136,7 +138,10 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [- [data_path] positional arguments: - data_path Path to data to predict on. This can be a labels (.slp) file or any supported video format. + data_path Path to data to predict on. This can be one of the following: A .slp file containing labeled data; A folder containing multiple + video files in supported formats; An individual video file in a supported format; A CSV file with a column of video file paths. + If more than one column is provided in the CSV file, the first will be used for the input data paths and the next column will be + used as the output paths; A text file with a path to a video file on each line optional arguments: -h, --help show this help message and exit @@ -151,7 +156,7 @@ optional arguments: Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for initialization during labeling. -o OUTPUT, --output OUTPUT - The output filename to use for the predicted data. If not provided, defaults to '[data_path].predictions.slp'. + The output filename or directory path to use for the predicted data. If not provided, defaults to '[data_path].predictions.slp'. --no-empty-frames Clear any empty frames that did not have any detected instances before saving to output. --verbosity {none,rich,json} Verbosity of inference progress reporting. 'none' does not output anything during inference, 'rich' displays an updating @@ -202,7 +207,7 @@ optional arguments: --tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD IOU to use when culling instances *after* tracking. (default: 0) --tracking.similarity TRACKING.SIMILARITY - Options: instance, centroid, iou (default: instance) + Options: instance, normalized_instance, object_keypoint, centroid, iou (default: instance) --tracking.match TRACKING.MATCH Options: hungarian, greedy (default: greedy) --tracking.robust TRACKING.ROBUST @@ -322,7 +327,8 @@ optional arguments: analysis file for the latter video is given a default name. --format FORMAT Output format. Default ('slp') is SLEAP dataset; 'analysis' results in analysis.h5 file; 'analysis.nix' results - in an analysis nix file; 'h5' or 'json' results in SLEAP dataset + in an analysis nix file; 'analysis.csv' results + in an analysis csv file; 'h5' or 'json' results in SLEAP dataset with specified file format. --video VIDEO Path to video (if needed for conversion). ``` @@ -389,6 +395,9 @@ optional arguments: --distinctly_color DISTINCTLY_COLOR Specify how to color instances. Options include: "instances", "edges", and "nodes" (default: "instances") + --background BACKGROUND + Specify the type of background to be used to save the videos. + Options: original, black, white and grey. (default: "original") ``` ## Debugging diff --git a/docs/guides/gui.md b/docs/guides/gui.md index 88cf3f656..813ed68fa 100644 --- a/docs/guides/gui.md +++ b/docs/guides/gui.md @@ -60,7 +60,7 @@ Note that many of the menu command have keyboard shortcuts which can be configur "**Edge Style**" controls whether edges are drawn as thin lines or as wedges which indicate the {ref}`orientation` of the instance (as well as the direction of the part affinity field which would be used to predict the connection between nodes when using a "bottom-up" approach). -"**Trail Length**" allows you to show a trail of where each instance was located in prior frames (the length of the trail is the number of prior frames). This can be useful when proofreading predictions since it can help you detect swaps in the identities of animals across frames. +"**Trail Length**" allows you to show a trail of where each instance was located in prior frames (the length of the trail is the number of prior frames). This can be useful when proofreading predictions since it can help you detect swaps in the identities of animals across frames. By default, you can only select trail lengths of up to 250 frames. You can use a custom trail length by modifying the default length in the `preferences.yaml` file. However, using trail lengths longer than about 500 frames can result in significant lag. "**Fit Instances to View**" allows you to toggle whether the view is auto-zoomed to the instances in each frame. This can be useful when proofreading predictions. diff --git a/docs/guides/index.md b/docs/guides/index.md index 7eb55b2b2..6d773d9de 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -30,6 +30,10 @@ {ref}`remote-inference` when you trained models and you want to run inference on a different machine using a **command-line interface**. +## SLEAP with Bonsai + +{ref}`bonsai` when you want to analyze the trained SLEAP model to visualize the poses, centroids and identities for further visual analysis. + ```{toctree} :hidden: true :maxdepth: 2 @@ -44,4 +48,5 @@ proofreading colab custom-training remote +bonsai ``` diff --git a/docs/guides/proofreading.md b/docs/guides/proofreading.md index fea1c5ebc..941b85154 100644 --- a/docs/guides/proofreading.md +++ b/docs/guides/proofreading.md @@ -50,6 +50,8 @@ There are currently three methods for matching instances in frame N against thes - “**centroid**” measures similarity by the distance between the instance centroids - “**iou**” measures similarity by the intersection/overlap of the instance bounding boxes - “**instance**” measures similarity by looking at the distances between corresponding nodes in the instances, normalized by the number of valid nodes in the candidate instance. +- “**normalized_instance**” measures similarity by looking at the distances between corresponding nodes in the instances, normalized by the number of valid nodes in the candidate instance and the keypoints normalized by the image size. +- “**object_keypoint**” measures similarity by measuring the distance between each keypoints from a reference instance and a query instance, takes the exp(-d**2), sum for all the keypoints and divide by the number of visible keypoints in the reference instance. Once SLEAP has measured the similarity between all the candidates and the instances in frame N, you need to choose a way to pair them up. You can do this either by picking the best match, and the picking the best remaining match for each remaining instance in turn—this is “**greedy**” matching—or you can find the way of matching identities which minimizes the total cost (or: maximizes the total similarity)—this is “**Hungarian**” matching. diff --git a/docs/installation.md b/docs/installation.md index eea65cc31..2c1ef41be 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,10 +1,21 @@ # Installation -SLEAP can be installed as a Python package on Windows, Linux, Mac OS X, and Mac OS Apple Silicon. +SLEAP can be installed as a Python package on Windows, Linux, and Mac OS. For quick install using conda, see below: -SLEAP requires many complex dependencies, so we **strongly** recommend using [Mambaforge](https://mamba.readthedocs.io/en/latest/installation.html) to install it in its own isolated environment. See {ref}`Installing Mambaforge` below for more instructions. +````{tabs} + ```{group-tab} Windows and Linux + ```bash + conda create -y -n sleap -c conda-forge -c nvidia -c sleap/label/dev -c sleap -c anaconda sleap=1.4.1 + ``` + ``` + ```{group-tab} Mac OS + ```bash + conda create -y -n sleap -c conda-forge -c anaconda -c sleap sleap=1.4.1 + ``` + ``` +```` -The newest version of SLEAP can always be found in the [Releases page](https://github.com/talmolab/sleap/releases). +. For more in-depth installation instructions, see the [installation methods](installation-methods). The newest version of SLEAP can always be found in the [Releases page](https://github.com/talmolab/sleap/releases). ```{contents} Contents --- @@ -12,66 +23,30 @@ local: --- ``` -````{hint} -Installation requires entering commands in a terminal. To open one: - -**Windows:** Open the *Start menu* and search for the *Miniforge Prompt* (if using Mambaforge) or the *Command Prompt* if not. -```{note} -On Windows, our personal preference is to use alternative terminal apps like [Cmder](https://cmder.net) or [Windows Terminal](https://aka.ms/terminal). -``` - -**Linux:** Launch a new terminal by pressing Ctrl + Alt + T. - -**Mac:** Launch a new terminal by pressing Cmd + Space and searching for _Terminal_. - -```` - -(apple-silicon)= - -### Macs Pre-M1 (Pre-Installation) - -SLEAP can be installed on Macs by following these instructions: - -1. Make sure you're on **macOS Monterey** or later, i.e., version 12+. - -2. If you don't have it yet, [install **homebrew**](https://brew.sh/), a convenient package manager for Macs (skip this if you can run `brew` from the terminal): - - ```bash - /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" - ``` - - This might take a little while since it'll also install Xcode (which we'll need later). Once it's finished, your terminal should give you two extra commands to run listed under **Next Steps**. - - ````{note} - We recommend running the commands given in your terminal which will be similar to (but may differ slightly) from the commands below: - ```bash - echo 'eval "$(/opt/homebrew/bin/brew shellenv)"' >> ~/.zprofile - ``` - - ```bash - eval "$(/opt/homebrew/bin/brew shellenv)" - ``` - +`````{hint} + Installation requires entering commands in a terminal. To open one: + ````{tabs} + ```{tab} Windows + Open the *Start menu* and search for the *Anaconda Prompt* (if using Miniconda) or the *Command Prompt* if not. + ```{note} + On Windows, our personal preference is to use alternative terminal apps like [Cmder](https://cmder.net) or [Windows Terminal](https://aka.ms/terminal). + ``` + ``` + ```{tab} Linux + Launch a new terminal by pressing Ctrl + Alt + T. + ``` + ```{group-tab} Mac OS + Launch a new terminal by pressing Cmd + Space and searching for _Terminal_. + ``` ```` +````` - Then, close and re-open the terminal for it to take effect. +## Package Manager -3. Install wget, a CLI downloading utility (also makes sure your homebrew setup worked): - - ```bash - brew install wget - ``` - -(mambaforge)= - -## Installing Mambaforge - -**Anaconda** is a Python environment manager that makes it easy to install SLEAP and its necessary dependencies without affecting other Python software on your computer. - -[**Mambaforge**](https://mamba.readthedocs.io/en/latest/installation.html) is a lightweight installer of Anaconda with speedy package resolution that we recommend. +SLEAP requires many complex dependencies, so we **strongly** recommend using a package manager such as [Miniforge](https://github.com/conda-forge/miniforge) or [Miniconda](https://docs.anaconda.com/free/miniconda/) to install SLEAP in its own isolated environment. ````{note} -If you already have Anaconda on your computer, then you can [set the solver to `libmamba`](https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community) in the `base` environment (and skip the Mambaforge installation): +If you already have Anaconda on your computer (and it is an [older installation](https://conda.org/blog/2023-11-06-conda-23-10-0-release/)), then make sure to [set the solver to `libmamba`](https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community) in the `base` environment. ```bash conda update -n base conda @@ -80,195 +55,220 @@ conda config --set solver libmamba ``` ```{warning} -Any subsequent `mamba` commands in the docs will need to be replaced with `conda` if you choose to use your existing Anaconda installation. +Any subsequent `conda` commands in the docs will need to be replaced with `mamba` if you have [Mamba](https://mamba.readthedocs.io/en/latest/) installed instead of Anaconda or Miniconda. ``` ```` -Otherwise, to install Mamba: - -**On Windows**, just click through the installation steps. - -1. Go to: https://github.com/conda-forge/miniforge#mambaforge -2. Download the latest version for your OS. -3. Follow the installer instructions. - -We recommend using the following settings: - -- Install for: All Users (requires admin privileges) -- Destination folder: `C:\mambaforge` -- Advanced Options: Add MambaForge to the system PATH environment variable -- Advanced Options: Register MambaForge as the system Python 3.X - These will make sure that MambaForge is easily accessible from most places on your computer. - -**On Linux**, it might be easier to do this straight from the terminal (Ctrl + Alt + T) with this one-liner: - -```bash -wget -nc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh && bash Mambaforge-Linux-x86_64.sh -b && ~/mambaforge/bin/conda init bash -``` - -Restart the terminal after running this command. - -```{note} -For other Linux architectures (arm64 and POWER8/9), replace the `.sh` filenames above with the correct installer name for your architecture. See the Download column in [this table](https://github.com/conda-forge/miniforge#mambaforge) for the correct filename. - -``` - -**On Macs (pre-M1)**, you can run the installer using this terminal command: +If you don't have a `conda` package manager installation, here are some quick install options: -```bash -wget -nc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-MacOSX-x86_64.sh && bash Mambaforge-MacOSX-x86_64.sh -b && ~/mambaforge/bin/conda init zsh -``` +### Miniforge (recommended) -**On Macs (Apple Silicon)**, use this terminal command: +Miniforge is a minimal installer for conda that includes the `conda` package manager and is maintained by the [conda-forge](https://conda-forge.org) community. The only difference between Miniforge and Miniconda is that Miniforge uses the `conda-forge` channel by default, which provides a much wider selection of community-maintained packages. -```bash -curl -fsSL --compressed https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-MacOSX-arm64.sh -o Mambaforge3-MacOSX-arm64.sh && chmod +x Mambaforge3-MacOSX-arm64.sh && ./Mambaforge3-MacOSX-arm64.sh -b -p ~/mambaforge3 && rm Mambaforge3-MacOSX-arm64.sh && ~/mambaforge3/bin/conda init "$(basename "${SHELL}")" && source "$HOME/.$(basename "${SHELL}")rc" -``` +````{tabs} + ```{group-tab} Windows + Open a new PowerShell terminal (does not need to be admin) and enter: -## Installation methods - -SLEAP can be installed three different ways: via {ref}`conda package`, {ref}`conda from source`, or {ref}`pip package`. Select one of the methods below to install SLEAP. We recommend {ref}`conda package`. - -(condapackage)= - -### `conda` package - -**Windows** and **Linux** - -```bash -mamba create -y -n sleap -c conda-forge -c nvidia -c sleap -c anaconda sleap=1.3.3 -``` - -**Mac OS X** and **Apple Silicon** - -```bash -mamba create -y -n sleap -c conda-forge -c anaconda -c sleap sleap=1.3.3 -``` - -**This is the recommended installation method**. - -```{note} -- This comes with CUDA to enable GPU support. All you need is to have an NVIDIA GPU and [updated drivers](https://nvidia.com/drivers). -- If you already have CUDA installed on your system, this will not conflict with it. -- This will also work in CPU mode if you don't have a GPU on your machine. -``` - -(condasource)= - -### `conda` from source - -1. First, ensure git is installed: - - ```bash - git --version + ```bash + Invoke-WebRequest -Uri "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Windows-x86_64.exe" -OutFile "$env:UserProfile/Downloads/Miniforge3-Windows-x86_64.exe"; Start-Process -FilePath "$env:UserProfile/Downloads/Miniforge3-Windows-x86_64.exe" -ArgumentList "/InstallationType=JustMe /RegisterPython=1 /S" -Wait; Remove-Item -Path "$env:UserProfile/Downloads/Miniforge3-Windows-x86_64.exe" + ``` ``` + ```{group-tab} Linux + Open a new terminal and enter: - If 'git' is not recognized, then [install git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git). - -2. Then, clone the repository: - - ```bash - git clone https://github.com/talmolab/sleap && cd sleap + ```bash + curl -fsSL --compressed https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -o "~/Downloads/Miniforge3-Linux-x86_64.sh" && chmod +x "~/Downloads/Miniforge3-Linux-x86_64.sh" && "~/Downloads/Miniforge3-Linux-x86_64.sh" -b -p ~/miniforge3 && rm "~/Downloads/Miniforge3-Linux-x86_64.sh" && ~/miniforge3/bin/conda init "$(basename "${SHELL}")" && source "$HOME/.$(basename "${SHELL}")rc" + ``` ``` + ```{group-tab} Mac (Apple Silicon) + Open a new terminal and enter: -3. Finally, install from the environment file (differs based on OS and GPU): - - **Windows** and **Linux** - - ```bash - mamba env create -f environment.yml -n sleap - ``` - - If you do not have a NVIDIA GPU, then you should use the no CUDA environment file: - - ```bash - mamba env create -f environment_no_cuda.yml -n sleap + ```bash + curl -fsSL --compressed https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh -o "~/Downloads/Miniforge3-MacOSX-arm64.sh" && chmod +x "~/Downloads/Miniforge3-MacOSX-arm64.sh" && "~/Downloads/Miniforge3-MacOSX-arm64.sh" -b -p ~/miniforge3 && rm "~/Downloads/Miniforge3-MacOSX-arm64.sh" && ~/miniforge3/bin/conda init "$(basename "${SHELL}")" && source "$HOME/.$(basename "${SHELL}")rc" + ``` ``` + ```{group-tab} Mac (Intel) + Open a new terminal and enter: - **Mac OS X** and **Apple Silicon** - - ```bash - mamba env create -f environment_mac.yml -n sleap + ```bash + curl -fsSL --compressed https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-x86_64.sh -o "~/Downloads/Miniforge3-MacOSX-x86_64.sh" && chmod +x "~/Downloads/Miniforge3-MacOSX-x86_64.sh" && "~/Downloads/Miniforge3-MacOSX-x86_64.sh" -b -p ~/miniforge3 && rm "~/Downloads/Miniforge3-MacOSX-x86_64.sh" && ~/miniforge3/bin/conda init "$(basename "${SHELL}")" && source "$HOME/.$(basename "${SHELL}")rc" + ``` ``` +```` - This is the **recommended method for development**. - -```{note} -- This installs SLEAP in development mode, which means that edits to the source code will be applied the next time you run SLEAP. -- Change the `-n sleap` in the command to create an environment with a different name (e.g., `-n sleap_develop`). -``` - -(pippackage)= - -### `pip` package - -Although you do not need Mambaforge installed to perform a `pip install`, we recommend {ref}`installing Mambaforge` to create a new environment where we can isolate the `pip install`. Alternatively, you can use a venv if you have an existing python installation. If you are working on **Google Colab**, skip to step 3 to perform the `pip install` without using a conda environment. +### Miniconda -1. Otherwise, create a new conda environment where we will `pip install sleap`: +This is a minimal installer for conda that includes the `conda` package manager and is maintained by the [Anaconda](https://www.anaconda.com) company. - either without GPU support: +````{tabs} + ```{group-tab} Windows + Open a new PowerShell terminal (does not need to be admin) and enter: - ```bash - mamba create --name sleap pip python=3.7.12 + ```bash + curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe -o miniconda.exe; Start-Process -FilePath ".\miniconda.exe" -ArgumentList "/S" -Wait; del miniconda.exe + ``` ``` + ```{group-tab} Linux + Open a new terminal and enter: - or with GPU support: - - ```bash - mamba create --name sleap pip python=3.7.12 cudatoolkit=11.3 cudnn=8.2 + ```bash + mkdir -p ~/miniconda3 && wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh && bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 && rm ~/miniconda3/miniconda.sh && ~/miniconda3/bin/conda init "$(basename "${SHELL}")" && source "$HOME/.$(basename "${SHELL}")rc" + ``` ``` + ```{group-tab} Mac (Apple Silicon) + Open a new terminal and enter: -2. Then activate the environment to isolate the `pip install` from other environments on your computer: - - ```bash - mamba activate sleap + ```bash + curl -fsSL --compressed https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o "~/Downloads/Miniconda3-latest-MacOSX-arm64.sh" && chmod +x "~/Downloads/Miniconda3-latest-MacOSX-arm64.sh" && "~/Downloads/Miniconda3-latest-MacOSX-arm64.sh" -b -u -p ~/miniconda3 && rm "~/Downloads/Miniconda3-latest-MacOSX-arm64.sh" && ~/miniconda3/bin/conda init "$(basename "${SHELL}")" && source "$HOME/.$(basename "${SHELL}")rc" + ``` ``` + ```{group-tab} Mac (Intel) + Open a new terminal and enter: - ```{warning} - Refrain from installing anything into the `base` environment. Always create a new environment to install new packages. + ```bash + curl -fsSL --compressed https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o "~/Downloads/Miniconda3-latest-MacOSX-x86_64.sh" && chmod +x "~/Downloads/Miniconda3-latest-MacOSX-x86_64.sh" && "~/Downloads/Miniconda3-latest-MacOSX-x86_64.sh" -b -u -p ~/miniconda3 && rm "~/Downloads/Miniconda3-latest-MacOSX-x86_64.sh" && ~/miniconda3/bin/conda init "$(basename "${SHELL}")" && source "$HOME/.$(basename "${SHELL}")rc" + ``` ``` +```` -3. Finally, we can perform the `pip install`: - - ```bash - pip install sleap[pypi]==1.3.3 - ``` +See the [Miniconda website](https://docs.anaconda.com/free/miniconda/) for up-to-date installation instructions if the above instructions don't work for your system. - This works on **any OS except Apple silicon** and on **Google Colab**. +(installation-methods)= - ```{note} - The pypi distributed package of SLEAP ships with the following extras: - - **pypi**: For installation without an mamba environment file. All dependencies come from PyPI. - - **jupyter**: This installs all *pypi* and jupyter lab dependencies. - - **dev**: This installs all *jupyter* dependencies and developement tools for testing and building docs. - - **conda_jupyter**: For installation using a mamba environment file included in the source code. Most dependencies are listed as conda packages in the environment file and only a few come from PyPI to allow jupyter lab support. - - **conda_dev**: For installation using [a mamba environment](https://github.com/search?q=repo%3Atalmolab%2Fsleap+path%3Aenvironment*.yml&type=code) with a few PyPI dependencies for development tools. - ``` +## Installation methods - ```{note} - - Requires Python 3.7 - - To enable GPU support, make sure that you have **CUDA Toolkit v11.3** and **cuDNN v8.2** installed. - ``` +SLEAP can be installed three different ways: via {ref}`conda package`, {ref}`conda from source`, or {ref}`pip package`. Select one of the methods below to install SLEAP. We recommend {ref}`conda package`. - ```{warning} - This will uninstall existing libraries and potentially install conflicting ones. +`````{tabs} + ```{tab} conda package + **This is the recommended installation method**. + ````{tabs} + ```{group-tab} Windows and Linux + ```bash + conda create -y -n sleap -c conda-forge -c nvidia -c sleap/label/dev -c sleap -c anaconda sleap=1.4.1 + ``` + ```{note} + - This comes with CUDA to enable GPU support. All you need is to have an NVIDIA GPU and [updated drivers](https://nvidia.com/drivers). + - If you already have CUDA installed on your system, this will not conflict with it. + - This will also work in CPU mode if you don't have a GPU on your machine. + ``` + ``` + ```{group-tab} Mac OS + ```bash + conda create -y -n sleap -c conda-forge -c anaconda -c sleap sleap=1.4.1 + ``` + ```{note} + This will also work in CPU mode if you don't have a GPU on your machine. + ``` + ``` + ```` - We strongly recommend that you **only use this method if you know what you're doing**! ``` + ```{tab} conda from source + This is the **recommended method for development**. + 1. First, ensure git is installed: + ```bash + git --version + ``` + If `git` is not recognized, then [install git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git). + 2. Then, clone the repository: + ```bash + git clone https://github.com/talmolab/sleap && cd sleap + ``` + 3. Finally, install SLEAP from the environment file: + ````{tabs} + ```{group-tab} Windows and Linux + ````{tabs} + ```{group-tab} NVIDIA GPU + ```bash + conda env create -f environment.yml -n sleap + ``` + ``` + ```{group-tab} CPU or other GPU + ```bash + conda env create -f environment_no_cuda.yml -n sleap + ``` + ``` + ```` + ``` + ```{group-tab} Mac OS + ```bash + conda env create -f environment_mac.yml -n sleap + ``` + ``` + ```` + ```{note} + - This installs SLEAP in development mode, which means that edits to the source code will be applied the next time you run SLEAP. + - Change the `-n sleap` in the command to create an environment with a different name (e.g., `-n sleap_develop`). + ``` + ``` + ```{tab} pip package + This is the **recommended method for Google Colab only**. + ```{warning} + This will uninstall existing libraries and potentially install conflicting ones. + + We strongly recommend that you **only use this method if you know what you're doing**! + ``` + ````{tabs} + ```{group-tab} Windows and Linux + ```{note} + - Requires Python 3.7 + - To enable GPU support, make sure that you have **CUDA Toolkit v11.3** and **cuDNN v8.2** installed. + ``` + Although you do not need Miniconda installed to perform a `pip install`, we recommend [installing Miniconda](https://docs.anaconda.com/free/miniconda/) to create a new environment where we can isolate the `pip install`. Alternatively, you can use a venv if you have an existing Python 3.7 installation. If you are working on **Google Colab**, skip to step 3 to perform the `pip install` without using a conda environment. + 1. Otherwise, create a new conda environment where we will `pip install sleap`: + ````{tabs} + ```{group-tab} NVIDIA GPU + ```bash + conda create --name sleap pip python=3.7.12 cudatoolkit=11.3 cudnn=8.2 -c conda-forge -c nvidia + ``` + ``` + ```{group-tab} CPU or other GPU + ```bash + conda create --name sleap pip python=3.7.12 + ``` + ``` + ```` + 2. Then activate the environment to isolate the `pip install` from other environments on your computer: + ```bash + conda activate sleap + ``` + ```{warning} + Refrain from installing anything into the `base` environment. Always create a new environment to install new packages. + ``` + 3. Finally, we can perform the `pip install`: + ```bash + pip install sleap[pypi]==1.4.1 + ``` + ```{note} + The pypi distributed package of SLEAP ships with the following extras: + - **pypi**: For installation without an conda environment file. All dependencies come from PyPI. + - **jupyter**: This installs all *pypi* and jupyter lab dependencies. + - **dev**: This installs all *jupyter* dependencies and developement tools for testing and building docs. + - **conda_jupyter**: For installation using a conda environment file included in the source code. Most dependencies are listed as conda packages in the environment file and only a few come from PyPI to allow jupyter lab support. + - **conda_dev**: For installation using [a conda environment](https://github.com/search?q=repo%3Atalmolab%2Fsleap+path%3Aenvironment*.yml&type=code) with a few PyPI dependencies for development tools. + ``` + ``` + ```{group-tab} Mac OS + Not supported. + ``` + ```` + ``` +````` ## Testing that things are working -If you installed using `mamba`, first activate the `sleap` environment by opening a terminal and typing: +If you installed using `conda`, first activate the `sleap` environment by opening a terminal and typing: ```bash -mamba activate sleap +conda activate sleap ``` ````{hint} -Not sure what `mamba` environments you already installed? You can get a list of the environments on your system with: +Not sure what `conda` environments you already installed? You can get a list of the environments on your system with: ``` -mamba env list +conda env list ``` ```` @@ -301,7 +301,7 @@ python -c "import sleap; sleap.versions()" ### GPU support -Assuming you installed using either of the `mamba`-based methods on Windows or Linux, SLEAP should automatically have GPU support enabled. +Assuming you installed using either of the `conda`-based methods on Windows or Linux, SLEAP should automatically have GPU support enabled. To check, verify that SLEAP can detect the GPUs on your system: @@ -362,7 +362,7 @@ file: No such file or directory then activate the environment: ```bash -mamba activate sleap +conda activate sleap ``` and run the commands: @@ -391,13 +391,13 @@ We **strongly recommend** installing SLEAP in a fresh environment when updating. To uninstall an existing environment named `sleap`: ```bash -mamba env remove -n sleap +conda env remove -n sleap ``` ````{hint} -Not sure what `mamba` environments you already installed? You can get a list of the environments on your system with: +Not sure what `conda` environments you already installed? You can get a list of the environments on your system with: ```bash -mamba env list +conda env list ``` ```` @@ -413,10 +413,10 @@ If you get any errors or the GUI fails to launch, try running the diagnostics to sleap-diagnostic ``` -If you were not able to get SLEAP installed, activate the mamba environment it is in and generate a list of the package versions installed: +If you were not able to get SLEAP installed, activate the conda environment it is in and generate a list of the package versions installed: ```bash -mamba list +conda list ``` Then, [open a new Issue](https://github.com/talmolab/sleap/issues) providing the versions from either command above, as well as any errors you saw in the console during the installation. Or [start a discussion](https://github.com/talmolab/sleap/discussions) to get help from the community. diff --git a/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb b/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb index b0211bbca..4e26cb286 100644 --- a/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb +++ b/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb @@ -335,7 +335,7 @@ " \"runs_folder\": \"models\",\n", " \"tags\": [],\n", " \"save_visualizations\": true,\n", - " \"delete_viz_images\": true,\n", + " \"keep_viz_images\": true,\n", " \"zip_outputs\": false,\n", " \"log_to_csv\": true,\n", " \"checkpointing\": {\n", @@ -727,7 +727,7 @@ " \"runs_folder\": \"models\",\n", " \"tags\": [],\n", " \"save_visualizations\": true,\n", - " \"delete_viz_images\": true,\n", + " \"keep_viz_images\": true,\n", " \"zip_outputs\": false,\n", " \"log_to_csv\": true,\n", " \"checkpointing\": {\n", diff --git a/environment.yml b/environment.yml index 67ed39d01..d8f752759 100644 --- a/environment.yml +++ b/environment.yml @@ -10,14 +10,15 @@ channels: dependencies: # Packages SLEAP uses directly - - conda-forge::attrs >=21.2.0 #,<=21.4.0 + - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - - conda-forge::imgaug ==0.4.0 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::networkx - anaconda::numpy >=1.19.5,<1.23.0 - - conda-forge::opencv + - conda-forge::opencv <4.9.0 + - conda-forge::h5py <=3.7.0 - conda-forge::pandas - conda-forge::pip - conda-forge::pillow #>=8.3.1,<=8.4.0 @@ -35,8 +36,11 @@ dependencies: - conda-forge::scikit-learn ==1.0 - conda-forge::scikit-video - conda-forge::seaborn - - sleap::tensorflow >=2.6.3,<2.11 # No windows GPU support for >2.10 + - sleap/label/dev::tensorflow ==2.7.0 # TODO: Switch to main label when updated - conda-forge::tensorflow-hub # Pinned in meta.yml, but no problems here... yet + - conda-forge::qudida + - conda-forge::albumentations + - conda-forge::ndx-pose <0.2.0 # Packages required by tensorflow to find/use GPUs - conda-forge::cudatoolkit ==11.3.1 @@ -46,4 +50,3 @@ dependencies: - pip: - "--editable=.[conda_dev]" - \ No newline at end of file diff --git a/environment_mac.yml b/environment_mac.yml index 85ef7d3b9..2026154fa 100644 --- a/environment_mac.yml +++ b/environment_mac.yml @@ -9,13 +9,14 @@ channels: dependencies: # Packages SLEAP uses directly - conda-forge::attrs >=21.2.0 + - conda-forge::importlib-metadata <7.1.0 - conda-forge::cattrs ==1.1.1 - conda-forge::h5py - - conda-forge::imgaug ==0.4.0 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::keras <2.10.0,>=2.9.0rc0 # Required by tensorflow-macos - - conda-forge::networkx + - conda-forge::networkx <3.3 - anaconda::numpy >=1.19.5,<1.23.0 - conda-forge::opencv - conda-forge::pandas @@ -35,6 +36,9 @@ dependencies: - conda-forge::scikit-learn ==1.0 - conda-forge::scikit-video - conda-forge::seaborn - - conda-forge::tensorflow-hub + # - conda-forge::tensorflow-hub # pulls in tensorflow cpu from conda-forge + - conda-forge::qudida + - conda-forge::albumentations + - conda-forge::ndx-pose <0.2.0 - pip: - - "--editable=.[conda_dev]" + - "--editable=.[conda_dev]" \ No newline at end of file diff --git a/environment_no_cuda.yml b/environment_no_cuda.yml index 7e384b5f9..721c27fca 100644 --- a/environment_no_cuda.yml +++ b/environment_no_cuda.yml @@ -11,14 +11,14 @@ channels: dependencies: # Packages SLEAP uses directly - - conda-forge::attrs >=21.2.0 #,<=21.4.0 + - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - - conda-forge::imgaug ==0.4.0 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::networkx - anaconda::numpy >=1.19.5,<1.23.0 - - conda-forge::opencv + - conda-forge::opencv <4.9.0 - conda-forge::pandas - conda-forge::pip - conda-forge::pillow #>=8.3.1,<=8.4.0 @@ -36,8 +36,12 @@ dependencies: - conda-forge::scikit-learn ==1.0 - conda-forge::scikit-video - conda-forge::seaborn - - sleap::tensorflow >=2.6.3,<2.11 # No windows GPU support for >2.10 + # - sleap::tensorflow >=2.6.3,<2.11 # No windows GPU support for >2.10 + - sleap/label/dev::tensorflow ==2.7.0 - conda-forge::tensorflow-hub + - conda-forge::qudida + - conda-forge::albumentations + - conda-forge::ndx-pose <0.2.0 - pip: - - "--editable=.[conda_dev]" + - "--editable=.[conda_dev]" \ No newline at end of file diff --git a/pypi_requirements.txt b/pypi_requirements.txt index 33f419c9c..775ce584e 100644 --- a/pypi_requirements.txt +++ b/pypi_requirements.txt @@ -3,15 +3,17 @@ # setup.py, the packages in requirements.txt will also be installed when running # pip install sleap[pypi]. -# These are also distrubuted through conda and not pip installed when using conda. +# These are also distributed through conda and not pip installed when using conda. attrs>=21.2.0,<=21.4.0 cattrs==1.1.1 +imageio +imageio-ffmpeg # certifi>=2017.4.17,<=2021.10.8 jsmin jsonpickle==1.2 networkx numpy>=1.19.5,<1.23.0 -opencv-python>=4.2.0,<=4.6.0 +opencv-python>=4.2.0,<=4.7.0 pandas pillow>=8.3.1,<=8.4.0 psutil @@ -32,7 +34,10 @@ scikit-learn ==1.0.* scikit-video seaborn tensorflow>=2.6.3,<2.9; platform_machine != 'arm64' +# tensorflow ==2.7.4; platform_machine != 'arm64' tensorflow-hub<=0.14.0 +albumentations +ndx-pose<0.2.0 # These dependencies are untested since we do not offer a wheel for apple silicon atm. tensorflow-macos==2.9.2; sys_platform == 'darwin' and platform_machine == 'arm64' tensorflow-metal==0.5.0; sys_platform == 'darwin' and platform_machine == 'arm64' diff --git a/requirements.txt b/requirements.txt index cb0ef45c5..5db435ec8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,12 @@ # This file contains the minimal requirements to be installed via pip when using conda. # No conda packages for these -imgstore<0.3.0 # 0.3.3 results in https://github.com/O365/python-o365/issues/591 -ndx-pose +imgstore<0.3.0 # 0.3.3 results in https://github.com/O365/python-o365/issues/591 which is from https://github.com/regebro/tzlocal/issues/112 when tzlocal is v3.0 nixio>=1.5.3 # Constrain put on by @jgrewe from G-Node qimage2ndarray # ==1.9.0 segmentation-models tensorflow-macos==2.9.2; sys_platform == 'darwin' and platform_machine == 'arm64' tensorflow-metal==0.5.0; sys_platform == 'darwin' and platform_machine == 'arm64' +tensorflow-hub==0.12.0; sys_platform == 'darwin' and platform_machine == 'arm64' -# Conda installing results in https://github.com/h5py/h5py/issues/2037 -h5py<3.2; sys_platform == 'win32' # Newer versions result in error above, linking issue in Linux pynwb>=2.3.3 # 2.0.0 required by ndx-pose, 2.3.3 fixes importlib-metadata incompatibility diff --git a/sleap/config/frame_range_form.yaml b/sleap/config/frame_range_form.yaml new file mode 100644 index 000000000..3f01eade4 --- /dev/null +++ b/sleap/config/frame_range_form.yaml @@ -0,0 +1,13 @@ +main: + + - name: min_frame_idx + label: Minimum frame index + type: int + range: 1,1000000 + default: 1 + + - name: max_frame_idx + label: Maximum frame index + type: int + range: 1,1000000 + default: 1000 \ No newline at end of file diff --git a/sleap/config/labeled_clip_form.yaml b/sleap/config/labeled_clip_form.yaml index be0d64829..9236ad42b 100644 --- a/sleap/config/labeled_clip_form.yaml +++ b/sleap/config/labeled_clip_form.yaml @@ -18,6 +18,10 @@ main: label: Use GUI Visual Settings (colors, line widths) type: bool default: true + - name: background + label: Video Background + type: list + options: original,black,white,grey - name: open_when_done label: Open When Done Saving type: bool diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index cbcea2be5..1bb930e58 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -52,7 +52,7 @@ training: This pipeline uses two models: a "centroid" model to locate and crop around each animal in the frame, and a "centered-instance confidence map" model for predicted node locations - for each individual animal predicted by the centroid model.' + for each individual animal predicted by the centroid model.' - label: Max Instances name: max_instances type: optional_int @@ -211,6 +211,21 @@ training: options: ',RGB,grayscale' type: list +- type: text + text: 'ZMQ Options' + +- name: controller_port + label: Controller Port + type: int + default: 9000 + range: 1024,65535 + +- name: publish_port + label: Publish Port + type: int + default: 9001 + range: 1024,65535 + - type: text text: 'Output Options' @@ -271,6 +286,11 @@ training: type: bool default: true +- name: _keep_viz + label: Keep Prediction Visualization Images After Training + type: bool + default: false + - name: _predict_frames label: Predict On type: list @@ -287,7 +307,7 @@ inference: label: Training/Inference Pipeline Type type: stacked default: "multi-animal bottom-up " - options: "multi-animal bottom-up,multi-animal top-down,multi-animal bottom-up-id,multi-animal top-down-id,single animal,movenet-lightning,movenet-thunder,none" + options: "multi-animal bottom-up,multi-animal top-down,multi-animal bottom-up-id,multi-animal top-down-id,single animal,movenet-lightning,movenet-thunder,tracking-only" multi-animal bottom-up: - type: text @@ -365,7 +385,13 @@ inference: Note that this model is intended for human pose estimation. There is no support for videos containing more than one instance' - none: + tracking-only: + +- name: batch_size + label: Batch Size + type: int + default: 4 + range: 1,512 - name: tracking.tracker label: Tracker (cross-frame identity) Method @@ -413,7 +439,7 @@ inference: label: Similarity Method type: list default: instance - options: instance,centroid,iou + options: "instance,normalized_instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -452,6 +478,22 @@ inference: label: Nodes to use for Tracking type: string default: 0,1,2 + - type: text + text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.' + - name: tracking.oks_errors + label: Keypoints errors in pixels + help: 'Standard error in pixels of the distance for each keypoint. + If the list is empty, defaults to 1. If singleton list, each keypoint has + the same error. Otherwise, the length should be the same as the number of + keypoints in the skeleton.' + type: string + default: + - name: tracking.oks_score_weighting + label: Use prediction score for weighting + help: 'Use prediction scores to weight the similarity of each keypoint' + type: bool + default: false - type: text text: 'Post-tracker data cleaning:' - name: tracking.post_connect_single_breaks @@ -495,8 +537,8 @@ inference: - name: tracking.similarity label: Similarity Method type: list - default: iou - options: instance,centroid,iou + default: instance + options: "instance,normalized_instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -531,6 +573,22 @@ inference: label: Nodes to use for Tracking type: string default: 0,1,2 + - type: text + text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.' + - name: tracking.oks_errors + label: Keypoints errors in pixels + help: 'Standard error in pixels of the distance for each keypoint. + If the list is empty, defaults to 1. If singleton list, each keypoint has + the same error. Otherwise, the length should be the same as the number of + keypoints in the skeleton.' + type: string + default: + - name: tracking.oks_score_weighting + label: Use prediction score for weighting + help: 'Use prediction scores to weight the similarity of each keypoint' + type: bool + default: false - type: text text: 'Post-tracker data cleaning:' - name: tracking.post_connect_single_breaks diff --git a/sleap/config/suggestions.yaml b/sleap/config/suggestions.yaml index 8cf89728a..1440530fc 100644 --- a/sleap/config/suggestions.yaml +++ b/sleap/config/suggestions.yaml @@ -3,7 +3,7 @@ main: label: Method type: stacked default: " " - options: " ,image features,sample,prediction score,velocity,frame chunk" + options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement" " ": sample: @@ -175,6 +175,13 @@ main: type: double default: 0.1 range: 0.1,1.0 + + "max point displacement": + - name: displacement_threshold + label: Maximum Displacement Threshold + type: int + default: 10 + range: 0,999 - name: target label: Target diff --git a/sleap/config/training_editor_form.yaml b/sleap/config/training_editor_form.yaml index d10b840a0..7d7972892 100644 --- a/sleap/config/training_editor_form.yaml +++ b/sleap/config/training_editor_form.yaml @@ -44,7 +44,7 @@ model: label: Max Stride name: model.backbone.hourglass.max_stride type: list - options: 1,2,4,8,16,32,64 + options: 1,2,4,8,16,32,64,128 # - default: 4 # help: Determines the number of upsampling blocks in the network. # label: Output Stride @@ -81,7 +81,7 @@ model: label: Max Stride name: model.backbone.leap.max_stride type: list - options: 2,4,8,16,32,64 + options: 2,4,8,16,32,64,128 # - default: 1 # help: Determines the number of upsampling blocks in the network. # label: Output Stride @@ -190,7 +190,7 @@ model: label: Max Stride name: model.backbone.resnet.max_stride type: list - options: 2,4,8,16,32,64 + options: 2,4,8,16,32,64,128 # - default: 4 # help: Stride of the final output. If the upsampling branch is not defined, the # output stride is controlled via dilated convolutions or reduced pooling in the @@ -250,7 +250,7 @@ model: label: Max Stride name: model.backbone.unet.max_stride type: list - options: 2,4,8,16,32,64 + options: 2,4,8,16,32,64,128 # - default: 1 # help: Determines the number of upsampling blocks in the network. # label: Output Stride @@ -661,6 +661,7 @@ optimization: label: Batch Size name: optimization.batch_size type: int + range: 1,512 - default: 100 help: Maximum number of epochs to train for. Training can be stopped manually or automatically if early stopping is enabled and a plateau is detected. label: Epochs diff --git a/sleap/gui/app.py b/sleap/gui/app.py index de6ce9fbf..8b711c806 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -44,13 +44,16 @@ frame and instances listed in data view table. """ - import os import platform import random import re +import traceback +from logging import getLogger from pathlib import Path from typing import Callable, List, Optional, Tuple +import sys +import subprocess from qtpy import QtCore, QtGui from qtpy.QtCore import QEvent, Qt @@ -82,7 +85,10 @@ from sleap.io.video import available_video_exts from sleap.prefs import prefs from sleap.skeleton import Skeleton -from sleap.util import parse_uri_path +from sleap.util import parse_uri_path, get_config_file + + +logger = getLogger(__name__) class MainWindow(QMainWindow): @@ -101,6 +107,7 @@ class MainWindow(QMainWindow): def __init__( self, labels_path: Optional[str] = None, + labels: Optional[Labels] = None, reset: bool = False, no_usage_data: bool = False, *args, @@ -118,7 +125,7 @@ def __init__( self.setAcceptDrops(True) self.state = GuiState() - self.labels = Labels() + self.labels = labels or Labels() self.commands = CommandContext( state=self.state, app=self, update_callback=self.on_data_update @@ -145,6 +152,7 @@ def __init__( self.state["edge style"] = prefs["edge style"] self.state["fit"] = False self.state["color predicted"] = prefs["color predicted"] + self.state["trail_length"] = prefs["trail length"] self.state["trail_shade"] = prefs["trail shade"] self.state["marker size"] = prefs["marker size"] self.state["propagate track labels"] = prefs["propagate track labels"] @@ -175,8 +183,10 @@ def __init__( print("Restoring GUI state...") self.restoreState(prefs["window state"]) - if labels_path: + if labels_path is not None: self.commands.loadProjectFile(filename=labels_path) + elif labels is not None: + self.commands.loadLabelsObject(labels=labels) else: self.state["project_loaded"] = False @@ -213,6 +223,7 @@ def closeEvent(self, event): prefs["edge style"] = self.state["edge style"] prefs["propagate track labels"] = self.state["propagate track labels"] prefs["color predicted"] = self.state["color predicted"] + prefs["trail length"] = self.state["trail_length"] prefs["trail shade"] = self.state["trail_shade"] prefs["share usage data"] = self.state["share usage data"] @@ -254,7 +265,6 @@ def dragEnterEvent(self, event): event.acceptProposedAction() def dropEvent(self, event): - # Parse filenames filenames = event.mimeData().data("text/uri-list").data().decode() filenames = [parse_uri_path(f.strip()) for f in filenames.strip().split("\n")] @@ -367,7 +377,9 @@ def add_menu_item(menu, key: str, name: str, action: Callable): def connect_check(key): self._menu_actions[key].setCheckable(True) self._menu_actions[key].setChecked(self.state[key]) - self.state.connect(key, self._menu_actions[key].setChecked) + self.state.connect( + key, lambda checked: self._menu_actions[key].setChecked(checked) + ) # add checkable menu item connected to state variable def add_menu_check_item(menu, key: str, name: str): @@ -506,6 +518,13 @@ def add_submenu_choices(menu, title, options, key): fileMenu, "reset prefs", "Reset preferences to defaults...", self.resetPrefs ) + add_menu_item( + fileMenu, + "open preference directory", + "Open Preferences Directory...", + self.openPrefs, + ) + fileMenu.addSeparator() add_menu_item(fileMenu, "close", "Quit", self.close) @@ -637,17 +656,18 @@ def prev_vid(): key="edge style", ) + # XXX add_submenu_choices( menu=viewMenu, title="Node Marker Size", - options=(1, 2, 4, 6, 8, 12), + options=prefs["node marker sizes"], key="marker size", ) add_submenu_choices( menu=viewMenu, title="Node Label Size", - options=(6, 12, 18, 24, 36), + options=prefs["node label sizes"], key="node label size", ) @@ -686,13 +706,17 @@ def prev_vid(): ) def new_instance_menu_action(): + """Determine which action to use when using Ctrl + I or menu Add Instance. + + We always add an offset of 10. + """ method_key = [ key for (key, val) in instance_adding_methods.items() if val == self.state["instance_init_method"] ] if method_key: - self.commands.newInstance(init_method=method_key[0]) + self.commands.newInstance(init_method=method_key[0], offset=10) labelMenu = self.menuBar().addMenu("Labels") add_menu_item( @@ -735,12 +759,12 @@ def new_instance_menu_action(): labelMenu.addAction( "Copy Instance", self.commands.copyInstance, - Qt.CTRL + Qt.Key_C, + Qt.CTRL | Qt.Key_C, ) labelMenu.addAction( "Paste Instance", self.commands.pasteInstance, - Qt.CTRL + Qt.Key_V, + Qt.CTRL | Qt.Key_V, ) labelMenu.addSeparator() @@ -775,6 +799,12 @@ def new_instance_menu_action(): "Delete Predictions with Low Score...", self.commands.deleteLowScorePredictions, ) + add_menu_item( + labelMenu, + "delete max instance predictions", + "Delete Predictions beyond Max Instances...", + self.commands.deleteInstanceLimitPredictions, + ) add_menu_item( labelMenu, "delete frame limit predictions", @@ -834,12 +864,12 @@ def new_instance_menu_action(): tracksMenu.addAction( "Copy Instance Track", self.commands.copyInstanceTrack, - Qt.CTRL + Qt.SHIFT + Qt.Key_C, + Qt.CTRL | Qt.SHIFT | Qt.Key_C, ) tracksMenu.addAction( "Paste Instance Track", self.commands.pasteInstanceTrack, - Qt.CTRL + Qt.SHIFT + Qt.Key_V, + Qt.CTRL | Qt.SHIFT | Qt.Key_V, ) tracksMenu.addSeparator() @@ -850,6 +880,8 @@ def new_instance_menu_action(): "Point Displacement (max)", "Primary Point Displacement (sum)", "Primary Point Displacement (max)", + "Tracking Score (mean)", + "Tracking Score (min)", "Instance Score (sum)", "Instance Score (min)", "Point Score (sum)", @@ -1018,6 +1050,7 @@ def _load_overlays(self): labels=self.labels, player=self.player, trail_shade=self.state["trail_shade"], + trail_length=self.state["trail_length"], ) self.overlays["instance"] = InstanceOverlay( labels=self.labels, player=self.player, state=self.state @@ -1307,7 +1340,7 @@ def updateStatusMessage(self, message: Optional[str] = None): message += f" [Hidden] Press '{hide_key}' to toggle." self.statusBar().setStyleSheet("color: red") else: - self.statusBar().setStyleSheet("color: black") + self.statusBar().setStyleSheet("") self.statusBar().showMessage(message) @@ -1320,14 +1353,42 @@ def resetPrefs(self): ) msg.exec_() + def openPrefs(self): + """Open preference file directory""" + pref_path = get_config_file("preferences.yaml") + # Make sure the pref_path is a directory rather than a file + if pref_path.is_file(): + pref_path = pref_path.parent + # Open the file explorer at the folder containing the preferences.yaml file + if sys.platform == "win32": + subprocess.Popen(["explorer", str(pref_path)]) + elif sys.platform == "darwin": + subprocess.Popen(["open", str(pref_path)]) + else: + subprocess.Popen(["xdg-open", str(pref_path)]) + def _update_track_menu(self): """Updates track menu options.""" self.track_menu.clear() self.delete_tracks_menu.clear() + + # Create a dictionary mapping track indices to Qt.Key values + key_mapping = { + 0: Qt.Key_1, + 1: Qt.Key_2, + 2: Qt.Key_3, + 3: Qt.Key_4, + 4: Qt.Key_5, + 5: Qt.Key_6, + 6: Qt.Key_7, + 7: Qt.Key_8, + 8: Qt.Key_9, + 9: Qt.Key_0, + } for track_ind, track in enumerate(self.labels.tracks): key_command = "" if track_ind < 9: - key_command = Qt.CTRL + Qt.Key_0 + self.labels.tracks.index(track) + 1 + key_command = Qt.CTRL | key_mapping[track_ind] self.track_menu.addAction( f"{track.name}", lambda x=track: self.commands.setInstanceTrack(x), @@ -1337,7 +1398,7 @@ def _update_track_menu(self): f"{track.name}", lambda x=track: self.commands.deleteTrack(x) ) self.track_menu.addAction( - "New Track", self.commands.addTrack, Qt.CTRL + Qt.Key_0 + "New Track", self.commands.addTrack, Qt.CTRL | Qt.Key_0 ) def _update_seekbar_marks(self): @@ -1354,6 +1415,8 @@ def _set_seekbar_header(self, graph_name: str): "Point Displacement (max)": data_obj.get_point_displacement_series, "Primary Point Displacement (sum)": data_obj.get_primary_point_displacement_series, "Primary Point Displacement (max)": data_obj.get_primary_point_displacement_series, + "Tracking Score (mean)": data_obj.get_tracking_score_series, + "Tracking Score (min)": data_obj.get_tracking_score_series, "Instance Score (sum)": data_obj.get_instance_score_series, "Instance Score (min)": data_obj.get_instance_score_series, "Point Score (sum)": data_obj.get_point_score_series, @@ -1367,7 +1430,7 @@ def _set_seekbar_header(self, graph_name: str): else: if graph_name in header_functions: kwargs = dict(video=self.state["video"]) - reduction_name = re.search("\\((sum|max|min)\\)", graph_name) + reduction_name = re.search("\\((sum|max|min|mean)\\)", graph_name) if reduction_name is not None: kwargs["reduction"] = reduction_name.group(1) series = header_functions[graph_name](**kwargs) @@ -1594,8 +1657,12 @@ def _show_keyboard_shortcuts_window(self): ShortcutDialog().exec_() -def main(args: Optional[list] = None): - """Starts new instance of app.""" +def create_sleap_label_parser(): + """Creates parser for `sleap-label` command line arguments. + + Returns: + argparse.ArgumentParser: The parser. + """ import argparse @@ -1635,6 +1702,23 @@ def main(args: Optional[list] = None): default=False, ) + return parser + + +def create_app(): + """Creates Qt application.""" + + app = QApplication([]) + app.setApplicationName(f"SLEAP v{sleap.version.__version__}") + app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png"))) + + return app + + +def main(args: Optional[list] = None, labels: Optional[Labels] = None): + """Starts new instance of app.""" + + parser = create_sleap_label_parser() args = parser.parse_args(args) if args.nonnative: @@ -1646,17 +1730,26 @@ def main(args: Optional[list] = None): # https://stackoverflow.com/q/64818879 os.environ["QT_MAC_WANTS_LAYER"] = "1" - app = QApplication([]) - app.setApplicationName(f"SLEAP v{sleap.version.__version__}") - app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png"))) + app = create_app() window = MainWindow( - labels_path=args.labels_path, reset=args.reset, no_usage_data=args.no_usage_data + labels_path=args.labels_path, + labels=labels, + reset=args.reset, + no_usage_data=args.no_usage_data, ) window.showMaximized() # Disable GPU in GUI process. This does not affect subprocesses. - sleap.use_cpu_only() + try: + sleap.use_cpu_only() + except RuntimeError: # Visible devices cannot be modified after being initialized + logger.warning( + "Running processes on the GPU. Restarting your GUI should allow switching " + "back to CPU-only mode.\n" + "Received the following error when trying to switch back to CPU-only mode:" + ) + traceback.print_exc() # Print versions. print() diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 698eed756..fca982327 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -36,7 +36,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from enum import Enum from glob import glob from pathlib import Path, PurePath -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast import attr import cv2 @@ -49,7 +49,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.gui.dialogs.merge import MergeDialog, ReplaceSkeletonTableDialog from sleap.gui.dialogs.message import MessageDialog from sleap.gui.dialogs.missingfiles import MissingFilesDialog -from sleap.gui.dialogs.query import QueryDialog +from sleap.gui.dialogs.frame_range import FrameRangeDialog from sleap.gui.state import GuiState from sleap.gui.suggestions import VideoFrameSuggestions from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track @@ -260,16 +260,15 @@ def loadLabelsObject(self, labels: Labels, filename: Optional[str] = None): """ self.execute(LoadLabelsObject, labels=labels, filename=filename) - def loadProjectFile(self, filename: str): + def loadProjectFile(self, filename: Union[str, Labels]): """Loads given labels file into GUI. Args: - filename: The path to the saved labels dataset. If None, - then don't do anything. + filename: The path to the saved labels dataset or the `Labels` object. + If None, then don't do anything. Returns: None - """ self.execute(LoadProjectFile, filename=filename) @@ -492,8 +491,12 @@ def deleteLowScorePredictions(self): """Gui for deleting instances below some score threshold.""" self.execute(DeleteLowScorePredictions) - def deleteFrameLimitPredictions(self): + def deleteInstanceLimitPredictions(self): """Gui for deleting instances beyond some number in each frame.""" + self.execute(DeleteInstanceLimitPredictions) + + def deleteFrameLimitPredictions(self): + """Gui for deleting instances beyond some frame number.""" self.execute(DeleteFrameLimitPredictions) def completeInstanceNodes(self, instance: Instance): @@ -506,6 +509,7 @@ def newInstance( init_method: str = "best", location: Optional[QtCore.QPoint] = None, mark_complete: bool = False, + offset: int = 0, ): """Creates a new instance, copying node coordinates as appropriate. @@ -515,6 +519,8 @@ def newInstance( init_method: Method to use for positioning nodes. location: The location where instance should be added (if node init method supports custom location). + mark_complete: Whether to mark the instance as complete. + offset: Offset to apply to the location if given. """ self.execute( AddInstance, @@ -522,6 +528,7 @@ def newInstance( init_method=init_method, location=location, mark_complete=mark_complete, + offset=offset, ) def setPointLocations( @@ -647,9 +654,8 @@ def do_action(context: "CommandContext", params: dict): Returns: None. - """ - filename = params["filename"] + filename = params.get("filename", None) # If called with just a Labels object labels: Labels = params["labels"] context.state["labels"] = labels @@ -669,7 +675,9 @@ def do_action(context: "CommandContext", params: dict): context.state["video"] = labels.videos[0] context.state["project_loaded"] = True - context.state["has_changes"] = params.get("changed_on_load", False) + context.state["has_changes"] = params.get("changed_on_load", False) or ( + filename is None + ) # This is not listed as an edit command since we want a clean changestack context.app.on_data_update([UpdateTopic.project, UpdateTopic.all]) @@ -683,17 +691,16 @@ def ask(context: "CommandContext", params: dict): if len(filename) == 0: return - gui_video_callback = Labels.make_gui_video_callback( - search_paths=[os.path.dirname(filename)], context=params - ) - has_loaded = False labels = None - if type(filename) == Labels: + if isinstance(filename, Labels): labels = filename filename = None has_loaded = True else: + gui_video_callback = Labels.make_gui_video_callback( + search_paths=[os.path.dirname(filename)], context=params + ) try: labels = Labels.load_file(filename, video_search=gui_video_callback) has_loaded = True @@ -751,7 +758,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportAlphaTracker(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - video_path = params["video_path"] if "video_path" in params else None labels = Labels.load_alphatracker( @@ -791,7 +797,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportNWB(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_nwb(filename=params["filename"]) new_window = context.app.__class__() @@ -824,7 +829,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportDeepPoseKit(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.from_deepposekit( filename=params["filename"], video_path=params["video_path"], @@ -873,7 +877,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportLEAP(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_leap_matlab( filename=params["filename"], ) @@ -904,7 +907,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportCoco(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_coco( filename=params["filename"], img_dir=params["img_dir"], use_missing_gui=True ) @@ -936,7 +938,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportDeepLabCut(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_deeplabcut(filename=params["filename"]) new_window = context.app.__class__() @@ -1295,6 +1296,7 @@ def do_action(context: CommandContext, params: dict): frames=list(params["frames"]), fps=params["fps"], color_manager=params["color_manager"], + background=params["background"], show_edges=params["show edges"], edge_is_wedge=params["edge_is_wedge"], marker_size=params["marker size"], @@ -1309,7 +1311,6 @@ def do_action(context: CommandContext, params: dict): @staticmethod def ask(context: CommandContext, params: dict) -> bool: - from sleap.gui.dialogs.export_clip import ExportClipDialog dialog = ExportClipDialog() @@ -1333,17 +1334,15 @@ def ask(context: CommandContext, params: dict) -> bool: # makes mp4's that most programs can't open (VLC can). default_out_filename = context.state["filename"] + ".avi" - # But if we can write mpegs using sci-kit video, use .mp4 - # since it has trouble writing .avi files. - if VideoWriter.can_use_skvideo(): + if VideoWriter.can_use_ffmpeg(): default_out_filename = context.state["filename"] + ".mp4" - # Ask where use wants to save video file + # Ask where user wants to save video file filename, _ = FileDialog.save( context.app, caption="Save Video As...", dir=default_out_filename, - filter="Video (*.avi *mp4)", + filter="Video (*.avi *.mp4)", ) # Check if user hit cancel @@ -1354,6 +1353,7 @@ def ask(context: CommandContext, params: dict) -> bool: params["fps"] = export_options["fps"] params["scale"] = export_options["scale"] params["open_when_done"] = export_options["open_when_done"] + params["background"] = export_options["background"] params["crop"] = None @@ -1584,7 +1584,6 @@ class GoNextSuggestedFrame(NavCommand): @classmethod def do_action(cls, context: CommandContext, params: dict): - next_suggestion_frame = context.labels.get_next_suggestion( context.state["video"], context.state["frame_idx"], cls.seek_direction ) @@ -1770,7 +1769,6 @@ class ReplaceVideo(EditCommand): @staticmethod def do_action(context: CommandContext, params: dict) -> bool: - import_list = params["import_list"] for import_item, video in import_list: @@ -1899,7 +1897,6 @@ def ask(context: CommandContext, params: dict) -> bool: video_file_names = [] total_num_labeled_frames = 0 for idx in row_idxs: - video = videos[idx] if video is None: return False @@ -1944,7 +1941,6 @@ def load_skeleton(filename: str): def compare_skeletons( skeleton: Skeleton, new_skeleton: Skeleton ) -> Tuple[List[str], List[str], List[str]]: - delete_nodes = [] add_nodes = [] if skeleton.node_names != new_skeleton.node_names: @@ -2076,7 +2072,7 @@ def try_and_skip_if_error(func, *args, **kwargs): func(*args, **kwargs) except Exception as e: tb_str = traceback.format_exception( - etype=type(e), value=e, tb=e.__traceback__ + type(e), value=e, tb=e.__traceback__ ) logger.warning( f"Recieved the following error while replacing skeleton:\n" @@ -2307,6 +2303,8 @@ def _do_deletion(context: CommandContext, lf_inst_list: List[int]): lfs_to_remove = [] for lf, inst in lf_inst_list: context.labels.remove_instance(lf, inst, in_transaction=True) + if context.state["instance"] == inst: + context.state["instance"] = None if len(lf.instances) == 0: lfs_to_remove.append(lf) @@ -2449,7 +2447,7 @@ def ask(cls, context: CommandContext, params: dict) -> bool: return super().ask(context, params) -class DeleteFrameLimitPredictions(InstanceDeleteCommand): +class DeleteInstanceLimitPredictions(InstanceDeleteCommand): @staticmethod def get_frame_instance_list(context: CommandContext, params: dict): count_thresh = params["count_threshold"] @@ -2479,6 +2477,36 @@ def ask(cls, context: CommandContext, params: dict) -> bool: return super().ask(context, params) +class DeleteFrameLimitPredictions(InstanceDeleteCommand): + @staticmethod + def get_frame_instance_list(context: CommandContext, params: Dict): + """Called from the parent `InstanceDeleteCommand.ask` method. + + Returns: + List of instances to be deleted. + """ + instances = [] + # Select the instances to be deleted + for lf in context.labels.labeled_frames: + if lf.frame_idx < (params["min_frame_idx"] - 1) or lf.frame_idx > ( + params["max_frame_idx"] - 1 + ): + instances.extend([(lf, inst) for inst in lf.instances]) + return instances + + @classmethod + def ask(cls, context: CommandContext, params: Dict) -> bool: + current_video = context.state["video"] + dialog = FrameRangeDialog( + title="Delete Instances in Frame Range...", max_frame_idx=len(current_video) + ) + results = dialog.get_results() + if results: + params["min_frame_idx"] = results["min_frame_idx"] + params["max_frame_idx"] = results["max_frame_idx"] + return super().ask(context, params) + + class TransposeInstances(EditCommand): topics = [UpdateTopic.project_instances, UpdateTopic.tracks] @@ -2492,7 +2520,16 @@ def do_action(cls, context: CommandContext, params: dict): # Swap tracks for current and subsequent frames when we have tracks old_track, new_track = instances[0].track, instances[1].track if old_track is not None and new_track is not None: - frame_range = (context.state["frame_idx"], context.state["video"].frames) + if context.state["propagate track labels"]: + frame_range = ( + context.state["frame_idx"], + context.state["video"].frames, + ) + else: + frame_range = ( + context.state["frame_idx"], + context.state["frame_idx"] + 1, + ) context.labels.track_swap( context.state["video"], new_track, old_track, frame_range ) @@ -2535,6 +2572,7 @@ def do_action(context: CommandContext, params: dict): return context.labels.remove_instance(context.state["labeled_frame"], selected_inst) + context.state["instance"] = None class DeleteSelectedInstanceTrack(EditCommand): @@ -2552,6 +2590,7 @@ def do_action(context: CommandContext, params: dict): track = selected_inst.track context.labels.remove_instance(context.state["labeled_frame"], selected_inst) + context.state["instance"] = None if track is not None: # remove any instance on this track @@ -2723,7 +2762,6 @@ class GenerateSuggestions(EditCommand): @classmethod def do_action(cls, context: CommandContext, params: dict): - if len(context.labels.videos) == 0: print("Error: no videos to generate suggestions for") return @@ -2851,27 +2889,13 @@ def ask_and_do(cls, context: CommandContext, params: dict): class AddInstance(EditCommand): topics = [UpdateTopic.frame, UpdateTopic.project_instances, UpdateTopic.suggestions] - @staticmethod - def get_previous_frame_index(context: CommandContext) -> Optional[int]: - frames = context.labels.frames( - context.state["video"], - from_frame_idx=context.state["frame_idx"], - reverse=True, - ) - - try: - next_idx = next(frames).frame_idx - except: - return - - return next_idx - @classmethod def do_action(cls, context: CommandContext, params: dict): copy_instance = params.get("copy_instance", None) init_method = params.get("init_method", "best") location = params.get("location", None) mark_complete = params.get("mark_complete", False) + offset = params.get("offset", 0) if context.state["labeled_frame"] is None: return @@ -2879,6 +2903,250 @@ def do_action(cls, context: CommandContext, params: dict): if len(context.state["skeleton"]) == 0: return + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=copy_instance, init_method=init_method + ) + + new_instance = AddInstance.create_new_instance( + context=context, + from_predicted=from_predicted, + copy_instance=copy_instance, + mark_complete=mark_complete, + init_method=init_method, + location=location, + from_prev_frame=from_prev_frame, + offset=offset, + ) + + # Add the instance + context.labels.add_instance(context.state["labeled_frame"], new_instance) + + if context.state["labeled_frame"] not in context.labels.labels: + context.labels.append(context.state["labeled_frame"]) + + @staticmethod + def create_new_instance( + context: CommandContext, + from_predicted: Optional[PredictedInstance], + copy_instance: Optional[Union[Instance, PredictedInstance]], + mark_complete: bool, + init_method: str, + location: Optional[QtCore.QPoint], + from_prev_frame: bool, + offset: int = 0, + ) -> Instance: + """Create new instance.""" + + # Now create the new instance + new_instance = Instance( + skeleton=context.state["skeleton"], + from_predicted=from_predicted, + frame=context.state["labeled_frame"], + ) + + has_missing_nodes = AddInstance.set_visible_nodes( + context=context, + copy_instance=copy_instance, + new_instance=new_instance, + mark_complete=mark_complete, + init_method=init_method, + location=location, + offset=offset, + ) + + if has_missing_nodes: + AddInstance.fill_missing_nodes( + context=context, + copy_instance=copy_instance, + init_method=init_method, + new_instance=new_instance, + location=location, + ) + + # If we're copying a predicted instance or from another frame, copy the track + if hasattr(copy_instance, "score") or from_prev_frame: + copy_instance = cast(Union[PredictedInstance, Instance], copy_instance) + new_instance.track = copy_instance.track + + return new_instance + + @staticmethod + def fill_missing_nodes( + context: CommandContext, + copy_instance: Optional[Union[Instance, PredictedInstance]], + init_method: str, + new_instance: Instance, + location: Optional[QtCore.QPoint], + ): + """Fill in missing nodes for new instance. + + Args: + context: The command context. + copy_instance: The instance to copy from. + init_method: The initialization method. + new_instance: The new instance. + location: The location of the instance. + + Returns: + None + """ + + # mark the node as not "visible" if we're copying from a predicted instance without this node + is_visible = copy_instance is None or (not hasattr(copy_instance, "score")) + + if init_method == "force_directed": + AddMissingInstanceNodes.add_force_directed_nodes( + context=context, + instance=new_instance, + visible=is_visible, + center_point=location, + ) + elif init_method == "random": + AddMissingInstanceNodes.add_random_nodes( + context=context, instance=new_instance, visible=is_visible + ) + elif init_method == "template": + AddMissingInstanceNodes.add_nodes_from_template( + context=context, + instance=new_instance, + visible=is_visible, + center_point=location, + ) + else: + AddMissingInstanceNodes.add_best_nodes( + context=context, instance=new_instance, visible=is_visible + ) + + @staticmethod + def set_visible_nodes( + context: CommandContext, + copy_instance: Optional[Union[Instance, PredictedInstance]], + new_instance: Instance, + mark_complete: bool, + init_method: str, + location: Optional[QtCore.QPoint] = None, + offset: int = 0, + ) -> bool: + """Sets visible nodes for new instance. + + Args: + context: The command context. + copy_instance: The instance to copy from. + new_instance: The new instance. + mark_complete: Whether to mark the instance as complete. + init_method: The initialization method. + location: The location of the mouse click if any. + offset: The offset to apply to all nodes. + + Returns: + Whether the new instance has missing nodes. + """ + + if copy_instance is None: + return True + + has_missing_nodes = False + + # Calculate scale factor for getting new x and y values. + old_size_width = copy_instance.frame.video.shape[2] + old_size_height = copy_instance.frame.video.shape[1] + new_size_width = new_instance.frame.video.shape[2] + new_size_height = new_instance.frame.video.shape[1] + scale_width = new_size_width / old_size_width + scale_height = new_size_height / old_size_height + + # The offset is 0, except when using Ctrl + I or Add Instance button. + offset_x = offset + offset_y = offset + + # Using right click and context menu with option "best" + if (init_method == "best") and (location is not None): + reference_node = next( + (node for node in copy_instance if not node.isnan()), None + ) + reference_x = reference_node.x + reference_y = reference_node.y + offset_x = location.x() - (reference_x * scale_width) + offset_y = location.y() - (reference_y * scale_height) + + # Go through each node in skeleton. + for node in context.state["skeleton"].node_names: + # If we're copying from a skeleton that has this node. + if node in copy_instance and not copy_instance[node].isnan(): + # Ensure x, y inside current frame, then copy x, y, and visible. + # We don't want to copy a PredictedPoint or score attribute. + x_old = copy_instance[node].x + y_old = copy_instance[node].y + + # Copy the instance without scale or offset if predicted + if isinstance(copy_instance, PredictedInstance): + x_new = x_old + y_new = y_old + else: + x_new = x_old * scale_width + y_new = y_old * scale_height + + # Apply offset if in bounds + x_new_offset = x_new + offset_x + y_new_offset = y_new + offset_y + + # Default visibility is same as copied instance. + visible = copy_instance[node].visible + + # If the node is offset to outside the frame, mark as not visible. + if x_new_offset < 0: + x_new = 0 + visible = False + elif x_new_offset > new_size_width: + x_new = new_size_width + visible = False + else: + x_new = x_new_offset + if y_new_offset < 0: + y_new = 0 + visible = False + elif y_new_offset > new_size_height: + y_new = new_size_height + visible = False + else: + y_new = y_new_offset + + # Update the new instance with the new x, y, and visibility. + new_instance[node] = Point( + x=x_new, + y=y_new, + visible=visible, + complete=mark_complete, + ) + else: + has_missing_nodes = True + + return has_missing_nodes + + @staticmethod + def find_instance_to_copy_from( + context: CommandContext, + copy_instance: Optional[Union[Instance, PredictedInstance]], + init_method: bool, + ) -> Tuple[ + Optional[Union[Instance, PredictedInstance]], Optional[PredictedInstance], bool + ]: + """Find instance to copy from. + + Args: + context: The command context. + copy_instance: The `Instance` to copy from. + init_method: The initialization method. + + Returns: + The instance to copy from, the predicted instance (if it is from a predicted + instance, else None), and whether it's from a previous frame. + """ + from_predicted = copy_instance from_prev_frame = False @@ -2904,7 +3172,7 @@ def do_action(cls, context: CommandContext, params: dict): ) or init_method == "prior_frame": # Otherwise, if there are instances in previous frames, # copy the points from one of those instances. - prev_idx = cls.get_previous_frame_index(context) + prev_idx = AddInstance.get_previous_frame_index(context) if prev_idx is not None: prev_instances = context.labels.find( @@ -2929,71 +3197,26 @@ def do_action(cls, context: CommandContext, params: dict): from_prev_frame = True from_predicted = from_predicted if hasattr(from_predicted, "score") else None + from_predicted = cast(Optional[PredictedInstance], from_predicted) - # Now create the new instance - new_instance = Instance( - skeleton=context.state["skeleton"], - from_predicted=from_predicted, - frame=context.state["labeled_frame"], - ) + return copy_instance, from_predicted, from_prev_frame - has_missing_nodes = False - - # go through each node in skeleton - for node in context.state["skeleton"].node_names: - # if we're copying from a skeleton that has this node - if ( - copy_instance is not None - and node in copy_instance - and not copy_instance[node].isnan() - ): - # just copy x, y, and visible - # we don't want to copy a PredictedPoint or score attribute - new_instance[node] = Point( - x=copy_instance[node].x, - y=copy_instance[node].y, - visible=copy_instance[node].visible, - complete=mark_complete, - ) - else: - has_missing_nodes = True - - if has_missing_nodes: - # mark the node as not "visible" if we're copying from a predicted instance without this node - is_visible = copy_instance is None or (not hasattr(copy_instance, "score")) - - if init_method == "force_directed": - AddMissingInstanceNodes.add_force_directed_nodes( - context=context, - instance=new_instance, - visible=is_visible, - center_point=location, - ) - elif init_method == "random": - AddMissingInstanceNodes.add_random_nodes( - context=context, instance=new_instance, visible=is_visible - ) - elif init_method == "template": - AddMissingInstanceNodes.add_nodes_from_template( - context=context, - instance=new_instance, - visible=is_visible, - center_point=location, - ) - else: - AddMissingInstanceNodes.add_best_nodes( - context=context, instance=new_instance, visible=is_visible - ) + @staticmethod + def get_previous_frame_index(context: CommandContext) -> Optional[int]: + """Returns index of previous frame.""" - # If we're copying a predicted instance or from another frame, copy the track - if hasattr(copy_instance, "score") or from_prev_frame: - new_instance.track = copy_instance.track + frames = context.labels.frames( + context.state["video"], + from_frame_idx=context.state["frame_idx"], + reverse=True, + ) - # Add the instance - context.labels.add_instance(context.state["labeled_frame"], new_instance) + try: + next_idx = next(frames).frame_idx + except: + return - if context.state["labeled_frame"] not in context.labels.labels: - context.labels.append(context.state["labeled_frame"]) + return next_idx class SetInstancePointLocations(EditCommand): diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 0a008bea7..721bdc321 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -15,20 +15,17 @@ """ -from qtpy import QtCore, QtWidgets, QtGui - -import numpy as np import os - from operator import itemgetter +from pathlib import Path +from typing import Any, Callable, List, Optional -from typing import Any, Callable, Dict, List, Optional, Type +import numpy as np +from qtpy import QtCore, QtGui, QtWidgets -from sleap.gui.state import GuiState from sleap.gui.commands import CommandContext -from sleap.gui.color import ColorManager -from sleap.io.dataset import Labels -from sleap.instance import LabeledFrame, Instance +from sleap.gui.state import GuiState +from sleap.instance import LabeledFrame from sleap.skeleton import Skeleton @@ -386,10 +383,25 @@ def getSelectedRowItem(self) -> Any: class VideosTableModel(GenericTableModel): - properties = ("filename", "frames", "height", "width", "channels") - - def item_to_data(self, obj, item): - return {key: getattr(item, key) for key in self.properties} + properties = ( + "name", + "filepath", + "frames", + "height", + "width", + "channels", + ) + + def item_to_data(self, obj, item: "Video"): + data = {} + for property in self.properties: + if property == "name": + data[property] = Path(item.filename).name + elif property == "filepath": + data[property] = str(Path(item.filename).parent) + else: + data[property] = getattr(item, property) + return data class SkeletonNodesTableModel(GenericTableModel): @@ -413,13 +425,6 @@ def set_item(self, item, key, value): elif key == "symmetry": self.context.setNodeSymmetry(skeleton=self.obj, node=item, symmetry=value) - def get_item_color(self, item: Any, key: str): - if self.skeleton: - color = self.context.app.color_manager.get_item_color( - item, parent_skeleton=self.skeleton - ) - return QtGui.QColor(*color) - class SkeletonEdgesTableModel(GenericTableModel): """Table model for skeleton edges.""" @@ -436,14 +441,6 @@ def object_to_items(self, skeleton: Skeleton): ] return items - def get_item_color(self, item: Any, key: str): - if self.skeleton: - edge_pair = (item["source"], item["destination"]) - color = self.context.app.color_manager.get_item_color( - edge_pair, parent_skeleton=self.skeleton - ) - return QtGui.QColor(*color) - class LabeledFrameTableModel(GenericTableModel): """Table model for listing instances in labeled frame. diff --git a/sleap/gui/dialogs/export_clip.py b/sleap/gui/dialogs/export_clip.py index 312f9a807..f84766d18 100644 --- a/sleap/gui/dialogs/export_clip.py +++ b/sleap/gui/dialogs/export_clip.py @@ -11,16 +11,16 @@ def __init__(self): super().__init__(form_name="labeled_clip_form") - can_use_skvideo = VideoWriter.can_use_skvideo() + can_use_ffmpeg = VideoWriter.can_use_ffmpeg() - if can_use_skvideo: + if can_use_ffmpeg: message = ( "MP4 file will be encoded using " - "system ffmpeg via scikit-video (preferred option)." + "system ffmpeg via imageio (preferred option)." ) else: message = ( - "Unable to use ffpmeg via scikit-video. " + "Unable to use ffpmeg via imageio. " "AVI file will be encoding using OpenCV." ) diff --git a/sleap/gui/dialogs/filedialog.py b/sleap/gui/dialogs/filedialog.py index 930c71b0d..ff394d191 100644 --- a/sleap/gui/dialogs/filedialog.py +++ b/sleap/gui/dialogs/filedialog.py @@ -29,7 +29,8 @@ def set_dialog_type(cls, *args, **kwargs): if cls.is_non_native: kwargs["options"] = kwargs.get("options", 0) - kwargs["options"] |= QtWidgets.QFileDialog.DontUseNativeDialog + if not kwargs["options"]: + kwargs["options"] = QtWidgets.QFileDialog.DontUseNativeDialog # Make sure we don't send empty options argument if "options" in kwargs and not kwargs["options"]: diff --git a/sleap/gui/dialogs/frame_range.py b/sleap/gui/dialogs/frame_range.py new file mode 100644 index 000000000..7165dd939 --- /dev/null +++ b/sleap/gui/dialogs/frame_range.py @@ -0,0 +1,42 @@ +"""Frame range dialog.""" +from qtpy import QtWidgets +from sleap.gui.dialogs.formbuilder import FormBuilderModalDialog +from typing import Optional + + +class FrameRangeDialog(FormBuilderModalDialog): + def __init__(self, max_frame_idx: Optional[int] = None, title: str = "Frame Range"): + + super().__init__(form_name="frame_range_form") + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + if max_frame_idx is not None: + min_frame_idx_field.setRange(1, max_frame_idx) + min_frame_idx_field.setValue(1) + + max_frame_idx_field.setRange(1, max_frame_idx) + max_frame_idx_field.setValue(max_frame_idx) + + min_frame_idx_field.valueChanged.connect(self._update_max_frame_range) + max_frame_idx_field.valueChanged.connect(self._update_min_frame_range) + + self.setWindowTitle(title) + + def _update_max_frame_range(self, value): + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + max_frame_idx_field.setRange(value, max_frame_idx_field.maximum()) + + def _update_min_frame_range(self, value): + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + min_frame_idx_field.setRange(min_frame_idx_field.minimum(), value) + + +if __name__ == "__main__": + app = QtWidgets.QApplication([]) + dialog = FrameRangeDialog(max_frame_idx=100) + print(dialog.get_results()) diff --git a/sleap/gui/dialogs/metrics.py b/sleap/gui/dialogs/metrics.py index 864a6adf0..884b373a9 100644 --- a/sleap/gui/dialogs/metrics.py +++ b/sleap/gui/dialogs/metrics.py @@ -120,10 +120,11 @@ def _show_model_params( if cfg_info is None: cfg_info = self.table_view.getSelectedRowItem() + cfg_getter = self._cfg_getter key = cfg_info.path if key not in model_detail_widgets: model_detail_widgets[key] = TrainingEditorWidget.from_trained_config( - cfg_info + cfg_info, cfg_getter ) model_detail_widgets[key].show() diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index d9f872fda..bc26d826c 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -1,24 +1,20 @@ """ Dialogs for running training and/or inference in GUI. """ -import cattr -import os +import json import shutil -import atexit import tempfile from pathlib import Path +from typing import Dict, List, Optional, Text, cast + +import cattr +from qtpy import QtCore, QtGui, QtWidgets import sleap from sleap import Labels, Video from sleap.gui.dialogs.filedialog import FileDialog from sleap.gui.dialogs.formbuilder import YamlFormWidget -from sleap.gui.learning import runners, scopedkeydict, configs, datagen, receptivefield - -from typing import Dict, List, Optional, Text, Optional, cast - -from qtpy import QtWidgets, QtCore - -import json +from sleap.gui.learning import configs, datagen, receptivefield, runners, scopedkeydict # List of fields which should show list of skeleton nodes NODE_LIST_FIELDS = [ @@ -128,12 +124,25 @@ def __init__( self.message_widget = QtWidgets.QLabel("") # Layout for entire dialog - layout = QtWidgets.QVBoxLayout() - layout.addWidget(self.tab_widget) - layout.addWidget(self.message_widget) - layout.addWidget(buttons_layout_widget) + content_widget = QtWidgets.QWidget() + content_layout = QtWidgets.QVBoxLayout(content_widget) - self.setLayout(layout) + content_layout.addWidget(self.tab_widget) + content_layout.addWidget(self.message_widget) + content_layout.addWidget(buttons_layout_widget) + + # Create the QScrollArea. + scroll_area = QtWidgets.QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(content_widget) + + scroll_area.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) + scroll_area.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) + + layout = QtWidgets.QVBoxLayout(self) + layout.addWidget(scroll_area) + + self.adjust_initial_size() # Default to most recently trained pipeline (if there is one) self.set_default_pipeline_tab() @@ -157,6 +166,20 @@ def __init__( self.view_datagen ) + def adjust_initial_size(self): + # Get screen size + screen = QtGui.QGuiApplication.primaryScreen().availableGeometry() + + max_width = 1860 + max_height = 1150 + margin = 0.10 + + # Calculate target width and height + target_width = min(screen.width() - screen.width() * margin, max_width) + target_height = min(screen.height() - screen.height() * margin, max_height) + # Set the dialog's dimensions + self.resize(target_width, target_height) + def update_file_lists(self): self._cfg_getter.update() for tab in self.tabs.values(): @@ -579,6 +602,7 @@ def get_selected_frames_to_predict( def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInference: predict_frames_choice = pipeline_form_data.get("_predict_frames", "") + batch_size = pipeline_form_data.get("batch_size") frame_selection = self.get_selected_frames_to_predict(pipeline_form_data) frame_count = self.count_total_frames_for_selection_option(frame_selection) @@ -591,6 +615,7 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen ) ], total_frame_count=frame_count, + batch_size=batch_size, ) elif predict_frames_choice.startswith("suggested"): items_for_inference = runners.ItemsForInference( @@ -600,6 +625,7 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen ) ], total_frame_count=frame_count, + batch_size=batch_size, ) else: items_for_inference = runners.ItemsForInference.from_video_frames_dict( @@ -607,9 +633,24 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen total_frame_count=frame_count, labels_path=self.labels_filename, labels=self.labels, + batch_size=batch_size, ) return items_for_inference + def _validate_id_model(self) -> bool: + """Make sure we have instances with tracks set for ID models.""" + if not self.labels.tracks: + message = "Cannot run ID model training without tracks." + return False + + found_tracks = False + for inst in self.labels.instances(): + if type(inst) == sleap.Instance and inst.track is not None: + found_tracks = True + break + + return found_tracks + def _validate_pipeline(self): can_run = True message = "" @@ -628,6 +669,15 @@ def _validate_pipeline(self): f"({', '.join(untrained)})." ) + # Make sure we have instances with tracks set for ID models. + if self.mode == "training" and self.current_pipeline in ( + "top-down-id", + "bottom-up-id", + ): + can_run = self.validate_id_model() + if not can_run: + message = "Cannot run ID model training without tracks." + # Make sure skeleton will be valid for bottom-up inference. if self.mode == "training" and self.current_pipeline == "bottom-up": skeleton = self.labels.skeletons[0] @@ -1088,8 +1138,12 @@ def __init__( self.setLayout(layout) @classmethod - def from_trained_config(cls, cfg_info: configs.ConfigFileInfo): - widget = cls(require_trained=True, head=cfg_info.head_name) + def from_trained_config( + cls, cfg_info: configs.ConfigFileInfo, cfg_getter: configs.TrainingConfigsGetter + ): + widget = cls( + require_trained=True, head=cfg_info.head_name, cfg_getter=cfg_getter + ) widget.acceptSelectedConfigInfo(cfg_info) widget.setWindowTitle(cfg_info.path_dir) return widget diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index ca60c4127..d0bb1f3ba 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -1,4 +1,5 @@ """Run training/inference in background process via CLI.""" + import abc import attr import os @@ -151,6 +152,7 @@ class ItemsForInference: items: List[ItemForInference] total_frame_count: int + batch_size: int def __len__(self): return len(self.items) @@ -160,6 +162,7 @@ def from_video_frames_dict( cls, video_frames_dict: Dict[Video, List[int]], total_frame_count: int, + batch_size: int, labels: Labels, labels_path: Optional[str] = None, ): @@ -174,7 +177,9 @@ def from_video_frames_dict( video_idx=labels.videos.index(video), ) ) - return cls(items=items, total_frame_count=total_frame_count) + return cls( + items=items, total_frame_count=total_frame_count, batch_size=batch_size + ) @attr.s(auto_attribs=True) @@ -255,12 +260,20 @@ def make_predict_cli_call( "tracking.max_tracking", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", + "tracking.oks_score_weighting", ) for key in bool_items_as_ints: if key in self.inference_params: self.inference_params[key] = int(self.inference_params[key]) + remove_spaces_items = ("tracking.similarity",) + + for key in remove_spaces_items: + if key in self.inference_params: + value = self.inference_params[key] + self.inference_params[key] = value.replace(" ", "_") + for key, val in self.inference_params.items(): if not key.startswith(("_", "outputs.", "model.", "data.")): cli_args.extend((f"--{key}", str(val))) @@ -496,9 +509,11 @@ def write_pipeline_files( "data_path": os.path.basename(data_path), "models": [Path(p).as_posix() for p in new_cfg_filenames], "output_path": prediction_output_path, - "type": "labels" - if type(item_for_inference) == DatasetItemForInference - else "video", + "type": ( + "labels" + if type(item_for_inference) == DatasetItemForInference + else "video" + ), "only_suggested_frames": only_suggested_frames, "tracking": tracking_args, } @@ -540,6 +555,7 @@ def run_learning_pipeline( """ save_viz = inference_params.get("_save_viz", False) + keep_viz = inference_params.get("_keep_viz", False) if "movenet" in inference_params["_pipeline"]: trained_job_paths = [inference_params["_pipeline"]] @@ -550,8 +566,10 @@ def run_learning_pipeline( labels_filename=labels_filename, labels=labels, config_info_list=config_info_list, + inference_params=inference_params, gui=True, save_viz=save_viz, + keep_viz=keep_viz, ) # Check that all the models were trained @@ -577,8 +595,10 @@ def run_gui_training( labels_filename: str, labels: Labels, config_info_list: List[ConfigFileInfo], + inference_params: Dict[str, Any], gui: bool = True, save_viz: bool = False, + keep_viz: bool = False, ) -> Dict[Text, Text]: """ Runs training for each training job. @@ -588,19 +608,28 @@ def run_gui_training( config_info_list: List of ConfigFileInfo with configs for training. gui: Whether to show gui windows and process gui events. save_viz: Whether to save visualizations from training. + keep_viz: Whether to keep prediction visualization images after training. Returns: Dictionary, keys are head name, values are path to trained config. """ trained_job_paths = dict() - + zmq_ports = None if gui: from sleap.gui.widgets.monitor import LossViewer from sleap.gui.widgets.imagedir import QtImageDirectoryWidget - # open training monitor window - win = LossViewer() + zmq_ports = dict() + zmq_ports["controller_port"] = inference_params.get("controller_port", 9000) + zmq_ports["publish_port"] = inference_params.get("publish_port", 9001) + + # Open training monitor window + win = LossViewer(zmq_ports=zmq_ports) + + # Reassign the values in the inference parameters in case the ports were changed + inference_params["controller_port"] = win.zmq_ports["controller_port"] + inference_params["publish_port"] = win.zmq_ports["publish_port"] win.resize(600, 400) win.show() @@ -664,10 +693,12 @@ def waiting(): # Run training trained_job_path, ret = train_subprocess( job_config=job, + inference_params=inference_params, labels_filename=labels_filename, video_paths=video_path_list, waiting_callback=waiting, save_viz=save_viz, + keep_viz=keep_viz, ) if ret == "success": @@ -806,9 +837,11 @@ def waiting_item(**kwargs): def train_subprocess( job_config: TrainingJobConfig, labels_filename: str, + inference_params: Dict[str, Any], video_paths: Optional[List[Text]] = None, waiting_callback: Optional[Callable] = None, save_viz: bool = False, + keep_viz: bool = False, ): """Runs training inside subprocess.""" run_path = job_config.outputs.run_path @@ -829,10 +862,16 @@ def train_subprocess( training_job_path, labels_filename, "--zmq", + "--controller_port", + str(inference_params["controller_port"]), + "--publish_port", + str(inference_params["publish_port"]), ] if save_viz: cli_args.append("--save_viz") + if keep_viz: + cli_args.append("--keep_viz") # Use cli arg since cli ignores setting in config if job_config.outputs.tensorboard.write_logs: diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index d27b069ac..879d12810 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -71,7 +71,6 @@ def remove_from_scene(self): except RuntimeError as e: # Internal C++ object (PySide2.QtWidgets.QGraphicsPathItem) already deleted. logger.debug(e) - pass # Stop tracking the items after they been removed from the scene self.items = [] diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index 361585719..c5f091658 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -1,17 +1,16 @@ """Track trail and track list overlays.""" +from typing import Dict, Iterable, List, Optional, Tuple + +import attr +from qtpy import QtCore, QtGui + from sleap.gui.overlays.base import BaseOverlay +from sleap.gui.widgets.video import QtTextWithBackground from sleap.instance import Track from sleap.io.dataset import Labels from sleap.io.video import Video from sleap.prefs import prefs -from sleap.gui.widgets.video import QtTextWithBackground - -import attr - -from typing import Iterable, List, Optional, Dict - -from qtpy import QtCore, QtGui @attr.s(auto_attribs=True) @@ -48,7 +47,9 @@ def __attrs_post_init__(self): @classmethod def get_length_options(cls): - return (0, 10, 50, 100, 250) + if prefs["trail length"] != 0: + return (0, 10, 50, 100, 250, 500, prefs["trail length"]) + return (0, 10, 50, 100, 250, 500) @classmethod def get_shade_options(cls): @@ -56,7 +57,9 @@ def get_shade_options(cls): return {"Dark": 0.6, "Normal": 1.0, "Light": 1.25} - def get_track_trails(self, frame_selection: Iterable["LabeledFrame"]): + def get_track_trails( + self, frame_selection: Iterable["LabeledFrame"] + ) -> Optional[Dict[Track, List[List[Tuple[float, float]]]]]: """Get data needed to draw track trail. Args: @@ -152,6 +155,8 @@ def add_to_scene(self, video: Video, frame_idx: int): frame_selection = self.get_frame_selection(video, frame_idx) all_track_trails = self.get_track_trails(frame_selection) + if all_track_trails is None: + return for track, trails in all_track_trails.items(): trail_color = tuple( diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 48b916437..b85d6ac32 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame] prediction_score=cls.prediction_score, velocity=cls.velocity, frame_chunk=cls.frame_chunk, + max_point_displacement=cls.max_point_displacement, ) method = str.replace(params["method"], " ", "_") @@ -213,6 +214,7 @@ def _prediction_score_video( ): lfs = labels.find(video) frames = len(lfs) + # initiate an array filled with -1 to store frame index (starting from 0). idxs = np.full((frames), -1, dtype="int") @@ -291,6 +293,56 @@ def _velocity_video( return cls.idx_list_to_frame_list(frame_idxs, video) + @classmethod + def max_point_displacement( + cls, + labels: "Labels", + videos: List[Video], + displacement_threshold: float, + **kwargs, + ): + """Finds frames with maximum point displacement above a threshold.""" + + proposed_suggestions = [] + for video in videos: + proposed_suggestions.extend( + cls._max_point_displacement_video(video, labels, displacement_threshold) + ) + + suggestions = VideoFrameSuggestions.filter_unique_suggestions( + labels, videos, proposed_suggestions + ) + + return suggestions + + @classmethod + def _max_point_displacement_video( + cls, video: Video, labels: "Labels", displacement_threshold: float + ): + # Get numpy of shape (frames, tracks, nodes, x, y) + labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False) + + # Return empty list if not enough frames + n_frames, n_tracks, n_nodes, _ = labels_numpy.shape + + if n_frames < 2: + return [] + + # Calculate displacements + diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y) + euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes) + mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks) + + # Find frames where mean displacement is above threshold + threshold_mask = np.any( + mean_euc_norm > displacement_threshold, axis=-1 + ) # (frames - 1,) + frame_idxs = list( + np.argwhere(threshold_mask).flatten() + 1 + ) # [0, len(frames - 1)] + + return cls.idx_list_to_frame_list(frame_idxs, video) + @classmethod def frame_chunk( cls, diff --git a/sleap/gui/utils.py b/sleap/gui/utils.py new file mode 100644 index 000000000..4f8215706 --- /dev/null +++ b/sleap/gui/utils.py @@ -0,0 +1,28 @@ +"""Generic module containing utilities used for the GUI.""" + +import zmq +from typing import Optional + + +def is_port_free(port: int, zmq_context: Optional[zmq.Context] = None) -> bool: + """Checks if a port is free.""" + ctx = zmq.Context.instance() if zmq_context is None else zmq_context + socket = ctx.socket(zmq.REP) + address = f"tcp://127.0.0.1:{port}" + try: + socket.bind(address) + socket.unbind(address) + return True + except zmq.error.ZMQError: + return False + finally: + socket.close() + + +def select_zmq_port(zmq_context: Optional[zmq.Context] = None) -> int: + """Select a port that is free to connect within the given context.""" + ctx = zmq.Context.instance() if zmq_context is None else zmq_context + socket = ctx.socket(zmq.REP) + port = socket.bind_to_random_port("tcp://127.0.0.1") + socket.close() + return port diff --git a/sleap/gui/widgets/docks.py b/sleap/gui/widgets/docks.py index 43e218adb..bd20bf79a 100644 --- a/sleap/gui/widgets/docks.py +++ b/sleap/gui/widgets/docks.py @@ -30,10 +30,8 @@ ) from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.widgets.views import CollapsibleWidget -from sleap.skeleton import Skeleton -from sleap.util import decode_preview_image, find_files_by_suffix, get_package_file - -# from sleap.gui.app import MainWindow +from sleap.skeleton import Skeleton, SkeletonDecoder +from sleap.util import find_files_by_suffix, get_package_file class DockWidget(QDockWidget): @@ -365,7 +363,7 @@ def create_templates_groupbox(self) -> QGroupBox: def updatePreviewImage(preview_image_bytes: bytes): # Decode the preview image - preview_image = decode_preview_image(preview_image_bytes) + preview_image = SkeletonDecoder.decode_preview_image(preview_image_bytes) # Create a QImage from the Image preview_image = QtGui.QImage( @@ -557,7 +555,7 @@ def create_table_edit_buttons(self) -> QWidget: hb = QHBoxLayout() self.add_button( - hb, "New Instance", lambda x: main_window.commands.newInstance() + hb, "New Instance", lambda x: main_window.commands.newInstance(offset=10) ) self.add_button( hb, "Delete Instance", main_window.commands.deleteSelectedInstance diff --git a/sleap/gui/widgets/monitor.py b/sleap/gui/widgets/monitor.py index 93bc483e9..fff8a0327 100644 --- a/sleap/gui/widgets/monitor.py +++ b/sleap/gui/widgets/monitor.py @@ -1,20 +1,590 @@ """GUI for monitoring training progress interactively.""" -import numpy as np -from time import perf_counter -from sleap.nn.config.training_job import TrainingJobConfig -import zmq -import jsonpickle import logging -from typing import Optional -from qtpy import QtCore, QtWidgets, QtGui -from qtpy.QtCharts import QtCharts +from time import perf_counter +from typing import Dict, Optional, Tuple + import attr +import jsonpickle +import numpy as np +import zmq +from matplotlib.collections import PathCollection +import matplotlib.transforms as mtransforms +from qtpy import QtCore, QtWidgets +from sleap.gui.utils import is_port_free, select_zmq_port +from sleap.gui.widgets.mpl import MplCanvas +from sleap.nn.config.training_job import TrainingJobConfig logger = logging.getLogger(__name__) +class LossPlot(MplCanvas): + """Matplotlib canvas for diplaying training and validation loss curves.""" + + def __init__( + self, + width: int = 5, + height: int = 4, + dpi: int = 100, + log_scale: bool = True, + ignore_outliers: bool = False, + ): + super().__init__(width=width, height=height, dpi=dpi) + + self._log_scale: bool = log_scale + + self.ignore_outliers = ignore_outliers + + # Initialize the series for the plot + self.series: dict = {} + COLOR_TRAIN = (18, 158, 220) + COLOR_VAL = (248, 167, 52) + COLOR_BEST_VAL = (151, 204, 89) + + # Initialize scatter series for batch training loss + self.series["batch"] = self._init_series( + series_type=self.axes.scatter, + name="Batch Training Loss", + color=COLOR_TRAIN + (48,), + border_color=(255, 255, 255, 25), + ) + + # Initialize line series for epoch training loss + self.series["epoch_loss"] = self._init_series( + series_type=self.axes.plot, + name="Epoch Training Loss", + color=COLOR_TRAIN + (255,), + line_width=3.0, + ) + + # Initialize line series for epoch validation loss + self.series["val_loss"] = self._init_series( + series_type=self.axes.plot, + name="Epoch Validation Loss", + color=COLOR_VAL + (255,), + line_width=3.0, + zorder=4, # Below best validation loss series + ) + + # Initialize scatter series for best epoch validation loss + self.series["val_loss_best"] = self._init_series( + series_type=self.axes.scatter, + name="Best Validation Loss", + color=COLOR_BEST_VAL + (255,), + border_color=(255, 255, 255, 25), + zorder=5, # Above epoch validation loss series + ) + + # Set the x and y positions for the xy labels (as fraction of figure size) + self.ypos_xlabel = 0.1 + self.xpos_ylabel = 0.05 + + # Padding between the axes and the xy labels + self.xpos_padding = 0.2 + self.ypos_padding = 0.1 + + # Set up the major gridlines + self._setup_major_gridlines() + + # Set up the x-axis + self._setup_x_axis() + + # Set up the y-axis + self._set_up_y_axis() + + # Set up the legend + self.legend_width, legend_height = self._setup_legend() + + # Set up the title space + self.ypos_title = None + title_height = self._set_title_space() + self.ypos_title = 1 - title_height - self.ypos_padding + + # Determine the top height of the plot + top_height = max(title_height, legend_height) + + # Adjust the figure layout + self.xpos_left_plot = self.xpos_ylabel + self.xpos_padding + self.xpos_right_plot = 0.97 + self.ypos_bottom_plot = self.ypos_xlabel + self.ypos_padding + self.ypos_top_plot = 1 - top_height - self.ypos_padding + + # Adjust the top parameters as needed + self.fig.subplots_adjust( + left=self.xpos_left_plot, + right=self.xpos_right_plot, + top=self.ypos_top_plot, + bottom=self.ypos_bottom_plot, + ) + + @property + def log_scale(self): + """Returns True if the plot has a log scale for y-axis.""" + + return self._log_scale + + @log_scale.setter + def log_scale(self, val): + """Sets the scale of the y axis to log if True else linear.""" + + if isinstance(val, bool): + self._log_scale = val + + y_scale = "log" if self._log_scale else "linear" + self.axes.set_yscale(y_scale) + self.redraw_plot() + + def set_data_on_scatter(self, xs, ys, which): + """Set data on a scatter plot. + + Not to be used with line plots. + + Args: + xs: The x-coordinates of the data points. + ys: The y-coordinates of the data points. + which: The type of data point. Possible values are: + * "batch" + * "val_loss_best" + """ + + offsets = np.column_stack((xs, ys)) + self.series[which].set_offsets(offsets) + + def add_data_to_plot(self, x, y, which): + """Add data to a line plot. + + Not to be used with scatter plots. + + Args: + x: The x-coordinate of the data point. + y: The y-coordinate of the data point. + which: The type of data point. Possible values are: + * "epoch_loss" + * "val_loss" + """ + + x_data, y_data = self.series[which].get_data() + self.series[which].set_data(np.append(x_data, x), np.append(y_data, y)) + + def resize_axes(self, x, y): + """Resize axes to fit data. + + This is only called when plotting batches. + + Args: + x: The x-coordinates of the data points. + y: The y-coordinates of the data points. + """ + + # Set X scale to show all points + x_min, x_max = self._calculate_xlim(x) + self.axes.set_xlim(x_min, x_max) + + # Set Y scale, ensuring that y_min and y_max do not lead to sngular transform + y_min, y_max = self._calculate_ylim(y) + y_min, y_max = self.axes.yaxis.get_major_locator().nonsingular(y_min, y_max) + self.axes.set_ylim(y_min, y_max) + + # Add gridlines at midpoint between major ticks (major gridlines are automatic) + self._add_midpoint_gridlines() + + # Redraw the plot + self.redraw_plot() + + def redraw_plot(self): + """Redraw the plot.""" + + self.fig.canvas.draw_idle() + + def set_title(self, title, color=None): + """Set the title of the plot. + + Args: + title: The title text to display. + """ + + if color is None: + color = "black" + + self.axes.set_title( + title, fontweight="light", fontsize="small", color=color, x=0.55, y=1.03 + ) + + def update_runtime_title( + self, + epoch: int, + dt_min: int, + dt_sec: int, + last_epoch_val_loss: float = None, + penultimate_epoch_val_loss: float = None, + mean_epoch_time_min: int = None, + mean_epoch_time_sec: int = None, + eta_ten_epochs_min: int = None, + epochs_in_plateau: int = None, + plateau_patience: int = None, + epoch_in_plateau_flag: bool = False, + best_val_x: int = None, + best_val_y: float = None, + epoch_size: int = None, + ): + + # Add training epoch and runtime info + title = self._get_training_epoch_and_runtime_text(epoch, dt_min, dt_sec) + + if last_epoch_val_loss is not None: + + if penultimate_epoch_val_loss is not None: + # Add mean epoch time and ETA for next 10 epochs + eta_text = self._get_eta_text( + mean_epoch_time_min, mean_epoch_time_sec, eta_ten_epochs_min + ) + title = self._add_with_newline(title, eta_text) + + # Add epochs in plateau if flag is set + if epoch_in_plateau_flag: + plateau_text = self._get_epochs_in_plateau_text( + epochs_in_plateau, plateau_patience + ) + title = self._add_with_newline(title, plateau_text) + + # Add last epoch validation loss + last_val_text = self._get_last_validation_loss_text(last_epoch_val_loss) + title = self._add_with_newline(title, last_val_text) + + # Add best epoch validation loss if available + if best_val_x is not None: + best_epoch = (best_val_x // epoch_size) + 1 + best_val_text = self._get_best_validation_loss_text( + best_val_y, best_epoch + ) + title = self._add_with_newline(title, best_val_text) + + self.set_title(title) + + @staticmethod + def _get_training_epoch_and_runtime_text(epoch: int, dt_min: int, dt_sec: int): + """Get the training epoch and runtime text to display in the plot. + + Args: + epoch: The current epoch. + dt_min: The number of minutes since training started. + dt_sec: The number of seconds since training started. + """ + + runtime_text = ( + r"Training Epoch $\mathbf{" + str(epoch + 1) + r"}$ / " + r"Runtime: $\mathbf{" + f"{int(dt_min):02}:{int(dt_sec):02}" + r"}$" + ) + + return runtime_text + + @staticmethod + def _get_eta_text(mean_epoch_time_min, mean_epoch_time_sec, eta_ten_epochs_min): + """Get the mean time and ETA text to display in the plot. + + Args: + mean_epoch_time_min: The mean time per epoch in minutes. + mean_epoch_time_sec: The mean time per epoch in seconds. + eta_ten_epochs_min: The estimated time for the next ten epochs in minutes. + """ + + runtime_text = ( + r"Mean Time per Epoch: $\mathbf{" + + f"{int(mean_epoch_time_min):02}:{int(mean_epoch_time_sec):02}" + + r"}$ / " + r"ETA Next 10 Epochs: $\mathbf{" + f"{int(eta_ten_epochs_min)}" + r"}$ min" + ) + + return runtime_text + + @staticmethod + def _get_epochs_in_plateau_text(epochs_in_plateau, plateau_patience): + """Get the epochs in plateau text to display in the plot. + + Args: + epochs_in_plateau: The number of epochs in plateau. + plateau_patience: The number of epochs to wait before stopping training. + """ + + plateau_text = ( + r"Epochs in Plateau: $\mathbf{" + f"{epochs_in_plateau}" + r"}$ / " + r"$\mathbf{" + f"{plateau_patience}" + r"}$" + ) + + return plateau_text + + @staticmethod + def _get_last_validation_loss_text(last_epoch_val_loss): + """Get the last epoch validation loss text to display in the plot. + + Args: + last_epoch_val_loss: The validation loss from the last epoch. + """ + + last_val_loss_text = ( + "Last Epoch Validation Loss: " + r"$\mathbf{" + f"{last_epoch_val_loss:.3e}" + r"}$" + ) + + return last_val_loss_text + + @staticmethod + def _get_best_validation_loss_text(best_val_y, best_epoch): + """Get the best epoch validation loss text to display in the plot. + + Args: + best_val_x: The epoch number of the best validation loss. + best_val_y: The best validation loss. + """ + + best_val_loss_text = ( + r"Best Epoch Validation Loss: $\mathbf{" + + f"{best_val_y:.3e}" + + r"}$ (epoch $\mathbf{" + + str(best_epoch) + + r"}$)" + ) + + return best_val_loss_text + + @staticmethod + def _add_with_newline(old_text: str, new_text: str): + """Add a new line to the text. + + Args: + old_text: The existing text. + new_text: The text to add on a new line. + """ + + return old_text + "\n" + new_text + + @staticmethod + def _calculate_xlim(x: np.ndarray, dx: float = 0.5): + """Calculates x-axis limits. + + Args: + x: Array of x data to fit the limits to. + dx: The padding to add to the limits. + + Returns: + Tuple of the minimum and maximum x-axis limits. + """ + + x_min = min(x) - dx + x_min = x_min if x_min > 0 else 0 + x_max = max(x) + dx + + return x_min, x_max + + def _calculate_ylim(self, y: np.ndarray, dy: float = 0.02): + """Calculates y-axis limits. + + Args: + y: Array of y data to fit the limits to. + dy: The padding to add to the limits. + + Returns: + Tuple of the minimum and maximum y-axis limits. + """ + + if self.ignore_outliers: + dy = np.ptp(y) * 0.02 + # Set Y scale to exclude outliers + q1, q3 = np.quantile(y, (0.25, 0.75)) + iqr = q3 - q1 # Interquartile range + y_min = q1 - iqr * 1.5 + y_max = q3 + iqr * 1.5 + + # Keep within range of data + y_min = max(y_min, min(y) - dy) + y_max = min(y_max, max(y) + dy) + else: + # Set Y scale to show all points + dy = np.ptp(y) * 0.02 + y_min = min(y) - dy + y_max = max(y) + dy + + # For log scale, low cannot be 0 + if self.log_scale: + y_min = max(y_min, 1e-8) + + return y_min, y_max + + def _set_title_space(self): + """Set up the title space. + + Returns: + The height of the title space as a decimal fraction of the total figure height. + """ + + # Set a dummy title of the plot + n_lines = 5 # Number of lines in the title + title_str = "\n".join( + [r"Number: $\mathbf{" + str(n) + r"}$" for n in range(n_lines + 1)] + ) + self.set_title( + title_str, color="white" + ) # Set the title color to white so it's not visible + + # Draw the canvas to ensure the title is created + self.fig.canvas.draw() + + # Get the title Text object + title = self.axes.title + + # Get the bounding box of the title in display coordinates + bbox = title.get_window_extent() + + # Transform the bounding box to figure coordinates + bbox = bbox.transformed(self.fig.transFigure.inverted()) + + # Calculate the height of the title as a percentage of the total figure height + title_height = bbox.height + + return title_height + + def _setup_x_axis(self): + """Set up the x axis. + + This includes setting the label, limits, and bottom/right adjustment. + """ + + self.axes.set_xlim(0, 1) + self.axes.set_xlabel("Batches", fontweight="bold", fontsize="small") + + # Set the x-label in the center of the axes and some amount above the bottom of the figure + blended_transform = mtransforms.blended_transform_factory( + self.axes.transAxes, self.fig.transFigure + ) + self.axes.xaxis.set_label_coords( + 0.5, self.ypos_xlabel, transform=blended_transform + ) + + def _set_up_y_axis(self): + """Set up the y axis. + + This includes setting the label, limits, scaling, and left adjustment. + """ + + # Set the minimum value of the y-axis depending on scaling + if self.log_scale: + yscale = "log" + y_min = 0.001 + else: + yscale = "linear" + y_min = 0 + self.axes.set_ylim(bottom=y_min) + self.axes.set_yscale(yscale) + + # Set the y-label name, size, wight, and position + self.axes.set_ylabel("Loss", fontweight="bold", fontsize="small") + self.axes.yaxis.set_label_coords( + self.xpos_ylabel, 0.5, transform=self.fig.transFigure + ) + + def _setup_legend(self): + """Set up the legend. + + Returns: + Tuple of the width and height of the legend as a decimal fraction of the total figure width and height. + """ + + # Move the legend outside the plot on the upper left + legend = self.axes.legend( + loc="upper left", + fontsize="small", + bbox_to_anchor=(0, 1), + bbox_transform=self.fig.transFigure, + ) + + # Draw the canvas to ensure the legend is created + self.fig.canvas.draw() + + # Get the bounding box of the legend in display coordinates + bbox = legend.get_window_extent() + + # Transform the bounding box to figure coordinates + bbox = bbox.transformed(self.fig.transFigure.inverted()) + + # Calculate the width and height of the legend as a percentage of the total figure width and height + return bbox.width, bbox.height + + def _setup_major_gridlines(self): + + # Set the outline color of the plot to gray + for spine in self.axes.spines.values(): + spine.set_edgecolor("#d3d3d3") # Light gray color + + # Remove the top and right axis spines + self.axes.spines["top"].set_visible(False) + self.axes.spines["right"].set_visible(False) + + # Set the tick markers color to light gray, but not the tick labels + self.axes.tick_params( + axis="both", which="both", color="#d3d3d3", labelsize="small" + ) + + # Add gridlines at the tick labels + self.axes.grid(True, which="major", linewidth=0.5, color="#d3d3d3") + + def _add_midpoint_gridlines(self): + # Clear existing minor vertical lines + for line in self.axes.get_lines(): + if line.get_linestyle() == ":": + line.remove() + + # Add gridlines at midpoint between major ticks + major_ticks = self.axes.yaxis.get_majorticklocs() + if len(major_ticks) > 1: + prev_major_tick = major_ticks[0] + for major_tick in major_ticks[:-1]: + midpoint = (major_tick + prev_major_tick) / 2 + self.axes.axhline( + midpoint, linestyle=":", linewidth=0.5, color="#d3d3d3" + ) + prev_major_tick = major_tick + + def _init_series( + self, + series_type, + color, + name: Optional[str] = None, + line_width: Optional[float] = None, + border_color: Optional[Tuple[int, int, int]] = None, + zorder: Optional[int] = None, + ): + + # Set the color + color = [c / 255.0 for c in color] # Normalize color values to [0, 1] + + # Create the series + series = series_type( + [], + [], + color=color, + label=name, + marker="o", + zorder=zorder, + ) + + # ax.plot returns a list of PathCollections, so we need to get the first one + if not isinstance(series, PathCollection): + series = series[0] + + if line_width is not None: + series.set_linewidth(line_width) + + # Set the border color (edge color) + if border_color is not None: + border_color = [ + c / 255.0 for c in border_color + ] # Normalize color values to [0, 1] + series.set_edgecolor(border_color) + + return series + + class LossViewer(QtWidgets.QMainWindow): """Qt window for showing in-progress training metrics sent over ZMQ.""" @@ -22,6 +592,7 @@ class LossViewer(QtWidgets.QMainWindow): def __init__( self, + zmq_ports: Dict = None, zmq_context: Optional[zmq.Context] = None, show_controller=True, parent=None, @@ -33,41 +604,62 @@ def __init__( self.cancel_button = None self.canceled = False + # Set up ZMQ ports for communication. + zmq_ports = zmq_ports or dict() + zmq_ports["publish_port"] = zmq_ports.get("publish_port", 9001) + zmq_ports["controller_port"] = zmq_ports.get("controller_port", 9000) + self.zmq_ports = zmq_ports + self.batches_to_show = -1 # -1 to show all - self.ignore_outliers = False - self.log_scale = True + self._ignore_outliers = False + self._log_scale = True self.message_poll_time_ms = 20 # ms self.redraw_batch_time_ms = 500 # ms self.last_redraw_batch = None + self.canvas = None self.reset() - self.setup_zmq(zmq_context) + self._setup_zmq(zmq_context) def __del__(self): - self.unbind() + self._unbind() - def close(self): - """Disconnect from ZMQ ports and close the window.""" - self.unbind() - super().close() + @property + def is_timer_running(self) -> bool: + """Return True if the timer has started.""" + return self.t0 is not None and self.is_running - def unbind(self): - """Disconnect from all ZMQ sockets.""" - if self.sub is not None: - self.sub.unbind(self.sub.LAST_ENDPOINT) - self.sub.close() - self.sub = None + @property + def log_scale(self): + """Returns True if the plot has a log scale for y-axis.""" - if self.zmq_ctrl is not None: - url = self.zmq_ctrl.LAST_ENDPOINT - self.zmq_ctrl.unbind(url) - self.zmq_ctrl.close() - self.zmq_ctrl = None + return self._log_scale - # If we started out own zmq context, terminate it. - if not self.ctx_given and self.ctx is not None: - self.ctx.term() - self.ctx = None + @log_scale.setter + def log_scale(self, val): + """Sets the scale of the y axis to log if True else linear.""" + + if isinstance(val, bool): + self._log_scale = val + + # Set the log scale on the canvas + self.canvas.log_scale = self._log_scale + + @property + def ignore_outliers(self): + """Returns True if the plot ignores outliers.""" + + return self._ignore_outliers + + @ignore_outliers.setter + def ignore_outliers(self, val): + """Sets whether to ignore outliers in the plot.""" + + if isinstance(val, bool): + self._ignore_outliers = val + + # Set the ignore_outliers on the canvas + self.canvas.ignore_outliers = self._ignore_outliers def reset( self, @@ -80,112 +672,34 @@ def reset( what: String identifier indicating which job type the current run corresponds to. """ - self.chart = QtCharts.QChart() - - self.series = dict() + self.canvas = LossPlot( + width=5, + height=4, + dpi=100, + log_scale=self.log_scale, + ignore_outliers=self.ignore_outliers, + ) - COLOR_TRAIN = (18, 158, 220) - COLOR_VAL = (248, 167, 52) - COLOR_BEST_VAL = (151, 204, 89) + self.mp_series = dict() + self.mp_series["batch"] = self.canvas.series["batch"] + self.mp_series["epoch_loss"] = self.canvas.series["epoch_loss"] + self.mp_series["val_loss"] = self.canvas.series["val_loss"] + self.mp_series["val_loss_best"] = self.canvas.series["val_loss_best"] - self.series["batch"] = QtCharts.QScatterSeries() - self.series["batch"].setName("Batch Training Loss") - self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48)) - self.series["batch"].setMarkerSize(8.0) - self.series["batch"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["batch"]) - - self.series["epoch_loss"] = QtCharts.QLineSeries() - self.series["epoch_loss"].setName("Epoch Training Loss") - self.series["epoch_loss"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) - pen = self.series["epoch_loss"].pen() - pen.setWidth(4) - self.series["epoch_loss"].setPen(pen) - self.chart.addSeries(self.series["epoch_loss"]) - - self.series["epoch_loss_scatter"] = QtCharts.QScatterSeries() - self.series["epoch_loss_scatter"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) - self.series["epoch_loss_scatter"].setMarkerSize(12.0) - self.series["epoch_loss_scatter"].setBorderColor( - QtGui.QColor(255, 255, 255, 25) - ) - self.chart.addSeries(self.series["epoch_loss_scatter"]) - - self.series["val_loss"] = QtCharts.QLineSeries() - self.series["val_loss"].setName("Epoch Validation Loss") - self.series["val_loss"].setColor(QtGui.QColor(*COLOR_VAL, 255)) - pen = self.series["val_loss"].pen() - pen.setWidth(4) - self.series["val_loss"].setPen(pen) - self.chart.addSeries(self.series["val_loss"]) - - self.series["val_loss_scatter"] = QtCharts.QScatterSeries() - self.series["val_loss_scatter"].setColor(QtGui.QColor(*COLOR_VAL, 255)) - self.series["val_loss_scatter"].setMarkerSize(12.0) - self.series["val_loss_scatter"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["val_loss_scatter"]) - - self.series["val_loss_best"] = QtCharts.QScatterSeries() - self.series["val_loss_best"].setName("Best Validation Loss") - self.series["val_loss_best"].setColor(QtGui.QColor(*COLOR_BEST_VAL, 255)) - self.series["val_loss_best"].setMarkerSize(12.0) - self.series["val_loss_best"].setBorderColor(QtGui.QColor(32, 32, 32, 25)) - self.chart.addSeries(self.series["val_loss_best"]) - - axisX = QtCharts.QValueAxis() - axisX.setLabelFormat("%d") - axisX.setTitleText("Batches") - self.chart.addAxis(axisX, QtCore.Qt.AlignBottom) - - # Create the different Y axes that can be used. - self.axisY = dict() - - self.axisY["log"] = QtCharts.QLogValueAxis() - self.axisY["log"].setBase(10) - - self.axisY["linear"] = QtCharts.QValueAxis() - - # Apply settings that apply to all Y axes. - for axisY in self.axisY.values(): - axisY.setLabelFormat("%f") - axisY.setLabelsVisible(True) - axisY.setMinorTickCount(1) - axisY.setTitleText("Loss") - - # Use the default Y axis. - axisY = self.axisY["log"] if self.log_scale else self.axisY["linear"] - - # Add axes to chart and series. - self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): - series.attachAxis(axisX) - series.attachAxis(axisY) - - # Setup legend. - self.chart.legend().setVisible(True) - self.chart.legend().setAlignment(QtCore.Qt.AlignTop) - self.chart.legend().setMarkerShape(QtCharts.QLegend.MarkerShapeCircle) - - # Hide scatters for epoch and val loss from legend. - for s in ("epoch_loss_scatter", "val_loss_scatter"): - self.chart.legend().markers(self.series[s])[0].setVisible(False) - - self.chartView = QtCharts.QChartView(self.chart) - self.chartView.setRenderHint(QtGui.QPainter.Antialiasing) layout = QtWidgets.QVBoxLayout() - layout.addWidget(self.chartView) + layout.addWidget(self.canvas) if self.show_controller: control_layout = QtWidgets.QHBoxLayout() field = QtWidgets.QCheckBox("Log Scale") field.setChecked(self.log_scale) - field.stateChanged.connect(self.toggle_log_scale) + field.stateChanged.connect(self._toggle_log_scale) control_layout.addWidget(field) field = QtWidgets.QCheckBox("Ignore Outliers") field.setChecked(self.ignore_outliers) - field.stateChanged.connect(self.toggle_ignore_outliers) + field.stateChanged.connect(self._toggle_ignore_outliers) control_layout.addWidget(field) control_layout.addWidget(QtWidgets.QLabel("Batches to Show:")) @@ -203,7 +717,7 @@ def reset( # Set connection action for when user selects another option. field.currentIndexChanged.connect( - lambda x: self.set_batches_to_show(self.batch_options[x]) + lambda x: self._set_batches_to_show(self.batch_options[x]) ) # Store field as property and add to layout. @@ -213,10 +727,10 @@ def reset( control_layout.addStretch(1) self.stop_button = QtWidgets.QPushButton("Stop Early") - self.stop_button.clicked.connect(self.stop) + self.stop_button.clicked.connect(self._stop) control_layout.addWidget(self.stop_button) self.cancel_button = QtWidgets.QPushButton("Cancel Training") - self.cancel_button.clicked.connect(self.cancel) + self.cancel_button.clicked.connect(self._cancel) control_layout.addWidget(self.cancel_button) widget = QtWidgets.QWidget() @@ -248,48 +762,16 @@ def reset( self.last_batch_number = 0 self.is_running = False - def toggle_ignore_outliers(self): - """Toggles whether to ignore outliers in chart scaling.""" - self.ignore_outliers = not self.ignore_outliers - - def toggle_log_scale(self): - """Toggle whether to use log-scaled y-axis.""" - self.log_scale = not self.log_scale - self.update_y_axis() - - def set_batches_to_show(self, batches: str): - """Set the number of batches to show on the x-axis. + def set_message(self, text: str): + """Set the chart title text.""" + self.canvas.set_title(text) - Args: - batches: Number of batches as a string. If numeric, this will be converted - to an integer. If non-numeric string (e.g., "All"), then all batches - will be shown. - """ - if batches.isdigit(): - self.batches_to_show = int(batches) - else: - self.batches_to_show = -1 + def close(self): + """Disconnect from ZMQ ports and close the window.""" + self._unbind() + super().close() - def update_y_axis(self): - """Update the y-axis when scale changes.""" - to = "log" if self.log_scale else "linear" - - # Remove other axes. - for name, axisY in self.axisY.items(): - if name != to: - if axisY in self.chart.axes(): - self.chart.removeAxis(axisY) - for series in self.chart.series(): - if axisY in series.attachedAxes(): - series.detachAxis(axisY) - - # Add axis. - axisY = self.axisY[to] - self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): - series.attachAxis(axisY) - - def setup_zmq(self, zmq_context: Optional[zmq.Context] = None): + def _setup_zmq(self, zmq_context: Optional[zmq.Context] = None): """Connect to ZMQ ports that listen to commands and updates. Args: @@ -305,112 +787,69 @@ def setup_zmq(self, zmq_context: Optional[zmq.Context] = None): # Progress monitoring, SUBSCRIBER self.sub = self.ctx.socket(zmq.SUB) self.sub.subscribe("") - self.sub.bind("tcp://127.0.0.1:9001") + + def find_free_port(port: int, zmq_context: zmq.Context): + """Find free port to bind to. + + Args: + port: The port to start searching from. + zmq_context: The ZMQ context to use. + + Returns: + The free port. + """ + attempts = 0 + max_attempts = 10 + while not is_port_free(port=port, zmq_context=zmq_context): + if attempts >= max_attempts: + raise RuntimeError( + f"Could not find free port to display training progress after " + f"{max_attempts} attempts. Please check your network settings " + "or use the CLI `sleap-train` command." + ) + port = select_zmq_port(zmq_context=self.ctx) + attempts += 1 + + return port + + # Find a free port and bind to it. + self.zmq_ports["publish_port"] = find_free_port( + port=self.zmq_ports["publish_port"], zmq_context=self.ctx + ) + publish_address = f"tcp://127.0.0.1:{self.zmq_ports['publish_port']}" + self.sub.bind(publish_address) # Controller, PUBLISHER self.zmq_ctrl = None if self.show_controller: self.zmq_ctrl = self.ctx.socket(zmq.PUB) - self.zmq_ctrl.bind("tcp://127.0.0.1:9000") + + # Find a free port and bind to it. + self.zmq_ports["controller_port"] = find_free_port( + port=self.zmq_ports["controller_port"], zmq_context=self.ctx + ) + controller_address = f"tcp://127.0.0.1:{self.zmq_ports['controller_port']}" + self.zmq_ctrl.bind(controller_address) # Set timer to poll for messages. self.timer = QtCore.QTimer() - self.timer.timeout.connect(self.check_messages) + self.timer.timeout.connect(self._check_messages) self.timer.start(self.message_poll_time_ms) - def cancel(self): - """Set the cancel flag.""" - self.canceled = True - if self.cancel_button is not None: - self.cancel_button.setText("Canceling...") - self.cancel_button.setEnabled(False) - - def stop(self): - """Send command to stop training.""" - if self.zmq_ctrl is not None: - # Send command to stop training. - logger.info("Sending command to stop training.") - self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) - - # Disable the button to prevent double messages. - if self.stop_button is not None: - self.stop_button.setText("Stopping...") - self.stop_button.setEnabled(False) - - def add_datapoint(self, x: int, y: float, which: str): - """Add a data point to graph. + def _set_batches_to_show(self, batches: str): + """Set the number of batches to show on the x-axis. Args: - x: The batch number (out of all epochs, not just current), or epoch. - y: The loss value. - which: Type of data point we're adding. Possible values are: - * "batch" (loss for the batch) - * "epoch_loss" (loss for the entire epoch) - * "val_loss" (validation loss for the epoch) + batches: Number of batches as a string. If numeric, this will be converted + to an integer. If non-numeric string (e.g., "All"), then all batches + will be shown. """ - if which == "batch": - self.X.append(x) - self.Y.append(y) - - # Redraw batch at intervals (faster than plotting every batch). - draw_batch = False - if self.last_redraw_batch is None: - draw_batch = True - else: - dt = perf_counter() - self.last_redraw_batch - draw_batch = (dt * 1000) >= self.redraw_batch_time_ms - - if draw_batch: - self.last_redraw_batch = perf_counter() - if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: - xs, ys = self.X, self.Y - else: - xs, ys = ( - self.X[-self.batches_to_show :], - self.Y[-self.batches_to_show :], - ) - - points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0] - self.series["batch"].replace(points) - - # Set X scale to show all points - dx = 0.5 - self.chart.axisX().setRange(min(xs) - dx, max(xs) + dx) - - if self.ignore_outliers: - dy = np.ptp(ys) * 0.02 - # Set Y scale to exclude outliers - q1, q3 = np.quantile(ys, (0.25, 0.75)) - iqr = q3 - q1 # interquartile range - low = q1 - iqr * 1.5 - high = q3 + iqr * 1.5 - - low = max(low, min(ys) - dy) # keep within range of data - high = min(high, max(ys) + dy) - else: - # Set Y scale to show all points - dy = np.ptp(ys) * 0.02 - low = min(ys) - dy - high = max(ys) + dy - - if self.log_scale: - low = max(low, 1e-8) # for log scale, low cannot be 0 - - self.chart.axisY().setRange(low, high) - + if batches.isdigit(): + self.batches_to_show = int(batches) else: - if which == "epoch_loss": - self.series["epoch_loss"].append(x, y) - self.series["epoch_loss_scatter"].append(x, y) - elif which == "val_loss": - self.series["val_loss"].append(x, y) - self.series["val_loss_scatter"].append(x, y) - if self.best_val_y is None or y < self.best_val_y: - self.best_val_x = x - self.best_val_y = y - self.series["val_loss_best"].replace([QtCore.QPointF(x, y)]) + self.batches_to_show = -1 - def set_start_time(self, t0: float): + def _set_start_time(self, t0: float): """Mark the start flag and time of the run. Args: @@ -419,52 +858,31 @@ def set_start_time(self, t0: float): self.t0 = t0 self.is_running = True - def set_end(self): - """Mark the end of the run.""" - self.is_running = False - - def update_runtime(self): + def _update_runtime(self): """Update the title text with the current running time.""" + if self.is_timer_running: dt = perf_counter() - self.t0 dt_min, dt_sec = divmod(dt, 60) - title = f"Training Epoch {self.epoch + 1} / " - title += f"Runtime: {int(dt_min):02}:{int(dt_sec):02}" - if self.last_epoch_val_loss is not None: - if self.penultimate_epoch_val_loss is not None: - title += ( - f"
Mean Time per Epoch: " - f"{int(self.mean_epoch_time_min):02}:{int(self.mean_epoch_time_sec):02} / " - f"ETA Next 10 Epochs: {int(self.eta_ten_epochs_min)} min" - ) - if self.epoch_in_plateau_flag: - title += ( - f"
Epochs in Plateau: " - f"{self.epochs_in_plateau} / " - f"{self.config.optimization.early_stopping.plateau_patience}" - ) - title += ( - f"
Last Epoch Validation Loss: " - f"{self.last_epoch_val_loss:.3e}" - ) - if self.best_val_x is not None: - best_epoch = (self.best_val_x // self.epoch_size) + 1 - title += ( - f"
Best Epoch Validation Loss: " - f"{self.best_val_y:.3e} (epoch {best_epoch})" - ) - self.set_message(title) - - @property - def is_timer_running(self) -> bool: - """Return True if the timer has started.""" - return self.t0 is not None and self.is_running - def set_message(self, text: str): - """Set the chart title text.""" - self.chart.setTitle(text) + self.canvas.update_runtime_title( + epoch=self.epoch, + dt_min=dt_min, + dt_sec=dt_sec, + last_epoch_val_loss=self.last_epoch_val_loss, + penultimate_epoch_val_loss=self.penultimate_epoch_val_loss, + mean_epoch_time_min=self.mean_epoch_time_min, + mean_epoch_time_sec=self.mean_epoch_time_sec, + eta_ten_epochs_min=self.eta_ten_epochs_min, + epochs_in_plateau=self.epochs_in_plateau, + plateau_patience=self.config.optimization.early_stopping.plateau_patience, + epoch_in_plateau_flag=self.epoch_in_plateau_flag, + best_val_x=self.best_val_x, + best_val_y=self.best_val_y, + epoch_size=self.epoch_size, + ) - def check_messages( + def _check_messages( self, timeout: int = 10, times_to_check: int = 10, do_update: bool = True ): """Poll for ZMQ messages and adds any received data to graph. @@ -496,7 +914,7 @@ def check_messages( msg = jsonpickle.decode(self.sub.recv_string()) if msg["event"] == "train_begin": - self.set_start_time(perf_counter()) + self._set_start_time(perf_counter()) self.current_job_output_type = msg["what"] # Make sure message matches current training job. @@ -504,15 +922,15 @@ def check_messages( if not self.is_timer_running: # We must have missed the train_begin message, so start timer now. - self.set_start_time(perf_counter()) + self._set_start_time(perf_counter()) if msg["event"] == "train_end": - self.set_end() + self._set_end() elif msg["event"] == "epoch_begin": self.epoch = msg["epoch"] elif msg["event"] == "epoch_end": self.epoch_size = max(self.epoch_size, self.last_batch_number + 1) - self.add_datapoint( + self._add_datapoint( (self.epoch + 1) * self.epoch_size, msg["logs"]["loss"], "epoch_loss", @@ -521,7 +939,7 @@ def check_messages( # update variables and add points to plot self.penultimate_epoch_val_loss = self.last_epoch_val_loss self.last_epoch_val_loss = msg["logs"]["val_loss"] - self.add_datapoint( + self._add_datapoint( (self.epoch + 1) * self.epoch_size, msg["logs"]["val_loss"], "val_loss", @@ -552,7 +970,7 @@ def check_messages( self.on_epoch.emit() elif msg["event"] == "batch_end": self.last_batch_number = msg["batch"] - self.add_datapoint( + self._add_datapoint( (self.epoch * self.epoch_size) + msg["batch"], msg["logs"]["loss"], "batch", @@ -560,9 +978,155 @@ def check_messages( # Check for messages again (up to times_to_check times). if times_to_check > 0: - self.check_messages( + self._check_messages( timeout=timeout, times_to_check=times_to_check - 1, do_update=False ) if do_update: - self.update_runtime() + self._update_runtime() + + def _add_datapoint(self, x: int, y: float, which: str): + """Add a data point to graph. + + Args: + x: The batch number (out of all epochs, not just current), or epoch. + y: The loss value. + which: Type of data point we're adding. Possible values are: + * "batch" (loss for the batch) + * "epoch_loss" (loss for the entire epoch) + * "val_loss" (validation loss for the epoch) + """ + if which == "batch": + self.X.append(x) + self.Y.append(y) + + # Redraw batch at intervals (faster than plotting every batch). + draw_batch = False + if self.last_redraw_batch is None: + draw_batch = True + else: + dt = perf_counter() - self.last_redraw_batch + draw_batch = (dt * 1000) >= self.redraw_batch_time_ms + + if draw_batch: + self.last_redraw_batch = perf_counter() + if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: + xs, ys = self.X, self.Y + else: + xs, ys = ( + self.X[-self.batches_to_show :], + self.Y[-self.batches_to_show :], + ) + + # Set data, resize and redraw the plot + self._set_data_on_scatter(xs, ys, which) + self._resize_axes(xs, ys) + + else: + + if which == "val_loss": + if self.best_val_y is None or y < self.best_val_y: + self.best_val_x = x + self.best_val_y = y + self._set_data_on_scatter([x], [y], "val_loss_best") + + # Add data and redraw the plot + self._add_data_to_plot(x, y, which) + self._redraw_plot() + + def _set_data_on_scatter(self, xs, ys, which): + """Add data to a scatter plot. + + Not to be used with line plots. + + Args: + xs: The x-coordinates of the data points. + ys: The y-coordinates of the data points. + which: The type of data point. Possible values are: + * "batch" + * "val_loss_best" + """ + + self.canvas.set_data_on_scatter(xs, ys, which) + + def _add_data_to_plot(self, x, y, which): + """Add data to a line plot. + + Not to be used with scatter plots. + + Args: + x: The x-coordinate of the data point. + y: The y-coordinate of the data point. + which: The type of data point. Possible values are: + * "epoch_loss" + * "val_loss" + """ + + self.canvas.add_data_to_plot(x, y, which) + + def _redraw_plot(self): + """Redraw the plot.""" + + self.canvas.redraw_plot() + + def _resize_axes(self, x, y): + """Resize axes to fit data. + + This is only called when plotting batches. + + Args: + x: The x-coordinates of the data points. + y: The y-coordinates of the data points. + """ + self.canvas.resize_axes(x, y) + + def _toggle_ignore_outliers(self): + """Toggles whether to ignore outliers in chart scaling.""" + + self.ignore_outliers = not self.ignore_outliers + + def _toggle_log_scale(self): + """Toggle whether to use log-scaled y-axis.""" + + self.log_scale = not self.log_scale + + def _stop(self): + """Send command to stop training.""" + if self.zmq_ctrl is not None: + # Send command to stop training. + logger.info("Sending command to stop training.") + self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) + + # Disable the button to prevent double messages. + if self.stop_button is not None: + self.stop_button.setText("Stopping...") + self.stop_button.setEnabled(False) + + def _cancel(self): + """Set the cancel flag.""" + self.canceled = True + if self.cancel_button is not None: + self.cancel_button.setText("Canceling...") + self.cancel_button.setEnabled(False) + + def _unbind(self): + """Disconnect from all ZMQ sockets.""" + if self.sub is not None: + self.sub.unbind(self.sub.LAST_ENDPOINT) + self.sub.close() + self.sub = None + + if self.zmq_ctrl is not None: + url = self.zmq_ctrl.LAST_ENDPOINT + self.zmq_ctrl.unbind(url) + self.zmq_ctrl.close() + self.zmq_ctrl = None + + # If we started out own zmq context, terminate it. + if not self.ctx_given and self.ctx is not None: + self.ctx.term() + self.ctx = None + + def _set_end(self): + """Mark the end of the run.""" + self.is_running = False diff --git a/sleap/gui/widgets/mpl.py b/sleap/gui/widgets/mpl.py index a9b7fc838..890c1a67a 100644 --- a/sleap/gui/widgets/mpl.py +++ b/sleap/gui/widgets/mpl.py @@ -6,11 +6,10 @@ from qtpy import QtWidgets from matplotlib.figure import Figure -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as Canvas +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as Canvas import matplotlib -# Ensure using PyQt5 backend -matplotlib.use("QT5Agg") +matplotlib.use("QtAgg") class MplCanvas(Canvas): diff --git a/sleap/gui/widgets/training_monitor.py b/sleap/gui/widgets/training_monitor.py deleted file mode 100644 index ed405a747..000000000 --- a/sleap/gui/widgets/training_monitor.py +++ /dev/null @@ -1,566 +0,0 @@ -"""GUI for monitoring training progress interactively.""" - -import numpy as np -from time import perf_counter -from sleap.nn.config.training_job import TrainingJobConfig -import zmq -import jsonpickle -import logging -from typing import Optional -from qtpy import QtCore, QtWidgets, QtGui, QtCharts -import attr - -logger = logging.getLogger(__name__) - - -class LossViewer(QtWidgets.QMainWindow): - """Qt window for showing in-progress training metrics sent over ZMQ.""" - - on_epoch = QtCore.Signal() - - def __init__( - self, - zmq_context: Optional[zmq.Context] = None, - show_controller=True, - parent=None, - ): - super().__init__(parent) - - self.show_controller = show_controller - self.stop_button = None - self.cancel_button = None - self.canceled = False - - self.batches_to_show = -1 # -1 to show all - self.ignore_outliers = False - self.log_scale = True - self.message_poll_time_ms = 20 # ms - self.redraw_batch_time_ms = 500 # ms - self.last_redraw_batch = None - - self.reset() - self.setup_zmq(zmq_context) - - def __del__(self): - self.unbind() - - def close(self): - """Disconnect from ZMQ ports and close the window.""" - self.unbind() - super().close() - - def unbind(self): - """Disconnect from all ZMQ sockets.""" - if self.sub is not None: - self.sub.unbind(self.sub.LAST_ENDPOINT) - self.sub.close() - self.sub = None - - if self.zmq_ctrl is not None: - url = self.zmq_ctrl.LAST_ENDPOINT - self.zmq_ctrl.unbind(url) - self.zmq_ctrl.close() - self.zmq_ctrl = None - - # If we started out own zmq context, terminate it. - if not self.ctx_given and self.ctx is not None: - self.ctx.term() - self.ctx = None - - def reset( - self, - what: str = "", - config: TrainingJobConfig = attr.ib(factory=TrainingJobConfig), - ): - """Reset all chart series. - - Args: - what: String identifier indicating which job type the current run - corresponds to. - """ - self.chart = QtCharts.QChart() - - self.series = dict() - - COLOR_TRAIN = (18, 158, 220) - COLOR_VAL = (248, 167, 52) - COLOR_BEST_VAL = (151, 204, 89) - - self.series["batch"] = QtCharts.QScatterSeries() - self.series["batch"].setName("Batch Training Loss") - self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48)) - self.series["batch"].setMarkerSize(8.0) - self.series["batch"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["batch"]) - - self.series["epoch_loss"] = QtCharts.QLineSeries() - self.series["epoch_loss"].setName("Epoch Training Loss") - self.series["epoch_loss"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) - pen = self.series["epoch_loss"].pen() - pen.setWidth(4) - self.series["epoch_loss"].setPen(pen) - self.chart.addSeries(self.series["epoch_loss"]) - - self.series["epoch_loss_scatter"] = QtCharts.QScatterSeries() - self.series["epoch_loss_scatter"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) - self.series["epoch_loss_scatter"].setMarkerSize(12.0) - self.series["epoch_loss_scatter"].setBorderColor( - QtGui.QColor(255, 255, 255, 25) - ) - self.chart.addSeries(self.series["epoch_loss_scatter"]) - - self.series["val_loss"] = QtCharts.QLineSeries() - self.series["val_loss"].setName("Epoch Validation Loss") - self.series["val_loss"].setColor(QtGui.QColor(*COLOR_VAL, 255)) - pen = self.series["val_loss"].pen() - pen.setWidth(4) - self.series["val_loss"].setPen(pen) - self.chart.addSeries(self.series["val_loss"]) - - self.series["val_loss_scatter"] = QtCharts.QScatterSeries() - self.series["val_loss_scatter"].setColor(QtGui.QColor(*COLOR_VAL, 255)) - self.series["val_loss_scatter"].setMarkerSize(12.0) - self.series["val_loss_scatter"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["val_loss_scatter"]) - - self.series["val_loss_best"] = QtCharts.QScatterSeries() - self.series["val_loss_best"].setName("Best Validation Loss") - self.series["val_loss_best"].setColor(QtGui.QColor(*COLOR_BEST_VAL, 255)) - self.series["val_loss_best"].setMarkerSize(12.0) - self.series["val_loss_best"].setBorderColor(QtGui.QColor(32, 32, 32, 25)) - self.chart.addSeries(self.series["val_loss_best"]) - - axisX = QtCharts.QValueAxis() - axisX.setLabelFormat("%d") - axisX.setTitleText("Batches") - self.chart.addAxis(axisX, QtCore.Qt.AlignBottom) - - # Create the different Y axes that can be used. - self.axisY = dict() - - self.axisY["log"] = QtCharts.QLogValueAxis() - self.axisY["log"].setBase(10) - - self.axisY["linear"] = QtCharts.QValueAxis() - - # Apply settings that apply to all Y axes. - for axisY in self.axisY.values(): - axisY.setLabelFormat("%f") - axisY.setLabelsVisible(True) - axisY.setMinorTickCount(1) - axisY.setTitleText("Loss") - - # Use the default Y axis. - axisY = self.axisY["log"] if self.log_scale else self.axisY["linear"] - - # Add axes to chart and series. - self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): - series.attachAxis(axisX) - series.attachAxis(axisY) - - # Setup legend. - self.chart.legend().setVisible(True) - self.chart.legend().setAlignment(QtCore.Qt.AlignTop) - self.chart.legend().setMarkerShape(QtCharts.QLegend.MarkerShapeCircle) - - # Hide scatters for epoch and val loss from legend. - for s in ("epoch_loss_scatter", "val_loss_scatter"): - self.chart.legend().markers(self.series[s])[0].setVisible(False) - - self.chartView = QtCharts.QChartView(self.chart) - self.chartView.setRenderHint(QtGui.QPainter.Antialiasing) - layout = QtWidgets.QVBoxLayout() - layout.addWidget(self.chartView) - - if self.show_controller: - control_layout = QtWidgets.QHBoxLayout() - - field = QtWidgets.QCheckBox("Log Scale") - field.setChecked(self.log_scale) - field.stateChanged.connect(self.toggle_log_scale) - control_layout.addWidget(field) - - field = QtWidgets.QCheckBox("Ignore Outliers") - field.setChecked(self.ignore_outliers) - field.stateChanged.connect(self.toggle_ignore_outliers) - control_layout.addWidget(field) - - control_layout.addWidget(QtWidgets.QLabel("Batches to Show:")) - - # Add field for how many batches to show in chart. - field = QtWidgets.QComboBox() - self.batch_options = "200,1000,5000,All".split(",") - for opt in self.batch_options: - field.addItem(opt) - cur_opt_str = ( - "All" if self.batches_to_show < 0 else str(self.batches_to_show) - ) - if cur_opt_str in self.batch_options: - field.setCurrentText(cur_opt_str) - - # Set connection action for when user selects another option. - field.currentIndexChanged.connect( - lambda x: self.set_batches_to_show(self.batch_options[x]) - ) - - # Store field as property and add to layout. - self.batches_to_show_field = field - control_layout.addWidget(self.batches_to_show_field) - - control_layout.addStretch(1) - - self.stop_button = QtWidgets.QPushButton("Stop Early") - self.stop_button.clicked.connect(self.stop) - control_layout.addWidget(self.stop_button) - self.cancel_button = QtWidgets.QPushButton("Cancel Training") - self.cancel_button.clicked.connect(self.cancel) - control_layout.addWidget(self.cancel_button) - - widget = QtWidgets.QWidget() - widget.setLayout(control_layout) - layout.addWidget(widget) - - wid = QtWidgets.QWidget() - wid.setLayout(layout) - self.setCentralWidget(wid) - - self.config = config - self.X = [] - self.Y = [] - self.best_val_x = None - self.best_val_y = None - - self.t0 = None - self.mean_epoch_time_min = None - self.mean_epoch_time_sec = None - self.eta_ten_epochs_min = None - - self.current_job_output_type = what - self.epoch = 0 - self.epoch_size = 1 - self.epochs_in_plateau = 0 - self.last_epoch_val_loss = None - self.penultimate_epoch_val_loss = None - self.epoch_in_plateau_flag = False - self.last_batch_number = 0 - self.is_running = False - - def toggle_ignore_outliers(self): - """Toggles whether to ignore outliers in chart scaling.""" - self.ignore_outliers = not self.ignore_outliers - - def toggle_log_scale(self): - """Toggle whether to use log-scaled y-axis.""" - self.log_scale = not self.log_scale - self.update_y_axis() - - def set_batches_to_show(self, batches: str): - """Set the number of batches to show on the x-axis. - - Args: - batches: Number of batches as a string. If numeric, this will be converted - to an integer. If non-numeric string (e.g., "All"), then all batches - will be shown. - """ - if batches.isdigit(): - self.batches_to_show = int(batches) - else: - self.batches_to_show = -1 - - def update_y_axis(self): - """Update the y-axis when scale changes.""" - to = "log" if self.log_scale else "linear" - - # Remove other axes. - for name, axisY in self.axisY.items(): - if name != to: - if axisY in self.chart.axes(): - self.chart.removeAxis(axisY) - for series in self.chart.series(): - if axisY in series.attachedAxes(): - series.detachAxis(axisY) - - # Add axis. - axisY = self.axisY[to] - self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): - series.attachAxis(axisY) - - def setup_zmq(self, zmq_context: Optional[zmq.Context] = None): - """Connect to ZMQ ports that listen to commands and updates. - - Args: - zmq_context: The `zmq.Context` object to use for connections. A new one is - created if not specified and will be closed when the monitor exits. If - an existing one is provided, it will NOT be closed. - """ - # Keep track of whether we're using an existing context (which we won't close - # when done) or are creating our own (which we should close). - self.ctx_given = zmq_context is not None - self.ctx = zmq.Context() if zmq_context is None else zmq_context - - # Progress monitoring, SUBSCRIBER - self.sub = self.ctx.socket(zmq.SUB) - self.sub.subscribe("") - self.sub.bind("tcp://127.0.0.1:9001") - - # Controller, PUBLISHER - self.zmq_ctrl = None - if self.show_controller: - self.zmq_ctrl = self.ctx.socket(zmq.PUB) - self.zmq_ctrl.bind("tcp://127.0.0.1:9000") - - # Set timer to poll for messages. - self.timer = QtCore.QTimer() - self.timer.timeout.connect(self.check_messages) - self.timer.start(self.message_poll_time_ms) - - def cancel(self): - """Set the cancel flag.""" - self.canceled = True - if self.cancel_button is not None: - self.cancel_button.setText("Canceling...") - self.cancel_button.setEnabled(False) - - def stop(self): - """Send command to stop training.""" - if self.zmq_ctrl is not None: - # Send command to stop training. - logger.info("Sending command to stop training.") - self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) - - # Disable the button to prevent double messages. - if self.stop_button is not None: - self.stop_button.setText("Stopping...") - self.stop_button.setEnabled(False) - - def add_datapoint(self, x: int, y: float, which: str): - """Add a data point to graph. - - Args: - x: The batch number (out of all epochs, not just current), or epoch. - y: The loss value. - which: Type of data point we're adding. Possible values are: - * "batch" (loss for the batch) - * "epoch_loss" (loss for the entire epoch) - * "val_loss" (validation loss for the epoch) - """ - if which == "batch": - self.X.append(x) - self.Y.append(y) - - # Redraw batch at intervals (faster than plotting every batch). - draw_batch = False - if self.last_redraw_batch is None: - draw_batch = True - else: - dt = perf_counter() - self.last_redraw_batch - draw_batch = (dt * 1000) >= self.redraw_batch_time_ms - - if draw_batch: - self.last_redraw_batch = perf_counter() - if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: - xs, ys = self.X, self.Y - else: - xs, ys = ( - self.X[-self.batches_to_show :], - self.Y[-self.batches_to_show :], - ) - - points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0] - self.series["batch"].replace(points) - - # Set X scale to show all points - dx = 0.5 - self.chart.axisX().setRange(min(xs) - dx, max(xs) + dx) - - if self.ignore_outliers: - dy = np.ptp(ys) * 0.02 - # Set Y scale to exclude outliers - q1, q3 = np.quantile(ys, (0.25, 0.75)) - iqr = q3 - q1 # interquartile range - low = q1 - iqr * 1.5 - high = q3 + iqr * 1.5 - - low = max(low, min(ys) - dy) # keep within range of data - high = min(high, max(ys) + dy) - else: - # Set Y scale to show all points - dy = np.ptp(ys) * 0.02 - low = min(ys) - dy - high = max(ys) + dy - - if self.log_scale: - low = max(low, 1e-8) # for log scale, low cannot be 0 - - self.chart.axisY().setRange(low, high) - - else: - if which == "epoch_loss": - self.series["epoch_loss"].append(x, y) - self.series["epoch_loss_scatter"].append(x, y) - elif which == "val_loss": - self.series["val_loss"].append(x, y) - self.series["val_loss_scatter"].append(x, y) - if self.best_val_y is None or y < self.best_val_y: - self.best_val_x = x - self.best_val_y = y - self.series["val_loss_best"].replace([QtCore.QPointF(x, y)]) - - def set_start_time(self, t0: float): - """Mark the start flag and time of the run. - - Args: - t0: Start time in seconds. - """ - self.t0 = t0 - self.is_running = True - - def set_end(self): - """Mark the end of the run.""" - self.is_running = False - - def update_runtime(self): - """Update the title text with the current running time.""" - if self.is_timer_running: - dt = perf_counter() - self.t0 - dt_min, dt_sec = divmod(dt, 60) - title = f"Training Epoch {self.epoch + 1} / " - title += f"Runtime: {int(dt_min):02}:{int(dt_sec):02}" - if self.last_epoch_val_loss is not None: - if self.penultimate_epoch_val_loss is not None: - title += ( - f"
Mean Time per Epoch: " - f"{int(self.mean_epoch_time_min):02}:{int(self.mean_epoch_time_sec):02} / " - f"ETA Next 10 Epochs: {int(self.eta_ten_epochs_min)} min" - ) - if self.epoch_in_plateau_flag: - title += ( - f"
Epochs in Plateau: " - f"{self.epochs_in_plateau} / " - f"{self.config.optimization.early_stopping.plateau_patience}" - ) - title += ( - f"
Last Epoch Validation Loss: " - f"{self.last_epoch_val_loss:.3e}" - ) - if self.best_val_x is not None: - best_epoch = (self.best_val_x // self.epoch_size) + 1 - title += ( - f"
Best Epoch Validation Loss: " - f"{self.best_val_y:.3e} (epoch {best_epoch})" - ) - self.set_message(title) - - @property - def is_timer_running(self) -> bool: - """Return True if the timer has started.""" - return self.t0 is not None and self.is_running - - def set_message(self, text: str): - """Set the chart title text.""" - self.chart.setTitle(text) - - def check_messages( - self, timeout: int = 10, times_to_check: int = 10, do_update: bool = True - ): - """Poll for ZMQ messages and adds any received data to graph. - - The message is a dictionary encoded as JSON: - * event - options include - * train_begin - * train_end - * epoch_begin - * epoch_end - * batch_end - * what - this should match the type of model we're training and - ensures that we ignore old messages when we start monitoring - a new training session (when we're training multiple types - of models in a sequence, as for the top-down pipeline). - * logs - dictionary with data relevant for plotting, can include - * loss - * val_loss - - Args: - timeout: Message polling timeout in milliseconds. This is how often we will - check for new command messages. - times_to_check: How many times to check for new messages in the queue before - going back to polling with a timeout. Helps to clear backlogs of - messages if necessary. - do_update: If True (the default), update the GUI text. - """ - if self.sub and self.sub.poll(timeout, zmq.POLLIN): - msg = jsonpickle.decode(self.sub.recv_string()) - - if msg["event"] == "train_begin": - self.set_start_time(perf_counter()) - self.current_job_output_type = msg["what"] - - # Make sure message matches current training job. - if msg.get("what", "") == self.current_job_output_type: - - if not self.is_timer_running: - # We must have missed the train_begin message, so start timer now. - self.set_start_time(perf_counter()) - - if msg["event"] == "train_end": - self.set_end() - elif msg["event"] == "epoch_begin": - self.epoch = msg["epoch"] - elif msg["event"] == "epoch_end": - self.epoch_size = max(self.epoch_size, self.last_batch_number + 1) - self.add_datapoint( - (self.epoch + 1) * self.epoch_size, - msg["logs"]["loss"], - "epoch_loss", - ) - if "val_loss" in msg["logs"].keys(): - # update variables and add points to plot - self.penultimate_epoch_val_loss = self.last_epoch_val_loss - self.last_epoch_val_loss = msg["logs"]["val_loss"] - self.add_datapoint( - (self.epoch + 1) * self.epoch_size, - msg["logs"]["val_loss"], - "val_loss", - ) - # calculate timing and flags at new epoch - if self.penultimate_epoch_val_loss is not None: - mean_epoch_time = (perf_counter() - self.t0) / ( - self.epoch + 1 - ) - self.mean_epoch_time_min, self.mean_epoch_time_sec = divmod( - mean_epoch_time, 60 - ) - self.eta_ten_epochs_min = (mean_epoch_time * 10) // 60 - - val_loss_delta = ( - self.penultimate_epoch_val_loss - - self.last_epoch_val_loss - ) - self.epoch_in_plateau_flag = ( - val_loss_delta - < self.config.optimization.early_stopping.plateau_min_delta - ) or (self.best_val_y < self.last_epoch_val_loss) - self.epochs_in_plateau = ( - self.epochs_in_plateau + 1 - if self.epoch_in_plateau_flag - else 0 - ) - self.on_epoch.emit() - elif msg["event"] == "batch_end": - self.last_batch_number = msg["batch"] - self.add_datapoint( - (self.epoch * self.epoch_size) + msg["batch"], - msg["logs"]["loss"], - "batch", - ) - - # Check for messages again (up to times_to_check times). - if times_to_check > 0: - self.check_messages( - timeout=timeout, times_to_check=times_to_check - 1, do_update=False - ) - - if do_update: - self.update_runtime() diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 502ea388e..08ee5bf36 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -62,6 +62,7 @@ QShortcut, QVBoxLayout, QWidget, + QPinchGesture, ) import sleap @@ -240,6 +241,8 @@ def __init__( self._register_shortcuts() + self.context_menu = None + self._menu_actions = dict() if self.context: self.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) self.customContextMenuRequested.connect(self.show_contextual_menu) @@ -358,41 +361,54 @@ def add_shortcut(key, step): def setSeekbarSelection(self, a: int, b: int): self.seekbar.setSelection(a, b) - def show_contextual_menu(self, where: QtCore.QPoint): - if not self.is_menu_enabled: - return + def create_contextual_menu(self, scene_pos: QtCore.QPointF) -> QtWidgets.QMenu: + """Create the context menu for the viewer. - scene_pos = self.view.mapToScene(where) - menu = QtWidgets.QMenu() + This is called when the user right-clicks in the viewer. This function also + stores the menu actions in the `_menu_actions` attribute so that they can be + accessed later and stores the context menu in the `context_menu` attribute. - menu.addAction("Add Instance:").setEnabled(False) + Args: + scene_pos: The position in the scene where the menu was requested. - menu.addAction("Default", lambda: self.context.newInstance(init_method="best")) + Returns: + The created context menu. + """ - menu.addAction( - "Average", - lambda: self.context.newInstance( - init_method="template", location=scene_pos - ), - ) + self.context_menu = QtWidgets.QMenu() + self.context_menu.addAction("Add Instance:").setEnabled(False) + + self._menu_actions = dict() + params_by_action_name = { + "Default": {"init_method": "best", "location": scene_pos}, + "Average": {"init_method": "template", "location": scene_pos}, + "Force Directed": {"init_method": "force_directed", "location": scene_pos}, + "Copy Prior Frame": {"init_method": "prior_frame"}, + "Random": {"init_method": "random", "location": scene_pos}, + } + for action_name, params in params_by_action_name.items(): + self._menu_actions[action_name] = self.context_menu.addAction( + action_name, lambda params=params: self.context.newInstance(**params) + ) - menu.addAction( - "Force Directed", - lambda: self.context.newInstance( - init_method="force_directed", location=scene_pos - ), - ) + return self.context_menu - menu.addAction( - "Copy Prior Frame", - lambda: self.context.newInstance(init_method="prior_frame"), - ) + def show_contextual_menu(self, where: QtCore.QPoint): + """Show the context menu at the given position in the viewer. - menu.addAction( - "Random", - lambda: self.context.newInstance(init_method="random", location=scene_pos), - ) + This is called when the user right-clicks in the viewer. This function calls + `create_contextual_menu` to create the menu and then shows the menu at the + given position. + Args: + where: The position in the viewer where the menu was requested. + """ + + if not self.is_menu_enabled: + return + + scene_pos = self.view.mapToScene(where) + menu = self.create_contextual_menu(scene_pos) menu.exec_(self.mapToGlobal(where)) def load_video(self, video: Video, plot=True): @@ -808,6 +824,8 @@ def __init__(self, state=None, player=None, *args, **kwargs): # Set icon as default background. self.setImage(QImage(sleap.util.get_package_file("gui/background.png"))) + self.grabGesture(Qt.GestureType.PinchGesture) + def dragEnterEvent(self, event): if self.parentWidget(): self.parentWidget().dragEnterEvent(event) @@ -1147,8 +1165,13 @@ def mouseDoubleClickEvent(self, event: QMouseEvent): QGraphicsView.mouseDoubleClickEvent(self, event) def wheelEvent(self, event): - """Custom event handler. Zoom in/out based on scroll wheel change.""" - # zoom on wheel when no mouse buttons are pressed + """Custom event handler to zoom in/out based on scroll wheel change. + + We cannot use the default QGraphicsView.wheelEvent behavior since that will + scroll the view. + """ + + # Zoom on wheel when no mouse buttons are pressed if event.buttons() == Qt.NoButton: angle = event.angleDelta().y() factor = 1.1 if angle > 0 else 0.9 @@ -1156,20 +1179,10 @@ def wheelEvent(self, event): self.zoomFactor = max(factor * self.zoomFactor, 1) self.updateViewer() - # Trigger wheelEvent for all child elements. This is a bit of a hack. - # We can't use QGraphicsView.wheelEvent(self, event) since that will scroll - # view. - # We want to trigger for all children, since wheelEvent should continue rotating - # an skeleton even if the skeleton node/node label is no longer under the - # cursor. - # Note that children expect a QGraphicsSceneWheelEvent event, which is why we're - # explicitly ignoring TypeErrors. Everything seems to work fine since we don't - # care about the mouse position; if we did, we'd need to map pos to scene. + # Trigger only for rotation-relevant children (otherwise GUI crashes) for child in self.items(): - try: + if isinstance(child, (QtNode, QtNodeLabel)): child.wheelEvent(event) - except TypeError: - pass def keyPressEvent(self, event): """Custom event hander, disables default QGraphicsView behavior.""" @@ -1179,6 +1192,23 @@ def keyReleaseEvent(self, event): """Custom event hander, disables default QGraphicsView behavior.""" event.ignore() # Kicks the event up to parent + def event(self, event): + if event.type() == QtCore.QEvent.Gesture: + return self.handleGestureEvent(event) + return super().event(event) + + def handleGestureEvent(self, event): + gesture = event.gesture(Qt.GestureType.PinchGesture) + if gesture: + self.handlePinchGesture(gesture) + return True + + def handlePinchGesture(self, gesture: QPinchGesture): + if gesture.state() == Qt.GestureState.GestureUpdated: + factor = gesture.scaleFactor() + self.zoomFactor = max(factor * self.zoomFactor, 1) + self.updateViewer() + class QtNodeLabel(QGraphicsTextItem): """ @@ -1560,7 +1590,6 @@ def mousePressEvent(self, event): def mouseMoveEvent(self, event): """Custom event handler for mouse move.""" - # print(event) if self.dragParent: self.parentObject().mouseMoveEvent(event) else: @@ -1571,7 +1600,6 @@ def mouseMoveEvent(self, event): def mouseReleaseEvent(self, event): """Custom event handler for mouse release.""" - # print(event) self.unsetCursor() if self.dragParent: self.parentObject().mouseReleaseEvent(event) @@ -1587,7 +1615,9 @@ def mouseReleaseEvent(self, event): def wheelEvent(self, event): """Custom event handler for mouse scroll wheel.""" if self.dragParent: - angle = event.delta() / 20 + self.parentObject().rotation() + angle = ( + event.angleDelta().x() + event.angleDelta().y() + ) / 20 + self.parentObject().rotation() self.parentObject().setRotation(angle) event.accept() @@ -1598,6 +1628,10 @@ def mouseDoubleClickEvent(self, event: QMouseEvent): view = scene.views()[0] view.instanceDoubleClicked.emit(self.parentObject().instance, event) + def hoverEnterEvent(self, event): + """Custom event handler for mouse hover enter.""" + return super().hoverEnterEvent(event) + class QtEdge(QGraphicsPolygonItem): """ @@ -1797,6 +1831,7 @@ def __init__( self.labels = {} self.labels_shown = True self._selected = False + self._is_hovering = False self._bounding_rect = QRectF() # Show predicted instances behind non-predicted ones @@ -1818,6 +1853,7 @@ def __init__( box_pen.setStyle(Qt.DashLine) box_pen.setCosmetic(True) self.box.setPen(box_pen) + self.setAcceptHoverEvents(True) # Add label for highlighted instance self.highlight_label = QtTextWithBackground(parent=self) @@ -1979,7 +2015,12 @@ def updateBox(self, *args, **kwargs): select this instance. """ # Only show box if instance is selected - op = 0.7 if self._selected else 0 + op = 0 + if self._selected: + op = 0.8 + elif self._is_hovering: + op = 0.4 + self.box.setOpacity(op) # Update the position for the box rect = self.getPointsBoundingRect() @@ -2073,6 +2114,16 @@ def paint(self, painter, option, widget=None): """Method required by Qt.""" pass + def hoverEnterEvent(self, event): + self._is_hovering = True + self.updateBox() + return super().hoverEnterEvent(event) + + def hoverLeaveEvent(self, event): + self._is_hovering = False + self.updateBox() + return super().hoverLeaveEvent(event) + class VisibleBoundingBox(QtWidgets.QGraphicsRectItem): """QGraphicsRectItem for user instance bounding boxes. @@ -2263,7 +2314,7 @@ def mouseReleaseEvent(self, event): self.parent.nodes[node_key].setPos(new_x, new_y) # Update the instance - self.parent.updatePoints(complete=True, user_change=True) + self.parent.updatePoints(complete=False, user_change=True) self.resizing = None diff --git a/sleap/info/feature_suggestions.py b/sleap/info/feature_suggestions.py index 51f9038a5..a5f773fa7 100644 --- a/sleap/info/feature_suggestions.py +++ b/sleap/info/feature_suggestions.py @@ -644,7 +644,7 @@ class ParallelFeaturePipeline(object): def get(self, video_idx): """Apply pipeline to single video by idx. Can be called in process.""" video_dict = self.videos_as_dicts[video_idx] - video = cattr.structure(video_dict, Video) + video = Video.cattr().structure(video_dict, Video) group_offset = video_idx * self.pipeline.n_clusters # t0 = time() diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index 2ac61d339..5bec077e4 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -10,75 +10,6 @@ from sleap.io.dataset import Labels -def matched_instance_distances( - labels_gt: Labels, - labels_pr: Labels, - match_lists_function: Callable, - frame_range: Optional[range] = None, -) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: - - """ - Distances between ground truth and predicted nodes over a set of frames. - - Args: - labels_gt: the `Labels` object with ground truth data - labels_pr: the `Labels` object with predicted data - match_lists_function: function for determining corresponding instances - Takes two lists of instances and returns "sorted" lists. - frame_range (optional): range of frames for which to compare data - If None, we compare every frame in labels_gt with corresponding - frame in labels_pr. - Returns: - Tuple: - * frame indices map: instance idx (for other matrices) -> frame idx - * distance matrix: (instances * nodes) - * ground truth points matrix: (instances * nodes * 2) - * predicted points matrix: (instances * nodes * 2) - """ - - frame_idxs = [] - points_gt = [] - points_pr = [] - for lf_gt in labels_gt.find(labels_gt.videos[0]): - frame_idx = lf_gt.frame_idx - - # Get instances from ground truth/predicted labels - instances_gt = lf_gt.instances - lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx) - if len(lfs_pr): - instances_pr = lfs_pr[0].instances - else: - instances_pr = [] - - # Sort ground truth and predicted instances. - # We'll then compare points between corresponding items in lists. - # We can use different "match" functions depending on what we want. - sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr) - - # Convert lists of instances to (instances, nodes, 2) matrices. - # This allows match_lists_function to return data as either - # a list of Instances or a (instances, nodes, 2) matrix. - if type(sorted_gt[0]) != np.ndarray: - sorted_gt = list_points_array(sorted_gt) - if type(sorted_pr[0]) != np.ndarray: - sorted_pr = list_points_array(sorted_pr) - - points_gt.append(sorted_gt) - points_pr.append(sorted_pr) - frame_idxs.extend([frame_idx] * len(sorted_gt)) - - # Convert arrays to numpy matrixes - # instances * nodes * (x,y) - points_gt = np.concatenate(points_gt) - points_pr = np.concatenate(points_pr) - - # Calculate distances between corresponding nodes for all corresponding - # ground truth and predicted instances. - D = np.linalg.norm(points_gt - points_pr, axis=2) - - return frame_idxs, D, points_gt, points_pr - - def match_instance_lists( instances_a: List[Union[Instance, PredictedInstance]], instances_b: List[Union[Instance, PredictedInstance]], @@ -165,6 +96,75 @@ def match_instance_lists_nodewise( return instances_a, best_points_array +def matched_instance_distances( + labels_gt: Labels, + labels_pr: Labels, + match_lists_function: Callable = match_instance_lists_nodewise, + frame_range: Optional[range] = None, +) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: + + """ + Distances between ground truth and predicted nodes over a set of frames. + + Args: + labels_gt: the `Labels` object with ground truth data + labels_pr: the `Labels` object with predicted data + match_lists_function: function for determining corresponding instances + Takes two lists of instances and returns "sorted" lists. + frame_range (optional): range of frames for which to compare data + If None, we compare every frame in labels_gt with corresponding + frame in labels_pr. + Returns: + Tuple: + * frame indices map: instance idx (for other matrices) -> frame idx + * distance matrix: (instances * nodes) + * ground truth points matrix: (instances * nodes * 2) + * predicted points matrix: (instances * nodes * 2) + """ + + frame_idxs = [] + points_gt = [] + points_pr = [] + for lf_gt in labels_gt.find(labels_gt.videos[0]): + frame_idx = lf_gt.frame_idx + + # Get instances from ground truth/predicted labels + instances_gt = lf_gt.instances + lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx) + if len(lfs_pr): + instances_pr = lfs_pr[0].instances + else: + instances_pr = [] + + # Sort ground truth and predicted instances. + # We'll then compare points between corresponding items in lists. + # We can use different "match" functions depending on what we want. + sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr) + + # Convert lists of instances to (instances, nodes, 2) matrices. + # This allows match_lists_function to return data as either + # a list of Instances or a (instances, nodes, 2) matrix. + if type(sorted_gt[0]) != np.ndarray: + sorted_gt = list_points_array(sorted_gt) + if type(sorted_pr[0]) != np.ndarray: + sorted_pr = list_points_array(sorted_pr) + + points_gt.append(sorted_gt) + points_pr.append(sorted_pr) + frame_idxs.extend([frame_idx] * len(sorted_gt)) + + # Convert arrays to numpy matrixes + # instances * nodes * (x,y) + points_gt = np.concatenate(points_gt) + points_pr = np.concatenate(points_pr) + + # Calculate distances between corresponding nodes for all corresponding + # ground truth and predicted instances. + D = np.linalg.norm(points_gt - points_pr, axis=2) + + return frame_idxs, D, points_gt, points_pr + + def point_dist( inst_a: Union[Instance, PredictedInstance], inst_b: Union[Instance, PredictedInstance], @@ -238,46 +238,3 @@ def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int: def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are not <= threshold.""" return dist_array.shape[0] - point_match_count(dist_array, thresh) - - -if __name__ == "__main__": - - labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json") - labels_pr = Labels.load_json( - "tests/data/json_format_v2/centered_pair_predictions.json" - ) - - # OPTION 1 - - # Match each ground truth instance node to the closest corresponding node - # from any predicted instance in the same frame. - - nodewise_matching_func = match_instance_lists_nodewise - - # OPTION 2 - - # Match each ground truth instance to a distinct predicted instance: - # We want to maximize the number of "matching" points between instances, - # where "match" means the points are within some threshold distance. - # Note that each sorted list will be as long as the shorted input list. - - instwise_matching_func = lambda gt_list, pr_list: match_instance_lists( - gt_list, pr_list, point_nonmatch_count - ) - - # PICK THE FUNCTION - - inst_matching_func = nodewise_matching_func - # inst_matching_func = instwise_matching_func - - # Calculate distances - frame_idxs, D, points_gt, points_pr = matched_instance_distances( - labels_gt, labels_pr, inst_matching_func - ) - - # Show mean difference for each node - node_names = labels_gt.skeletons[0].node_names - - for node_idx, node_name in enumerate(node_names): - mean_d = np.nanmean(D[..., node_idx]) - print(f"{node_name}\t\t{mean_d}") diff --git a/sleap/info/summary.py b/sleap/info/summary.py index c6a6af60e..0cad1617e 100644 --- a/sleap/info/summary.py +++ b/sleap/info/summary.py @@ -21,7 +21,7 @@ class StatisticSeries: are frame index and value are some numerical value for the frame. Args: - labels: The :class:`Labels` for which to calculate series. + labels: The `Labels` for which to calculate series. """ labels: Labels @@ -41,7 +41,7 @@ def get_point_score_series( """Get series with statistic of point scores in each frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to scores: * sum * min @@ -67,7 +67,7 @@ def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]: """Get series with statistic of instance scores in each frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to scores: * sum * min @@ -93,7 +93,7 @@ def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, flo same track) from the closest earlier labeled frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to point scores: * sum * mean @@ -121,7 +121,7 @@ def get_primary_point_displacement_series( Get sum of displacement for single node of each instance per frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to point scores: * sum * mean @@ -226,7 +226,7 @@ def _calculate_frame_velocity( Calculate total point displacement between two given frames. Args: - lf: The :class:`LabeledFrame` for which we want velocity + lf: The `LabeledFrame` for which we want velocity last_lf: The frame from which to calculate displacement. reduce_function: Numpy function (e.g., np.sum, np.nanmean) is applied to *point* displacement, and then those @@ -246,3 +246,35 @@ def _calculate_frame_velocity( inst_dist = reduce_function(point_dist) val += inst_dist if not np.isnan(inst_dist) else 0 return val + + def get_tracking_score_series( + self, video: Video, reduction: str = "min" + ) -> Dict[int, float]: + """Get series with statistic of tracking scores in each frame. + + Args: + video: The `Video` for which to calculate statistic. + reduction: name of function applied to scores: + * mean + * min + + Returns: + The series dictionary (see class docs for details) + """ + reduce_fn = { + "min": np.nanmin, + "mean": np.nanmean, + }[reduction] + + series = dict() + + for lf in self.labels.find(video): + vals = [ + inst.tracking_score for inst in lf if hasattr(inst, "tracking_score") + ] + if vals: + val = reduce_fn(vals) + if not np.isnan(val): + series[lf.frame_idx] = val + + return series diff --git a/sleap/instance.py b/sleap/instance.py index c14038552..382ececf2 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -364,7 +364,7 @@ class Instance: from_predicted: Optional["PredictedInstance"] = attr.ib(default=None) _points: PointArray = attr.ib(default=None) _nodes: List = attr.ib(default=None) - frame: Union["LabeledFrame", None] = attr.ib(default=None) + frame: Union["LabeledFrame", None] = attr.ib(default=None) # TODO(LM): Make private # The underlying Point array type that this instances point array should be. _point_array_type = PointArray @@ -1049,7 +1049,9 @@ def scores(self) -> np.ndarray: return self.points_and_scores_array[:, 2] @classmethod - def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": + def from_instance( + cls, instance: Instance, score: float, tracking_score: float = 0.0 + ) -> "PredictedInstance": """Create a `PredictedInstance` from an `Instance`. The fields are copied in a shallow manner with the exception of points. For each @@ -1059,6 +1061,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": Args: instance: The `Instance` object to shallow copy data from. score: The score for this instance. + tracking_score: The tracking score for this instance. Returns: A `PredictedInstance` for the given `Instance`. @@ -1070,6 +1073,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": ) kw_args["points"] = PredictedPointArray.from_array(instance._points) kw_args["score"] = score + kw_args["tracking_score"] = tracking_score return cls(**kw_args) @classmethod @@ -1080,6 +1084,7 @@ def from_arrays( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1094,6 +1099,7 @@ def from_arrays( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. @@ -1114,6 +1120,7 @@ def from_arrays( skeleton=skeleton, score=instance_score, track=track, + tracking_score=tracking_score, ) @classmethod @@ -1124,6 +1131,7 @@ def from_pointsarray( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1138,12 +1146,18 @@ def from_pointsarray( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. """ return cls.from_arrays( - points, point_confidences, instance_score, skeleton, track=track + points, + point_confidences, + instance_score, + skeleton, + track=track, + tracking_score=tracking_score, ) @classmethod @@ -1154,6 +1168,7 @@ def from_numpy( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1168,12 +1183,18 @@ def from_numpy( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. """ return cls.from_arrays( - points, point_confidences, instance_score, skeleton, track=track + points, + point_confidences, + instance_score, + skeleton, + track=track, + tracking_score=tracking_score, ) @@ -1214,6 +1235,9 @@ def unstructure_instance(x: Instance): converter.register_unstructure_hook(Instance, unstructure_instance) converter.register_unstructure_hook(PredictedInstance, unstructure_instance) + converter.register_unstructure_hook( + InstancesList, lambda x: [converter.unstructure(inst) for inst in x] + ) ## STRUCTURE HOOKS @@ -1229,35 +1253,37 @@ def structure_points(x, type): def structure_instances_list(x, type): inst_list = [] for inst_data in x: - if "score" in inst_data.keys(): - inst = converter.structure(inst_data, PredictedInstance) - else: - if ( - "from_predicted" in inst_data - and inst_data["from_predicted"] is not None - ): - inst_data["from_predicted"] = converter.structure( - inst_data["from_predicted"], PredictedInstance - ) - inst = converter.structure(inst_data, Instance) + inst = structure_instance(inst_data, type) inst_list.append(inst) return inst_list + def structure_instance(inst_data, type): + """Structure hook for Instance and PredictedInstance objects.""" + from_predicted = None + + if "score" in inst_data.keys(): + inst = converter.structure(inst_data, PredictedInstance) + else: + if ( + "from_predicted" in inst_data + and inst_data["from_predicted"] is not None + ): + from_predicted = converter.structure( + inst_data["from_predicted"], PredictedInstance + ) + # Remove the from_predicted key. We'll add it back afterwards. + inst_data["from_predicted"] = None + + # Structure the instance data, then add the from_predicted attribute. + inst = converter.structure(inst_data, Instance) + inst.from_predicted = from_predicted + return inst + converter.register_structure_hook( Union[List[Instance], List[PredictedInstance]], structure_instances_list ) - - # Structure forward reference for PredictedInstance for the Instance.from_predicted - # attribute. - converter.register_structure_hook_func( - lambda t: t.__class__ is ForwardRef, - lambda v, t: converter.structure(v, t.__forward_value__), - ) - # converter.register_structure_hook( - # ForwardRef("PredictedInstance"), - # lambda x, _: converter.structure(x, PredictedInstance), - # ) + converter.register_structure_hook(InstancesList, structure_instances_list) # We can register structure hooks for point arrays that do nothing # because Instance can have a dict of points passed to it in place of @@ -1278,6 +1304,127 @@ def structure_point_array(x, t): return converter +class InstancesList(list): + """A list of `Instance`s associated with a `LabeledFrame`. + + This class should only be used for the `LabeledFrame.instances` attribute. + """ + + def __init__(self, *args, labeled_frame: Optional["LabeledFrame"] = None): + super(InstancesList, self).__init__(*args) + + # Set the labeled frame for each instance + self.labeled_frame = labeled_frame + + @property + def labeled_frame(self) -> "LabeledFrame": + """Return the `LabeledFrame` associated with this list of instances.""" + + return self._labeled_frame + + @labeled_frame.setter + def labeled_frame(self, labeled_frame: "LabeledFrame"): + """Set the `LabeledFrame` associated with this list of instances. + + This updates the `frame` attribute on each instance. + + Args: + labeled_frame: The `LabeledFrame` to associate with this list of instances. + """ + + try: + # If the labeled frame is the same as the one we're setting, then skip + if self._labeled_frame == labeled_frame: + return + except AttributeError: + # Only happens on init and updates each instance.frame (even if None) + pass + + # Otherwise, update the frame for each instance + self._labeled_frame = labeled_frame + for instance in self: + instance.frame = labeled_frame + + def append(self, instance: Union[Instance, PredictedInstance]): + """Append an `Instance` or `PredictedInstance` to the list, setting the frame. + + Args: + item: The `Instance` or `PredictedInstance` to append to the list. + """ + + if not isinstance(instance, (Instance, PredictedInstance)): + raise ValueError( + f"InstancesList can only contain Instance or PredictedInstance objects," + f" but got {type(instance)}." + ) + instance.frame = self.labeled_frame + super().append(instance) + + def extend(self, instances: List[Union[PredictedInstance, Instance]]): + """Extend the list with a list of `Instance`s or `PredictedInstance`s. + + Args: + instances: A list of `Instance` or `PredictedInstance` objects to add to the + list. + + Returns: + None + """ + for instance in instances: + self.append(instance) + + def __delitem__(self, index): + """Remove instance (by index), and set instance.frame to None.""" + + instance: Instance = self.__getitem__(index) + super().__delitem__(index) + + # Modify the instance to remove reference to the frame + instance.frame = None + + def insert(self, index: int, instance: Union[Instance, PredictedInstance]) -> None: + super().insert(index, instance) + instance.frame = self.labeled_frame + + def __setitem__(self, index, instance: Union[Instance, PredictedInstance]): + """Set nth instance in frame to the given instance. + + Args: + index: The index of instance to replace with new instance. + value: The new instance to associate with frame. + + Returns: + None. + """ + super().__setitem__(index, instance) + instance.frame = self.labeled_frame + + def pop(self, index: int) -> Union[Instance, PredictedInstance]: + """Remove and return instance at index, setting instance.frame to None.""" + + instance = super().pop(index) + instance.frame = None + return instance + + def remove(self, instance: Union[Instance, PredictedInstance]) -> None: + """Remove instance from list, setting instance.frame to None.""" + super().remove(instance) + instance.frame = None + + def clear(self) -> None: + """Remove all instances from list, setting instance.frame to None.""" + for instance in self: + instance.frame = None + super().clear() + + def copy(self) -> list: + """Return a shallow copy of the list of instances as a list. + + Note: This will not return an `InstancesList` object, but a normal list. + """ + return list(self) + + @attr.s(auto_attribs=True, eq=False, repr=False, str=False) class LabeledFrame: """Holds labeled data for a single frame of a video. @@ -1290,9 +1437,7 @@ class LabeledFrame: video: Video = attr.ib() frame_idx: int = attr.ib(converter=int) - _instances: Union[List[Instance], List[PredictedInstance]] = attr.ib( - default=attr.Factory(list) - ) + _instances: InstancesList = attr.ib(default=attr.Factory(InstancesList)) def __attrs_post_init__(self): """Called by attrs. @@ -1302,8 +1447,7 @@ def __attrs_post_init__(self): """ # Make sure all instances have a reference to this frame - for instance in self.instances: - instance.frame = self + self.instances = self._instances def __len__(self) -> int: """Return number of instances associated with frame.""" @@ -1319,13 +1463,8 @@ def index(self, value: Instance) -> int: def __delitem__(self, index): """Remove instance (by index) from frame.""" - value = self.instances.__getitem__(index) - self.instances.__delitem__(index) - # Modify the instance to remove reference to this frame - value.frame = None - def __repr__(self) -> str: """Return a readable representation of the LabeledFrame.""" return ( @@ -1348,9 +1487,6 @@ def insert(self, index: int, value: Instance): """ self.instances.insert(index, value) - # Modify the instance to have a reference back to this frame - value.frame = self - def __setitem__(self, index, value: Instance): """Set nth instance in frame to the given instance. @@ -1363,9 +1499,6 @@ def __setitem__(self, index, value: Instance): """ self.instances.__setitem__(index, value) - # Modify the instance to have a reference back to this frame - value.frame = self - def find( self, track: Optional[Union[Track, int]] = -1, user: bool = False ) -> List[Instance]: @@ -1393,7 +1526,7 @@ def instances(self) -> List[Instance]: return self._instances @instances.setter - def instances(self, instances: List[Instance]): + def instances(self, instances: Union[InstancesList, List[Instance]]): """Set the list of instances associated with this frame. Updates the `frame` attribute on each instance to the @@ -1408,9 +1541,11 @@ def instances(self, instances: List[Instance]): None """ - # Make sure to set the frame for each instance to this LabeledFrame - for instance in instances: - instance.frame = self + # Make sure to set the LabeledFrame for each instance to this frame + if isinstance(instances, InstancesList): + instances.labeled_frame = self + else: + instances = InstancesList(instances, labeled_frame=self) self._instances = instances @@ -1685,22 +1820,20 @@ def complex_frame_merge( * list of conflicting instances from base * list of conflicting instances from new """ - merged_instances = [] - redundant_instances = [] - extra_base_instances = copy(base_frame.instances) - extra_new_instances = [] + merged_instances: List[Instance] = [] # Only used for informing user + redundant_instances: List[Instance] = [] + extra_base_instances: List[Instance] = list(base_frame.instances) + extra_new_instances: List[Instance] = [] for new_inst in new_frame: redundant = False for base_inst in base_frame.instances: if new_inst.matches(base_inst): - base_inst.frame = None extra_base_instances.remove(base_inst) redundant_instances.append(base_inst) redundant = True continue if not redundant: - new_inst.frame = None extra_new_instances.append(new_inst) conflict = False @@ -1732,7 +1865,7 @@ def complex_frame_merge( else: # No conflict, so include all instances in base base_frame.instances.extend(extra_new_instances) - merged_instances = copy(extra_new_instances) + merged_instances: List[Instance] = copy(extra_new_instances) extra_base_instances = [] extra_new_instances = [] diff --git a/sleap/io/asyncvideo.py b/sleap/io/asyncvideo.py deleted file mode 100644 index c48d21a8b..000000000 --- a/sleap/io/asyncvideo.py +++ /dev/null @@ -1,218 +0,0 @@ -""" -Support for loading video frames (by chunk) in background process. -""" - -from sleap import Video -from sleap.message import PairedSender, PairedReceiver - -import cattr -import logging -import time -import numpy as np -from math import ceil -from multiprocessing import Process -from typing import Iterable, Iterator, List, Optional, Tuple - - -logger = logging.getLogger(__name__) - - -class AsyncVideo: - """Supports fetching chunks from video in background process.""" - - def __init__(self, base_port: int = 9010): - self.base_port = base_port - - # Spawn the server as a background process - self.server = AsyncVideoServer(self.base_port) - self.server.start() - - # Create sender/receiver for sending requests and receiving data via ZMQ - sender = PairedSender.from_tcp_ports(self.base_port, self.base_port + 1) - result_receiver = PairedReceiver.from_tcp_ports( - send_port=self.base_port + 2, rec_port=self.base_port + 3 - ) - - sender.setup() - result_receiver.setup() - - self.sender = sender - self.receiver = result_receiver - - # Use "handshake" to ensure that initial messages aren't dropped - self.handshake_success = sender.send_handshake() - - def close(self): - """Close the async video server and communication ports.""" - if self.sender and self.server: - self.sender.send_dict(dict(stop=True)) - self.server.join() - - self.server = None - - if self.sender: - self.sender.close() - self.sender = None - - if self.receiver: - self.receiver.close() - self.receiver = None - - def __del__(self): - self.close() - - @classmethod - def from_video( - cls, - video: Video, - frame_idxs: Optional[Iterable[int]] = None, - frames_per_chunk: int = 64, - ) -> "AsyncVideo": - """Create object and start loading frames in background process.""" - obj = cls() - obj.load_by_chunk( - video=video, frame_idxs=frame_idxs, frames_per_chunk=frames_per_chunk - ) - return obj - - def load_by_chunk( - self, - video: Video, - frame_idxs: Optional[Iterable[int]] = None, - frames_per_chunk: int = 64, - ): - """ - Sends request for loading video in background process. - - Args: - video: The :py:class:`Video` to load - frame_idxs: Frame indices we want to load; if None, then full video - is loaded. - frames_per_chunk: How many frames to load per chunk. - - Returns: - None, data should be accessed via :py:method:`chunks`. - """ - # prime the video since this seems to make frames load faster (!?) - video.test_frame - - request_dict = dict( - video=cattr.unstructure(video), frames_per_chunk=frames_per_chunk - ) - # if no frames are specified, whole video will be loaded - if frame_idxs is not None: - request_dict["frame_idxs"] = list(frame_idxs) - - # send the request - self.sender.send_dict(request_dict) - - @property - def chunks(self) -> Iterator[Tuple[List[int], np.ndarray]]: - """ - Generator for fetching chunks of frames. - - When all chunks are loaded, closes the server and communication ports. - - Yields: - Tuple with (list of frame indices, ndarray of frames) - """ - done = False - while not done: - results = self.receiver.check_messages() - if results: - for result in results: - yield result["frame_idxs"], result["ndarray"] - - if result["chunk"] == result["last_chunk"]: - done = True - - # automatically close when all chunks have been received - self.close() - - -class AsyncVideoServer(Process): - """ - Class which loads video frames in background on request. - - All interactions with video server should go through :py:class:`AsyncVideo` - which runs in local thread. - """ - - def __init__(self, base_port: int): - super(AsyncVideoServer, self).__init__() - - self.video = None - self.base_port = base_port - - def run(self): - receiver = PairedReceiver.from_tcp_ports(self.base_port + 1, self.base_port) - receiver.setup() - - result_sender = PairedSender.from_tcp_ports( - send_port=self.base_port + 3, rec_port=self.base_port + 2 - ) - result_sender.setup() - - running = True - while running: - requests = receiver.check_messages() - if requests: - - for request in requests: - - if "stop" in request: - running = False - logger.debug("stopping async video server") - break - - if "video" in request: - self.video = cattr.structure(request["video"], Video) - logger.debug(f"loaded video: {self.video.filename}") - - if self.video is not None: - if "frames_per_chunk" in request: - - load_time = 0 - send_time = 0 - - per_chunk = request["frames_per_chunk"] - - frame_idxs = request.get( - "frame_idxs", list(range(self.video.frames)) - ) - - frame_count = len(frame_idxs) - chunks = ceil(frame_count / per_chunk) - - for chunk_idx in range(chunks): - start = per_chunk * chunk_idx - end = min(per_chunk * (chunk_idx + 1), frame_count) - chunk_frame_idxs = frame_idxs[start:end] - - # load the frames - t0 = time.time() - frames = self.video[chunk_frame_idxs] - t1 = time.time() - load_time += t1 - t0 - - metadata = dict( - chunk=chunk_idx, - last_chunk=chunks - 1, - frame_idxs=chunk_frame_idxs, - ) - - # send back results - t0 = time.time() - result_sender.send_array(metadata, frames) - t1 = time.time() - send_time += t1 - t0 - - logger.debug(f"returned chunk: {chunk_idx+1}/{chunks}") - - logger.debug(f"total load time: {load_time}") - logger.debug(f"total send time: {send_time}") - else: - logger.warning( - "unable to process message since no video loaded" - ) - logger.warning(request) diff --git a/sleap/io/convert.py b/sleap/io/convert.py index 3353a169b..7045ed71f 100644 --- a/sleap/io/convert.py +++ b/sleap/io/convert.py @@ -70,6 +70,7 @@ def create_parser(): help="Output format. Default ('slp') is SLEAP dataset; " "'analysis' results in analysis.h5 file; " "'analysis.nix' results in an analysis nix file;" + "'analysis.csv' results in an analysis csv file;" "'h5' or 'json' results in SLEAP dataset " "with specified file format.", ) @@ -135,7 +136,12 @@ def main(args: list = None): outnames = [path for path in args.outputs] if len(outnames) < len(vids): # if there are less outnames provided than videos to convert... - out_suffix = "nix" if "nix" in args.format else "h5" + if "nix" in args.format: + out_suffix = "nix" + elif "csv" in args.format: + out_suffix = "csv" + else: + out_suffix = "h5" fn = args.input_path fn = re.sub("(\.json(\.zip)?|\.h5|\.slp)$", "", fn) fn = PurePath(fn) @@ -158,6 +164,20 @@ def main(args: list = None): NixAdaptor.write(outname, labels, args.input_path, video) except ValueError as e: print(e.args[0]) + + elif "csv" in args.format: + from sleap.info.write_tracking_h5 import main as write_analysis + + for video, output_path in zip(vids, outnames): + write_analysis( + labels, + output_path=output_path, + labels_path=args.input_path, + all_frames=True, + video=video, + csv=True, + ) + else: from sleap.info.write_tracking_h5 import main as write_analysis diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 45280cc54..1b894089f 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -2055,6 +2055,19 @@ def export(self, filename: str): SleapAnalysisAdaptor.write(filename, self) + def export_csv(self, filename: str): + """Export labels to CSV format. + + Args: + filename: Output path for the CSV format file. + + Notes: + This will write the contents of the labels out as a CSV file. + """ + from sleap.io.format.csv import CSVAdaptor + + CSVAdaptor.write(filename, self) + def export_nwb( self, filename: str, diff --git a/sleap/io/format/coco.py b/sleap/io/format/coco.py index 25122e4d0..44e7fb84a 100644 --- a/sleap/io/format/coco.py +++ b/sleap/io/format/coco.py @@ -180,6 +180,9 @@ def read( if flag == 0: # node not labeled for this instance + if (x, y) != (0, 0): + # If labeled but invisible, place the node at the coord + points[node] = Point(x, y, False) continue is_visible = flag == 2 diff --git a/sleap/io/format/deeplabcut.py b/sleap/io/format/deeplabcut.py index bb5dc3410..5892dba1a 100644 --- a/sleap/io/format/deeplabcut.py +++ b/sleap/io/format/deeplabcut.py @@ -19,10 +19,10 @@ import numpy as np import pandas as pd -from typing import List, Optional +from typing import List, Optional, Dict from sleap import Labels, Video, Skeleton -from sleap.instance import Instance, LabeledFrame, Point +from sleap.instance import Instance, LabeledFrame, Point, Track from sleap.util import find_files_by_suffix from .adaptor import Adaptor, SleapObjectType @@ -119,11 +119,12 @@ def read_frames( # Pull out animal and node names from the columns. start_col = 3 if is_new_format else 1 - animal_names = [] + tracks: Dict[str, Optional[Track]] = {} node_names = [] for animal_name, node_name, _ in data.columns[start_col:][::2]: - if animal_name not in animal_names: - animal_names.append(animal_name) + # Keep the starting frame index for each individual/track + if animal_name not in tracks.keys(): + tracks[animal_name] = None if node_name not in node_names: node_names.append(node_name) @@ -177,23 +178,33 @@ def read_frames( instances = [] if is_multianimal: - for animal_name in animal_names: + for animal_name in tracks.keys(): any_not_missing = False # Get points for each node. instance_points = dict() for node in node_names: - x, y = ( - data[(animal_name, node, "x")][i], - data[(animal_name, node, "y")][i], - ) + if (animal_name, node) in data.columns: + x, y = ( + data[(animal_name, node, "x")][i], + data[(animal_name, node, "y")][i], + ) + else: + x, y = np.nan, np.nan instance_points[node] = Point(x, y) if ~(np.isnan(x) and np.isnan(y)): any_not_missing = True if any_not_missing: + # Create track + if tracks[animal_name] is None: + tracks[animal_name] = Track(spawned_on=i, name=animal_name) # Create instance with points. instances.append( - Instance(skeleton=skeleton, points=instance_points) + Instance( + skeleton=skeleton, + points=instance_points, + track=tracks[animal_name], + ) ) else: # Get points for each node. @@ -270,6 +281,8 @@ def read( skeleton = Skeleton() if project_data.get("multianimalbodyparts", False): skeleton.add_nodes(project_data["multianimalbodyparts"]) + if "uniquebodyparts" in project_data: + skeleton.add_nodes(project_data["uniquebodyparts"]) else: skeleton.add_nodes(project_data["bodyparts"]) @@ -298,13 +311,24 @@ def read( # If subdirectory is foo, we look for foo.mp4 in videos dir. shortname = os.path.split(data_subdir)[-1] - video_path = os.path.join(videos_dir, f"{shortname}.mp4") - - if os.path.exists(video_path): + video_path = None + if os.path.exists(videos_dir): + with os.scandir(videos_dir) as file_iterator: + for file in file_iterator: + if not file.is_file(): + continue + if os.path.splitext(file.name)[0] != shortname: + continue + video_path = os.path.join(videos_dir, file.name) + break + + if video_path is not None and os.path.exists(video_path): video = Video.from_filename(video_path) else: # When no video is found, the individual frame images # stored in the labeled data subdir will be used. + if video_path is None: + video_path = os.path.join(videos_dir, f"{shortname}.mp4") print( f"Unable to find {video_path} so using individual frame images." ) diff --git a/sleap/io/format/hdf5.py b/sleap/io/format/hdf5.py index 353f88e3a..55a30d74f 100644 --- a/sleap/io/format/hdf5.py +++ b/sleap/io/format/hdf5.py @@ -81,7 +81,10 @@ def read_headers( # Extract the Labels JSON metadata and create Labels object with just this # metadata. - dicts = json_loads(f.require_group("metadata").attrs["json"].tobytes().decode()) + json = f.require_group("metadata").attrs["json"] + if not isinstance(json, str): + json = json.tobytes().decode() + dicts = json_loads(json) # These items are stored in separate lists because the metadata group got to be # too big. @@ -151,6 +154,45 @@ def read( points_dset[:]["x"] -= 0.5 points_dset[:]["y"] -= 0.5 + def cast_as_compound(arr, dtype): + out = np.empty(shape=(len(arr),), dtype=dtype) + if out.size == 0: + return out + for i, (name, _) in enumerate(dtype): + out[name] = arr[:, i] + return out + + # cast points, instances, and frames into complex dtype if not already + dtype_points = [("x", " np.ndarray: def get_frames_safely(self, idxs: Iterable[int]) -> Tuple[List[int], np.ndarray]: """Return list of frame indices and frames which were successfully loaded. + Args: + idxs: An iterable object that contains the indices of frames. - idxs: An iterable object that contains the indices of frames. Returns: A tuple of (frame indices, frames), where * frame indices is a subset of the specified idxs, and @@ -1442,19 +1443,31 @@ def to_hdf5( def encode(img): _, encoded = cv2.imencode("." + format, img) - return np.squeeze(encoded) + return np.squeeze(encoded).astype("int8") + + # pad with zeroes to guarantee int8 type in hdf5 file + frames = [] + for i in range(len(frame_numbers)): + frames.append(encode(frame_data[i])) + + max_frame_size = ( + max([len(x) if len(x) else 0 for x in frames]) if len(frames) else 0 + ) - dtype = h5.special_dtype(vlen=np.dtype("int8")) dset = f.create_dataset( - dataset + "/video", (len(frame_numbers),), dtype=dtype + dataset + "/video", + (len(frame_numbers), max_frame_size), + dtype="int8", + compression="gzip", ) dset.attrs["format"] = format dset.attrs["channels"] = self.channels dset.attrs["height"] = self.height dset.attrs["width"] = self.width - for i in range(len(frame_numbers)): - dset[i] = encode(frame_data[i]) + for i, frame in enumerate(frames): + dset[i, 0 : len(frame)] = frame + else: f.create_dataset( dataset + "/video", @@ -1532,22 +1545,17 @@ def cattr(): A cattr converter. """ - # When we are structuring video backends, try to fixup the video file paths - # in case they are coming from a different computer or the file has been moved. - def fixup_video(x, cl): - if "filename" in x: - x["filename"] = Video.fixup_path(x["filename"]) - if "file" in x: - x["file"] = Video.fixup_path(x["file"]) + # Use from_filename to fixup the video path and determine backend + def fixup_video(x: dict, cl: Video): + backend_dict = x.pop("backend") + filename = backend_dict.pop("filename", None) or backend_dict.pop( + "file", None + ) - return Video.make_specific_backend(cl, x) + return Video.from_filename(filename, **backend_dict) vid_cattr = cattr.Converter() - - # Check the type hint for backend and register the video path - # fixup hook for each type in the Union. - for t in attr.fields(Video).backend.type.__args__: - vid_cattr.register_structure_hook(t, fixup_video) + vid_cattr.register_structure_hook(Video, fixup_video) return vid_cattr diff --git a/sleap/io/videowriter.py b/sleap/io/videowriter.py index 510fad739..cd710c9d5 100644 --- a/sleap/io/videowriter.py +++ b/sleap/io/videowriter.py @@ -12,6 +12,7 @@ from abc import ABC, abstractmethod import cv2 import numpy as np +import imageio.v2 as iio class VideoWriter(ABC): @@ -32,22 +33,26 @@ def close(self): @staticmethod def safe_builder(filename, height, width, fps): """Builds VideoWriter based on available dependencies.""" - if VideoWriter.can_use_skvideo(): - return VideoWriterSkvideo(filename, height, width, fps) + if VideoWriter.can_use_ffmpeg(): + return VideoWriterImageio(filename, height, width, fps) else: return VideoWriterOpenCV(filename, height, width, fps) @staticmethod - def can_use_skvideo(): - # See if we can import skvideo + def can_use_ffmpeg(): + """Check if ffmpeg is available for writing videos.""" try: - import skvideo + import imageio_ffmpeg as ffmpeg except ImportError: return False - # See if skvideo can find FFMPEG - if skvideo.getFFmpegVersion() != "0.0.0": - return True + try: + # Try to get the version of the ffmpeg plugin + ffmpeg_version = ffmpeg.get_ffmpeg_version() + if ffmpeg_version: + return True + except Exception: + return False return False @@ -68,11 +73,11 @@ def close(self): self._writer.release() -class VideoWriterSkvideo(VideoWriter): - """Writes video using scikit-video as wrapper for ffmpeg. +class VideoWriterImageio(VideoWriter): + """Writes video using imageio as a wrapper for ffmpeg. Attributes: - filename: Path to mp4 file to save to. + filename: Path to video file to save to. height: Height of movie frames. width: Width of movie frames. fps: Playback framerate to save at. @@ -85,28 +90,38 @@ class VideoWriterSkvideo(VideoWriter): def __init__( self, filename, height, width, fps, crf: int = 21, preset: str = "superfast" ): - import skvideo.io - - fps = str(fps) - self._writer = skvideo.io.FFmpegWriter( + self.filename = filename + self.height = height + self.width = width + self.fps = fps + self.crf = crf + self.preset = preset + + import imageio_ffmpeg as ffmpeg + + # Imageio's ffmpeg writer parameters + # https://imageio.readthedocs.io/en/stable/examples.html#writing-videos-with-ffmpeg-and-vaapi + # Use `ffmpeg -h encoder=libx264`` to see all options for libx264 output_params + # output_params must be a list of strings + # iio.help(name='FFMPEG') to test + self.writer = iio.get_writer( filename, - inputdict={ - "-r": fps, - }, - outputdict={ - "-c:v": "libx264", - "-preset": preset, - "-vf": "scale=trunc(iw/2)*2:trunc(ih/2)*2", # Need even dims for libx264 - "-framerate": fps, - "-crf": str(crf), - "-pix_fmt": "yuv420p", - }, + fps=fps, + codec="libx264", + format="FFMPEG", + pixelformat="yuv420p", + output_params=[ + "-preset", + preset, + "-crf", + str(crf), + ], ) def add_frame(self, img, bgr: bool = False): if bgr: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - self._writer.writeFrame(img) + self.writer.append_data(img) def close(self): - self._writer.close() + self.writer.close() diff --git a/sleap/io/visuals.py b/sleap/io/visuals.py index 2018ce0bf..f2dde0be3 100644 --- a/sleap/io/visuals.py +++ b/sleap/io/visuals.py @@ -27,7 +27,13 @@ _sentinel = object() -def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0): +def reader( + out_q: Queue, + video: Video, + frames: List[int], + scale: float = 1.0, + background: str = "original", +): """Read frame images from video and send them into queue. Args: @@ -36,11 +42,13 @@ def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0): video: The `Video` object to read. frames: Full list frame indexes we want to read. scale: Output scale for frame images. + background: output video background. Either original, black, white, grey Returns: None. """ + background = background.lower() cv2.setNumThreads(usable_cpu_count()) total_count = len(frames) @@ -64,6 +72,16 @@ def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0): loaded_chunk_idxs, video_frame_images = video.get_frames_safely( frames_idx_chunk ) + if background != "original": + # fill the frame with the color + fill_values = {"black": 0, "grey": 127, "white": 255} + try: + fill = fill_values[background] + except KeyError: + raise ValueError( + f"Invalid background color: {background}. Options include: {', '.join(fill_values.keys())}" + ) + video_frame_images = video_frame_images * 0 + fill if not loaded_chunk_idxs: print(f"No frames could be loaded from chunk {chunk_i}") @@ -497,6 +515,7 @@ def save_labeled_video( fps: int = 15, scale: float = 1.0, crop_size_xy: Optional[Tuple[int, int]] = None, + background: str = "original", show_edges: bool = True, edge_is_wedge: bool = False, marker_size: int = 4, @@ -515,6 +534,7 @@ def save_labeled_video( fps: Frames per second for output video. scale: scale of image (so we can scale point locations to match) crop_size_xy: size of crop around instances, or None for full images + background: output video background. Either original, black, white, grey show_edges: whether to draw lines between nodes edge_is_wedge: whether to draw edges as wedges (draw as line if False) marker_size: Size of marker in pixels before scaling by `scale` @@ -537,7 +557,7 @@ def save_labeled_video( q2 = Queue(maxsize=10) progress_queue = Queue() - thread_read = Thread(target=reader, args=(q1, video, frames, scale)) + thread_read = Thread(target=reader, args=(q1, video, frames, scale, background)) thread_mark = VideoMarkerThread( in_q=q1, out_q=q2, @@ -695,6 +715,15 @@ def main(args: list = None): "and 'nodes' (default: 'nodes')" ), ) + parser.add_argument( + "--background", + type=str, + default="original", + help=( + "Specify the type of background to be used to save the videos." + "Options for background: original, black, white and grey" + ), + ) args = parser.parse_args(args=args) labels = Labels.load_file( args.data_path, video_search=[os.path.dirname(args.data_path)] @@ -730,6 +759,7 @@ def main(args: list = None): marker_size=args.marker_size, palette=args.palette, distinctly_color=args.distinctly_color, + background=args.background, ) print(f"Video saved as: {filename}") diff --git a/sleap/nn/config/outputs.py b/sleap/nn/config/outputs.py index ffb0d76e4..ccb6077b1 100644 --- a/sleap/nn/config/outputs.py +++ b/sleap/nn/config/outputs.py @@ -151,8 +151,8 @@ class OutputsConfig: save_visualizations: If True, will render and save visualizations of the model predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the split is one of "train", "validation", "test". - delete_viz_images: If True, delete the saved visualizations after training - completes. This is useful to reduce the model folder size if you do not need + keep_viz_images: If True, keep the saved visualization images after training + completes. This is useful unchecked to reduce the model folder size if you do not need to keep the visualization images. zip_outputs: If True, compress the run folder to a zip file. This will be named "{run_folder}.zip". @@ -170,7 +170,7 @@ class OutputsConfig: runs_folder: Text = "models" tags: List[Text] = attr.ib(factory=list) save_visualizations: bool = True - delete_viz_images: bool = True + keep_viz_images: bool = False zip_outputs: bool = False log_to_csv: bool = True checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig) diff --git a/sleap/nn/data/augmentation.py b/sleap/nn/data/augmentation.py index 21dfb29e6..b754c0fe9 100644 --- a/sleap/nn/data/augmentation.py +++ b/sleap/nn/data/augmentation.py @@ -1,19 +1,11 @@ """Transformers for applying data augmentation.""" -# Monkey patch for: https://github.com/aleju/imgaug/issues/537 -# TODO: Fix when PyPI/conda packages are available for version fencing. -import numpy - -if hasattr(numpy.random, "_bit_generator"): - numpy.random.bit_generator = numpy.random._bit_generator - import sleap import numpy as np import tensorflow as tf import attr from typing import List, Text, Optional -import imgaug as ia -import imgaug.augmenters as iaa +import albumentations as A from sleap.nn.config import AugmentationConfig from sleap.nn.data.instance_cropping import crop_bboxes @@ -111,15 +103,15 @@ def flip_instances_ud( @attr.s(auto_attribs=True) -class ImgaugAugmenter: - """Data transformer based on the `imgaug` library. +class AlbumentationsAugmenter: + """Data transformer based on the `albumentations` library. This class can generate a `tf.data.Dataset` from an existing one that generates image and instance data. Element of the output dataset will have a set of augmentation transformations applied. Attributes: - augmenter: An instance of `imgaug.augmenters.Sequential` that will be applied to + augmenter: An instance of `albumentations.Compose` that will be applied to each element of the input dataset. image_key: Name of the example key where the image is stored. Defaults to "image". @@ -127,7 +119,7 @@ class ImgaugAugmenter: Defaults to "instances". """ - augmenter: iaa.Sequential + augmenter: A.Compose image_key: str = "image" instances_key: str = "instances" @@ -137,7 +129,7 @@ def from_config( config: AugmentationConfig, image_key: Text = "image", instances_key: Text = "instances", - ) -> "ImgaugAugmenter": + ) -> "AlbumentationsAugmenter": """Create an augmenter from a set of configuration parameters. Args: @@ -148,52 +140,64 @@ def from_config( Defaults to "instances". Returns: - An instance of `ImgaugAugmenter` with the specified augmentation + An instance of `AlbumentationsAugmenter` with the specified augmentation configuration. """ aug_stack = [] if config.rotate: aug_stack.append( - iaa.Affine( - rotate=(config.rotation_min_angle, config.rotation_max_angle) + A.Rotate( + limit=(config.rotation_min_angle, config.rotation_max_angle), p=1.0 ) ) if config.translate: aug_stack.append( - iaa.Affine( + A.Affine( translate_px={ "x": (config.translate_min, config.translate_max), "y": (config.translate_min, config.translate_max), - } + }, + p=1.0, ) ) if config.scale: - aug_stack.append(iaa.Affine(scale=(config.scale_min, config.scale_max))) - if config.uniform_noise: aug_stack.append( - iaa.AddElementwise( - value=(config.uniform_noise_min_val, config.uniform_noise_max_val) - ) + A.Affine(scale=(config.scale_min, config.scale_max), p=1.0) ) + if config.uniform_noise: + + def uniform_noise(image, **kwargs): + return image + np.random.uniform( + config.uniform_noise_min_val, config.uniform_noise_max_val + ) + + aug_stack.append(A.Lambda(image=uniform_noise)) if config.gaussian_noise: aug_stack.append( - iaa.AdditiveGaussianNoise( - loc=config.gaussian_noise_mean, scale=config.gaussian_noise_stddev + A.GaussNoise( + mean=config.gaussian_noise_mean, + var_limit=config.gaussian_noise_stddev, ) ) if config.contrast: aug_stack.append( - iaa.GammaContrast( - gamma=(config.contrast_min_gamma, config.contrast_max_gamma) + A.RandomGamma( + gamma_limit=(config.contrast_min_gamma, config.contrast_max_gamma), + p=1.0, ) ) if config.brightness: aug_stack.append( - iaa.Add(value=(config.brightness_min_val, config.brightness_max_val)) + A.RandomBrightness( + limit=(config.brightness_min_val, config.brightness_max_val), p=1.0 + ) ) return cls( - augmenter=iaa.Sequential(aug_stack), + augmenter=A.Compose( + aug_stack, + keypoint_params=A.KeypointParams(format="xy", remove_invisible=False), + ), image_key=image_key, instances_key=instances_key, ) @@ -226,22 +230,16 @@ def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset: # Define augmentation function to map over each sample. def py_augment(image, instances): """Local processing function that will not be autographed.""" - # Ensure that the transformations applied to all data within this - # example are kept consistent. - aug_det = self.augmenter.to_deterministic() + # Convert to numpy arrays. + img = image.numpy() + kps = instances.numpy() + original_shape = kps.shape + kps = kps.reshape(-1, 2) - # Augment the image. - aug_img = aug_det.augment_image(image.numpy()) - - # This will get converted to a rank 3 tensor (n_instances, n_nodes, 2). - aug_instances = np.full_like(instances, np.nan) - - # Augment each set of points for each instance. - for i, instance in enumerate(instances): - kps = ia.KeypointsOnImage.from_xy_array( - instance.numpy(), tuple(image.shape) - ) - aug_instances[i] = aug_det.augment_keypoints(kps).to_xy_array() + # Augment. + augmented = self.augmenter(image=img, keypoints=kps) + aug_img = augmented["image"] + aug_instances = np.array(augmented["keypoints"]).reshape(original_shape) return aug_img, aug_instances @@ -258,7 +256,6 @@ def augment(frame_data): return frame_data # Apply the augmentation to each element. - # Note: We map sequentially since imgaug gets slower with tf.data parallelism. output_ds = input_ds.map(augment) return output_ds diff --git a/sleap/nn/data/pipelines.py b/sleap/nn/data/pipelines.py index b0892f8a1..2e334456a 100644 --- a/sleap/nn/data/pipelines.py +++ b/sleap/nn/data/pipelines.py @@ -18,7 +18,7 @@ from sleap.nn.data.providers import LabelsReader, VideoReader from sleap.nn.data.augmentation import ( AugmentationConfig, - ImgaugAugmenter, + AlbumentationsAugmenter, RandomCropper, RandomFlipper, ) @@ -68,7 +68,7 @@ PROVIDERS = (LabelsReader, VideoReader) TRANSFORMERS = ( - ImgaugAugmenter, + AlbumentationsAugmenter, RandomCropper, Normalizer, Resizer, @@ -406,7 +406,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: self.data_config.labels.skeletons[0], horizontal=self.optimization_config.augmentation_config.flip_horizontal, ) - pipeline += ImgaugAugmenter.from_config( + pipeline += AlbumentationsAugmenter.from_config( self.optimization_config.augmentation_config ) if self.optimization_config.augmentation_config.random_crop: @@ -550,7 +550,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: self.data_config.labels.skeletons[0], horizontal=self.optimization_config.augmentation_config.flip_horizontal, ) - pipeline += ImgaugAugmenter.from_config( + pipeline += AlbumentationsAugmenter.from_config( self.optimization_config.augmentation_config ) if self.optimization_config.augmentation_config.random_crop: @@ -713,7 +713,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: self.data_config.labels.skeletons[0], horizontal=self.optimization_config.augmentation_config.flip_horizontal, ) - pipeline += ImgaugAugmenter.from_config( + pipeline += AlbumentationsAugmenter.from_config( self.optimization_config.augmentation_config ) pipeline += Normalizer.from_config(self.data_config.preprocessing) @@ -863,7 +863,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: self.data_config.labels.skeletons[0], horizontal=aug_config.flip_horizontal, ) - pipeline += ImgaugAugmenter.from_config(aug_config) + pipeline += AlbumentationsAugmenter.from_config(aug_config) if aug_config.random_crop: pipeline += RandomCropper( crop_height=aug_config.random_crop_height, @@ -1028,7 +1028,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: horizontal=aug_config.flip_horizontal, ) - pipeline += ImgaugAugmenter.from_config(aug_config) + pipeline += AlbumentationsAugmenter.from_config(aug_config) if aug_config.random_crop: pipeline += RandomCropper( crop_height=aug_config.random_crop_height, @@ -1186,7 +1186,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: config=self.data_config.preprocessing, provider=data_provider, ) - pipeline += ImgaugAugmenter.from_config( + pipeline += AlbumentationsAugmenter.from_config( self.optimization_config.augmentation_config ) pipeline += Normalizer.from_config(self.data_config.preprocessing) diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 16f439d10..9e93d0b18 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -394,9 +394,7 @@ def make_dataset(self) -> tf.data.Dataset: grid in order to properly map points to image coordinates. """ # Grab an image to test for the dtype. - test_image = tf.convert_to_tensor( - self.video.get_frame(self.video.last_frame_idx) - ) + test_image = tf.convert_to_tensor(self.video.get_frame(0)) image_dtype = test_image.dtype def py_fetch_frame(ind): diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 6d7d24f8c..3f01a1c3c 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -33,6 +33,7 @@ import atexit import subprocess import rich.progress +import pandas as pd from rich.pretty import pprint from collections import deque import json @@ -1142,6 +1143,7 @@ def export_model( info["frozen_model_inputs"] = frozen_func.inputs info["frozen_model_outputs"] = frozen_func.outputs + info["unragged_outputs"] = unrag_outputs with (Path(save_path) / "info.json").open("w") as fp: json.dump( @@ -1582,6 +1584,15 @@ def _object_builder(): try: for ex in generator: prediction_queue.put(ex) + + except KeyError as e: + # Gracefully handle seeking errors by early termination. + if "Unable to load frame" in str(e): + pass # TODO: Print warning obeying verbosity? (This code path is also + # called for interactive prediction where we don't want any spam.) + else: + raise + finally: prediction_queue.put(None) object_builder.join() @@ -2611,6 +2622,7 @@ def _object_builder(): # Set tracks for predicted instances in this frame. predicted_instances = self.tracker.track( untracked_instances=predicted_instances, + img_hw=ex["image"].shape[-3:-1], img=image, t=frame_ind, ) @@ -2632,6 +2644,15 @@ def _object_builder(): try: for ex in generator: prediction_queue.put(ex) + + except KeyError as e: + # Gracefully handle seeking errors by early termination. + if "Unable to load frame" in str(e): + pass # TODO: Print warning obeying verbosity? (This code path is also + # called for interactive prediction where we don't want any spam.) + else: + raise + finally: prediction_queue.put(None) object_builder.join() @@ -3244,6 +3265,7 @@ def _object_builder(): # Set tracks for predicted instances in this frame. predicted_instances = self.tracker.track( untracked_instances=predicted_instances, + img_hw=ex["image"].shape[-3:-1], img=image, t=frame_ind, ) @@ -3265,6 +3287,15 @@ def _object_builder(): try: for ex in generator: prediction_queue.put(ex) + + except KeyError as e: + # Gracefully handle seeking errors by early termination. + if "Unable to load frame" in str(e): + pass # TODO: Print warning obeying verbosity? (This code path is also + # called for interactive prediction where we don't want any spam.) + else: + raise + finally: prediction_queue.put(None) object_builder.join() @@ -3747,9 +3778,10 @@ def _object_builder(): PredictedInstance.from_numpy( points=pts, point_confidences=confs, - instance_score=np.nanmean(score), + instance_score=np.nanmean(confs), skeleton=skeleton, track=track, + tracking_score=np.nanmean(score), ) ) @@ -3770,6 +3802,15 @@ def _object_builder(): try: for ex in generator: prediction_queue.put(ex) + + except KeyError as e: + # Gracefully handle seeking errors by early termination. + if "Unable to load frame" in str(e): + pass # TODO: Print warning obeying verbosity? (This code path is also + # called for interactive prediction where we don't want any spam.) + else: + raise + finally: prediction_queue.put(None) object_builder.join() @@ -4412,18 +4453,27 @@ def _object_builder(): break # Loop over frames. - for image, video_ind, frame_ind, points, confidences, scores in zip( + for ( + image, + video_ind, + frame_ind, + centroid_vals, + points, + confidences, + scores, + ) in zip( ex["image"], ex["video_ind"], ex["frame_ind"], + ex["centroid_vals"], ex["instance_peaks"], ex["instance_peak_vals"], ex["instance_scores"], ): # Loop over instances. predicted_instances = [] - for i, (pts, confs, score) in enumerate( - zip(points, confidences, scores) + for i, (pts, centroid_val, confs, score) in enumerate( + zip(points, centroid_vals, confidences, scores) ): if np.isnan(pts).all(): continue @@ -4434,9 +4484,10 @@ def _object_builder(): PredictedInstance.from_numpy( points=pts, point_confidences=confs, - instance_score=np.nanmean(score), + instance_score=centroid_val, skeleton=skeleton, track=track, + tracking_score=score, ) ) @@ -4457,6 +4508,15 @@ def _object_builder(): try: for ex in generator: prediction_queue.put(ex) + + except KeyError as e: + # Gracefully handle seeking errors by early termination. + if "Unable to load frame" in str(e): + pass # TODO: Print warning obeying verbosity? (This code path is also + # called for interactive prediction where we don't want any spam.) + else: + raise + finally: prediction_queue.put(None) object_builder.join() @@ -4734,6 +4794,15 @@ def _object_builder(): try: for ex in generator: prediction_queue.put(ex) + + except KeyError as e: + # Gracefully handle seeking errors by early termination. + if "Unable to load frame" in str(e): + pass # TODO: Print warning obeying verbosity? (This code path is also + # called for interactive prediction where we don't want any spam.) + else: + raise + finally: prediction_queue.put(None) object_builder.join() @@ -4939,7 +5008,7 @@ def export_cli(args: Optional[list] = None): export_model( args.models, args.export_path, - unrag_outputs=args.unrag, + unrag_outputs=(not args.ragged), max_instances=args.max_instances, ) @@ -4971,13 +5040,13 @@ def _make_export_cli_parser() -> argparse.ArgumentParser: ), ) parser.add_argument( - "-u", - "--unrag", + "-r", + "--ragged", action="store_true", - default=True, + default=False, help=( - "Convert ragged tensors into regular tensors with NaN padding. " - "Defaults to True." + "Keep tensors ragged if present. If ommited, convert ragged tensors" + " into regular tensors with NaN padding." ), ) parser.add_argument( @@ -5230,15 +5299,14 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: args: Parsed CLI namespace. Returns: - A tuple of `(provider, data_path)` with the data `Provider` and path to the data - that was specified in the args. + `(provider_list, data_path_list, output_path_list)` where `provider_list` contains the data providers, + `data_path_list` contains the paths to the specified data, and the `output_path_list` contains the list + of output paths if a CSV file with a column of output paths was provided; otherwise, `output_path_list` + defaults to None """ + # Figure out which input path to use. - labels_path = getattr(args, "labels", None) - if labels_path is not None: - data_path = labels_path - else: - data_path = args.data_path + data_path = args.data_path if data_path is None or data_path == "": raise ValueError( @@ -5246,33 +5314,117 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: "Run 'sleap-track -h' to see full command documentation." ) - if data_path.endswith(".slp"): - labels = sleap.load_file(data_path) - - if args.only_labeled_frames: - provider = LabelsReader.from_user_labeled_frames(labels) - elif args.only_suggested_frames: - provider = LabelsReader.from_unlabeled_suggestions(labels) - elif getattr(args, "video.index") != "": - provider = VideoReader( - video=labels.videos[int(getattr(args, "video.index"))], - example_indices=frame_list(args.frames), - ) + data_path_obj = Path(data_path) + + # Set output_path_list to None as a default to return later + output_path_list = None + + # Check that input value is valid + if not data_path_obj.exists(): + raise ValueError("Path to data_path does not exist") + + elif data_path_obj.is_file(): + # If the file is a CSV file, check for data_paths and output_paths + if data_path_obj.suffix.lower() == ".csv": + try: + data_path_column = None + # Read the CSV file + df = pd.read_csv(data_path) + + # collect data_paths from column + for col_index in range(df.shape[1]): + path_str = df.iloc[0, col_index] + if Path(path_str).exists(): + data_path_column = df.columns[col_index] + break + if data_path_column is None: + raise ValueError( + f"Column containing valid data_paths does not exist in the CSV file: {data_path}" + ) + raw_data_path_list = df[data_path_column].tolist() + + # optional output_path column to specify multiple output_paths + output_path_column_index = df.columns.get_loc(data_path_column) + 1 + if ( + output_path_column_index < df.shape[1] + and df.iloc[:, output_path_column_index].dtype == object + ): + # Ensure the next column exists + output_path_list = df.iloc[:, output_path_column_index].tolist() + else: + output_path_list = None + + except pd.errors.EmptyDataError as e: + raise ValueError(f"CSV file is empty: {data_path}. Error: {e}") from e + + # If the file is a text file, collect data_paths + elif data_path_obj.suffix.lower() == ".txt": + try: + with open(data_path_obj, "r") as file: + raw_data_path_list = [line.strip() for line in file.readlines()] + except Exception as e: + raise ValueError( + f"Error reading text file: {data_path}. Error: {e}" + ) from e else: - provider = LabelsReader(labels) + raw_data_path_list = [data_path_obj.as_posix()] - else: - print(f"Video: {data_path}") - # TODO: Clean this up. - video_kwargs = dict( - dataset=vars(args).get("video.dataset"), - input_format=vars(args).get("video.input_format"), - ) - provider = VideoReader.from_filepath( - filename=data_path, example_indices=frame_list(args.frames), **video_kwargs - ) + raw_data_path_list = [Path(p) for p in raw_data_path_list] - return provider, data_path + # Check for multiple video inputs + # Compile file(s) into a list for later iteration + elif data_path_obj.is_dir(): + raw_data_path_list = [ + file_path for file_path in data_path_obj.iterdir() if file_path.is_file() + ] + + # Provider list to accomodate multiple video inputs + provider_list = [] + data_path_list = [] + for file_path in raw_data_path_list: + # Create a provider for each file + if file_path.as_posix().endswith(".slp") and len(raw_data_path_list) > 1: + print(f"slp file skipped: {file_path.as_posix()}") + + elif file_path.as_posix().endswith(".slp"): + labels = sleap.load_file(file_path.as_posix()) + + if args.only_labeled_frames: + provider_list.append(LabelsReader.from_user_labeled_frames(labels)) + elif args.only_suggested_frames: + provider_list.append(LabelsReader.from_unlabeled_suggestions(labels)) + elif getattr(args, "video.index") != "": + provider_list.append( + VideoReader( + video=labels.videos[int(getattr(args, "video.index"))], + example_indices=frame_list(args.frames), + ) + ) + else: + provider_list.append(LabelsReader(labels)) + + data_path_list.append(file_path) + + else: + try: + video_kwargs = dict( + dataset=vars(args).get("video.dataset"), + input_format=vars(args).get("video.input_format"), + ) + provider_list.append( + VideoReader.from_filepath( + filename=file_path.as_posix(), + example_indices=frame_list(args.frames), + **video_kwargs, + ) + ) + print(f"Video: {file_path.as_posix()}") + data_path_list.append(file_path) + # TODO: Clean this up. + except Exception: + print(f"Error reading file: {file_path.as_posix()}") + + return provider_list, data_path_list, output_path_list def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: @@ -5367,8 +5519,6 @@ def main(args: Optional[list] = None): pprint(vars(args)) print() - output_path = args.output - # Setup devices. if args.cpu or not sleap.nn.system.is_gpu_system(): sleap.nn.system.use_cpu_only() @@ -5406,7 +5556,20 @@ def main(args: Optional[list] = None): print() # Setup data loader. - provider, data_path = _make_provider_from_cli(args) + provider_list, data_path_list, output_path_list = _make_provider_from_cli(args) + + output_path = None + + # if output_path has not been extracted from a csv file yet + if output_path_list is None and args.output is not None: + output_path = args.output + output_path_obj = Path(output_path) + + # check if output_path is valid before running inference + if Path(output_path).is_file() and len(data_path_list) > 1: + raise ValueError( + "output_path argument must be a directory if multiple video inputs are given" + ) # Setup tracker. tracker = _make_tracker_from_cli(args) @@ -5414,25 +5577,94 @@ def main(args: Optional[list] = None): if args.models is not None and "movenet" in args.models[0]: args.models = args.models[0] - # Either run inference (and tracking) or just run tracking + # Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run) if args.models is not None: - # Setup models. - predictor = _make_predictor_from_cli(args) - predictor.tracker = tracker - # Run inference! - labels_pr = predictor.predict(provider) + # Run inference on all files inputed + for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)): + # Setup models. + data_path_obj = Path(data_path) + predictor = _make_predictor_from_cli(args) + predictor.tracker = tracker - if output_path is None: - output_path = data_path + ".predictions.slp" + # Run inference! + labels_pr = predictor.predict(provider) - labels_pr.provenance["model_paths"] = predictor.model_paths - labels_pr.provenance["predictor"] = type(predictor).__name__ + # if output path was not provided, create an output path + if output_path is None: + # if output path was not provided, create an output path + if output_path_list: + output_path = output_path_list[i] + + else: + output_path = data_path_obj.with_suffix(".predictions.slp") + output_path_obj = Path(output_path) + + # if output_path was provided and multiple inputs were provided, create a directory to store outputs + elif len(data_path_list) > 1: + output_path_obj = Path(output_path) + output_path = ( + output_path_obj + / (data_path_obj.with_suffix(".predictions.slp")).name + ) + output_path_obj = Path(output_path) + # Create the containing directory if needed. + output_path_obj.parent.mkdir(exist_ok=True, parents=True) + + labels_pr.provenance["model_paths"] = predictor.model_paths + labels_pr.provenance["predictor"] = type(predictor).__name__ + + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() + + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["command"] = " ".join(sys.argv) + labels_pr.provenance["data_path"] = data_path_obj.as_posix() + labels_pr.provenance["output_path"] = output_path_obj.as_posix() + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp + + print("Provenance:") + pprint(labels_pr.provenance) + print() + + labels_pr.provenance["args"] = vars(args) + + # Save results. + try: + labels_pr.save(output_path) + except Exception: + print("WARNING: Provided output path invalid.") + fallback_path = data_path_obj.with_suffix(".predictions.slp") + labels_pr.save(fallback_path) + print("Saved output:", output_path) + + if args.open_in_gui: + subprocess.call(["sleap-label", output_path]) + + # Reset output_path for next iteration + output_path = args.output + + # running tracking on existing prediction file elif getattr(args, "tracking.tracker") is not None: + provider = provider_list[0] + data_path = data_path_list[0] + # Load predictions + data_path = args.data_path print("Loading predictions...") - labels_pr = sleap.load_file(args.data_path) + labels_pr = sleap.load_file(data_path) frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) print("Starting tracker...") @@ -5444,6 +5676,40 @@ def main(args: Optional[list] = None): if output_path is None: output_path = f"{data_path}.{tracker.get_name()}.slp" + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() + + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["command"] = " ".join(sys.argv) + labels_pr.provenance["data_path"] = data_path + labels_pr.provenance["output_path"] = output_path + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp + + print("Provenance:") + pprint(labels_pr.provenance) + print() + + labels_pr.provenance["args"] = vars(args) + + # Save results. + labels_pr.save(output_path) + + print("Saved output:", output_path) + + if args.open_in_gui: + subprocess.call(["sleap-label", output_path]) + else: raise ValueError( "Neither tracker type nor path to trained models specified. " @@ -5451,36 +5717,3 @@ def main(args: Optional[list] = None): "To retrack on predictions, must specify tracker. " "Use \"sleap-track --tracking.tracker ...' to specify tracker to use." ) - - if args.no_empty_frames: - # Clear empty frames if specified. - labels_pr.remove_empty_frames() - - finish_timestamp = str(datetime.now()) - total_elapsed = time() - t0 - print("Finished inference at:", finish_timestamp) - print(f"Total runtime: {total_elapsed} secs") - print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") - - # Add provenance metadata to predictions. - labels_pr.provenance["sleap_version"] = sleap.__version__ - labels_pr.provenance["platform"] = platform.platform() - labels_pr.provenance["command"] = " ".join(sys.argv) - labels_pr.provenance["data_path"] = data_path - labels_pr.provenance["output_path"] = output_path - labels_pr.provenance["total_elapsed"] = total_elapsed - labels_pr.provenance["start_timestamp"] = start_timestamp - labels_pr.provenance["finish_timestamp"] = finish_timestamp - - print("Provenance:") - pprint(labels_pr.provenance) - print() - - labels_pr.provenance["args"] = vars(args) - - # Save results. - labels_pr.save(output_path) - print("Saved output:", output_path) - - if args.open_in_gui: - subprocess.call(["sleap-label", output_path]) diff --git a/sleap/nn/peak_finding.py b/sleap/nn/peak_finding.py index 84dca00ae..e1fb43a6e 100644 --- a/sleap/nn/peak_finding.py +++ b/sleap/nn/peak_finding.py @@ -221,7 +221,7 @@ def find_global_peaks_rough( channels = tf.cast(tf.shape(cms)[-1], tf.int64) total_peaks = tf.cast(tf.shape(argmax_cols)[0], tf.int64) sample_subs = tf.range(total_peaks, dtype=tf.int64) // channels - channel_subs = tf.range(total_peaks, dtype=tf.int64) % channels + channel_subs = tf.math.mod(tf.range(total_peaks, dtype=tf.int64), channels) # Gather subscripts. peak_subs = tf.stack([sample_subs, argmax_rows, argmax_cols, channel_subs], axis=1) diff --git a/sleap/nn/system.py b/sleap/nn/system.py index eeb3f3ca4..4cc3d1804 100644 --- a/sleap/nn/system.py +++ b/sleap/nn/system.py @@ -48,7 +48,17 @@ def get_current_gpu() -> tf.config.PhysicalDevice: def use_cpu_only(): """Hide GPUs from TensorFlow to ensure only the CPU is available.""" - tf.config.set_visible_devices([], "GPU") + try: + tf.config.set_visible_devices([], "GPU") + except RuntimeError as ex: + if ( + len(ex.args) > 0 + and ex.args[0] + == "Visible devices cannot be modified after being initialized" + ): + print( + "Failed to set visible GPU. Visible devices cannot be modified after being initialized." + ) def use_gpu(device_ind: int): @@ -58,7 +68,17 @@ def use_gpu(device_ind: int): device_ind: Index of the GPU within the list of system GPUs. """ gpus = get_all_gpus() - tf.config.set_visible_devices(gpus[device_ind], "GPU") + try: + tf.config.set_visible_devices(gpus[device_ind], "GPU") + except RuntimeError as ex: + if ( + len(ex.args) > 0 + and ex.args[0] + == "Visible devices cannot be modified after being initialized" + ): + print( + "Failed to set visible GPU. Visible devices cannot be modified after being initialized." + ) def use_first_gpu(): @@ -159,7 +179,7 @@ def summary(): for gpu in all_gpus: print(f" Device: {gpu.name}") print(f" Available: {gpu in gpus}") - print(f" Initalized: {is_initialized(gpu)}") + print(f" Initialized: {is_initialized(gpu)}") print( f" Memory growth: {tf.config.experimental.get_memory_growth(gpu)}" ) diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index 10b2953b7..0b77f4ac9 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -12,9 +12,11 @@ """ + import operator from collections import defaultdict -from typing import List, Tuple, Optional, TypeVar, Callable +import logging +from typing import List, Tuple, Union, Optional, TypeVar, Callable import attr import numpy as np @@ -23,9 +25,26 @@ from sleap import PredictedInstance, Instance, Track from sleap.nn import utils +logger = logging.getLogger(__name__) + InstanceType = TypeVar("InstanceType", Instance, PredictedInstance) +def normalized_instance_similarity( + ref_instance: InstanceType, query_instance: InstanceType, img_hw: Tuple[int] +) -> float: + """Computes similarity between instances with normalized keypoints.""" + + normalize_factors = np.array((img_hw[1], img_hw[0])) + ref_visible = ~(np.isnan(ref_instance.points_array).any(axis=1)) + normalized_query_keypoints = query_instance.points_array / normalize_factors + normalized_ref_keypoints = ref_instance.points_array / normalize_factors + dists = np.sum((normalized_query_keypoints - normalized_ref_keypoints) ** 2, axis=1) + similarity = np.nansum(np.exp(-dists)) / np.sum(ref_visible) + + return similarity + + def instance_similarity( ref_instance: InstanceType, query_instance: InstanceType ) -> float: @@ -40,6 +59,95 @@ def instance_similarity( return similarity +def factory_object_keypoint_similarity( + keypoint_errors: Optional[Union[List, int, float]] = None, + score_weighting: bool = False, + normalization_keypoints: str = "all", +) -> Callable: + """Factory for similarity function based on object keypoints. + + Args: + keypoint_errors: The standard error of the distance between the predicted + keypoint and the true value, in pixels. + If None or empty list, defaults to 1. + If a scalar or singleton list, every keypoint has the same error. + If a list, defines the error for each keypoint, the length should be equal + to the number of keypoints in the skeleton. + score_weighting: If True, use `score` of `PredictedPoint` to weigh + `keypoint_errors`. If False, do not add a weight to `keypoint_errors`. + normalization_keypoints: Determine how to normalize similarity score. One of + ["all", "ref", "union"]. If "all", similarity score is normalized by number + of reference points. If "ref", similarity score is normalized by number of + visible reference points. If "union", similarity score is normalized by + number of points both visible in query and reference instance. + Default is "all". + + Returns: + Callable that returns object keypoint similarity between two `Instance`s. + + """ + keypoint_errors = 1 if keypoint_errors is None else keypoint_errors + with np.errstate(divide="ignore"): + kp_precision = 1 / (2 * np.array(keypoint_errors) ** 2) + + def object_keypoint_similarity( + ref_instance: InstanceType, query_instance: InstanceType + ) -> float: + nonlocal kp_precision + # Keypoints + ref_points = ref_instance.points_array + query_points = query_instance.points_array + # Keypoint scores + if score_weighting: + ref_scores = getattr(ref_instance, "scores", np.ones(len(ref_points))) + query_scores = getattr(query_instance, "scores", np.ones(len(query_points))) + else: + ref_scores = 1 + query_scores = 1 + # Number of keypoint for normalization + if normalization_keypoints in ("ref", "union"): + ref_visible = ~(np.isnan(ref_points).any(axis=1)) + if normalization_keypoints == "ref": + max_n_keypoints = np.sum(ref_visible) + elif normalization_keypoints == "union": + query_visible = ~(np.isnan(query_points).any(axis=1)) + max_n_keypoints = np.sum(np.logical_and(ref_visible, query_visible)) + else: # if normalization_keypoints == "all": + max_n_keypoints = len(ref_points) + if max_n_keypoints == 0: + return 0 + + # Make sure the sizes of kp_precision and n_points match + if kp_precision.size > 1 and 2 * kp_precision.size != ref_points.size: + # Correct kp_precision size to fit number of points + n_points = ref_points.size // 2 + mess = ( + "keypoint_errors array should have the same size as the number of " + f"keypoints in the instance: {kp_precision.size} != {n_points}" + ) + + if kp_precision.size > n_points: + kp_precision = kp_precision[:n_points] + mess += "\nTruncating keypoint_errors array." + + else: # elif kp_precision.size < n_points: + pad = n_points - kp_precision.size + kp_precision = np.pad(kp_precision, (0, pad), "edge") + mess += "\nPadding keypoint_errors array by repeating the last value." + logger.warning(mess) + + # Compute distances + dists = np.sum((query_points - ref_points) ** 2, axis=1) * kp_precision + + similarity = ( + np.nansum(ref_scores * query_scores * np.exp(-dists)) / max_n_keypoints + ) + + return similarity + + return object_keypoint_similarity + + def centroid_distance( ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict() ) -> float: diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 9865b7db5..558aa9309 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -5,12 +5,15 @@ import attr import numpy as np import cv2 +import functools from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple from sleap import Track, LabeledFrame, Skeleton from sleap.nn.tracker.components import ( + factory_object_keypoint_similarity, instance_similarity, + normalized_instance_similarity, centroid_distance, instance_iou, hungarian_matching, @@ -391,6 +394,7 @@ def get_ref_instances( def get_candidates( self, track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], + max_tracking: bool, t: int, img: np.ndarray, *args, @@ -404,7 +408,7 @@ def get_candidates( tracks = [] for track, matched_items in track_matching_queue_dict.items(): - if len(tracks) <= self.max_tracks: + if not max_tracking or len(tracks) < self.max_tracks: tracks.append(track) for matched_item in matched_items: ref_t, ref_img = ( @@ -466,6 +470,7 @@ class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): def get_candidates( self, track_matching_queue_dict: Dict, + max_tracking: bool, *args, **kwargs, ) -> List[InstanceType]: @@ -473,7 +478,7 @@ def get_candidates( candidate_instances = [] tracks = [] for track, matched_instances in track_matching_queue_dict.items(): - if len(tracks) <= self.max_tracks: + if not max_tracking or len(tracks) < self.max_tracks: tracks.append(track) for ref_instance in matched_instances: if ref_instance.instance_t.n_visible_points >= self.min_points: @@ -492,6 +497,8 @@ def get_candidates( instance=instance_similarity, centroid=centroid_distance, iou=instance_iou, + normalized_instance=normalized_instance_similarity, + object_keypoint=factory_object_keypoint_similarity, ) match_policies = dict( @@ -598,8 +605,15 @@ def _init_matching_queue(self): """Factory for instantiating default matching queue with specified size.""" return deque(maxlen=self.track_window) + @property + def has_max_tracking(self) -> bool: + return isinstance( + self.candidate_maker, + (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker), + ) + def reset_candidates(self): - if self.max_tracking: + if self.has_max_tracking: for track in self.track_matching_queue_dict: self.track_matching_queue_dict[track] = deque(maxlen=self.track_window) else: @@ -610,14 +624,15 @@ def unique_tracks_in_queue(self) -> List[Track]: """Returns the unique tracks in the matching queue.""" unique_tracks = set() - for match_item in self.track_matching_queue: - for instance in match_item.instances_t: - unique_tracks.add(instance.track) - - if self.max_tracking: + if self.has_max_tracking: for track in self.track_matching_queue_dict.keys(): unique_tracks.add(track) + else: + for match_item in self.track_matching_queue: + for instance in match_item.instances_t: + unique_tracks.add(instance.track) + return list(unique_tracks) @property @@ -627,6 +642,7 @@ def uses_image(self): def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ) -> List[InstanceType]: @@ -634,19 +650,25 @@ def track( Args: untracked_instances: List of instances to assign to tracks. + img_hw: (height, width) of the image used to normalize the keypoints. img: Image data of the current frame for flow shifting. t: Current timestep. If not provided, increments from the internal queue. Returns: A list of the instances that were tracked. """ + if self.similarity_function == normalized_instance_similarity: + factory_normalized_instance = functools.partial( + normalized_instance_similarity, img_hw=img_hw + ) + self.similarity_function = factory_normalized_instance if self.candidate_maker is None: return untracked_instances # Infer timestep if not provided. if t is None: - if self.max_tracking: + if self.has_max_tracking: if len(self.track_matching_queue_dict) > 0: # Default to last timestep + 1 if available. @@ -684,10 +706,10 @@ def track( self.pre_cull_function(untracked_instances) # Build a pool of matchable candidate instances. - if self.max_tracking: + if self.has_max_tracking: candidate_instances = self.candidate_maker.get_candidates( track_matching_queue_dict=self.track_matching_queue_dict, - max_tracks=self.max_tracks, + max_tracking=self.max_tracking, t=t, img=img, ) @@ -721,13 +743,16 @@ def track( ) # Add the tracked instances to the dictionary of matched instances. - if self.max_tracking: + if self.has_max_tracking: for tracked_instance in tracked_instances: if tracked_instance.track in self.track_matching_queue_dict: self.track_matching_queue_dict[tracked_instance.track].append( MatchedFrameInstance(t, tracked_instance, img) ) - elif len(self.track_matching_queue_dict) < self.max_tracks: + elif ( + not self.max_tracking + or len(self.track_matching_queue_dict) < self.max_tracks + ): self.track_matching_queue_dict[tracked_instance.track] = deque( maxlen=self.track_window ) @@ -773,7 +798,8 @@ def spawn_for_untracked_instances( # Skip if we've reached the maximum number of tracks. if ( - self.max_tracking + self.has_max_tracking + and self.max_tracking and len(self.track_matching_queue_dict) >= self.max_tracks ): break @@ -838,8 +864,17 @@ def make_tracker_by_name( # Max tracking options max_tracks: Optional[int] = None, max_tracking: bool = False, + # Object keypoint similarity options + oks_errors: Optional[list] = None, + oks_score_weighting: bool = False, + oks_normalization: str = "all", **kwargs, ) -> BaseTracker: + # Parse max_tracking arguments, only True if max_tracks is not None and > 0 + max_tracking = max_tracking if max_tracks else False + if max_tracking and tracker in ("simple", "flow"): + # Force a candidate maker of 'maxtracks' type + tracker += "maxtracks" if tracker.lower() == "none": candidate_maker = None @@ -858,7 +893,14 @@ def make_tracker_by_name( raise ValueError(f"{match} is not a valid tracker matching function.") candidate_maker = tracker_policies[tracker](min_points=min_match_points) - similarity_function = similarity_policies[similarity] + if similarity == "object_keypoint": + similarity_function = factory_object_keypoint_similarity( + keypoint_errors=oks_errors, + score_weighting=oks_score_weighting, + normalization_keypoints=oks_normalization, + ) + else: + similarity_function = similarity_policies[similarity] matching_function = match_policies[match] if tracker == "flow": @@ -931,7 +973,10 @@ def get_by_name_factory_options(cls): option = dict(name="max_tracking", default=False) option["type"] = bool - option["help"] = "If true then the tracker will cap the max number of tracks." + option["help"] = ( + "If true then the tracker will cap the max number of tracks. " + "Falls back to false if `max_tracks` is not defined or 0." + ) options.append(option) option = dict(name="max_tracks", default=None) @@ -1054,6 +1099,42 @@ def int_list_func(s): ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." options.append(option) + def float_list_func(s): + return [float(x.strip()) for x in s.split(",")] if s else None + + option = dict(name="oks_errors", default="1") + option["type"] = float_list_func + option["help"] = ( + "For Object Keypoint similarity: the standard error of the distance " + "between the predicted keypoint and the true value, in pixels.\n" + "If None or empty list, defaults to 1. If a scalar or singleton list, " + "every keypoint has the same error. If a list, defines the error for each " + "keypoint, the length should be equal to the number of keypoints in the " + "skeleton." + ) + options.append(option) + + option = dict(name="oks_score_weighting", default="0") + option["type"] = int + option["help"] = ( + "For Object Keypoint similarity: if 0 (default), only the distance between the reference " + "and query keypoint is used to compute the similarity. If 1, each distance is weighted " + "by the prediction scores of the reference and query keypoint." + ) + options.append(option) + + option = dict(name="oks_normalization", default="all") + option["type"] = str + option["options"] = ["all", "ref", "union"] + option["help"] = ( + "For Object Keypoint similarity: Determine how to normalize similarity score. " + "If 'all', similarity score is normalized by number of reference points. " + "If 'ref', similarity score is normalized by number of visible reference points. " + "If 'union', similarity score is normalized by number of points both visible " + "in query and reference instance." + ) + options.append(option) + return options @classmethod @@ -1449,6 +1530,7 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele track_args["img"] = lf.video[lf.frame_idx] else: track_args["img"] = None + track_args["img_hw"] = lf.image.shape[-3:-1] new_lf = LabeledFrame( frame_idx=lf.frame_idx, diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 16f027175..c3692637c 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -508,7 +508,7 @@ def setup_visualization( callbacks = [] try: - matplotlib.use("Qt5Agg") + matplotlib.use("QtAgg") except ImportError: print( "Unable to use Qt backend for matplotlib. " @@ -946,7 +946,7 @@ def train(self): if self.config.outputs.save_outputs: if ( self.config.outputs.save_visualizations - and self.config.outputs.delete_viz_images + and not self.config.outputs.keep_viz_images ): self.cleanup() @@ -997,7 +997,7 @@ def cleanup(self): def package(self): """Package model folder into a zip file for portability.""" - if self.config.outputs.delete_viz_images: + if not self.config.outputs.keep_viz_images: self.cleanup() logger.info(f"Packaging results to: {self.run_path}.zip") shutil.make_archive( @@ -1864,6 +1864,14 @@ def create_trainer_using_cli(args: Optional[List] = None): "already specified in the training job config." ), ) + parser.add_argument( + "--keep_viz", + action="store_true", + help=( + "Keep prediction visualization images in the run folder after training when " + "--save_viz is enabled." + ), + ) parser.add_argument( "--zmq", action="store_true", @@ -1872,6 +1880,18 @@ def create_trainer_using_cli(args: Optional[List] = None): "job config." ), ) + parser.add_argument( + "--publish_port", + type=int, + default=9001, + help="Port to set up the publish address while using ZMQ, defaults to 9001.", + ) + parser.add_argument( + "--controller_port", + type=int, + default=9000, + help="Port to set up the controller address while using ZMQ, defaults to 9000.", + ) parser.add_argument( "--run_name", default="", @@ -1926,6 +1946,10 @@ def create_trainer_using_cli(args: Optional[List] = None): job_config.outputs.tensorboard.write_logs |= args.tensorboard job_config.outputs.zmq.publish_updates |= args.zmq job_config.outputs.zmq.subscribe_to_controller |= args.zmq + job_config.outputs.zmq.controller_address = "tcp://127.0.0.1:" + str( + args.controller_port + ) + job_config.outputs.zmq.publish_address = "tcp://127.0.0.1:" + str(args.publish_port) if args.run_name != "": job_config.outputs.run_name = args.run_name if args.prefix != "": @@ -1933,6 +1957,7 @@ def create_trainer_using_cli(args: Optional[List] = None): if args.suffix != "": job_config.outputs.run_name_suffix = args.suffix job_config.outputs.save_visualizations |= args.save_viz + job_config.outputs.keep_viz_images = args.keep_viz if args.labels_path == "": args.labels_path = None args.video_paths = args.video_paths.split(",") diff --git a/sleap/prefs.py b/sleap/prefs.py index 3d5a2113e..e043afc44 100644 --- a/sleap/prefs.py +++ b/sleap/prefs.py @@ -28,6 +28,8 @@ class Preferences(object): "node label size": 12, "show non-visible nodes": True, "share usage data": True, + "node marker sizes": (1, 2, 3, 4, 6, 8, 12), + "node label sizes": (6, 9, 12, 18, 24, 36), } _filename = "preferences.yaml" @@ -43,10 +45,14 @@ def load_(self): """Load preferences from file (regardless of whether loaded already).""" try: self._prefs = util.get_config_yaml(self._filename) - if not hasattr(self._prefs, "get"): - self._prefs = self._defaults except FileNotFoundError: - self._prefs = self._defaults + pass + + self._prefs = self._prefs or {} + + for k, v in self._defaults.items(): + if k not in self._prefs: + self._prefs[k] = v def save(self): """Save preferences to file.""" diff --git a/sleap/skeleton.py b/sleap/skeleton.py index eca393b8e..fbd1b909c 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -6,24 +6,24 @@ their connection to each other, and needed meta-data. """ -import attr -import cattr -import numpy as np -import jsonpickle -import json -import h5py +import base64 import copy - +import json import operator from enum import Enum +from io import BytesIO from itertools import count -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Text +from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union +import attr +import cattr +import h5py import networkx as nx +import numpy as np from networkx.readwrite import json_graph +from PIL import Image from scipy.io import loadmat - NodeRef = Union[str, "Node"] H5FileRef = Union[str, h5py.File] @@ -85,6 +85,502 @@ def matches(self, other: "Node") -> bool: return other.name == self.name and other.weight == self.weight +class SkeletonDecoder: + """Replace jsonpickle.decode with our own decoder. + + This function will decode the following from jsonpickle's encoded format: + + `Node` objects from + { + "py/object": "sleap.skeleton.Node", + "py/state": { "py/tuple": ["thorax1", 1.0] } + } + to `Node(name="thorax1", weight=1.0)` + + `EdgeType` objects from + { + "py/reduce": [ + { "py/type": "sleap.skeleton.EdgeType" }, + { "py/tuple": [1] } + ] + } + to `EdgeType(1)` + + `bytes` from + { + "py/b64": "aVZC..." + } + to `b"iVBO..."` + + and any repeated objects from + { + "py/id": 1 + } + to the object with the same reconstruction id (from top to bottom). + """ + + def __init__(self): + self.decoded_objects: List[Union[Node, EdgeType]] = [] + + def _decode_id(self, id: int) -> Union[Node, EdgeType]: + """Decode the object with the given `py/id` value of `id`. + + Args: + id: The `py/id` value to decode (1-indexed). + objects: The dictionary of objects that have already been decoded. + + Returns: + The object with the given `py/id` value. + """ + return self.decoded_objects[id - 1] + + @staticmethod + def _decode_state(state: dict) -> Node: + """Reconstruct the `Node` object from 'py/state' key in the serialized nx_graph. + + We support states in either dictionary or tuple format: + { + "py/state": { "py/tuple": ["thorax1", 1.0] } + } + or + { + "py/state": {"name": "thorax1", "weight": 1.0} + } + + Args: + state: The state to decode, i.e. state = dict["py/state"] + + Returns: + The `Node` object reconstructed from the state. + """ + + if "py/tuple" in state: + return Node(*state["py/tuple"]) + + return Node(**state) + + @staticmethod + def _decode_object_dict(object_dict) -> Node: + """Decode dict containing `py/object` key in the serialized nx_graph. + + Args: + object_dict: The dict to decode, i.e. + object_dict = {"py/object": ..., "py/state":...} + + Raises: + ValueError: If object_dict does not have 'py/object' and 'py/state' keys. + ValueError: If object_dict['py/object'] is not 'sleap.skeleton.Node'. + + Returns: + The decoded `Node` object. + """ + + if object_dict["py/object"] != "sleap.skeleton.Node": + raise ValueError("Only 'sleap.skeleton.Node' objects are supported.") + + node: Node = SkeletonDecoder._decode_state(state=object_dict["py/state"]) + return node + + def _decode_node(self, encoded_node: dict) -> Node: + """Decode an item believed to be an encoded `Node` object. + + Also updates the list of decoded objects. + + Args: + encoded_node: The encoded node to decode. + + Returns: + The decoded node and the updated list of decoded objects. + """ + + if isinstance(encoded_node, int): + # Using index mapping to replace the object (load from Labels) + return encoded_node + elif "py/object" in encoded_node: + decoded_node: Node = SkeletonDecoder._decode_object_dict(encoded_node) + self.decoded_objects.append(decoded_node) + elif "py/id" in encoded_node: + decoded_node: Node = self._decode_id(encoded_node["py/id"]) + + return decoded_node + + def _decode_nodes(self, encoded_nodes: List[dict]) -> List[Dict[str, Node]]: + """Decode the 'nodes' key in the serialized nx_graph. + + The encoded_nodes is a list of dictionary of two types: + - A dictionary with 'py/object' and 'py/state' keys. + - A dictionary with 'py/id' key. + + Args: + encoded_nodes: The list of encoded nodes to decode. + + Returns: + The decoded nodes. + """ + + decoded_nodes: List[Dict[str, Node]] = [] + for e_node_dict in encoded_nodes: + e_node = e_node_dict["id"] + d_node = self._decode_node(e_node) + decoded_nodes.append({"id": d_node}) + + return decoded_nodes + + def _decode_reduce_dict(self, reduce_dict: Dict[str, List[dict]]) -> EdgeType: + """Decode the 'reduce' key in the serialized nx_graph. + + The reduce_dict is a dictionary in the following format: + { + "py/reduce": [ + { "py/type": "sleap.skeleton.EdgeType" }, + { "py/tuple": [1] } + ] + } + + Args: + reduce_dict: The dictionary to decode i.e. reduce_dict = {"py/reduce": ...} + + Returns: + The decoded `EdgeType` object. + """ + + reduce_list = reduce_dict["py/reduce"] + has_py_type = has_py_tuple = False + for reduce_item in reduce_list: + if reduce_item is None: + # Sometimes the reduce list has None values, skip them + continue + if ( + "py/type" in reduce_item + and reduce_item["py/type"] == "sleap.skeleton.EdgeType" + ): + has_py_type = True + elif "py/tuple" in reduce_item: + edge_type: int = reduce_item["py/tuple"][0] + has_py_tuple = True + + if not has_py_type or not has_py_tuple: + raise ValueError( + "Only 'sleap.skeleton.EdgeType' objects are supported. " + "The 'py/reduce' list must have dictionaries with 'py/type' and " + "'py/tuple' keys." + f"\n\tHas py/type: {has_py_type}\n\tHas py/tuple: {has_py_tuple}" + ) + + edge = EdgeType(edge_type) + self.decoded_objects.append(edge) + + return edge + + def _decode_edge_type(self, encoded_edge_type: dict) -> EdgeType: + """Decode the 'type' key in the serialized nx_graph. + + Args: + encoded_edge_type: a dictionary with either 'py/id' or 'py/reduce' key. + + Returns: + The decoded `EdgeType` object. + """ + + if "py/reduce" in encoded_edge_type: + edge_type = self._decode_reduce_dict(encoded_edge_type) + else: + # Expect a "py/id" instead of "py/reduce" + edge_type = self._decode_id(encoded_edge_type["py/id"]) + return edge_type + + def _decode_links( + self, links: List[dict] + ) -> List[Dict[str, Union[int, Node, EdgeType]]]: + """Decode the 'links' key in the serialized nx_graph. + + The links are the edges in the graph and will have the following keys: + - source: The source node of the edge. + - target: The destination node of the edge. + - type: The type of the edge (e.g. BODY, SYMMETRY). + and more. + + Args: + encoded_links: The list of encoded links to decode. + """ + + for link in links: + for key, value in link.items(): + if key == "source": + link[key] = self._decode_node(value) + elif key == "target": + link[key] = self._decode_node(value) + elif key == "type": + link[key] = self._decode_edge_type(value) + + return links + + @staticmethod + def decode_preview_image( + img_b64: bytes, return_bytes: bool = False + ) -> Union[Image.Image, bytes]: + """Decode a skeleton preview image byte string representation to a `PIL.Image` + + Args: + img_b64: a byte string representation of a skeleton preview image + return_bytes: whether to return the decoded image as bytes + + Returns: + Either a PIL.Image of the skeleton preview image or the decoded image as bytes + (if `return_bytes` is True). + """ + bytes = base64.b64decode(img_b64) + if return_bytes: + return bytes + + buffer = BytesIO(bytes) + img = Image.open(buffer) + return img + + def _decode(self, json_str: str): + dicts = json.loads(json_str) + + # Enforce same format across template and non-template skeletons + if "nx_graph" not in dicts: + # Non-template skeletons use the dicts as the "nx_graph" + dicts = {"nx_graph": dicts} + + # Decode the graph + nx_graph = dicts["nx_graph"] + + self.decoded_objects = [] # Reset the decoded objects incase reusing decoder + for key, value in nx_graph.items(): + if key == "nodes": + nx_graph[key] = self._decode_nodes(value) + elif key == "links": + nx_graph[key] = self._decode_links(value) + + # Decode the preview image (if it exists) + preview_image = dicts.get("preview_image", None) + if preview_image is not None: + dicts["preview_image"] = SkeletonDecoder.decode_preview_image( + preview_image["py/b64"], return_bytes=True + ) + + return dicts + + @classmethod + def decode(cls, json_str: str) -> Dict: + """Decode the given json string into a dictionary. + + Returns: + A dict with `Node`s, `EdgeType`s, and `bytes` decoded/reconstructed. + """ + decoder = cls() + return decoder._decode(json_str) + + +class SkeletonEncoder: + """Replace jsonpickle.encode with our own encoder. + + The input is a dictionary containing python objects that need to be encoded as + JSON strings. The output is a JSON string that represents the input dictionary. + + `Node(name='neck', weight=1.0)` => + { + "py/object": "sleap.Skeleton.Node", + "py/state": {"py/tuple" ["neck", 1.0]} + } + + `` => + {"py/reduce": [ + {"py/type": "sleap.Skeleton.EdgeType"}, + {"py/tuple": [1] } + ] + }` + + Where `name` and `weight` are the attributes of the `Node` class; weight is always 1.0. + `EdgeType` is an enum with values `BODY = 1` and `SYMMETRY = 2`. + + See sleap.skeleton.Node and sleap.skeleton.EdgeType. + + If the object has been "seen" before, it will not be encoded as the full JSON string + but referenced by its `py/id`, which starts at 1 and indexes the objects in the + order they are seen so that the second time the first object is used, it will be + referenced as `{"py/id": 1}`. + """ + + def __init__(self): + """Initializes a SkeletonEncoder instance.""" + # Maps object id to py/id + self._encoded_objects: Dict[int, int] = {} + + @classmethod + def encode(cls, data: Dict[str, Any]) -> str: + """Encodes the input dictionary as a JSON string. + + Args: + data: The data to encode. + + Returns: + json_str: The JSON string representation of the data. + """ + + # This is required for backwards compatibility with SLEAP <=1.3.4 + sorted_data = cls._recursively_sort_dict(data) + + encoder = cls() + encoded_data = encoder._encode(sorted_data) + json_str = json.dumps(encoded_data) + return json_str + + @staticmethod + def _recursively_sort_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]: + """Recursively sorts the dictionary by keys.""" + sorted_dict = dict(sorted(dictionary.items())) + for key, value in sorted_dict.items(): + if isinstance(value, dict): + sorted_dict[key] = SkeletonEncoder._recursively_sort_dict(value) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, dict): + sorted_dict[key][i] = SkeletonEncoder._recursively_sort_dict( + item + ) + return sorted_dict + + def _encode(self, obj: Any) -> Any: + """Recursively encodes the input object. + + Args: + obj: The object to encode. Can be a dictionary, list, Node, EdgeType or + primitive data type. + + Returns: + The encoded object as a dictionary. + """ + if isinstance(obj, dict): + encoded_obj = {} + for key, value in obj.items(): + if key == "links": + encoded_obj[key] = self._encode_links(value) + else: + encoded_obj[key] = self._encode(value) + return encoded_obj + elif isinstance(obj, list): + return [self._encode(v) for v in obj] + elif isinstance(obj, EdgeType): + return self._encode_edge_type(obj) + elif isinstance(obj, Node): + return self._encode_node(obj) + else: + return obj # Primitive data types + + def _encode_links(self, links: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Encodes the list of links (edges) in the skeleton graph. + + Args: + links: A list of dictionaries, each representing an edge in the graph. + + Returns: + A list of encoded edge dictionaries with keys ordered as specified. + """ + encoded_links = [] + for link in links: + # Use a regular dict (insertion order preserved in Python 3.7+) + encoded_link = {} + + for key, value in link.items(): + if key in ("source", "target"): + encoded_link[key] = self._encode_node(value) + elif key == "type": + encoded_link[key] = self._encode_edge_type(value) + else: + encoded_link[key] = self._encode(value) + encoded_links.append(encoded_link) + + return encoded_links + + def _encode_node(self, node: Union["Node", int]) -> Dict[str, Any]: + """Encodes a Node object. + + Args: + node: The Node object to encode or integer index. The latter requires that + the class has the `idx_to_node` attribute set. + + Returns: + The encoded `Node` object as a dictionary. + """ + if isinstance(node, int): + # We sometimes have the node object already replaced by its index (when + # `node_to_idx` is provided). In this case, the node is already encoded. + return node + + # Check if object has been encoded before + first_encoding = self._is_first_encoding(node) + py_id = self._get_or_assign_id(node, first_encoding) + if first_encoding: + # Full encoding + return { + "py/object": "sleap.skeleton.Node", + "py/state": {"py/tuple": [node.name, node.weight]}, + } + else: + # Reference by py/id + return {"py/id": py_id} + + def _encode_edge_type(self, edge_type: "EdgeType") -> Dict[str, Any]: + """Encodes an EdgeType object. + + Args: + edge_type: The EdgeType object to encode. Either `EdgeType.BODY` or + `EdgeType.SYMMETRY` enum with values 1 and 2 respectively. + + Returns: + The encoded EdgeType object as a dictionary. + """ + # Check if object has been encoded before + first_encoding = self._is_first_encoding(edge_type) + py_id = self._get_or_assign_id(edge_type, first_encoding) + if first_encoding: + # Full encoding + return { + "py/reduce": [ + {"py/type": "sleap.skeleton.EdgeType"}, + {"py/tuple": [edge_type.value]}, + ] + } + else: + # Reference by py/id + return {"py/id": py_id} + + def _get_or_assign_id(self, obj: Any, first_encoding: bool) -> int: + """Gets or assigns a py/id for the object. + + Args: + The object to get or assign a py/id for. + + Returns: + The py/id assigned to the object. + """ + # Object id is unique for each object in the current session + obj_id = id(obj) + # Assign a py/id to the object if it hasn't been assigned one yet + if first_encoding: + py_id = len(self._encoded_objects) + 1 # py/id starts at 1 + # Assign the py/id to the object and store it in _encoded_objects + self._encoded_objects[obj_id] = py_id + return self._encoded_objects[obj_id] + + def _is_first_encoding(self, obj: Any) -> bool: + """Checks if the object is being encoded for the first time. + + Args: + obj: The object to check. + + Returns: + True if this is the first encoding of the object, False otherwise. + """ + obj_id = id(obj) + first_time = obj_id not in self._encoded_objects + return first_time + + class Skeleton: """The main object for representing animal skeletons. @@ -937,7 +1433,7 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D # This is a weird hack to serialize the whole _graph into a dict. # I use the underlying to_json and parse it. - return json.loads(obj.to_json(node_to_idx)) + return json.loads(obj.to_json(node_to_idx=node_to_idx)) @classmethod def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": @@ -999,12 +1495,12 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: Returns: A string containing the JSON representation of the skeleton. """ - jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) + if node_to_idx is not None: - indexed_node_graph = nx.relabel_nodes( - G=self._graph, mapping=node_to_idx - ) # map nodes to int + # Map Nodes to int + indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx) else: + # Keep graph nodes as Node objects indexed_node_graph = self._graph # Encode to JSON @@ -1023,7 +1519,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: else: data = graph - json_str = jsonpickle.encode(data) + json_str = SkeletonEncoder.encode(data) return json_str @@ -1071,7 +1567,7 @@ def from_json( Returns: An instance of the `Skeleton` object decoded from the JSON. """ - dicts = jsonpickle.decode(json_str) + dicts: dict = SkeletonDecoder.decode(json_str) nx_graph = dicts.get("nx_graph", dicts) graph = json_graph.node_link_graph(nx_graph) diff --git a/sleap/training_profiles/baseline.centroid.json b/sleap/training_profiles/baseline.centroid.json index 933989ecf..3a54db25c 100755 --- a/sleap/training_profiles/baseline.centroid.json +++ b/sleap/training_profiles/baseline.centroid.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.bottomup.json b/sleap/training_profiles/baseline_large_rf.bottomup.json index ea45c9b25..18fb3104f 100644 --- a/sleap/training_profiles/baseline_large_rf.bottomup.json +++ b/sleap/training_profiles/baseline_large_rf.bottomup.json @@ -125,6 +125,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.single.json b/sleap/training_profiles/baseline_large_rf.single.json index 75e97b1a6..3feeccd69 100644 --- a/sleap/training_profiles/baseline_large_rf.single.json +++ b/sleap/training_profiles/baseline_large_rf.single.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.topdown.json b/sleap/training_profiles/baseline_large_rf.topdown.json index 9b17f6832..38e96594b 100644 --- a/sleap/training_profiles/baseline_large_rf.topdown.json +++ b/sleap/training_profiles/baseline_large_rf.topdown.json @@ -117,6 +117,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.bottomup.json b/sleap/training_profiles/baseline_medium_rf.bottomup.json index 1cc35330a..61b08515c 100644 --- a/sleap/training_profiles/baseline_medium_rf.bottomup.json +++ b/sleap/training_profiles/baseline_medium_rf.bottomup.json @@ -125,6 +125,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.single.json b/sleap/training_profiles/baseline_medium_rf.single.json index 579f6c8c3..0951bc761 100644 --- a/sleap/training_profiles/baseline_medium_rf.single.json +++ b/sleap/training_profiles/baseline_medium_rf.single.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.topdown.json b/sleap/training_profiles/baseline_medium_rf.topdown.json index 9e3a0bde5..9eccb76c1 100755 --- a/sleap/training_profiles/baseline_medium_rf.topdown.json +++ b/sleap/training_profiles/baseline_medium_rf.topdown.json @@ -117,6 +117,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.bottomup.json b/sleap/training_profiles/pretrained.bottomup.json index 3e4f3935f..57b7398b5 100644 --- a/sleap/training_profiles/pretrained.bottomup.json +++ b/sleap/training_profiles/pretrained.bottomup.json @@ -122,6 +122,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.centroid.json b/sleap/training_profiles/pretrained.centroid.json index a5df5e48a..74c43d3e2 100644 --- a/sleap/training_profiles/pretrained.centroid.json +++ b/sleap/training_profiles/pretrained.centroid.json @@ -113,6 +113,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.single.json b/sleap/training_profiles/pretrained.single.json index 7ca907007..615f0de4d 100644 --- a/sleap/training_profiles/pretrained.single.json +++ b/sleap/training_profiles/pretrained.single.json @@ -113,6 +113,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.topdown.json b/sleap/training_profiles/pretrained.topdown.json index aeeaebbd8..be0d97de8 100644 --- a/sleap/training_profiles/pretrained.topdown.json +++ b/sleap/training_profiles/pretrained.topdown.json @@ -114,6 +114,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/util.py b/sleap/util.py index 5edbf164b..bc3389b7d 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -3,13 +3,11 @@ Try not to put things in here unless they really have no other place. """ -import base64 import json import os import re import shutil from collections import defaultdict -from io import BytesIO from pathlib import Path from typing import Any, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse @@ -26,7 +24,6 @@ from importlib.resources import files # New in 3.9+ except ImportError: from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. -from PIL import Image import sleap.version as sleap_version @@ -270,30 +267,20 @@ def get_config_file( The full path to the specified config file. """ - desired_path = None # Handle case where get_defaults, but cannot find package_path + desired_path = Path.home() / f".sleap/{sleap_version.__version__}/{shortname}" - if not get_defaults: - desired_path = os.path.expanduser( - f"~/.sleap/{sleap_version.__version__}/{shortname}" - ) + # Make sure there's a ~/.sleap// directory to store user version of the config file. + desired_path.parent.mkdir(parents=True, exist_ok=True) - # Make sure there's a ~/.sleap// directory to store user version of the - # config file. - try: - os.makedirs(os.path.expanduser(f"~/.sleap/{sleap_version.__version__}")) - except FileExistsError: - pass - - # If we don't care whether the file exists, just return the path - if ignore_file_not_found: - return desired_path - - # If we do care whether the file exists, check the package version of the - # config file if we can't find the user version. + # If we don't care whether the file exists, just return the path + if ignore_file_not_found: + return desired_path - if get_defaults or not os.path.exists(desired_path): + # If we do care whether the file exists, check the package version of the config file if we can't find the user version. + if get_defaults or not desired_path.exists(): package_path = get_package_file(f"config/{shortname}") - if not os.path.exists(package_path): + package_path = Path(package_path) + if not package_path.exists(): raise FileNotFoundError( f"Cannot locate {shortname} config file at {desired_path} or {package_path}." ) @@ -384,18 +371,3 @@ def find_files_by_suffix( def parse_uri_path(uri: str) -> str: """Parse a URI starting with 'file:///' to a posix path.""" return Path(url2pathname(urlparse(unquote(uri)).path)).as_posix() - - -def decode_preview_image(img_b64: bytes) -> Image: - """Decode a skeleton preview image byte string representation to a `PIL.Image` - - Args: - img_b64: a byte string representation of a skeleton preview image - - Returns: - A PIL.Image of the skeleton preview - """ - bytes = base64.b64decode(img_b64) - buffer = BytesIO(bytes) - img = Image.open(buffer) - return img diff --git a/sleap/version.py b/sleap/version.py index 437e17fba..698710132 100644 --- a/sleap/version.py +++ b/sleap/version.py @@ -11,8 +11,7 @@ Must be a semver string, "aN" should be appended for alpha releases. """ - -__version__ = "1.3.3" +__version__ = "1.4.1" def versions(): diff --git a/tests/data/dlc/labeled-data/video/CollectedData_LM.csv b/tests/data/dlc/labeled-data/video/CollectedData_LM.csv index f57b667f4..27c86f8af 100644 --- a/tests/data/dlc/labeled-data/video/CollectedData_LM.csv +++ b/tests/data/dlc/labeled-data/video/CollectedData_LM.csv @@ -1,8 +1,8 @@ -scorer,,,LM,LM,LM,LM,LM,LM,LM,LM,LM,LM,LM,LM -individuals,,,individual1,individual1,individual1,individual1,individual1,individual1,individual2,individual2,individual2,individual2,individual2,individual2 -bodyparts,,,A,A,B,B,C,C,A,A,B,B,C,C -coords,,,x,y,x,y,x,y,x,y,x,y,x,y -labeled-data,video,img000.png,0,1,2,3,4,5,6,7,8,9,10,11 -labeled-data,video,img001.png,12,13,,,15,16,17,18,,,20,21 +scorer,,,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer +individuals,,,Animal1,Animal1,Animal1,Animal1,Animal1,Animal1,Animal2,Animal2,Animal2,Animal2,Animal2,Animal2,single,single,single,single +bodyparts,,,A,A,B,B,C,C,A,A,B,B,C,C,D,D,E,E +coords,,,x,y,x,y,x,y,x,y,x,y,x,y,x,y,x,y +labeled-data,video,img000.png,0,1,2,3,4,5,6,7,8,9,10,11,,,, +labeled-data,video,img001.png,12,13,,,15,16,17,18,,,20,21,22,23,24,25 labeled-data,video,img002.png,,,,,,,,,,,, -labeled-data,video,img003.png,22,23,24,25,26,27,,,,,, +labeled-data,video,img003.png,26,27,28,29,30,31,,,,,,,32,33,34,35 diff --git a/tests/data/dlc/labeled-data/video/maudlc_testdata.csv b/tests/data/dlc/labeled-data/video/maudlc_testdata.csv new file mode 100644 index 000000000..4e3e3c28c --- /dev/null +++ b/tests/data/dlc/labeled-data/video/maudlc_testdata.csv @@ -0,0 +1,8 @@ +scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer +individuals,Animal1,Animal1,Animal1,Animal1,Animal1,Animal1,Animal2,Animal2,Animal2,Animal2,Animal2,Animal2,single,single,single,single +bodyparts,A,A,B,B,C,C,A,A,B,B,C,C,D,D,E,E +coords,x,y,x,y,x,y,x,y,x,y,x,y,x,y,x,y +labeled-data/video/img000.png,0,1,2,3,4,5,6,7,8,9,10,11,,,, +labeled-data/video/img001.png,12,13,,,15,16,17,18,,,20,21,22,23,24,25 +labeled-data/video/img002.png,,,,,,,,,,,, +labeled-data/video/img003.png,26,27,28,29,30,31,,,,,,,32,33,34,35 diff --git a/tests/data/dlc/labeled-data/video/maudlc_testdata_v2.csv b/tests/data/dlc/labeled-data/video/maudlc_testdata_v2.csv new file mode 100644 index 000000000..27c86f8af --- /dev/null +++ b/tests/data/dlc/labeled-data/video/maudlc_testdata_v2.csv @@ -0,0 +1,8 @@ +scorer,,,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer +individuals,,,Animal1,Animal1,Animal1,Animal1,Animal1,Animal1,Animal2,Animal2,Animal2,Animal2,Animal2,Animal2,single,single,single,single +bodyparts,,,A,A,B,B,C,C,A,A,B,B,C,C,D,D,E,E +coords,,,x,y,x,y,x,y,x,y,x,y,x,y,x,y,x,y +labeled-data,video,img000.png,0,1,2,3,4,5,6,7,8,9,10,11,,,, +labeled-data,video,img001.png,12,13,,,15,16,17,18,,,20,21,22,23,24,25 +labeled-data,video,img002.png,,,,,,,,,,,, +labeled-data,video,img003.png,26,27,28,29,30,31,,,,,,,32,33,34,35 diff --git a/tests/data/dlc/madlc_230_config.yaml b/tests/data/dlc/madlc_230_config.yaml index ae2cbb44b..01e1d32c1 100644 --- a/tests/data/dlc/madlc_230_config.yaml +++ b/tests/data/dlc/madlc_230_config.yaml @@ -1,12 +1,12 @@ # Project definitions (do not edit) -Task: madlc_2.3.0 +Task: maudlc_2.3.0 scorer: LM date: Mar1 multianimalproject: true identity: false # Project path (change when moving around) -project_path: D:\social-leap-estimates-animal-poses\pull-requests\sleap\tests\data\dlc\madlc_testdata_v3 +project_path: D:\social-leap-estimates-animal-poses\pull-requests\sleap\tests\data\dlc\maudlc_testdata_v3 # Annotation data set configuration (and individual video cropping parameters) video_sets: @@ -16,7 +16,9 @@ individuals: - individual1 - individual2 - individual3 -uniquebodyparts: [] +uniquebodyparts: +- D +- E multianimalbodyparts: - A - B diff --git a/tests/data/hdf5_format_v1/small_robot.000_small_robot_3_frame.analysis.h5 b/tests/data/hdf5_format_v1/small_robot.000_small_robot_3_frame.analysis.h5 new file mode 100644 index 000000000..d2cec1d1b Binary files /dev/null and b/tests/data/hdf5_format_v1/small_robot.000_small_robot_3_frame.analysis.h5 differ diff --git a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json index 7e52d1703..2ae0e925c 100644 --- a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +++ b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json @@ -128,6 +128,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json index bcb2f26d5..7b6f817aa 100644 --- a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +++ b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json @@ -191,6 +191,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json index 045890b21..5d8081628 100644 --- a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json +++ b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json @@ -141,7 +141,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, - "delete_viz_images": true, + "keep_viz_images": false, "zip_outputs": false, "log_to_csv": true, "checkpointing": { diff --git a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json index 070e9d3c0..9591e5b52 100644 --- a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json +++ b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json @@ -208,7 +208,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, - "delete_viz_images": true, + "keep_viz_images": false, "zip_outputs": false, "log_to_csv": true, "checkpointing": { diff --git a/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json b/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json index 8e39fea3f..68e4f894e 100644 --- a/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json @@ -127,6 +127,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.bottomup/training_config.json b/tests/data/models/minimal_instance.UNet.bottomup/training_config.json index d1fb718ba..e3bfbc5f8 100644 --- a/tests/data/models/minimal_instance.UNet.bottomup/training_config.json +++ b/tests/data/models/minimal_instance.UNet.bottomup/training_config.json @@ -192,6 +192,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json b/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json index 739d8e3e7..f4914aae4 100644 --- a/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json @@ -119,6 +119,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json b/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json index 7b6782a68..e747f6862 100644 --- a/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json +++ b/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json @@ -179,6 +179,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centroid/initial_config.json b/tests/data/models/minimal_instance.UNet.centroid/initial_config.json index 41d8ac8c3..977654b2e 100644 --- a/tests/data/models/minimal_instance.UNet.centroid/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.centroid/initial_config.json @@ -118,6 +118,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centroid/training_config.json b/tests/data/models/minimal_instance.UNet.centroid/training_config.json index 2d2280a31..02e9683e1 100644 --- a/tests/data/models/minimal_instance.UNet.centroid/training_config.json +++ b/tests/data/models/minimal_instance.UNet.centroid/training_config.json @@ -175,6 +175,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json b/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json index cb2e4f353..f2bb907fa 100644 --- a/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json +++ b/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json @@ -120,6 +120,7 @@ "" ], "save_visualizations": false, + "keep_viz_images": true, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_robot.UNet.single_instance/training_config.json b/tests/data/models/minimal_robot.UNet.single_instance/training_config.json index 66901c9f0..dffecc1d9 100644 --- a/tests/data/models/minimal_robot.UNet.single_instance/training_config.json +++ b/tests/data/models/minimal_robot.UNet.single_instance/training_config.json @@ -180,6 +180,7 @@ "" ], "save_visualizations": false, + "keep_viz_images": true, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json b/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json new file mode 100644 index 000000000..eae83d6bc --- /dev/null +++ b/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json @@ -0,0 +1 @@ +{"directed": true, "graph": {"name": "skeleton_legs.mat", "num_edges_inserted": 23}, "links": [{"edge_insert_idx": 1, "key": 0, "source": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "neck", "weight": 1.0}}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "head", "weight": 1.0}}, "type": {"py/reduce": [{"py/type": "sleap.skeleton.EdgeType"}, {"py/tuple": [1]}]}}, {"edge_insert_idx": 0, "key": 0, "source": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "thorax", "weight": 1.0}}, "target": {"py/id": 1}, "type": {"py/id": 3}}, {"edge_insert_idx": 2, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "abdomen", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 3, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "wingL", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 4, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "wingR", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 5, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 8, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 11, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 14, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 17, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 20, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 6, "key": 0, "source": {"py/id": 8}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 7, "key": 0, "source": {"py/id": 14}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 9, "key": 0, "source": {"py/id": 9}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 10, "key": 0, "source": {"py/id": 16}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 12, "key": 0, "source": {"py/id": 10}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 13, "key": 0, "source": {"py/id": 18}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 15, "key": 0, "source": {"py/id": 11}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 16, "key": 0, "source": {"py/id": 20}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 18, "key": 0, "source": {"py/id": 12}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 19, "key": 0, "source": {"py/id": 22}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 21, "key": 0, "source": {"py/id": 13}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 22, "key": 0, "source": {"py/id": 24}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR3", "weight": 1.0}}, "type": {"py/id": 3}}], "multigraph": true, "nodes": [{"id": {"py/id": 2}}, {"id": {"py/id": 1}}, {"id": {"py/id": 4}}, {"id": {"py/id": 5}}, {"id": {"py/id": 6}}, {"id": {"py/id": 7}}, {"id": {"py/id": 8}}, {"id": {"py/id": 14}}, {"id": {"py/id": 15}}, {"id": {"py/id": 9}}, {"id": {"py/id": 16}}, {"id": {"py/id": 17}}, {"id": {"py/id": 10}}, {"id": {"py/id": 18}}, {"id": {"py/id": 19}}, {"id": {"py/id": 11}}, {"id": {"py/id": 20}}, {"id": {"py/id": 21}}, {"id": {"py/id": 12}}, {"id": {"py/id": 22}}, {"id": {"py/id": 23}}, {"id": {"py/id": 13}}, {"id": {"py/id": 24}}, {"id": {"py/id": 25}}]} \ No newline at end of file diff --git a/tests/data/tracks/clip.predictions.slp b/tests/data/tracks/clip.predictions.slp new file mode 100644 index 000000000..652e21302 Binary files /dev/null and b/tests/data/tracks/clip.predictions.slp differ diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 801fcc092..c6507caec 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -41,6 +41,13 @@ def centered_pair_predictions(): return Labels.load_file(TEST_JSON_PREDICTIONS) +@pytest.fixture +def centered_pair_predictions_sorted(centered_pair_predictions): + labels: Labels = centered_pair_predictions + labels.labeled_frames.sort(key=lambda lf: lf.frame_idx) + return labels + + @pytest.fixture def min_labels(): return Labels.load_file(TEST_JSON_MIN_LABELS) @@ -90,6 +97,20 @@ def min_tracks_2node_labels(): ) +@pytest.fixture +def min_tracks_2node_predictions(): + """ + Generated with: + ``` + sleap-track -m "tests/data/models/min_tracks_2node.UNet.bottomup_multiclass" "tests/data/tracks/clip.mp4" + ``` + """ + return Labels.load_file( + "tests/data/tracks/clip.predictions.slp", + video_search=["tests/data/tracks/clip.mp4"], + ) + + @pytest.fixture def min_tracks_13node_labels(): return Labels.load_file( diff --git a/tests/fixtures/skeletons.py b/tests/fixtures/skeletons.py index 311510e6a..b432ca2c7 100644 --- a/tests/fixtures/skeletons.py +++ b/tests/fixtures/skeletons.py @@ -3,14 +3,27 @@ from sleap.skeleton import Skeleton TEST_FLY_LEGS_SKELETON = "tests/data/skeleton/fly_skeleton_legs.json" +TEST_FLY_LEGS_SKELETON_DICT = "tests/data/skeleton/fly_skeleton_legs_pystate_dict.json" @pytest.fixture def fly_legs_skeleton_json(): - """Path to fly_skeleton_legs.json""" + """Path to fly_skeleton_legs.json + + This skeleton json has py/state in tuple format. + """ return TEST_FLY_LEGS_SKELETON +@pytest.fixture +def fly_legs_skeleton_dict_json(): + """Path to fly_skeleton_legs_pystate_dict.json + + This skeleton json has py/state dict format. + """ + return TEST_FLY_LEGS_SKELETON_DICT + + @pytest.fixture def stickman(): diff --git a/tests/fixtures/videos.py b/tests/fixtures/videos.py index b160caedd..08974b3de 100644 --- a/tests/fixtures/videos.py +++ b/tests/fixtures/videos.py @@ -1,12 +1,21 @@ import pytest from sleap.io.video import Video +from sleap.io.format.filehandle import FileHandle TEST_H5_FILE = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" TEST_H5_DSET = "/box" TEST_H5_CONFMAPS = "/confmaps" TEST_H5_AFFINITY = "/pafs" TEST_H5_INPUT_FORMAT = "channels_first" +TEST_SMALL_ROBOT3_FRAME_H5 = ( + "tests/data/hdf5_format_v1/small_robot.000_small_robot_3_frame.analysis.h5" +) + + +@pytest.fixture +def small_robot_3_frame_hdf5(): + return FileHandle(filename=TEST_SMALL_ROBOT3_FRAME_H5) @pytest.fixture diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index 3d77c891f..389bb48a3 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -7,6 +7,7 @@ import pytest from qtpy import QtWidgets +import sleap from sleap.gui.learning.dialog import LearningDialog, TrainingEditorWidget from sleap.gui.learning.configs import ( TrainingConfigFilesWidget, @@ -429,3 +430,22 @@ def test_immutablilty_of_trained_config_info( # saving multiple configs from one config info. ld.save(output_dir=tmpdir) ld.save(output_dir=tmpdir) + + +def test_validate_id_model(qtbot, min_labels_slp, min_labels_slp_path): + app = MainWindow(no_usage_data=True) + ld = LearningDialog( + mode="training", + labels_filename=Path(min_labels_slp_path), + labels=min_labels_slp, + ) + assert not ld._validate_id_model() + + # Add track but don't assign it to instances + new_track = sleap.Track(name="new_track") + min_labels_slp.tracks.append(new_track) + assert not ld._validate_id_model() + + # Assign track to instances + min_labels_slp[0][0].track = new_track + assert ld._validate_id_model() diff --git a/tests/gui/test_app.py b/tests/gui/test_app.py index bacda4ae3..def835b6e 100644 --- a/tests/gui/test_app.py +++ b/tests/gui/test_app.py @@ -142,6 +142,7 @@ def assert_frame_chunk_suggestion_ui_updated( # Select and delete instance app.state["instance"] = inst_27_1 app.commands.deleteSelectedInstance() + assert app.state["instance"] is None assert len(app.state["labeled_frame"].instances) == 1 assert app.state["labeled_frame"].instances == [inst_27_0] @@ -179,6 +180,7 @@ def assert_frame_chunk_suggestion_ui_updated( # Delete all instances in track app.commands.deleteSelectedInstanceTrack() + assert app.state["instance"] is None assert len(app.state["labeled_frame"].instances) == 0 app.state["frame_idx"] = 29 @@ -412,6 +414,12 @@ def toggle_and_verify_visibility(expected_visibility: bool = True): window.showNormal() vp = window.player + # Change state and ensure menu-item check updates + color_predicted = window.state["color predicted"] + assert window._menu_actions["color predicted"].isChecked() == color_predicted + window.state["color predicted"] = not color_predicted + assert window._menu_actions["color predicted"].isChecked() == (not color_predicted) + # Enable distinct colors window.state["color predicted"] = True diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 13aa60e6b..e19e00236 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -3,11 +3,15 @@ import sys import time +import numpy as np from pathlib import PurePath, Path +from qtpy import QtCore from typing import List from sleap import Skeleton, Track, PredictedInstance +from sleap.gui.app import MainWindow from sleap.gui.commands import ( + AddInstance, CommandContext, ExportAnalysisFile, ExportDatasetWithImages, @@ -16,6 +20,7 @@ ReplaceVideo, OpenSkeleton, SaveProjectAs, + DeleteFrameLimitPredictions, get_new_version_filename, ) from sleap.instance import Instance, LabeledFrame @@ -65,7 +70,7 @@ def test_import_labels_from_dlc_folder(): assert len(labels.videos) == 2 assert len(labels.skeletons) == 1 assert len(labels.nodes) == 3 - assert len(labels.tracks) == 0 + assert len(labels.tracks) == 3 assert set( [fix_path_separator(l.video.backend.filename) for l in labels.labeled_frames] @@ -847,6 +852,26 @@ def load_and_assert_changes(new_video_path: Path): shutil.move(new_video_path, expected_video_path) +def test_DeleteFrameLimitPredictions( + centered_pair_predictions: Labels, centered_pair_vid: Video +): + """Test deleting instances beyond a certain frame limit.""" + labels = centered_pair_predictions + + # Set-up command context + context = CommandContext.from_labels(labels) + context.state["video"] = centered_pair_vid + + # Set-up params for the command + params = {"min_frame_idx": 900, "max_frame_idx": 1000} + + instances_to_delete = DeleteFrameLimitPredictions.get_frame_instance_list( + context, params + ) + + assert len(instances_to_delete) == 2070 + + @pytest.mark.parametrize("export_extension", [".json.zip", ".slp"]) def test_exportLabelsPackage(export_extension, centered_pair_labels: Labels, tmpdir): def assert_loaded_package_similar(path_to_pkg: Path, sugg=False, pred=False): @@ -922,3 +947,102 @@ def no_gui_ask(cls, context, params): # Case 3: Export all frames and suggested frames with image data. context.exportFullPackage() assert_loaded_package_similar(path_to_pkg, sugg=True, pred=True) + + +def test_newInstance(qtbot, centered_pair_predictions: Labels): + + # Get the data + labels = centered_pair_predictions + lf = labels[0] + pred_inst = lf.instances[0] + video = labels.video + + # Set-up command context + main_window = MainWindow(labels=labels) + context = main_window.commands + context.state["labeled_frame"] = lf + context.state["frame_idx"] = lf.frame_idx + context.state["skeleton"] = labels.skeleton + context.state["video"] = labels.videos[0] + + # Case 1: Double clicking a prediction results in no offset for new instance + + # Double click on prediction + assert len(lf.instances) == 2 + main_window._handle_instance_double_click(instance=pred_inst) + + # Check new instance + assert len(lf.instances) == 3 + new_inst = lf.instances[-1] + assert new_inst.from_predicted is pred_inst + assert np.array_equal(new_inst.numpy(), pred_inst.numpy()) # No offset + + # Case 2: Using Ctrl + I (or menu "Add Instance" button) + + # Connect the action to a slot + add_instance_menu_action = main_window._menu_actions["add instance"] + triggered = False + + def on_triggered(): + nonlocal triggered + triggered = True + + add_instance_menu_action.triggered.connect(on_triggered) + + # Find which instance we are going to copy from + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=None, init_method="best" + ) + + # Click on the menu action + assert len(lf.instances) == 3 + add_instance_menu_action.trigger() + assert triggered, "Action not triggered" + + # Check new instance + assert len(lf.instances) == 4 + new_inst = lf.instances[-1] + offset = 10 + np.nan_to_num(new_inst.numpy() - copy_instance.numpy(), nan=offset) + assert np.all( + np.nan_to_num(new_inst.numpy() - copy_instance.numpy(), nan=offset) == offset + ) + + # Case 3: Using right click and "Default" option + + # Find which instance we are going to copy from + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=None, init_method="best" + ) + + video_player = main_window.player + right_click_location_x = video.shape[2] / 2 + right_click_location_y = video.shape[1] / 2 + right_click_location = QtCore.QPointF( + right_click_location_x, right_click_location_y + ) + video_player.create_contextual_menu(scene_pos=right_click_location) + default_action = video_player._menu_actions["Default"] + default_action.trigger() + + # Check new instance + assert len(lf.instances) == 5 + new_inst = lf.instances[-1] + reference_node_idx = np.where( + np.all( + new_inst.numpy() == [right_click_location_x, right_click_location_y], axis=1 + ) + )[0][0] + offset = ( + new_inst.numpy()[reference_node_idx] - copy_instance.numpy()[reference_node_idx] + ) + diff = np.nan_to_num(new_inst.numpy() - copy_instance.numpy(), nan=offset) + assert np.all(diff == offset) diff --git a/tests/gui/test_dialogs.py b/tests/gui/test_dialogs.py index 4455550fb..611a73c85 100644 --- a/tests/gui/test_dialogs.py +++ b/tests/gui/test_dialogs.py @@ -1,6 +1,5 @@ """Module to test the dialogs of the GUI (contained in sleap/gui/dialogs).""" - import os from pathlib import Path diff --git a/tests/gui/test_monitor.py b/tests/gui/test_monitor.py index 51af0ca92..e0abea692 100644 --- a/tests/gui/test_monitor.py +++ b/tests/gui/test_monitor.py @@ -1,4 +1,3 @@ -from turtle import title from sleap.gui.widgets.monitor import LossViewer from sleap import TrainingJobConfig @@ -12,6 +11,9 @@ def test_monitor_release(qtbot, min_centroid_model_path): win.reset(what="Model Type", config=config) assert win.config.optimization.early_stopping.plateau_patience == 10 + # Ensure zmq port is set correctly + assert win.zmq_ports["controller_port"] == 9000 + assert win.zmq_ports["publish_port"] == 9001 # Ensure all lines of update_runtime() are run error-free win.is_running = True win.t0 = 0 @@ -28,13 +30,17 @@ def test_monitor_release(qtbot, min_centroid_model_path): # Enter "bes_val_x" conditional win.best_val_x = 0 win.best_val_y = win.last_epoch_val_loss - win.update_runtime() + win._update_runtime() win.close() # Make sure the first monitor released its zmq socket - win2 = LossViewer() + controller_port = 9191 + zmq_ports = dict(controller_port=controller_port) + win2 = LossViewer(zmq_ports=zmq_ports) win2.show() + assert win2.zmq_ports["controller_port"] == controller_port + assert win2.zmq_ports["publish_port"] == 9001 # Make sure batches to show field is working correction @@ -47,3 +53,14 @@ def test_monitor_release(qtbot, min_centroid_model_path): assert win2.batches_to_show == 200 win2.close() + + # Ensure zmq port is set correctly + controller_port = 9191 + publish_port = 9101 + zmq_ports = dict(controller_port=controller_port, publish_port=publish_port) + win3 = LossViewer(zmq_ports=zmq_ports) + win3.show() + assert win3.zmq_ports["controller_port"] == controller_port + assert win3.zmq_ports["publish_port"] == publish_port + + win3.close() diff --git a/tests/gui/test_suggestions.py b/tests/gui/test_suggestions.py index bbad73179..196ff2d35 100644 --- a/tests/gui/test_suggestions.py +++ b/tests/gui/test_suggestions.py @@ -24,6 +24,20 @@ def test_velocity_suggestions(centered_pair_predictions): assert suggestions[1].frame_idx == 45 +def test_max_point_displacement_suggestions(centered_pair_predictions): + suggestions = VideoFrameSuggestions.suggest( + labels=centered_pair_predictions, + params=dict( + videos=centered_pair_predictions.videos, + method="max_point_displacement", + displacement_threshold=6, + ), + ) + assert len(suggestions) == 19 + assert suggestions[0].frame_idx == 28 + assert suggestions[1].frame_idx == 82 + + def test_frame_increment(centered_pair_predictions: Labels): # Testing videos that have less frames than desired Samples per Video (stride) # Expected result is there should be n suggestions where n is equal to the frames diff --git a/tests/gui/test_video_player.py b/tests/gui/test_video_player.py index b0661a4e1..c246f0489 100644 --- a/tests/gui/test_video_player.py +++ b/tests/gui/test_video_player.py @@ -3,14 +3,13 @@ from sleap.gui.widgets.video import ( QtVideoPlayer, GraphicsView, - QtInstance, QtVideoPlayer, QtTextWithBackground, VisibleBoundingBox, ) from qtpy import QtCore, QtWidgets -from qtpy.QtGui import QColor +from qtpy.QtGui import QColor, QWheelEvent def test_gui_video(qtbot): @@ -20,10 +19,6 @@ def test_gui_video(qtbot): assert vp.close() - # Click the button 20 times - # for i in range(20): - # qtbot.mouseClick(vp.btn, QtCore.Qt.LeftButton) - def test_gui_video_instances(qtbot, small_robot_mp4_vid, centered_pair_labels): vp = QtVideoPlayer(small_robot_mp4_vid) @@ -144,3 +139,40 @@ def test_VisibleBoundingBox(qtbot, centered_pair_labels): # Check if bounding box scaled appropriately assert inst.box.rect().width() - initial_width == 2 * dx assert inst.box.rect().height() - initial_height == 2 * dy + + +def test_wheelEvent(qtbot): + """Test the wheelEvent method of the GraphicsView class.""" + graphics_view = GraphicsView() + + # Create a QWheelEvent + position = QtCore.QPointF(100, 100) # The position of the wheel event + global_position = QtCore.QPointF(100, 100) # The global position of the wheel event + pixel_delta = QtCore.QPoint(0, 120) # The distance in pixels the wheel is rotated + angle_delta = QtCore.QPoint(0, 120) # The distance in degrees the wheel is rotated + buttons = QtCore.Qt.NoButton # No mouse button is pressed + modifiers = QtCore.Qt.NoModifier # No keyboard modifier is pressed + phase = QtCore.Qt.ScrollUpdate # The phase of the scroll event + inverted = False # The scroll direction is not inverted + source = ( + QtCore.Qt.MouseEventNotSynthesized + ) # The event is not synthesized from a touch or tablet event + + event = QWheelEvent( + position, + global_position, + pixel_delta, + angle_delta, + buttons, + modifiers, + phase, + inverted, + source, + ) + + # Call the wheelEvent method + print( + "Testing GraphicsView.wheelEvent which will result in exit code 127 " + "originating from a segmentation fault if it fails." + ) + graphics_view.wheelEvent(event) diff --git a/tests/gui/widgets/test_docks.py b/tests/gui/widgets/test_docks.py index 69fe56a56..d5c16a763 100644 --- a/tests/gui/widgets/test_docks.py +++ b/tests/gui/widgets/test_docks.py @@ -1,15 +1,17 @@ """Module for testing dock widgets for the `MainWindow`.""" from pathlib import Path -import pytest + +import numpy as np + from sleap import Labels, Video from sleap.gui.app import MainWindow -from sleap.gui.commands import OpenSkeleton +from sleap.gui.commands import AddInstance, OpenSkeleton from sleap.gui.widgets.docks import ( InstancesDock, + SkeletonDock, SuggestionsDock, VideosDock, - SkeletonDock, ) @@ -99,11 +101,35 @@ def test_suggestions_dock(qtbot): assert dock.wgt_layout is dock.widget().layout() -def test_instances_dock(qtbot): +def test_instances_dock(qtbot, centered_pair_predictions: Labels): """Test the `DockWidget` class.""" - main_window = MainWindow() + main_window = MainWindow(labels=centered_pair_predictions) + labels = main_window.labels + context = main_window.commands + lf = context.state["labeled_frame"] dock = InstancesDock(main_window) assert dock.name == "Instances" assert dock.main_window is main_window assert dock.wgt_layout is dock.widget().layout() + + # Test new instance button + + offset = 10 + + # Find instance that we will copy from + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=None, init_method="best" + ) + n_instance = len(lf.instances) + dock.main_window._buttons["new instance"].click() + + # Check that new instance was added with offset + assert len(lf.instances) == n_instance + 1 + new_inst = lf.instances[-1] + diff = np.nan_to_num(new_inst.numpy() - copy_instance.numpy(), nan=offset) + assert np.all(diff == offset) diff --git a/tests/info/test_metrics.py b/tests/info/test_metrics.py new file mode 100644 index 000000000..0d2e097e6 --- /dev/null +++ b/tests/info/test_metrics.py @@ -0,0 +1,55 @@ +import numpy as np + +from sleap import Labels +from sleap.info.metrics import ( + match_instance_lists_nodewise, + matched_instance_distances, +) + + +def test_matched_instance_distances(centered_pair_labels, centered_pair_predictions): + labels_gt = centered_pair_labels + labels_pr = centered_pair_predictions + + # Match each ground truth instance node to the closest corresponding node + # from any predicted instance in the same frame. + + inst_matching_func = match_instance_lists_nodewise + + # Calculate distances + frame_idxs, D, points_gt, points_pr = matched_instance_distances( + labels_gt, labels_pr, inst_matching_func + ) + + # Show mean difference for each node + node_names = labels_gt.skeletons[0].node_names + expected_values = { + "head": 0.872426920709296, + "neck": 0.8016280746914615, + "thorax": 0.8602021363390538, + "abdomen": 1.01012200038258, + "wingL": 1.1297727023475939, + "wingR": 1.0869857897008424, + "forelegL1": 0.780584225081443, + "forelegL2": 1.170805798894702, + "forelegL3": 1.1020486509389473, + "forelegR1": 0.9014698776116817, + "forelegR2": 0.9448001033112047, + "forelegR3": 1.308385214215777, + "midlegL1": 0.9095691623265347, + "midlegL2": 1.2203595627907582, + "midlegL3": 0.9813843358470163, + "midlegR1": 0.9871017182813739, + "midlegR2": 1.0209829335569256, + "midlegR3": 1.0990681234096988, + "hindlegL1": 1.0005335192834348, + "hindlegL2": 1.273539518539708, + "hindlegL3": 1.1752245985832817, + "hindlegR1": 1.1402833959265248, + "hindlegR2": 1.3143221301212737, + "hindlegR3": 1.0441458592503365, + } + + for node_idx, node_name in enumerate(node_names): + mean_d = np.nanmean(D[..., node_idx]) + assert np.isclose(mean_d, expected_values[node_name], atol=1e-6) diff --git a/tests/info/test_summary.py b/tests/info/test_summary.py index 2cf76c166..672d97e63 100644 --- a/tests/info/test_summary.py +++ b/tests/info/test_summary.py @@ -37,6 +37,19 @@ def test_frame_statistics(simple_predictions): x = stats.get_point_displacement_series(video, "max") assert len(x) == 2 - assert len(x) == 2 assert x[0] == 0 assert x[1] == 18.0 + + +def test_get_tracking_score_series(min_tracks_2node_predictions): + + stats = StatisticSeries(min_tracks_2node_predictions) + x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "min") + assert len(x) == 1500 + assert x[0] == 0.9999966621398926 + assert x[1000] == 0.9998022317886353 + + x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "mean") + assert len(x) == 1500 + assert x[0] == 0.9999983310699463 + assert x[1000] == 0.9999011158943176 diff --git a/tests/io/test_asyncvideo.py b/tests/io/test_asyncvideo.py deleted file mode 100644 index 1bc3f19c8..000000000 --- a/tests/io/test_asyncvideo.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest -import sys -from sleap import Video -from sleap.io.asyncvideo import AsyncVideo - - -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="ZMQ testing breaks locally on Windows" -) -def test_async_video(centered_pair_vid, small_robot_mp4_vid): - async_video = AsyncVideo.from_video(centered_pair_vid, frames_per_chunk=23) - - all_idxs = [] - for idxs, frames in async_video.chunks: - assert len(idxs) in (23, 19) # 19 for last chunk - all_idxs.extend(idxs) - - assert frames.shape[0] == len(idxs) - assert frames.shape[1:] == centered_pair_vid.shape[1:] - - assert len(all_idxs) == centered_pair_vid.num_frames - - # make sure we can load another video (i.e., previous video closed) - - async_video = AsyncVideo.from_video( - small_robot_mp4_vid, frame_idxs=range(0, 10, 2), frames_per_chunk=10 - ) - - for idxs, frames in async_video.chunks: - # there should only be single chunk - assert idxs == list(range(0, 10, 2)) diff --git a/tests/io/test_convert.py b/tests/io/test_convert.py index da1971c11..738c3d625 100644 --- a/tests/io/test_convert.py +++ b/tests/io/test_convert.py @@ -8,7 +8,7 @@ import pytest -@pytest.mark.parametrize("format", ["analysis", "analysis.nix"]) +@pytest.mark.parametrize("format", ["analysis", "analysis.nix", "analysis.csv"]) def test_analysis_format( min_labels_slp: Labels, min_labels_slp_path: Labels, @@ -27,7 +27,7 @@ def generate_filenames(paths, format="analysis"): labels_path = str(slp_path) fn = re.sub("(\\.json(\\.zip)?|\\.h5|\\.slp)$", "", labels_path) fn = PurePath(fn) - out_suffix = "nix" if "nix" in format else "h5" + out_suffix = "nix" if "nix" in format else "csv" if "csv" in format else "h5" default_names = [ default_analysis_filename( labels=labels, diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 5592ae437..d71d4cc83 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -1,9 +1,11 @@ import os +import pandas as pd import pytest import numpy as np from pathlib import Path, PurePath import sleap +from sleap.info.write_tracking_h5 import get_nodes_as_np_strings from sleap.skeleton import Skeleton from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track from sleap.io.video import Video, MediaVideo @@ -1234,7 +1236,7 @@ def test_has_frame(): @pytest.fixture def removal_test_labels(): skeleton = Skeleton() - video = Video(backend=MediaVideo(filename="test")) + video = Video(backend=MediaVideo(filename="test.mp4")) lf_user_only = LabeledFrame( video=video, frame_idx=0, instances=[Instance(skeleton=skeleton)] ) @@ -1559,3 +1561,45 @@ def test_export_nwb(centered_pair_predictions: Labels, tmpdir): # Read from NWB file read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) assert_read_labels_match(centered_pair_predictions, read_labels) + + +@pytest.mark.parametrize( + "labels_fixture_name", + [ + "centered_pair_labels", + "centered_pair_predictions", + "min_labels", + "min_labels_slp", + "min_labels_robot", + ], +) +def test_export_csv(labels_fixture_name, tmpdir, request): + # Retrieve Labels fixture by name + labels_fixture = request.getfixturevalue(labels_fixture_name) + + # Generate the filename for the CSV file + csv_filename = Path(tmpdir) / (labels_fixture_name + "_export.csv") + + # Export to CSV file + labels_fixture.export_csv(str(csv_filename)) + + # Assert that the CSV file was created + assert csv_filename.is_file(), f"CSV file '{csv_filename}' was not created" + + +def test_exported_csv(tmpdir, min_labels_slp, minimal_instance_predictions_csv_path): + # Construct the filename for the CSV file + filename_csv = Path(tmpdir) / "minimal_instance_predictions_export.csv" + labels = min_labels_slp + # Export to CSV file + labels.export_csv(filename_csv) + # Read the CSV file + labels_csv = pd.read_csv(filename_csv) + + # Read the csv file fixture + csv_predictions = pd.read_csv(minimal_instance_predictions_csv_path) + + assert labels_csv.equals(csv_predictions) + + # check number of cols + assert len(labels_csv.columns) - 3 == len(get_nodes_as_np_strings(labels)) * 3 diff --git a/tests/io/test_formats.py b/tests/io/test_formats.py index a89bf60d7..cee754b7c 100644 --- a/tests/io/test_formats.py +++ b/tests/io/test_formats.py @@ -19,6 +19,19 @@ from sleap.gui.app import MainWindow from sleap.gui.state import GuiState from sleap.info.write_tracking_h5 import get_nodes_as_np_strings +from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor + + +def test_sleap_analysis_read(small_robot_3_frame_vid, small_robot_3_frame_hdf5): + + # Single instance hdf5 analysis file test + read_labels = SleapAnalysisAdaptor.read( + file=small_robot_3_frame_hdf5, video=small_robot_3_frame_vid + ) + + assert len(read_labels.videos) == 1 + assert len(read_labels.tracks) == 1 + assert len(read_labels.skeletons) == 1 def test_text_adaptor(tmpdir): @@ -198,7 +211,6 @@ def test_matching_adaptor(centered_pair_predictions_hdf5_path): [ "tests/data/dlc/labeled-data/video/madlc_testdata.csv", "tests/data/dlc/labeled-data/video/madlc_testdata_v2.csv", - "tests/data/dlc/madlc_230_config.yaml", ], ) def test_madlc(test_data): @@ -232,6 +244,78 @@ def test_madlc(test_data): assert labels[2].frame_idx == 3 +@pytest.mark.parametrize( + "test_data", + [ + "tests/data/dlc/labeled-data/video/maudlc_testdata.csv", + "tests/data/dlc/labeled-data/video/maudlc_testdata_v2.csv", + "tests/data/dlc/madlc_230_config.yaml", + ], +) +def test_maudlc(test_data): + labels = read( + test_data, + for_object="labels", + as_format="deeplabcut", + ) + + assert labels.skeleton.node_names == ["A", "B", "C", "D", "E"] + assert len(labels.videos) == 1 + assert len(labels.video.filenames) == 4 + assert labels.videos[0].filenames[0].endswith("img000.png") + assert labels.videos[0].filenames[1].endswith("img001.png") + assert labels.videos[0].filenames[2].endswith("img002.png") + assert labels.videos[0].filenames[3].endswith("img003.png") + + # Assert frames without any coor are not labeled + assert len(labels) == 3 + + # Assert number of instances per frame is correct + assert len(labels[0]) == 2 + assert len(labels[1]) == 3 + assert len(labels[2]) == 2 + + assert_array_equal( + labels[0][0].numpy(), + [[0, 1], [2, 3], [4, 5], [np.nan, np.nan], [np.nan, np.nan]], + ) + assert_array_equal( + labels[0][1].numpy(), + [[6, 7], [8, 9], [10, 11], [np.nan, np.nan], [np.nan, np.nan]], + ) + assert_array_equal( + labels[1][0].numpy(), + [[12, 13], [np.nan, np.nan], [15, 16], [np.nan, np.nan], [np.nan, np.nan]], + ) + assert_array_equal( + labels[1][1].numpy(), + [[17, 18], [np.nan, np.nan], [20, 21], [np.nan, np.nan], [np.nan, np.nan]], + ) + assert_array_equal( + labels[1][2].numpy(), + [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan], [22, 23], [24, 25]], + ) + assert_array_equal( + labels[2][0].numpy(), + [[26, 27], [28, 29], [30, 31], [np.nan, np.nan], [np.nan, np.nan]], + ) + assert_array_equal( + labels[2][1].numpy(), + [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan], [32, 33], [34, 35]], + ) + assert labels[2].frame_idx == 3 + + # Assert tracks are correct + assert len(labels.tracks) == 3 + sorted_animals = sorted(["Animal1", "Animal2", "single"]) + assert sorted([t.name for t in labels.tracks]) == sorted_animals + for t in labels.tracks: + if t.name == "single": + assert t.spawned_on == 1 + else: + assert t.spawned_on == 0 + + @pytest.mark.parametrize( "test_data", [ diff --git a/tests/io/test_videowriter.py b/tests/io/test_videowriter.py index dea193117..35d9bc6df 100644 --- a/tests/io/test_videowriter.py +++ b/tests/io/test_videowriter.py @@ -1,5 +1,7 @@ import os -from sleap.io.videowriter import VideoWriter, VideoWriterOpenCV +import cv2 +from pathlib import Path +from sleap.io.videowriter import VideoWriter, VideoWriterOpenCV, VideoWriterImageio def test_video_writer(tmpdir, small_robot_mp4_vid): @@ -38,3 +40,62 @@ def test_cv_video_writer(tmpdir, small_robot_mp4_vid): writer.close() assert os.path.exists(out_path) + + +def test_imageio_video_writer_avi(tmpdir, small_robot_mp4_vid): + out_path = Path(tmpdir) / "clip.avi" + + # Make sure imageio video writer works + writer = VideoWriterImageio( + out_path, + height=small_robot_mp4_vid.height, + width=small_robot_mp4_vid.width, + fps=small_robot_mp4_vid.fps, + ) + + writer.add_frame(small_robot_mp4_vid[0][0]) + writer.add_frame(small_robot_mp4_vid[1][0]) + + writer.close() + + assert os.path.exists(out_path) + # Check attributes + assert writer.height == small_robot_mp4_vid.height + assert writer.width == small_robot_mp4_vid.width + assert writer.fps == small_robot_mp4_vid.fps + assert writer.filename == out_path + assert writer.crf == 21 + assert writer.preset == "superfast" + + +def test_imageio_video_writer_odd_size(tmpdir, movenet_video): + out_path = Path(tmpdir) / "clip.mp4" + + # Reduce the size of the video frames by 1 pixel in each dimension + reduced_height = movenet_video.height - 1 + reduced_width = movenet_video.width - 1 + + # Initialize the writer with the reduced dimensions + writer = VideoWriterImageio( + out_path, + height=reduced_height, + width=reduced_width, + fps=movenet_video.fps, + ) + + # Resize frames and add them to the video + for i in range(len(movenet_video) - 1): + frame = movenet_video[i][0] # Access the actual frame object + reduced_frame = cv2.resize(frame, (reduced_width, reduced_height)) + writer.add_frame(reduced_frame) + + writer.close() + + # Assertions to validate the test + assert os.path.exists(out_path) + assert writer.height == reduced_height + assert writer.width == reduced_width + assert writer.fps == movenet_video.fps + assert writer.filename == out_path + assert writer.crf == 21 + assert writer.preset == "superfast" diff --git a/tests/io/test_visuals.py b/tests/io/test_visuals.py index d6144e2c1..a1223bfdf 100644 --- a/tests/io/test_visuals.py +++ b/tests/io/test_visuals.py @@ -1,6 +1,7 @@ import numpy as np import os import pytest +import cv2 from sleap.io.dataset import Labels from sleap.io.visuals import ( save_labeled_video, @@ -63,6 +64,46 @@ def test_serial_pipeline(centered_pair_predictions, tmpdir): ) +@pytest.mark.parametrize("background", ["original", "black", "white", "grey"]) +def test_sleap_render_with_different_backgrounds(background): + args = ( + f"-o test_{background}.avi -f 2 --scale 1.2 --frames 1,2 --video-index 0 " + f"--background {background} " + "tests/data/json_format_v2/centered_pair_predictions.json".split() + ) + sleap_render(args) + assert ( + os.path.exists(f"test_{background}.avi") + and os.path.getsize(f"test_{background}.avi") > 0 + ) + + # Check if the background is set correctly if not original background + if background != "original": + saved_video_path = f"test_{background}.avi" + cap = cv2.VideoCapture(saved_video_path) + ret, frame = cap.read() + + # Calculate mean color of the channels + b, g, r = cv2.split(frame) + mean_b = np.mean(b) + mean_g = np.mean(g) + mean_r = np.mean(r) + + # Set threshold values. Color is white if greater than white threshold, black + # if less than grey threshold and grey if in between both threshold values. + white_threshold = 240 + grey_threshold = 40 + + # Check if the average color is white, grey, or black + if all(val > white_threshold for val in [mean_b, mean_g, mean_r]): + background_color = "white" + elif all(val < grey_threshold for val in [mean_b, mean_g, mean_r]): + background_color = "black" + else: + background_color = "grey" + assert background_color == background + + def test_sleap_render(centered_pair_predictions): args = ( "-o testvis.avi -f 2 --scale 1.2 --frames 1,2 --video-index 0 " diff --git a/tests/nn/architectures/test_common.py b/tests/nn/architectures/test_common.py index a40d621ef..96db870ea 100644 --- a/tests/nn/architectures/test_common.py +++ b/tests/nn/architectures/test_common.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.architectures import common diff --git a/tests/nn/architectures/test_encoder_decoder.py b/tests/nn/architectures/test_encoder_decoder.py index 3ce019371..8b8f51f0a 100644 --- a/tests/nn/architectures/test_encoder_decoder.py +++ b/tests/nn/architectures/test_encoder_decoder.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.architectures import encoder_decoder diff --git a/tests/nn/architectures/test_hourglass.py b/tests/nn/architectures/test_hourglass.py index 4efe79a1c..c45ff1b91 100644 --- a/tests/nn/architectures/test_hourglass.py +++ b/tests/nn/architectures/test_hourglass.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.architectures import hourglass from sleap.nn.config import HourglassConfig diff --git a/tests/nn/architectures/test_leap.py b/tests/nn/architectures/test_leap.py index edf07396b..9a73c80d5 100644 --- a/tests/nn/architectures/test_leap.py +++ b/tests/nn/architectures/test_leap.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.architectures import leap from sleap.nn.config import LEAPConfig diff --git a/tests/nn/architectures/test_pretrained_encoders.py b/tests/nn/architectures/test_pretrained_encoders.py index f318754ac..b1f7e0af8 100644 --- a/tests/nn/architectures/test_pretrained_encoders.py +++ b/tests/nn/architectures/test_pretrained_encoders.py @@ -3,7 +3,7 @@ import pytest from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.architectures import UnetPretrainedEncoder from sleap.nn.config import PretrainedEncoderConfig diff --git a/tests/nn/architectures/test_resnet.py b/tests/nn/architectures/test_resnet.py index 965ea3b72..b0d9d26eb 100644 --- a/tests/nn/architectures/test_resnet.py +++ b/tests/nn/architectures/test_resnet.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.architectures import upsampling from sleap.nn.architectures import resnet diff --git a/tests/nn/architectures/test_unet.py b/tests/nn/architectures/test_unet.py index 98b6d7768..1dad7ea05 100644 --- a/tests/nn/architectures/test_unet.py +++ b/tests/nn/architectures/test_unet.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.architectures import unet from sleap.nn.config import UNetConfig diff --git a/tests/nn/config/test_config_utils.py b/tests/nn/config/test_config_utils.py index 69e8ddec8..64d83a141 100644 --- a/tests/nn/config/test_config_utils.py +++ b/tests/nn/config/test_config_utils.py @@ -4,7 +4,7 @@ from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.config import utils diff --git a/tests/nn/data/test_augmentation.py b/tests/nn/data/test_augmentation.py index d2b468522..2b95a01a3 100644 --- a/tests/nn/data/test_augmentation.py +++ b/tests/nn/data/test_augmentation.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import tensorflow as tf import sleap @@ -9,14 +10,95 @@ from sleap.nn.data import augmentation +@pytest.fixture +def dummy_instances_data_nans(): + return np.full((2, 2), np.nan, dtype=np.float32) + + +@pytest.fixture +def dummy_instances_data_mixed(): + return np.array([[0.1, np.nan], [0.0, 0.8]], dtype=np.float32) + + +@pytest.fixture +def dummy_image_data(): + return np.zeros((100, 100, 3), dtype=np.uint8) + + +@pytest.fixture +def dummy_instances_data_zeros(): + return np.zeros((2, 2), dtype=np.float32) + + +@pytest.fixture +def rotation_min_angle(): + return 90 + + +@pytest.fixture +def rotation_max_angle(): + return 90 + + +@pytest.fixture +def augmentation_config(rotation_min_angle, rotation_max_angle): + return augmentation.AugmentationConfig( + rotate=True, + rotation_min_angle=rotation_min_angle, + rotation_max_angle=rotation_max_angle, + ) + + +@pytest.fixture +def dummy_dataset(dummy_image_data, dummy_instances_data_zeros): + dataset = tf.data.Dataset.from_tensor_slices( + {"image": [dummy_image_data], "instances": [dummy_instances_data_zeros]} + ) + return dataset + + +@pytest.fixture +def augmenter(augmentation_config): + return augmentation.AlbumentationsAugmenter.from_config(augmentation_config) + + +# Test class instantiation and augmentation +@pytest.mark.parametrize( + "dummy_instances_data", + [ + pytest.param("dummy_instances_data_zeros", id="zeros"), + pytest.param("dummy_instances_data_nans", id="nans"), + pytest.param("dummy_instances_data_mixed", id="mixed"), + ], +) +def test_albumentations_augmenter( + dummy_image_data, dummy_instances_data, augmenter, dummy_dataset +): + # Apply augmentation + augmented_dataset = augmenter.transform_dataset(dummy_dataset) + + # Check if augmentation is applied + augmented_example = next(iter(augmented_dataset)) + assert augmented_example["image"].shape == (100, 100, 3) + assert augmented_example["instances"].shape == (2, 2) + + +# Test class method from_config +def test_albumentations_augmenter_from_config(augmentation_config): + augmenter = augmentation.AlbumentationsAugmenter.from_config(augmentation_config) + assert isinstance(augmenter, augmentation.AlbumentationsAugmenter) + assert augmenter.image_key == "image" + assert augmenter.instances_key == "instances" + + def test_augmentation(min_labels): labels_reader = providers.LabelsReader.from_user_instances(min_labels) ds = labels_reader.make_dataset() example_preaug = next(iter(ds)) - augmenter = augmentation.ImgaugAugmenter.from_config( + augmenter = augmentation.AlbumentationsAugmenter.from_config( augmentation.AugmentationConfig( - rotate=True, rotation_min_angle=-90, rotation_max_angle=-90 + rotate=True, rotation_min_angle=90, rotation_max_angle=90 ) ) ds = augmenter.transform_dataset(ds) @@ -52,13 +134,39 @@ def test_augmentation_with_no_instances(min_labels): ) p = min_labels.to_pipeline(user_labeled_only=False) - p += augmentation.ImgaugAugmenter.from_config( + p += augmentation.AlbumentationsAugmenter.from_config( augmentation.AugmentationConfig(rotate=True) ) exs = p.run() assert exs[-1]["instances"].shape[0] == 0 +def test_augmentation_edges(min_labels): + # Tests 1722 + height, width = min_labels[0].video.shape[1:3] + min_labels[0].instances.append( + sleap.Instance.from_numpy( + [[0, 0], [width, height]], + skeleton=min_labels.skeleton, + ) + ) + + labels_reader = providers.LabelsReader.from_user_instances(min_labels) + ds = labels_reader.make_dataset() + example_preaug = next(iter(ds)) + + augmenter = augmentation.AlbumentationsAugmenter.from_config( + augmentation.AugmentationConfig( + rotate=True, rotation_min_angle=90, rotation_max_angle=90 + ) + ) + ds = augmenter.transform_dataset(ds) + + example = next(iter(ds)) + # TODO: check for correctness + assert example["instances"].shape == (3, 2, 2) + + def test_random_cropper(min_labels): cropper = augmentation.RandomCropper(crop_height=64, crop_width=32) assert "image" in cropper.input_keys diff --git a/tests/nn/data/test_data_training.py b/tests/nn/data/test_data_training.py index eb79464e0..c90a29365 100644 --- a/tests/nn/data/test_data_training.py +++ b/tests/nn/data/test_data_training.py @@ -3,7 +3,7 @@ from sleap.nn.data.training import split_labels_train_val -sleap.use_cpu_only() # hide GPUs for test +# sleap.use_cpu_only() # hide GPUs for test def test_split_labels_train_val(): diff --git a/tests/nn/data/test_edge_maps.py b/tests/nn/data/test_edge_maps.py index 295360538..5eb13f9b8 100644 --- a/tests/nn/data/test_edge_maps.py +++ b/tests/nn/data/test_edge_maps.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.data import providers from sleap.nn.data import edge_maps diff --git a/tests/nn/data/test_identity.py b/tests/nn/data/test_identity.py index 52d25dd1b..224eff0ba 100644 --- a/tests/nn/data/test_identity.py +++ b/tests/nn/data/test_identity.py @@ -10,7 +10,7 @@ ) -sleap.use_cpu_only() +# sleap.use_cpu_only() def test_make_class_vectors(): diff --git a/tests/nn/data/test_instance_centroids.py b/tests/nn/data/test_instance_centroids.py index 78dee251c..2d8f57627 100644 --- a/tests/nn/data/test_instance_centroids.py +++ b/tests/nn/data/test_instance_centroids.py @@ -3,7 +3,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test import sleap from sleap.nn.data import providers diff --git a/tests/nn/data/test_instance_cropping.py b/tests/nn/data/test_instance_cropping.py index b54fb0e99..688f50dbd 100644 --- a/tests/nn/data/test_instance_cropping.py +++ b/tests/nn/data/test_instance_cropping.py @@ -3,7 +3,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.data import providers from sleap.nn.data import instance_centroids diff --git a/tests/nn/data/test_normalization.py b/tests/nn/data/test_normalization.py index 20a1df4ec..d2eb7c290 100644 --- a/tests/nn/data/test_normalization.py +++ b/tests/nn/data/test_normalization.py @@ -3,7 +3,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.data import normalization from sleap.nn.data import providers diff --git a/tests/nn/data/test_offset_regression.py b/tests/nn/data/test_offset_regression.py index 31e688839..ce63894d6 100644 --- a/tests/nn/data/test_offset_regression.py +++ b/tests/nn/data/test_offset_regression.py @@ -4,7 +4,7 @@ from sleap.nn.data import offset_regression -sleap.use_cpu_only() # hide GPUs for test +# sleap.use_cpu_only() # hide GPUs for test def test_make_offsets(): diff --git a/tests/nn/data/test_pipelines.py b/tests/nn/data/test_pipelines.py index 30b67e13c..7d442c32d 100644 --- a/tests/nn/data/test_pipelines.py +++ b/tests/nn/data/test_pipelines.py @@ -3,7 +3,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test import sleap from sleap.nn.data import pipelines diff --git a/tests/nn/data/test_providers.py b/tests/nn/data/test_providers.py index 279244ea1..f30216e6a 100644 --- a/tests/nn/data/test_providers.py +++ b/tests/nn/data/test_providers.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test import sleap from sleap.nn.data import providers diff --git a/tests/nn/data/test_resizing.py b/tests/nn/data/test_resizing.py index 440ca66d0..6ef15c2f1 100644 --- a/tests/nn/data/test_resizing.py +++ b/tests/nn/data/test_resizing.py @@ -1,14 +1,10 @@ import pytest import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only - -use_cpu_only() # hide GPUs for test - import sleap from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.data import resizing from sleap.nn.data import providers from sleap.nn.data.resizing import SizeMatcher diff --git a/tests/nn/data/test_utils.py b/tests/nn/data/test_utils.py index 213e357e8..7fa98a57a 100644 --- a/tests/nn/data/test_utils.py +++ b/tests/nn/data/test_utils.py @@ -2,7 +2,7 @@ import tensorflow as tf from sleap.nn.system import use_cpu_only -use_cpu_only() # hide GPUs for test +# use_cpu_only() # hide GPUs for test from sleap.nn.data import utils diff --git a/tests/nn/test_evals.py b/tests/nn/test_evals.py index 265994056..48f0d69f8 100644 --- a/tests/nn/test_evals.py +++ b/tests/nn/test_evals.py @@ -20,7 +20,7 @@ from sleap.nn.model import Model -sleap.use_cpu_only() +# sleap.use_cpu_only() def test_compute_oks(): diff --git a/tests/nn/test_heads.py b/tests/nn/test_heads.py index 02fbc2737..a4acbb15f 100644 --- a/tests/nn/test_heads.py +++ b/tests/nn/test_heads.py @@ -21,7 +21,7 @@ ) -sleap.use_cpu_only() +# sleap.use_cpu_only() def test_single_instance_confmaps_head(): diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index fe848bb1c..0a978de0a 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -3,12 +3,15 @@ import zipfile from pathlib import Path from typing import cast +import shutil +import csv import numpy as np import pytest +import pandas as pd import tensorflow as tf -import tensorflow_hub as hub from numpy.testing import assert_array_equal, assert_allclose +from sleap.io.video import available_video_exts import sleap from sleap.gui.learning import runners @@ -50,6 +53,7 @@ _make_tracker_from_cli, main as sleap_track, export_cli as sleap_export, + _make_export_cli_parser, ) from sleap.nn.tracking import ( MatchedFrameInstance, @@ -60,7 +64,7 @@ from sleap.instance import Track -sleap.nn.system.use_cpu_only() +# sleap.nn.system.use_cpu_only() @pytest.fixture @@ -925,7 +929,7 @@ def test_load_model(resize_input_shape, model_fixture_name, request): predictor = load_model(model_path, resize_input_layer=resize_input_shape) # Determine predictor type - for (fname, mname, ptype, ishape) in fname_mname_ptype_ishape: + for fname, mname, ptype, ishape in fname_mname_ptype_ishape: if fname in model_fixture_name: expected_model_name = mname expected_predictor_type = ptype @@ -966,7 +970,6 @@ def test_topdown_multi_size_inference( def test_ensure_numpy( min_centroid_model_path, min_centered_instance_model_path, min_labels_slp ): - model = load_model([min_centroid_model_path, min_centered_instance_model_path]) # each frame has same number of instances @@ -1037,7 +1040,6 @@ def test_ensure_numpy( def test_centroid_inference(): - xv, yv = make_grid_vectors(image_height=12, image_width=12, output_stride=1) points = tf.cast([[[1.75, 2.75]], [[3.75, 4.75]], [[5.75, 6.75]]], tf.float32) cms = tf.expand_dims(make_multi_confmaps(points, xv, yv, sigma=1.5), axis=0) @@ -1093,7 +1095,6 @@ def test_centroid_inference(): def export_frozen_graph(model, preds, output_path): - tensors = {} for key, val in preds.items(): @@ -1120,7 +1121,6 @@ def export_frozen_graph(model, preds, output_path): info = json.load(json_file) for tensor_info in info["frozen_model_inputs"] + info["frozen_model_outputs"]: - saved_name = ( tensor_info.split("Tensor(")[1].split(", shape")[0].replace('"', "") ) @@ -1137,7 +1137,6 @@ def export_frozen_graph(model, preds, output_path): def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): - single_instance_model = tf.keras.models.load_model( min_single_instance_robot_model_path + "/best_model.h5", compile=False ) @@ -1152,7 +1151,6 @@ def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): def test_centroid_save(min_centroid_model_path, tmp_path): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1171,7 +1169,6 @@ def test_centroid_save(min_centroid_model_path, tmp_path): def test_topdown_save( min_centroid_model_path, min_centered_instance_model_path, min_labels_slp, tmp_path ): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1195,7 +1192,6 @@ def test_topdown_save( def test_topdown_id_save( min_centroid_model_path, min_topdown_multiclass_model_path, min_labels_slp, tmp_path ): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1217,7 +1213,6 @@ def test_topdown_id_save( def test_single_instance_predictor_save(min_single_instance_robot_model_path, tmp_path): - # directly initialize predictor predictor = SingleInstancePredictor.from_trained_models( min_single_instance_robot_model_path, resize_input_layer=False @@ -1254,10 +1249,33 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm ) +def test_make_export_cli(): + models_path = r"psuedo/models/path" + export_path = r"psuedo/test/path" + max_instances = 5 + + parser = _make_export_cli_parser() + + # Test default values + args = None + args, _ = parser.parse_known_args(args=args) + assert args.models is None + assert args.export_path == "exported_model" + assert not args.ragged + assert args.max_instances is None + + # Test all arguments + cmd = f"-m {models_path} -e {export_path} -r -n {max_instances}" + args, _ = parser.parse_known_args(args=cmd.split()) + assert args.models == [models_path] + assert args.export_path == export_path + assert args.ragged + assert args.max_instances == max_instances + + def test_topdown_predictor_save( min_centroid_model_path, min_centered_instance_model_path, tmp_path ): - # directly initialize predictor predictor = TopDownPredictor.from_trained_models( centroid_model_path=min_centroid_model_path, @@ -1300,7 +1318,6 @@ def test_topdown_predictor_save( def test_topdown_id_predictor_save( min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path ): - # directly initialize predictor predictor = TopDownMultiClassPredictor.from_trained_models( centroid_model_path=min_centroid_model_path, @@ -1358,7 +1375,7 @@ def test_retracking( # Create sleap-track command cmd = ( f"{slp_path} --tracking.tracker {tracker_method} --video.index 0 --frames 1-3 " - "--cpu" + "--tracking.similarity object_keypoint --cpu" ) if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" @@ -1378,6 +1395,8 @@ def test_retracking( parser = _make_cli_parser() args, _ = parser.parse_known_args(args=args) tracker = _make_tracker_from_cli(args) + # Additional check for similarity method + assert tracker.similarity_function.__name__ == "object_keypoint_similarity" output_path = f"{slp_path}.{tracker.get_name()}.slp" # Assert tracked predictions file exists @@ -1433,7 +1452,49 @@ def test_make_predictor_from_cli( assert predictor.max_instances == 5 -def test_sleap_track( +def test_make_predictor_from_cli_mult_input( + centered_pair_predictions: Labels, + min_centroid_model_path: str, + min_centered_instance_model_path: str, + min_bottomup_model_path: str, + tmpdir, +): + slp_path = tmpdir.mkdir("slp_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name for the video + + # Construct the destination path with a unique name for the SLP file + slp_dest_path = slp_path / f"old_slp_copy_{i}.slp" + shutil.copy(slp_file, slp_dest_path) + + # Create sleap-track command + model_args = [ + f"--model {min_centroid_model_path} --model {min_centered_instance_model_path}", + f"--model {min_bottomup_model_path}", + ] + for model_arg in model_args: + args = ( + f"{slp_path} {model_arg} --video.index 0 --frames 1-3 " + "--cpu --max_instances 5" + ).split() + parser = _make_cli_parser() + args, _ = parser.parse_known_args(args=args) + + # Create predictor + predictor = _make_predictor_from_cli(args=args) + if isinstance(predictor, TopDownPredictor): + assert predictor.inference_model.centroid_crop.max_instances == 5 + elif isinstance(predictor, BottomUpPredictor): + assert predictor.max_instances == 5 + + +def test_sleap_track_single_input( centered_pair_predictions: Labels, min_centroid_model_path: str, min_centered_instance_model_path: str, @@ -1452,7 +1513,7 @@ def test_sleap_track( sleap_track(args=args) # Assert predictions file exists - output_path = f"{slp_path}.predictions.slp" + output_path = Path(slp_path).with_suffix(".predictions.slp") assert Path(output_path).exists() # Create invalid sleap-track command @@ -1461,9 +1522,398 @@ def test_sleap_track( sleap_track(args=args) -def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_slp( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + tmpdir, + centered_pair_predictions: Labels, + tracking, +): + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("slp_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + slp_path_obj = Path(slp_path) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name for the SLP file + slp_dest_path = slp_path / f"old_slp_copy_{i}.slp" + shutil.copy(slp_file, slp_dest_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + expected_extensions = available_video_exts() + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = Path(file_path).with_suffix(".predictions.slp") + assert Path(expected_output_file).exists() + + +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_slp_mp4( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tracking, + tmpdir, + centered_pair_predictions: Labels, +): + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("slp_mp4_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + # Copy and paste the video into temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + expected_extensions = available_video_exts() + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = Path(file_path).with_suffix(".predictions.slp") + assert Path(expected_output_file).exists() + + +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_mp4( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tracking, + tmpdir, +): + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + expected_extensions = available_video_exts() + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = Path(file_path).with_suffix(".predictions.slp") + assert Path(expected_output_file).exists() + + +def test_sleap_track_output_mult( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tmpdir, +): + + output_path = tmpdir.mkdir("output_directory") + output_path_obj = Path(output_path) + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"-o {output_path} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + slp_path = Path(slp_path) + + # Check if there are any files in the directory + expected_extensions = available_video_exts() + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = output_path_obj / ( + file_path.stem + ".predictions.slp" + ) + assert Path(expected_output_file).exists() + + +def test_sleap_track_invalid_output( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + centered_pair_predictions: Labels, + tmpdir, +): + + output_path = Path(tmpdir, "output_file.slp").as_posix() + Labels.save(centered_pair_predictions, output_path) + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"-o {output_path} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference + with pytest.raises(ValueError): + sleap_track(args=args) + + +def test_sleap_track_invalid_input( + min_centroid_model_path: str, + min_centered_instance_model_path: str, +): + + slp_path = "" + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference + with pytest.raises(ValueError): + sleap_track(args=args) + + # Test with a non-existent path + slp_path = "/path/to/nonexistent/file.mp4" + + # Create sleap-track command for non-existent path + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect a ValueError for non-existent path + with pytest.raises(ValueError): + sleap_track(args=args) + + +def test_sleap_track_csv_input( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tmpdir, +): + + # Create temporary directory with the structured video files + slp_path = Path(tmpdir.mkdir("mp4_directory")) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + file_paths = [] + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + file_paths.append(dest_path) + + # Generate output paths for each data_path + output_paths = [ + file_path.with_suffix(".TESTpredictions.slp") for file_path in file_paths + ] + + # Create a CSV file with the file paths + csv_file_path = slp_path / "file_paths.csv" + with open(csv_file_path, mode="w", newline="") as csv_file: + csv_writer = csv.writer(csv_file) + csv_writer.writerow(["data_path", "output_path"]) + for data_path, output_path in zip(file_paths, output_paths): + csv_writer.writerow([data_path, output_path]) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{csv_file_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + expected_extensions = available_video_exts() + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = file_path.with_suffix(".TESTpredictions.slp") + assert Path(expected_output_file).exists() + + +def test_sleap_track_invalid_csv( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + tmpdir, +): + + # Create a CSV file with nonexistant data files + csv_nonexistant_files_path = tmpdir / "nonexistant_files.csv" + df_nonexistant_files = pd.DataFrame( + {"data_path": ["video1.mp4", "video2.mp4", "video3.mp4"]} + ) + df_nonexistant_files.to_csv(csv_nonexistant_files_path, index=False) + + # Create an empty CSV file + csv_empty_path = tmpdir / "empty.csv" + open(csv_empty_path, "w").close() + + # Create sleap-track command for missing 'data_path' column + args_missing_column = ( + f"{csv_nonexistant_files_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect ValueError for missing 'data_path' column + with pytest.raises( + ValueError, + ): + sleap_track(args=args_missing_column) + + # Create sleap-track command for empty CSV file + args_empty = ( + f"{csv_empty_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect ValueError for empty CSV file + with pytest.raises(ValueError): + sleap_track(args=args_empty) + + +def test_sleap_track_text_file_input( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tmpdir, +): + + # Create temporary directory with the structured video files + slp_path = Path(tmpdir.mkdir("mp4_directory")) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + file_paths = [] + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + file_paths.append(dest_path) + + # Create a text file with the file paths + txt_file_path = slp_path / "file_paths.txt" + with open(txt_file_path, mode="w") as txt_file: + for file_path in file_paths: + txt_file.write(f"{file_path}\n") + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{txt_file_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + expected_extensions = available_video_exts() + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = Path(file_path).with_suffix(".predictions.slp") + assert Path(expected_output_file).exists() + + +def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): """Test flow tracker instances are pruned.""" - labels: Labels = centered_pair_predictions + labels: Labels = centered_pair_predictions_sorted track_window = 5 # Setup tracker @@ -1473,17 +1923,20 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker) # Run tracking - frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) + frames = labels.labeled_frames # Run tracking on subset of frames using psuedo-implementation of # sleap.nn.tracking.run_tracker for lf in frames[:20]: - # Clear the tracks for inst in lf.instances: inst.track = None - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + track_args = dict( + untracked_instances=lf.instances, + img=lf.video[lf.frame_idx], + img_hw=lf.image.shape[-3:-1], + ) tracker.track(**track_args) # Check that saved instances are pruned to track window @@ -1522,12 +1975,15 @@ def test_max_tracks_matching_queue( frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) for lf in frames[:20]: - # Clear the tracks for inst in lf.instances: inst.track = None - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + track_args = dict( + untracked_instances=lf.instances, + img=lf.video[lf.frame_idx], + img_hw=lf.image.shape[-3:-1], + ) tracker.track(**track_args) if trackername == "flowmaxtracks": @@ -1583,8 +2039,7 @@ def test_movenet_predictor(min_dance_labels, movenet_video): [labels_pr[0][0].numpy(), labels_pr[1][0].numpy()], axis=0 ) - max_diff = np.nanmax(np.abs(points_gt - points_pr)) - assert max_diff < 0.1 + np.testing.assert_allclose(points_gt, points_pr, atol=0.75) @pytest.mark.parametrize( diff --git a/tests/nn/test_inference_identity.py b/tests/nn/test_inference_identity.py index 22be152ea..aaacfef61 100644 --- a/tests/nn/test_inference_identity.py +++ b/tests/nn/test_inference_identity.py @@ -9,7 +9,7 @@ ) -sleap.use_cpu_only() +# sleap.use_cpu_only() def test_group_class_peaks(): diff --git a/tests/nn/test_model.py b/tests/nn/test_model.py index 329e5528f..6c60cb354 100644 --- a/tests/nn/test_model.py +++ b/tests/nn/test_model.py @@ -15,7 +15,7 @@ ModelConfig, ) -sleap.use_cpu_only() +# sleap.use_cpu_only() def test_model_from_config(): diff --git a/tests/nn/test_nn_utils.py b/tests/nn/test_nn_utils.py index 15b9d4bf3..4e8703c05 100644 --- a/tests/nn/test_nn_utils.py +++ b/tests/nn/test_nn_utils.py @@ -6,7 +6,7 @@ from sleap.nn.inference import TopDownPredictor from sleap.nn.utils import tf_linear_sum_assignment, match_points, reset_input_layer -sleap.use_cpu_only() +# sleap.use_cpu_only() def test_tf_linear_sum_assignment(): diff --git a/tests/nn/test_paf_grouping.py b/tests/nn/test_paf_grouping.py index 4856c1fed..d9578bfa9 100644 --- a/tests/nn/test_paf_grouping.py +++ b/tests/nn/test_paf_grouping.py @@ -22,7 +22,7 @@ assign_connections_to_instances, ) -sleap.nn.system.use_cpu_only() +# sleap.nn.system.use_cpu_only() def test_get_connection_candidates(): diff --git a/tests/nn/test_peak_finding.py b/tests/nn/test_peak_finding.py index 93beaa193..243653202 100644 --- a/tests/nn/test_peak_finding.py +++ b/tests/nn/test_peak_finding.py @@ -22,7 +22,7 @@ ) -sleap.nn.system.use_cpu_only() +# sleap.nn.system.use_cpu_only() def test_find_local_offsets(): diff --git a/tests/nn/test_system.py b/tests/nn/test_system.py index fc95bb0ea..7b16f1219 100644 --- a/tests/nn/test_system.py +++ b/tests/nn/test_system.py @@ -4,13 +4,19 @@ be available. """ -from sleap.nn.system import get_gpu_memory -from sleap.nn.system import get_all_gpus +from sleap.nn.system import ( + get_gpu_memory, + get_all_gpus, + use_cpu_only, + use_gpu, + is_gpu_system, +) import os import pytest import subprocess import tensorflow as tf import shutil +import platform def test_get_gpu_memory(): @@ -93,3 +99,17 @@ def test_gpu_device_order(): """Indirectly tests GPU device order by ensuring environment variable is set.""" assert os.environ["CUDA_DEVICE_ORDER"] == "PCI_BUS_ID" + + +@pytest.mark.skipif( + not ("arm64" in platform.platform()), + reason="Only test on macosx-arm64", +) +def test_reinitialize(): + """This test tries to change the devices after they have been initialized.""" + assert is_gpu_system() + use_gpu(0) + tf.zeros((1,)) + tf.ones((1,)) + # The following would normally throw: + # RuntimeError: Visible devices cannot be modified after being initialized + use_cpu_only() diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index f861241ee..0c7ba2b0a 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -9,23 +9,82 @@ FrameMatches, greedy_matching, ) +from sleap.io.dataset import Labels from sleap.instance import PredictedInstance from sleap.skeleton import Skeleton +def tracker_by_name(frames=None, **kwargs): + t = Tracker.make_tracker_by_name(**kwargs) + print(kwargs) + print(t.candidate_maker) + if frames is None: + t.track([]) + t.final_pass([]) + return + + for lf in frames: + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + t.track(**track_args, img_hw=(1, 1)) + t.final_pass(frames) + + @pytest.mark.parametrize( "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] ) -@pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) +@pytest.mark.parametrize( + "similarity", + ["instance", "normalized_instance", "iou", "centroid", "object_keypoint"], +) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("count", [0, 2]) -def test_tracker_by_name(tracker, similarity, match, count): - t = Tracker.make_tracker_by_name( - "flow", "instance", "greedy", clean_instance_count=2 +def test_tracker_by_name( + centered_pair_predictions_sorted, + tracker, + similarity, + match, + count, +): + # This is slow, so limit to 5 time points + frames = centered_pair_predictions_sorted[:5] + + tracker_by_name( + frames=frames, + tracker=tracker, + similarity=similarity, + match=match, + max_tracks=count, + ) + + +@pytest.mark.parametrize( + "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] +) +@pytest.mark.parametrize("oks_score_weighting", ["True", "False"]) +@pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"]) +def test_oks_tracker_by_name( + centered_pair_predictions_sorted, + tracker, + oks_score_weighting, + oks_normalization, +): + # This is slow, so limit to 5 time points + frames = centered_pair_predictions_sorted[:5] + + tracker_by_name( + frames=frames, + tracker=tracker, + similarity="object_keypoint", + matching="greedy", + oks_score_weighting=oks_score_weighting, + oks_normalization=oks_normalization, + max_tracks=2, ) - t.track([]) - t.final_pass([]) def test_cull_instances(centered_pair_predictions): @@ -232,7 +291,7 @@ def test_max_tracking_large_gap_single_track(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -249,7 +308,7 @@ def test_max_tracking_large_gap_single_track(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -296,7 +355,7 @@ def test_max_tracking_small_gap_on_both_tracks(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -313,7 +372,7 @@ def test_max_tracking_small_gap_on_both_tracks(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -365,7 +424,7 @@ def test_max_tracking_extra_detections(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -382,7 +441,7 @@ def test_max_tracking_extra_detections(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index a6592dc4d..625302fd0 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -102,7 +102,7 @@ def run_tracker(frames, tracker): new_lf = LabeledFrame( frame_idx=lf.frame_idx, video=lf.video, - instances=tracker.track(**track_args), + instances=tracker.track(**track_args, img_hw=lf.image.shape[-3:-1]), ) new_lfs.append(new_lf) @@ -138,6 +138,8 @@ def main(f, dir): instance=sleap.nn.tracker.components.instance_similarity, centroid=sleap.nn.tracker.components.centroid_distance, iou=sleap.nn.tracker.components.instance_iou, + normalized_instance=sleap.nn.tracker.components.normalized_instance_similarity, + object_keypoint=sleap.nn.tracker.components.factory_object_keypoint_similarity(), ) scales = ( 1, diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index 55f404929..72db17bb5 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -25,7 +25,7 @@ create_trainer_using_cli as sleap_train, ) -sleap.use_cpu_only() +# sleap.use_cpu_only() @pytest.fixture @@ -44,7 +44,7 @@ def cfg(): cfg = TrainingJobConfig() cfg.data.instance_cropping.center_on_part = "A" cfg.model.backbone.unet = UNetConfig( - max_stride=8, output_stride=1, filters=8, filters_rate=1.0 + max_stride=8, output_stride=1, filters=2, filters_rate=1.0 ) cfg.optimization.preload_data = False cfg.optimization.batch_size = 1 @@ -123,34 +123,61 @@ def test_train_load_single_instance( assert (w == w2).all() -def test_train_single_instance(min_labels_robot, cfg): +def test_train_single_instance(min_labels_robot, cfg, tmp_path): cfg.model.heads.single_instance = SingleInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=False ) + + # Set save directory + cfg.outputs.run_name = "test_run" + cfg.outputs.runs_folder = str(tmp_path / "training_runs") # ensure it's a string + cfg.outputs.save_visualizations = True + cfg.outputs.keep_viz_images = True + cfg.outputs.save_outputs = True # enable saving + trainer = SingleInstanceModelTrainer.from_config( cfg, training_labels=min_labels_robot ) trainer.setup() trainer.train() + + run_path = Path(cfg.outputs.runs_folder, cfg.outputs.run_name) + viz_path = run_path / "viz" + assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 320, 560, 2) + assert viz_path.exists() -def test_train_single_instance_with_offset(min_labels_robot, cfg): +def test_train_single_instance_with_offset(min_labels_robot, cfg, tmp_path): cfg.model.heads.single_instance = SingleInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=True ) + + # Set save directory + cfg.outputs.run_name = "test_run" + cfg.outputs.runs_folder = str(tmp_path / "training_runs") # ensure it's a string + cfg.outputs.save_visualizations = False + cfg.outputs.keep_viz_images = False + cfg.outputs.save_outputs = True # enable saving + trainer = SingleInstanceModelTrainer.from_config( cfg, training_labels=min_labels_robot ) trainer.setup() trainer.train() + + run_path = Path(cfg.outputs.runs_folder, cfg.outputs.run_name) + viz_path = run_path / "viz" + assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 320, 560, 2) assert trainer.keras_model.output_names[1] == "OffsetRefinementHead" assert tuple(trainer.keras_model.outputs[1].shape) == (None, 320, 560, 4) + assert not viz_path.exists() + def test_train_centroids(training_labels, cfg): cfg.model.heads.centroid = CentroidsHeadConfig( @@ -251,12 +278,12 @@ def test_train_bottomup_with_offset(training_labels, cfg): def test_train_bottomup_multiclass(min_tracks_2node_labels, cfg): labels = min_tracks_2node_labels - cfg.data.preprocessing.input_scaling = 0.5 + cfg.data.preprocessing.input_scaling = 0.25 cfg.model.heads.multi_class_bottomup = sleap.nn.config.MultiClassBottomUpConfig( confmaps=sleap.nn.config.MultiInstanceConfmapsHeadConfig( - output_stride=2, offset_refinement=False + output_stride=4, offset_refinement=False ), - class_maps=sleap.nn.config.ClassMapsHeadConfig(output_stride=2), + class_maps=sleap.nn.config.ClassMapsHeadConfig(output_stride=4), ) trainer = sleap.nn.training.BottomUpMultiClassModelTrainer.from_config( cfg, training_labels=labels @@ -266,8 +293,8 @@ def test_train_bottomup_multiclass(min_tracks_2node_labels, cfg): assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead" assert trainer.keras_model.output_names[1] == "ClassMapsHead" - assert tuple(trainer.keras_model.outputs[0].shape) == (None, 256, 256, 2) - assert tuple(trainer.keras_model.outputs[1].shape) == (None, 256, 256, 2) + assert tuple(trainer.keras_model.outputs[0].shape) == (None, 64, 64, 2) + assert tuple(trainer.keras_model.outputs[1].shape) == (None, 64, 64, 2) def test_train_topdown_multiclass(min_tracks_2node_labels, cfg): @@ -360,3 +387,26 @@ def test_resume_training_cli( trainer = sleap_train(cli_args) assert trainer.config.model.base_checkpoint == base_checkpoint_path + + +@pytest.mark.parametrize("keep_viz_cli", ["", "--keep_viz"]) +def test_keep_viz_cli( + keep_viz_cli, + min_single_instance_robot_model_path: str, + tmp_path: str, +): + """Test training CLI for --keep_viz option.""" + cfg_dir = min_single_instance_robot_model_path + cfg = TrainingJobConfig.load_json(str(Path(cfg_dir, "training_config.json"))) + + # Save training config to tmp folder + cfg_path = str(Path(tmp_path, "training_config.json")) + cfg.save_json(cfg_path) + + cli_args = [cfg_path, keep_viz_cli] + trainer = sleap_train(cli_args) + + # Check that --keep_viz is set correctly + assert trainer.config.outputs.keep_viz_images == ( + True if keep_viz_cli == "--keep_viz" else False + ) diff --git a/tests/test_instance.py b/tests/test_instance.py index 74a8b192e..58a630a8b 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -1,19 +1,21 @@ -import os -import math import copy +import math +import os +from typing import List -import pytest import numpy as np +import pytest -from sleap.skeleton import Skeleton +from sleap import Labels from sleap.instance import ( Instance, - PredictedInstance, + InstancesList, + LabeledFrame, Point, + PredictedInstance, PredictedPoint, - LabeledFrame, ) -from sleap import Labels +from sleap.skeleton import Skeleton def test_instance_node_get_set_item(skeleton): @@ -310,6 +312,8 @@ def test_frame_merge_predicted_and_user(skeleton, centered_pair_vid): # and we want to retain both even though they perfectly match. assert user_inst in user_frame.instances assert pred_inst in user_frame.instances + assert user_inst.frame == user_frame + assert pred_inst.frame == user_frame assert len(user_frame.instances) == 2 @@ -529,3 +533,216 @@ def test_instance_structuring_from_predicted(centered_pair_predictions): # Unstructure -> structure labels_copy = labels.copy() + + +def test_instances_list(centered_pair_predictions): + + labels = centered_pair_predictions + + def test_extend(instances: InstancesList, list_of_instances: List[Instance]): + instances.extend(list_of_instances) + assert len(instances) == len(list_of_instances) + for instance in instances: + assert isinstance(instance, PredictedInstance) + if instances.labeled_frame is None: + assert instance.frame is None + else: + assert instance.frame == instances.labeled_frame + + def test_append(instances: InstancesList, instance: Instance): + prev_len = len(instances) + instances.append(instance) + assert len(instances) == prev_len + 1 + assert instances[-1] == instance + assert instance.frame == instances.labeled_frame + + def test_labeled_frame_setter( + instances: InstancesList, labeled_frame: LabeledFrame + ): + instances.labeled_frame = labeled_frame + for instance in instances: + assert instance.frame == labeled_frame + + # Case 1: Create an empty instances list + labeled_frame = labels.labeled_frames[0] + list_of_instances = list(labeled_frame.instances) + instances = InstancesList() + assert len(instances) == 0 + assert instances._labeled_frame is None + assert instances.labeled_frame is None + + # Extend instances list + assert not isinstance(list_of_instances, InstancesList) + assert isinstance(list_of_instances, list) + test_extend(instances, list_of_instances) + + # Set the labeled frame + test_labeled_frame_setter(instances, labeled_frame) + + # Case 2: Create an empy instances list but initialize the labeled frame + instances = InstancesList(labeled_frame=labeled_frame) + assert len(instances) == 0 + assert instances._labeled_frame == labeled_frame + assert instances.labeled_frame == labeled_frame + + # Extend instances to the list from a different labeled frame + labeled_frame = labels.labeled_frames[1] + list_of_instances = list(labeled_frame.instances) + test_extend(instances, list_of_instances) + + # Add instance to the list + instance = list_of_instances[0] + instance.frame = None + test_append(instances, instance) + + # Set the labeled frame + test_labeled_frame_setter(instances, labeled_frame) + + # Test InstancesList.copy + instances_copy = instances.copy() + assert len(instances_copy) == len(instances) + assert not isinstance(instances_copy, InstancesList) + assert isinstance(instances_copy, list) + + # Test InstancesList.clear + instances_in_instances = list(instances) + instances.clear() + assert len(instances) == 0 + for instance in instances_in_instances: + assert instance.frame is None + + # Case 3: Create an instances list with a list of instances + labeled_frame = labels.labeled_frames[0] + list_of_instances = list(labeled_frame.instances) + instances = InstancesList(list_of_instances) + assert len(instances) == len(list_of_instances) + assert instances._labeled_frame is None + assert instances.labeled_frame is None + for instance in instances: + assert instance.frame is None + + # Add instance to the list + instance = list_of_instances[0] + test_append(instances, instance) + + # Case 4: Create an instances list with a list of instances and initialize the frame + labeled_frame_1 = labels.labeled_frames[0] + labeled_frame_2 = labels.labeled_frames[1] + list_of_instances = list(labeled_frame_2.instances) + instances = InstancesList(list_of_instances, labeled_frame=labeled_frame_1) + assert len(instances) == len(list_of_instances) + assert instances._labeled_frame == labeled_frame + assert instances.labeled_frame == labeled_frame + for instance in instances: + assert instance.frame == labeled_frame + + # Test InstancesList.__delitem__ + instance_to_remove = instances[0] + del instances[0] + assert instance_to_remove not in instances + assert instance_to_remove.frame is None + + # Test InstancesList.insert + instances.insert(0, instance_to_remove) + assert instances[0] == instance_to_remove + assert instance_to_remove.frame == instances.labeled_frame + + # Test InstancesList.__setitem__ + new_instance = labeled_frame_1.instances[0] + new_instance.frame = None + instances[0] = new_instance + assert instances[0] == new_instance + assert new_instance.frame == instances.labeled_frame + + # Test InstancesList.pop + popped_instance = instances.pop(0) + assert popped_instance.frame is None + + # Test InstancesList.remove + instance_to_remove = instances[0] + instances.remove(instance_to_remove) + assert instance_to_remove.frame is None + assert instance_to_remove not in instances + + # Case 5: Create an instances list from an instances list + instances_1 = InstancesList(list_of_instances, labeled_frame=labeled_frame_1) + instances = InstancesList(instances_1) + assert len(instances) == len(instances_1) + assert instances._labeled_frame is None + assert instances.labeled_frame is None + for instance in instances: + assert instance.frame is None + + +def test_instances_list_with_labeled_frame(centered_pair_predictions): + labels: Labels = centered_pair_predictions + labels_lf_0: LabeledFrame = labels.labeled_frames[0] + video = labels_lf_0.video + frame_idx = labels_lf_0.frame_idx + + def test_post_init(labeled_frame: LabeledFrame): + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame + + # Create labeled frame from list of instances + instances = list(labels_lf_0.instances) + for instance in instances: + instance.frame = None # Change frame to None to test if it is set correctly + labeled_frame = LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + assert isinstance(labeled_frame.instances, InstancesList) + assert len(labeled_frame.instances) == len(instances) + test_post_init(labeled_frame) + + # Create labeled frame from instances list + instances = InstancesList(labels_lf_0.instances) + labeled_frame = LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + assert isinstance(labeled_frame.instances, InstancesList) + assert len(labeled_frame.instances) == len(instances) + test_post_init(labeled_frame) + + # Test LabeledFrame.__len__ + assert len(labeled_frame.instances) == len(instances) + + # Test LabeledFrame.__getitem__ + assert labeled_frame[0] == instances[0] + + # Test LabeledFrame.index + assert labeled_frame.index(instances[0]) == instances.index(instances[0]) == 0 + + # Test LabeledFrame.__delitem__ + instance_to_remove = labeled_frame[0] + del labeled_frame[0] + assert instance_to_remove not in labeled_frame.instances + assert instance_to_remove.frame is None + + # Test LabeledFrame.__repr__ + print(labeled_frame) + + # Test LabeledFrame.insert + labeled_frame.insert(0, instance_to_remove) + assert labeled_frame[0] == instance_to_remove + assert instance_to_remove.frame == labeled_frame + + # Test LabeledFrame.__setitem__ + new_instance = instances[1] + new_instance.frame = None + labeled_frame[0] = new_instance + assert labeled_frame[0] == new_instance + assert new_instance.frame == labeled_frame + + # Test instances.setter (empty list) + labeled_frame.instances = [] + assert len(labeled_frame.instances) == 0 + assert labeled_frame.instances.labeled_frame == labeled_frame + # Test instances.setter (InstancesList) + labeled_frame.instances = labels.labeled_frames[1].instances + assert len(labeled_frame.instances) == len(labels.labeled_frames[1].instances) + assert labeled_frame.instances.labeled_frame == labeled_frame + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame + # Test instances.setter (populated list) + labeled_frame.instances = list(labels.labeled_frames[1].instances) + assert len(labeled_frame.instances) == len(labels.labeled_frames[1].instances) + assert labeled_frame.instances.labeled_frame == labeled_frame + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index 1f7c3a853..2320342f6 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -1,10 +1,74 @@ -import os import copy - -import jsonpickle +import os import pytest +import json + +from networkx.readwrite import json_graph +from sleap.skeleton import Skeleton, SkeletonDecoder +from sleap.skeleton import SkeletonEncoder + + +def test_decoded_encoded_Skeleton_from_load_json(fly_legs_skeleton_json): + """ + Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton. + """ + # Get the skeleton from the fixture + skeleton = Skeleton.load_json(fly_legs_skeleton_json) + # Get the graph from the skeleton + indexed_node_graph = skeleton._graph + graph = json_graph.node_link_data(indexed_node_graph) + + # Encode the graph as a json string to test .encode method + encoded_json_str = SkeletonEncoder.encode(graph) + + # Get the skeleton from the encoded json string + decoded_skeleton = Skeleton.from_json(encoded_json_str) + + # Check that the decoded skeleton is the same as the original skeleton + assert skeleton.matches(decoded_skeleton) + + +@pytest.mark.parametrize( + "skeleton_fixture_name", ["flies13_skeleton", "skeleton", "stickman"] +) +def test_decoded_encoded_Skeleton(skeleton_fixture_name, request): + """ + Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton. + """ + # Use request.getfixturevalue to get the actual fixture value by name + skeleton = request.getfixturevalue(skeleton_fixture_name) + + # Get the graph from the skeleton + indexed_node_graph = skeleton._graph + graph = json_graph.node_link_data(indexed_node_graph) + + # Encode the graph as a json string to test .encode method + encoded_json_str = SkeletonEncoder.encode(graph) + + # Assert that the encoded json has keys in sorted order (backwards compatibility) + encoded_dict = json.loads(encoded_json_str) + sorted_keys = sorted(encoded_dict.keys()) + assert list(encoded_dict.keys()) == sorted_keys + for key, value in encoded_dict.items(): + if isinstance(value, dict): + assert list(value.keys()) == sorted(value.keys()) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + assert list(item.keys()) == sorted(item.keys()) + + # Get the skeleton from the encoded json string + decoded_skeleton = Skeleton.from_json(encoded_json_str) + + # Check that the decoded skeleton is the same as the original skeleton + assert skeleton.matches(decoded_skeleton) + + # Now make everything into a JSON string + skeleton_json_str = skeleton.to_json() + decoded_skeleton_json_str = decoded_skeleton.to_json() -from sleap.skeleton import Skeleton + # Check that the JSON strings are the same + assert json.loads(skeleton_json_str) == json.loads(decoded_skeleton_json_str) def test_add_dupe_node(skeleton): @@ -194,9 +258,9 @@ def test_json(skeleton: Skeleton, tmpdir): ) assert skeleton.is_template == False json_str = skeleton.to_json() - json_dict = jsonpickle.decode(json_str) + json_dict = SkeletonDecoder.decode(json_str) json_dict_keys = list(json_dict.keys()) - assert "nx_graph" not in json_dict_keys + assert "nx_graph" in json_dict_keys # SkeletonDecoder adds this key assert "preview_image" not in json_dict_keys assert "description" not in json_dict_keys @@ -208,7 +272,7 @@ def test_json(skeleton: Skeleton, tmpdir): skeleton._is_template = True json_str = skeleton.to_json() - json_dict = jsonpickle.decode(json_str) + json_dict = SkeletonDecoder.decode(json_str) json_dict_keys = list(json_dict.keys()) assert "nx_graph" in json_dict_keys assert "preview_image" in json_dict_keys @@ -224,6 +288,26 @@ def test_json(skeleton: Skeleton, tmpdir): assert skeleton.matches(skeleton_copy) +def test_decode_preview_image(flies13_skeleton: Skeleton): + skeleton = flies13_skeleton + img_b64 = skeleton.preview_image + img = SkeletonDecoder.decode_preview_image(img_b64) + assert img.mode == "RGBA" + + +def test_skeleton_decoder(fly_legs_skeleton_json, fly_legs_skeleton_dict_json): + """Test that SkeletonDecoder can decode both tuple and dict py/state formats.""" + + skeleton_tuple_pystate = Skeleton.load_json(fly_legs_skeleton_json) + assert isinstance(skeleton_tuple_pystate, Skeleton) + + skeleton_dict_pystate = Skeleton.load_json(fly_legs_skeleton_dict_json) + assert isinstance(skeleton_dict_pystate, Skeleton) + + # These are the same skeleton, so they should match + assert skeleton_dict_pystate.matches(skeleton_tuple_pystate) + + def test_hdf5(skeleton, stickman, tmpdir): filename = os.path.join(tmpdir, "skeleton.h5") diff --git a/tests/test_util.py b/tests/test_util.py index a7916d47f..35b41afa8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,5 +1,4 @@ import pytest -from sleap.skeleton import Skeleton from sleap.util import * @@ -147,10 +146,3 @@ def test_save_dict_to_hdf5(tmpdir): assert f["bar"][-1].decode() == "zop" assert f["cab"]["a"][()] == 2 - - -def test_decode_preview_image(flies13_skeleton: Skeleton): - skeleton = flies13_skeleton - img_b64 = skeleton.preview_image - img = decode_preview_image(img_b64) - assert img.mode == "RGBA"