diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 25c8868174..322bf0e151 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -61,7 +61,7 @@ def xp_take_along_axis(arr, indices, axis): else: indices = xp.reshape(indices, (0, 0)) - offset = (xp.arange(indices.shape[0], dtype=indices.type) * m)[:, xp.newaxis] + offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis] indices = xp.reshape(offset + indices, (-1,)) out = xp.take(arr, indices)