Skip to content

Commit

Permalink
WIP outline of consolidation idea for type 1
Browse files Browse the repository at this point in the history
  • Loading branch information
eickenberg committed Oct 4, 2023
1 parent d8a8e4b commit 88e8741
Showing 1 changed file with 148 additions and 0 deletions.
148 changes: 148 additions & 0 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,3 +1592,151 @@ def backward(
None,
None,
)





###############################################################################
# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1
###############################################################################

# This function takes a ctx object and is supposed to later replace all type 1
# functions above for all dimensionalities.

def finufft_type1(
ctx: Any,
points: torch.Tensor,
values: torch.Tensor,
output_shape: Union[int, tuple[int, int], tuple[int, int, int]],
out: Optional[torch.Tensor]=None,
fftshift: bool=False,
finufftkwargs: dict[str, Union[int, float]]=None):
"""
Evaluates the Type 1 NUFFT on the inputs.
"""

if out is not None:
print("In-place results are not yet implemented")
# All this requires is a check on the out array to make sure it is the
# correct shape.

err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately


if finufftkwargs is None:
finufftkwargs = dict()
finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop("modeord", 1)
_i_sign = finufftkwargs.pop("isign", -1)

if fftshift:
# TODO -- this check should be done elsewhere? or error msg changed
# to note instead that there is a conflict in fftshift
if _mode_ordering != 1:
raise ValueError(
"Double specification of ordering; only one of fftshift and modeord should be provided"
)
_mode_ordering = 0

ctx.save_for_backward(*points.T, values)

ctx.isign = _i_sign
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs

finufft_out = torch.from_numpy(
finufft.nufft3d1(
*points.data.T.numpy(),
values.data.numpy(),
output_shape,
modeord=_mode_ordering,
isign=_i_sign,
**finufftkwargs,
)
)

return finufft_out




def backward_type1(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Union[torch.Tensor, None], ...]:
"""
Implements derivatives wrt. each argument in the forward method.
Parameters
----------
ctx : Any
Pytorch context object.
grad_output : torch.Tensor
Backpass gradient output
Returns
-------
tuple[Union[torch.Tensor, None], ...]
A tuple of derivatives wrt. each argument in the forward method
"""
_i_sign = -1 * ctx.isign
_mode_ordering = ctx.mode_ordering
finufftkwargs = ctx.finufftkwargs


points, values = ctx.saved_tensors

start_points = -np.array(grad_output.shape) // 2
end_points = start_points + grad_output.shape
slices = tuple(slice(start, end) for start, end in zip(start_points, end_points))

coord_ramps = torch.mgrid[slices]

grads_points = None
grad_values = None

if ctx.needs_input_grad[0]:
# wrt points

if _mode_ordering != 0:
coord_ramps = torch.fft.ifftshift(coord_ramps)

ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign

grads_points = []
for ramp in ramped_grad_output: # we can batch this into finufft
backprop_ramp = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy(),
ramp.data.numpy(),
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
))
grad_points = (backprop_ramp.conj() * values).real
grads_points.append(grad_points)

grads_points = torch.stack(grads_points)

if ctx.needs_input_grad[1]:
np_grad_output = grad_output.data.numpy()

grad_values = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy()
np_grad_output,
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
)
)

return (
grads_points,
grad_values,
None,
None,
None,
None,
)

0 comments on commit 88e8741

Please sign in to comment.