Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type 3 #102

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ API Reference
.. autofunction:: pytorch_finufft.functional.finufft_type1

.. autofunction:: pytorch_finufft.functional.finufft_type2

.. autofunction:: pytorch_finufft.functional.finufft_type3
6 changes: 3 additions & 3 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ Pre-requistes
-------------

Pytorch-FINUFFT requires either ``finufft`` *and/or* ``cufinufft``
2.1.0 or greater.
2.2.0 or greater.

Note that currently, this version of ``cufinufft`` is unreleased
and can only be installed from source. See the relevant installation pages for
These are available via `pip` or can be built from source.
See the relevant installation pages for
:external+finufft:doc:`finufft <install>` and
:external+finufft:doc:`cufinufft <install_gpu>`.

Expand Down
50 changes: 47 additions & 3 deletions pytorch_finufft/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def check_devices(*tensors: torch.Tensor) -> None:
)


def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None:
def check_dtypes(
data: torch.Tensor, points: torch.Tensor, name: str, points_name: str = "Points"
) -> None:
"""
Checks that data is complex-valued
and that points is real-valued of the same precision
Expand All @@ -38,8 +40,8 @@ def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None:

if points.dtype is not real_dtype:
raise TypeError(
f"Points must have a dtype of {real_dtype} as {name.lower()} has a "
f"dtype of {complex_dtype}"
f"{points_name} must have a dtype of {real_dtype} as {name.lower()} has a "
f"dtype of {complex_dtype}, but got {points.dtype} instead"
)


Expand Down Expand Up @@ -102,3 +104,45 @@ def check_sizes_t2(targets: torch.Tensor, points: torch.Tensor) -> None:
f"For type 2 {points_dim}d FINUFFT, targets must be at "
f"least a {points_dim}d tensor"
)


def check_sizes_t3(
points: torch.Tensor, strengths: torch.Tensor, targets: torch.Tensor
) -> None:
"""
Checks that targets and points are of the same dimension.
This is used in type3.
"""
points_len = len(points.shape)
targets_len = len(targets.shape)

if points_len == 1:
points_dim = 1
elif points_len == 2:
points_dim = points.shape[0]
else:
raise ValueError("The points tensor must be 1d or 2d")

if targets_len == 1:
targets_dim = 1
elif targets_len == 2:
targets_dim = targets.shape[0]
else:
raise ValueError("The targets tensor must be 1d or 2d")

if targets_dim != points_dim:
raise ValueError(
"Points and targets must be of the same dimension!"
+ f" Got {points_dim=} and {targets_dim=} instead"
)

if points_dim not in {1, 2, 3}:
raise ValueError(
f"Points and targets can be at most 3d, got {points_dim} instead"
)

n_points = points.shape[-1]
n_strengths = strengths.shape[-1]

if n_points != n_strengths:
raise ValueError("The same number of points and strengths must be supplied")
143 changes: 132 additions & 11 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,6 @@ def backward( # type: ignore[override]
grad_values,
None,
None,
None,
None,
)


Expand Down Expand Up @@ -383,13 +381,7 @@ def vmap( # type: ignore[override]
@staticmethod
def backward( # type: ignore[override]
ctx: Any, grad_output: torch.Tensor
) -> Tuple[
Union[torch.Tensor, None],
Union[torch.Tensor, None],
None,
None,
None,
]:
) -> Tuple[Union[torch.Tensor, None], ...]:
_i_sign = ctx.isign
_mode_ordering = ctx.mode_ordering
finufftkwargs = ctx.finufftkwargs
Expand Down Expand Up @@ -450,11 +442,98 @@ def backward( # type: ignore[override]
grad_points,
grad_targets,
None,
None,
None,
)


class FinufftType3(torch.autograd.Function):
"""
FINUFFT problem type 3
"""

ISIGN_DEFAULT = -1 # note: FINUFFT default is 1

@staticmethod
def setup_context(
ctx: Any,
inputs: Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[Dict[str, Union[int, float]]],
],
output: Any,
) -> None:
points, strengths, targets, finufftkwargs = inputs
if finufftkwargs is None:
finufftkwargs = {}
else: # copy to avoid mutating caller's dictionary
finufftkwargs = finufftkwargs.copy()
ctx.save_for_backward(points, strengths, targets)
ctx.isign = finufftkwargs.pop("isign", FinufftType3.ISIGN_DEFAULT)
ctx.finufftkwargs = finufftkwargs

@staticmethod
def forward( # type: ignore[override]
points: torch.Tensor,
strengths: torch.Tensor,
targets: torch.Tensor,
finufftkwargs: Optional[Dict[str, Union[int, float]]] = None,
) -> torch.Tensor:
checks.check_devices(targets, strengths, points)
checks.check_dtypes(strengths, points, "Strengths")
checks.check_dtypes(strengths, targets, "Strengths", points_name="Targets")
checks.check_sizes_t3(points, strengths, targets)

if finufftkwargs is None:
finufftkwargs = dict()
else:
finufftkwargs = finufftkwargs.copy()

finufftkwargs.setdefault("isign", FinufftType3.ISIGN_DEFAULT)

points = torch.atleast_2d(points)
targets = torch.atleast_2d(targets)

ndim = points.shape[0]
npoints = points.shape[1]
batch_dims = strengths.shape[:-1]

if points.device.type != "cpu":
raise NotImplementedError("Type 3 is not currently implemented for GPU")

nufft_func = get_nufft_func(ndim, 3, points.device)

finufft_out = nufft_func(
*points,
strengths.reshape(-1, npoints),
*targets,
**finufftkwargs,
)
finufft_out = finufft_out.reshape(*batch_dims, targets.shape[-1])

return finufft_out

@staticmethod
def backward( # type: ignore[override]
ctx: Any, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
_i_sign = ctx.isign
# finufftkwargs = ctx.finufftkwargs

points, strengths, targets = ctx.saved_tensors
points = torch.atleast_2d(points)
targets = torch.atleast_2d(targets)

# device = points.device
# ndim = points.shape[0]

grad_points = None
grad_strengths = None
grad_targets = None

return grad_points, grad_strengths, grad_targets, None


def finufft_type1(
points: torch.Tensor,
values: torch.Tensor,
Expand Down Expand Up @@ -534,3 +613,45 @@ def finufft_type2(
"""
res: torch.Tensor = FinufftType2.apply(points, targets, finufftkwargs)
return res


def finufft_type3(
points: torch.Tensor,
strengths: torch.Tensor,
targets: torch.Tensor,
**finufftkwargs: Union[int, float],
) -> torch.Tensor:
"""
Evaluates the Type 3 (nonuniform-to-nonuniform) NUFFT on the inputs.

This is a wrapper around :func:`finufft.nufft1d3`, :func:`finufft.nufft2d3`, and
:func:`finufft.nufft3d3` on CPU.

Note that this function is **not implemented** for GPUs at the time of writing.

Parameters
----------
points : torch.Tensor
DxM tensor of locations of the non-uniform source points.
Points must lie in the range ``[-3pi, 3pi]``.
strengths: torch.Tensor
Complex-valued tensor of source strengths at the non-uniform points.
All dimensions except the final dimension are treated as batch
dimensions. The final dimension must have size ``M``.
targets : torch.Tensor
DxN tensor of locations of the non-uniform target points.
**finufftkwargs : int | float
Additional keyword arguments are forwarded to the underlying
FINUFFT functions. A few notable options are

- ``eps``: precision requested (default: ``1e-6``)
- ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``)
- ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``)

Returns
-------
torch.Tensor
A ``[batch]xN`` tensor of values at the target non-uniform points.
"""
res: torch.Tensor = FinufftType3.apply(points, strengths, targets, finufftkwargs)
return res
Loading
Loading