From f228b584e87dfa1280561e15373f04d530b10762 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 8 Jan 2024 18:40:18 -0700 Subject: [PATCH] Add fft support for numpy and cupy This is based off of https://github.com/numpy/numpy/pull/25317 --- array_api_compat/common/_fft.py | 183 +++++++++++++++++++++++++++++ array_api_compat/cupy/__init__.py | 2 + array_api_compat/cupy/fft.py | 29 +++++ array_api_compat/numpy/__init__.py | 2 + array_api_compat/numpy/fft.py | 29 +++++ 5 files changed, 245 insertions(+) create mode 100644 array_api_compat/common/_fft.py create mode 100644 array_api_compat/cupy/fft.py create mode 100644 array_api_compat/numpy/fft.py diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py new file mode 100644 index 00000000..4d2bb3fe --- /dev/null +++ b/array_api_compat/common/_fft.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Union, Optional, Literal + +if TYPE_CHECKING: + from ._typing import Device, ndarray + from collections.abc import Sequence + +# Note: NumPy fft functions improperly upcast float32 and complex64 to +# complex128, which is why we require wrapping them all here. + +def fft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.fft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def ifft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def fftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def ifftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def rfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) + if x.dtype == xp.float32: + return res.astype(xp.complex64) + return res + +def irfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) + if x.dtype == xp.complex64: + return res.astype(xp.float32) + return res + +def rfftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) + if x.dtype == xp.float32: + return res.astype(xp.complex64) + return res + +def irfftn( + x: ndarray, + /, + xp, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) + if x.dtype == xp.complex64: + return res.astype(xp.float32) + return res + +def hfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def ihfft( + x: ndarray, + /, + xp, + *, + n: Optional[int] = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> ndarray: + res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) + if x.dtype in [xp.float32, xp.complex64]: + return res.astype(xp.complex64) + return res + +def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device {device!r}") + return xp.fft.fftfreq(n, d=d) + +def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device {device!r}") + return xp.fft.rfftfreq(n, d=d) + +def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: + return xp.fft.fftshift(x, axes=axes) + +def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: + return xp.fft.ifftshift(x, axes=axes) + +__all__ = [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", +] diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index ec113f9d..d820e44b 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -9,6 +9,8 @@ # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + from .linalg import matrix_transpose, vecdot from ..common._helpers import * diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py new file mode 100644 index 00000000..8e83abb8 --- /dev/null +++ b/array_api_compat/cupy/fft.py @@ -0,0 +1,29 @@ +from cupy.fft import * +from cupy.fft import __all__ as fft_all + +from ..common import _fft +from .._internal import get_xp + +import cupy as cp + +fft = get_xp(cp)(_fft.fft), +ifft = get_xp(cp)(_fft.ifft), +fftn = get_xp(cp)(_fft.fftn), +ifftn = get_xp(cp)(_fft.ifftn), +rfft = get_xp(cp)(_fft.rfft), +irfft = get_xp(cp)(_fft.irfft), +rfftn = get_xp(cp)(_fft.rfftn), +irfftn = get_xp(cp)(_fft.irfftn), +hfft = get_xp(cp)(_fft.hfft), +ihfft = get_xp(cp)(_fft.ihfft), +fftfreq = get_xp(cp)(_fft.fftfreq), +rfftfreq = get_xp(cp)(_fft.rfftfreq), +fftshift = get_xp(cp)(_fft.fftshift), +ifftshift = get_xp(cp)(_fft.ifftshift), + +__all__ = fft_all + _fft.__all__ + +del get_xp +del cp +del fft_all +del _fft diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 4a49f2f1..ff5efdfd 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -15,6 +15,8 @@ # dynamically so that the library can be vendored. __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + from .linalg import matrix_transpose, vecdot from ..common._helpers import * diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py new file mode 100644 index 00000000..6093b19d --- /dev/null +++ b/array_api_compat/numpy/fft.py @@ -0,0 +1,29 @@ +from numpy.fft import * +from numpy.fft import __all__ as fft_all + +from ..common import _fft +from .._internal import get_xp + +import numpy as np + +fft = get_xp(np)(_fft.fft) +ifft = get_xp(np)(_fft.ifft) +fftn = get_xp(np)(_fft.fftn) +ifftn = get_xp(np)(_fft.ifftn) +rfft = get_xp(np)(_fft.rfft) +irfft = get_xp(np)(_fft.irfft) +rfftn = get_xp(np)(_fft.rfftn) +irfftn = get_xp(np)(_fft.irfftn) +hfft = get_xp(np)(_fft.hfft) +ihfft = get_xp(np)(_fft.ihfft) +fftfreq = get_xp(np)(_fft.fftfreq) +rfftfreq = get_xp(np)(_fft.rfftfreq) +fftshift = get_xp(np)(_fft.fftshift) +ifftshift = get_xp(np)(_fft.ifftshift) + +__all__ = fft_all + _fft.__all__ + +del get_xp +del np +del fft_all +del _fft