diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index a63289e..0546f62 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1601,142 +1601,139 @@ def backward( # Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 ############################################################################### -# This function takes a ctx object and is supposed to later replace all type 1 -# functions above for all dimensionalities. - -def finufft_type1( - ctx: Any, - points: torch.Tensor, - values: torch.Tensor, - output_shape: Union[int, tuple[int, int], tuple[int, int, int]], - out: Optional[torch.Tensor]=None, - fftshift: bool=False, - finufftkwargs: dict[str, Union[int, float]]=None): - """ - Evaluates the Type 1 NUFFT on the inputs. +class finufft_type1(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + points: torch.Tensor, + values: torch.Tensor, + output_shape: Union[int, tuple[int, int], tuple[int, int, int]], + out: Optional[torch.Tensor]=None, + fftshift: bool=False, + finufftkwargs: dict[str, Union[int, float]]=None): + """ + Evaluates the Type 1 NUFFT on the inputs. - """ + """ - if out is not None: - print("In-place results are not yet implemented") - # All this requires is a check on the out array to make sure it is the - # correct shape. + if out is not None: + print("In-place results are not yet implemented") + # All this requires is a check on the out array to make sure it is the + # correct shape. - err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately + err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately - if finufftkwargs is None: - finufftkwargs = dict() - finufftkwargs = {k: v for k, v in finufftkwargs.items()} - _mode_ordering = finufftkwargs.pop("modeord", 1) - _i_sign = finufftkwargs.pop("isign", -1) + if finufftkwargs is None: + finufftkwargs = dict() + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) + _i_sign = finufftkwargs.pop("isign", -1) - if fftshift: - # TODO -- this check should be done elsewhere? or error msg changed - # to note instead that there is a conflict in fftshift - if _mode_ordering != 1: - raise ValueError( - "Double specification of ordering; only one of fftshift and modeord should be provided" - ) - _mode_ordering = 0 + if fftshift: + # TODO -- this check should be done elsewhere? or error msg changed + # to note instead that there is a conflict in fftshift + if _mode_ordering != 1: + raise ValueError( + "Double specification of ordering; only one of fftshift and modeord should be provided" + ) + _mode_ordering = 0 - ctx.save_for_backward(*points.T, values) + ctx.save_for_backward(*points.T, values) - ctx.isign = _i_sign - ctx.mode_ordering = _mode_ordering - ctx.finufftkwargs = finufftkwargs + ctx.isign = _i_sign + ctx.mode_ordering = _mode_ordering + ctx.finufftkwargs = finufftkwargs - finufft_out = torch.from_numpy( - finufft.nufft3d1( - *points.data.T.numpy(), - values.data.numpy(), - output_shape, - modeord=_mode_ordering, - isign=_i_sign, - **finufftkwargs, + finufft_out = torch.from_numpy( + finufft.nufft3d1( + *points.data.T.numpy(), + values.data.numpy(), + output_shape, + modeord=_mode_ordering, + isign=_i_sign, + **finufftkwargs, + ) ) - ) - return finufft_out + return finufft_out + @staticmethod + def backward( + ctx: Any, grad_output: torch.Tensor + ) -> tuple[Union[torch.Tensor, None], ...]: + """ + Implements derivatives wrt. each argument in the forward method. + Parameters + ---------- + ctx : Any + Pytorch context object. + grad_output : torch.Tensor + Backpass gradient output + Returns + ------- + tuple[Union[torch.Tensor, None], ...] + A tuple of derivatives wrt. each argument in the forward method + """ + _i_sign = -1 * ctx.isign + _mode_ordering = ctx.mode_ordering + finufftkwargs = ctx.finufftkwargs -def backward_type1( - ctx: Any, grad_output: torch.Tensor -) -> tuple[Union[torch.Tensor, None], ...]: - """ - Implements derivatives wrt. each argument in the forward method. - - Parameters - ---------- - ctx : Any - Pytorch context object. - grad_output : torch.Tensor - Backpass gradient output - - Returns - ------- - tuple[Union[torch.Tensor, None], ...] - A tuple of derivatives wrt. each argument in the forward method - """ - _i_sign = -1 * ctx.isign - _mode_ordering = ctx.mode_ordering - finufftkwargs = ctx.finufftkwargs + points, values = ctx.saved_tensors - points, values = ctx.saved_tensors + start_points = -np.array(grad_output.shape) // 2 + end_points = start_points + grad_output.shape + slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) - start_points = -np.array(grad_output.shape) // 2 - end_points = start_points + grad_output.shape - slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) + coord_ramps = torch.mgrid[slices] - coord_ramps = torch.mgrid[slices] + grads_points = None + grad_values = None - grads_points = None - grad_values = None + if ctx.needs_input_grad[0]: + # wrt points - if ctx.needs_input_grad[0]: - # wrt points + if _mode_ordering != 0: + coord_ramps = torch.fft.ifftshift(coord_ramps) + + ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + + grads_points = [] + for ramp in ramped_grad_output: # we can batch this into finufft + backprop_ramp = torch.from_numpy( + finufft.nufft3d2( + *points.T.numpy(), + ramp.data.numpy(), + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + )) + grad_points = (backprop_ramp.conj() * values).real + grads_points.append(grad_points) + + grads_points = torch.stack(grads_points) - if _mode_ordering != 0: - coord_ramps = torch.fft.ifftshift(coord_ramps) - - ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + if ctx.needs_input_grad[1]: + np_grad_output = grad_output.data.numpy() - grads_points = [] - for ramp in ramped_grad_output: # we can batch this into finufft - backprop_ramp = torch.from_numpy( + grad_values = torch.from_numpy( finufft.nufft3d2( - *points.T.numpy(), - ramp.data.numpy(), + *points.T.numpy() + np_grad_output, isign=_i_sign, modeord=_mode_ordering, **finufftkwargs, - )) - grad_points = (backprop_ramp.conj() * values).real - grads_points.append(grad_points) - - grads_points = torch.stack(grads_points) - - if ctx.needs_input_grad[1]: - np_grad_output = grad_output.data.numpy() - - grad_values = torch.from_numpy( - finufft.nufft3d2( - *points.T.numpy() - np_grad_output, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, + ) ) - ) - return ( - grads_points, - grad_values, - None, - None, - None, - None, - ) + return ( + grads_points, + grad_values, + None, + None, + None, + None, + )