Skip to content

Commit

Permalink
jax is only available in Python 3.10
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 23, 2024
1 parent 64df179 commit b8a33b1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ cu12 = [
"nvidia-cuda-nvcc-cu12",
]
jax = [
"jax>=0.4.33",
'jax>=0.4.33;python_version>="3.10"',
]

[tool.deepmd_build_backend.scripts]
Expand Down
4 changes: 2 additions & 2 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
INSTALLED_PT = Backend.get_backend("pytorch")().is_available()
INSTALLED_JAX = Backend.get_backend("jax")().is_available()

if os.environ.get("CI") and not (INSTALLED_TF and INSTALLED_PT and INSTALLED_JAX):
raise ImportError("TensorFlow or PyTorch or JAX should be tested in the CI")
if os.environ.get("CI") and not (INSTALLED_TF and INSTALLED_PT):
raise ImportError("TensorFlow or PyTorch should be tested in the CI")


if INSTALLED_TF:
Expand Down

0 comments on commit b8a33b1

Please sign in to comment.