Skip to content

Commit

Permalink
try to fix
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 8, 2024
1 parent c8a57c7 commit fae3404
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
11 changes: 5 additions & 6 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,11 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]:
"""
if x is None:
return None
if hasattr(x, "__dlpack_device__") and x.__dlpack_device__()[0] == 1:
# CPU = 1, see https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack_device__.html#api-specification-generated-array-api-array-dlpack-device--page-root
# dlpack needs the device to be the same
return np.from_dlpack(x)
# asarray is not within Array API standard, so may fail
return np.asarray(x)
try:
# asarray is not within Array API standard, so may fail
return np.asarray(x)
except (ValueError, AttributeError):
return np.from_dlpack(x, copy=True)


__all__ = [
Expand Down
5 changes: 4 additions & 1 deletion source/tests/consistent/descriptor/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from deepmd.common import (
make_default_mesh,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
Expand Down Expand Up @@ -157,7 +160,7 @@ def eval_array_api_strict_descriptor(
distinguish_types=(not mixed_types),
)
return [
np.asarray(x) if hasattr(x, "__array_namespace__") else x
to_numpy_array(x) if hasattr(x, "__array_namespace__") else x
for x in array_api_strict_obj(
ext_coords, ext_atype, nlist=nlist, mapping=mapping
)
Expand Down
6 changes: 5 additions & 1 deletion source/tests/consistent/test_type_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
from deepmd.utils.argcheck import (
type_embedding_args,
Expand Down Expand Up @@ -130,7 +133,8 @@ def eval_jax(self, jax_obj: Any) -> Any:
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
out = array_api_strict_obj()
return [
np.asarray(x) if hasattr(x, "__array_namespace__") else x for x in (out,)
to_numpy_array(x) if hasattr(x, "__array_namespace__") else x
for x in (out,)
]

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
Expand Down

0 comments on commit fae3404

Please sign in to comment.