From 2bf4869ec0a86afab0a31a9c6df723e2dcf2e9f9 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 22 Nov 2024 18:42:29 -0500 Subject: [PATCH] chore: use `xp.take_along_axis` is Array API version >=2024.12 see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39 Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/array_api.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 322bf0e151..e5c0557851 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -2,6 +2,9 @@ """Utilities for the array API.""" import array_api_compat +from packaging.version import ( + Version, +) def support_array_api(version: str) -> callable: @@ -45,6 +48,9 @@ def xp_swapaxes(a, axis1, axis2): def xp_take_along_axis(arr, indices, axis): xp = array_api_compat.array_namespace(arr) + if Version(xp.__array_api_version__) >= Version("2024.12"): + # see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39 + return xp.take_along_axis(arr, indices, axis=axis) arr = xp_swapaxes(arr, axis, -1) indices = xp_swapaxes(indices, axis, -1)