diff --git a/docs/api.rst b/docs/api.rst index c1d1601..7b0d305 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -6,3 +6,5 @@ API Reference .. autofunction:: pytorch_finufft.functional.finufft_type1 .. autofunction:: pytorch_finufft.functional.finufft_type2 + +.. autofunction:: pytorch_finufft.functional.finufft_type3 diff --git a/docs/installation.rst b/docs/installation.rst index 49dbcd3..4a7723e 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -6,10 +6,10 @@ Pre-requistes ------------- Pytorch-FINUFFT requires either ``finufft`` *and/or* ``cufinufft`` -2.1.0 or greater. +2.2.0 or greater. -Note that currently, this version of ``cufinufft`` is unreleased -and can only be installed from source. See the relevant installation pages for +These are available via `pip` or can be built from source. +See the relevant installation pages for :external+finufft:doc:`finufft ` and :external+finufft:doc:`cufinufft `. diff --git a/pytorch_finufft/checks.py b/pytorch_finufft/checks.py index eedabae..b1a20db 100644 --- a/pytorch_finufft/checks.py +++ b/pytorch_finufft/checks.py @@ -21,7 +21,9 @@ def check_devices(*tensors: torch.Tensor) -> None: ) -def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None: +def check_dtypes( + data: torch.Tensor, points: torch.Tensor, name: str, points_name: str = "Points" +) -> None: """ Checks that data is complex-valued and that points is real-valued of the same precision @@ -38,8 +40,8 @@ def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None: if points.dtype is not real_dtype: raise TypeError( - f"Points must have a dtype of {real_dtype} as {name.lower()} has a " - f"dtype of {complex_dtype}" + f"{points_name} must have a dtype of {real_dtype} as {name.lower()} has a " + f"dtype of {complex_dtype}, but got {points.dtype} instead" ) @@ -102,3 +104,45 @@ def check_sizes_t2(targets: torch.Tensor, points: torch.Tensor) -> None: f"For type 2 {points_dim}d FINUFFT, targets must be at " f"least a {points_dim}d tensor" ) + + +def check_sizes_t3( + points: torch.Tensor, strengths: torch.Tensor, targets: torch.Tensor +) -> None: + """ + Checks that targets and points are of the same dimension. + This is used in type3. + """ + points_len = len(points.shape) + targets_len = len(targets.shape) + + if points_len == 1: + points_dim = 1 + elif points_len == 2: + points_dim = points.shape[0] + else: + raise ValueError("The points tensor must be 1d or 2d") + + if targets_len == 1: + targets_dim = 1 + elif targets_len == 2: + targets_dim = targets.shape[0] + else: + raise ValueError("The targets tensor must be 1d or 2d") + + if targets_dim != points_dim: + raise ValueError( + "Points and targets must be of the same dimension!" + + f" Got {points_dim=} and {targets_dim=} instead" + ) + + if points_dim not in {1, 2, 3}: + raise ValueError( + f"Points and targets can be at most 3d, got {points_dim} instead" + ) + + n_points = points.shape[-1] + n_strengths = strengths.shape[-1] + + if n_points != n_strengths: + raise ValueError("The same number of points and strengths must be supplied") diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index cdd71bf..7133fe9 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -266,8 +266,6 @@ def backward( # type: ignore[override] grad_values, None, None, - None, - None, ) @@ -383,13 +381,7 @@ def vmap( # type: ignore[override] @staticmethod def backward( # type: ignore[override] ctx: Any, grad_output: torch.Tensor - ) -> Tuple[ - Union[torch.Tensor, None], - Union[torch.Tensor, None], - None, - None, - None, - ]: + ) -> Tuple[Union[torch.Tensor, None], ...]: _i_sign = ctx.isign _mode_ordering = ctx.mode_ordering finufftkwargs = ctx.finufftkwargs @@ -450,11 +442,98 @@ def backward( # type: ignore[override] grad_points, grad_targets, None, - None, - None, ) +class FinufftType3(torch.autograd.Function): + """ + FINUFFT problem type 3 + """ + + ISIGN_DEFAULT = -1 # note: FINUFFT default is 1 + + @staticmethod + def setup_context( + ctx: Any, + inputs: Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[Dict[str, Union[int, float]]], + ], + output: Any, + ) -> None: + points, strengths, targets, finufftkwargs = inputs + if finufftkwargs is None: + finufftkwargs = {} + else: # copy to avoid mutating caller's dictionary + finufftkwargs = finufftkwargs.copy() + ctx.save_for_backward(points, strengths, targets) + ctx.isign = finufftkwargs.pop("isign", FinufftType3.ISIGN_DEFAULT) + ctx.finufftkwargs = finufftkwargs + + @staticmethod + def forward( # type: ignore[override] + points: torch.Tensor, + strengths: torch.Tensor, + targets: torch.Tensor, + finufftkwargs: Optional[Dict[str, Union[int, float]]] = None, + ) -> torch.Tensor: + checks.check_devices(targets, strengths, points) + checks.check_dtypes(strengths, points, "Strengths") + checks.check_dtypes(strengths, targets, "Strengths", points_name="Targets") + checks.check_sizes_t3(points, strengths, targets) + + if finufftkwargs is None: + finufftkwargs = dict() + else: + finufftkwargs = finufftkwargs.copy() + + finufftkwargs.setdefault("isign", FinufftType3.ISIGN_DEFAULT) + + points = torch.atleast_2d(points) + targets = torch.atleast_2d(targets) + + ndim = points.shape[0] + npoints = points.shape[1] + batch_dims = strengths.shape[:-1] + + if points.device.type != "cpu": + raise NotImplementedError("Type 3 is not currently implemented for GPU") + + nufft_func = get_nufft_func(ndim, 3, points.device) + + finufft_out = nufft_func( + *points, + strengths.reshape(-1, npoints), + *targets, + **finufftkwargs, + ) + finufft_out = finufft_out.reshape(*batch_dims, targets.shape[-1]) + + return finufft_out + + @staticmethod + def backward( # type: ignore[override] + ctx: Any, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + _i_sign = ctx.isign + # finufftkwargs = ctx.finufftkwargs + + points, strengths, targets = ctx.saved_tensors + points = torch.atleast_2d(points) + targets = torch.atleast_2d(targets) + + # device = points.device + # ndim = points.shape[0] + + grad_points = None + grad_strengths = None + grad_targets = None + + return grad_points, grad_strengths, grad_targets, None + + def finufft_type1( points: torch.Tensor, values: torch.Tensor, @@ -534,3 +613,45 @@ def finufft_type2( """ res: torch.Tensor = FinufftType2.apply(points, targets, finufftkwargs) return res + + +def finufft_type3( + points: torch.Tensor, + strengths: torch.Tensor, + targets: torch.Tensor, + **finufftkwargs: Union[int, float], +) -> torch.Tensor: + """ + Evaluates the Type 3 (nonuniform-to-nonuniform) NUFFT on the inputs. + + This is a wrapper around :func:`finufft.nufft1d3`, :func:`finufft.nufft2d3`, and + :func:`finufft.nufft3d3` on CPU. + + Note that this function is **not implemented** for GPUs at the time of writing. + + Parameters + ---------- + points : torch.Tensor + DxM tensor of locations of the non-uniform source points. + Points must lie in the range ``[-3pi, 3pi]``. + strengths: torch.Tensor + Complex-valued tensor of source strengths at the non-uniform points. + All dimensions except the final dimension are treated as batch + dimensions. The final dimension must have size ``M``. + targets : torch.Tensor + DxN tensor of locations of the non-uniform target points. + **finufftkwargs : int | float + Additional keyword arguments are forwarded to the underlying + FINUFFT functions. A few notable options are + + - ``eps``: precision requested (default: ``1e-6``) + - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``) + - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``) + + Returns + ------- + torch.Tensor + A ``[batch]xN`` tensor of values at the target non-uniform points. + """ + res: torch.Tensor = FinufftType3.apply(points, strengths, targets, finufftkwargs) + return res diff --git a/tests/test_errors.py b/tests/test_errors.py index 69d128e..ad73d5d 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -46,6 +46,37 @@ def test_t2_mismatch_cuda_index() -> None: pytorch_finufft.functional.finufft_type2(points, targets) +def test_t3_mismatch_device_cuda_cpu() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points.to("cuda:0"), values, targets) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points, values.to("cuda:0"), targets) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points, values, targets.to("cuda:0")) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="require multiple GPUs") +def test_t3_mismatch_cuda_index() -> None: + points = torch.rand((2, 10), dtype=torch.float64).to("cuda:0") + values = torch.randn(10, dtype=torch.complex128).to("cuda:0") + targets = torch.rand((2, 10), dtype=torch.float64).to("cuda:0") + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points.to("cuda:1"), values, targets) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points, values.to("cuda:1"), targets) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points, values, targets.to("cuda:1")) + + # dtypes @@ -171,6 +202,101 @@ def test_t2_mismatch_precision() -> None: pytorch_finufft.functional.finufft_type2(points, targets) +def test_t3_non_complex_targets() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.float64) + targets = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises( + TypeError, + match="Strengths must have a dtype of torch.complex64 or torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_half_complex_targets() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + targets = torch.rand((2, 10), dtype=torch.float64) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + values = torch.randn(10, dtype=torch.complex32) + + with pytest.raises( + TypeError, + match="Strengths must have a dtype of torch.complex64 or torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_non_real_points() -> None: + points = torch.rand((2, 10), dtype=torch.complex128) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float64 as strengths has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_non_real_targets() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((2, 10), dtype=torch.complex128) + + with pytest.raises( + TypeError, + match="Targets must have a dtype of torch.float64 as strengths has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_mismatch_precision() -> None: + points = torch.rand((2, 10), dtype=torch.float32) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float64 as strengths has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + points = points.to(torch.float64) + targets = targets.to(torch.float32) + + with pytest.raises( + TypeError, + match="Targets must have a dtype of torch.float64 as strengths has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + values = values.to(torch.complex64) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float32 as strengths has " + "a dtype of torch.complex64", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + points = points.to(torch.float32) + targets = targets.to(torch.float64) + + with pytest.raises( + TypeError, + match="Targets must have a dtype of torch.float32 as strengths has " + "a dtype of torch.complex64", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + # sizes @@ -272,6 +398,51 @@ def test_t2_mismatch_dims() -> None: pytorch_finufft.functional.finufft_type2(points, targets) +def test_t3_points_targets_4d() -> None: + points = torch.rand((4, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((4, 10), dtype=torch.float64) + + with pytest.raises(ValueError, match="Points and targets can be at most 3d, got"): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_points_targets_mismatch_dims() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((3, 10), dtype=torch.float64) + + with pytest.raises( + ValueError, match="Points and targets must be of the same dimension" + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_too_many_dims() -> None: + points = torch.rand((1, 4, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((1, 4, 10), dtype=torch.float64) + + with pytest.raises(ValueError, match="The points tensor must be 1d or 2d"): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + points = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises(ValueError, match="The targets tensor must be 1d or 2d"): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_mismatch_dims() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(11, dtype=torch.complex128) + targets = torch.rand((2, 12), dtype=torch.float64) + + with pytest.raises( + ValueError, match="The same number of points and strengths must be supplied" + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + # dependencies def test_finufft_not_installed(): if not pytorch_finufft.functional.CUFINUFFT_AVAIL: diff --git a/tests/test_t3_forward.py b/tests/test_t3_forward.py new file mode 100644 index 0000000..0b2142c --- /dev/null +++ b/tests/test_t3_forward.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest +import torch + +import pytorch_finufft + +torch.manual_seed(45678) + + +def check_t3_forward(N: int, dim: int, device: str) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + + slices = tuple(slice(None, N) for _ in range(dim)) + g = np.mgrid[slices] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(dim, -1)).to(device) + values = torch.randn(*g[0].shape, dtype=torch.complex128).to(device) + targets = ( + torch.from_numpy(np.mgrid[slices].astype(np.float64)) + .reshape(dim, -1) + .to(device) + ) + + print("N is " + str(N)) + print("dim is " + str(dim)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type3( + points, values.flatten(), targets, eps=1e-7 + ) + + against_torch = torch.fft.fftn(values) + + abs_errors = torch.abs(finufft_out - against_torch.flatten()) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 5e-5 * N**1.5 + assert l_2_error < 1.5e-5 * N**3.2 + assert l_1_error < 1.5e-5 * N**4.5 + + +Ns_and_dims = [ + (2, 1), + (3, 1), + (5, 1), + (10, 1), + (100, 1), + (101, 1), + (1000, 1), + (10001, 1), + (2, 2), + (3, 2), + (5, 2), + (10, 2), + (101, 2), + (2, 3), + (3, 3), + (5, 3), + (10, 3), + (37, 3), +] + + +@pytest.mark.parametrize("N, dim", Ns_and_dims) +def test_t3_forward_CPU(N, dim) -> None: + check_t3_forward(N, dim, "cpu")