Skip to content

Commit

Permalink
set dtype for xp_take_along_axis
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 23, 2024
1 parent 6074d33 commit 993b41d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def xp_take_along_axis(arr, indices, axis):
else:
indices = xp.reshape(indices, (0, 0))

offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
offset = (xp.arange(indices.shape[0], dtype=indices.type) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))

out = xp.take(arr, indices)
Expand Down

0 comments on commit 993b41d

Please sign in to comment.