diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 360df78a7b..322bf0e151 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -61,7 +61,7 @@ def xp_take_along_axis(arr, indices, axis): else: indices = xp.reshape(indices, (0, 0)) - offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis] + offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis] indices = xp.reshape(offset + indices, (-1,)) out = xp.take(arr, indices) diff --git a/source/checker/deepmd_checker.py b/source/checker/deepmd_checker.py index d763835fdc..0e11ed71c7 100644 --- a/source/checker/deepmd_checker.py +++ b/source/checker/deepmd_checker.py @@ -37,7 +37,7 @@ def visit_call(self, node): if ( isinstance(node.func, Attribute) and isinstance(node.func.expr, Name) - and node.func.expr.name in {"np", "tf", "torch"} + and node.func.expr.name in {"np", "tf", "torch", "xp", "jnp"} and node.func.attrname in { # https://pytorch.org/docs/stable/torch.html#creation-ops