Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP consolidation of dimensionalities for nufft type 1 #64

Merged
merged 10 commits into from
Oct 6, 2023
159 changes: 159 additions & 0 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import finufft
import torch

Expand Down Expand Up @@ -1592,3 +1593,161 @@ def backward(
None,
None,
)





###############################################################################
# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1
###############################################################################

def get_nufft_func(dim, nufft_type):
return getattr(finufft, f"nufft{dim}d{nufft_type}")


class finufft_type1(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
points: torch.Tensor,
values: torch.Tensor,
output_shape: Union[int, tuple[int, int], tuple[int, int, int]],
out: Optional[torch.Tensor]=None,
fftshift: bool=False,
finufftkwargs: dict[str, Union[int, float]]=None):
"""
Evaluates the Type 1 NUFFT on the inputs.

"""

if out is not None:
print("In-place results are not yet implemented")
# All this requires is a check on the out array to make sure it is the
# correct shape.

err._type1_checks(points, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately
# ^ make sure these checks check for consistency between output shape and len(points)

if finufftkwargs is None:
finufftkwargs = dict()
finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop("modeord", 1)
_i_sign = finufftkwargs.pop("isign", -1)

if fftshift:
# TODO -- this check should be done elsewhere? or error msg changed
# to note instead that there is a conflict in fftshift
if _mode_ordering != 1:
raise ValueError(
"Double specification of ordering; only one of fftshift and modeord should be provided"
)
_mode_ordering = 0

ctx.save_for_backward(points, values)

ctx.isign = _i_sign
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs

# this below should be a pre-check
ndim = points.shape[0]
assert len(output_shape) == ndim

nufft_func = get_nufft_func(ndim, 1)
finufft_out = torch.from_numpy(
nufft_func(
*points.data.numpy(),
values.data.numpy(),
output_shape,
modeord=_mode_ordering,
isign=_i_sign,
**finufftkwargs,
)
)

return finufft_out

@staticmethod
def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Union[torch.Tensor, None], ...]:
"""
Implements derivatives wrt. each argument in the forward method.

Parameters
----------
ctx : Any
Pytorch context object.
grad_output : torch.Tensor
Backpass gradient output

Returns
-------
tuple[Union[torch.Tensor, None], ...]
A tuple of derivatives wrt. each argument in the forward method
"""
_i_sign = -1 * ctx.isign
_mode_ordering = ctx.mode_ordering
finufftkwargs = ctx.finufftkwargs

points, values = ctx.saved_tensors

start_points = -(np.array(grad_output.shape) // 2)
end_points = start_points + grad_output.shape
slices = tuple(slice(start, end) for start, end in zip(start_points, end_points))

# CPU idiosyncracy that needs to be done differently
coord_ramps = torch.from_numpy(np.mgrid[slices])

grads_points = None
grad_values = None

ndim = points.shape[0]

nufft_func = get_nufft_func(ndim, 2)

if ctx.needs_input_grad[0]:
# wrt points

if _mode_ordering != 0:
coord_ramps = torch.fft.ifftshift(coord_ramps, dim=tuple(range(1, ndim+1)))

ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign

grads_points = []
for ramp in ramped_grad_output: # we can batch this into finufft
backprop_ramp = torch.from_numpy(
nufft_func(
*points.numpy(),
ramp.data.numpy(),
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
))
grad_points = (backprop_ramp.conj() * values).real
grads_points.append(grad_points)

grads_points = torch.stack(grads_points)

if ctx.needs_input_grad[1]:
np_grad_output = grad_output.data.numpy()

grad_values = torch.from_numpy(
nufft_func(
*points.numpy(),
np_grad_output,
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
)
)

return (
grads_points,
grad_values,
None,
None,
None,
None,
)
45 changes: 45 additions & 0 deletions tests/test_1d/test_forward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def test_1d_t1_forward_CPU(values: torch.Tensor) -> None:
) == pytest.approx(0, abs=1e-06)


abs_errors = torch.abs(finufft1D1_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 3.5e-3 * N ** .6
assert l_2_error < 7.5e-4 * N ** 1.1
assert l_1_error < 5e-4 * N ** 1.6


@pytest.mark.parametrize("targets", cases)
def test_1d_t2_forward_CPU(targets: torch.Tensor):
"""
Expand Down Expand Up @@ -96,6 +106,41 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor):
)


@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N] * 2 * np.pi / N
g.shape = 1, -1
points = torch.from_numpy(g.reshape(1, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N,),
)

against_torch = torch.fft.fft(values.reshape(g[0].shape))

abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 4.5e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3



# @pytest.mark.parametrize("values", cases)
# def test_1d_t3_forward_CPU(values: torch.Tensor) -> None:
# """
Expand Down
47 changes: 47 additions & 0 deletions tests/test_2d/test_backward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

import pytorch_finufft

from functools import partial

torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

######################################################################
# APPLY WRAPPERS
Expand Down Expand Up @@ -97,6 +100,50 @@ def test_t1_backward_CPU_values(
assert gradcheck(apply_finufft2d1(modifier, fftshift, isign), inputs)


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t1_consolidated_backward_CPU_values(N: int, modifier: int, fftshift: bool, isign: int) -> None:

points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi
values = torch.randn(N, dtype=torch.complex128)

points.requires_grad = False
values.requires_grad = True

inputs = (points, values)

def func(points, values):
return pytorch_finufft.functional.finufft_type1.apply(
points, values, (N,N + modifier), None, fftshift, dict(isign=isign)
)

assert gradcheck(func, inputs)


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t1_consolidated_backward_CPU_points(N: int, modifier: int, fftshift: bool, isign: int) -> None:

points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi
values = torch.randn(N, dtype=torch.complex128)

points.requires_grad = True
values.requires_grad = False

inputs = (points, values)

def func(points, values):
return pytorch_finufft.functional.finufft_type1.apply(
points, values, (N,N + modifier), None, fftshift, dict(isign=isign)
)

assert gradcheck(func, inputs, atol=1e-5 * N)


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [True, False])
Expand Down
74 changes: 50 additions & 24 deletions tests/test_2d/test_forward_2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
import torch
torch.manual_seed(0)

import pytorch_finufft

Expand Down Expand Up @@ -45,28 +46,14 @@ def test_2d_t1_forward_CPU(N: int) -> None:

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-6
)
abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

values = torch.randn(*x.shape, dtype=torch.complex64)

finufft_out = pytorch_finufft.functional.finufft2D1.apply(
torch.from_numpy(x).to(torch.float32),
torch.from_numpy(y).to(torch.float32),
values,
N,
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

# NOTE -- the below tolerance is set to 1e-5 instead of -6 due
# to the occasional failing case that seems to be caused by
# the randomness of the test cases in addition to the expected
# accruation of numerical inaccuracies
assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-5
)
assert l_inf_error < 5e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3


@pytest.mark.parametrize("N", Ns)
Expand Down Expand Up @@ -109,9 +96,14 @@ def test_2d_t2_forward_CPU(N: int) -> None:

against_torch = torch.fft.ifft2(values)

assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-6
)
abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 1e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3


# @pytest.mark.parametrize("N", Ns)
Expand All @@ -128,3 +120,37 @@ def test_2d_t2_forward_CPU(N: int) -> None:
# assert abs((f - comparison).sum()) / (N**3) == pytest.approx(0, abs=1e-6)

# pass


@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N, :N] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(2, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N, N),
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 4.5e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3

Loading
Loading