diff --git a/src/jax_finufft/options.py b/src/jax_finufft/options.py index ec08f77..55c4f60 100644 --- a/src/jax_finufft/options.py +++ b/src/jax_finufft/options.py @@ -40,6 +40,9 @@ class GpuMethod(IntEnum): @dataclass(frozen=True) class Opts: + + # These correspond to the default cufinufft options + # set in vendor/finufft/src/cuda/cufinufft.cu modeord: bool = False chkbnds: bool = True debug: DebugLevel = DebugLevel.Silent @@ -59,12 +62,12 @@ class Opts: gpu_upsampfac: float = 2.0 gpu_method: GpuMethod = 0 gpu_sort: bool = True - gpu_binsizex: int = -1 - gpu_binsizey: int = -1 - gpu_binsizez: int = -1 - gpu_obinsizex: int = -1 - gpu_obinsizey: int = -1 - gpu_obinsizez: int = -1 + gpu_binsizex: int = 0 + gpu_binsizey: int = 0 + gpu_binsizez: int = 0 + gpu_obinsizex: int = 0 + gpu_obinsizey: int = 0 + gpu_obinsizez: int = 0 gpu_maxsubprobsize: int = 1024 gpu_kerevalmeth: bool = True gpu_spreadinterponly: bool = False