diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 2b1591e..e59586f 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -27,7 +27,7 @@ def forward( output_shape: int, out: Optional[torch.Tensor] = None, fftshift: bool = False, - finufftkwargs: dict[str, Union[int, float]] = {}, + finufftkwargs: dict[str, Union[int, float]] = None, ) -> torch.Tensor: """ Evaluates the Type 1 NUFFT on the inputs. @@ -79,6 +79,9 @@ def forward( err._type1_checks((points,), values, output_shape) + if finufftkwargs is None: + finufftkwargs = dict() + finufftkwargs = {k: v for k, v in finufftkwargs.items()} _mode_ordering = finufftkwargs.pop("modeord", 1) _i_sign = finufftkwargs.pop("isign", -1)