From 1ebdfa1be08578d534cf7432e8be8d8875d4a203 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 28 Dec 2024 11:00:57 +0200 Subject: [PATCH] torch: add take_along_axis --- array_api_compat/torch/_aliases.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 44b0766..6288964 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -744,6 +744,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) + +def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: + return torch.take_along_dim(x, indices, dim=axis) + + def sign(x: array, /) -> array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 @@ -775,6 +780,6 @@ def sign(x: array, /) -> array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'sign'] + 'take', 'take_along_axis', 'sign'] _all_ignore = ['torch', 'get_xp']