Skip to content

Commit

Permalink
Add a wrapper for sign for NumPy-likes
Browse files Browse the repository at this point in the history
Fixes #183
  • Loading branch information
asmeurer committed Oct 24, 2024
1 parent 5affae5 commit dd44814
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 2 deletions.
13 changes: 12 additions & 1 deletion array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,22 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
raise ValueError("Input array must be at least 1-d.")
return tuple(xp.moveaxis(x, axis, 0))

# numpy 1.26 does not use the standard definition for sign on complex numbers

def sign(x: array, /, xp, **kwargs) -> array:

Check failure on line 535 in array_api_compat/common/_aliases.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

array_api_compat/common/_aliases.py:535:13: F821 Undefined name `array`

Check failure on line 535 in array_api_compat/common/_aliases.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

array_api_compat/common/_aliases.py:535:40: F821 Undefined name `array`
if isdtype(x.dtype, 'complex floating', xp=xp):
out = (x/xp.abs(x, **kwargs))[...]
# sign(0) = 0 but the above formula would give nan
out[x == 0+0j] = 0+0j
return out[()]
else:
return xp.sign(x, **kwargs)

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
'unstack']
'unstack', 'sign']
1 change: 1 addition & 0 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
matmul = get_xp(cp)(_aliases.matmul)
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
sign = get_xp(cp)(_aliases.sign)

_copy_default = object()

Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _dask_arange(
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)

sign = get_xp(np)(_aliases.sign)

# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
Expand Down
1 change: 1 addition & 0 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
matmul = get_xp(np)(_aliases.matmul)
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)

def _supports_buffer_protocol(obj):
try:
Expand Down

0 comments on commit dd44814

Please sign in to comment.