diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9718018c..341f3778 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -47,8 +47,12 @@ jobs: - name: Install dependencies run: pdm install - name: Test with pytest (very fast) + env: + JAX_PLATFORMS: cpu run: pdm run pytest -v --shorter-than=1.0 --cov=project --cov-report=xml --cov-append - name: Test with pytest (fast) + env: + JAX_PLATFORMS: cpu run: pdm run pytest -v --cov=project --cov-report=xml --cov-append - name: Store coverage report as an artifact diff --git a/conftest.py b/conftest.py index 9f7490b3..ff0ff888 100644 --- a/conftest.py +++ b/conftest.py @@ -1,14 +1,9 @@ +from pathlib import Path + import pytest def pytest_addoption(parser: pytest.Parser): - from argparse import BooleanOptionalAction - - parser.addoption( - "--gen-missing", - action=BooleanOptionalAction, - help="Whether to generate missing regression files or raise an error when a regression file is missing.", - ) parser.addoption( "--shorter-than", action="store", @@ -18,6 +13,14 @@ def pytest_addoption(parser: pytest.Parser): ) +def pytest_ignore_collect(path: str): + p = Path(path) + # fixme: Trying to fix doctest issues for project/configs/algorithm/lr_scheduler/__init__.py::project.configs.algorithm.lr_scheduler.StepLRConfig + if p.name in ["lr_scheduler", "optimizer"] and "configs" in p.parts: + return True + return False + + def pytest_configure(config: pytest.Config): config.addinivalue_line("markers", "fast: mark test as fast to run (after fixtures are setup)") config.addinivalue_line( diff --git a/pdm.lock b/pdm.lock index 1a03ae58..12a8b637 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:1e5611b1f430e5820256e84761ec57ca9de8a29a13612e23f80653c080095c5b" +content_hash = "sha256:805e3f5f1a98de3530f8ec547141537d7d30b7d7d7ca5a3b5f9477809327ecdd" [[package]] name = "absl-py" @@ -20,7 +20,7 @@ files = [ [[package]] name = "aiohttp" -version = "3.9.3" +version = "3.9.5" requires_python = ">=3.8" summary = "Async http client/server framework (asyncio)" groups = ["default"] @@ -32,22 +32,22 @@ dependencies = [ "yarl<2.0,>=1.0", ] files = [ - {file = "aiohttp-3.9.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:38a19bc3b686ad55804ae931012f78f7a534cce165d089a2059f658f6c91fa60"}, - {file = "aiohttp-3.9.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:770d015888c2a598b377bd2f663adfd947d78c0124cfe7b959e1ef39f5b13869"}, - {file = "aiohttp-3.9.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee43080e75fc92bf36219926c8e6de497f9b247301bbf88c5c7593d931426679"}, - {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52df73f14ed99cee84865b95a3d9e044f226320a87af208f068ecc33e0c35b96"}, - {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc9b311743a78043b26ffaeeb9715dc360335e5517832f5a8e339f8a43581e4d"}, - {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b955ed993491f1a5da7f92e98d5dad3c1e14dc175f74517c4e610b1f2456fb11"}, - {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:504b6981675ace64c28bf4a05a508af5cde526e36492c98916127f5a02354d53"}, - {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6fe5571784af92b6bc2fda8d1925cccdf24642d49546d3144948a6a1ed58ca5"}, - {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ba39e9c8627edc56544c8628cc180d88605df3892beeb2b94c9bc857774848ca"}, - {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e5e46b578c0e9db71d04c4b506a2121c0cb371dd89af17a0586ff6769d4c58c1"}, - {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:938a9653e1e0c592053f815f7028e41a3062e902095e5a7dc84617c87267ebd5"}, - {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:c3452ea726c76e92f3b9fae4b34a151981a9ec0a4847a627c43d71a15ac32aa6"}, - {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ff30218887e62209942f91ac1be902cc80cddb86bf00fbc6783b7a43b2bea26f"}, - {file = "aiohttp-3.9.3-cp312-cp312-win32.whl", hash = "sha256:38f307b41e0bea3294a9a2a87833191e4bcf89bb0365e83a8be3a58b31fb7f38"}, - {file = "aiohttp-3.9.3-cp312-cp312-win_amd64.whl", hash = "sha256:b791a3143681a520c0a17e26ae7465f1b6f99461a28019d1a2f425236e6eedb5"}, - {file = "aiohttp-3.9.3.tar.gz", hash = "sha256:90842933e5d1ff760fae6caca4b2b3edba53ba8f4b71e95dacf2818a2aca06f7"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"}, + {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"}, + {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"}, + {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"}, ] [[package]] @@ -86,7 +86,7 @@ files = [ [[package]] name = "anyio" -version = "4.3.0" +version = "4.4.0" requires_python = ">=3.8" summary = "High level compatibility layer for multiple asynchronous event loop implementations" groups = ["default"] @@ -95,18 +95,8 @@ dependencies = [ "sniffio>=1.1", ] files = [ - {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, - {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, -] - -[[package]] -name = "appdirs" -version = "1.4.4" -summary = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -groups = ["default"] -files = [ - {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, - {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, + {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, + {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, ] [[package]] @@ -167,35 +157,35 @@ files = [ [[package]] name = "blinker" -version = "1.8.1" +version = "1.8.2" requires_python = ">=3.8" summary = "Fast, simple object-to-object and broadcast signaling" groups = ["default"] files = [ - {file = "blinker-1.8.1-py3-none-any.whl", hash = "sha256:5f1cdeff423b77c31b89de0565cd03e5275a03028f44b2b15f912632a58cced6"}, - {file = "blinker-1.8.1.tar.gz", hash = "sha256:da44ec748222dcd0105ef975eed946da197d5bdf8bafb6aa92f5bc89da63fa25"}, + {file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"}, + {file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"}, ] [[package]] name = "boto3" -version = "1.34.66" -requires_python = ">= 3.8" +version = "1.34.116" +requires_python = ">=3.8" summary = "The AWS SDK for Python" groups = ["default"] dependencies = [ - "botocore<1.35.0,>=1.34.66", + "botocore<1.35.0,>=1.34.116", "jmespath<2.0.0,>=0.7.1", "s3transfer<0.11.0,>=0.10.0", ] files = [ - {file = "boto3-1.34.66-py3-none-any.whl", hash = "sha256:036989117c0bc4029daaa4cf713c4ff8c227b3eac6ef0e2118eb4098c114080e"}, - {file = "boto3-1.34.66.tar.gz", hash = "sha256:b1d6be3d5833e56198dc635ff4b428b93e5a2a2bd9bc4d94581a572a1ce97cfe"}, + {file = "boto3-1.34.116-py3-none-any.whl", hash = "sha256:e7f5ab2d1f1b90971a2b9369760c2c6bae49dae98c084a5c3f5c78e3968ace15"}, + {file = "boto3-1.34.116.tar.gz", hash = "sha256:53cb8aeb405afa1cd2b25421e27a951aeb568026675dec020587861fac96ac87"}, ] [[package]] name = "botocore" -version = "1.34.66" -requires_python = ">= 3.8" +version = "1.34.116" +requires_python = ">=3.8" summary = "Low-level, data-driven core of boto 3." groups = ["default"] dependencies = [ @@ -204,13 +194,13 @@ dependencies = [ "urllib3!=2.2.0,<3,>=1.25.4; python_version >= \"3.10\"", ] files = [ - {file = "botocore-1.34.66-py3-none-any.whl", hash = "sha256:92560f8fbdaa9dd221212a3d3a7609219ba0bbf308c13571674c0cda9d8f39e1"}, - {file = "botocore-1.34.66.tar.gz", hash = "sha256:fd7d8742007c220f897cb126b8916ca0cf3724a739d4d716aa5385d7f9d8aeb1"}, + {file = "botocore-1.34.116-py3-none-any.whl", hash = "sha256:ec4d42c816e9b2d87a2439ad277e7dda16a4a614ef6839cf66f4c1a58afa547c"}, + {file = "botocore-1.34.116.tar.gz", hash = "sha256:269cae7ba99081519a9f87d7298e238d9e68ba94eb4f8ddfa906224c34cb8b6c"}, ] [[package]] name = "brax" -version = "0.10.3" +version = "0.10.4" summary = "A differentiable physics engine written in JAX." groups = ["default"] dependencies = [ @@ -239,8 +229,8 @@ dependencies = [ "typing-extensions", ] files = [ - {file = "brax-0.10.3-py3-none-any.whl", hash = "sha256:22fae3e2d33944db8e256e7abbd1f76f3fa19cf8ec9c7d2d43bc37849ac07662"}, - {file = "brax-0.10.3.tar.gz", hash = "sha256:3a334049e3451e3bdb34696b706d0662dabccb574b4b77067b66172e0d883731"}, + {file = "brax-0.10.4-py3-none-any.whl", hash = "sha256:c47affa423ed0b2a987baef2553eeb84e701d52bfaa72695421d8b4ed9a826a5"}, + {file = "brax-0.10.4.tar.gz", hash = "sha256:6646bb5e280d3de2301f4908f236a14333817bdba5c7ec7faf38d4e8a627aec8"}, ] [[package]] @@ -364,28 +354,28 @@ files = [ [[package]] name = "contourpy" -version = "1.2.0" +version = "1.2.1" requires_python = ">=3.9" summary = "Python library for calculating contours of 2D quadrilateral grids" groups = ["default"] dependencies = [ - "numpy<2.0,>=1.20", + "numpy>=1.20", ] files = [ - {file = "contourpy-1.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:575bcaf957a25d1194903a10bc9f316c136c19f24e0985a2b9b5608bdf5dbfe0"}, - {file = "contourpy-1.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9e6c93b5b2dbcedad20a2f18ec22cae47da0d705d454308063421a3b290d9ea4"}, - {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:464b423bc2a009088f19bdf1f232299e8b6917963e2b7e1d277da5041f33a779"}, - {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:68ce4788b7d93e47f84edd3f1f95acdcd142ae60bc0e5493bfd120683d2d4316"}, - {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d7d1f8871998cdff5d2ff6a087e5e1780139abe2838e85b0b46b7ae6cc25399"}, - {file = "contourpy-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e739530c662a8d6d42c37c2ed52a6f0932c2d4a3e8c1f90692ad0ce1274abe0"}, - {file = "contourpy-1.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:247b9d16535acaa766d03037d8e8fb20866d054d3c7fbf6fd1f993f11fc60ca0"}, - {file = "contourpy-1.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:461e3ae84cd90b30f8d533f07d87c00379644205b1d33a5ea03381edc4b69431"}, - {file = "contourpy-1.2.0-cp312-cp312-win32.whl", hash = "sha256:1c2559d6cffc94890b0529ea7eeecc20d6fadc1539273aa27faf503eb4656d8f"}, - {file = "contourpy-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:491b1917afdd8638a05b611a56d46587d5a632cabead889a5440f7c638bc6ed9"}, - {file = "contourpy-1.2.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:be16975d94c320432657ad2402f6760990cb640c161ae6da1363051805fa8108"}, - {file = "contourpy-1.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b95a225d4948b26a28c08307a60ac00fb8671b14f2047fc5476613252a129776"}, - {file = "contourpy-1.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0d7e03c0f9a4f90dc18d4e77e9ef4ec7b7bbb437f7f675be8e530d65ae6ef956"}, - {file = "contourpy-1.2.0.tar.gz", hash = "sha256:171f311cb758de7da13fc53af221ae47a5877be5a0843a9fe150818c51ed276a"}, + {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"}, + {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"}, + {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"}, + {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"}, + {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"}, + {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"}, + {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"}, + {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"}, + {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"}, + {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"}, ] [[package]] @@ -486,16 +476,16 @@ files = [ [[package]] name = "deepdiff" -version = "6.7.1" -requires_python = ">=3.7" +version = "7.0.1" +requires_python = ">=3.8" summary = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." groups = ["default"] dependencies = [ - "ordered-set<4.2.0,>=4.0.2", + "ordered-set<4.2.0,>=4.1.0", ] files = [ - {file = "deepdiff-6.7.1-py3-none-any.whl", hash = "sha256:58396bb7a863cbb4ed5193f548c56f18218060362311aa1dc36397b2f25108bd"}, - {file = "deepdiff-6.7.1.tar.gz", hash = "sha256:b367e6fa6caac1c9f500adc79ada1b5b1242c50d5f716a1a4362030197847d30"}, + {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"}, + {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"}, ] [[package]] @@ -560,24 +550,24 @@ files = [ [[package]] name = "etils" -version = "1.8.0" +version = "1.9.0" requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] files = [ - {file = "etils-1.8.0-py3-none-any.whl", hash = "sha256:f31d7f27a889457eaa44eab18ce836d24fd6d40dbbb167d38879b7296f6456ea"}, - {file = "etils-1.8.0.tar.gz", hash = "sha256:fb478f57fec202e260e54c9192b317692fd63db2d11d993e70bcdffa29cccd58"}, + {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, + {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, ] [[package]] name = "etils" -version = "1.8.0" +version = "1.9.0" extras = ["epath", "epy"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.8.0", + "etils==1.9.0", "etils[epy]", "fsspec", "importlib-resources", @@ -586,19 +576,19 @@ dependencies = [ "zipp", ] files = [ - {file = "etils-1.8.0-py3-none-any.whl", hash = "sha256:f31d7f27a889457eaa44eab18ce836d24fd6d40dbbb167d38879b7296f6456ea"}, - {file = "etils-1.8.0.tar.gz", hash = "sha256:fb478f57fec202e260e54c9192b317692fd63db2d11d993e70bcdffa29cccd58"}, + {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, + {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, ] [[package]] name = "etils" -version = "1.8.0" +version = "1.9.0" extras = ["epath"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.8.0", + "etils==1.9.0", "etils[epy]", "fsspec", "importlib-resources", @@ -606,35 +596,35 @@ dependencies = [ "zipp", ] files = [ - {file = "etils-1.8.0-py3-none-any.whl", hash = "sha256:f31d7f27a889457eaa44eab18ce836d24fd6d40dbbb167d38879b7296f6456ea"}, - {file = "etils-1.8.0.tar.gz", hash = "sha256:fb478f57fec202e260e54c9192b317692fd63db2d11d993e70bcdffa29cccd58"}, + {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, + {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, ] [[package]] name = "etils" -version = "1.8.0" +version = "1.9.0" extras = ["epy"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.8.0", + "etils==1.9.0", "typing-extensions", ] files = [ - {file = "etils-1.8.0-py3-none-any.whl", hash = "sha256:f31d7f27a889457eaa44eab18ce836d24fd6d40dbbb167d38879b7296f6456ea"}, - {file = "etils-1.8.0.tar.gz", hash = "sha256:fb478f57fec202e260e54c9192b317692fd63db2d11d993e70bcdffa29cccd58"}, + {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, + {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, ] [[package]] name = "execnet" -version = "2.0.2" -requires_python = ">=3.7" +version = "2.1.1" +requires_python = ">=3.8" summary = "execnet: rapid multi-Python deployment" groups = ["dev"] files = [ - {file = "execnet-2.0.2-py3-none-any.whl", hash = "sha256:88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41"}, - {file = "execnet-2.0.2.tar.gz", hash = "sha256:cc59bc4423742fd71ad227122eb0dd44db51efb3dc4095b45ac9a08c770096af"}, + {file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"}, + {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"}, ] [[package]] @@ -664,13 +654,13 @@ files = [ [[package]] name = "filelock" -version = "3.13.1" +version = "3.14.0" requires_python = ">=3.8" summary = "A platform independent file lock." -groups = ["default"] +groups = ["default", "dev"] files = [ - {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, - {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, + {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, + {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, ] [[package]] @@ -693,20 +683,20 @@ files = [ [[package]] name = "flask-cors" -version = "4.0.0" +version = "4.0.1" summary = "A Flask extension adding a decorator for CORS support" groups = ["default"] dependencies = [ "Flask>=0.9", ] files = [ - {file = "Flask-Cors-4.0.0.tar.gz", hash = "sha256:f268522fcb2f73e2ecdde1ef45e2fd5c71cc48fe03cffb4b441c6d1b40684eb0"}, - {file = "Flask_Cors-4.0.0-py2.py3-none-any.whl", hash = "sha256:bc3492bfd6368d27cfe79c7821df5a8a319e1a6d5eab277a3794be19bdc51783"}, + {file = "Flask_Cors-4.0.1-py2.py3-none-any.whl", hash = "sha256:f2a704e4458665580c074b714c4627dd5a306b333deb9074d0b1794dfa2fb677"}, + {file = "flask_cors-4.0.1.tar.gz", hash = "sha256:eeb69b342142fdbf4766ad99357a7f3876a2ceb77689dc10ff912aac06c389e4"}, ] [[package]] name = "flax" -version = "0.8.3" +version = "0.8.4" requires_python = ">=3.9" summary = "Flax: A neural network library for JAX designed for flexibility" groups = ["default"] @@ -724,27 +714,27 @@ dependencies = [ "typing-extensions>=4.2", ] files = [ - {file = "flax-0.8.3-py3-none-any.whl", hash = "sha256:87933bc2aa5e70e92ac227a9bd2adeea6b9960a84eade18139a534851ddaf91d"}, - {file = "flax-0.8.3.tar.gz", hash = "sha256:5b051b4c27f4c0c43deb80c5e2509d2ee5ed4441c54b28940855119f83ac7d0f"}, + {file = "flax-0.8.4-py3-none-any.whl", hash = "sha256:785707e3a48f782a1bec17aa665697b7618c113a357d5f975791dcb090d818d8"}, + {file = "flax-0.8.4.tar.gz", hash = "sha256:968683f850198e1aa5eb2d9d1e20bead880ef7423c14f042db9d60848cb1c90b"}, ] [[package]] name = "fonttools" -version = "4.50.0" +version = "4.53.0" requires_python = ">=3.8" summary = "Tools to manipulate font files" groups = ["default"] files = [ - {file = "fonttools-4.50.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b4a886a6dbe60100ba1cd24de962f8cd18139bd32808da80de1fa9f9f27bf1dc"}, - {file = "fonttools-4.50.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b2ca1837bfbe5eafa11313dbc7edada79052709a1fffa10cea691210af4aa1fa"}, - {file = "fonttools-4.50.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0493dd97ac8977e48ffc1476b932b37c847cbb87fd68673dee5182004906828"}, - {file = "fonttools-4.50.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77844e2f1b0889120b6c222fc49b2b75c3d88b930615e98893b899b9352a27ea"}, - {file = "fonttools-4.50.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3566bfb8c55ed9100afe1ba6f0f12265cd63a1387b9661eb6031a1578a28bad1"}, - {file = "fonttools-4.50.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:35e10ddbc129cf61775d58a14f2d44121178d89874d32cae1eac722e687d9019"}, - {file = "fonttools-4.50.0-cp312-cp312-win32.whl", hash = "sha256:cc8140baf9fa8f9b903f2b393a6c413a220fa990264b215bf48484f3d0bf8710"}, - {file = "fonttools-4.50.0-cp312-cp312-win_amd64.whl", hash = "sha256:0ccc85fd96373ab73c59833b824d7a73846670a0cb1f3afbaee2b2c426a8f931"}, - {file = "fonttools-4.50.0-py3-none-any.whl", hash = "sha256:48fa36da06247aa8282766cfd63efff1bb24e55f020f29a335939ed3844d20d3"}, - {file = "fonttools-4.50.0.tar.gz", hash = "sha256:fa5cf61058c7dbb104c2ac4e782bf1b2016a8cf2f69de6e4dd6a865d2c969bb5"}, + {file = "fonttools-4.53.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d8f191a17369bd53a5557a5ee4bab91d5330ca3aefcdf17fab9a497b0e7cff7a"}, + {file = "fonttools-4.53.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:93156dd7f90ae0a1b0e8871032a07ef3178f553f0c70c386025a808f3a63b1f4"}, + {file = "fonttools-4.53.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bff98816cb144fb7b85e4b5ba3888a33b56ecef075b0e95b95bcd0a5fbf20f06"}, + {file = "fonttools-4.53.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:973d030180eca8255b1bce6ffc09ef38a05dcec0e8320cc9b7bcaa65346f341d"}, + {file = "fonttools-4.53.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c4ee5a24e281fbd8261c6ab29faa7fd9a87a12e8c0eed485b705236c65999109"}, + {file = "fonttools-4.53.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bd5bc124fae781a4422f61b98d1d7faa47985f663a64770b78f13d2c072410c2"}, + {file = "fonttools-4.53.0-cp312-cp312-win32.whl", hash = "sha256:a239afa1126b6a619130909c8404070e2b473dd2b7fc4aacacd2e763f8597fea"}, + {file = "fonttools-4.53.0-cp312-cp312-win_amd64.whl", hash = "sha256:45b4afb069039f0366a43a5d454bc54eea942bfb66b3fc3e9a2c07ef4d617380"}, + {file = "fonttools-4.53.0-py3-none-any.whl", hash = "sha256:6b4f04b1fbc01a3569d63359f2227c89ab294550de277fd09d8fca6185669fa4"}, + {file = "fonttools-4.53.0.tar.gz", hash = "sha256:c93ed66d32de1559b6fc348838c7572d5c0ac1e4a258e76763a5caddd8944002"}, ] [[package]] @@ -778,7 +768,7 @@ name = "fsspec" version = "2023.12.2" requires_python = ">=3.8" summary = "File-system specification" -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "fsspec-2023.12.2-py3-none-any.whl", hash = "sha256:d800d87f72189a745fa3d6b033b9dc4a34ad069f60ca60b943a63599f5501960"}, {file = "fsspec-2023.12.2.tar.gz", hash = "sha256:8548d39e8810b59c38014934f6b31e57f40c1b20f911f4cc2b85389c7e9bf0cb"}, @@ -803,7 +793,7 @@ files = [ [[package]] name = "gdown" -version = "5.1.0" +version = "5.2.0" requires_python = ">=3.8" summary = "Google Drive Public File/Folder Downloader" groups = ["default"] @@ -814,8 +804,8 @@ dependencies = [ "tqdm", ] files = [ - {file = "gdown-5.1.0-py3-none-any.whl", hash = "sha256:421530fd238fa15d41ba43219a79fdc28efe8ac11022173abad333701b77de2c"}, - {file = "gdown-5.1.0.tar.gz", hash = "sha256:550a72dc5ca2819fe4bcc15d80d05d7c98c0b90e57256254b77d0256b9df4683"}, + {file = "gdown-5.2.0-py3-none-any.whl", hash = "sha256:33083832d82b1101bdd0e9df3edd0fbc0e1c5f14c9d8c38d2a35bf1683b526d6"}, + {file = "gdown-5.2.0.tar.gz", hash = "sha256:2145165062d85520a3cd98b356c9ed522c5e7984d408535409fd46f94defc787"}, ] [[package]] @@ -834,7 +824,7 @@ files = [ [[package]] name = "gitpython" -version = "3.1.42" +version = "3.1.43" requires_python = ">=3.7" summary = "GitPython is a Python library used to interact with Git repositories" groups = ["default"] @@ -842,8 +832,8 @@ dependencies = [ "gitdb<5,>=4.0.1", ] files = [ - {file = "GitPython-3.1.42-py3-none-any.whl", hash = "sha256:1bf9cd7c9e7255f77778ea54359e54ac22a72a5b51288c457c881057b7bb9ecd"}, - {file = "GitPython-3.1.42.tar.gz", hash = "sha256:2d99869e0fef71a73cbd242528105af1d6c1b108c60dfabd994bf292f76c3ceb"}, + {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"}, + {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"}, ] [[package]] @@ -865,21 +855,21 @@ files = [ [[package]] name = "grpcio" -version = "1.63.0" +version = "1.64.0" requires_python = ">=3.8" summary = "HTTP/2-based RPC framework" groups = ["default"] files = [ - {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, - {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, - {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, - {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, - {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, + {file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"}, + {file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"}, + {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"}, + {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"}, + {file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"}, + {file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"}, + {file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"}, ] [[package]] @@ -1002,7 +992,7 @@ files = [ [[package]] name = "hydra-zen" -version = "0.12.1" +version = "0.13.0" requires_python = ">=3.8" summary = "Configurable, reproducible, and scalable workflows in Python, via Hydra" groups = ["default"] @@ -1012,24 +1002,24 @@ dependencies = [ "typing-extensions!=4.6.0,>=4.1.0", ] files = [ - {file = "hydra_zen-0.12.1-py3-none-any.whl", hash = "sha256:23ad9df648f0db1747c15cc4c90c8e46be2809099493b5d00dcc6e49979d2545"}, - {file = "hydra_zen-0.12.1.tar.gz", hash = "sha256:e27979c16505a654918c8c36cc190cc61ded14054a4d84efd65c0b83e7078050"}, + {file = "hydra_zen-0.13.0-py3-none-any.whl", hash = "sha256:6050b62be96d2a47b2abf0e9c0ebcce1e9a4e259e173870338ab049b833f26cf"}, + {file = "hydra_zen-0.13.0.tar.gz", hash = "sha256:1b53d74aa1f0baa04fafdac6aba7a94ae40929e7b0a5a5081d8740f74322052d"}, ] [[package]] name = "idna" -version = "3.6" +version = "3.7" requires_python = ">=3.5" summary = "Internationalized Domain Names in Applications (IDNA)" groups = ["default"] files = [ - {file = "idna-3.6-py3-none-any.whl", hash = "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f"}, - {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] [[package]] name = "imageio" -version = "2.34.0" +version = "2.34.1" requires_python = ">=3.8" summary = "Library for reading and writing a wide range of image, video, scientific, and volumetric data formats." groups = ["default"] @@ -1039,13 +1029,13 @@ dependencies = [ "pillow>=8.3.2", ] files = [ - {file = "imageio-2.34.0-py3-none-any.whl", hash = "sha256:08082bf47ccb54843d9c73fe9fc8f3a88c72452ab676b58aca74f36167e8ccba"}, - {file = "imageio-2.34.0.tar.gz", hash = "sha256:ae9732e10acf807a22c389aef193f42215718e16bd06eed0c5bb57e1034a4d53"}, + {file = "imageio-2.34.1-py3-none-any.whl", hash = "sha256:408c1d4d62f72c9e8347e7d1ca9bc11d8673328af3913868db3b828e28b40a4c"}, + {file = "imageio-2.34.1.tar.gz", hash = "sha256:f13eb76e4922f936ac4a7fec77ce8a783e63b93543d4ea3e40793a6cabd9ac7d"}, ] [[package]] name = "imageio-ffmpeg" -version = "0.4.9" +version = "0.5.0" requires_python = ">=3.5" summary = "FFMPEG wrapper for Python" groups = ["default"] @@ -1054,12 +1044,12 @@ dependencies = [ "setuptools", ] files = [ - {file = "imageio-ffmpeg-0.4.9.tar.gz", hash = "sha256:39bcd1660118ef360fa4047456501071364661aa9d9021d3d26c58f1ee2081f5"}, - {file = "imageio_ffmpeg-0.4.9-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl", hash = "sha256:24095e882a126a0d217197b86265f821b4bb3cf9004104f67c1384a2b4b49168"}, - {file = "imageio_ffmpeg-0.4.9-py3-none-manylinux2010_x86_64.whl", hash = "sha256:2996c64af3e5489227096580269317719ea1a8121d207f2e28d6c24ebc4a253e"}, - {file = "imageio_ffmpeg-0.4.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7eead662d2f46d748c0ab446b68f423eb63d2b54d0a8ef96f80607245540866d"}, - {file = "imageio_ffmpeg-0.4.9-py3-none-win32.whl", hash = "sha256:b6de1e18911687c538d5585d8287ab1a23624ca9dc2044fcc4607de667bcf11e"}, - {file = "imageio_ffmpeg-0.4.9-py3-none-win_amd64.whl", hash = "sha256:7e900c695c6541b1cb17feb1baacd4009b30a53a45b81c23d53a67ab13ffb766"}, + {file = "imageio-ffmpeg-0.5.0.tar.gz", hash = "sha256:75c9c45079510cfeb4849a17fcd3edd4f14062ea6b69c5b62695fb2075295c87"}, + {file = "imageio_ffmpeg-0.5.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl", hash = "sha256:e9aba9cdd01164a50a4cfb1b825fc8769151a0d3b5b5a7d5d50ff9fcda7eee9c"}, + {file = "imageio_ffmpeg-0.5.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:ba55f392ee5db9eb0a6d7699e0060a2edcaa7dbc740ca29671bdc8dbb763ca3b"}, + {file = "imageio_ffmpeg-0.5.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9c813be7d6a24236bb68aeab249ea67f5a7fdf7d86988855578247694c42e94a"}, + {file = "imageio_ffmpeg-0.5.0-py3-none-win32.whl", hash = "sha256:c4a3b32fc38d4a26c15582bf12246ddae060932889da5c9da487cc675740039b"}, + {file = "imageio_ffmpeg-0.5.0-py3-none-win_amd64.whl", hash = "sha256:8135f4d146094b62b31721ca53fe943f4134e3578e22015468e3df595217c24b"}, ] [[package]] @@ -1078,7 +1068,7 @@ name = "iniconfig" version = "2.0.0" requires_python = ">=3.7" summary = "brain-dead simple config-ini parsing" -groups = ["dev"] +groups = ["default", "dev"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -1100,15 +1090,29 @@ files = [ {file = "inquirer-3.2.4.tar.gz", hash = "sha256:33b09efc1b742b9d687b540296a8b6a3f773399673321fcc2ab0eb4c109bf9b5"}, ] +[[package]] +name = "intel-openmp" +version = "2021.4.0" +summary = "IntelĀ® OpenMP* Runtime Library" +groups = ["default", "dev"] +marker = "platform_system == \"Windows\"" +files = [ + {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, +] + [[package]] name = "itsdangerous" -version = "2.1.2" -requires_python = ">=3.7" +version = "2.2.0" +requires_python = ">=3.8" summary = "Safely pass data to untrusted environments and back." groups = ["default"] files = [ - {file = "itsdangerous-2.1.2-py3-none-any.whl", hash = "sha256:2c2349112351b88699d8d4b6b075022c0808887cb7ad10069318a8b0bc88db44"}, - {file = "itsdangerous-2.1.2.tar.gz", hash = "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"}, + {file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"}, + {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, ] [[package]] @@ -1222,16 +1226,16 @@ files = [ [[package]] name = "jinja2" -version = "3.1.3" +version = "3.1.4" requires_python = ">=3.7" summary = "A very fast and expressive template engine." -groups = ["default"] +groups = ["default", "dev"] dependencies = [ "MarkupSafe>=2.0", ] files = [ - {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"}, - {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"}, + {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, + {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, ] [[package]] @@ -1347,7 +1351,7 @@ files = [ [[package]] name = "lightning-cloud" -version = "0.5.65" +version = "0.5.69" requires_python = ">=3.7.0" summary = "Lightning Cloud" groups = ["default"] @@ -1355,6 +1359,7 @@ dependencies = [ "boto3", "click", "fastapi", + "protobuf", "pyjwt", "python-multipart", "requests", @@ -1365,13 +1370,13 @@ dependencies = [ "websocket-client", ] files = [ - {file = "lightning_cloud-0.5.65-py3-none-any.whl", hash = "sha256:cd0024e48c5e6807c0015052a1ece1ac6b25fb165c430fbd682f9e94b31fbe5d"}, - {file = "lightning_cloud-0.5.65.tar.gz", hash = "sha256:8ddd84c270ca486edc1178cc68c293029092c9c7c893d1ab9e085cf9c17179f3"}, + {file = "lightning_cloud-0.5.69-py3-none-any.whl", hash = "sha256:8e26b534c3970ea939d37c284e9de5d0c880339a49d18c9b9181c0e093f95fd1"}, + {file = "lightning_cloud-0.5.69.tar.gz", hash = "sha256:0baeef05c06a6d89c482abea1826cc3e3bec48901d10cc2749f39b344e6f1dc3"}, ] [[package]] name = "lightning-utilities" -version = "0.11.0" +version = "0.11.2" requires_python = ">=3.8" summary = "Lightning toolbox for across the our ecosystem." groups = ["default"] @@ -1381,8 +1386,8 @@ dependencies = [ "typing-extensions", ] files = [ - {file = "lightning-utilities-0.11.0.tar.gz", hash = "sha256:dd704795785ceba1e0cd60ba3a9b0553c7902ec9efc1578a74e893a291416e62"}, - {file = "lightning_utilities-0.11.0-py3-none-any.whl", hash = "sha256:bf576a421027fdbaf48e80cbc2fdf900a3316a469748a953c33a8ca2b2718a20"}, + {file = "lightning-utilities-0.11.2.tar.gz", hash = "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9"}, + {file = "lightning_utilities-0.11.2-py3-none-any.whl", hash = "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159"}, ] [[package]] @@ -1415,7 +1420,7 @@ name = "markupsafe" version = "2.1.5" requires_python = ">=3.7" summary = "Safely add untrusted strings to HTML/XML markup." -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"}, {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"}, @@ -1432,7 +1437,7 @@ files = [ [[package]] name = "matplotlib" -version = "3.8.3" +version = "3.9.0" requires_python = ">=3.9" summary = "Python plotting package" groups = ["default"] @@ -1441,23 +1446,24 @@ dependencies = [ "cycler>=0.10", "fonttools>=4.22.0", "kiwisolver>=1.3.1", - "numpy<2,>=1.21", + "numpy>=1.23", "packaging>=20.0", "pillow>=8", "pyparsing>=2.3.1", "python-dateutil>=2.7", ] files = [ - {file = "matplotlib-3.8.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:09074f8057917d17ab52c242fdf4916f30e99959c1908958b1fc6032e2d0f6d4"}, - {file = "matplotlib-3.8.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5745f6d0fb5acfabbb2790318db03809a253096e98c91b9a31969df28ee604aa"}, - {file = "matplotlib-3.8.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b97653d869a71721b639714b42d87cda4cfee0ee74b47c569e4874c7590c55c5"}, - {file = "matplotlib-3.8.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:242489efdb75b690c9c2e70bb5c6550727058c8a614e4c7716f363c27e10bba1"}, - {file = "matplotlib-3.8.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:83c0653c64b73926730bd9ea14aa0f50f202ba187c307a881673bad4985967b7"}, - {file = "matplotlib-3.8.3-cp312-cp312-win_amd64.whl", hash = "sha256:ef6c1025a570354297d6c15f7d0f296d95f88bd3850066b7f1e7b4f2f4c13a39"}, - {file = "matplotlib-3.8.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fa93695d5c08544f4a0dfd0965f378e7afc410d8672816aff1e81be1f45dbf2e"}, - {file = "matplotlib-3.8.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9764df0e8778f06414b9d281a75235c1e85071f64bb5d71564b97c1306a2afc"}, - {file = "matplotlib-3.8.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:5e431a09e6fab4012b01fc155db0ce6dccacdbabe8198197f523a4ef4805eb26"}, - {file = "matplotlib-3.8.3.tar.gz", hash = "sha256:7b416239e9ae38be54b028abbf9048aff5054a9aba5416bef0bd17f9162ce161"}, + {file = "matplotlib-3.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382"}, + {file = "matplotlib-3.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84"}, + {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5"}, + {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db"}, + {file = "matplotlib-3.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7"}, + {file = "matplotlib-3.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e"}, + {file = "matplotlib-3.9.0.tar.gz", hash = "sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a"}, ] [[package]] @@ -1471,6 +1477,24 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mkl" +version = "2021.4.0" +summary = "IntelĀ® oneAPI Math Kernel Library" +groups = ["default", "dev"] +marker = "platform_system == \"Windows\"" +dependencies = [ + "intel-openmp==2021.*", + "tbb==2021.*", +] +files = [ + {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, + {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, + {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, +] + [[package]] name = "ml-collections" version = "0.1.1" @@ -1530,7 +1554,7 @@ files = [ name = "mpmath" version = "1.3.0" summary = "Python library for arbitrary-precision floating-point arithmetic" -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, @@ -1559,7 +1583,7 @@ files = [ [[package]] name = "mujoco" -version = "3.1.4" +version = "3.1.5" requires_python = ">=3.8" summary = "MuJoCo Physics Simulator" groups = ["default"] @@ -1571,17 +1595,17 @@ dependencies = [ "pyopengl", ] files = [ - {file = "mujoco-3.1.4-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:6c1cf0d9b3edb6ac60ebb0bc2f0384dbff68ae16aae2d0e01996e83ea8de8f72"}, - {file = "mujoco-3.1.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bfabf43a22c3ea61bd65e13efba024117f287d151b0692eea103010438e55e9e"}, - {file = "mujoco-3.1.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cfafc98b528bb5ae43c706dec6db22b7a8818406e436c71e6bbc8d6f7064298"}, - {file = "mujoco-3.1.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36a68f95900426261887f3bde03981110f78fc5fb57268dd4a0726a2bc6ce3c6"}, - {file = "mujoco-3.1.4-cp312-cp312-win_amd64.whl", hash = "sha256:326e1709158c3c52edb094ebefd12a3f47ba04ecbd85b41e4ace5e5d23e0d741"}, - {file = "mujoco-3.1.4.tar.gz", hash = "sha256:19d78bd7332b8bf02b8d7ca35d381a9f8f1654f4c70c0d7f499c6d4d807c4059"}, + {file = "mujoco-3.1.5-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:0a78079b07e63d04f2985684ccd3a9937badba4cf51432662ff818b092442dbc"}, + {file = "mujoco-3.1.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4145c6277a1e71000a54c0bfef337c885a57452c5f0aa7cddf4b41932b639f41"}, + {file = "mujoco-3.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20bb70bfee28e026efc71f6872871c689fa2eaecc54d019ae1a21362453619cd"}, + {file = "mujoco-3.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f93bf770c3c963efe03c27b34ca59015e27ae70cdd4272a8312e583f52dbf40"}, + {file = "mujoco-3.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:8b139b1950ad52924e8666561414dd8f4f3f69f89364f1d0304371839be9264e"}, + {file = "mujoco-3.1.5.tar.gz", hash = "sha256:9099ba6001341cc9e38b7b94b8ef7a67346c7638fa3e94f520743a357891f296"}, ] [[package]] name = "mujoco-mjx" -version = "3.1.4" +version = "3.1.5" requires_python = ">=3.8" summary = "MuJoCo XLA (MJX)" groups = ["default"] @@ -1590,13 +1614,13 @@ dependencies = [ "etils[epath]", "jax", "jaxlib", - "mujoco>=3.1.4.dev0", + "mujoco>=3.1.5.dev0", "scipy", "trimesh", ] files = [ - {file = "mujoco-mjx-3.1.4.tar.gz", hash = "sha256:519e6aa0485fea75a1e96bf77ce4ee859ba7f8b83a976da4846e9c5fa09afd0c"}, - {file = "mujoco_mjx-3.1.4-py3-none-any.whl", hash = "sha256:393f4169ec68304a4e3ee3b0bd6379ee45c4fe6a504d084e43902a69ad17c188"}, + {file = "mujoco_mjx-3.1.5-py3-none-any.whl", hash = "sha256:4fc54e10c0cb811fd97584222a00ce9fa433f79d7ce46a8d7b22c8a054c35238"}, + {file = "mujoco_mjx-3.1.5.tar.gz", hash = "sha256:ee6b409d694a0a34ab93803089e3c1297ed91ae6a9461661cd1d80a9f0565880"}, ] [[package]] @@ -1638,13 +1662,13 @@ files = [ [[package]] name = "networkx" -version = "3.2.1" -requires_python = ">=3.9" +version = "3.3" +requires_python = ">=3.10" summary = "Python package for creating and manipulating graphs and networks" -groups = ["default"] +groups = ["default", "dev"] files = [ - {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, - {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, + {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"}, + {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"}, ] [[package]] @@ -1652,7 +1676,7 @@ name = "numpy" version = "1.26.4" requires_python = ">=3.9" summary = "Fundamental package for array computing in Python" -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, @@ -1673,7 +1697,7 @@ name = "nvidia-cublas-cu12" version = "12.1.3.1" requires_python = ">=3" summary = "CUBLAS native runtime libraries" -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, @@ -1684,7 +1708,7 @@ name = "nvidia-cuda-cupti-cu12" version = "12.1.105" requires_python = ">=3" summary = "CUDA profiling tools runtime libs." -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, @@ -1692,13 +1716,13 @@ files = [ [[package]] name = "nvidia-cuda-nvcc-cu12" -version = "12.4.131" +version = "12.5.40" requires_python = ">=3" summary = "CUDA nvcc" groups = ["default"] files = [ - {file = "nvidia_cuda_nvcc_cu12-12.4.131-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e4c9a35435c3b0e36c6397c5696334c9ea6650e524736cbd1c5d345ea099bb04"}, - {file = "nvidia_cuda_nvcc_cu12-12.4.131-py3-none-win_amd64.whl", hash = "sha256:aadd9fb307352bbcd5bc89b5f98c10cd78000915094882d97321f5fe36441742"}, + {file = "nvidia_cuda_nvcc_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8347e2458c99eb9db3c392035c1781798f2593d495554106cf45502eeabc1a10"}, + {file = "nvidia_cuda_nvcc_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:616cd3280a05657d1e40d4985058bbd4c88384b92c88a7c30228643abe7465f2"}, ] [[package]] @@ -1706,7 +1730,7 @@ name = "nvidia-cuda-nvrtc-cu12" version = "12.1.105" requires_python = ">=3" summary = "NVRTC native runtime libraries" -groups = ["default"] +groups = ["default", "dev"] marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, @@ -1718,7 +1742,7 @@ name = "nvidia-cuda-runtime-cu12" version = "12.1.105" requires_python = ">=3" summary = "CUDA Runtime native Libraries" -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, @@ -1729,7 +1753,7 @@ name = "nvidia-cudnn-cu12" version = "8.9.2.26" requires_python = ">=3" summary = "cuDNN runtime libraries" -groups = ["default"] +groups = ["default", "dev"] dependencies = [ "nvidia-cublas-cu12", ] @@ -1742,7 +1766,7 @@ name = "nvidia-cufft-cu12" version = "11.0.2.54" requires_python = ">=3" summary = "CUFFT native runtime libraries" -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, @@ -1753,7 +1777,7 @@ name = "nvidia-curand-cu12" version = "10.3.2.106" requires_python = ">=3" summary = "CURAND native runtime libraries" -groups = ["default"] +groups = ["default", "dev"] marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, @@ -1765,7 +1789,7 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" requires_python = ">=3" summary = "CUDA solver native runtime libraries" -groups = ["default"] +groups = ["default", "dev"] dependencies = [ "nvidia-cublas-cu12", "nvidia-cusparse-cu12", @@ -1781,7 +1805,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" requires_python = ">=3" summary = "CUSPARSE native runtime libraries" -groups = ["default"] +groups = ["default", "dev"] dependencies = [ "nvidia-nvjitlink-cu12", ] @@ -1792,24 +1816,24 @@ files = [ [[package]] name = "nvidia-nccl-cu12" -version = "2.19.3" +version = "2.20.5" requires_python = ">=3" summary = "NVIDIA Collective Communication Library (NCCL) Runtime" -groups = ["default"] +groups = ["default", "dev"] files = [ - {file = "nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.4.99" +version = "12.5.40" requires_python = ">=3" summary = "Nvidia JIT LTO Library" -groups = ["default"] +groups = ["default", "dev"] files = [ - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_aarch64.whl", hash = "sha256:75d6498c96d9adb9435f2bbdbddb479805ddfb97b5c1b32395c694185c20ca57"}, - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"}, - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"}, + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] [[package]] @@ -1817,7 +1841,7 @@ name = "nvidia-nvtx-cu12" version = "12.1.105" requires_python = ">=3" summary = "NVIDIA Tools Extension" -groups = ["default"] +groups = ["default", "dev"] marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, @@ -1873,7 +1897,7 @@ files = [ [[package]] name = "orbax-checkpoint" -version = "0.5.10" +version = "0.5.15" requires_python = ">=3.9" summary = "Orbax Checkpoint" groups = ["default"] @@ -1891,8 +1915,8 @@ dependencies = [ "typing-extensions", ] files = [ - {file = "orbax_checkpoint-0.5.10-py3-none-any.whl", hash = "sha256:377dd952410038731486e5fbbb91704e0feb027c2d4faf1f56f5276bcb01ee51"}, - {file = "orbax_checkpoint-0.5.10.tar.gz", hash = "sha256:2e00657383c6fcdc3209203c32d95decad798d975ec5973592e5e32d40455d80"}, + {file = "orbax_checkpoint-0.5.15-py3-none-any.whl", hash = "sha256:658dd89bc925cecc584d89eaa19af9a7e16e3371377907eb713fbd59b85262e4"}, + {file = "orbax_checkpoint-0.5.15.tar.gz", hash = "sha256:15195e8d1b381b56f23a62a25599a3644f5d08655fa64f60bb1b938b8ffe7ef3"}, ] [[package]] @@ -1942,47 +1966,59 @@ files = [ [[package]] name = "pillow" -version = "10.2.0" +version = "10.3.0" requires_python = ">=3.8" summary = "Python Imaging Library (Fork)" groups = ["default"] files = [ - {file = "pillow-10.2.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:1b5e1b74d1bd1b78bc3477528919414874748dd363e6272efd5abf7654e68bef"}, - {file = "pillow-10.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0eae2073305f451d8ecacb5474997c08569fb4eb4ac231ffa4ad7d342fdc25ac"}, - {file = "pillow-10.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7c2286c23cd350b80d2fc9d424fc797575fb16f854b831d16fd47ceec078f2c"}, - {file = "pillow-10.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e23412b5c41e58cec602f1135c57dfcf15482013ce6e5f093a86db69646a5aa"}, - {file = "pillow-10.2.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:52a50aa3fb3acb9cf7213573ef55d31d6eca37f5709c69e6858fe3bc04a5c2a2"}, - {file = "pillow-10.2.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:127cee571038f252a552760076407f9cff79761c3d436a12af6000cd182a9d04"}, - {file = "pillow-10.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8d12251f02d69d8310b046e82572ed486685c38f02176bd08baf216746eb947f"}, - {file = "pillow-10.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:54f1852cd531aa981bc0965b7d609f5f6cc8ce8c41b1139f6ed6b3c54ab82bfb"}, - {file = "pillow-10.2.0-cp312-cp312-win32.whl", hash = "sha256:257d8788df5ca62c980314053197f4d46eefedf4e6175bc9412f14412ec4ea2f"}, - {file = "pillow-10.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:154e939c5f0053a383de4fd3d3da48d9427a7e985f58af8e94d0b3c9fcfcf4f9"}, - {file = "pillow-10.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:f379abd2f1e3dddb2b61bc67977a6b5a0a3f7485538bcc6f39ec76163891ee48"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:322209c642aabdd6207517e9739c704dc9f9db943015535783239022002f054a"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eedd52442c0a5ff4f887fab0c1c0bb164d8635b32c894bc1faf4c618dd89df2"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb28c753fd5eb3dd859b4ee95de66cc62af91bcff5db5f2571d32a520baf1f04"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:33870dc4653c5017bf4c8873e5488d8f8d5f8935e2f1fb9a2208c47cdd66efd2"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3c31822339516fb3c82d03f30e22b1d038da87ef27b6a78c9549888f8ceda39a"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a2b56ba36e05f973d450582fb015594aaa78834fefe8dfb8fcd79b93e64ba4c6"}, - {file = "pillow-10.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d8e6aeb9201e655354b3ad049cb77d19813ad4ece0df1249d3c793de3774f8c7"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:2247178effb34a77c11c0e8ac355c7a741ceca0a732b27bf11e747bbc950722f"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15587643b9e5eb26c48e49a7b33659790d28f190fc514a322d55da2fb5c2950e"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753cd8f2086b2b80180d9b3010dd4ed147efc167c90d3bf593fe2af21265e5a5"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:7c8f97e8e7a9009bcacbe3766a36175056c12f9a44e6e6f2d5caad06dcfbf03b"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d1b35bcd6c5543b9cb547dee3150c93008f8dd0f1fef78fc0cd2b141c5baf58a"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe4c15f6c9285dc54ce6553a3ce908ed37c8f3825b5a51a15c91442bb955b868"}, - {file = "pillow-10.2.0.tar.gz", hash = "sha256:e87f0b2c78157e12d7686b27d63c070fd65d994e8ddae6f328e0dcf4a0cd007e"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"}, + {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"}, + {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"}, + {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"}, + {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"}, +] + +[[package]] +name = "platformdirs" +version = "4.2.2" +requires_python = ">=3.8" +summary = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." +groups = ["default"] +files = [ + {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, + {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, ] [[package]] name = "pluggy" -version = "1.4.0" +version = "1.5.0" requires_python = ">=3.8" summary = "plugin and hook calling mechanisms for python" -groups = ["dev"] +groups = ["default", "dev"] files = [ - {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, - {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] [[package]] @@ -2042,7 +2078,7 @@ files = [ [[package]] name = "pydantic" -version = "1.10.14" +version = "1.10.15" requires_python = ">=3.7" summary = "Data validation and settings management using python type hints" groups = ["default"] @@ -2050,8 +2086,8 @@ dependencies = [ "typing-extensions>=4.2.0", ] files = [ - {file = "pydantic-1.10.14-py3-none-any.whl", hash = "sha256:8ee853cd12ac2ddbf0ecbac1c289f95882b2d4482258048079d13be700aa114c"}, - {file = "pydantic-1.10.14.tar.gz", hash = "sha256:46f17b832fe27de7850896f3afee50ea682220dd218f7e9c88d436788419dca6"}, + {file = "pydantic-1.10.15-py3-none-any.whl", hash = "sha256:28e552a060ba2740d0d2aabe35162652c1459a0b9069fe0db7f4ee0e18e74d58"}, + {file = "pydantic-1.10.15.tar.gz", hash = "sha256:ca832e124eda231a60a041da4f013e3ff24949d94a01154b137fc2f2a43c3ffb"}, ] [[package]] @@ -2082,13 +2118,13 @@ files = [ [[package]] name = "pygments" -version = "2.17.2" -requires_python = ">=3.7" +version = "2.18.0" +requires_python = ">=3.8" summary = "Pygments is a syntax highlighting package written in Python." groups = ["default"] files = [ - {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"}, - {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"}, + {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, + {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, ] [[package]] @@ -2136,19 +2172,19 @@ files = [ [[package]] name = "pytest" -version = "8.1.1" +version = "8.2.1" requires_python = ">=3.8" summary = "pytest: simple powerful testing with Python" -groups = ["dev"] +groups = ["default", "dev"] dependencies = [ "colorama; sys_platform == \"win32\"", "iniconfig", "packaging", - "pluggy<2.0,>=1.4", + "pluggy<2.0,>=1.5", ] files = [ - {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, - {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, + {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"}, + {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"}, ] [[package]] @@ -2186,7 +2222,7 @@ name = "pytest-datadir" version = "1.5.0" requires_python = ">=3.8" summary = "pytest plugin for test data directories and files" -groups = ["dev"] +groups = ["default", "dev"] dependencies = [ "pytest>=5.0", ] @@ -2214,7 +2250,7 @@ name = "pytest-regressions" version = "2.5.0" requires_python = ">=3.8" summary = "Easy to use fixtures to write regression tests." -groups = ["dev"] +groups = ["default", "dev"] dependencies = [ "pytest-datadir>=1.2.0", "pytest>=6.2.0", @@ -2255,17 +2291,17 @@ files = [ [[package]] name = "pytest-xdist" -version = "3.5.0" -requires_python = ">=3.7" +version = "3.6.1" +requires_python = ">=3.8" summary = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" groups = ["dev"] dependencies = [ - "execnet>=1.1", - "pytest>=6.2.0", + "execnet>=2.1", + "pytest>=7.0.0", ] files = [ - {file = "pytest-xdist-3.5.0.tar.gz", hash = "sha256:cbb36f3d67e0c478baa57fa4edc8843887e0f6cfc42d677530a36d7472b32d8a"}, - {file = "pytest_xdist-3.5.0-py3-none-any.whl", hash = "sha256:d075629c7e00b611df89f490a5063944bee7a4362a5ff11c7cc7824a03dfce24"}, + {file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"}, + {file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"}, ] [[package]] @@ -2305,6 +2341,22 @@ files = [ {file = "pytinyrenderer-0.0.14.tar.gz", hash = "sha256:5fedb4798509cb911a03a3bc9e8de8d4d5aa36b1de52eb878efef104b95a3d15"}, ] +[[package]] +name = "pytorch2jax" +version = "0.1.0" +requires_python = ">=3.6, <4" +summary = "Convert PyTorch models to Jax functions and Flax models" +groups = ["default"] +dependencies = [ + "jax", + "jaxlib", + "torch", +] +files = [ + {file = "pytorch2jax-0.1.0-py3-none-any.whl", hash = "sha256:93313fe032b8fe5b404dcd7daed6c9b96ff6964a593a0849c70ac29a480f2867"}, + {file = "pytorch2jax-0.1.0.tar.gz", hash = "sha256:c579bbd23e1c7902c5ee636f0d32a24cd065302a30e4365b28b4dcd5f92e9936"}, +] + [[package]] name = "pytz" version = "2024.1" @@ -2334,22 +2386,19 @@ files = [ [[package]] name = "readchar" -version = "4.0.6" +version = "4.1.0" requires_python = ">=3.8" summary = "Library to easily read single chars and key strokes" groups = ["default"] -dependencies = [ - "setuptools>=41.0", -] files = [ - {file = "readchar-4.0.6-py3-none-any.whl", hash = "sha256:b4b31dd35de4897be738f27e8f9f62426b5fedb54b648364987e30ae534b71bc"}, - {file = "readchar-4.0.6.tar.gz", hash = "sha256:e0dae942d3a746f8d5423f83dbad67efe704004baafe31b626477929faaee472"}, + {file = "readchar-4.1.0-py3-none-any.whl", hash = "sha256:d163680656b34f263fb5074023db44b999c68ff31ab394445ebfd1a2a41fe9a2"}, + {file = "readchar-4.1.0.tar.gz", hash = "sha256:6f44d1b5f0fd93bd93236eac7da39609f15df647ab9cea39f5bc7478b3344b99"}, ] [[package]] name = "requests" -version = "2.31.0" -requires_python = ">=3.7" +version = "2.32.3" +requires_python = ">=3.8" summary = "Python HTTP for Humans." groups = ["default"] dependencies = [ @@ -2359,24 +2408,24 @@ dependencies = [ "urllib3<3,>=1.21.1", ] files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [[package]] name = "requests" -version = "2.31.0" +version = "2.32.3" extras = ["socks"] -requires_python = ">=3.7" +requires_python = ">=3.8" summary = "Python HTTP for Humans." groups = ["default"] dependencies = [ "PySocks!=1.5.7,>=1.5.6", - "requests==2.31.0", + "requests==2.32.3", ] files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [[package]] @@ -2396,28 +2445,28 @@ files = [ [[package]] name = "ruff" -version = "0.3.3" +version = "0.4.6" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["dev"] files = [ - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:973a0e388b7bc2e9148c7f9be8b8c6ae7471b9be37e1cc732f8f44a6f6d7720d"}, - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfa60d23269d6e2031129b053fdb4e5a7b0637fc6c9c0586737b962b2f834493"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eca7ff7a47043cf6ce5c7f45f603b09121a7cc047447744b029d1b719278eb5"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7d3f6762217c1da954de24b4a1a70515630d29f71e268ec5000afe81377642d"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b24c19e8598916d9c6f5a5437671f55ee93c212a2c4c569605dc3842b6820386"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5a6cbf216b69c7090f0fe4669501a27326c34e119068c1494f35aaf4cc683778"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352e95ead6964974b234e16ba8a66dad102ec7bf8ac064a23f95371d8b198aab"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d6ab88c81c4040a817aa432484e838aaddf8bfd7ca70e4e615482757acb64f8"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79bca3a03a759cc773fca69e0bdeac8abd1c13c31b798d5bb3c9da4a03144a9f"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2700a804d5336bcffe063fd789ca2c7b02b552d2e323a336700abb8ae9e6a3f8"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd66469f1a18fdb9d32e22b79f486223052ddf057dc56dea0caaf1a47bdfaf4e"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:45817af234605525cdf6317005923bf532514e1ea3d9270acf61ca2440691376"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0da458989ce0159555ef224d5b7c24d3d2e4bf4c300b85467b08c3261c6bc6a8"}, - {file = "ruff-0.3.3-py3-none-win32.whl", hash = "sha256:f2831ec6a580a97f1ea82ea1eda0401c3cdf512cf2045fa3c85e8ef109e87de0"}, - {file = "ruff-0.3.3-py3-none-win_amd64.whl", hash = "sha256:be90bcae57c24d9f9d023b12d627e958eb55f595428bafcb7fec0791ad25ddfc"}, - {file = "ruff-0.3.3-py3-none-win_arm64.whl", hash = "sha256:0171aab5fecdc54383993389710a3d1227f2da124d76a2784a7098e818f92d61"}, - {file = "ruff-0.3.3.tar.gz", hash = "sha256:38671be06f57a2f8aba957d9f701ea889aa5736be806f18c0cd03d6ff0cbca8d"}, + {file = "ruff-0.4.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ef995583a038cd4a7edf1422c9e19118e2511b8ba0b015861b4abd26ec5367c5"}, + {file = "ruff-0.4.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:602ebd7ad909eab6e7da65d3c091547781bb06f5f826974a53dbe563d357e53c"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f9ced5cbb7510fd7525448eeb204e0a22cabb6e99a3cb160272262817d49786"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04a80acfc862e0e1630c8b738e70dcca03f350bad9e106968a8108379e12b31f"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be47700ecb004dfa3fd4dcdddf7322d4e632de3c06cd05329d69c45c0280e618"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1ff930d6e05f444090a0139e4e13e1e2e1f02bd51bb4547734823c760c621e79"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f13410aabd3b5776f9c5699f42b37a3a348d65498c4310589bc6e5c548dc8a2f"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0cf5cc02d3ae52dfb0c8a946eb7a1d6ffe4d91846ffc8ce388baa8f627e3bd50"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea3424793c29906407e3cf417f28fc33f689dacbbadfb52b7e9a809dd535dcef"}, + {file = "ruff-0.4.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1fa8561489fadf483ffbb091ea94b9c39a00ed63efacd426aae2f197a45e67fc"}, + {file = "ruff-0.4.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4d5b914818d8047270308fe3e85d9d7f4a31ec86c6475c9f418fbd1624d198e0"}, + {file = "ruff-0.4.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:4f02284335c766678778475e7698b7ab83abaf2f9ff0554a07b6f28df3b5c259"}, + {file = "ruff-0.4.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3a6a0a4f4b5f54fff7c860010ab3dd81425445e37d35701a965c0248819dde7a"}, + {file = "ruff-0.4.6-py3-none-win32.whl", hash = "sha256:9018bf59b3aa8ad4fba2b1dc0299a6e4e60a4c3bc62bbeaea222679865453062"}, + {file = "ruff-0.4.6-py3-none-win_amd64.whl", hash = "sha256:a769ae07ac74ff1a019d6bd529426427c3e30d75bdf1e08bb3d46ac8f417326a"}, + {file = "ruff-0.4.6-py3-none-win_arm64.whl", hash = "sha256:735a16407a1a8f58e4c5b913ad6102722e80b562dd17acb88887685ff6f20cf6"}, + {file = "ruff-0.4.6.tar.gz", hash = "sha256:a797a87da50603f71e6d0765282098245aca6e3b94b7c17473115167d8dfb0b7"}, ] [[package]] @@ -2450,7 +2499,7 @@ files = [ [[package]] name = "scipy" -version = "1.13.0" +version = "1.13.1" requires_python = ">=3.9" summary = "Fundamental algorithms for scientific computing in Python" groups = ["default"] @@ -2458,13 +2507,13 @@ dependencies = [ "numpy<2.3,>=1.22.4", ] files = [ - {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"}, - {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"}, - {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"}, - {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"}, - {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"}, - {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"}, - {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, + {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, + {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, + {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, ] [[package]] @@ -2485,16 +2534,17 @@ files = [ [[package]] name = "sentry-sdk" -version = "1.43.0" +version = "2.3.1" +requires_python = ">=3.6" summary = "Python client for Sentry (https://sentry.io)" groups = ["default"] dependencies = [ "certifi", - "urllib3>=1.26.11; python_version >= \"3.6\"", + "urllib3>=1.26.11", ] files = [ - {file = "sentry-sdk-1.43.0.tar.gz", hash = "sha256:41df73af89d22921d8733714fb0fc5586c3461907e06688e6537d01a27e0e0f6"}, - {file = "sentry_sdk-1.43.0-py2.py3-none-any.whl", hash = "sha256:8d768724839ca18d7b4c7463ef7528c40b7aa2bfbf7fe554d5f9a7c044acfd36"}, + {file = "sentry_sdk-2.3.1-py2.py3-none-any.whl", hash = "sha256:c5aeb095ba226391d337dd42a6f9470d86c9fc236ecc71cfc7cd1942b45010c6"}, + {file = "sentry_sdk-2.3.1.tar.gz", hash = "sha256:139a71a19f5e9eb5d3623942491ce03cf8ebc14ea2e39ba3e6fe79560d8a5b1f"}, ] [[package]] @@ -2537,13 +2587,13 @@ files = [ [[package]] name = "setuptools" -version = "69.2.0" +version = "70.0.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default"] files = [ - {file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"}, - {file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"}, + {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, + {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, ] [[package]] @@ -2636,16 +2686,43 @@ files = [ [[package]] name = "sympy" -version = "1.12" +version = "1.12.1" requires_python = ">=3.8" summary = "Computer algebra system (CAS) in Python" -groups = ["default"] +groups = ["default", "dev"] dependencies = [ - "mpmath>=0.19", + "mpmath<1.4.0,>=1.1.0", ] files = [ - {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, - {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, + {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, + {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, +] + +[[package]] +name = "tbb" +version = "2021.12.0" +summary = "IntelĀ® oneAPI Threading Building Blocks (oneTBB)" +groups = ["default", "dev"] +marker = "platform_system == \"Windows\"" +files = [ + {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, + {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, + {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, + {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, +] + +[[package]] +name = "tensor-regression" +version = "0.0.2.post3.dev0" +requires_python = "<4.0,>=3.11" +git = "https://www.github.com/lebrice/tensor_regression" +revision = "2b15f9312fe8891f0c617b5cbce1ba757d514a0a" +summary = "A small wrapper around pytest_regressions for Tensors" +groups = ["default", "dev"] +dependencies = [ + "numpy<2.0.0,>=1.26.4", + "pytest-regressions<3.0.0,>=2.5.0", + "torch<3.0.0,>=2.3.1", ] [[package]] @@ -2698,7 +2775,7 @@ files = [ [[package]] name = "tensorstore" -version = "0.1.58" +version = "0.1.60" requires_python = ">=3.9" summary = "Read and write large, multi-dimensional arrays" groups = ["default"] @@ -2707,11 +2784,11 @@ dependencies = [ "numpy>=1.16.0", ] files = [ - {file = "tensorstore-0.1.58-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b999068aabab50ca96154783083a9efbdc0c5315745304b5a1ef543aa788c66e"}, - {file = "tensorstore-0.1.58-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2bb73015268f0894d23796f3384751fda3d40423d42e129469d91cca7728d4e0"}, - {file = "tensorstore-0.1.58-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e974d0f35d07cdfded317030caa08e18aa522bb950bcf345a0d34cc2ea9035aa"}, - {file = "tensorstore-0.1.58-cp312-cp312-win_amd64.whl", hash = "sha256:6c8bbb75e0cb764325702771bb818f26cf9fb6f39178693174fbc6107d2df156"}, - {file = "tensorstore-0.1.58.tar.gz", hash = "sha256:899bcf2fad09d78a886dc4a9ee70dba7dc9c1fb5a1d7d38f164a97046b5434d9"}, + {file = "tensorstore-0.1.60-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:65677e21304fcf272557f195c597704f4ccf55b75314e68ece17bb1784cb59f7"}, + {file = "tensorstore-0.1.60-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:725d1f70c17838815704805d2853c636bb2d680424e81f91677a7defea68373b"}, + {file = "tensorstore-0.1.60-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c477a0e6948326c414ed1bcdab2949e975f0b4e7e449cce39e0fec14b273e1b2"}, + {file = "tensorstore-0.1.60-cp312-cp312-win_amd64.whl", hash = "sha256:32cba3cf0ae6dd03d504162b8ea387f140050e279cf23e7eced68d3c845693da"}, + {file = "tensorstore-0.1.60.tar.gz", hash = "sha256:88da8f1978982101b8dbb144fd29ee362e4e8c97fc595c4992d555f80ce62a79"}, ] [[package]] @@ -2727,14 +2804,15 @@ files = [ [[package]] name = "torch" -version = "2.2.1" +version = "2.3.1" requires_python = ">=3.8.0" summary = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -groups = ["default"] +groups = ["default", "dev"] dependencies = [ "filelock", "fsspec", "jinja2", + "mkl<=2021.4.0,>=2021.1.1; platform_system == \"Windows\"", "networkx", "nvidia-cublas-cu12==12.1.3.1; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-cuda-cupti-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"", @@ -2745,22 +2823,36 @@ dependencies = [ "nvidia-curand-cu12==10.3.2.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-cusolver-cu12==11.4.5.107; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-cusparse-cu12==12.1.0.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"", - "nvidia-nccl-cu12==2.19.3; platform_system == \"Linux\" and platform_machine == \"x86_64\"", + "nvidia-nccl-cu12==2.20.5; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-nvtx-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "sympy", "typing-extensions>=4.8.0", ] files = [ - {file = "torch-2.2.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ada53aebede1c89570e56861b08d12ba4518a1f8b82d467c32665ec4d1f4b3c8"}, - {file = "torch-2.2.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:be21d4c41ecebed9e99430dac87de1439a8c7882faf23bba7fea3fea7b906ac1"}, - {file = "torch-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:79848f46196750367dcdf1d2132b722180b9d889571e14d579ae82d2f50596c5"}, - {file = "torch-2.2.1-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:7ee804847be6be0032fbd2d1e6742fea2814c92bebccb177f0d3b8e92b2d2b18"}, - {file = "torch-2.2.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:84b2fb322ab091039fdfe74e17442ff046b258eb5e513a28093152c5b07325a7"}, + {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"}, + {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"}, + {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"}, + {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"}, +] + +[[package]] +name = "torch-jax-interop" +version = "0.0.4.post7.dev0" +requires_python = "<4.0,>=3.11" +git = "https://www.github.com/lebrice/torch_jax_interop" +revision = "7f0c72fe19d8bd4bd957f20dd90d77acd8178bd4" +summary = "Utility to convert Tensors from Jax to Torch and vice-versa" +groups = ["default"] +dependencies = [ + "flax<1.0.0,>=0.8.4", + "jax[cuda12]<1.0.0,>=0.4.28", + "pytorch2jax<1.0.0,>=0.1.0", + "torch<3.0.0,>=2.3.0", ] [[package]] name = "torchmetrics" -version = "1.3.2" +version = "1.4.0.post0" requires_python = ">=3.8" summary = "PyTorch native Metrics" groups = ["default"] @@ -2771,32 +2863,31 @@ dependencies = [ "torch>=1.10.0", ] files = [ - {file = "torchmetrics-1.3.2-py3-none-any.whl", hash = "sha256:44ca3a9f86dc050cb3f554836ef291698ea797778457195b4f685fce8e2e64a3"}, - {file = "torchmetrics-1.3.2.tar.gz", hash = "sha256:0a67694a4c4265eeb54cda741eaf5cb1f3a71da74b7e7e6215ad156c9f2379f6"}, + {file = "torchmetrics-1.4.0.post0-py3-none-any.whl", hash = "sha256:ab234216598e3fbd8d62ee4541a0e74e7e8fc935d099683af5b8da50f745b3c8"}, + {file = "torchmetrics-1.4.0.post0.tar.gz", hash = "sha256:ab9bcfe80e65dbabbddb6cecd9be21f1f1d5207bb74051ef95260740f2762358"}, ] [[package]] name = "torchvision" -version = "0.17.1" +version = "0.18.1" requires_python = ">=3.8" summary = "image and video datasets and models for torch deep learning" groups = ["default"] dependencies = [ "numpy", "pillow!=8.3.*,>=5.3.0", - "torch==2.2.1", + "torch==2.3.1", ] files = [ - {file = "torchvision-0.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5d241d2a5fb4e608677fccf6f80b34a124446d324ee40c7814ce54bce888275b"}, - {file = "torchvision-0.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0fe98d9d92c23d2262ff82f973242951b9357fb640f8888ac50848bd00f5b45"}, - {file = "torchvision-0.17.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:32dc5de86d2ade399e11087095674ca08a1649fb322cfe69336d28add467edcb"}, - {file = "torchvision-0.17.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:54902877410ffb5458ee52b6d0de4b25cf01496bee736d6825301a5f0398536e"}, - {file = "torchvision-0.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:cc22c1ed0f1aba3f98fd72b6f60021f57aec1d2f6af518522e8a0a83848de3a8"}, + {file = "torchvision-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2be6f0bf7c455c89a51a1dbb6f668d36c6edc479f49ac912d745d10df5715657"}, + {file = "torchvision-0.18.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:f118d887bfde3a948a41d56587525401e5cac1b7db2eaca203324d6ed2b1caca"}, + {file = "torchvision-0.18.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:13d24d904f65e62d66a1e0c41faec630bc193867b8a4a01166769e8a8e8df8e9"}, + {file = "torchvision-0.18.1-cp312-cp312-win_amd64.whl", hash = "sha256:ed6340b69a63a625e512a66127210d412551d9c5f2ad2978130c6a45bf56cd4a"}, ] [[package]] name = "tqdm" -version = "4.66.2" +version = "4.66.4" requires_python = ">=3.7" summary = "Fast, Extensible Progress Meter" groups = ["default"] @@ -2804,24 +2895,24 @@ dependencies = [ "colorama; platform_system == \"Windows\"", ] files = [ - {file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"}, - {file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"}, + {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"}, + {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, ] [[package]] name = "traitlets" -version = "5.14.2" +version = "5.14.3" requires_python = ">=3.8" summary = "Traitlets Python configuration system" groups = ["default"] files = [ - {file = "traitlets-5.14.2-py3-none-any.whl", hash = "sha256:fcdf85684a772ddeba87db2f398ce00b40ff550d1528c03c14dbf6a02003cd80"}, - {file = "traitlets-5.14.2.tar.gz", hash = "sha256:8cdd83c040dab7d1dee822678e5f5d100b514f7b72b01615b26fc5718916fdf9"}, + {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, + {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, ] [[package]] name = "trimesh" -version = "4.3.2" +version = "4.4.0" requires_python = ">=3.7" summary = "Import, export, process, analyze and view triangular meshes." groups = ["default"] @@ -2829,8 +2920,8 @@ dependencies = [ "numpy>=1.20", ] files = [ - {file = "trimesh-4.3.2-py3-none-any.whl", hash = "sha256:7563182a9379485b88a44e87156fe54b41fb6f8f030001b9b6de39abdef05c22"}, - {file = "trimesh-4.3.2.tar.gz", hash = "sha256:1450dbd1aae8dd825eddd56c5a7d7d1b35cad7efc2c63d535e19569577c25916"}, + {file = "trimesh-4.4.0-py3-none-any.whl", hash = "sha256:e192458da391c1b0a850df0b713c59234a6582e641569b004b588ada337b05c0"}, + {file = "trimesh-4.4.0.tar.gz", hash = "sha256:daf6e56715de2e93dd905e926f9bb10d23dc4157f9724aa7caab5d0e28963e56"}, ] [[package]] @@ -2846,13 +2937,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.10.0" +version = "4.12.0" requires_python = ">=3.8" summary = "Backported and Experimental Type Hints for Python 3.8+" -groups = ["default"] +groups = ["default", "dev"] files = [ - {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, - {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, + {file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"}, + {file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"}, ] [[package]] @@ -2879,7 +2970,7 @@ files = [ [[package]] name = "uvicorn" -version = "0.29.0" +version = "0.30.0" requires_python = ">=3.8" summary = "The lightning-fast ASGI server." groups = ["default"] @@ -2888,33 +2979,38 @@ dependencies = [ "h11>=0.8", ] files = [ - {file = "uvicorn-0.29.0-py3-none-any.whl", hash = "sha256:2c2aac7ff4f4365c206fd773a39bf4ebd1047c238f8b8268ad996829323473de"}, - {file = "uvicorn-0.29.0.tar.gz", hash = "sha256:6a69214c0b6a087462412670b3ef21224fa48cae0e452b5883e8e8bdfdd11dd0"}, + {file = "uvicorn-0.30.0-py3-none-any.whl", hash = "sha256:78fa0b5f56abb8562024a59041caeb555c86e48d0efdd23c3fe7de7a4075bdab"}, + {file = "uvicorn-0.30.0.tar.gz", hash = "sha256:f678dec4fa3a39706bbf49b9ec5fc40049d42418716cea52b53f07828a60aa37"}, ] [[package]] name = "wandb" -version = "0.16.4" +version = "0.17.0" requires_python = ">=3.7" summary = "A CLI and library for interacting with the Weights & Biases API." groups = ["default"] dependencies = [ - "Click!=8.0.0,>=7.1", - "GitPython!=3.1.29,>=1.0.0", - "PyYAML", - "appdirs>=1.4.3", + "click!=8.0.0,>=7.1", "docker-pycreds>=0.4.0", + "gitpython!=3.1.29,>=1.0.0", + "platformdirs", "protobuf!=4.21.0,<5,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"", "protobuf!=4.21.0,<5,>=3.19.0; sys_platform != \"linux\"", "psutil>=5.0.0", + "pyyaml", "requests<3,>=2.0.0", "sentry-sdk>=1.0.0", "setproctitle", "setuptools", ] files = [ - {file = "wandb-0.16.4-py3-none-any.whl", hash = "sha256:bb9eb5aa2c2c85e11c76040c4271366f54d4975167aa6320ba86c3f2d97fe5fa"}, - {file = "wandb-0.16.4.tar.gz", hash = "sha256:8752c67d1347a4c29777e64dc1e1a742a66c5ecde03aebadf2b0d62183fa307c"}, + {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"}, + {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"}, + {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"}, + {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"}, + {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"}, + {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"}, + {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"}, ] [[package]] @@ -2929,13 +3025,13 @@ files = [ [[package]] name = "websocket-client" -version = "1.7.0" +version = "1.8.0" requires_python = ">=3.8" summary = "WebSocket client for Python with low level API options" groups = ["default"] files = [ - {file = "websocket-client-1.7.0.tar.gz", hash = "sha256:10e511ea3a8c744631d3bd77e61eb17ed09304c413ad42cf6ddfa4c7787e8fe6"}, - {file = "websocket_client-1.7.0-py3-none-any.whl", hash = "sha256:f4c3d22fec12a2461427a29957ff07d35098ee2d976d3ba244e688b8b4057588"}, + {file = "websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526"}, + {file = "websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da"}, ] [[package]] @@ -2966,7 +3062,7 @@ files = [ [[package]] name = "werkzeug" -version = "3.0.2" +version = "3.0.3" requires_python = ">=3.8" summary = "The comprehensive WSGI web application library." groups = ["default"] @@ -2974,8 +3070,8 @@ dependencies = [ "MarkupSafe>=2.1.1", ] files = [ - {file = "werkzeug-3.0.2-py3-none-any.whl", hash = "sha256:3aac3f5da756f93030740bc235d3e09449efcf65f2f55e3602e1d851b8f48795"}, - {file = "werkzeug-3.0.2.tar.gz", hash = "sha256:e39b645a6ac92822588e7b39a692e7828724ceae0b0d702ef96701f90e70128d"}, + {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, + {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, ] [[package]] @@ -3021,11 +3117,11 @@ files = [ [[package]] name = "zipp" -version = "3.18.1" +version = "3.19.1" requires_python = ">=3.8" summary = "Backport of pathlib-compatible object wrapper for zip files" groups = ["default"] files = [ - {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, - {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, + {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, + {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, ] diff --git a/project/algorithms/__init__.py b/project/algorithms/__init__.py index fe64d6b3..5a0a3021 100644 --- a/project/algorithms/__init__.py +++ b/project/algorithms/__init__.py @@ -1,5 +1,6 @@ from hydra_zen import builds, store +from project.algorithms.jax_algo import JaxAlgorithm from project.algorithms.no_op import NoOp from .bases.algorithm import Algorithm @@ -13,11 +14,12 @@ # If you add a configuration file under `configs/algorithm`, it will also be available as an option # from the command-line, and be validated against the schema. - +# todo: It might be nicer if we did this this `configs/algorithms` instead of here, no? algorithm_store = store(group="algorithm") algorithm_store(ExampleAlgorithm.HParams(), name="example_algo") algorithm_store(ManualGradientsExample.HParams(), name="manual_optimization") algorithm_store(builds(NoOp, populate_full_signature=False), name="no_op") +algorithm_store(JaxAlgorithm.HParams(), name="jax_algo") algorithm_store.add_to_hydra_store() diff --git a/project/algorithms/bases/algorithm.py b/project/algorithms/bases/algorithm.py index ffc7328f..5806c625 100644 --- a/project/algorithms/bases/algorithm.py +++ b/project/algorithms/bases/algorithm.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Any, TypedDict +import torch from lightning import Callback, LightningModule, Trainer from torch import Tensor, nn from typing_extensions import Generic, TypeVar # noqa @@ -46,12 +47,18 @@ def __init__( self, *, datamodule: DataModule[BatchType] | None = None, - network: NetworkType, + network: NetworkType | None = None, hp: HParams | None = None, ): super().__init__() self.datamodule = datamodule - self._device = get_device(network) # fix for `self.device` property which defaults to cpu. + if isinstance(network, torch.nn.Module): + # fix for `self.device` property which defaults to cpu. + self._device = get_device(network) + elif network and not isinstance(network, torch.nn.Module): + # todo: Should we automatically convert jax networks to torch in case the base class + # doesn't? + pass self.network = network self.hp = hp or self.HParams() self.trainer: Trainer diff --git a/project/algorithms/bases/algorithm_test.py b/project/algorithms/bases/algorithm_test.py index 98c7d00a..5df5ee95 100644 --- a/project/algorithms/bases/algorithm_test.py +++ b/project/algorithms/bases/algorithm_test.py @@ -18,11 +18,12 @@ from lightning import Callback, LightningDataModule, LightningModule, Trainer from lightning.pytorch.utilities.types import STEP_OUTPUT from omegaconf import DictConfig +from tensor_regression import TensorRegressionFixture from torch import Tensor, nn from torch.utils.data import DataLoader from typing_extensions import ParamSpec -from project.configs.config import Config, cs +from project.configs import Config, cs from project.conftest import setup_hydra_for_tests_and_compose from project.datamodules.image_classification import ( ImageClassificationDataModule, @@ -34,7 +35,6 @@ ) from project.main import main from project.utils.hydra_utils import resolve_dictconfig -from project.utils.tensor_regression import TensorRegressionFixture from project.utils.testutils import ( default_marks_for_config_name, get_all_datamodule_names_params, @@ -332,9 +332,11 @@ def _hydra_config( All overrides should have already been applied. """ + # todo: remove this hard-coded check somehow. if "resnet" in network_name and datamodule_name in ["mnist", "fashion_mnist"]: pytest.skip(reason="ResNet's can't be used on MNIST datasets.") + # todo: Get the name of the algorithm from the hydra config? algorithm_name = self.algorithm_name with setup_hydra_for_tests_and_compose( all_overrides=[ @@ -388,8 +390,9 @@ def network( f"type {type(network)}" ) ) - assert isinstance(network, nn.Module) - return network.to(device=device) + if isinstance(network, nn.Module): + network = network.to(device=device) + return network @pytest.fixture(scope="class") def hp(self, experiment_config: Config) -> Algorithm.HParams: # type: ignore @@ -554,7 +557,7 @@ def on_train_batch_end( pl_module: LightningModule, outputs, batch: tuple[Tensor, Tensor], - batch_idx: int, + batch_index: int, ) -> None: assert self.metric in trainer.logged_metrics, (self.metric, trainer.logged_metrics.keys()) metric_value = trainer.logged_metrics[self.metric] @@ -591,9 +594,9 @@ def on_train_batch_end( pl_module: LightningModule, outputs, batch: tuple[Tensor, Tensor], - batch_idx: int, + batch_index: int, ) -> None: - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_index) self.num_training_steps += 1 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -644,9 +647,9 @@ def on_train_batch_end( pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, - batch_idx: int, + batch_index: int, ) -> None: - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_index) parameters_with_nans = [ name for name, param in pl_module.named_parameters() if param.isnan().any() @@ -701,7 +704,7 @@ def on_train_batch_end( pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, - batch_idx: int, + batch_index: int, ) -> None: if self.item_index is not None: batch = batch[self.item_index] diff --git a/project/algorithms/bases/image_classification.py b/project/algorithms/bases/image_classification.py index 45fef1d1..49841119 100644 --- a/project/algorithms/bases/image_classification.py +++ b/project/algorithms/bases/image_classification.py @@ -58,7 +58,7 @@ def __init__( # NOTE: Setting this property allows PL to infer the shapes and number of params. # TODO: Check if PL now moves the `example_input_array` to the right device automatically. # If possible, we'd like to remove any reference to the device from the algorithm. - self.example_input_array = torch.rand( + self.example_input_array = torch.zeros( [datamodule.batch_size, *datamodule.dims], device=self.device, ) @@ -74,21 +74,23 @@ def __init__( self.val_top5_accuracy = MulticlassAccuracy(num_classes=num_classes, top_k=5) self.test_top5_accuracy = MulticlassAccuracy(num_classes=num_classes, top_k=5) - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> ClassificationOutputs: + def training_step( + self, batch: tuple[Tensor, Tensor], batch_index: int + ) -> ClassificationOutputs: """Performs a training step.""" - return self.shared_step(batch=batch, batch_idx=batch_idx, phase="train") + return self.shared_step(batch=batch, batch_index=batch_index, phase="train") def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int + self, batch: tuple[Tensor, Tensor], batch_index: int ) -> ClassificationOutputs: """Performs a validation step.""" - return self.shared_step(batch=batch, batch_idx=batch_idx, phase="val") + return self.shared_step(batch=batch, batch_index=batch_index, phase="val") - def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> ClassificationOutputs: + def test_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> ClassificationOutputs: """Performs a test step.""" - return self.shared_step(batch=batch, batch_idx=batch_idx, phase="test") + return self.shared_step(batch=batch, batch_index=batch_index, phase="test") - def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int): + def predict_step(self, batch: Tensor, batch_index: int, dataloader_idx: int): """Performs a prediction step.""" return self.predict(batch) @@ -98,7 +100,7 @@ def predict(self, x: Tensor) -> Tensor: @abstractmethod def shared_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int, phase: PhaseStr + self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr ) -> ClassificationOutputs: """Performs a training/validation/test step. diff --git a/project/algorithms/callbacks/callback.py b/project/algorithms/callbacks/callback.py index f7b17ebb..ae8f12be 100644 --- a/project/algorithms/callbacks/callback.py +++ b/project/algorithms/callbacks/callback.py @@ -34,7 +34,7 @@ def on_shared_batch_start( trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType], batch: BatchType, - batch_idx: int, + batch_index: int, phase: PhaseStr, dataloader_idx: int | None = None, ): ... @@ -45,7 +45,7 @@ def on_shared_batch_end( pl_module: Algorithm[BatchType, StepOutputType], outputs: StepOutputType, batch: BatchType, - batch_idx: int, + batch_index: int, phase: PhaseStr, dataloader_idx: int | None = None, ): ... @@ -65,21 +65,21 @@ def on_train_batch_end( pl_module: Algorithm[BatchType, StepOutputType], outputs: StepOutputType, batch: BatchType, - batch_idx: int, + batch_index: int, ) -> None: super().on_train_batch_end( trainer=trainer, pl_module=pl_module, outputs=outputs, # type: ignore batch=batch, - batch_idx=batch_idx, + batch_idx=batch_index, ) self.on_shared_batch_end( trainer=trainer, pl_module=pl_module, outputs=outputs, batch=batch, - batch_idx=batch_idx, + batch_index=batch_index, phase="train", ) @@ -90,15 +90,15 @@ def on_validation_batch_end( pl_module: Algorithm[BatchType, StepOutputType], outputs: StepOutputType, batch: BatchType, - batch_idx: int, - dataloader_idx: int, + batch_index: int, + dataloader_idx: int = 0, ) -> None: super().on_validation_batch_end( trainer=trainer, pl_module=pl_module, outputs=outputs, # type: ignore batch=batch, - batch_idx=batch_idx, + batch_idx=batch_index, dataloader_idx=dataloader_idx, ) self.on_shared_batch_end( @@ -106,9 +106,9 @@ def on_validation_batch_end( pl_module=pl_module, outputs=outputs, batch=batch, - batch_idx=batch_idx, - dataloader_idx=dataloader_idx, + batch_index=batch_index, phase="val", + dataloader_idx=dataloader_idx, ) @override @@ -118,15 +118,15 @@ def on_test_batch_end( pl_module: Algorithm[BatchType, StepOutputType], outputs: StepOutputType, batch: BatchType, - batch_idx: int, - dataloader_idx: int, + batch_index: int, + dataloader_idx: int = 0, ) -> None: super().on_test_batch_end( trainer=trainer, pl_module=pl_module, outputs=outputs, # type: ignore batch=batch, - batch_idx=batch_idx, + batch_idx=batch_index, dataloader_idx=dataloader_idx, ) self.on_shared_batch_end( @@ -134,7 +134,7 @@ def on_test_batch_end( pl_module=pl_module, outputs=outputs, batch=batch, - batch_idx=batch_idx, + batch_index=batch_index, dataloader_idx=dataloader_idx, phase="test", ) @@ -145,11 +145,15 @@ def on_train_batch_start( trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType], batch: BatchType, - batch_idx: int, + batch_index: int, ) -> None: - super().on_train_batch_start(trainer, pl_module, batch, batch_idx) + super().on_train_batch_start(trainer, pl_module, batch, batch_index) self.on_shared_batch_start( - trainer=trainer, pl_module=pl_module, batch=batch, batch_idx=batch_idx, phase="train" + trainer=trainer, + pl_module=pl_module, + batch=batch, + batch_index=batch_index, + phase="train", ) @override @@ -158,15 +162,15 @@ def on_validation_batch_start( trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType], batch: BatchType, - batch_idx: int, - dataloader_idx: int, + batch_index: int, + dataloader_idx: int = 0, ) -> None: - super().on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) + super().on_validation_batch_start(trainer, pl_module, batch, batch_index, dataloader_idx) self.on_shared_batch_start( trainer, pl_module, batch, - batch_idx, + batch_index, dataloader_idx=dataloader_idx, phase="val", ) @@ -177,15 +181,15 @@ def on_test_batch_start( trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType], batch: BatchType, - batch_idx: int, - dataloader_idx: int, + batch_index: int, + dataloader_idx: int = 0, ) -> None: - super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) + super().on_test_batch_start(trainer, pl_module, batch, batch_index, dataloader_idx) self.on_shared_batch_start( trainer, pl_module, batch, - batch_idx, + batch_index, dataloader_idx=dataloader_idx, phase="test", ) diff --git a/project/algorithms/callbacks/classification_metrics.py b/project/algorithms/callbacks/classification_metrics.py new file mode 100644 index 00000000..fbbad03b --- /dev/null +++ b/project/algorithms/callbacks/classification_metrics.py @@ -0,0 +1,177 @@ +import warnings +from logging import getLogger as get_logger +from typing import Any, Required + +import torch +import torchmetrics +from lightning import LightningModule, Trainer +from torch import Tensor +from torchmetrics.classification import MulticlassAccuracy + +from project.algorithms.bases.algorithm import Algorithm, BatchType +from project.algorithms.bases.image_classification import StepOutputDict +from project.algorithms.callbacks.callback import Callback +from project.utils.types import PhaseStr, StageStr +from project.utils.types.protocols import ClassificationDataModule + +logger = get_logger(__name__) + + +class ClassificationOutputs(StepOutputDict): + """The dictionary format that is minimally required to be returned from + `training/val/test_step` for classification algorithms.""" + + logits: Required[Tensor] + """The un-normalized logits.""" + + y: Required[Tensor] + """The class labels.""" + + +class ClassificationMetricsCallback(Callback[BatchType, ClassificationOutputs]): + """Callback that adds classification metrics to the pl module.""" + + def __init__(self) -> None: + super().__init__() + self.disabled = False + + @classmethod + def attach_to(cls, algorithm: Algorithm, num_classes: int): + callback = cls() + callback.add_metrics_to(algorithm, num_classes=num_classes) + return callback + + def add_metrics_to(self, pl_module: LightningModule, num_classes: int) -> None: + # IDEA: Could use a dict of metrics from torchmetrics instead of just accuracy: + # self.supervised_metrics: dist[str, Metrics] + # NOTE: Need to have one per phase! Not 100% sure that I'm not forgetting a phase here. + + # Slightly ugly. Need to set the metrics on the pl module for things to be logged / synced + # easily. + metrics_to_add = { + "train_accuracy": MulticlassAccuracy(num_classes=num_classes), + "val_accuracy": MulticlassAccuracy(num_classes=num_classes), + "test_accuracy": MulticlassAccuracy(num_classes=num_classes), + "train_top5_accuracy": MulticlassAccuracy(num_classes=num_classes, top_k=5), + "val_top5_accuracy": MulticlassAccuracy(num_classes=num_classes, top_k=5), + "test_top5_accuracy": MulticlassAccuracy(num_classes=num_classes, top_k=5), + } + if all( + hasattr(pl_module, name) and isinstance(getattr(pl_module, name), type(metric)) + for name, metric in metrics_to_add.items() + ): + logger.info("Not adding metrics to the pl module because they are already present.") + return + + for metric_name, metric in metrics_to_add.items(): + self._set_metric(pl_module, metric_name, metric) + + # todo: change these two if we end up putting metrics in a ModuleDict. + @staticmethod + def _set_metric(pl_module: LightningModule, name: str, metric: torchmetrics.Metric): + if hasattr(pl_module, name): + raise RuntimeError(f"The pl module already has an attribute with the name {name}.") + logger.info(f"Setting a new metric on the pl module at attribute {name}.") + setattr(pl_module, name, metric) + + @staticmethod + def _get_metric(pl_module: LightningModule, name: str): + return getattr(pl_module, name) + + def setup( + self, + trainer: Trainer, + pl_module: Algorithm[BatchType, ClassificationOutputs, Any], + stage: StageStr, + ) -> None: + if self.disabled: + return + datamodule = pl_module.datamodule + if not isinstance(datamodule, ClassificationDataModule): + warnings.warn( + RuntimeWarning( + f"Disabling the {type(self).__name__} callback because it only works with " + f"classification datamodules, but {pl_module.datamodule=} isn't a " + f"{ClassificationDataModule.__name__}." + ) + ) + self.disabled = True + return + + num_classes = datamodule.num_classes + self.add_metrics_to(pl_module, num_classes=num_classes) + + def on_shared_batch_end( + self, + trainer: Trainer, + pl_module: Algorithm[BatchType, ClassificationOutputs, Any], + outputs: ClassificationOutputs, + batch: BatchType, + batch_index: int, + phase: PhaseStr, + dataloader_idx: int | None = None, + ): + if self.disabled: + return + step_output = outputs + required_entries = ClassificationOutputs.__required_keys__ + if not isinstance(outputs, dict): + warnings.warn( + RuntimeWarning( + f"Expected the {phase} step method to output a dictionary with at least the " + f"{required_entries} keys, but got an output of type {type(step_output)} instead!\n" + f"Disabling the {type(self).__name__} callback." + ) + ) + self.disabled = True + return + if not all(k in step_output for k in required_entries): + warnings.warn( + RuntimeWarning( + f"Expected all the following keys to be in the output of the {phase} step " + f"method: {required_entries}. Disabling the {type(self).__name__} callback." + ) + ) + self.disabled = True + return + + logits = step_output["logits"] + y = step_output["y"] + + probs = torch.softmax(logits, -1) + + accuracy = self._get_metric(pl_module, f"{phase}_accuracy") + top5_accuracy = self._get_metric(pl_module, f"{phase}_top5_accuracy") + assert isinstance(accuracy, MulticlassAccuracy) + assert isinstance(top5_accuracy, MulticlassAccuracy) + + # TODO: It's a bit confusing, not sure if this is the right way to use this: + accuracy(probs, y) + top5_accuracy(probs, y) + prog_bar = phase == "train" + + pl_module.log(f"{phase}/accuracy", accuracy, prog_bar=prog_bar, sync_dist=True) + pl_module.log(f"{phase}/top5_accuracy", top5_accuracy, prog_bar=prog_bar, sync_dist=True) + + if "cross_entropy" not in step_output: + # Add the cross entropy loss as a metric. + with torch.no_grad(): + ce_loss = torch.nn.functional.cross_entropy(logits.detach(), y, reduction="mean") + pl_module.log(f"{phase}/cross_entropy", ce_loss, prog_bar=prog_bar, sync_dist=True) + + loss: Tensor | float | None = step_output.get("loss", None) + if loss is not None: + # note: Perhaps we should be careful not to overwrite the logged value if its already been logged? + pl_module.log( + f"{phase}/loss", torch.as_tensor(loss).mean(), prog_bar=prog_bar, sync_dist=True + ) + + # This part isn't necessary here: Average out the losses properly. + # fused_output = step_output.copy() + # if isinstance(loss, Tensor) and loss.shape: + # # Replace the loss with its mean. This is useful when automatic + # # optimization is enabled, for example in the baseline (backprop), where each replica + # # returns the un-reduced cross-entropy loss. Here we need to reduce it to a scalar. + # fused_output["loss"] = loss.mean() + + # return fused_output diff --git a/project/algorithms/callbacks/samples_per_second.py b/project/algorithms/callbacks/samples_per_second.py index 2bb5f50c..6b6e2b12 100644 --- a/project/algorithms/callbacks/samples_per_second.py +++ b/project/algorithms/callbacks/samples_per_second.py @@ -1,7 +1,8 @@ import time -from lightning import Trainer +from lightning import LightningModule, Trainer from torch import Tensor, nn +from torch.optim import Optimizer from project.algorithms.bases.algorithm import Algorithm, BatchType, StepOutputDict from project.algorithms.callbacks.callback import Callback @@ -12,6 +13,8 @@ class MeasureSamplesPerSecondCallback(Callback[BatchType, StepOutputDict]): def __init__(self): super().__init__() self.last_step_times: dict[PhaseStr, float] = {} + self.last_update_time: dict[int, float | None] = {} + self.num_optimizers: int | None = None def on_shared_epoch_start( self, @@ -19,7 +22,14 @@ def on_shared_epoch_start( pl_module: Algorithm[BatchType, StepOutputDict, nn.Module], phase: PhaseStr, ) -> None: + self.last_update_time.clear() self.last_step_times.pop(phase, None) + if self.num_optimizers is None: + optimizer_or_optimizers = pl_module.optimizers() + if not isinstance(optimizer_or_optimizers, list): + self.num_optimizers = 1 + else: + self.num_optimizers = len(optimizer_or_optimizers) def on_shared_batch_end( self, @@ -27,7 +37,7 @@ def on_shared_batch_end( pl_module: Algorithm[BatchType, StepOutputDict, nn.Module], outputs: StepOutputDict, batch: BatchType, - batch_idx: int, + batch_index: int, phase: PhaseStr, dataloader_idx: int | None = None, ): @@ -36,7 +46,7 @@ def on_shared_batch_end( pl_module=pl_module, outputs=outputs, batch=batch, - batch_idx=batch_idx, + batch_index=batch_index, phase=phase, dataloader_idx=dataloader_idx, ) @@ -45,6 +55,35 @@ def on_shared_batch_end( elapsed = now - self.last_step_times[phase] if is_sequence_of(batch, Tensor): batch_size = batch[0].shape[0] - pl_module.log(f"{phase}/samples_per_second", batch_size / elapsed, prog_bar=True) + pl_module.log( + f"{phase}/samples_per_second", + batch_size / elapsed, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True, + ) # todo: support other kinds of batches self.last_step_times[phase] = now + + def on_before_optimizer_step( + self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer, opt_idx: int = 0 + ) -> None: + if opt_idx not in self.last_update_time or self.last_update_time[opt_idx] is None: + self.last_update_time[opt_idx] = time.perf_counter() + return + last_update_time = self.last_update_time[opt_idx] + assert last_update_time is not None + now = time.perf_counter() + elapsed = now - last_update_time + updates_per_second = 1 / elapsed + if self.num_optimizers == 1: + key = "ups" + else: + key = f"optimizer_{opt_idx}/ups" + pl_module.log( + key, + updates_per_second, + prog_bar=False, + on_step=True, + ) diff --git a/project/algorithms/example_algo.py b/project/algorithms/example_algo.py index 188c326e..92c6902b 100644 --- a/project/algorithms/example_algo.py +++ b/project/algorithms/example_algo.py @@ -87,7 +87,7 @@ def forward(self, input: Tensor) -> Tensor: def shared_step( self, batch: tuple[Tensor, Tensor], - batch_idx: int, + batch_index: int, phase: PhaseStr, ) -> ClassificationOutputs: x, y = batch diff --git a/project/algorithms/example_algo_test.py b/project/algorithms/example_algo_test.py index 460d86a6..55c67195 100644 --- a/project/algorithms/example_algo_test.py +++ b/project/algorithms/example_algo_test.py @@ -1,5 +1,7 @@ from typing import ClassVar +import torch + from project.algorithms.bases.image_classification_test import ImageClassificationAlgorithmTests from .example_algo import ExampleAlgorithm @@ -9,3 +11,4 @@ class TestExampleAlgorithm(ImageClassificationAlgorithmTests[ExampleAlgorithm]): algorithm_type = ExampleAlgorithm algorithm_name: str = "example_algo" unsupported_datamodule_names: ClassVar[list[str]] = ["rl"] + _supported_network_types: ClassVar[list[type]] = [torch.nn.Module] diff --git a/project/algorithms/jax_algo.py b/project/algorithms/jax_algo.py new file mode 100644 index 00000000..5b527e7d --- /dev/null +++ b/project/algorithms/jax_algo.py @@ -0,0 +1,216 @@ +import dataclasses +import logging +import os +from collections.abc import Callable +from typing import Concatenate, Literal + +import flax.linen +import jax +import lightning +import lightning.pytorch +import lightning.pytorch.callbacks +import rich +import rich.logging +import torch +import torch.distributed +from lightning import Callback, Trainer +from torch_jax_interop import WrappedJaxFunction, torch_to_jax + +from project.algorithms.bases.algorithm import Algorithm +from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback +from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback +from project.datamodules.image_classification.base import ImageClassificationDataModule +from project.datamodules.image_classification.mnist import MNISTDataModule +from project.utils.types import PhaseStr +from project.utils.types.protocols import ClassificationDataModule + +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + + +def flatten(x: jax.Array) -> jax.Array: + return x.reshape((x.shape[0], -1)) + + +class CNN(flax.linen.Module): + """A simple CNN model. + + Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network + """ + + num_classes: int = 10 + + @flax.linen.compact + def __call__(self, x: jax.Array): + x = to_channels_last(x) + x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + + x = flatten(x) + x = flax.linen.Dense(features=256)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=self.num_classes)(x) + return x + + +class JaxFcNet(flax.linen.Module): + num_classes: int = 10 + num_features: int = 256 + + @flax.linen.compact + def __call__(self, x: jax.Array): + x = flatten(x) + x = flax.linen.Dense(features=self.num_features)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=self.num_classes)(x) + return x + + +# Register a handler function to "convert" `torch.nn.Parameter`s to jax Arrays: they can be viewed +# as jax Arrays by just viewing their data as a jax array. +@torch_to_jax.register(torch.nn.Parameter) +def _parameter_to_jax_array(value: torch.nn.Parameter) -> jax.Array: + return torch_to_jax(value.data) + + +class JaxAlgorithm(Algorithm): + """Example of an algorithm that uses Jax. + + In this case, the network is a flax.linen.Module, and its forward and backward passes are + written in Jax. + """ + + @dataclasses.dataclass + class HParams(Algorithm.HParams): + lr: float = 1e-3 + seed: int = 123 + debug: bool = True + + def __init__( + self, + *, + network: flax.linen.Module, + datamodule: ImageClassificationDataModule, + hp: HParams | None = None, + ): + super().__init__(datamodule=datamodule) + self.hp: JaxAlgorithm.HParams = hp or self.HParams() + + example_input = torch.zeros( + (datamodule.batch_size, *datamodule.dims), + device=self.device, + ) + # Initialize the jax parameters with a forward pass. + params = network.init(jax.random.key(self.hp.seed), x=torch_to_jax(example_input)) + + # Wrap the jax network into a nn.Module: + self.network = WrappedJaxFunction( + jax_function=jax.jit(network.apply) if not self.hp.debug else network.apply, + jax_params=params, + # Need to call .clone() when doing distributed training, otherwise we get a RuntimeError: + # Invalid device pointer when trying to share the CUDA tensors that come from jax. + clone_params=True, + has_aux=False, + ) + + self.example_input_array = example_input + + def shared_step( + self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr + ): + x, y = batch + assert not x.requires_grad + logits = self.network(x) + assert isinstance(logits, torch.Tensor) + # In this example we use a jax "encoder" network and a PyTorch loss function, but we could + # also just as easily have done the whole forward and backward pass in jax if we wanted to. + loss = torch.nn.functional.cross_entropy(logits, target=y, reduction="mean") + acc = logits.argmax(-1).eq(y).float().mean() + self.log(f"{phase}/loss", loss, prog_bar=True, sync_dist=True) + self.log(f"{phase}/acc", acc, prog_bar=True, sync_dist=True) + return {"loss": loss, "logits": logits, "y": y} + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.hp.lr) + + def configure_callbacks(self) -> list[Callback]: + assert isinstance(self.datamodule, ClassificationDataModule) + return super().configure_callbacks() + [ + MeasureSamplesPerSecondCallback(), + ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes), + ] + + +def is_channels_first(shape: tuple[int, ...]) -> bool: + if len(shape) == 4: + return is_channels_first(shape[1:]) + if len(shape) != 3: + return False + return (shape[0] in (1, 3) and shape[1] not in {1, 3} and shape[2] not in {1, 3}) or ( + shape[0] < min(shape[1], shape[2]) + ) + + +def is_channels_last(shape: tuple[int, ...]) -> bool: + if len(shape) == 4: + return is_channels_last(shape[1:]) + if len(shape) != 3: + return False + return (shape[2] in (1, 3) and shape[0] not in {1, 3} and shape[1] not in {1, 3}) or ( + shape[2] < min(shape[0], shape[1]) + ) + + +def to_channels_last(x: jax.Array) -> jax.Array: + shape = tuple(x.shape) + if is_channels_last(shape): + return x + if not is_channels_first(shape): + return x + if x.ndim == 3: + return x.transpose(1, 2, 0) + assert x.ndim == 4 + return x.transpose(0, 2, 3, 1) + + +def jit[**P, Out]( + fn: Callable[P, Out], +) -> Callable[P, Out]: + """Small type hint fix for jax's `jit` (preserves the signature of the callable).""" + return jax.jit(fn) # type: ignore + + +def value_and_grad[In, **P, Out, Aux]( + fn: Callable[Concatenate[In, P], tuple[Out, Aux]], + argnums: Literal[0] = 0, + has_aux: Literal[True] = True, +) -> Callable[Concatenate[In, P], tuple[tuple[Out, Aux], In]]: + """Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable).""" + return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) # type: ignore + + +def main(): + logging.basicConfig( + level=logging.INFO, format="%(message)s", handlers=[rich.logging.RichHandler()] + ) + trainer = Trainer( + devices="auto", + max_epochs=10, + accelerator="auto", + callbacks=[lightning.pytorch.callbacks.RichProgressBar()], + ) + datamodule = MNISTDataModule(num_workers=4, batch_size=512) + network = CNN(num_classes=datamodule.num_classes) + + model = JaxAlgorithm(network=network, datamodule=datamodule) + trainer.fit(model, datamodule=datamodule) + + ... + + +if __name__ == "__main__": + main() + print("Done!") diff --git a/project/algorithms/jax_algo_test.py b/project/algorithms/jax_algo_test.py new file mode 100644 index 00000000..2c5acd52 --- /dev/null +++ b/project/algorithms/jax_algo_test.py @@ -0,0 +1,17 @@ +from typing import ClassVar + +import flax +import flax.linen +import torch + +from project.algorithms.jax_algo import JaxAlgorithm + +from .bases.algorithm_test import AlgorithmTests + + +class TestJaxAlgorithm(AlgorithmTests[JaxAlgorithm]): + """This algorithm only works with Jax modules.""" + + algorithm_name: ClassVar[str] = "jax_algo" + unsupported_network_types: ClassVar[list[type]] = [torch.nn.Module] + _supported_network_types: ClassVar[list[type]] = [flax.linen.Module] diff --git a/project/algorithms/manual_optimization_example.py b/project/algorithms/manual_optimization_example.py index 55a1f5e6..3941965e 100644 --- a/project/algorithms/manual_optimization_example.py +++ b/project/algorithms/manual_optimization_example.py @@ -49,16 +49,18 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.network(x) - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> ClassificationOutputs: - return self.shared_step(batch, batch_idx, "train") + def training_step( + self, batch: tuple[Tensor, Tensor], batch_index: int + ) -> ClassificationOutputs: + return self.shared_step(batch, batch_index, "train") def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int + self, batch: tuple[Tensor, Tensor], batch_index: int ) -> ClassificationOutputs: - return self.shared_step(batch, batch_idx, "val") + return self.shared_step(batch, batch_index, "val") def shared_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int, phase: PhaseStr + self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr ) -> ClassificationOutputs: """Performs a training/validation/test step. diff --git a/project/algorithms/manual_optimization_example_test.py b/project/algorithms/manual_optimization_example_test.py index 486f8767..8509df95 100644 --- a/project/algorithms/manual_optimization_example_test.py +++ b/project/algorithms/manual_optimization_example_test.py @@ -1,5 +1,7 @@ from typing import ClassVar +import torch + from project.algorithms.bases.image_classification_test import ImageClassificationAlgorithmTests from .manual_optimization_example import ManualGradientsExample @@ -10,3 +12,4 @@ class TestManualOptimizationExample(ImageClassificationAlgorithmTests[ManualGrad algorithm_name: str = "manual_optimization" unsupported_datamodule_names: ClassVar[list[str]] = ["rl"] + _supported_network_types: ClassVar[list[type]] = [torch.nn.Module] diff --git a/project/configs/__init__.py b/project/configs/__init__.py index 74ce1a83..bb734082 100644 --- a/project/configs/__init__.py +++ b/project/configs/__init__.py @@ -2,35 +2,20 @@ from hydra.core.config_store import ConfigStore -from project.networks import FcNetConfig, ResNet18Config - from .config import Config from .datamodule import ( REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR, - CIFAR10DataModuleConfig, - DataModuleConfig, - FashionMNISTDataModuleConfig, - ImageNet32DataModuleConfig, - INaturalistDataModuleConfig, - MNISTDataModuleConfig, + datamodule_store, ) - -# todo: look into using this instead: -# from hydra_zen import store +from .network import network_store cs = ConfigStore.instance() -cs.store(group="datamodule", name="base", node=DataModuleConfig) -cs.store(group="datamodule", name="cifar10", node=CIFAR10DataModuleConfig) -cs.store(group="datamodule", name="mnist", node=MNISTDataModuleConfig) -cs.store(group="datamodule", name="fashion_mnist", node=FashionMNISTDataModuleConfig) -cs.store(group="datamodule", name="imagenet32", node=ImageNet32DataModuleConfig) -cs.store(group="datamodule", name="inaturalist", node=INaturalistDataModuleConfig) - - -cs.store(group="network", name="fcnet", node=FcNetConfig) -cs.store(group="network", name="resnet18", node=ResNet18Config) +cs.store(name="base_config", node=Config) +datamodule_store.add_to_hydra_store() +network_store.add_to_hydra_store() +# todo: move the algorithm_store.add_to_hydra_store() here? __all__ = [ "Config", diff --git a/project/configs/algorithm/lr_scheduler/__init__.py b/project/configs/algorithm/lr_scheduler/__init__.py index 4413a2fb..67180aa1 100644 --- a/project/configs/algorithm/lr_scheduler/__init__.py +++ b/project/configs/algorithm/lr_scheduler/__init__.py @@ -12,6 +12,7 @@ ] +# TODO: getting doctest issues here? @hydrated_dataclass(target=torch.optim.lr_scheduler.StepLR, zen_partial=True) class StepLRConfig: """Config for the StepLR Scheduler.""" diff --git a/project/configs/config.py b/project/configs/config.py index b6d3bebb..5ff41808 100644 --- a/project/configs/config.py +++ b/project/configs/config.py @@ -3,8 +3,6 @@ from logging import getLogger as get_logger from typing import Any, Literal -from hydra.core.config_store import ConfigStore - logger = get_logger(__name__) LogLevel = Literal["debug", "info", "warning", "error", "critical"] @@ -39,7 +37,3 @@ class Config: debug: bool = False verbose: bool = False - - -cs = ConfigStore.instance() -cs.store(name="base_config", node=Config) diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py index bf2fb834..03154b43 100644 --- a/project/configs/datamodule/__init__.py +++ b/project/configs/datamodule/__init__.py @@ -5,7 +5,7 @@ from pathlib import Path import torch -from hydra_zen import hydrated_dataclass, instantiate +from hydra_zen import hydrated_dataclass, instantiate, store from torch import Tensor from project.datamodules import ( @@ -78,6 +78,9 @@ class DataModuleConfig: ... +datamodule_store = store(group="datamodule") + + @hydrated_dataclass(target=VisionDataModule, populate_full_signature=True) class VisionDataModuleConfig(DataModuleConfig): data_dir: str | None = str(TORCHVISION_DIR or DATA_DIR) @@ -143,3 +146,10 @@ class INaturalistDataModuleConfig(VisionDataModuleConfig): data_dir: Path | None = None version: Version = "2021_train" target_type: TargetType | list[TargetType] = "full" + + +datamodule_store(CIFAR10DataModuleConfig, name="cifar10") +datamodule_store(MNISTDataModuleConfig, name="mnist") +datamodule_store(FashionMNISTDataModuleConfig, name="fashion_mnist") +datamodule_store(ImageNet32DataModuleConfig, name="imagenet32") +datamodule_store(INaturalistDataModuleConfig, name="inaturalist") diff --git a/project/configs/network/__init__.py b/project/configs/network/__init__.py index e69de29b..31723745 100644 --- a/project/configs/network/__init__.py +++ b/project/configs/network/__init__.py @@ -0,0 +1,26 @@ +import hydra_zen +import torchvision.models +from hydra_zen import store + +from project.networks.fcnet import FcNet +from project.utils.hydra_utils import interpolate_config_attribute + +network_store = store(group="network") +network_store( + hydra_zen.builds( + torchvision.models.resnet18, + populate_full_signature=True, + num_classes=interpolate_config_attribute("datamodule.num_classes"), + ), + name="resnet18", +) +network_store( + hydra_zen.builds( + FcNet, + hydra_convert="object", + hydra_recursive=True, + populate_full_signature=True, + output_dims=interpolate_config_attribute("datamodule.num_classes"), + ), + name="fcnet", +) diff --git a/project/configs/network/jax_cnn.yaml b/project/configs/network/jax_cnn.yaml new file mode 100644 index 00000000..4fc6dc8c --- /dev/null +++ b/project/configs/network/jax_cnn.yaml @@ -0,0 +1,2 @@ +_target_: project.algorithms.jax_algo.CNN +num_classes: ${instance_attr:datamodule.num_classes} diff --git a/project/configs/network/jax_fcnet.yaml b/project/configs/network/jax_fcnet.yaml new file mode 100644 index 00000000..55ed3023 --- /dev/null +++ b/project/configs/network/jax_fcnet.yaml @@ -0,0 +1,3 @@ +_target_: project.algorithms.jax_algo.JaxFcNet +num_classes: ${instance_attr:datamodule.num_classes} +num_features: 256 diff --git a/project/configs/network/resnet50.yaml b/project/configs/network/resnet50.yaml new file mode 100644 index 00000000..fc5d10be --- /dev/null +++ b/project/configs/network/resnet50.yaml @@ -0,0 +1,3 @@ +_target_: torchvision.models.resnet50 +num_classes: "${instance_attr:datamodule.num_classes}" +weights: null diff --git a/project/conftest.py b/project/conftest.py index e55d490b..0a69188d 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -17,8 +17,6 @@ from hydra import compose, initialize_config_module from lightning import seed_everything from omegaconf import DictConfig, open_dict -from pytest_regressions.data_regression import DataRegressionFixture -from pytest_regressions.ndarrays_regression import NDArraysRegressionFixture from torch import Tensor, nn from torch.utils.data import DataLoader @@ -38,7 +36,6 @@ setup_logging, ) from project.utils.hydra_utils import resolve_dictconfig -from project.utils.tensor_regression import TensorRegressionFixture from project.utils.testutils import default_marks_for_config_name from project.utils.types import is_sequence_of from project.utils.types.protocols import DataModule @@ -503,7 +500,8 @@ def network( input: Tensor, request: pytest.FixtureRequest, ): - network = instantiate_network(experiment_config, datamodule=datamodule).to(device) + with device: + network = instantiate_network(experiment_config, datamodule=datamodule) try: _ = network(input) except RuntimeError as err: @@ -550,41 +548,6 @@ def make_torch_deterministic(): torch.set_deterministic_debug_mode(mode_before) -@pytest.fixture -def tensor_regression( - datadir: Path, - original_datadir: Path, - request: pytest.FixtureRequest, - ndarrays_regression: NDArraysRegressionFixture, - data_regression: DataRegressionFixture, - monkeypatch: pytest.MonkeyPatch, - make_torch_deterministic: None, -) -> TensorRegressionFixture: - """Similar to num_regression, but supports numpy arrays with arbitrary shape. The dictionary is - stored as an NPZ file. The values of the dictionary must be accepted by ``np.asarray``. - - Example:: - - def test_some_data(tensor_regression): - points, values = some_function() - tensor_regression.check( - { - 'points': points, # tensor with shape (100, 3) - 'values': values, # tensor with shape (100,) - }, - default_tolerance=dict(atol=1e-8, rtol=1e-8) - ) - """ - return TensorRegressionFixture( - datadir=datadir, - original_datadir=original_datadir, - request=request, - ndarrays_regression=ndarrays_regression, - data_regression=data_regression, - monkeypatch=monkeypatch, - ) - - # Incremental testing: https://docs.pytest.org/en/7.1.x/example/simple.html#incremental-testing-test-steps # content of conftest.py diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index f647fa8d..432a4f7f 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -3,12 +3,9 @@ import matplotlib.pyplot as plt import pytest +from tensor_regression.fixture import TensorRegressionFixture, get_test_source_and_temp_file_paths from torch import Tensor -from project.utils.tensor_regression import ( - TensorRegressionFixture, - get_test_source_and_temp_file_paths, -) from project.utils.testutils import run_for_all_datamodules from project.utils.types import is_sequence_of diff --git a/project/datamodules/image_classification/imagenet32.py b/project/datamodules/image_classification/imagenet32.py index 8143dd01..825ba493 100644 --- a/project/datamodules/image_classification/imagenet32.py +++ b/project/datamodules/image_classification/imagenet32.py @@ -13,9 +13,9 @@ import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset, Subset -from torchvision import transforms +from torchvision.datasets import VisionDataset +from torchvision.transforms import v2 as transforms -from project.datamodules.image_classification.base import ImageClassificationDataModule from project.utils.types import C, H, StageStr, W from ..vision.base import VisionDataModule @@ -27,7 +27,7 @@ def imagenet32_normalization(): return transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) -class ImageNet32Dataset(ImageClassificationDataModule): +class ImageNet32Dataset(VisionDataset): """Downsampled ImageNet 32x32 Dataset.""" url: ClassVar[str] = "https://drive.google.com/uc?id=1XAlD_wshHhGNzaqy8ML-Jk0ZhAm8J5J_" diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index f050f392..9b54144e 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -21,7 +21,11 @@ P = ParamSpec("P") SLURM_TMPDIR: Path | None = ( - Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None + Path(os.environ["SLURM_TMPDIR"]) + if "SLURM_TMPDIR" in os.environ + else tmp + if "SLURM_JOB_ID" in os.environ and (tmp := Path("/tmp")).exists() + else None ) logger = get_logger(__name__) @@ -75,8 +79,9 @@ def __init__( """ super().__init__() + from project.configs.datamodule import DATA_DIR - self.data_dir = data_dir if data_dir is not None else os.getcwd() + self.data_dir = data_dir if data_dir is not None else DATA_DIR self.val_split = val_split if num_workers is None: num_workers = num_cpus_on_node() @@ -240,7 +245,6 @@ def train_dataloader( ) | kwargs ), - persistent_workers=True, ) def val_dataloader( @@ -256,7 +260,6 @@ def val_dataloader( _dataloader_fn=_dataloader_fn, *args, **(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs), - persistent_workers=True, ) def test_dataloader( @@ -274,7 +277,6 @@ def test_dataloader( _dataloader_fn=_dataloader_fn, *args, **(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs), - persistent_workers=True, ) def _data_loader( @@ -290,6 +292,7 @@ def _data_loader( num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, + persistent_workers=True if self.num_workers > 0 else False, ) | dataloader_kwargs ) diff --git a/project/experiment.py b/project/experiment.py index 4ec1b9ff..f4112356 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -170,12 +170,15 @@ def get_experiment_device(experiment_config: Config | DictConfig) -> torch.devic def instantiate_network(experiment_config: Config, datamodule: DataModule) -> nn.Module: + device = get_experiment_device(experiment_config) + network_config = experiment_config.network - device = get_experiment_device(experiment_config) - if hasattr(network_config, "_target_"): + # todo: Should we wrap flax.linen.Modules into torch modules automatically for torch-based algos? + + if isinstance(network_config, dict | DictConfig) or hasattr(network_config, "_target_"): with device: - network = instantiate(network_config) + network = hydra_zen.instantiate(network_config) elif is_dataclass(network_config): with device: network = instantiate_network_from_hparams( diff --git a/project/main_test.py b/project/main_test.py index 12b415d0..a855ba57 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -4,6 +4,7 @@ import typing from pathlib import Path +import hydra_zen import pytest from project.algorithms import Algorithm, ExampleAlgorithm @@ -11,7 +12,7 @@ from project.configs.datamodule import CIFAR10DataModuleConfig from project.conftest import setup_hydra_for_tests_and_compose, use_overrides from project.datamodules.image_classification.cifar10 import CIFAR10DataModule -from project.networks import FcNetConfig +from project.networks.fcnet import FcNet from project.utils.hydra_utils import resolve_dictconfig if typing.TYPE_CHECKING: @@ -77,13 +78,13 @@ def test_setting_algorithm( @pytest.mark.parametrize( ("overrides", "expected_type"), [ - (["algorithm=example_algo", "network=fcnet"], FcNetConfig), + (["algorithm=example_algo", "network=fcnet"], FcNet), ], ids=_ids, ) def test_setting_network( overrides: list[str], - expected_type: type[Algorithm.HParams], + expected_type: type, testing_overrides: list[str], tmp_path: Path, ) -> None: @@ -93,7 +94,7 @@ def test_setting_network( ) as dictconfig: options = resolve_dictconfig(dictconfig) assert isinstance(options, Config) - assert isinstance(options.network, expected_type) + assert hydra_zen.get_target(options.network) is expected_type # TODO: Add some more integration tests: diff --git a/project/networks/__init__.py b/project/networks/__init__.py index 028d444e..c44d7cfc 100644 --- a/project/networks/__init__.py +++ b/project/networks/__init__.py @@ -13,30 +13,7 @@ # _cs.store(group="network", name="fcnet", node=FcNetConfig) # _cs.store(group="network", name="resnet18", node=ResNet18Config) # Add your network configs here. -from dataclasses import field - -from hydra_zen import hydrated_dataclass -from torchvision.models import resnet18 - -from project.utils.hydra_utils import interpolated_field from .fcnet import FcNet - -@hydrated_dataclass(target=FcNet, hydra_convert="object", hydra_recursive=True) -class FcNetConfig: - output_dims: int = interpolated_field( - "${instance_attr:datamodule.num_classes,datamodule.action_dims}", default=-1 - ) - hparams: FcNet.HParams = field(default_factory=FcNet.HParams) - - -@hydrated_dataclass(target=resnet18) -class ResNet18Config: - pretrained: bool = False - num_classes: int = interpolated_field( - "${instance_attr:datamodule.num_classes,datamodule.action_dims}", default=1000 - ) - - __all__ = ["FcNet"] diff --git a/project/utils/hydra_utils.py b/project/utils/hydra_utils.py index 824966be..1ae7cc0e 100644 --- a/project/utils/hydra_utils.py +++ b/project/utils/hydra_utils.py @@ -27,6 +27,86 @@ T = TypeVar("T") +def interpolate_config_attribute(*attributes: str, default: Any | Literal[MISSING] = MISSING): + """Use this in a config to to get an attribute from another config after it is instantiated. + + Multiple attributes can be specified, which will lead to trying each of them in order until the + attribute is found. If none are found, then an error will be raised. + + For example, if we only know the number of classes in the datamodule after it is instantiated, + we can set this in the network config so it is created with the right number of output dims. + + ```yaml + _target_: torchvision.models.resnet50 + num_classes: ${instance_attr:datamodule.num_classes} + ``` + + This is equivalent to: + + >>> import hydra_zen + >>> import torchvision.models + >>> resnet50_config = hydra_zen.builds( + ... torchvision.models.resnet50, + ... num_classes=interpolate_config_attribute("datamodule.num_classes"), + ... populate_full_signature=True, + ... ) + >>> print(hydra_zen.to_yaml(resnet50_config)) # doctest: +NORMALIZE_WHITESPACE + _target_: torchvision.models.resnet.resnet50 + weights: null + progress: true + num_classes: ${instance_attr:datamodule.num_classes} + """ + if default is MISSING: + return "${instance_attr:" + ",".join(attributes) + "}" + return "${instance_attr:" + ",".join(attributes) + ":" + str(default) + "}" + + +def interpolated_field( + interpolation: str, + default: T | Literal[MISSING] = MISSING, + default_factory: Callable[[], T] | Literal[MISSING] = MISSING, + instance_attr: bool = False, +) -> T: + """Field with a default value computed with a OmegaConf-style interpolation when appropriate. + + When the dataclass is created by Hydra / OmegaConf, the interpolation is used. + Otherwise, behaves as usual (either using default or calling the default_factory). + + Parameters + ---------- + interpolation: The string interpolation to use to get the default value. + default: The default value to use when not in a hydra/OmegaConf context. + default_factory: The default value to use when not in a hydra/OmegaConf context. + instance_attr: Whether to use the `instance_attr` custom resolver to run the interpolation \ + with respect to instantiated objects instead of their configs. + Passing `interpolation='${instance_attr:some_config.some_attr}'` has the same effect. + + This last parameter is important, since in order to retrieve the instance attribute, we need to + instantiate the objects, which could be expensive. These instantiated objects are reused at + least, but still, be mindful when using this parameter. + """ + assert "${" in interpolation and "}" in interpolation + + if instance_attr: + if not interpolation.startswith("${instance_attr:"): + interpolation = interpolation.removeprefix("${") + interpolation = "${instance_attr:" + interpolation + + if default is MISSING and default_factory is MISSING: + raise RuntimeError( + "Interpolated fields currently still require a default value or default factory for " + "when they are used outside the Hydra/OmegaConf context." + ) + return field( + default_factory=functools.partial( + _default_factory, + interpolation=interpolation, + default=default, + default_factory=default_factory, + ) + ) + + # @dataclass(init=False) class Partial(functools.partial[T], _Partial[T]): def __getattr__(self, name: str): @@ -262,52 +342,6 @@ def get_instantiated_attr( ) -def interpolated_field( - interpolation: str, - default: T | Literal[MISSING] = MISSING, - default_factory: Callable[[], T] | Literal[MISSING] = MISSING, - instance_attr: bool = False, -) -> T: - """Field with a default value computed with a OmegaConf-style interpolation when appropriate. - - When the dataclass is created by Hydra / OmegaConf, the interpolation is used. - Otherwise, behaves as usual (either using default or calling the default_factory). - - Parameters - ---------- - interpolation: The string interpolation to use to get the default value. - default: The default value to use when not in a hydra/OmegaConf context. - default_factory: The default value to use when not in a hydra/OmegaConf context. - instance_attr: Whether to use the `instance_attr` custom resolver to run the interpolation \ - with respect to instantiated objects instead of their configs. - Passing `interpolation='${instance_attr:some_config.some_attr}'` has the same effect. - - This last parameter is important, since in order to retrieve the instance attribute, we need to - instantiate the objects, which could be expensive. These instantiated objects are reused at - least, but still, be mindful when using this parameter. - """ - assert "${" in interpolation and "}" in interpolation - - if instance_attr: - if not interpolation.startswith("${instance_attr:"): - interpolation = interpolation.removeprefix("${") - interpolation = "${instance_attr:" + interpolation - - if default is MISSING and default_factory is MISSING: - raise RuntimeError( - "Interpolated fields currently still require a default value or default factory for " - "when they are used outside the Hydra/OmegaConf context." - ) - return field( - default_factory=functools.partial( - _default_factory, - interpolation=interpolation, - default=default, - default_factory=default_factory, - ) - ) - - def being_called_in_hydra_context() -> bool: import hydra.core.utils import omegaconf._utils diff --git a/project/utils/tensor_regression.py b/project/utils/tensor_regression.py deleted file mode 100644 index 677c2929..00000000 --- a/project/utils/tensor_regression.py +++ /dev/null @@ -1,449 +0,0 @@ -import contextlib -import functools -import os -import re -import warnings -from collections.abc import Mapping -from logging import getLogger as get_logger -from pathlib import Path -from typing import Any - -import numpy as np -import pytest -import torch -from _pytest.outcomes import Failed -from pytest_regressions.data_regression import DataRegressionFixture -from pytest_regressions.ndarrays_regression import NDArraysRegressionFixture -from torch import Tensor - -from project.utils.utils import flatten_dict, get_shape_ish - -logger = get_logger(__name__) - -PRECISION = 3 -"""Number of decimals used when rounding the simple stats of Tensor / ndarray in the pre-check. - -Full precision is used in the actual regression check, but this is just for the simple attributes -(min, max, mean, etc.) which seem to be slightly different on the GitHub CI than on a local -machine. -""" - - -@functools.singledispatch -def to_ndarray(v: Any) -> np.ndarray | None: - return np.asarray(v) - - -@to_ndarray.register(type(None)) -def _none_to_ndarray(v: None) -> None: - return None - - -@to_ndarray.register(list) -def _list_to_ndarray(v: list) -> np.ndarray: - if all(isinstance(v_i, list) for v_i in v): - lengths = [len(v_i) for v_i in v] - if len(set(lengths)) != 1: - # List of lists of something, (e.g. a nested tensor-like list of dicts for instance). - if all(isinstance(v_i_j, dict) and not v_i_j for v_i in v for v_i_j in v_i): - # all empty dicts! - return np.asarray([f"list of {len_i} empty dicts" for len_i in lengths]) - raise NotImplementedError(v) - return np.asarray(v) - - -@to_ndarray.register(Tensor) -def _tensor_to_ndarray(v: Tensor) -> np.ndarray: - if v.is_nested: - v = v.to_padded_tensor(padding=0.0) - return v.detach().cpu().numpy() - - -@functools.singledispatch -def _hash(v: Any) -> int: - return hash(v) - - -@_hash.register(Tensor) -def tensor_hash(tensor: Tensor) -> int: - return hash(tuple(tensor.flatten().tolist())) - - -@_hash.register(np.ndarray) -def ndarray_hash(array: np.ndarray) -> int: - return hash(tuple(array.flat)) - - -class TensorRegressionFixture: - """Save some statistics (and a hash) of tensors in a file that is saved with git, but save the - entire tensors in gitignored files. - - This way, the first time the tests run, they re-generate the full regression files, and check - that their contents' hash matches what is stored with git! - - TODO: Add a `--regen-missing` option (currently implicitly always true) that decides if we - raise an error if a file is missing. (for example in unit tests we don't want this to be true!) - """ - - def __init__( - self, - datadir: Path, - original_datadir: Path, - request: pytest.FixtureRequest, - ndarrays_regression: NDArraysRegressionFixture, - data_regression: DataRegressionFixture, - monkeypatch: pytest.MonkeyPatch, - simple_attributes_precision: int = PRECISION, - ) -> None: - self.request = request - self.datadir = datadir - self.original_datadir = original_datadir - - self.ndarrays_regression = ndarrays_regression - self.data_regression = data_regression - self.monkeypatch = monkeypatch - self.simple_attributes_precision = simple_attributes_precision - self.generate_missing_files: bool | None = self.request.config.getoption( - "--gen-missing", - default=None, # type: ignore - ) - - def get_source_file(self, extension: str, additional_subfolder: str | None = None) -> Path: - source_file, _test_file = get_test_source_and_temp_file_paths( - extension=extension, - request=self.request, - original_datadir=self.original_datadir, - datadir=self.datadir, - additional_subfolder=additional_subfolder, - ) - return source_file - - # Would be nice if this were a singledispatch method or something similar. - - def check( - self, - data_dict: Mapping[str, Any], - tolerances: dict[str, dict[str, float]] | None = None, - default_tolerance: dict[str, float] | None = None, - ) -> None: - # IDEA: - # - Get the hashes of each array, and actually run the regression check first with those files. - # - Then, if that check passes, run the actual check with the full files. - # NOTE: If the array hash files exist, but the full files don't, then we should just - # re-create the full files instead of failing. - # __tracebackhide__ = True - - data_dict = flatten_dict(data_dict) - - if not isinstance(data_dict, dict): - raise TypeError( - "Only dictionaries with Tensors, NumPy arrays or array-like objects are " - "supported on ndarray_regression fixture.\n" - f"Object with type '{str(type(data_dict))}' was given." - ) - - # File some simple attributes of the full arrays/tensors. This one is saved with git. - simple_attributes_source_file = self.get_source_file(extension=".yaml") - - # File with the full arrays/tensors. This one is ignored by git. - arrays_source_file = self.get_source_file(extension=".npz") - - regen_all = self.request.config.getoption("regen_all") - assert isinstance(regen_all, bool) - - if regen_all: - assert self.generate_missing_files in [ - True, - None, - ], "--gen-missing contradicts --regen-all!" - # Regenerate everything. - if arrays_source_file.exists(): - arrays_source_file.unlink() - if simple_attributes_source_file.exists(): - simple_attributes_source_file.unlink() - - if arrays_source_file.exists(): - logger.info(f"Full arrays file found at {arrays_source_file}.") - if not simple_attributes_source_file.exists(): - # Weird: the simple attributes file doesn't exist. Re-create it if allowed. - with dont_fail_if_files_are_missing(enabled=bool(self.generate_missing_files)): - self.pre_check( - data_dict, - simple_attributes_source_file=simple_attributes_source_file, - ) - - # We already generated the file with the full tensors (and we also already checked - # that their hashes correspond to what we expect.) - # 1. Check that they match the data_dict. - logger.info("Checking the full arrays.") - self.regular_check( - data_dict=data_dict, - fullpath=arrays_source_file, - tolerances=tolerances, - default_tolerance=default_tolerance, - ) - # the simple attributes file should already have been generated and saved in git. - assert simple_attributes_source_file.exists() - # NOTE: No need to do this step here. Saves us a super super tiny amount of time. - # logger.debug("Checking that the hashes of the full arrays still match.") - # self.pre_check( - # data_dict, - # simple_attributes_source_file=simple_attributes_source_file, - # ) - return - - if simple_attributes_source_file.exists(): - logger.debug(f"Simple attributes file found at {simple_attributes_source_file}.") - logger.debug(f"Regenerating the full arrays at {arrays_source_file}") - # Go straight to the full check. - # TODO: Need to get the full error when the tensors change instead of just the check - # for the hash, which should only be used when re-creating the full regression files. - - with dont_fail_if_files_are_missing(): - self.regular_check( - data_dict=data_dict, - fullpath=arrays_source_file, - tolerances=tolerances, - default_tolerance=default_tolerance, - ) - logger.debug( - "Checking if the newly-generated full tensor regression files match the expected " - "attributes and hashes." - ) - self.pre_check( - data_dict, - simple_attributes_source_file=simple_attributes_source_file, - ) - return - - logger.warning(f"Creating the simple attributes file at {simple_attributes_source_file}.") - - with dont_fail_if_files_are_missing(enabled=bool(self.generate_missing_files)): - self.pre_check( - data_dict, - simple_attributes_source_file=simple_attributes_source_file, - ) - - with dont_fail_if_files_are_missing(enabled=bool(self.generate_missing_files)): - self.regular_check( - data_dict=data_dict, - fullpath=arrays_source_file, - tolerances=tolerances, - default_tolerance=default_tolerance, - ) - - test_dir = self.original_datadir - assert test_dir.exists() - gitignore_file = test_dir / ".gitignore" - if not gitignore_file.exists(): - logger.info(f"Making a new .gitignore file at {gitignore_file}") - gitignore_file.write_text( - "\n".join( - [ - "# Ignore full tensor files, but not the files with tensor attributes and hashes.", - "*.npz", - ] - ) - + "\n" - ) - - def pre_check(self, data_dict: dict[str, Any], simple_attributes_source_file: Path) -> None: - version_controlled_simple_attributes = get_version_controlled_attributes( - data_dict, precision=self.simple_attributes_precision - ) - # Run the regression check with the hashes (and don't fail if they don't exist) - __tracebackhide__ = True - # TODO: Figure out how to include/use the names of the GPUs: - # - Should it be part of the hash? Or should there be a subfolder for each GPU type? - _gpu_names = get_gpu_names(data_dict) - if len(set(_gpu_names)) == 1: - gpu_name = _gpu_names[0] - if any(isinstance(t, Tensor) and t.device.type == "cuda" for t in data_dict.values()): - version_controlled_simple_attributes["GPU"] = gpu_name - - self.data_regression.check( - version_controlled_simple_attributes, fullpath=simple_attributes_source_file - ) - - def regular_check( - self, - data_dict: dict[str, Any], - basename: str | None = None, - fullpath: os.PathLike[str] | None = None, - tolerances: dict[str, dict[str, float]] | None = None, - default_tolerance: dict[str, float] | None = None, - ) -> None: - array_dict: dict[str, np.ndarray] = {} - for key, array in data_dict.items(): - if isinstance(key, (int | bool | float)): - new_key = f"{key}" - assert new_key not in data_dict - key = new_key - assert isinstance( - key, str - ), f"The dictionary keys must be strings. Found key with type '{str(type(key))}'" - - ndarray_value = to_ndarray(array) - if ndarray_value is None: - logger.debug( - f"Got a value of `None` for key {key} not including it in the saved dict." - ) - else: - array_dict[key] = ndarray_value - self.ndarrays_regression.check( - array_dict, - basename=basename, - fullpath=fullpath, - tolerances=tolerances, - default_tolerance=default_tolerance, - ) - return - - -def get_test_source_and_temp_file_paths( - extension: str, - request: pytest.FixtureRequest, - original_datadir: Path, - datadir: Path, - additional_subfolder: str | None = None, -) -> tuple[Path, Path]: - """Returns the path to the (maybe version controlled) source file and the path to the temporary - file where test results might be generated during a regression test. - - NOTE: This is different than in pytest-regressions. Here we use a subfolder with the same name - as the test function. - """ - basename = re.sub(r"[\W]", "_", request.node.name) - overrides_name = basename.removeprefix(request.node.function.__name__).lstrip("_") - - if extension.startswith(".") and overrides_name: - # Remove trailing _'s if the extension starts with a dot. - overrides_name = overrides_name.rstrip("_") - - if overrides_name: - # There are overrides, so use a subdirectory. - relative_path = Path(request.node.function.__name__) / overrides_name - else: - # There are no overrides, so use the regular base name. - relative_path = Path(basename) - - relative_path = relative_path.with_suffix(extension) - if additional_subfolder: - relative_path = relative_path.parent / additional_subfolder / relative_path.name - - source_file = original_datadir / relative_path - test_file = datadir / relative_path - return source_file, test_file - - -@functools.singledispatch -def get_simple_attributes(value: Any, precision: int) -> Any: - raise NotImplementedError( - f"get_simple_attributes doesn't have a registered handler for values of type {type(value)}" - ) - - -@get_simple_attributes.register(type(None)) -def _get_none_attributes(value: None, precision: int): - return {"type": "None"} - - -@get_simple_attributes.register(bool) -@get_simple_attributes.register(int | float | str) -def _get_bool_attributes(value: Any, precision: int): - return {"value": value, "type": type(value).__name__} - - -@get_simple_attributes.register(list) -def list_simple_attributes(some_list: list[Any], precision: int): - return { - "length": len(some_list), - "item_types": sorted(set(type(item).__name__ for item in some_list)), - } - - -@get_simple_attributes.register(dict) -def dict_simple_attributes(some_dict: dict[str, Any], precision: int): - return {k: get_simple_attributes(v, precision=precision) for k, v in some_dict.items()} - - -@get_simple_attributes.register(np.ndarray) -def ndarray_simple_attributes(array: np.ndarray, precision: int) -> dict: - return { - "shape": tuple(array.shape), - "hash": _hash(array), - "min": round(array.min().item(), precision), - "max": round(array.max().item(), precision), - "sum": round(array.sum().item(), precision), - "mean": round(array.mean(), precision), - } - - -@get_simple_attributes.register(Tensor) -def tensor_simple_attributes(tensor: Tensor, precision: int) -> dict: - if tensor.is_nested: - # assert not [tensor_i.any() for tensor_i in tensor.unbind()], tensor - # TODO: It might be a good idea to make a distinction here between '0' as the default, and - # '0' as a value in the tensor? Hopefully this should be clear enough. - tensor = tensor.to_padded_tensor(padding=0.0) - - return { - "shape": tuple(tensor.shape) if not tensor.is_nested else get_shape_ish(tensor), - "hash": _hash(tensor), - "min": round(tensor.min().item(), precision), - "max": round(tensor.max().item(), precision), - "sum": round(tensor.sum().item(), precision), - "mean": round(tensor.float().mean().item(), precision), - "device": ( - "cpu" if tensor.device.type == "cpu" else f"{tensor.device.type}:{tensor.device.index}" - ), - } - - -def get_gpu_names(data_dict: dict[str, Any]) -> list[str]: - """Returns the names of the GPUS that tensors in this dict are on.""" - return sorted( - set( - torch.cuda.get_device_name(tensor.device) - for tensor in data_dict.values() - if isinstance(tensor, Tensor) and tensor.device.type == "cuda" - ) - ) - - -def get_version_controlled_attributes(data_dict: dict[str, Any], precision: int) -> dict[str, Any]: - return { - key: get_simple_attributes(value, precision=precision) for key, value in data_dict.items() - } - - -class FilesDidntExist(Failed): - pass - - -@contextlib.contextmanager -def dont_fail_if_files_are_missing(enabled: bool = True): - try: - with _catch_fails_with_files_didnt_exist(): - yield - except FilesDidntExist as exc: - if enabled: - logger.warning(exc) - warnings.warn(RuntimeWarning(exc.msg)) - else: - raise - - -@contextlib.contextmanager -def _catch_fails_with_files_didnt_exist(): - try: - yield - except Failed as failure_exc: - if failure_exc.msg and "File not found in data directory, created" in failure_exc.msg: - raise FilesDidntExist( - failure_exc.msg - + "\n(Use the --gen-missing flag to create any missing regression files.)", - pytrace=failure_exc.pytrace, - ) from failure_exc - else: - raise diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 0d6dbe31..bfe9a203 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -13,6 +13,7 @@ from typing import Any, TypeVar import hydra.errors +import hydra_zen import pytest import torch import yaml @@ -22,7 +23,7 @@ from torch import Tensor, nn from torch.optim import Optimizer -from project.configs.config import Config, cs +from project.configs import Config, cs from project.configs.datamodule import DATA_DIR, SLURM_JOB_ID from project.datamodules.image_classification import ( ImageClassificationDataModule, @@ -124,12 +125,29 @@ def _parametrized_fixture_method(request: pytest.FixtureRequest): return _parametrized_fixture_method +def get_config_loader(): + from hydra._internal.config_loader_impl import ConfigLoaderImpl + from hydra._internal.utils import create_automatic_config_search_path + + search_path = create_automatic_config_search_path( + calling_file=None, calling_module=None, config_path="pkg://project.configs" + ) + config_loader = ConfigLoaderImpl(config_search_path=search_path) + return config_loader + + def get_all_configs_in_group(group_name: str) -> list[str]: - names_yaml = cs.list(group_name) - names = [name.rpartition(".")[0] for name in names_yaml] - if "base" in names: - names.remove("base") - return names + # note: here we're copying a bit of the internal code from Hydra so that we also get the + # configs that are just yaml files, in addition to the configs we added programmatically to the + # configstores. + + # names_yaml = cs.list(group_name) + # names = [name.rpartition(".")[0] for name in names_yaml] + # if "base" in names: + # names.remove("base") + # return names + + return get_config_loader().get_group_options(group_name) def get_all_algorithm_names() -> list[str]: @@ -142,7 +160,20 @@ def get_type_for_config_name(config_group: str, config_name: str, _cs: ConfigSto In the case of inner dataclasses (e.g. Model.HParams), this returns the outer class (Model). """ + + config_loader = get_config_loader() + _, caching_repo = config_loader._parse_overrides_and_create_caching_repo( + config_name=None, overrides=[] + ) + config_result = caching_repo.load_config(f"{config_group}/{config_name}.yaml") + if config_result is not None: + try: + return hydra_zen.get_target(config_result.config) # type: ignore + except TypeError: + pass + config_node = _cs._load(f"{config_group}/{config_name}.yaml") + if "_target_" in config_node.node: target: str = config_node.node["_target_"] module_name, _, class_name = target.rpartition(".") @@ -367,13 +398,15 @@ def reconstruct(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor: return self.inf_network(input) - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - return self.shared_step(batch, batch_idx, phase="train") + def training_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> Tensor: + return self.shared_step(batch, batch_index, phase="train") - def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - return self.shared_step(batch, batch_idx, phase="val") + def validation_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> Tensor: + return self.shared_step(batch, batch_index, phase="val") - def shared_step(self, batch: tuple[Tensor, Tensor], batch_idx: int, phase: PhaseStr) -> Tensor: + def shared_step( + self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr + ) -> Tensor: x, _y = batch latents = self.inf_network(x) x_hat = self.gen_network(latents) @@ -405,7 +438,9 @@ def forward(self, input: Tensor) -> Tensor: assert isinstance(output, Tensor) return output - def shared_step(self, batch: tuple[Tensor, Tensor], batch_idx: int, phase: PhaseStr) -> Tensor: + def shared_step( + self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr + ) -> Tensor: x, y = batch latents = self.inf_network(x) x_hat = self.gen_network(latents) @@ -436,13 +471,15 @@ def configure_optimizers(self) -> Optimizer: def forward(self, input: Tensor) -> Tensor: return self.network(input) - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - return self.shared_step(batch, batch_idx, phase="train") + def training_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> Tensor: + return self.shared_step(batch, batch_index, phase="train") - def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - return self.shared_step(batch, batch_idx, phase="val") + def validation_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> Tensor: + return self.shared_step(batch, batch_index, phase="val") - def shared_step(self, batch: tuple[Tensor, Tensor], batch_idx: int, phase: PhaseStr) -> Tensor: + def shared_step( + self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr + ) -> Tensor: x, y = batch logits = self.network(x) assert isinstance(logits, Tensor) diff --git a/project/utils/types/__init__.py b/project/utils/types/__init__.py index 7a9ccc69..6858b489 100644 --- a/project/utils/types/__init__.py +++ b/project/utils/types/__init__.py @@ -39,6 +39,7 @@ type NestedDict[K, V] = dict[K, V | NestedDict[K, V]] type NestedMapping[K, V] = Mapping[K, V | NestedMapping[K, V]] +type PyTree[T] = T | tuple[PyTree[T], ...] | list[PyTree[T]] | Mapping[Any, PyTree[T]] def is_list_of[V](object: Any, item_type: type[V] | tuple[type[V], ...]) -> TypeGuard[list[V]]: @@ -51,10 +52,18 @@ def is_sequence_of[V]( ) -> TypeGuard[Sequence[V]]: """Used to check (and tell the type checker) that `object` is a sequence of items of this type.""" - try: - return all(isinstance(value, item_type) for value in object) - except TypeError: - return False + return isinstance(object, Sequence) and all(isinstance(value, item_type) for value in object) + + +def is_mapping_of[K, V]( + object: Any, key_type: type[K], value_type: type[V] +) -> TypeGuard[Mapping[K, V]]: + """Used to check (and tell the type checker) that `object` is a mapping with keys and values of + the given types.""" + return isinstance(object, Mapping) and all( + isinstance(key, key_type) and isinstance(value, value_type) + for key, value in object.items() + ) __all__ = [ diff --git a/project/utils/types/protocols.py b/project/utils/types/protocols.py index 4f0e8302..25c4bd4a 100644 --- a/project/utils/types/protocols.py +++ b/project/utils/types/protocols.py @@ -64,3 +64,8 @@ def prepare_data(self) -> None: ... def setup(self, stage: StageStr) -> None: ... def train_dataloader(self) -> Iterable[BatchType]: ... + + +@runtime_checkable +class ClassificationDataModule[BatchType](DataModule[BatchType], Protocol): + num_classes: int diff --git a/pyproject.toml b/pyproject.toml index e07e050a..1443b0ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ "brax>=0.10.3", "tensorboard>=2.16.2", "gymnax>=0.0.8", + "torch-jax-interop @ git+https://www.github.com/lebrice/torch_jax_interop", + "tensor-regression @ git+https://www.github.com/lebrice/tensor_regression", ] requires-python = ">=3.12" readme = "README.md" @@ -47,6 +49,7 @@ dev = [ "ruff>=0.3.3", "pytest-benchmark>=4.0.0", "pytest-cov>=5.0.0", + "tensor-regression>=0.0.2.post3.dev0", ] [[tool.pdm.source]] @@ -60,6 +63,8 @@ build-backend = "setuptools.build_meta" [tool.pytest.ini_options] testpaths = ["project"] +# todo: look into using https://github.com/scientific-python/pytest-doctestplus +addopts = ["--doctest-modules", "--stats-rounding-precision=3"] [tool.ruff] line-length = 99