diff --git a/backend/find_pytorch.py b/backend/find_pytorch.py index e01f4e84fe..ff645de458 100644 --- a/backend/find_pytorch.py +++ b/backend/find_pytorch.py @@ -22,6 +22,9 @@ Union, ) +from packaging.specifiers import ( + SpecifierSet, +) from packaging.version import ( Version, ) @@ -104,6 +107,16 @@ def get_pt_requirement(pt_version: str = "") -> dict: """ if pt_version is None: return {"torch": []} + if os.environ.get("CIBUILDWHEEL", "0") == "1": + cuda_version = os.environ.get("CUDA_VERSION", "12.2") + if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"): + # CUDA 12.2, cudnn 9 + pt_version = "2.5.0" + elif cuda_version in SpecifierSet(">=11,<12"): + # CUDA 11.8, cudnn 8 + pt_version = "2.3.1" + else: + raise RuntimeError("Unsupported CUDA version") from None if pt_version == "": pt_version = os.environ.get("PYTORCH_VERSION", "") diff --git a/backend/find_tensorflow.py b/backend/find_tensorflow.py index 5b0de0b2dd..1fc3a8a6d9 100644 --- a/backend/find_tensorflow.py +++ b/backend/find_tensorflow.py @@ -85,14 +85,14 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]: if os.environ.get("CIBUILDWHEEL", "0") == "1": cuda_version = os.environ.get("CUDA_VERSION", "12.2") if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"): - # CUDA 12.2 + # CUDA 12.2, cudnn 9 requires.extend( [ - "tensorflow-cpu>=2.15.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'", + "tensorflow-cpu>=2.18.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'", ] ) elif cuda_version in SpecifierSet(">=11,<12"): - # CUDA 11.8 + # CUDA 11.8, cudnn 8 requires.extend( [ "tensorflow-cpu>=2.5.0rc0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'", diff --git a/pyproject.toml b/pyproject.toml index c0c6b13719..06d39fe2f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -276,8 +276,6 @@ PATH = "/usr/lib64/mpich/bin:$PATH" UV_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cpu" # trick to find the correction version of mpich CMAKE_PREFIX_PATH="/opt/python/cp311-cp311/" -TENSORFLOW_VERSION = "2.18.0rc2" -PYTORCH_VERSION = "2.5.0" [tool.cibuildwheel.windows] test-extras = ["cpu", "torch"]