From dd44814f6a26092d6b4e01ecc66b7e83c0ebf2b9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 24 Oct 2024 15:58:36 -0600 Subject: [PATCH 1/4] Add a wrapper for sign for NumPy-likes Fixes #183 --- array_api_compat/common/_aliases.py | 13 ++++++++++++- array_api_compat/cupy/_aliases.py | 1 + array_api_compat/dask/array/_aliases.py | 2 +- array_api_compat/numpy/_aliases.py | 1 + 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 91c4d9a7..d32c5ddd 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -530,6 +530,17 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) +# numpy 1.26 does not use the standard definition for sign on complex numbers + +def sign(x: array, /, xp, **kwargs) -> array: + if isdtype(x.dtype, 'complex floating', xp=xp): + out = (x/xp.abs(x, **kwargs))[...] + # sign(0) = 0 but the above formula would give nan + out[x == 0+0j] = 0+0j + return out[()] + else: + return xp.sign(x, **kwargs) + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', @@ -537,4 +548,4 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack'] + 'unstack', 'sign'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30ae2943..f3f83100 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -62,6 +62,7 @@ matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) +sign = get_xp(cp)(_aliases.sign) _copy_default = object() diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index a24694f3..ee2d88c0 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -104,7 +104,7 @@ def _dask_arange( trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) - +sign = get_xp(np)(_aliases.sign) # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 355215e4..2bfc98ff 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -62,6 +62,7 @@ matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) +sign = get_xp(np)(_aliases.sign) def _supports_buffer_protocol(obj): try: From 25390576a1ee87b9f6e16e8a1db672686907c83d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 24 Oct 2024 16:03:59 -0600 Subject: [PATCH 2/4] Fix ruff errors Ensure nan propagation is still handled correctly for CuPy sign(). --- array_api_compat/common/_aliases.py | 9 ++++++--- array_api_compat/cupy/_aliases.py | 7 ------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index d32c5ddd..37d8ebd1 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -532,14 +532,17 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: array, /, xp, **kwargs) -> array: +def sign(x: ndarray, /, xp, **kwargs) -> ndarray: if isdtype(x.dtype, 'complex floating', xp=xp): out = (x/xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan out[x == 0+0j] = 0+0j - return out[()] else: - return xp.sign(x, **kwargs) + out = xp.sign(x, **kwargs) + # CuPy sign() does not propagate nans. See + # https://github.com/data-apis/array-api-compat/issues/136 + out[xp.isnan(x)] = xp.nan + return out[()] __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index f3f83100..3627fb6b 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -110,13 +110,6 @@ def asarray( return cp.array(obj, dtype=dtype, **kwargs) -def sign(x: ndarray, /) -> ndarray: - # CuPy sign() does not propagate nans. See - # https://github.com/data-apis/array-api-compat/issues/136 - out = cp.sign(x) - out[cp.isnan(x)] = cp.nan - return out - # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): From 55e3f717f843d01ecad5741ed4d06194a4be92c1 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 24 Oct 2024 16:06:25 -0600 Subject: [PATCH 3/4] Fix cupy sign nan handling --- array_api_compat/common/_aliases.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 37d8ebd1..7a90f444 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple import inspect -from ._helpers import array_namespace, _check_device, device, is_torch_array +from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace # These functions are modified from the NumPy versions. @@ -541,7 +541,8 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - out[xp.isnan(x)] = xp.nan + if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + out[xp.isnan(x)] = xp.nan return out[()] __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', From 8dee4d613c9e45a2dd1edc8ef62465b0d13edd29 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Oct 2024 15:11:25 -0600 Subject: [PATCH 4/4] Update torch xfails --- torch-xfails.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch-xfails.txt b/torch-xfails.txt index c7abe2e9..c972659e 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -56,6 +56,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1 array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] +# inverse trig functions are too inaccurate on CPU +array_api_tests/test_operators_and_elementwise_functions.py::test_acos +array_api_tests/test_operators_and_elementwise_functions.py::test_atan +array_api_tests/test_operators_and_elementwise_functions.py::test_asin # overflow near float max array_api_tests/test_operators_and_elementwise_functions.py::test_log1p