Skip to content

Commit

Permalink
fix: fix for __array__ removal from array-api-strict (#4325)
Browse files Browse the repository at this point in the history
See:
https://github.com/data-apis/array-api-strict/blob/main/docs/changelog.md#major-changes

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Enhanced array conversion capability to handle device-specific data
for improved performance.
	- Introduced a robust error handling mechanism for array conversions.
- Updated serialization methods to consistently format data as NumPy
arrays.
- Standardized the conversion process across various evaluation methods
in tests.
- **Bug Fixes**
- Improved handling of input arrays from the CPU to ensure seamless
conversion to NumPy arrays.
- **Chores**
- Refined dependency specifications in `pyproject.toml` for better
compatibility.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 9, 2024
1 parent 22123aa commit 8f11bc7
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 25 deletions.
10 changes: 9 additions & 1 deletion deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Optional,
)

import array_api_compat
import ml_dtypes
import numpy as np

Expand Down Expand Up @@ -105,7 +106,14 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]:
"""
if x is None:
return None
return np.asarray(x)
try:
# asarray is not within Array API standard, so may fail
return np.asarray(x)
except (ValueError, AttributeError):
xp = array_api_compat.array_namespace(x)
# to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack.
x = xp.asarray(x, copy=True)
return np.from_dlpack(x)


__all__ = [
Expand Down
11 changes: 7 additions & 4 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand Down Expand Up @@ -548,8 +551,8 @@ def serialize(self) -> dict:
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": np.array(obj["davg"]),
"dstd": np.array(obj["dstd"]),
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
## to be updated when the options are supported.
"trainable": self.trainable,
Expand Down Expand Up @@ -1022,8 +1025,8 @@ def serialize(self) -> dict:
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": np.array(obj["davg"]),
"dstd": np.array(obj["dstd"]),
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
}
if obj.tebd_input_mode in ["strip"]:
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ test = [
"pytest-sugar",
"pytest-split",
"dpgui",
# https://github.com/data-apis/array-api-strict/issues/85
'array-api-strict>=2,<2.1.1;python_version>="3.9"',
'array-api-strict>=2,!=2.1.1;python_version>="3.9"',
]
docs = [
"sphinx>=3.1.1",
Expand Down
1 change: 0 additions & 1 deletion source/tests/common/dpmodel/array_api/test_env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

class TestEnvMat(unittest.TestCase, ArrayAPITest):
def test_compute_smooth_weight(self):
self.set_array_api_version(compute_smooth_weight)
d = xp.arange(10, dtype=xp.float64)
w = compute_smooth_weight(
d,
Expand Down
6 changes: 4 additions & 2 deletions source/tests/consistent/descriptor/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from deepmd.common import (
make_default_mesh,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
Expand Down Expand Up @@ -141,7 +144,6 @@ def eval_array_api_strict_descriptor(
box,
mixed_types: bool = False,
) -> Any:
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
ext_coords, ext_atype, mapping = extend_coord_with_ghosts(
array_api_strict.asarray(coords.reshape(1, -1, 3)),
array_api_strict.asarray(atype.reshape(1, -1)),
Expand All @@ -157,7 +159,7 @@ def eval_array_api_strict_descriptor(
distinguish_types=(not mixed_types),
)
return [
np.asarray(x) if hasattr(x, "__array_namespace__") else x
to_numpy_array(x) if hasattr(x, "__array_namespace__") else x
for x in array_api_strict_obj(
ext_coords, ext_atype, nlist=nlist, mapping=mapping
)
Expand Down
5 changes: 4 additions & 1 deletion source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -175,7 +178,7 @@ def eval_jax(self, jax_obj: Any) -> Any:
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
6 changes: 4 additions & 2 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -218,7 +221,6 @@ def eval_jax(self, jax_obj: Any) -> Any:
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
(
resnet_dt,
precision,
Expand All @@ -227,7 +229,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
6 changes: 4 additions & 2 deletions source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnerFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -232,7 +235,6 @@ def eval_jax(self, jax_obj: Any) -> Any:
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
(
resnet_dt,
precision,
Expand All @@ -241,7 +243,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
(numb_aparam, use_aparam_as_mask),
atom_ener,
) = self.param
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
5 changes: 4 additions & 1 deletion source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.polarizability_fitting import PolarFitting as PolarFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -175,7 +178,7 @@ def eval_jax(self, jax_obj: Any) -> Any:
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
6 changes: 4 additions & 2 deletions source/tests/consistent/fitting/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.property_fitting import (
PropertyFittingNet as PropertyFittingDP,
)
Expand Down Expand Up @@ -226,7 +229,6 @@ def eval_jax(self, jax_obj: Any) -> Any:
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
(
resnet_dt,
precision,
Expand All @@ -236,7 +238,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
task_dim,
intensive,
) = self.param
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
12 changes: 6 additions & 6 deletions source/tests/consistent/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from deepmd.common import (
VALID_ACTIVATION,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.network import get_activation_fn as get_activation_fn_dp

from ..seed import (
Expand All @@ -21,8 +24,8 @@

if INSTALLED_PT:
from deepmd.pt.utils.utils import ActivationFn as ActivationFn_pt
from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)
if INSTALLED_TF:
Expand Down Expand Up @@ -59,7 +62,7 @@ def test_tf_consistent_with_ref(self):
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
def test_pt_consistent_with_ref(self):
if INSTALLED_PT:
test = to_numpy_array(
test = torch_to_numpy(
ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input))
)
np.testing.assert_allclose(self.ref, test, atol=1e-10)
Expand All @@ -70,12 +73,9 @@ def test_pt_consistent_with_ref(self):
def test_arary_api_strict(self):
import array_api_strict as xp

xp.set_array_api_strict_flags(
api_version=get_activation_fn_dp.array_api_version
)
input = xp.asarray(self.random_input)
test = get_activation_fn_dp(self.activation)(input)
np.testing.assert_allclose(self.ref, np.array(test), atol=1e-10)
np.testing.assert_allclose(self.ref, to_numpy_array(test), atol=1e-10)

@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
def test_jax_consistent_with_ref(self):
Expand Down
6 changes: 5 additions & 1 deletion source/tests/consistent/test_type_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
from deepmd.utils.argcheck import (
type_embedding_args,
Expand Down Expand Up @@ -130,7 +133,8 @@ def eval_jax(self, jax_obj: Any) -> Any:
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
out = array_api_strict_obj()
return [
np.asarray(x) if hasattr(x, "__array_namespace__") else x for x in (out,)
to_numpy_array(x) if hasattr(x, "__array_namespace__") else x
for x in (out,)
]

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
Expand Down

0 comments on commit 8f11bc7

Please sign in to comment.