From f228b584e87dfa1280561e15373f04d530b10762 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 8 Jan 2024 18:40:18 -0700 Subject: [PATCH 01/13] Add fft support for numpy and cupy This is based off of https://github.com/numpy/numpy/pull/25317 --- array_api_compat/common/_fft.py | 183 +++++++++++++++++++++++++++++ array_api_compat/cupy/__init__.py | 2 + array_api_compat/cupy/fft.py | 29 +++++ array_api_compat/numpy/__init__.py | 2 + array_api_compat/numpy/fft.py | 29 +++++ 5 files changed, 245 insertions(+) create mode 100644 array_api_compat/common/_fft.py create mode 100644 array_api_compat/cupy/fft.py create mode 100644 array_api_compat/numpy/fft.py diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py new file mode 100644 index 00000000..4d2bb3fe --- /dev/null +++ b/array_api_compat/common/_fft.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Union, Optional, Literal + +if TYPE_CHECKING: + from ._typing import Device, ndarray + from collections.abc import Sequence + +# Note: NumPy fft functions improperly upcast float32 and complex64 to +# complex128, which is why we require wrapping them all here. + +def fft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.fft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def ifft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def fftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def ifftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def rfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) + if x.dtype == xp.float32: + return res.astype(xp.complex64) + return res + +def irfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) + if x.dtype == xp.complex64: + return res.astype(xp.float32) + return res + +def rfftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) + if x.dtype == xp.float32: + return res.astype(xp.complex64) + return res + +def irfftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) + if x.dtype == xp.complex64: + return res.astype(xp.float32) + return res + +def hfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def ihfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device {device!r}") + return xp.fft.fftfreq(n, d=d) + +def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device {device!r}") + return xp.fft.rfftfreq(n, d=d) + +def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: + return xp.fft.fftshift(x, axes=axes) + +def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: + return xp.fft.ifftshift(x, axes=axes) + +__all__ = [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", +] diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index ec113f9d..d820e44b 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -9,6 +9,8 @@ # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + from .linalg import matrix_transpose, vecdot from ..common._helpers import * diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py new file mode 100644 index 00000000..8e83abb8 --- /dev/null +++ b/array_api_compat/cupy/fft.py @@ -0,0 +1,29 @@ +from cupy.fft import * +from cupy.fft import __all__ as fft_all + +from ..common import _fft +from .._internal import get_xp + +import cupy as cp + +fft = get_xp(cp)(_fft.fft), +ifft = get_xp(cp)(_fft.ifft), +fftn = get_xp(cp)(_fft.fftn), +ifftn = get_xp(cp)(_fft.ifftn), +rfft = get_xp(cp)(_fft.rfft), +irfft = get_xp(cp)(_fft.irfft), +rfftn = get_xp(cp)(_fft.rfftn), +irfftn = get_xp(cp)(_fft.irfftn), +hfft = get_xp(cp)(_fft.hfft), +ihfft = get_xp(cp)(_fft.ihfft), +fftfreq = get_xp(cp)(_fft.fftfreq), +rfftfreq = get_xp(cp)(_fft.rfftfreq), +fftshift = get_xp(cp)(_fft.fftshift), +ifftshift = get_xp(cp)(_fft.ifftshift), + +__all__ = fft_all + _fft.__all__ + +del get_xp +del cp +del fft_all +del _fft diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 4a49f2f1..ff5efdfd 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -15,6 +15,8 @@ # dynamically so that the library can be vendored. __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + from .linalg import matrix_transpose, vecdot from ..common._helpers import * diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py new file mode 100644 index 00000000..6093b19d --- /dev/null +++ b/array_api_compat/numpy/fft.py @@ -0,0 +1,29 @@ +from numpy.fft import * +from numpy.fft import __all__ as fft_all + +from ..common import _fft +from .._internal import get_xp + +import numpy as np + +fft = get_xp(np)(_fft.fft) +ifft = get_xp(np)(_fft.ifft) +fftn = get_xp(np)(_fft.fftn) +ifftn = get_xp(np)(_fft.ifftn) +rfft = get_xp(np)(_fft.rfft) +irfft = get_xp(np)(_fft.irfft) +rfftn = get_xp(np)(_fft.rfftn) +irfftn = get_xp(np)(_fft.irfftn) +hfft = get_xp(np)(_fft.hfft) +ihfft = get_xp(np)(_fft.ihfft) +fftfreq = get_xp(np)(_fft.fftfreq) +rfftfreq = get_xp(np)(_fft.rfftfreq) +fftshift = get_xp(np)(_fft.fftshift) +ifftshift = get_xp(np)(_fft.ifftshift) + +__all__ = fft_all + _fft.__all__ + +del get_xp +del np +del fft_all +del _fft From d7a9ecbad6e522ce43d2d4941e10e0792e252941 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 16:37:57 -0700 Subject: [PATCH 02/13] Fix hfft downcasting logic --- array_api_compat/common/_fft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index 4d2bb3fe..666b0b1f 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -132,7 +132,7 @@ def hfft( ) -> ndarray: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: - return res.astype(xp.complex64) + return res.astype(xp.float32) return res def ihfft( From f97b59ec343a7290d0ae24554797f00195cb6e45 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 16:39:44 -0700 Subject: [PATCH 03/13] Remove fft xfails --- cupy-xfails.txt | 13 ------------- numpy-1-21-xfails.txt | 13 ------------- numpy-dev-xfails.txt | 13 ------------- numpy-xfails.txt | 13 ------------- torch-xfails.txt | 13 ------------- 5 files changed, 65 deletions(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index cfacbe33..e76c4c32 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -164,16 +164,3 @@ array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] - -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fft -array_api_tests/test_fft.py::test_ifft -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 9a0d2827..dce83859 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -50,19 +50,6 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fft -array_api_tests/test_fft.py::test_ifft -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft - # NumPy 1.21 specific XFAILS ############################ diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 5e270a95..8d291d01 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -42,16 +42,3 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum - -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fft -array_api_tests/test_fft.py::test_ifft -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft diff --git a/numpy-xfails.txt b/numpy-xfails.txt index d0be245b..e44d7035 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -44,16 +44,3 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum - -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fft -array_api_tests/test_fft.py::test_ifft -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft diff --git a/torch-xfails.txt b/torch-xfails.txt index caf1aa65..a9106fae 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -190,16 +190,3 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_expm1 array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values - -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft -array_api_tests/test_fft.py::test_shift_func[fftshift] -array_api_tests/test_fft.py::test_shift_func[ifftshift] From 1ea7ecd97f98e025c6c28ce2a18b7ff36e128354 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 16:53:14 -0700 Subject: [PATCH 04/13] Add wrappers for torch.fft The only thing that needs to be wrapped is a few functions which do not properly map axes to dim. --- array_api_compat/torch/__init__.py | 2 + array_api_compat/torch/fft.py | 84 ++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 array_api_compat/torch/fft.py diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 59898aab..172f5279 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -17,6 +17,8 @@ # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + from ..common._helpers import * # noqa: F403 __array_api_version__ = '2022.12' diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py new file mode 100644 index 00000000..dbf74cb0 --- /dev/null +++ b/array_api_compat/torch/fft.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import torch + array = torch.Tensor + from typing import Union, Sequence, Literal + +from torch.fft import * # noqa: F403 +import torch.fft + +# Several torch fft functions do not map axes to dim + +def fftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) + +def ifftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) + +def rfftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) + +def irfftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) + +def fftshift( + x: array, + /, + *, + axes: Union[int, Sequence[int]] = None, + **kwargs, +) -> array: + return torch.fft.fftshift(x, dim=axes, **kwargs) + +def ifftshift( + x: array, + /, + *, + axes: Union[int, Sequence[int]] = None, + **kwargs, +) -> array: + return torch.fft.ifftshift(x, dim=axes, **kwargs) + + +__all__ = torch.fft.__all__ + [ + "fftn", + "ifftn", + "rfftn", + "irfftn", + "fftshift", + "ifftshift", +] From 18960aafc9bbdad120a251e49e4e98a0bf89649f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 18:05:59 -0700 Subject: [PATCH 05/13] Fix ruff and tests --- array_api_compat/cupy/__init__.py | 2 +- array_api_compat/cupy/fft.py | 2 +- array_api_compat/numpy/fft.py | 2 +- array_api_compat/torch/fft.py | 2 ++ 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 697c20f3..7968d68d 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,4 +1,4 @@ -from cupy import * +from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names from cupy import abs, max, min, round # noqa: F401 diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 8e83abb8..297a52b6 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -1,4 +1,4 @@ -from cupy.fft import * +from cupy.fft import * # noqa: F403 from cupy.fft import __all__ as fft_all from ..common import _fft diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 6093b19d..28667594 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,4 +1,4 @@ -from numpy.fft import * +from numpy.fft import * # noqa: F403 from numpy.fft import __all__ as fft_all from ..common import _fft diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index dbf74cb0..3c9117ee 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -82,3 +82,5 @@ def ifftshift( "fftshift", "ifftshift", ] + +_all_ignore = ['torch'] From da6d4e44572339a6a40b5ce0ef6e46d417481e9c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 18:12:02 -0700 Subject: [PATCH 06/13] Fix cupy fft __all__ --- array_api_compat/cupy/fft.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 297a52b6..db1f8047 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -1,5 +1,12 @@ from cupy.fft import * # noqa: F403 -from cupy.fft import __all__ as fft_all +# cupy.fft doesn't have __all__. If it is added, replace this with +# +# from cupy.fft import __all__ as linalg_all +_n = {} +exec('from cupy.fft import *', _n) +del _n['__builtins__'] +fft_all = list(_n) +del _n from ..common import _fft from .._internal import get_xp From 912e80c146c41a25e034e5cd04b83cff9ee28c3e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 18:13:43 -0700 Subject: [PATCH 07/13] Avoid testing against vendored array_api_compat in test_all --- tests/test_all.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_all.py b/tests/test_all.py index 5b49fa14..7a6f74f0 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -21,7 +21,7 @@ def test_all(library): import_(library, wrapper=True) for mod_name in sys.modules: - if 'array_api_compat.' + library not in mod_name: + if not mod_name.startswith('array_api_compat.' + library): continue module = sys.modules[mod_name] From 4018fe43d0ead5ed83fbcf07580197801bc35f1e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 18:15:52 -0700 Subject: [PATCH 08/13] Fix import_('cupy', wrapper=True) tests helper --- tests/_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index 23cb5db9..c41bc881 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -6,8 +6,6 @@ def import_(library, wrapper=False): - if library == 'cupy': - return pytest.importorskip(library) if 'jax' in library and sys.version_info < (3, 9): pytest.skip('JAX array API support does not support Python 3.8') @@ -16,5 +14,7 @@ def import_(library, wrapper=False): library = 'jax.experimental.array_api' else: library = 'array_api_compat.' + library + elif library == 'cupy': + return pytest.importorskip(library) return import_module(library) From d7f95a32e61e168042d1baae4c3919da5a714662 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 18:16:58 -0700 Subject: [PATCH 09/13] Fix test_all for cupy --- array_api_compat/cupy/_aliases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 968b974b..b9364ac6 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -77,3 +77,5 @@ 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'concat', 'pow'] + +_all_ignore = ['cp', 'get_xp'] From 29ec4d6d7c79da4162f4ba3b3d1be267f91b2b91 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 18:18:31 -0700 Subject: [PATCH 10/13] Fix array api tests pytest call in test_cupy.sh --- test_cupy.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_cupy.sh b/test_cupy.sh index 6b4d6b56..3d8c0711 100755 --- a/test_cupy.sh +++ b/test_cupy.sh @@ -26,4 +26,4 @@ mkdir -p $SCRIPT_DIR/.hypothesis ln -s $SCRIPT_DIR/.hypothesis .hypothesis export ARRAY_API_TESTS_MODULE=array_api_compat.cupy -pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" +pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" From de64f5ecd7d1a1e5abb87b98ea7ad5e3759a5825 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Mar 2024 18:59:31 -0700 Subject: [PATCH 11/13] Remove a bunch of incorrect trailing commas --- array_api_compat/cupy/fft.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index db1f8047..307e0f72 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -13,20 +13,20 @@ import cupy as cp -fft = get_xp(cp)(_fft.fft), -ifft = get_xp(cp)(_fft.ifft), -fftn = get_xp(cp)(_fft.fftn), -ifftn = get_xp(cp)(_fft.ifftn), -rfft = get_xp(cp)(_fft.rfft), -irfft = get_xp(cp)(_fft.irfft), -rfftn = get_xp(cp)(_fft.rfftn), -irfftn = get_xp(cp)(_fft.irfftn), -hfft = get_xp(cp)(_fft.hfft), -ihfft = get_xp(cp)(_fft.ihfft), -fftfreq = get_xp(cp)(_fft.fftfreq), -rfftfreq = get_xp(cp)(_fft.rfftfreq), -fftshift = get_xp(cp)(_fft.fftshift), -ifftshift = get_xp(cp)(_fft.ifftshift), +fft = get_xp(cp)(_fft.fft) +ifft = get_xp(cp)(_fft.ifft) +fftn = get_xp(cp)(_fft.fftn) +ifftn = get_xp(cp)(_fft.ifftn) +rfft = get_xp(cp)(_fft.rfft) +irfft = get_xp(cp)(_fft.irfft) +rfftn = get_xp(cp)(_fft.rfftn) +irfftn = get_xp(cp)(_fft.irfftn) +hfft = get_xp(cp)(_fft.hfft) +ihfft = get_xp(cp)(_fft.ihfft) +fftfreq = get_xp(cp)(_fft.fftfreq) +rfftfreq = get_xp(cp)(_fft.rfftfreq) +fftshift = get_xp(cp)(_fft.fftshift) +ifftshift = get_xp(cp)(_fft.ifftshift) __all__ = fft_all + _fft.__all__ From 9e613cce3197f53f5f0bd7a03d676c4218521923 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Mar 2024 14:28:09 -0700 Subject: [PATCH 12/13] Fix cupy skipping in the tests --- tests/_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index c41bc881..e8421b52 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -9,12 +9,12 @@ def import_(library, wrapper=False): if 'jax' in library and sys.version_info < (3, 9): pytest.skip('JAX array API support does not support Python 3.8') + if library == 'cupy': + pytest.importorskip(library) if wrapper: if 'jax' in library: library = 'jax.experimental.array_api' else: library = 'array_api_compat.' + library - elif library == 'cupy': - return pytest.importorskip(library) return import_module(library) From cb46aad155d87d7a75aa2d04e8d4c80c7196d691 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Mar 2024 14:36:57 -0700 Subject: [PATCH 13/13] Add xfails for cupy n-dim fft funcs --- cupy-xfails.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index e76c4c32..85ca5aa4 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -164,3 +164,9 @@ array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] + +# CuPy gives the wrong shape for n-dim fft funcs. See +# https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870 +array_api_tests/test_fft.py::test_fftn +array_api_tests/test_fft.py::test_ifftn +array_api_tests/test_fft.py::test_rfftn