Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fft support for numpy and cupy #78

Merged
merged 15 commits into from
Mar 8, 2024
183 changes: 183 additions & 0 deletions array_api_compat/common/_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Union, Optional, Literal

if TYPE_CHECKING:
from ._typing import Device, ndarray
from collections.abc import Sequence

# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.

def fft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ifft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def fftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ifftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def rfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res

def irfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res

def rfftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res

def irfftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res

def hfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.float32)
return res

def ihfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return xp.fft.fftfreq(n, d=d)

def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return xp.fft.rfftfreq(n, d=d)

def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
return xp.fft.fftshift(x, axes=axes)

def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
return xp.fft.ifftshift(x, axes=axes)

__all__ = [
"fft",
"ifft",
"fftn",
"ifftn",
"rfft",
"irfft",
"rfftn",
"irfftn",
"hfft",
"ihfft",
"fftfreq",
"rfftfreq",
"fftshift",
"ifftshift",
]
2 changes: 2 additions & 0 deletions array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')

__import__(__package__ + '.fft')

from ..common._helpers import * # noqa: F401,F403

__array_api_version__ = '2022.12'
2 changes: 2 additions & 0 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,5 @@
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']

_all_ignore = ['cp', 'get_xp']
36 changes: 36 additions & 0 deletions array_api_compat/cupy/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from cupy.fft import * # noqa: F403
# cupy.fft doesn't have __all__. If it is added, replace this with
#
# from cupy.fft import __all__ as linalg_all
_n = {}
exec('from cupy.fft import *', _n)
del _n['__builtins__']
fft_all = list(_n)
del _n

from ..common import _fft
from .._internal import get_xp

import cupy as cp

fft = get_xp(cp)(_fft.fft)
ifft = get_xp(cp)(_fft.ifft)
fftn = get_xp(cp)(_fft.fftn)
ifftn = get_xp(cp)(_fft.ifftn)
rfft = get_xp(cp)(_fft.rfft)
irfft = get_xp(cp)(_fft.irfft)
rfftn = get_xp(cp)(_fft.rfftn)
irfftn = get_xp(cp)(_fft.irfftn)
hfft = get_xp(cp)(_fft.hfft)
ihfft = get_xp(cp)(_fft.ihfft)
fftfreq = get_xp(cp)(_fft.fftfreq)
rfftfreq = get_xp(cp)(_fft.rfftfreq)
fftshift = get_xp(cp)(_fft.fftshift)
ifftshift = get_xp(cp)(_fft.ifftshift)

__all__ = fft_all + _fft.__all__

del get_xp
del cp
del fft_all
del _fft
2 changes: 2 additions & 0 deletions array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# dynamically so that the library can be vendored.
__import__(__package__ + '.linalg')

__import__(__package__ + '.fft')

from .linalg import matrix_transpose, vecdot # noqa: F401

from ..common._helpers import * # noqa: F403
Expand Down
29 changes: 29 additions & 0 deletions array_api_compat/numpy/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from numpy.fft import * # noqa: F403
from numpy.fft import __all__ as fft_all

from ..common import _fft
from .._internal import get_xp

import numpy as np

fft = get_xp(np)(_fft.fft)
ifft = get_xp(np)(_fft.ifft)
fftn = get_xp(np)(_fft.fftn)
ifftn = get_xp(np)(_fft.ifftn)
rfft = get_xp(np)(_fft.rfft)
irfft = get_xp(np)(_fft.irfft)
rfftn = get_xp(np)(_fft.rfftn)
irfftn = get_xp(np)(_fft.irfftn)
hfft = get_xp(np)(_fft.hfft)
ihfft = get_xp(np)(_fft.ihfft)
fftfreq = get_xp(np)(_fft.fftfreq)
rfftfreq = get_xp(np)(_fft.rfftfreq)
fftshift = get_xp(np)(_fft.fftshift)
ifftshift = get_xp(np)(_fft.ifftshift)

__all__ = fft_all + _fft.__all__

del get_xp
del np
del fft_all
del _fft
2 changes: 2 additions & 0 deletions array_api_compat/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')

__import__(__package__ + '.fft')

from ..common._helpers import * # noqa: F403

__array_api_version__ = '2022.12'
86 changes: 86 additions & 0 deletions array_api_compat/torch/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
array = torch.Tensor
from typing import Union, Sequence, Literal

from torch.fft import * # noqa: F403
import torch.fft

# Several torch fft functions do not map axes to dim

def fftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)

def ifftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)

def rfftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)

def irfftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)

def fftshift(
x: array,
/,
*,
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
return torch.fft.fftshift(x, dim=axes, **kwargs)

def ifftshift(
x: array,
/,
*,
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
return torch.fft.ifftshift(x, dim=axes, **kwargs)


__all__ = torch.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
"irfftn",
"fftshift",
"ifftshift",
]

_all_ignore = ['torch']
11 changes: 2 additions & 9 deletions cupy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,8 @@ array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0]
array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0]
array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0]

# fft functions are not yet supported
# (https://github.com/data-apis/array-api-compat/issues/67)
array_api_tests/test_fft.py::test_fft
array_api_tests/test_fft.py::test_ifft
# CuPy gives the wrong shape for n-dim fft funcs. See
# https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870
array_api_tests/test_fft.py::test_fftn
array_api_tests/test_fft.py::test_ifftn
array_api_tests/test_fft.py::test_rfft
array_api_tests/test_fft.py::test_irfft
array_api_tests/test_fft.py::test_rfftn
array_api_tests/test_fft.py::test_irfftn
array_api_tests/test_fft.py::test_hfft
array_api_tests/test_fft.py::test_ihfft
Loading