Skip to content

Commit

Permalink
Add fft support for numpy and cupy
Browse files Browse the repository at this point in the history
This is based off of numpy/numpy#25317
  • Loading branch information
asmeurer committed Jan 9, 2024
1 parent d235910 commit f228b58
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 0 deletions.
183 changes: 183 additions & 0 deletions array_api_compat/common/_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Union, Optional, Literal

if TYPE_CHECKING:
from ._typing import Device, ndarray
from collections.abc import Sequence

# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.

def fft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ifft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def fftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ifftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def rfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res

def irfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res

def rfftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res

def irfftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res

def hfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ihfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return xp.fft.fftfreq(n, d=d)

def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return xp.fft.rfftfreq(n, d=d)

def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
return xp.fft.fftshift(x, axes=axes)

def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
return xp.fft.ifftshift(x, axes=axes)

__all__ = [
"fft",
"ifft",
"fftn",
"ifftn",
"rfft",
"irfft",
"rfftn",
"irfftn",
"hfft",
"ihfft",
"fftfreq",
"rfftfreq",
"fftshift",
"ifftshift",
]
2 changes: 2 additions & 0 deletions array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')

__import__(__package__ + '.fft')

from .linalg import matrix_transpose, vecdot

from ..common._helpers import *
Expand Down
29 changes: 29 additions & 0 deletions array_api_compat/cupy/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from cupy.fft import *
from cupy.fft import __all__ as fft_all

from ..common import _fft
from .._internal import get_xp

import cupy as cp

fft = get_xp(cp)(_fft.fft),
ifft = get_xp(cp)(_fft.ifft),
fftn = get_xp(cp)(_fft.fftn),
ifftn = get_xp(cp)(_fft.ifftn),
rfft = get_xp(cp)(_fft.rfft),
irfft = get_xp(cp)(_fft.irfft),
rfftn = get_xp(cp)(_fft.rfftn),
irfftn = get_xp(cp)(_fft.irfftn),
hfft = get_xp(cp)(_fft.hfft),
ihfft = get_xp(cp)(_fft.ihfft),
fftfreq = get_xp(cp)(_fft.fftfreq),
rfftfreq = get_xp(cp)(_fft.rfftfreq),
fftshift = get_xp(cp)(_fft.fftshift),
ifftshift = get_xp(cp)(_fft.ifftshift),

__all__ = fft_all + _fft.__all__

del get_xp
del cp
del fft_all
del _fft
2 changes: 2 additions & 0 deletions array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# dynamically so that the library can be vendored.
__import__(__package__ + '.linalg')

__import__(__package__ + '.fft')

from .linalg import matrix_transpose, vecdot

from ..common._helpers import *
Expand Down
29 changes: 29 additions & 0 deletions array_api_compat/numpy/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from numpy.fft import *
from numpy.fft import __all__ as fft_all

from ..common import _fft
from .._internal import get_xp

import numpy as np

fft = get_xp(np)(_fft.fft)
ifft = get_xp(np)(_fft.ifft)
fftn = get_xp(np)(_fft.fftn)
ifftn = get_xp(np)(_fft.ifftn)
rfft = get_xp(np)(_fft.rfft)
irfft = get_xp(np)(_fft.irfft)
rfftn = get_xp(np)(_fft.rfftn)
irfftn = get_xp(np)(_fft.irfftn)
hfft = get_xp(np)(_fft.hfft)
ihfft = get_xp(np)(_fft.ihfft)
fftfreq = get_xp(np)(_fft.fftfreq)
rfftfreq = get_xp(np)(_fft.rfftfreq)
fftshift = get_xp(np)(_fft.fftshift)
ifftshift = get_xp(np)(_fft.ifftshift)

__all__ = fft_all + _fft.__all__

del get_xp
del np
del fft_all
del _fft

0 comments on commit f228b58

Please sign in to comment.