Skip to content

Commit

Permalink
Wrap numpy and cupy nonzero to error on zero-dimensional arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Jan 18, 2024
1 parent 486ca51 commit 739730c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
10 changes: 8 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,12 @@ def sort(
res = xp.flip(res, axis=axis)
return res

# nonzero should error for zero-dimensional arrays
def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return xp.nonzero(x, **kwargs)

# sum() and prod() should always upcast when dtype=None
def sum(
x: ndarray,
Expand Down Expand Up @@ -526,5 +532,5 @@ def isdtype(
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
1 change: 1 addition & 0 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
reshape = get_xp(cp)(_aliases.reshape)
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
sum = get_xp(cp)(_aliases.sum)
prod = get_xp(cp)(_aliases.prod)
ceil = get_xp(cp)(_aliases.ceil)
Expand Down
1 change: 1 addition & 0 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
reshape = get_xp(np)(_aliases.reshape)
argsort = get_xp(np)(_aliases.argsort)
sort = get_xp(np)(_aliases.sort)
nonzero = get_xp(np)(_aliases.nonzero)
sum = get_xp(np)(_aliases.sum)
prod = get_xp(np)(_aliases.prod)
ceil = get_xp(np)(_aliases.ceil)
Expand Down

0 comments on commit 739730c

Please sign in to comment.