From 70d6c54d95d80bf3223beb2c9fb4782c3021df66 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 22 Sep 2024 13:45:15 -0400 Subject: [PATCH 1/4] Set up new build context and actions. --- .github/workflows/code-quality-main.yaml | 27 ++++ .github/workflows/code-quality-pr.yaml | 41 ++++++ .github/workflows/python-build.yaml | 99 ++++++++++++++ .github/workflows/tests.yaml | 48 +++++++ .pre-commit-config.yaml | 129 ++++++++++++++++++ mixins/__init__.py | 24 ---- pyproject.toml | 12 +- setup.py | 25 ---- src/mixins/__init__.py | 29 ++++ {mixins => src/mixins}/debuggable.py | 0 {mixins => src/mixins}/multiprocessingable.py | 0 {mixins => src/mixins}/saveable.py | 0 {mixins => src/mixins}/seedable.py | 0 {mixins => src/mixins}/swapcacheable.py | 0 {mixins => src/mixins}/tensorable.py | 0 {mixins => src/mixins}/timeable.py | 0 {mixins => src/mixins}/tqdmable.py | 0 {mixins => src/mixins}/utils.py | 0 tests/test_saveable_mixin.py | 3 - tests/test_seedable_mixin.py | 3 - tests/test_swapcacheable_mixin.py | 3 - tests/test_tensorable_mixin.py | 3 - tests/test_timeable_mixin.py | 5 +- tests/test_tqdmable_mixin.py | 3 - 24 files changed, 385 insertions(+), 69 deletions(-) create mode 100644 .github/workflows/code-quality-main.yaml create mode 100644 .github/workflows/code-quality-pr.yaml create mode 100644 .github/workflows/python-build.yaml create mode 100644 .github/workflows/tests.yaml create mode 100644 .pre-commit-config.yaml delete mode 100644 mixins/__init__.py delete mode 100644 setup.py create mode 100644 src/mixins/__init__.py rename {mixins => src/mixins}/debuggable.py (100%) rename {mixins => src/mixins}/multiprocessingable.py (100%) rename {mixins => src/mixins}/saveable.py (100%) rename {mixins => src/mixins}/seedable.py (100%) rename {mixins => src/mixins}/swapcacheable.py (100%) rename {mixins => src/mixins}/tensorable.py (100%) rename {mixins => src/mixins}/timeable.py (100%) rename {mixins => src/mixins}/tqdmable.py (100%) rename {mixins => src/mixins}/utils.py (100%) diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml new file mode 100644 index 0000000..ec878bf --- /dev/null +++ b/.github/workflows/code-quality-main.yaml @@ -0,0 +1,27 @@ +# Same as `code-quality-pr.yaml` but triggered on commit to main branch +# and runs on all files (instead of only the changed ones) + +name: Code Quality Main + +on: + push: + branches: [main] + +jobs: + code-quality: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Run pre-commits + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml new file mode 100644 index 0000000..2e08be0 --- /dev/null +++ b/.github/workflows/code-quality-pr.yaml @@ -0,0 +1,41 @@ +# This workflow finds which files were changed, prints them, +# and runs `pre-commit` on those files. + +# Inspired by the sktime library: +# https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml + +name: Code Quality PR + +on: + pull_request: + branches: [main, "release/*", "dev"] + +jobs: + code-quality: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Find modified files + id: file_changes + uses: trilom/file-changes-action@v1.2.4 + with: + output: " " + + - name: List modified files + run: echo '${{ steps.file_changes.outputs.files}}' + + - name: Run pre-commits + uses: pre-commit/action@v3.0.1 + with: + extra_args: --files ${{ steps.file_changes.outputs.files}} diff --git a/.github/workflows/python-build.yaml b/.github/workflows/python-build.yaml new file mode 100644 index 0000000..a32827f --- /dev/null +++ b/.github/workflows/python-build.yaml @@ -0,0 +1,99 @@ +name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI + +on: push + +jobs: + build: + name: Build distribution 📦 + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.12"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: >- + Publish Python 🐍 distribution 📦 to PyPI + if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes + needs: + - build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/ml-mixins + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + github-release: + name: >- + Sign the Python 🐍 distribution 📦 with Sigstore + and upload them to GitHub Release + needs: + - publish-to-pypi + runs-on: ubuntu-latest + + permissions: + contents: write # IMPORTANT: mandatory for making GitHub Releases + id-token: write # IMPORTANT: mandatory for sigstore + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v2.1.1 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release create + '${{ github.ref_name }}' + --repo '${{ github.repository }}' + --notes "" + - name: Upload artifact signatures to GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + # Upload to GitHub Release using the `gh` CLI. + # `dist/` contains the built packages, and the + # sigstore-produced signatures and certificates. + run: >- + gh release upload + '${{ github.ref_name }}' dist/** + --repo '${{ github.repository }}' diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 0000000..c82160d --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,48 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + branches: [main, "release/*", "dev"] + +jobs: + run_tests_ubuntu: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + fail-fast: false + + timeout-minutes: 30 + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install packages + run: | + pip install -e .[tests] + + #---------------------------------------------- + # run test suite + #---------------------------------------------- + - name: Run tests + run: | + pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4.0.1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..61bde52 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,129 @@ +default_language_version: + python: python3.12 + +exclude: "docs/index.md|MIMIC-IV_Example/README.md|eICU_Example/README.md" + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-docstring-first + - id: check-yaml + - id: debug-statements + - id: detect-private-key + - id: check-executables-have-shebangs + - id: check-toml + - id: check-case-conflict + - id: check-added-large-files + args: [--maxkb, "800"] + + # python code formatting + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + args: [--line-length, "110"] + + # python import sorting + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files", "-o", "wandb"] + + - repo: https://github.com/PyCQA/autoflake + rev: v2.2.0 + hooks: + - id: autoflake + args: [--in-place, --remove-all-unused-imports] + + # python upgrading syntax to newer version + - repo: https://github.com/asottile/pyupgrade + rev: v3.10.1 + hooks: + - id: pyupgrade + args: [--py311-plus] + + # python docstring formatting + - repo: https://github.com/myint/docformatter + rev: v1.7.5 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=110, --wrap-descriptions=110] + + # python check (PEP8), programming errors and code complexity + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + args: + [ + "--max-complexity=10", + "--extend-ignore", + "E402,E701,E251,E226,E302,W504,E704,E402,E401,C901,E203", + "--max-line-length=110", + "--exclude", + "logs/*,data/*", + "--per-file-ignores", + "__init__.py:F401", + ] + + # yaml formatting + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.0.3 + hooks: + - id: prettier + types: [yaml] + exclude: "environment.yaml" + + # shell scripts linter + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.9.0.5 + hooks: + - id: shellcheck + + # md formatting + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.17 + hooks: + - id: mdformat + args: ["--number"] + additional_dependencies: + - mdformat-gfm + - mdformat-tables + - mdformat_frontmatter + - mdformat-black + - mdformat-config + - mdformat-shfmt + - mdformat-mkdocs + - mdformat-toc + - mdformat-admon + + # word spelling linter + - repo: https://github.com/codespell-project/codespell + rev: v2.2.5 + hooks: + - id: codespell + args: + - --skip=logs/**,data/**,*.ipynb,*.bib,env.yml,env_cpu.yml,*.svg,poetry.lock + - --ignore-words-list=ehr,crate + + # jupyter notebook cell output clearing + - repo: https://github.com/kynan/nbstripout + rev: 0.6.1 + hooks: + - id: nbstripout + + # jupyter notebook linting + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.7.0 + hooks: + - id: nbqa-black + args: ["--line-length=110"] + - id: nbqa-isort + args: ["--profile=black"] + - id: nbqa-flake8 + args: ["--extend-ignore=E203,E402,E501,F401,F841", "--exclude=logs/*,data/*"] diff --git a/mixins/__init__.py b/mixins/__init__.py deleted file mode 100644 index a3b8dba..0000000 --- a/mixins/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# __all__ = [ -# 'debuggable', -# 'multiprocessingable', -# 'saveable', -# 'swapcacheable', -# 'tensorable', -# 'timeable', -# 'tqdmable', -# ] - -from .debuggable import DebuggableMixin -from .multiprocessingable import MultiprocessingMixin -from .saveable import SaveableMixin -from .seedable import SeedableMixin -from .swapcacheable import SwapcacheableMixin -from .timeable import TimeableMixin - -# Tensorable and Tqdmable rely on packages that may or may not be installed. - -try: from .tensorable import TensorableMixin -except ImportError as e: pass - -try: from .tqdmable import TQDMableMixin -except ImportError as e: pass diff --git a/pyproject.toml b/pyproject.toml index e1aff84..21775d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [build-system] -requires = ["setuptools>=61.0"] +requires = ["setuptools>=64", "setuptools-scm>=8.0", "wheel"] build-backend = "setuptools.build_meta" [project] name = "ml_mixins" -version = "0.0.6" +dynamic = ["version"] authors = [ { name="Matthew B. A. McDermott", email="mattmcdermott8@gmail.com" }, ] @@ -16,6 +16,14 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] +dependencies = ["numpy"] + +[project.optional-dependencies] +dev = ["pre-commit"] +tests = ["pytest", "pytest-cov"] +tqdmable = ["tqdm"] + +[tool.setuptools_scm] [project.urls] "Homepage" = "https://github.com/mmcdermott/ML_mixins" diff --git a/setup.py b/setup.py deleted file mode 100644 index 80f7cb8..0000000 --- a/setup.py +++ /dev/null @@ -1,25 +0,0 @@ -import setuptools - -with open("README.md", "r") as fh: long_description = fh.read() - -setuptools.setup( - name="ml_mixins_mmd", # Replace with your own username - version="0.0.1", - author="Matthew McDermott", - author_email="mattmcdermott8@gmail.com", - description="Various ML / data-science Mixins", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/mmcdermott/ML_mixins", - packages=setuptools.find_packages(), - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - install_requires=[ - 'numpy', - ], - python_requires='>=3.7', - test_suite='tests', -) diff --git a/src/mixins/__init__.py b/src/mixins/__init__.py new file mode 100644 index 0000000..23c83a0 --- /dev/null +++ b/src/mixins/__init__.py @@ -0,0 +1,29 @@ +from .debuggable import DebuggableMixin +from .multiprocessingable import MultiprocessingMixin +from .saveable import SaveableMixin +from .seedable import SeedableMixin +from .swapcacheable import SwapcacheableMixin +from .timeable import TimeableMixin + +__all__ = [ + "DebuggableMixin", + "MultiprocessingMixin", + "SaveableMixin", + "SeedableMixin", + "SwapcacheableMixin", + "TimeableMixin", +] + +# Tensorable and Tqdmable rely on packages that may or may not be installed. + +try: + from .tensorable import TensorableMixin + __all__.append("TensorableMixin") +except ImportError as e: + pass + +try: + from .tqdmable import TQDMableMixin + __all__.append("TQDMableMixin") +except ImportError as e: + pass diff --git a/mixins/debuggable.py b/src/mixins/debuggable.py similarity index 100% rename from mixins/debuggable.py rename to src/mixins/debuggable.py diff --git a/mixins/multiprocessingable.py b/src/mixins/multiprocessingable.py similarity index 100% rename from mixins/multiprocessingable.py rename to src/mixins/multiprocessingable.py diff --git a/mixins/saveable.py b/src/mixins/saveable.py similarity index 100% rename from mixins/saveable.py rename to src/mixins/saveable.py diff --git a/mixins/seedable.py b/src/mixins/seedable.py similarity index 100% rename from mixins/seedable.py rename to src/mixins/seedable.py diff --git a/mixins/swapcacheable.py b/src/mixins/swapcacheable.py similarity index 100% rename from mixins/swapcacheable.py rename to src/mixins/swapcacheable.py diff --git a/mixins/tensorable.py b/src/mixins/tensorable.py similarity index 100% rename from mixins/tensorable.py rename to src/mixins/tensorable.py diff --git a/mixins/timeable.py b/src/mixins/timeable.py similarity index 100% rename from mixins/timeable.py rename to src/mixins/timeable.py diff --git a/mixins/tqdmable.py b/src/mixins/tqdmable.py similarity index 100% rename from mixins/tqdmable.py rename to src/mixins/tqdmable.py diff --git a/mixins/utils.py b/src/mixins/utils.py similarity index 100% rename from mixins/utils.py rename to src/mixins/utils.py diff --git a/tests/test_saveable_mixin.py b/tests/test_saveable_mixin.py index 26a80a2..6b2133b 100644 --- a/tests/test_saveable_mixin.py +++ b/tests/test_saveable_mixin.py @@ -1,6 +1,3 @@ -import sys -sys.path.append('..') - import unittest from pathlib import Path from tempfile import TemporaryDirectory diff --git a/tests/test_seedable_mixin.py b/tests/test_seedable_mixin.py index e0daab9..051dae7 100644 --- a/tests/test_seedable_mixin.py +++ b/tests/test_seedable_mixin.py @@ -1,6 +1,3 @@ -import sys -sys.path.append('..') - import unittest import random, numpy as np diff --git a/tests/test_swapcacheable_mixin.py b/tests/test_swapcacheable_mixin.py index ac52d46..e5f4a9e 100644 --- a/tests/test_swapcacheable_mixin.py +++ b/tests/test_swapcacheable_mixin.py @@ -1,6 +1,3 @@ -import sys -sys.path.append('..') - import unittest from mixins import SwapcacheableMixin diff --git a/tests/test_tensorable_mixin.py b/tests/test_tensorable_mixin.py index 6c263d9..bfcf1c9 100644 --- a/tests/test_tensorable_mixin.py +++ b/tests/test_tensorable_mixin.py @@ -1,6 +1,3 @@ -import sys -sys.path.append('..') - import unittest try: diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 039610d..28a40bf 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -1,7 +1,6 @@ -import sys, time, numpy as np -sys.path.append('..') - import unittest +import time +import numpy as np from mixins import TimeableMixin diff --git a/tests/test_tqdmable_mixin.py b/tests/test_tqdmable_mixin.py index 372114b..25e0cb0 100644 --- a/tests/test_tqdmable_mixin.py +++ b/tests/test_tqdmable_mixin.py @@ -1,6 +1,3 @@ -import sys -sys.path.append('..') - import unittest from mixins import TQDMableMixin From a7fb0cc94d0574702d03e06bdf170dbe2d92e8d4 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 22 Sep 2024 14:10:24 -0400 Subject: [PATCH 2/4] Linted files. --- .github/workflows/tests.yaml | 2 +- .pre-commit-config.yaml | 6 ++- README.md | 43 ++++++++------- pyproject.toml | 1 + src/mixins/__init__.py | 10 ++-- src/mixins/debuggable.py | 30 +++++++---- src/mixins/multiprocessingable.py | 30 ++++++----- src/mixins/saveable.py | 50 +++++++++++------- src/mixins/seedable.py | 49 ++++++++++------- src/mixins/swapcacheable.py | 43 +++++++++------ src/mixins/tensorable.py | 52 ++++++++++-------- src/mixins/timeable.py | 88 +++++++++++++++++-------------- src/mixins/tqdmable.py | 32 ++++++----- src/mixins/utils.py | 6 ++- tests/test_saveable_mixin.py | 58 ++++++++++++-------- tests/test_seedable_mixin.py | 28 +++++----- tests/test_swapcacheable_mixin.py | 6 ++- tests/test_tensorable_mixin.py | 14 +++-- tests/test_timeable_mixin.py | 66 +++++++++++------------ tests/test_tqdmable_mixin.py | 6 ++- 20 files changed, 357 insertions(+), 263 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c82160d..aface00 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -28,7 +28,7 @@ jobs: - name: Install packages run: | - pip install -e .[tests] + pip install -e .[tests,tqdmable,tensorable] #---------------------------------------------- # run test suite diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61bde52..1b655c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -126,4 +126,8 @@ repos: - id: nbqa-isort args: ["--profile=black"] - id: nbqa-flake8 - args: ["--extend-ignore=E203,E402,E501,F401,F841", "--exclude=logs/*,data/*"] + args: + [ + "--extend-ignore=E203,E402,E501,F401,F841", + "--exclude=logs/*,data/*", + ] diff --git a/README.md b/README.md index 7f25e6e..4f9ca19 100644 --- a/README.md +++ b/README.md @@ -1,35 +1,42 @@ # ML Mixins + ## Installation + this package can be installed via [`pip`](https://pypi.org/project/ml-mixins/): + ``` pip install ml-mixins ``` + Then + ``` from mixins import SeedableMixin ... ``` ## Description + Useful Python Mixins for ML. These are python mixin classes that can be used to add useful bits of discrete functionality to python objects for use in ML / data science. They currently include: - 1. `SeedableMixin` which adds nice seeding capabilities, including functions to seed various stages of - computation in a manner that is both random but also reproducible from a global seed, as well as to store - seeds used at various times so that a subsection of the computation can be reproduced exactly during - debugging outside of the rest of the computation flow. - 2. `TimeableMixin` adds functionality for timing sections of code. - 3. `SaveableMixin` adds customizable save/load functionality (using pickle) - 4. `SwapcacheableMixin`. This one is a bit more niche. It adds a "_swapcache_" to the class, which allows - one to store various iterations of parameters keyed by an arbitrary python object with an equality - operator, with a notion of a "current" setting whose values are then exposed as main class attributes. - The intended use-case is for data processing classes, where it may be desirable to try different - preprocesisng settings, have the object retain derived data for those settings, but present a - front-facing interface that looks like it is only computing a single setting. For example, if running - tfidf under different stopwords and ngram settings, one can run the system via the swapcache under - settings A, and the class can present an interface of `[obj].stop_words`, `obj.ngram_range`, - `obj.tfidf_vectorized_data`, but then this can be transparently updated to a different setting without - discarding that data via the swapcache interface. - 5. `TQDMableMixin`. This one adds a `_tqdm` method to a class which automatically progressbar-ifies ranges - for iteration, unless the range is sufficiently short or the class has `self.tqdm` set to `None` + +1. `SeedableMixin` which adds nice seeding capabilities, including functions to seed various stages of + computation in a manner that is both random but also reproducible from a global seed, as well as to store + seeds used at various times so that a subsection of the computation can be reproduced exactly during + debugging outside of the rest of the computation flow. +2. `TimeableMixin` adds functionality for timing sections of code. +3. `SaveableMixin` adds customizable save/load functionality (using pickle) +4. `SwapcacheableMixin`. This one is a bit more niche. It adds a "_swapcache_" to the class, which allows + one to store various iterations of parameters keyed by an arbitrary python object with an equality + operator, with a notion of a "current" setting whose values are then exposed as main class attributes. + The intended use-case is for data processing classes, where it may be desirable to try different + preprocesisng settings, have the object retain derived data for those settings, but present a + front-facing interface that looks like it is only computing a single setting. For example, if running + tfidf under different stopwords and ngram settings, one can run the system via the swapcache under + settings A, and the class can present an interface of `[obj].stop_words`, `obj.ngram_range`, + `obj.tfidf_vectorized_data`, but then this can be transparently updated to a different setting without + discarding that data via the swapcache interface. +5. `TQDMableMixin`. This one adds a `_tqdm` method to a class which automatically progressbar-ifies ranges + for iteration, unless the range is sufficiently short or the class has `self.tqdm` set to `None` None of these are guaranteed to work or be useful at this point. diff --git a/pyproject.toml b/pyproject.toml index 21775d8..573db20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = ["numpy"] dev = ["pre-commit"] tests = ["pytest", "pytest-cov"] tqdmable = ["tqdm"] +tensorable = ["torch"] [tool.setuptools_scm] diff --git a/src/mixins/__init__.py b/src/mixins/__init__.py index 23c83a0..77f2171 100644 --- a/src/mixins/__init__.py +++ b/src/mixins/__init__.py @@ -17,13 +17,15 @@ # Tensorable and Tqdmable rely on packages that may or may not be installed. try: - from .tensorable import TensorableMixin + from .tensorable import TensorableMixin # noqa + __all__.append("TensorableMixin") -except ImportError as e: +except ImportError: pass try: - from .tqdmable import TQDMableMixin + from .tqdmable import TQDMableMixin # noqa + __all__.append("TQDMableMixin") -except ImportError as e: +except ImportError: pass diff --git a/src/mixins/debuggable.py b/src/mixins/debuggable.py index 61830bb..4ef58b0 100644 --- a/src/mixins/debuggable.py +++ b/src/mixins/debuggable.py @@ -1,39 +1,47 @@ from __future__ import annotations -import functools, inspect, pickle +import functools +import inspect +import pickle from copy import deepcopy from pathlib import Path -from typing import Optional from .utils import doublewrap + class DebuggableMixin: @property def _do_debug(self): - if hasattr(self, 'do_debug'): return self.do_debug - else: return False + if hasattr(self, "do_debug"): + return self.do_debug + else: + return False @staticmethod @doublewrap - def CaptureErrorState(fn, store_global: Optional[bool] = None, filepath: Optional[Path] = None): - if store_global is None: store_global = (filepath is None) + def CaptureErrorState(fn, store_global: bool | None = None, filepath: Path | None = None): + if store_global is None: + store_global = filepath is None @functools.wraps(fn) - def debugging_wrapper(self, *args, seed: Optional[int] = None, **kwargs): - if not self._do_debug: return fn(self, *args, **kwargs) + def debugging_wrapper(self, *args, seed: int | None = None, **kwargs): + if not self._do_debug: + return fn(self, *args, **kwargs) try: return fn(self, *args, **kwargs) - except Exception as e: + except Exception: T = inspect.trace() for t in T: - if t[3] == fn.__name__: break + if t[3] == fn.__name__: + break new_vars = deepcopy(t[0].f_locals) if store_global: __builtins__["_DEBUGGER_VARS"] = new_vars if filepath: - with open(filepath, mode='wb') as f: + with open(filepath, mode="wb") as f: pickle.dump(new_vars, f) raise + return debugging_wrapper diff --git a/src/mixins/multiprocessingable.py b/src/mixins/multiprocessingable.py index d71d726..b8443af 100644 --- a/src/mixins/multiprocessingable.py +++ b/src/mixins/multiprocessingable.py @@ -1,28 +1,32 @@ from __future__ import annotations +from collections.abc import Callable, Sequence from multiprocessing import Pool -from typing import Callable, Optional, Sequence class MultiprocessingMixin: - def __init__(self, *args, multiprocessing_pool_size: Optional[int] = None, **kwargs): + def __init__(self, *args, multiprocessing_pool_size: int | None = None, **kwargs): self.multiprocessing_pool_size = multiprocessing_pool_size @property def _multiprocessing_pool_size(self): - if hasattr(self, 'multiprocessing_pool_size'): return self.multiprocessing_pool_size - else: return None + if hasattr(self, "multiprocessing_pool_size"): + return self.multiprocessing_pool_size + else: + return None @property def _use_multiprocessing(self): - return (self._multiprocessing_pool_size is not None and self._multiprocessing_pool_size > 1) + return self._multiprocessing_pool_size is not None and self._multiprocessing_pool_size > 1 - def _map( - self, fn: Callable, iterable: Sequence, tqdm: Optional[Callable] = None, **tqdm_kwargs - ) -> Sequence: + def _map(self, fn: Callable, iterable: Sequence, tqdm: Callable | None = None, **tqdm_kwargs) -> Sequence: if self._use_multiprocessing: - with Pool(self._multiprocessing_pool_size) as p: - if tqdm is None: return p.map(fn, iterable) - else: return list(tqdm(p.imap(fn, iterable), **tqdm_kwargs)) - elif tqdm is None: return [fn(x) for x in iterable] - else: return [fn(x) for x in tqdm(iterable, **tqdm_kwargs)] + with Pool(self._multiprocessing_pool_size) as p: + if tqdm is None: + return p.map(fn, iterable) + else: + return list(tqdm(p.imap(fn, iterable), **tqdm_kwargs)) + elif tqdm is None: + return [fn(x) for x in iterable] + else: + return [fn(x) for x in tqdm(iterable, **tqdm_kwargs)] diff --git a/src/mixins/saveable.py b/src/mixins/saveable.py index da7870e..fa6ea72 100644 --- a/src/mixins/saveable.py +++ b/src/mixins/saveable.py @@ -1,8 +1,10 @@ from __future__ import annotations import pickle as pickle + try: import dill + dill_imported = True dill_import_error = None except ImportError as e: @@ -10,16 +12,17 @@ dill_imported = False from pathlib import Path -from typing import Optional -class SaveableMixin(): + +class SaveableMixin: _DEL_BEFORE_SAVING_ATTRS = [] # TODO(mmd): Make StrEnum upon conversion to python 3.11 - _PICKLER = 'dill' if dill_imported else 'pickle' + _PICKLER = "dill" if dill_imported else "pickle" def __init__(self, *args, **kwargs): - self.do_overwrite = kwargs.get('do_overwrite', False) - if self._PICKLER == 'dill' and not dill_imported: raise dill_import_error + self.do_overwrite = kwargs.get("do_overwrite", False) + if self._PICKLER == "dill" and not dill_imported: + raise dill_import_error @classmethod def _load(cls, filepath: Path, **add_kwargs) -> None: @@ -28,16 +31,19 @@ def _load(cls, filepath: Path, **add_kwargs) -> None: elif not filepath.is_file(): raise IsADirectoryError(f"{filepath} is not a file.") - with open(filepath, mode='rb') as f: + with open(filepath, mode="rb") as f: match cls._PICKLER: - case 'dill': - if not dill_imported: raise dill_import_error + case "dill": + if not dill_imported: + raise dill_import_error obj = dill.load(f) - case 'pickle': obj = pickle.load(f) + case "pickle": + obj = pickle.load(f) case _: raise NotImplementedError(f"{cls._PICKLER} not supported! Options: {'dill', 'pickle'}") - for a, v in add_kwargs.items(): setattr(obj, a, v) + for a, v in add_kwargs.items(): + setattr(obj, a, v) obj._post_load(add_kwargs) return obj @@ -46,30 +52,34 @@ def _post_load(self, load_add_kwargs: dict) -> None: # Overwrite this in the base class if desired. return - def _save(self, filepath: Path, do_overwrite: Optional[bool] = False) -> None: - if not hasattr(self, 'do_overwrite'): self.do_overwrite = False + def _save(self, filepath: Path, do_overwrite: bool | None = False) -> None: + if not hasattr(self, "do_overwrite"): + self.do_overwrite = False if not (self.do_overwrite or do_overwrite): if filepath.exists(): raise FileExistsError(f"Filepath {filepath} already exists!") skipped_attrs = {} for attr in self._DEL_BEFORE_SAVING_ATTRS: - if hasattr(self, attr): skipped_attrs[attr] = self.__dict__.pop(attr) + if hasattr(self, attr): + skipped_attrs[attr] = self.__dict__.pop(attr) try: - with open(filepath, mode='wb') as f: + with open(filepath, mode="wb") as f: match self._PICKLER: - case 'dill': - if not dill_imported: raise dill_import_error + case "dill": + if not dill_imported: + raise dill_import_error dill.dump(self, f) - case 'pickle': pickle.dump(self, f) + case "pickle": + pickle.dump(self, f) case _: raise NotImplementedError( f"{self._PICKLER} not supported! Options: {'dill', 'pickle'}" ) - except: + except Exception: filepath.unlink() raise - - for attr, val in skipped_attrs.items(): setattr(self, attr, val) + for attr, val in skipped_attrs.items(): + setattr(self, attr, val) diff --git a/src/mixins/seedable.py b/src/mixins/seedable.py index a36bb06..b519d7a 100644 --- a/src/mixins/seedable.py +++ b/src/mixins/seedable.py @@ -1,38 +1,45 @@ from __future__ import annotations -import functools, random, numpy as np - +import functools +import os +import random from datetime import datetime -from typing import Optional + +import numpy as np from .utils import doublewrap -def seed_everything(seed: Optional[int] = None, try_import_torch: Optional[bool] = True) -> int: + +def seed_everything(seed: int | None = None, try_import_torch: bool | None = True) -> int: max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min try: - if seed is None: seed = os.environ.get("PL_GLOBAL_SEED") + if seed is None: + seed = os.environ.get("PL_GLOBAL_SEED") seed = int(seed) except (TypeError, ValueError): seed = np.random.randint(min_seed_value, max_seed_value) - assert (min_seed_value <= seed <= max_seed_value) + assert min_seed_value <= seed <= max_seed_value random.seed(seed) np.random.seed(seed) if try_import_torch: try: import torch + torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - except ModuleNotFoundError: pass + except ModuleNotFoundError: + pass return seed -class SeedableMixin(): + +class SeedableMixin: def __init__(self, *args, **kwargs): - self._past_seeds = kwargs.get('_past_seeds', []) + self._past_seeds = kwargs.get("_past_seeds", []) def _last_seed(self, key: str): for idx, (s, k, time) in enumerate(self._past_seeds[::-1]): @@ -42,25 +49,31 @@ def _last_seed(self, key: str): return -1, None - def _seed(self, seed: Optional[int] = None, key: Optional[str] = None): - if seed is None: seed = random.randint(0, int(1e8)) - if key is None: key = '' + def _seed(self, seed: int | None = None, key: str | None = None): + if seed is None: + seed = random.randint(0, int(1e8)) + if key is None: + key = "" time = str(datetime.now()) self.seed = seed - if hasattr(self, '_past_seeds'): self._past_seeds.append((self.seed, key, time)) - else: self._past_seeds = [(self.seed, key, time)] + if hasattr(self, "_past_seeds"): + self._past_seeds.append((self.seed, key, time)) + else: + self._past_seeds = [(self.seed, key, time)] seed_everything(seed) return seed @staticmethod @doublewrap - def WithSeed(fn, key: Optional[str] = None): - if key is None: key = fn.__name__ + def WithSeed(fn, key: str | None = None): + if key is None: + key = fn.__name__ + @functools.wraps(fn) - def wrapper_seeding(self, *args, seed: Optional[int] = None, **kwargs): + def wrapper_seeding(self, *args, seed: int | None = None, **kwargs): self._seed(seed=seed, key=key) return fn(self, *args, **kwargs) - return wrapper_seeding + return wrapper_seeding diff --git a/src/mixins/swapcacheable.py b/src/mixins/swapcacheable.py index 451bdc7..5103adb 100644 --- a/src/mixins/swapcacheable.py +++ b/src/mixins/swapcacheable.py @@ -1,37 +1,45 @@ from __future__ import annotations import time -from typing import Hashable +from collections.abc import Hashable -class SwapcacheableMixin(): + +class SwapcacheableMixin: def __init__(self, *args, **kwargs): - self._cache_size = kwargs.get('cache_size', 5) + self._cache_size = kwargs.get("cache_size", 5) def _init_attrs(self): - if not hasattr(self, '_cache'): self._cache = {'keys': [], 'values': []} - if not hasattr(self, '_cache_size'): self._cache_size = 5 - if not hasattr(self, '_front_attrs'): self._front_attrs = [] - if not hasattr(self, '_front_cache_key'): self._front_cache_key = None - if not hasattr(self, '_front_cache_idx'): self._front_cache_idx = None + if not hasattr(self, "_cache"): + self._cache = {"keys": [], "values": []} + if not hasattr(self, "_cache_size"): + self._cache_size = 5 + if not hasattr(self, "_front_attrs"): + self._front_attrs = [] + if not hasattr(self, "_front_cache_key"): + self._front_cache_key = None + if not hasattr(self, "_front_cache_idx"): + self._front_cache_idx = None def _set_swapcache_key(self, key: Hashable) -> None: self._init_attrs() - if key == self._front_cache_key: return + if key == self._front_cache_key: + return seen_key = self._swapcache_has_key(key) if seen_key: idx = next(i for i, (k, t) in enumerate(self._seen_parameters) if k == key) else: - self._cache['keys'].append((key, time.time())) - self._cache['values'].append({}) + self._cache["keys"].append((key, time.time())) + self._cache["values"].append({}) - self._cache['keys'] = self._cache['keys'][-self._cache_size:] - self._cache['values'] = self._cache['values'][-self._cache_size:] + self._cache["keys"] = self._cache["keys"][-self._cache_size :] + self._cache["values"] = self._cache["values"][-self._cache_size :] idx = -1 # Clear out the old front-and-center attributes - for attr in self._front_attrs: delattr(self, attr) + for attr in self._front_attrs: + delattr(self, attr) self._front_cache_key = key self._front_cache_idx = idx @@ -40,7 +48,7 @@ def _set_swapcache_key(self, key: Hashable) -> None: def _swapcache_has_key(self, key: Hashable) -> bool: self._init_attrs() - return any(k == key for k, t in self._cache['keys']) + return any(k == key for k, t in self._cache["keys"]) def _swap_to_key(self, key: Hashable) -> None: self._init_attrs() @@ -50,14 +58,15 @@ def _swap_to_key(self, key: Hashable) -> None: def _update_front_attrs(self): self._init_attrs() # Set the new front-and-center attributes - for key, val in self._cache['values'][self._front_cache_idx].items(): setattr(self, key, val) + for key, val in self._cache["values"][self._front_cache_idx].items(): + setattr(self, key, val) def _update_swapcache_key_and_swap(self, key: Hashable, values_dict: dict): self._init_attrs() assert key is not None self._set_swapcache_key(key) - self._cache['values'][self._front_cache_idx].update(values_dict) + self._cache["values"][self._front_cache_idx].update(values_dict) self._update_front_attrs() def _update_current_swapcache_key(self, values_dict: dict): diff --git a/src/mixins/tensorable.py b/src/mixins/tensorable.py index ca908c3..d07a2b5 100644 --- a/src/mixins/tensorable.py +++ b/src/mixins/tensorable.py @@ -1,33 +1,39 @@ from __future__ import annotations -try: - import torch, numpy as np - from typing import Dict, Hashable, List, Optional, Tuple, Union +from collections.abc import Hashable +from typing import Union - class TensorableMixin(): - Tensorable_T = Union[np.ndarray, List[float], Tuple['Tensorable_T'], Dict[Hashable, 'Tensorable_T']] - Tensor_T = Union[torch.Tensor, Tuple['Tensor_T'], Dict[Hashable, 'Tensor_T']] +import numpy as np +import torch - def __init__(self, *args, **kwargs): - self.do_cuda = kwargs.get('do_cuda', torch.cuda.is_available) - def _cuda(self, T: torch.Tensor, do_cuda: Optional[bool] = None): - if do_cuda is None: - do_cuda = self.do_cuda if hasattr(self, 'do_cuda') else torch.cuda.is_available +class TensorableMixin: + Tensorable_T = Union[np.ndarray, list[float], tuple["Tensorable_T"], dict[Hashable, "Tensorable_T"]] + Tensor_T = Union[torch.Tensor, tuple["Tensor_T"], dict[Hashable, "Tensor_T"]] - return T.cuda() if do_cuda else T + def __init__(self, *args, **kwargs): + self.do_cuda = kwargs.get("do_cuda", torch.cuda.is_available) - def _from_numpy(self, obj: np.ndarray) -> torch.Tensor: - # I keep getting errors about "RuntimeError: expected scalar type Float but found Double" - if obj.dtype == np.float64: obj = obj.astype(np.float32) - return self._cuda(torch.from_numpy(obj)) + def _cuda(self, T: torch.Tensor, do_cuda: bool | None = None): + if do_cuda is None: + do_cuda = self.do_cuda if hasattr(self, "do_cuda") else torch.cuda.is_available - def _nested_to_tensor(self, obj: TensorableMixin.Tensorable_T) -> TensorableMixin.Tensor_T: - if isinstance(obj, np.ndarray): return self._from_numpy(obj) - elif isinstance(obj, list): return self._from_numpy(np.array(obj)) - elif isinstance(obj, dict): return {k: self._nested_to_tensor(v) for k, v in obj.items()} - elif isinstance(obj, tuple): return tuple((self._nested_to_tensor(e) for e in obj)) + return T.cuda() if do_cuda else T - raise ValueError(f"Don't know how to convert {type(obj)} object {obj} to tensor!") + def _from_numpy(self, obj: np.ndarray) -> torch.Tensor: + # I keep getting errors about "RuntimeError: expected scalar type Float but found Double" + if obj.dtype == np.float64: + obj = obj.astype(np.float32) + return self._cuda(torch.from_numpy(obj)) -except ImportError as e: pass + def _nested_to_tensor(self, obj: TensorableMixin.Tensorable_T) -> TensorableMixin.Tensor_T: + if isinstance(obj, np.ndarray): + return self._from_numpy(obj) + elif isinstance(obj, list): + return self._from_numpy(np.array(obj)) + elif isinstance(obj, dict): + return {k: self._nested_to_tensor(v) for k, v in obj.items()} + elif isinstance(obj, tuple): + return tuple(self._nested_to_tensor(e) for e in obj) + + raise ValueError(f"Don't know how to convert {type(obj)} object {obj} to tensor!") diff --git a/src/mixins/timeable.py b/src/mixins/timeable.py index a5d8939..14fb8ca 100644 --- a/src/mixins/timeable.py +++ b/src/mixins/timeable.py @@ -1,31 +1,35 @@ from __future__ import annotations -import functools, time, numpy as np +import functools +import time from collections import defaultdict from contextlib import contextmanager -from typing import Optional, Set, Tuple + +import numpy as np from .utils import doublewrap -class TimeableMixin(): - _START_TIME = 'start' - _END_TIME = 'end' + +class TimeableMixin: + _START_TIME = "start" + _END_TIME = "end" _CUTOFFS_AND_UNITS = [ - (1000, 'μs'), - (1000, 'ms'), - (60, 'sec'), - (60, 'min'), - (24, 'hour'), - (7, 'days'), - (None, 'weeks') + (1000, "μs"), + (1000, "ms"), + (60, "sec"), + (60, "min"), + (24, "hour"), + (7, "days"), + (None, "weeks"), ] @classmethod - def _get_pprint_num_unit(cls, x: float, x_unit: str = 'sec') -> Tuple[float, str]: + def _get_pprint_num_unit(cls, x: float, x_unit: str = "sec") -> tuple[float, str]: x_unit_factor = 1 for fac, unit in cls._CUTOFFS_AND_UNITS: - if unit == x_unit: break + if unit == x_unit: + break if fac is None: raise LookupError( f"Passed unit {x_unit} invalid! " @@ -36,35 +40,37 @@ def _get_pprint_num_unit(cls, x: float, x_unit: str = 'sec') -> Tuple[float, str min_unit = x * x_unit_factor upper_bound = 1 for upper_bound_factor, unit in cls._CUTOFFS_AND_UNITS: - if ( - (upper_bound_factor is None) or - (min_unit < upper_bound * upper_bound_factor) - ): return min_unit / upper_bound, unit + if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor): + return min_unit / upper_bound, unit upper_bound *= upper_bound_factor @classmethod - def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: Optional[float] = None) -> str: + def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: float | None = None) -> str: mean_time, mean_unit = cls._get_pprint_num_unit(mean_sec) if std_seconds: - std_time = std_seconds * mean_time/mean_sec + std_time = std_seconds * mean_time / mean_sec mean_std_str = f"{mean_time:.1f} ± {std_time:.1f} {mean_unit}" - else: mean_std_str = f"{mean_time:.1f} {mean_unit}" + else: + mean_std_str = f"{mean_time:.1f} {mean_unit}" - if n_times > 1: return f"{mean_std_str} (x{n_times})" - else: return mean_std_str + if n_times > 1: + return f"{mean_std_str} (x{n_times})" + else: + return mean_std_str def __init__(self, *args, **kwargs): - self._timings = kwargs.get('_timings', defaultdict(list)) + self._timings = kwargs.get("_timings", defaultdict(list)) def __assert_key_exists(self, key: str) -> None: - assert hasattr(self, '_timings') and key in self._timings, f"{key} should exist in self._timings!" + assert hasattr(self, "_timings") and key in self._timings, f"{key} should exist in self._timings!" - def _times_for(self, key: str) -> List[float]: + def _times_for(self, key: str) -> list[float]: self.__assert_key_exists(key) return [ - t[self._END_TIME] - t[self._START_TIME] for t in self._timings[key] \ - if self._START_TIME in t and self._END_TIME in t + t[self._END_TIME] - t[self._START_TIME] + for t in self._timings[key] + if self._START_TIME in t and self._END_TIME in t ] def _time_so_far(self, key: str) -> float: @@ -73,12 +79,13 @@ def _time_so_far(self, key: str) -> float: return time.time() - self._timings[key][-1][self._START_TIME] def _register_start(self, key: str) -> None: - if not hasattr(self, '_timings'): self._timings = defaultdict(list) + if not hasattr(self, "_timings"): + self._timings = defaultdict(list) self._timings[key].append({self._START_TIME: time.time()}) def _register_end(self, key: str) -> None: - assert hasattr(self, '_timings') + assert hasattr(self, "_timings") assert key in self._timings and len(self._timings[key]) > 0 assert self._timings[key][-1].get(self._END_TIME, None) is None self._timings[key][-1][self._END_TIME] = time.time() @@ -93,25 +100,28 @@ def _time_as(self, key: str): @staticmethod @doublewrap - def TimeAs(fn, key: Optional[str] = None): - if key is None: key = fn.__name__ + def TimeAs(fn, key: str | None = None): + if key is None: + key = fn.__name__ + @functools.wraps(fn) - def wrapper_timing(self, *args, seed: Optional[int] = None, **kwargs): + def wrapper_timing(self, *args, seed: int | None = None, **kwargs): self._register_start(key=key) out = fn(self, *args, **kwargs) self._register_end(key=key) return out + return wrapper_timing @property - def _duration_stats(self): + def _duration_stats(self): out = {} for k in self._timings: arr = np.array(self._times_for(k)) out[k] = (arr.mean(), len(arr), arr.std()) return out - def _profile_durations(self, only_keys: Optional[Set[str]] = None): + def _profile_durations(self, only_keys: set[str] | None = None): stats = self._duration_stats if only_keys is not None: @@ -119,10 +129,8 @@ def _profile_durations(self, only_keys: Optional[Set[str]] = None): longest_key_length = max(len(k) for k in stats) ordered_keys = sorted(stats.keys(), key=lambda k: stats[k][0] * stats[k][1]) - tfk_str = '\n'.join( - ( - f"{k}:{' '*(longest_key_length - len(k))} " - f"{self._pprint_duration(*stats[k])}" - ) for k in ordered_keys + tfk_str = "\n".join( + (f"{k}:{' '*(longest_key_length - len(k))} " f"{self._pprint_duration(*stats[k])}") + for k in ordered_keys ) return tfk_str diff --git a/src/mixins/tqdmable.py b/src/mixins/tqdmable.py index 835e85a..b7f129f 100644 --- a/src/mixins/tqdmable.py +++ b/src/mixins/tqdmable.py @@ -1,23 +1,27 @@ from __future__ import annotations -try: - from tqdm.auto import tqdm +from tqdm.auto import tqdm - class TQDMableMixin(): - _SKIP_TQDM_IF_LE = 3 - def __init__(self, *args, **kwargs): - self.tqdm = kwargs.get('tqdm', tqdm) +class TQDMableMixin: + _SKIP_TQDM_IF_LE = 3 - def _tqdm(self, rng, **kwargs): - if not hasattr(self, 'tqdm'): self.tqdm = tqdm + def __init__(self, *args, **kwargs): + self.tqdm = kwargs.get("tqdm", tqdm) - if self.tqdm is None: return rng + def _tqdm(self, rng, **kwargs): + if not hasattr(self, "tqdm"): + self.tqdm = tqdm - try: N = len(rng) - except: return rng + if self.tqdm is None: + return rng - if N <= self._SKIP_TQDM_IF_LE: return rng + try: + N = len(rng) + except Exception: + return rng - return tqdm(rng, **kwargs) -except ImportError as e: pass + if N <= self._SKIP_TQDM_IF_LE: + return rng + + return tqdm(rng, **kwargs) diff --git a/src/mixins/utils.py b/src/mixins/utils.py index e8419c7..9e74e5f 100644 --- a/src/mixins/utils.py +++ b/src/mixins/utils.py @@ -1,12 +1,14 @@ import functools + def doublewrap(f): - ''' + """ a decorator decorator, allowing the decorator to be used as: @decorator(with, arguments, and=kwargs) or @decorator - ''' + """ + @functools.wraps(f) def new_dec(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): diff --git a/tests/test_saveable_mixin.py b/tests/test_saveable_mixin.py index 6b2133b..21b8865 100644 --- a/tests/test_saveable_mixin.py +++ b/tests/test_saveable_mixin.py @@ -5,66 +5,78 @@ from mixins import SaveableMixin + class Derived(SaveableMixin): - _PICKLER = 'pickle' + _PICKLER = "pickle" - def __init__(self, a: int = -1, b: str = 'unset', **kwargs): + def __init__(self, a: int = -1, b: str = "unset", **kwargs): super().__init__(**kwargs) self.a = a self.b = b def __eq__(self, other: Any) -> bool: - return type(self) == type(other) and (self.a == other.a) and (self.b == other.b) + return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) + class DillDerived(SaveableMixin): - _PICKLER = 'dill' + _PICKLER = "dill" - def __init__(self, a: int = -1, b: str = 'unset', **kwargs): + def __init__(self, a: int = -1, b: str = "unset", **kwargs): super().__init__(**kwargs) self.a = a self.b = b def __eq__(self, other: Any) -> bool: - return type(self) == type(other) and (self.a == other.a) and (self.b == other.b) + return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) + class BadDerived(SaveableMixin): - _PICKLER = 'not_supported' + _PICKLER = "not_supported" - def __init__(self, a: int = -1, b: str = 'unset', **kwargs): + def __init__(self, a: int = -1, b: str = "unset", **kwargs): super().__init__(**kwargs) self.a = a self.b = b def __eq__(self, other: Any) -> bool: - return type(self) == type(other) and (self.a == other.a) and (self.b == other.b) + return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) + class TestSaveableMixin(unittest.TestCase): def test_saveable_mixin(self): - T = Derived(a=2, b='hi') + T = Derived(a=2, b="hi") with TemporaryDirectory() as d: - save_path = Path(d) / 'save.pkl' + save_path = Path(d) / "save.pkl" T._save(save_path) with self.assertRaises(FileExistsError): - new_t = Derived(a=3, b='bar') + new_t = Derived(a=3, b="bar") new_t._save(save_path) got_T = Derived._load(save_path) self.assertEqual(T, got_T) - bad_T = BadDerived(a=2, b='hi') - with self.assertRaises(NotImplementedError): bad_T._save(Path(d) / 'no_save.pkl') + bad_T = BadDerived(a=2, b="hi") + with self.assertRaises(NotImplementedError): + bad_T._save(Path(d) / "no_save.pkl") # This should error as that pickler isn't supported. - with self.assertRaises(FileNotFoundError): got_T = Derived._load(Path(d) / 'no_save.pkl') - with self.assertRaises(IsADirectoryError): got_T = Derived._load(Path(d)) + with self.assertRaises(FileNotFoundError): + got_T = Derived._load(Path(d) / "no_save.pkl") + with self.assertRaises(IsADirectoryError): + got_T = Derived._load(Path(d)) # This should error as dill isn't installed. - with self.assertRaises(ImportError): bad_T = DillDerived(a=3, b='baz') - T._PICKLER = 'dill' - with self.assertRaises(ImportError): T._save(Path(d) / 'no_save.pkl') - Derived._PICKLER = 'dill' - with self.assertRaises(ImportError): got_T = Derived._load(save_path) - -if __name__ == '__main__': unittest.main() + with self.assertRaises(ImportError): + bad_T = DillDerived(a=3, b="baz") + T._PICKLER = "dill" + with self.assertRaises(ImportError): + T._save(Path(d) / "no_save.pkl") + Derived._PICKLER = "dill" + with self.assertRaises(ImportError): + got_T = Derived._load(save_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_seedable_mixin.py b/tests/test_seedable_mixin.py index 051dae7..0fa8432 100644 --- a/tests/test_seedable_mixin.py +++ b/tests/test_seedable_mixin.py @@ -1,12 +1,14 @@ +import random import unittest -import random, numpy as np +import numpy as np from mixins import SeedableMixin + class SeedableDerived(SeedableMixin): def __init__(self): - self.foo = 'foo' + self.foo = "foo" # Doesn't call super().__init__()! Should still work in this case. def gen_random_num(self): @@ -21,21 +23,21 @@ def decorated_gen_random_num(self): def decorated_auto_key(self): return random.random() + class TestSeedableMixin(unittest.TestCase): def test_constructs(self): - T = SeedableMixin() - T = SeedableDerived() + SeedableMixin() + SeedableDerived() def test_responds_to_methods(self): T = SeedableMixin() T._seed() - T._last_seed('foo') + T._last_seed("foo") T = SeedableDerived() T._seed() - T._last_seed('foo') - + T._last_seed("foo") def test_seeding_freezes_randomness(self): T = SeedableDerived() @@ -93,7 +95,6 @@ def test_decorated_seeding_freezes_randomness(self): self.assertEqual(seeded_1_1, seeded_1_3) self.assertEqual(seeded_2_1, seeded_2_3) - def test_seeds_follow_consistent_sequence(self): T = SeedableDerived() @@ -119,12 +120,12 @@ def test_seeds_follow_consistent_sequence(self): def test_get_last_seed(self): T = SeedableDerived() - key = 'key' - non_key = 'not_key' + key = "key" + non_key = "not_key" seed_key_early = 1 - seed_key_late = 1 - seed_non_key = 2 + seed_key_late = 1 + seed_non_key = 2 T._seed() @@ -142,5 +143,6 @@ def test_get_last_seed(self): self.assertEqual(idx, 4) self.assertEqual(seed, seed_key_late) -if __name__ == '__main__': unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_swapcacheable_mixin.py b/tests/test_swapcacheable_mixin.py index e5f4a9e..eee0627 100644 --- a/tests/test_swapcacheable_mixin.py +++ b/tests/test_swapcacheable_mixin.py @@ -2,9 +2,11 @@ from mixins import SwapcacheableMixin + class TestSwapcacheableMixin(unittest.TestCase): def test_constructs(self): - T = SwapcacheableMixin() + SwapcacheableMixin() -if __name__ == '__main__': unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tensorable_mixin.py b/tests/test_tensorable_mixin.py index bfcf1c9..9e485a0 100644 --- a/tests/test_tensorable_mixin.py +++ b/tests/test_tensorable_mixin.py @@ -1,14 +1,12 @@ import unittest -try: - import torch - from mixins import TensorableMixin +from mixins import TensorableMixin - class TestTensorableMixin(unittest.TestCase): - def test_constructs(self): - T = TensorableMixin() -except ImportError: pass +class TestTensorableMixin(unittest.TestCase): + def test_constructs(self): + TensorableMixin() -if __name__ == '__main__': unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 28a40bf..4b70b70 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -1,16 +1,18 @@ -import unittest import time +import unittest + import numpy as np from mixins import TimeableMixin + class TimeableDerived(TimeableMixin): def __init__(self): - self.foo = 'foo' + self.foo = "foo" # Doesn't call super().__init__()! Should still work in this case. def uses_contextlib(self, num_seconds: int = 5): - with self._time_as('using_contextlib'): + with self._time_as("using_contextlib"): time.sleep(num_seconds) @TimeableMixin.TimeAs(key="decorated") @@ -21,72 +23,70 @@ def decorated_takes_time(self, num_seconds: int = 10): def decorated_takes_time_auto_key(self, num_seconds: int = 10): time.sleep(num_seconds) + class TestTimeableMixin(unittest.TestCase): def test_constructs(self): - T = TimeableMixin() - T = TimeableDerived() + TimeableMixin() + TimeableDerived() def test_responds_to_methods(self): T = TimeableMixin() - T._register_start('key') - T._register_end('key') + T._register_start("key") + T._register_end("key") - T._times_for('key') + T._times_for("key") - T._register_start('key') - T._time_so_far('key') - T._register_end('key') + T._register_start("key") + T._time_so_far("key") + T._register_end("key") def test_pprint_num_unit(self): - self.assertEqual((5, 'μs'), TimeableMixin._get_pprint_num_unit(5 * 1e-6)) + self.assertEqual((5, "μs"), TimeableMixin._get_pprint_num_unit(5 * 1e-6)) class Derived(TimeableMixin): - _CUTOFFS_AND_UNITS = [ - (10, 'foo'), - (2, 'bar'), - (None, 'biz') - ] + _CUTOFFS_AND_UNITS = [(10, "foo"), (2, "bar"), (None, "biz")] - self.assertEqual((3, 'biz'), Derived._get_pprint_num_unit(3, 'biz')) - self.assertEqual((3, 'foo'), Derived._get_pprint_num_unit(3/20, 'biz')) - self.assertEqual((1.2, 'biz'), Derived._get_pprint_num_unit(2.4 * 10, 'foo')) + self.assertEqual((3, "biz"), Derived._get_pprint_num_unit(3, "biz")) + self.assertEqual((3, "foo"), Derived._get_pprint_num_unit(3 / 20, "biz")) + self.assertEqual((1.2, "biz"), Derived._get_pprint_num_unit(2.4 * 10, "foo")) def test_context_manager(self): T = TimeableDerived() T.uses_contextlib(num_seconds=1) - duration = T._times_for('using_contextlib')[-1] + duration = T._times_for("using_contextlib")[-1] np.testing.assert_almost_equal(duration, 1, decimal=1) def test_times_and_profiling(self): T = TimeableDerived() T.decorated_takes_time(num_seconds=2) - duration = T._times_for('decorated')[-1] + duration = T._times_for("decorated")[-1] np.testing.assert_almost_equal(duration, 2, decimal=1) T.decorated_takes_time_auto_key(num_seconds=2) - duration = T._times_for('decorated_takes_time_auto_key')[-1] + duration = T._times_for("decorated_takes_time_auto_key")[-1] np.testing.assert_almost_equal(duration, 2, decimal=1) T.decorated_takes_time(num_seconds=1) stats = T._duration_stats - self.assertEqual({'decorated', 'decorated_takes_time_auto_key'}, set(stats.keys())) - np.testing.assert_almost_equal(1.5, stats['decorated'][0], decimal=1) - self.assertEqual(2, stats['decorated'][1]) - np.testing.assert_almost_equal(0.5, stats['decorated'][2], decimal=1) - np.testing.assert_almost_equal(2, stats['decorated_takes_time_auto_key'][0], decimal=1) - self.assertEqual(1, stats['decorated_takes_time_auto_key'][1]) - self.assertEqual(0, stats['decorated_takes_time_auto_key'][2]) + self.assertEqual({"decorated", "decorated_takes_time_auto_key"}, set(stats.keys())) + np.testing.assert_almost_equal(1.5, stats["decorated"][0], decimal=1) + self.assertEqual(2, stats["decorated"][1]) + np.testing.assert_almost_equal(0.5, stats["decorated"][2], decimal=1) + np.testing.assert_almost_equal(2, stats["decorated_takes_time_auto_key"][0], decimal=1) + self.assertEqual(1, stats["decorated_takes_time_auto_key"][1]) + self.assertEqual(0, stats["decorated_takes_time_auto_key"][2]) got_str = T._profile_durations() want_str = ( - "decorated_takes_time_auto_key: 2.0 sec\n" - "decorated: 1.5 ± 0.5 sec (x2)" + "decorated_takes_time_auto_key: 2.0 sec\n" "decorated: 1.5 ± 0.5 sec (x2)" ) self.assertEqual(want_str, got_str, msg=f"Want:\n{want_str}\nGot:\n{got_str}") -if __name__ == '__main__': unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tqdmable_mixin.py b/tests/test_tqdmable_mixin.py index 25e0cb0..c881cad 100644 --- a/tests/test_tqdmable_mixin.py +++ b/tests/test_tqdmable_mixin.py @@ -2,9 +2,11 @@ from mixins import TQDMableMixin + class TestTQDMableMixin(unittest.TestCase): def test_constructs(self): - T = TQDMableMixin() + TQDMableMixin() -if __name__ == '__main__': unittest.main() +if __name__ == "__main__": + unittest.main() From 6b60297fc6632a5322d3a06bcb94ff805d130b31 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 22 Sep 2024 17:06:58 -0400 Subject: [PATCH 3/4] Fix typo with torch cuda is_available call Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/mixins/tensorable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mixins/tensorable.py b/src/mixins/tensorable.py index d07a2b5..7393837 100644 --- a/src/mixins/tensorable.py +++ b/src/mixins/tensorable.py @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs): def _cuda(self, T: torch.Tensor, do_cuda: bool | None = None): if do_cuda is None: - do_cuda = self.do_cuda if hasattr(self, "do_cuda") else torch.cuda.is_available + do_cuda = self.do_cuda if hasattr(self, "do_cuda") else torch.cuda.is_available() return T.cuda() if do_cuda else T From 8a303a3bbb83f603dec1aa979afcf1a5be03de9b Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 22 Sep 2024 17:07:14 -0400 Subject: [PATCH 4/4] Fix typo with torch cuda is_available call Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/mixins/tensorable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mixins/tensorable.py b/src/mixins/tensorable.py index 7393837..c603fe1 100644 --- a/src/mixins/tensorable.py +++ b/src/mixins/tensorable.py @@ -12,7 +12,7 @@ class TensorableMixin: Tensor_T = Union[torch.Tensor, tuple["Tensor_T"], dict[Hashable, "Tensor_T"]] def __init__(self, *args, **kwargs): - self.do_cuda = kwargs.get("do_cuda", torch.cuda.is_available) + self.do_cuda = kwargs.get("do_cuda", torch.cuda.is_available()) def _cuda(self, T: torch.Tensor, do_cuda: bool | None = None): if do_cuda is None: