From 88e87412e20cf2665654b88507e111830ae199d5 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Wed, 4 Oct 2023 12:36:16 -0400 Subject: [PATCH] WIP outline of consolidation idea for type 1 --- pytorch_finufft/functional.py | 148 ++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 96551b5..a63289e 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1592,3 +1592,151 @@ def backward( None, None, ) + + + + + +############################################################################### +# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 +############################################################################### + +# This function takes a ctx object and is supposed to later replace all type 1 +# functions above for all dimensionalities. + +def finufft_type1( + ctx: Any, + points: torch.Tensor, + values: torch.Tensor, + output_shape: Union[int, tuple[int, int], tuple[int, int, int]], + out: Optional[torch.Tensor]=None, + fftshift: bool=False, + finufftkwargs: dict[str, Union[int, float]]=None): + """ + Evaluates the Type 1 NUFFT on the inputs. + + """ + + if out is not None: + print("In-place results are not yet implemented") + # All this requires is a check on the out array to make sure it is the + # correct shape. + + err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately + + + if finufftkwargs is None: + finufftkwargs = dict() + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) + _i_sign = finufftkwargs.pop("isign", -1) + + if fftshift: + # TODO -- this check should be done elsewhere? or error msg changed + # to note instead that there is a conflict in fftshift + if _mode_ordering != 1: + raise ValueError( + "Double specification of ordering; only one of fftshift and modeord should be provided" + ) + _mode_ordering = 0 + + ctx.save_for_backward(*points.T, values) + + ctx.isign = _i_sign + ctx.mode_ordering = _mode_ordering + ctx.finufftkwargs = finufftkwargs + + finufft_out = torch.from_numpy( + finufft.nufft3d1( + *points.data.T.numpy(), + values.data.numpy(), + output_shape, + modeord=_mode_ordering, + isign=_i_sign, + **finufftkwargs, + ) + ) + + return finufft_out + + + + +def backward_type1( + ctx: Any, grad_output: torch.Tensor +) -> tuple[Union[torch.Tensor, None], ...]: + """ + Implements derivatives wrt. each argument in the forward method. + + Parameters + ---------- + ctx : Any + Pytorch context object. + grad_output : torch.Tensor + Backpass gradient output + + Returns + ------- + tuple[Union[torch.Tensor, None], ...] + A tuple of derivatives wrt. each argument in the forward method + """ + _i_sign = -1 * ctx.isign + _mode_ordering = ctx.mode_ordering + finufftkwargs = ctx.finufftkwargs + + + points, values = ctx.saved_tensors + + start_points = -np.array(grad_output.shape) // 2 + end_points = start_points + grad_output.shape + slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) + + coord_ramps = torch.mgrid[slices] + + grads_points = None + grad_values = None + + if ctx.needs_input_grad[0]: + # wrt points + + if _mode_ordering != 0: + coord_ramps = torch.fft.ifftshift(coord_ramps) + + ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + + grads_points = [] + for ramp in ramped_grad_output: # we can batch this into finufft + backprop_ramp = torch.from_numpy( + finufft.nufft3d2( + *points.T.numpy(), + ramp.data.numpy(), + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + )) + grad_points = (backprop_ramp.conj() * values).real + grads_points.append(grad_points) + + grads_points = torch.stack(grads_points) + + if ctx.needs_input_grad[1]: + np_grad_output = grad_output.data.numpy() + + grad_values = torch.from_numpy( + finufft.nufft3d2( + *points.T.numpy() + np_grad_output, + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + ) + ) + + return ( + grads_points, + grad_values, + None, + None, + None, + None, + )