diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index a443d3e970..a08e849c6c 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -101,7 +101,7 @@ class CommonTest(ABC): # we may usually skip jax before jax is fully supported skip_jax: ClassVar[bool] = True """Whether to skip the JAX model.""" - skip_pd: ClassVar[bool] = not INSTALLED_PD + skip_pd: ClassVar[bool] = True """Whether to skip the Paddle model.""" skip_array_api_strict: ClassVar[bool] = True """Whether to skip the array_api_strict model.""" diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index a463960fb7..8838696108 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -136,7 +136,7 @@ def skip_pd(self) -> bool: precision, env_protection, ) = self.param - return CommonTest.skip_pd + return not INSTALLED_PD @property def skip_array_api_strict(self) -> bool: diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index 12fafa7ba8..f5a79acabe 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -135,7 +135,7 @@ def skip_pd(self) -> bool: ) = self.param # Paddle do not support "bfloat16" in some kernels, # so skip this in CI test - return CommonTest.skip_pd or precision == "bfloat16" + return not INSTALLED_PD or precision == "bfloat16" tf_class = EnerFittingTF dp_class = EnerFittingDP