diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index f834754195..6e6113b494 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -8,6 +8,7 @@ Optional, ) +import array_api_compat import ml_dtypes import numpy as np @@ -105,7 +106,14 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]: """ if x is None: return None - return np.asarray(x) + try: + # asarray is not within Array API standard, so may fail + return np.asarray(x) + except (ValueError, AttributeError): + xp = array_api_compat.array_namespace(x) + # to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack. + x = xp.asarray(x, copy=True) + return np.from_dlpack(x) __all__ = [ diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index a84cc18882..259593e731 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -18,6 +18,9 @@ from deepmd.dpmodel.array_api import ( xp_take_along_axis, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, @@ -548,8 +551,8 @@ def serialize(self) -> dict: "exclude_types": obj.exclude_types, "env_protection": obj.env_protection, "@variables": { - "davg": np.array(obj["davg"]), - "dstd": np.array(obj["dstd"]), + "davg": to_numpy_array(obj["davg"]), + "dstd": to_numpy_array(obj["dstd"]), }, ## to be updated when the options are supported. "trainable": self.trainable, @@ -1022,8 +1025,8 @@ def serialize(self) -> dict: "exclude_types": obj.exclude_types, "env_protection": obj.env_protection, "@variables": { - "davg": np.array(obj["davg"]), - "dstd": np.array(obj["dstd"]), + "davg": to_numpy_array(obj["davg"]), + "dstd": to_numpy_array(obj["dstd"]), }, } if obj.tebd_input_mode in ["strip"]: diff --git a/pyproject.toml b/pyproject.toml index 7d64d48e80..7a42adf8b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,8 +85,7 @@ test = [ "pytest-sugar", "pytest-split", "dpgui", - # https://github.com/data-apis/array-api-strict/issues/85 - 'array-api-strict>=2,<2.1.1;python_version>="3.9"', + 'array-api-strict>=2,!=2.1.1;python_version>="3.9"', ] docs = [ "sphinx>=3.1.1", diff --git a/source/tests/common/dpmodel/array_api/test_env_mat.py b/source/tests/common/dpmodel/array_api/test_env_mat.py index 8dfa199d53..0c0a69fc2e 100644 --- a/source/tests/common/dpmodel/array_api/test_env_mat.py +++ b/source/tests/common/dpmodel/array_api/test_env_mat.py @@ -14,7 +14,6 @@ class TestEnvMat(unittest.TestCase, ArrayAPITest): def test_compute_smooth_weight(self): - self.set_array_api_version(compute_smooth_weight) d = xp.arange(10, dtype=xp.float64) w = compute_smooth_weight( d, diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index e0ca30c799..a469a22348 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -8,6 +8,9 @@ from deepmd.common import ( make_default_mesh, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils.nlist import ( build_neighbor_list, extend_coord_with_ghosts, @@ -141,7 +144,6 @@ def eval_array_api_strict_descriptor( box, mixed_types: bool = False, ) -> Any: - array_api_strict.set_array_api_strict_flags(api_version="2023.12") ext_coords, ext_atype, mapping = extend_coord_with_ghosts( array_api_strict.asarray(coords.reshape(1, -1, 3)), array_api_strict.asarray(atype.reshape(1, -1)), @@ -157,7 +159,7 @@ def eval_array_api_strict_descriptor( distinguish_types=(not mixed_types), ) return [ - np.asarray(x) if hasattr(x, "__array_namespace__") else x + to_numpy_array(x) if hasattr(x, "__array_namespace__") else x for x in array_api_strict_obj( ext_coords, ext_atype, nlist=nlist, mapping=mapping ) diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 60ee7322c1..088cb30238 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingDP from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -175,7 +178,7 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index d3de3ef151..0649681ccb 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingDP from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -218,7 +221,6 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: - array_api_strict.set_array_api_strict_flags(api_version="2023.12") ( resnet_dt, precision, @@ -227,7 +229,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: numb_aparam, numb_dos, ) = self.param - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index f4e78ce966..7be0382b16 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnerFittingDP from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -232,7 +235,6 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: - array_api_strict.set_array_api_strict_flags(api_version="2023.12") ( resnet_dt, precision, @@ -241,7 +243,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index bd9d013b8d..12f13d1e08 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.polarizability_fitting import PolarFitting as PolarFittingDP from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -175,7 +178,7 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index a096d4dd68..d8a56447a4 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.property_fitting import ( PropertyFittingNet as PropertyFittingDP, ) @@ -226,7 +229,6 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: - array_api_strict.set_array_api_strict_flags(api_version="2023.12") ( resnet_dt, precision, @@ -236,7 +238,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: task_dim, intensive, ) = self.param - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 5630e913a8..2af545fc35 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -7,6 +7,9 @@ from deepmd.common import ( VALID_ACTIVATION, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils.network import get_activation_fn as get_activation_fn_dp from ..seed import ( @@ -21,8 +24,8 @@ if INSTALLED_PT: from deepmd.pt.utils.utils import ActivationFn as ActivationFn_pt + from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy from deepmd.pt.utils.utils import ( - to_numpy_array, to_torch_tensor, ) if INSTALLED_TF: @@ -59,7 +62,7 @@ def test_tf_consistent_with_ref(self): @unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") def test_pt_consistent_with_ref(self): if INSTALLED_PT: - test = to_numpy_array( + test = torch_to_numpy( ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input)) ) np.testing.assert_allclose(self.ref, test, atol=1e-10) @@ -70,12 +73,9 @@ def test_pt_consistent_with_ref(self): def test_arary_api_strict(self): import array_api_strict as xp - xp.set_array_api_strict_flags( - api_version=get_activation_fn_dp.array_api_version - ) input = xp.asarray(self.random_input) test = get_activation_fn_dp(self.activation)(input) - np.testing.assert_allclose(self.ref, np.array(test), atol=1e-10) + np.testing.assert_allclose(self.ref, to_numpy_array(test), atol=1e-10) @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") def test_jax_consistent_with_ref(self): diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 0dd17c841e..10cbd1837d 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.utils.argcheck import ( type_embedding_args, @@ -130,7 +133,8 @@ def eval_jax(self, jax_obj: Any) -> Any: def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: out = array_api_strict_obj() return [ - np.asarray(x) if hasattr(x, "__array_namespace__") else x for x in (out,) + to_numpy_array(x) if hasattr(x, "__array_namespace__") else x + for x in (out,) ] def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: