From a89f1ebb728f80a4019b6a0e634bc874bbedd960 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 22 Sep 2024 21:10:17 -0400 Subject: [PATCH] fix tests on py38 Signed-off-by: Jinzhe Zeng --- source/tests/consistent/test_activation.py | 3 ++- source/tests/consistent/test_type_embedding.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 1d8d776d10..5630e913a8 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -2,7 +2,6 @@ import sys import unittest -import array_api_strict as xp import numpy as np from deepmd.common import ( @@ -69,6 +68,8 @@ def test_pt_consistent_with_ref(self): sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8" ) def test_arary_api_strict(self): + import array_api_strict as xp + xp.set_array_api_strict_flags( api_version=get_activation_fn_dp.array_api_version ) diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index b8b0389a9c..c66ef0fbaa 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -36,6 +36,8 @@ jnp, ) from deepmd.jax.utils.type_embed import TypeEmbedNet as TypeEmbedNetJAX +else: + TypeEmbedNetJAX = object @parameterized(