forked from data-apis/array-api-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This is based off of numpy/numpy#25317
- Loading branch information
Showing
5 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Union, Optional, Literal | ||
|
||
if TYPE_CHECKING: | ||
from ._typing import Device, ndarray | ||
from collections.abc import Sequence | ||
|
||
# Note: NumPy fft functions improperly upcast float32 and complex64 to | ||
# complex128, which is why we require wrapping them all here. | ||
|
||
def fft( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
n: Optional[int] = None, | ||
axis: int = -1, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.fft(x, n=n, axis=axis, norm=norm) | ||
if x.dtype in [xp.float32, xp.complex64]: | ||
return res.astype(xp.complex64) | ||
return res | ||
|
||
def ifft( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
n: Optional[int] = None, | ||
axis: int = -1, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) | ||
if x.dtype in [xp.float32, xp.complex64]: | ||
return res.astype(xp.complex64) | ||
return res | ||
|
||
def fftn( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
s: Sequence[int] = None, | ||
axes: Sequence[int] = None, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) | ||
if x.dtype in [xp.float32, xp.complex64]: | ||
return res.astype(xp.complex64) | ||
return res | ||
|
||
def ifftn( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
s: Sequence[int] = None, | ||
axes: Sequence[int] = None, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) | ||
if x.dtype in [xp.float32, xp.complex64]: | ||
return res.astype(xp.complex64) | ||
return res | ||
|
||
def rfft( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
n: Optional[int] = None, | ||
axis: int = -1, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) | ||
if x.dtype == xp.float32: | ||
return res.astype(xp.complex64) | ||
return res | ||
|
||
def irfft( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
n: Optional[int] = None, | ||
axis: int = -1, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) | ||
if x.dtype == xp.complex64: | ||
return res.astype(xp.float32) | ||
return res | ||
|
||
def rfftn( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
s: Sequence[int] = None, | ||
axes: Sequence[int] = None, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) | ||
if x.dtype == xp.float32: | ||
return res.astype(xp.complex64) | ||
return res | ||
|
||
def irfftn( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
s: Sequence[int] = None, | ||
axes: Sequence[int] = None, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) | ||
if x.dtype == xp.complex64: | ||
return res.astype(xp.float32) | ||
return res | ||
|
||
def hfft( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
n: Optional[int] = None, | ||
axis: int = -1, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) | ||
if x.dtype in [xp.float32, xp.complex64]: | ||
return res.astype(xp.complex64) | ||
return res | ||
|
||
def ihfft( | ||
x: ndarray, | ||
/, | ||
xp, | ||
*, | ||
n: Optional[int] = None, | ||
axis: int = -1, | ||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||
) -> ndarray: | ||
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) | ||
if x.dtype in [xp.float32, xp.complex64]: | ||
return res.astype(xp.complex64) | ||
return res | ||
|
||
def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: | ||
if device not in ["cpu", None]: | ||
raise ValueError(f"Unsupported device {device!r}") | ||
return xp.fft.fftfreq(n, d=d) | ||
|
||
def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: | ||
if device not in ["cpu", None]: | ||
raise ValueError(f"Unsupported device {device!r}") | ||
return xp.fft.rfftfreq(n, d=d) | ||
|
||
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: | ||
return xp.fft.fftshift(x, axes=axes) | ||
|
||
def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: | ||
return xp.fft.ifftshift(x, axes=axes) | ||
|
||
__all__ = [ | ||
"fft", | ||
"ifft", | ||
"fftn", | ||
"ifftn", | ||
"rfft", | ||
"irfft", | ||
"rfftn", | ||
"irfftn", | ||
"hfft", | ||
"ihfft", | ||
"fftfreq", | ||
"rfftfreq", | ||
"fftshift", | ||
"ifftshift", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from cupy.fft import * | ||
from cupy.fft import __all__ as fft_all | ||
|
||
from ..common import _fft | ||
from .._internal import get_xp | ||
|
||
import cupy as cp | ||
|
||
fft = get_xp(cp)(_fft.fft), | ||
ifft = get_xp(cp)(_fft.ifft), | ||
fftn = get_xp(cp)(_fft.fftn), | ||
ifftn = get_xp(cp)(_fft.ifftn), | ||
rfft = get_xp(cp)(_fft.rfft), | ||
irfft = get_xp(cp)(_fft.irfft), | ||
rfftn = get_xp(cp)(_fft.rfftn), | ||
irfftn = get_xp(cp)(_fft.irfftn), | ||
hfft = get_xp(cp)(_fft.hfft), | ||
ihfft = get_xp(cp)(_fft.ihfft), | ||
fftfreq = get_xp(cp)(_fft.fftfreq), | ||
rfftfreq = get_xp(cp)(_fft.rfftfreq), | ||
fftshift = get_xp(cp)(_fft.fftshift), | ||
ifftshift = get_xp(cp)(_fft.ifftshift), | ||
|
||
__all__ = fft_all + _fft.__all__ | ||
|
||
del get_xp | ||
del cp | ||
del fft_all | ||
del _fft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from numpy.fft import * | ||
from numpy.fft import __all__ as fft_all | ||
|
||
from ..common import _fft | ||
from .._internal import get_xp | ||
|
||
import numpy as np | ||
|
||
fft = get_xp(np)(_fft.fft) | ||
ifft = get_xp(np)(_fft.ifft) | ||
fftn = get_xp(np)(_fft.fftn) | ||
ifftn = get_xp(np)(_fft.ifftn) | ||
rfft = get_xp(np)(_fft.rfft) | ||
irfft = get_xp(np)(_fft.irfft) | ||
rfftn = get_xp(np)(_fft.rfftn) | ||
irfftn = get_xp(np)(_fft.irfftn) | ||
hfft = get_xp(np)(_fft.hfft) | ||
ihfft = get_xp(np)(_fft.ihfft) | ||
fftfreq = get_xp(np)(_fft.fftfreq) | ||
rfftfreq = get_xp(np)(_fft.rfftfreq) | ||
fftshift = get_xp(np)(_fft.fftshift) | ||
ifftshift = get_xp(np)(_fft.ifftshift) | ||
|
||
__all__ = fft_all + _fft.__all__ | ||
|
||
del get_xp | ||
del np | ||
del fft_all | ||
del _fft |