Skip to content

Commit

Permalink
torch: add take_along_axis
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Dec 28, 2024
1 parent 508f652 commit 1ebdfa1
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
axis = 0
return torch.index_select(x, axis, indices, **kwargs)


def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array:
return torch.take_along_dim(x, indices, dim=axis)


def sign(x: array, /) -> array:
# torch sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
Expand Down Expand Up @@ -775,6 +780,6 @@ def sign(x: array, /) -> array:
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
'take', 'sign']
'take', 'take_along_axis', 'sign']

_all_ignore = ['torch', 'get_xp']

0 comments on commit 1ebdfa1

Please sign in to comment.