From 993b41d837b733a293cbfaac81a76b103e47c9fb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 23 Oct 2024 16:15:17 -0400 Subject: [PATCH] set dtype for `xp_take_along_axis` Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 360df78a7b..25c8868174 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -61,7 +61,7 @@ def xp_take_along_axis(arr, indices, axis): else: indices = xp.reshape(indices, (0, 0)) - offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis] + offset = (xp.arange(indices.shape[0], dtype=indices.type) * m)[:, xp.newaxis] indices = xp.reshape(offset + indices, (-1,)) out = xp.take(arr, indices)