diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 63d99e45..a9847d61 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -144,6 +144,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: # Basic renames bitwise_invert = torch.bitwise_not +newaxis = None # Two-arg elementwise functions # These require a wrapper to do the correct type promotion on 0-D tensors @@ -690,8 +691,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) -__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add', - 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', +__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', + 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',