From 44712cf2aa6945ba26ac82df88a285729423fc4b Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 16 Sep 2024 12:39:59 +0200 Subject: [PATCH] [nnx] add flaxlib --- .../{pythonpublish.yml => flax_publish.yml} | 4 +- .../workflows/{build.yml => flax_test.yml} | 27 +-- .github/workflows/flaxlib_publish.yml | 72 ++++++++ flax/nnx/graph.py | 2 +- flaxlib/.gitignore | 72 ++++++++ flaxlib/Cargo.lock | 171 ++++++++++++++++++ flaxlib/Cargo.toml | 12 ++ flaxlib/flaxlib/__init__.py | 15 ++ flaxlib/flaxlib/flaxlib.pyi | 15 ++ flaxlib/pyproject.toml | 19 ++ flaxlib/src/lib.rs | 28 +++ flaxlib/uv.lock | 86 +++++++++ pyproject.toml | 2 + tests/flaxlib_test.py | 7 + uv.lock | 57 ++++-- 15 files changed, 550 insertions(+), 39 deletions(-) rename .github/workflows/{pythonpublish.yml => flax_publish.yml} (92%) rename .github/workflows/{build.yml => flax_test.yml} (90%) create mode 100644 .github/workflows/flaxlib_publish.yml create mode 100644 flaxlib/.gitignore create mode 100644 flaxlib/Cargo.lock create mode 100644 flaxlib/Cargo.toml create mode 100644 flaxlib/flaxlib/__init__.py create mode 100644 flaxlib/flaxlib/flaxlib.pyi create mode 100644 flaxlib/pyproject.toml create mode 100644 flaxlib/src/lib.rs create mode 100644 flaxlib/uv.lock create mode 100644 tests/flaxlib_test.py diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/flax_publish.yml similarity index 92% rename from .github/workflows/pythonpublish.yml rename to .github/workflows/flax_publish.yml index 3ebd99443d..383461a5e7 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/flax_publish.yml @@ -1,11 +1,11 @@ # This workflows will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries -name: Upload Python Package +name: Flax - Build and upload to PyPI on: release: - types: [created] + types: [published] jobs: deploy: diff --git a/.github/workflows/build.yml b/.github/workflows/flax_test.yml similarity index 90% rename from .github/workflows/build.yml rename to .github/workflows/flax_test.yml index 1c8f3b0fdd..3597932d6b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/flax_test.yml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Build +name: Flax - Test on: push: @@ -70,7 +70,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - uses: yezz123/setup-uv@v4 + - uses: astral-sh/setup-uv@v2 with: uv-version: "0.3.0" - name: Install standalone dependencies only @@ -104,23 +104,16 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - uses: yezz123/setup-uv@v4 + - name: Setup uv + uses: astral-sh/setup-uv@v2 with: - uv-version: "0.3.0" - - name: Cached virtual environment - id: venv_cache - uses: actions/cache@v3 - with: - path: .venv - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('uv.lock') }} - - name: Install Dependencies for cache - if: steps.venv_cache.outputs.cache-hit != 'true' - run: | - if [ -d ".venv" ]; then rm -rf .venv; fi - uv sync --locked --all-extras - - name: Check lockfile + version: "0.3.0" + - name: Setup Rust (flaxlib) + uses: actions-rs/toolchain@v1 + - name: Install dependencies run: | - uv sync --locked --all-extras + uv sync --locked --extra all --extra testing --extra docs + uv pip install ./flaxlib - name: Install JAX run: | if [[ "${{ matrix.jax-version }}" == "newest" ]]; then diff --git a/.github/workflows/flaxlib_publish.yml b/.github/workflows/flaxlib_publish.yml new file mode 100644 index 0000000000..f304c17d16 --- /dev/null +++ b/.github/workflows/flaxlib_publish.yml @@ -0,0 +1,72 @@ +name: Flaxlib - Build and upload to PyPI + +# for testing only: +on: + pull_request: + branches: [main] + +# on: +# workflow_dispatch: +# pull_request: +# push: +# branches: [main] +# paths: ['flaxlib/**'] +# release: +# types: [published] + +jobs: + build_wheels: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + # macos-13 is an intel runner, macos-14 is apple silicon + os: [ubuntu-latest, windows-latest, macos-13, macos-14] + + steps: + - uses: actions/checkout@v4 + + - name: Build wheels + uses: pypa/cibuildwheel@v2.21.0 + with: + package-dir: './flaxlib' + output-dir: './flaxlib/wheelhouse' + + - uses: actions/upload-artifact@v4 + with: + name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} + path: ./flaxlib/wheelhouse/*.whl + + build_sdist: + name: Build source distribution + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Build sdist + run: pipx run build --sdist flaxlib + + - uses: actions/upload-artifact@v4 + with: + name: cibw-sdist + path: dist/*.tar.gz + + upload_pypi: + needs: [build_wheels, build_sdist] + runs-on: ubuntu-latest + permissions: + id-token: write + steps: + - uses: actions/download-artifact@v4 + with: + # unpacks all CIBW artifacts into dist/ + pattern: cibw-* + path: ./flaxlib/dist + merge-multiple: true + + - name: Build and publish + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + twine upload flaxlib/dist/* \ No newline at end of file diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 5468a5a987..d363c801f9 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -51,7 +51,7 @@ AuxData = tp.TypeVar('AuxData') StateLeaf = VariableState[tp.Any] -NodeLeaf = VariableState[tp.Any] +NodeLeaf = Variable[tp.Any] GraphState = State[Key, StateLeaf] GraphFlatState = FlatState[StateLeaf] diff --git a/flaxlib/.gitignore b/flaxlib/.gitignore new file mode 100644 index 0000000000..c8f044299d --- /dev/null +++ b/flaxlib/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version diff --git a/flaxlib/Cargo.lock b/flaxlib/Cargo.lock new file mode 100644 index 0000000000..286d7ae3d6 --- /dev/null +++ b/flaxlib/Cargo.lock @@ -0,0 +1,171 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "flaxlib" +version = "0.0.1" +dependencies = [ + "pyo3", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "libc" +version = "0.2.158" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ea5043e58958ee56f3e15a90aee535795cd7dfd319846288d93c5b57d85cbe" + +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" diff --git a/flaxlib/Cargo.toml b/flaxlib/Cargo.toml new file mode 100644 index 0000000000..d0c201aa3c --- /dev/null +++ b/flaxlib/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "flaxlib" +version = "0.0.1" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "flaxlib" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = "0.22.0" diff --git a/flaxlib/flaxlib/__init__.py b/flaxlib/flaxlib/__init__.py new file mode 100644 index 0000000000..049f475bc1 --- /dev/null +++ b/flaxlib/flaxlib/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .flaxlib import sum_as_string as sum_as_string diff --git a/flaxlib/flaxlib/flaxlib.pyi b/flaxlib/flaxlib/flaxlib.pyi new file mode 100644 index 0000000000..505fd3d0f0 --- /dev/null +++ b/flaxlib/flaxlib/flaxlib.pyi @@ -0,0 +1,15 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def sum_as_string(a: int, b: int) -> str: ... diff --git a/flaxlib/pyproject.toml b/flaxlib/pyproject.toml new file mode 100644 index 0000000000..993b9703a6 --- /dev/null +++ b/flaxlib/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" + +[project] +name = "flaxlib" +requires-python = ">=3.10" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] +[project.optional-dependencies] +tests = [ + "pytest", +] +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/flaxlib/src/lib.rs b/flaxlib/src/lib.rs new file mode 100644 index 0000000000..c906f402b4 --- /dev/null +++ b/flaxlib/src/lib.rs @@ -0,0 +1,28 @@ +// Copyright 2024 The Flax Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pyo3::prelude::*; + +/// Formats the sum of two numbers as string. +#[pyfunction] +fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) +} + +/// A Python module implemented in Rust. +#[pymodule] +fn flaxlib(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + Ok(()) +} diff --git a/flaxlib/uv.lock b/flaxlib/uv.lock new file mode 100644 index 0000000000..c819ecdad4 --- /dev/null +++ b/flaxlib/uv.lock @@ -0,0 +1,86 @@ +version = 1 +requires-python = ">=3.8" + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/09/35/2495c4ac46b980e4ca1f6ad6db102322ef3ad2410b79fdde159a4b0f3b92/exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc", size = 28883 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, +] + +[[package]] +name = "flaxlib" +version = "0.1.0" +source = { editable = "." } + +[package.optional-dependencies] +tests = [ + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [{ name = "pytest", marker = "extra == 'tests'" }] + +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, +] + +[[package]] +name = "packaging" +version = "24.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/65/50db4dda066951078f0a96cf12f4b9ada6e4b811516bf0262c0f4f7064d4/packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002", size = 148788 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/aa/cc0199a5f0ad350994d660967a8efb233fe0416e4639146c089643407ce6/packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124", size = 53985 }, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + +[[package]] +name = "pytest" +version = "8.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/6c/62bbd536103af674e227c41a8f3dcd022d591f6eed5facb5a0f31ee33bbc/pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181", size = 1442487 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, +] + +[[package]] +name = "tomli" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/3f/d7af728f075fb08564c5949a9c95e44352e23dee646869fa104a3b2060a3/tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f", size = 15164 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/97/75/10a9ebee3fd790d20926a90a2547f0bf78f371b2f13aa822c759680ca7b9/tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc", size = 12757 }, +] diff --git a/pyproject.toml b/pyproject.toml index 0b21a5c277..90668acdc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,8 @@ docs = [ ] dev = [ "pre-commit>=3.8.0", + "maturin>=1.7.1", + "pip>=24.2", ] [project.urls] diff --git a/tests/flaxlib_test.py b/tests/flaxlib_test.py new file mode 100644 index 0000000000..5df524a80a --- /dev/null +++ b/tests/flaxlib_test.py @@ -0,0 +1,7 @@ +from absl.testing import absltest +import flaxlib + + +class TestFlaxlib(absltest.TestCase): + def test_flaxlib(self): + self.assertEqual(flaxlib.sum_as_string(1, 2), '3') diff --git a/uv.lock b/uv.lock index 29d358e255..4115ec74c8 100644 --- a/uv.lock +++ b/uv.lock @@ -504,7 +504,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version == '3.11'" }, + { name = "tomli", marker = "python_full_version <= '3.11'" }, ] [[package]] @@ -786,6 +786,8 @@ all = [ { name = "matplotlib" }, ] dev = [ + { name = "maturin" }, + { name = "pip" }, { name = "pre-commit" }, ] docs = [ @@ -809,9 +811,7 @@ docs = [ testing = [ { name = "clu" }, { name = "einops" }, - { name = "gymnasium" }, - { name = "gymnasium", extra = ["accept-rom-license"] }, - { name = "gymnasium", extra = ["atari"] }, + { name = "gymnasium", extra = ["accept-rom-license", "atari"] }, { name = "jaxlib" }, { name = "jaxtyping" }, { name = "jraph" }, @@ -849,6 +849,7 @@ requires-dist = [ { name = "jupytext", marker = "extra == 'docs'", specifier = "==1.13.8" }, { name = "matplotlib", marker = "extra == 'all'" }, { name = "matplotlib", marker = "extra == 'docs'" }, + { name = "maturin", marker = "extra == 'dev'", specifier = ">=1.7.1" }, { name = "ml-collections", marker = "extra == 'docs'" }, { name = "ml-collections", marker = "extra == 'testing'" }, { name = "msgpack" }, @@ -860,6 +861,7 @@ requires-dist = [ { name = "opencv-python", marker = "extra == 'testing'" }, { name = "optax" }, { name = "orbax-checkpoint" }, + { name = "pip", marker = "extra == 'dev'", specifier = ">=24.2" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.8.0" }, { name = "pygments", marker = "extra == 'docs'", specifier = ">=2.6.1" }, { name = "pytest", marker = "extra == 'testing'" }, @@ -1046,11 +1048,9 @@ wheels = [ [package.optional-dependencies] accept-rom-license = [ - { name = "autorom" }, { name = "autorom", extra = ["accept-rom-license"] }, ] atari = [ - { name = "shimmy" }, { name = "shimmy", extra = ["atari"] }, ] @@ -1662,6 +1662,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899 }, ] +[[package]] +name = "maturin" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/ec/1f688d6ad82a568fd7c239f1c7a130d3fc2634977df4ef662ee0ac58a153/maturin-1.7.1.tar.gz", hash = "sha256:147754cb3d81177ee12d9baf575d93549e76121dacd3544ad6a50ab718de2b9c", size = 190286 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/71/2da6a923a8c65749c614f95046ea0190ff00d6923edc20b0c5ecff2119f1/maturin-1.7.1-py3-none-linux_armv6l.whl", hash = "sha256:372a141b31ae7396728d2dedc6061fe4522c1803ae1c05700d37008e1d1a2cc9", size = 8198799 }, + { url = "https://files.pythonhosted.org/packages/21/7c/70e4f4e634777652101277eb1449777310f960f831c831bf7956ea81ef82/maturin-1.7.1-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:49939608095d9bcdf19d081dfd6ac1e8f915c645115090514c7b86e1e382f241", size = 15603724 }, + { url = "https://files.pythonhosted.org/packages/0d/2c/06702f20e9f8f019bc036084292c9fe3ae04b4f6a163929ee10627dd0258/maturin-1.7.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:973126a36cfb9861b3207df579678c1bcd7c348578a41ccfbe80d811a84f1740", size = 7982580 }, + { url = "https://files.pythonhosted.org/packages/6c/a3/a4841dddb81e1855b57acf393ba72c405f097a3c6d7d5078e4d7105a4735/maturin-1.7.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:6eec984d26f707b18765478f4892e58ac72e777287cd2ba721d6e2ef6da1f66e", size = 8548076 }, + { url = "https://files.pythonhosted.org/packages/ad/1c/1d0fd54bb2d068d0f9d513b0fdfb089a0fe8d20f020e673de0a0cda4f485/maturin-1.7.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:0df0a6aaf7e9ab92cce2490b03d80b8f5ecbfa0689747a2ea4dfb9e63877b79c", size = 8705393 }, + { url = "https://files.pythonhosted.org/packages/61/f4/6f4023c9653256fbcf2ef1ab6926f9fd4260390d25c258108ddfd45978d3/maturin-1.7.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:09cca3491c756d1bce6ffff13f004e8a10e67c72a1cba9579058f58220505881", size = 8422778 }, + { url = "https://files.pythonhosted.org/packages/15/d9/d927f225959e95c89fc6999130426d90d3f8285815dc2503a473049cb232/maturin-1.7.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:00f0f8f5051f4c0d0f69bdd0c6297ea87e979f70fb78a377eb4277c932804e2d", size = 8090248 }, + { url = "https://files.pythonhosted.org/packages/e3/d7/577d081996b901e02c2f3f9881fa1c1b4097bf3a0a46b7f7d8481a37ce1e/maturin-1.7.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:7bb184cfbac4e3c55ca21d322e4801e0f75e7932287e156c280c279eae60b69e", size = 8731356 }, + { url = "https://files.pythonhosted.org/packages/53/3e/725176fac7ce884bc577603f58026fd56b4faed067e81fe6a0839f5a4464/maturin-1.7.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5e8e61468d7d79790f0b54f2ed24f2fefbce3518548bc4e1a1f0c7be5bad710", size = 9901905 }, + { url = "https://files.pythonhosted.org/packages/61/69/3960d152d0a3e527212b4fe991ada3618fd2f5ec64edffdd38875adb1b9c/maturin-1.7.1-py3-none-win32.whl", hash = "sha256:07c8800603e551a45e16fe7ad1742977097ea43c18b28e491df74d4ca15c5857", size = 6479042 }, + { url = "https://files.pythonhosted.org/packages/a1/5b/512efa939f747f1a1277f981ca1de332f01bb187d193cb8d67f816c38735/maturin-1.7.1-py3-none-win_amd64.whl", hash = "sha256:c5e7e6d130072ca76956106daa276f24a66c3407cfe6cf64c196d4299fd4175c", size = 7268270 }, + { url = "https://files.pythonhosted.org/packages/c6/ce/eda05e623102dfb75b60f8b222ab3d6bc98a6e7182cc44602b422bd0f07a/maturin-1.7.1-py3-none-win_arm64.whl", hash = "sha256:acf9f539f53a7ad64d406a40b27b768f67d75e6e4e93cb04b29025144a74ef45", size = 6277021 }, +] + [[package]] name = "mdit-py-plugins" version = "0.3.5" @@ -2011,7 +2034,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -2020,7 +2042,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -2029,7 +2050,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -2038,7 +2058,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -2050,7 +2069,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -2059,7 +2077,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -2068,7 +2085,6 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -2082,7 +2098,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -2094,7 +2109,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -2113,7 +2127,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/81/b3/e456a1b2d499bb84bdc6670bfbcf41ff3bac58bd2fae6880d62834641558/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb", size = 19252608 }, { url = "https://files.pythonhosted.org/packages/59/65/7ff0569494fbaea45ad2814972cc88da843d53cc96eb8554fcd0908941d9/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79", size = 19724950 }, - { url = "https://files.pythonhosted.org/packages/cb/ef/8f96c82e1cfcf6d5b770f7b043c3cc24841fc247b37629a7cc643dbf72a1/nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6", size = 162012830 }, ] [[package]] @@ -2122,7 +2135,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]] @@ -2332,6 +2344,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/3b/ce7a01026a7cf46e5452afa86f97a5e88ca97f562cafa76570178ab56d8d/pillow-10.4.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5", size = 2554661 }, ] +[[package]] +name = "pip" +version = "24.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4d/87/fb90046e096a03aeab235e139436b3fe804cdd447ed2093b0d70eba3f7f8/pip-24.2.tar.gz", hash = "sha256:5b5e490b5e9cb275c879595064adce9ebd31b854e3e803740b72f9ccf34a45b8", size = 1922041 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/55/90db48d85f7689ec6f81c0db0622d704306c5284850383c090e6c7195a5c/pip-24.2-py3-none-any.whl", hash = "sha256:2cd581cf58ab7fcfca4ce8efa6dcacd0de5bf8d0a3eb9ec927e07405f4d9e2a2", size = 1815170 }, +] + [[package]] name = "platformdirs" version = "4.2.2" @@ -3412,9 +3433,7 @@ dependencies = [ { name = "tensorflow", marker = "platform_system != 'Darwin'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/e3/33fc5957790cf4710e0a9116cf37c0a881eda673e5f8b569bfff5654a48c/tensorflow_text-2.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8eba0b5804235519b571c827c97337c332de270107f06af6d2171cdefdc4c6a0", size = 6109587 }, { url = "https://files.pythonhosted.org/packages/61/59/2090318555d98dc9dc868b3c585ada2e1139be538d954340726aa3d3899a/tensorflow_text-2.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f04c3f478f1885ad4c7380643a768a72a3de79e1f8f40d50b48cc1fbf73893", size = 5205819 }, - { url = "https://files.pythonhosted.org/packages/92/65/e2d3d9300173a0927e8b7e3cf9a35f9539e9269786c1e1d9d945223fe21a/tensorflow_text-2.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a9b9f9c8a06878714a14f4e086fa8122beb2e141f82d0aa5a8f6b8f9b694db51", size = 6109684 }, { url = "https://files.pythonhosted.org/packages/de/32/182ecf4eb1432942876d9b0b089625564084c5ed4d03c02ddf2872177e95/tensorflow_text-2.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161c09380b090774ed721cdcce973194458708250d7dfbac7cb9ea8a3e9ac762", size = 5205866 }, ]