From 804848a5b33ca116d9db5bbad5104b9dd19eee0f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 11 Mar 2024 07:18:17 -0400 Subject: [PATCH] fix: do not install tf-keras for cu11 (#3444) Signed-off-by: Jinzhe Zeng --- backend/find_tensorflow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/find_tensorflow.py b/backend/find_tensorflow.py index fb9e719600..4d63f3118d 100644 --- a/backend/find_tensorflow.py +++ b/backend/find_tensorflow.py @@ -83,6 +83,7 @@ def find_tensorflow() -> Tuple[Optional[str], List[str]]: # TypeError if submodule_search_locations are None # IndexError if submodule_search_locations is an empty list except (AttributeError, TypeError, IndexError): + tf_version = "" if os.environ.get("CIBUILDWHEEL", "0") == "1": cuda_version = os.environ.get("CUDA_VERSION", "12.2") if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"): @@ -99,9 +100,10 @@ def find_tensorflow() -> Tuple[Optional[str], List[str]]: "tensorflow-cpu>=2.5.0rc0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'", ] ) + tf_version = "2.14.1" else: raise RuntimeError("Unsupported CUDA version") - requires.extend(get_tf_requirement()["cpu"]) + requires.extend(get_tf_requirement(tf_version)["cpu"]) # setuptools will re-find tensorflow after installing setup_requires tf_install_dir = None return tf_install_dir, requires