diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index e5231770..c057e71d 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -386,6 +386,12 @@ def sort( res = xp.flip(res, axis=axis) return res +# nonzero should error for zero-dimensional arrays +def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: + if x.ndim == 0: + raise ValueError("nonzero() does not support zero-dimensional arrays") + return xp.nonzero(x, **kwargs) + # sum() and prod() should always upcast when dtype=None def sum( x: ndarray, @@ -526,5 +532,5 @@ def isdtype( 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', - 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul', - 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] + 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', + 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index b43c371f..da2ebf03 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -52,6 +52,7 @@ reshape = get_xp(cp)(_aliases.reshape) argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) +nonzero = get_xp(cp)(_aliases.nonzero) sum = get_xp(cp)(_aliases.sum) prod = get_xp(cp)(_aliases.prod) ceil = get_xp(cp)(_aliases.ceil) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 08f4de0b..633b2f62 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -52,6 +52,7 @@ reshape = get_xp(np)(_aliases.reshape) argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) +nonzero = get_xp(np)(_aliases.nonzero) sum = get_xp(np)(_aliases.sum) prod = get_xp(np)(_aliases.prod) ceil = get_xp(np)(_aliases.ceil)