diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py new file mode 100644 index 00000000..666b0b1f --- /dev/null +++ b/array_api_compat/common/_fft.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Union, Optional, Literal + +if TYPE_CHECKING: + from ._typing import Device, ndarray + from collections.abc import Sequence + +# Note: NumPy fft functions improperly upcast float32 and complex64 to +# complex128, which is why we require wrapping them all here. + +def fft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.fft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def ifft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def fftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def ifftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def rfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) + if x.dtype == xp.float32: + return res.astype(xp.complex64) + return res + +def irfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) + if x.dtype == xp.complex64: + return res.astype(xp.float32) + return res + +def rfftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) + if x.dtype == xp.float32: + return res.astype(xp.complex64) + return res + +def irfftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) + if x.dtype == xp.complex64: + return res.astype(xp.float32) + return res + +def hfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.float32) + return res + +def ihfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device {device!r}") + return xp.fft.fftfreq(n, d=d) + +def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device {device!r}") + return xp.fft.rfftfreq(n, d=d) + +def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: + return xp.fft.fftshift(x, axes=axes) + +def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: + return xp.fft.ifftshift(x, axes=axes) + +__all__ = [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", +] diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 148691f5..7968d68d 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -9,6 +9,8 @@ # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + from ..common._helpers import * # noqa: F401,F403 __array_api_version__ = '2022.12' diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 968b974b..b9364ac6 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -77,3 +77,5 @@ 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'concat', 'pow'] + +_all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py new file mode 100644 index 00000000..307e0f72 --- /dev/null +++ b/array_api_compat/cupy/fft.py @@ -0,0 +1,36 @@ +from cupy.fft import * # noqa: F403 +# cupy.fft doesn't have __all__. If it is added, replace this with +# +# from cupy.fft import __all__ as linalg_all +_n = {} +exec('from cupy.fft import *', _n) +del _n['__builtins__'] +fft_all = list(_n) +del _n + +from ..common import _fft +from .._internal import get_xp + +import cupy as cp + +fft = get_xp(cp)(_fft.fft) +ifft = get_xp(cp)(_fft.ifft) +fftn = get_xp(cp)(_fft.fftn) +ifftn = get_xp(cp)(_fft.ifftn) +rfft = get_xp(cp)(_fft.rfft) +irfft = get_xp(cp)(_fft.irfft) +rfftn = get_xp(cp)(_fft.rfftn) +irfftn = get_xp(cp)(_fft.irfftn) +hfft = get_xp(cp)(_fft.hfft) +ihfft = get_xp(cp)(_fft.ihfft) +fftfreq = get_xp(cp)(_fft.fftfreq) +rfftfreq = get_xp(cp)(_fft.rfftfreq) +fftshift = get_xp(cp)(_fft.fftshift) +ifftshift = get_xp(cp)(_fft.ifftshift) + +__all__ = fft_all + _fft.__all__ + +del get_xp +del cp +del fft_all +del _fft diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index adf20191..87908709 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -15,6 +15,8 @@ # dynamically so that the library can be vendored. __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + from .linalg import matrix_transpose, vecdot # noqa: F401 from ..common._helpers import * # noqa: F403 diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py new file mode 100644 index 00000000..28667594 --- /dev/null +++ b/array_api_compat/numpy/fft.py @@ -0,0 +1,29 @@ +from numpy.fft import * # noqa: F403 +from numpy.fft import __all__ as fft_all + +from ..common import _fft +from .._internal import get_xp + +import numpy as np + +fft = get_xp(np)(_fft.fft) +ifft = get_xp(np)(_fft.ifft) +fftn = get_xp(np)(_fft.fftn) +ifftn = get_xp(np)(_fft.ifftn) +rfft = get_xp(np)(_fft.rfft) +irfft = get_xp(np)(_fft.irfft) +rfftn = get_xp(np)(_fft.rfftn) +irfftn = get_xp(np)(_fft.irfftn) +hfft = get_xp(np)(_fft.hfft) +ihfft = get_xp(np)(_fft.ihfft) +fftfreq = get_xp(np)(_fft.fftfreq) +rfftfreq = get_xp(np)(_fft.rfftfreq) +fftshift = get_xp(np)(_fft.fftshift) +ifftshift = get_xp(np)(_fft.ifftshift) + +__all__ = fft_all + _fft.__all__ + +del get_xp +del np +del fft_all +del _fft diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 59898aab..172f5279 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -17,6 +17,8 @@ # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + from ..common._helpers import * # noqa: F403 __array_api_version__ = '2022.12' diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py new file mode 100644 index 00000000..3c9117ee --- /dev/null +++ b/array_api_compat/torch/fft.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import torch + array = torch.Tensor + from typing import Union, Sequence, Literal + +from torch.fft import * # noqa: F403 +import torch.fft + +# Several torch fft functions do not map axes to dim + +def fftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) + +def ifftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) + +def rfftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) + +def irfftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) + +def fftshift( + x: array, + /, + *, + axes: Union[int, Sequence[int]] = None, + **kwargs, +) -> array: + return torch.fft.fftshift(x, dim=axes, **kwargs) + +def ifftshift( + x: array, + /, + *, + axes: Union[int, Sequence[int]] = None, + **kwargs, +) -> array: + return torch.fft.ifftshift(x, dim=axes, **kwargs) + + +__all__ = torch.fft.__all__ + [ + "fftn", + "ifftn", + "rfftn", + "irfftn", + "fftshift", + "ifftshift", +] + +_all_ignore = ['torch'] diff --git a/cupy-xfails.txt b/cupy-xfails.txt index cfacbe33..85ca5aa4 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -165,15 +165,8 @@ array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fft -array_api_tests/test_fft.py::test_ifft +# CuPy gives the wrong shape for n-dim fft funcs. See +# https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870 array_api_tests/test_fft.py::test_fftn array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 1c38b209..868b3c80 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -50,19 +50,6 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fft -array_api_tests/test_fft.py::test_ifft -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft - # NumPy 1.21 specific XFAILS ############################ diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 5e270a95..8d291d01 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -42,16 +42,3 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum - -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fft -array_api_tests/test_fft.py::test_ifft -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft diff --git a/numpy-xfails.txt b/numpy-xfails.txt index d0be245b..e44d7035 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -44,16 +44,3 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum - -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fft -array_api_tests/test_fft.py::test_ifft -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft diff --git a/test_cupy.sh b/test_cupy.sh index 6b4d6b56..3d8c0711 100755 --- a/test_cupy.sh +++ b/test_cupy.sh @@ -26,4 +26,4 @@ mkdir -p $SCRIPT_DIR/.hypothesis ln -s $SCRIPT_DIR/.hypothesis .hypothesis export ARRAY_API_TESTS_MODULE=array_api_compat.cupy -pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" +pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" diff --git a/tests/_helpers.py b/tests/_helpers.py index 23cb5db9..e8421b52 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -6,11 +6,11 @@ def import_(library, wrapper=False): - if library == 'cupy': - return pytest.importorskip(library) if 'jax' in library and sys.version_info < (3, 9): pytest.skip('JAX array API support does not support Python 3.8') + if library == 'cupy': + pytest.importorskip(library) if wrapper: if 'jax' in library: library = 'jax.experimental.array_api' diff --git a/tests/test_all.py b/tests/test_all.py index 5b49fa14..7a6f74f0 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -21,7 +21,7 @@ def test_all(library): import_(library, wrapper=True) for mod_name in sys.modules: - if 'array_api_compat.' + library not in mod_name: + if not mod_name.startswith('array_api_compat.' + library): continue module = sys.modules[mod_name] diff --git a/torch-xfails.txt b/torch-xfails.txt index caf1aa65..a9106fae 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -190,16 +190,3 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_expm1 array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values - -# fft functions are not yet supported -# (https://github.com/data-apis/array-api-compat/issues/67) -array_api_tests/test_fft.py::test_fftn -array_api_tests/test_fft.py::test_ifftn -array_api_tests/test_fft.py::test_rfft -array_api_tests/test_fft.py::test_irfft -array_api_tests/test_fft.py::test_rfftn -array_api_tests/test_fft.py::test_irfftn -array_api_tests/test_fft.py::test_hfft -array_api_tests/test_fft.py::test_ihfft -array_api_tests/test_fft.py::test_shift_func[fftshift] -array_api_tests/test_fft.py::test_shift_func[ifftshift]