From a5db3c77095c94b26f5be0e0dfb90ea3cb98a315 Mon Sep 17 00:00:00 2001 From: Noah Syrkis Date: Sun, 9 Jun 2024 07:44:52 -0300 Subject: [PATCH] downgrade jax hpc cuda 11 test --- Dockerfile | 5 +- poetry.lock | 187 +---------------------------------------------- pyproject.toml | 2 - requirements.txt | 7 -- 4 files changed, 4 insertions(+), 197 deletions(-) diff --git a/Dockerfile b/Dockerfile index 15f933f..24ed6eb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,9 +20,8 @@ COPY requirements.txt . RUN python3.11 -m pip install -r requirements.txt RUN python3.11 -m pip install \ - jaxlib==0.4.25+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ - jax==0.4.25 \ - optax==0.2.2 + pip install -U "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ + optax==0.1.5 ENV PYGLFW_PREVIEW=1 diff --git a/poetry.lock b/poetry.lock index f059813..7ac26d2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,16 +1,5 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. -[[package]] -name = "absl-py" -version = "2.1.0" -description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -optional = false -python-versions = ">=3.7" -files = [ - {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, - {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, -] - [[package]] name = "altair" version = "5.3.0" @@ -186,26 +175,6 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] -[[package]] -name = "chex" -version = "0.1.86" -description = "Chex: Testing made fun, in JAX!" -optional = false -python-versions = ">=3.9" -files = [ - {file = "chex-0.1.86-py3-none-any.whl", hash = "sha256:251c20821092323a3d9c28e1cf80e4a58180978bec368f531949bd9847eee568"}, - {file = "chex-0.1.86.tar.gz", hash = "sha256:e8b0f96330eba4144659e1617c0f7a57b161e8cbb021e55c6d5056c7378091d1"}, -] - -[package.dependencies] -absl-py = ">=0.9.0" -jax = ">=0.4.16" -jaxlib = ">=0.1.37" -numpy = ">=1.24.1" -setuptools = {version = "*", markers = "python_version >= \"3.12\""} -toolz = ">=0.9.0" -typing-extensions = ">=4.2.0" - [[package]] name = "click" version = "8.1.7" @@ -456,81 +425,6 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] -[[package]] -name = "jax" -version = "0.4.28" -description = "Differentiate, compile, and transform Numpy code." -optional = false -python-versions = ">=3.9" -files = [ - {file = "jax-0.4.28-py3-none-any.whl", hash = "sha256:6a181e6b5a5b1140e19cdd2d5c4aa779e4cb4ec627757b918be322d8e81035ba"}, - {file = "jax-0.4.28.tar.gz", hash = "sha256:dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9"}, -] - -[package.dependencies] -ml-dtypes = ">=0.2.0" -numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, -] -opt-einsum = "*" -scipy = [ - {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, - {version = ">=1.9", markers = "python_version < \"3.12\""}, -] - -[package.extras] -australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.27)"] -cpu = ["jaxlib (==0.4.28)"] -cuda = ["jaxlib (==0.4.28+cuda12.cudnn89)"] -cuda12 = ["jax-cuda12-plugin (==0.4.28)", "jaxlib (==0.4.28)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] -cuda12-cudnn89 = ["jaxlib (==0.4.28+cuda12.cudnn89)"] -cuda12-local = ["jaxlib (==0.4.28+cuda12.cudnn89)"] -cuda12-pip = ["jaxlib (==0.4.28+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] -minimum-jaxlib = ["jaxlib (==0.4.27)"] -tpu = ["jaxlib (==0.4.28)", "libtpu-nightly (==0.1.dev20240508)", "requests"] - -[[package]] -name = "jaxlib" -version = "0.4.28" -description = "XLA library for JAX" -optional = false -python-versions = ">=3.9" -files = [ - {file = "jaxlib-0.4.28-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:a421d237f8c25d2850166d334603c673ddb9b6c26f52bc496704b8782297bd66"}, - {file = "jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f038e68bd10d1a3554722b0bbe36e6a448384437a75aa9d283f696f0ed9f8c09"}, - {file = "jaxlib-0.4.28-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:fabe77c174e9e196e9373097cefbb67e00c7e5f9d864583a7cfcf9dabd2429b6"}, - {file = "jaxlib-0.4.28-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:e3bcdc6f8e60f8554f415c14d930134e602e3ca33c38e546274fd545f875769b"}, - {file = "jaxlib-0.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:a8b31c0e5eea36b7915696b9be40ea8646edc395a3e5437bf7ef26b7239a567a"}, - {file = "jaxlib-0.4.28-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2ff8290edc7b92c7eae52517f65492633e267b2e9067bad3e4c323d213e77cf5"}, - {file = "jaxlib-0.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:793857faf37f371cafe752fea5fc811f435e43b8fb4b502058444a7f5eccf829"}, - {file = "jaxlib-0.4.28-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b41a6b0d506c09f86a18ecc05bd376f072b548af89c333107e49bb0c09c1a3f8"}, - {file = "jaxlib-0.4.28-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:45ce0f3c840cff8236cff26c37f26c9ff078695f93e0c162c320c281f5041275"}, - {file = "jaxlib-0.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:d4d762c3971d74e610a0e85a7ee063cea81a004b365b2a7dc65133f08b04fac5"}, - {file = "jaxlib-0.4.28-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:d6c09a545329722461af056e735146d2c8c74c22ac7426a845eb69f326b4f7a0"}, - {file = "jaxlib-0.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8dd8bffe3853702f63cd924da0ee25734a4d19cd5c926be033d772ba7d1c175d"}, - {file = "jaxlib-0.4.28-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:de2e8521eb51e16e85093a42cb51a781773fa1040dcf9245d7ea160a14ee5a5b"}, - {file = "jaxlib-0.4.28-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:46a1aa857f4feee8a43fcba95c0e0ab62d40c26cc9730b6c69655908ba359f8d"}, - {file = "jaxlib-0.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:eee428eac31697a070d655f1f24f6ab39ced76750d93b1de862377a52dcc2401"}, - {file = "jaxlib-0.4.28-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:4f98cc837b2b6c6dcfe0ab7ff9eb109314920946119aa3af9faa139718ff2787"}, - {file = "jaxlib-0.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b01562ec8ad75719b7d0389752489e97eb6b4dcb4c8c113be491634d5282ad3c"}, - {file = "jaxlib-0.4.28-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:aa77a9360a395ba9faf6932df637686fb0c14ddcf4fdc1d2febe04bc88a580a6"}, - {file = "jaxlib-0.4.28-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4a56ebf05b4a4c1791699d874e072f3f808f0986b4010b14fb549a69c90ca9dc"}, - {file = "jaxlib-0.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:459a4ddcc3e120904b9f13a245430d7801d707bca48925981cbdc59628057dc8"}, -] - -[package.dependencies] -ml-dtypes = ">=0.2.0" -numpy = ">=1.22" -scipy = [ - {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, - {version = ">=1.9", markers = "python_version < \"3.12\""}, -] - -[package.extras] -cuda12-pip = ["nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] - [[package]] name = "jinja2" version = "3.1.4" @@ -863,41 +757,6 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] -[[package]] -name = "ml-dtypes" -version = "0.4.0" -description = "" -optional = false -python-versions = ">=3.9" -files = [ - {file = "ml_dtypes-0.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81"}, - {file = "ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d"}, - {file = "ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5"}, - {file = "ml_dtypes-0.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364"}, - {file = "ml_dtypes-0.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1"}, - {file = "ml_dtypes-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c"}, - {file = "ml_dtypes-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6"}, - {file = "ml_dtypes-0.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e"}, - {file = "ml_dtypes-0.4.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06"}, - {file = "ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49"}, - {file = "ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259"}, - {file = "ml_dtypes-0.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675"}, - {file = "ml_dtypes-0.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368"}, - {file = "ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0"}, - {file = "ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1"}, - {file = "ml_dtypes-0.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e"}, - {file = "ml_dtypes-0.4.0.tar.gz", hash = "sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb"}, -] - -[package.dependencies] -numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, -] - -[package.extras] -dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] - [[package]] name = "mpmath" version = "1.3.0" @@ -985,48 +844,6 @@ files = [ matplotlib = ">=3.1,<4" sympy = ">=1.2,<2" -[[package]] -name = "opt-einsum" -version = "3.3.0" -description = "Optimizing numpys einsum function" -optional = false -python-versions = ">=3.5" -files = [ - {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, - {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, -] - -[package.dependencies] -numpy = ">=1.7" - -[package.extras] -docs = ["numpydoc", "sphinx (==1.2.3)", "sphinx-rtd-theme", "sphinxcontrib-napoleon"] -tests = ["pytest", "pytest-cov", "pytest-pep8"] - -[[package]] -name = "optax" -version = "0.2.2" -description = "A gradient processing and optimisation library in JAX." -optional = false -python-versions = ">=3.9" -files = [ - {file = "optax-0.2.2-py3-none-any.whl", hash = "sha256:411c414a76aae259f4191a60b712663968741a5163ca92fc250b5d5c7d36fb57"}, - {file = "optax-0.2.2.tar.gz", hash = "sha256:f09bf790ef4b09fb9c35f79a07594c6196a719919985f542dc84b0bf97812e0e"}, -] - -[package.dependencies] -absl-py = ">=0.7.1" -chex = ">=0.1.86" -jax = ">=0.1.55" -jaxlib = ">=0.1.37" -numpy = ">=1.18.0" - -[package.extras] -docs = ["flax", "ipython (>=8.8.0)", "matplotlib (>=3.5.0)", "myst-nb (>=1.0.0)", "sphinx (>=6.0.0)", "sphinx-autodoc-typehints", "sphinx-book-theme (>=1.0.1)", "sphinx-collections (>=0.0.1)", "sphinx-gallery (>=0.14.0)", "sphinx_contributors", "sphinxcontrib-katex", "tensorflow (>=2.4.0)", "tensorflow-datasets (>=4.2.0)"] -dp-accounting = ["absl-py (>=1.0.0)", "attrs (>=21.4.0)", "mpmath (>=1.2.1)", "numpy (>=1.21.4)", "scipy (>=1.7.1)"] -examples = ["dp_accounting (>=0.4)", "flax", "tensorflow (>=2.4.0)", "tensorflow-datasets (>=4.2.0)"] -test = ["dm-tree (>=0.1.7)", "flax (>=0.5.3)"] - [[package]] name = "packaging" version = "24.0" @@ -1078,8 +895,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2151,4 +1968,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "0ef5f8f341d1491a5805196240212c41bbd7e4fa853f0a86bbd680cdac4f5e5e" +content-hash = "44b08e99f2f33906ec403ed9dfc377d57d426a3ca210c2f67a881dcf424b4c25" diff --git a/pyproject.toml b/pyproject.toml index 455c40d..570471b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,6 @@ pyyaml = "^6.0.1" numpy-hilbert-curve = "^1.0.1" wandb = "^0.17.0" seaborn = "^0.13.2" -jax = "^0.4.28" -optax = "^0.2.2" scikit-learn = "^1.5.0" diff --git a/requirements.txt b/requirements.txt index aee65db..d241d46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,9 @@ -absl-py==2.1.0 altair==5.3.0 attrs==23.2.0 blinker==1.8.2 cachetools==5.3.3 certifi==2024.2.2 charset-normalizer==3.3.2 -chex==0.1.86 click==8.1.7 contourpy==1.2.1 cycler==0.12.1 @@ -16,8 +14,6 @@ fonttools==4.52.4 gitdb==4.0.11 GitPython==3.1.43 idna==3.7 -jax==0.4.28 -jaxlib==0.4.28 Jinja2==3.1.4 joblib==1.4.2 jsonschema==4.22.0 @@ -27,13 +23,10 @@ markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 mdurl==0.1.2 -ml-dtypes==0.4.0 mpmath==1.3.0 numpy==1.26.4 numpy-hilbert-curve==1.0.1 oeis==2023.3.10 -opt-einsum==3.3.0 -optax==0.2.2 packaging==24.0 pandas==2.2.2 pillow==10.3.0