From b50ddcd0241c247363d9f7d493075424d040e333 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 10:31:28 -0400 Subject: [PATCH 01/17] MVP: cuda support for type 1 Needs more checks, modeord support? --- pytorch_finufft/functional.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 7b40e71..bf8651e 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -6,6 +6,11 @@ import numpy as np import finufft +try: + import cufinufft + CUFINUFFT_AVAIL = True +except: + CUFINUFFT_AVAIL = False import torch import pytorch_finufft._err as err @@ -1602,7 +1607,9 @@ def backward( # Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 ############################################################################### -def get_nufft_func(dim, nufft_type): +def get_nufft_func(dim, nufft_type, device_type): + if device_type == 'cuda': + return getattr(cufinufft, f"nufft{dim}d{nufft_type}") return getattr(finufft, f"nufft{dim}d{nufft_type}") @@ -1654,8 +1661,15 @@ def forward( ndim = points.shape[0] assert len(output_shape) == ndim - nufft_func = get_nufft_func(ndim, 1) - finufft_out = torch.from_numpy( + nufft_func = get_nufft_func(ndim, 1, points.device.type) + if points.device.type == 'cuda': + finufft_out = nufft_func( + *points, values, output_shape, + # modeord=_mode_ordering, # TODO(cuda): modeord not supported? + isign=_i_sign, **finufftkwargs + ) + else: + finufft_out = torch.from_numpy( nufft_func( *points.data.numpy(), values.data.numpy(), @@ -1712,7 +1726,7 @@ def backward( if _mode_ordering != 0: coord_ramps = torch.fft.ifftshift(coord_ramps, dim=tuple(range(1, ndim+1))) - + ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign grads_points = [] @@ -1727,7 +1741,7 @@ def backward( )) grad_points = (backprop_ramp.conj() * values).real grads_points.append(grad_points) - + grads_points = torch.stack(grads_points) if ctx.needs_input_grad[1]: From 7ca3c6f843f82c0b8fab92bbc3edf03682c27cd5 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 10:31:56 -0400 Subject: [PATCH 02/17] Add (failing) tests --- tests/test_1d/test_forward_1d.py | 33 ++++++++++++++++++++++++++++ tests/test_2d/test_forward_2d.py | 33 ++++++++++++++++++++++++++++ tests/test_3d/test_forward_3d.py | 37 +++++++++++++++++++++++++++++++- 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 5379927..8e93a33 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -140,6 +140,39 @@ def test_t1_forward_CPU(N: int) -> None: assert l_1_error < 1e-5 * N ** 3 +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_cuda(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N] * 2 * np.pi / N + g.shape = 1, -1 + points = torch.from_numpy(g.reshape(1, -1)).to('cuda') + + values = torch.randn(*points[0].shape, dtype=torch.complex128).to('cuda') + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N,), + ) + + against_torch = torch.fft.fft(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 + # @pytest.mark.parametrize("values", cases) # def test_1d_t3_forward_CPU(values: torch.Tensor) -> None: diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 1dda568..04d37f5 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -154,3 +154,36 @@ def test_t1_forward_CPU(N: int) -> None: assert l_2_error < 1e-5 * N ** 2 assert l_1_error < 1e-5 * N ** 3 + + +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_cuda(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(2, -1)).to('cuda') + + values = torch.randn(*points[0].shape, dtype=torch.complex128).to('cuda') + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N, N), + ) + + against_torch = torch.fft.fft2(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index 45484aa..3eccd58 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -114,4 +114,39 @@ def test_t1_forward_CPU(N: int) -> None: assert l_inf_error < 1.5e-5 * N ** 1.5 assert l_2_error < 1e-5 * N ** 3 - assert l_1_error < 1e-5 * N ** 4.5 \ No newline at end of file + assert l_1_error < 1e-5 * N ** 4.5 + + +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_cuda(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(3, -1)).to('cuda') + + values = torch.randn(*points[0].shape, dtype=torch.complex128).to('cuda') + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N, N, N), + ) + + against_torch = torch.fft.fftn(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + + + assert l_inf_error < 1.5e-5 * N ** 1.5 + assert l_2_error < 1e-5 * N ** 3 + assert l_1_error < 1e-5 * N ** 4.5 From c8c9e0cbe65812644941756d26fd6ea894d3cf54 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 10:33:33 -0400 Subject: [PATCH 03/17] First pass at CI --- Jenkinsfile | 90 +++++++++++++++++++++++++++++++++++ ci/docker/Dockerfile-cuda11.8 | 69 +++++++++++++++++++++++++++ ci/docker/cuda.repo | 6 +++ 3 files changed, 165 insertions(+) create mode 100644 Jenkinsfile create mode 100644 ci/docker/Dockerfile-cuda11.8 create mode 100644 ci/docker/cuda.repo diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 0000000..09ee115 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,90 @@ +pipeline { + agent none + options { + disableConcurrentBuilds() + buildDiscarder(logRotator(numToKeepStr: '8', daysToKeepStr: '20')) + timeout(time: 1, unit: 'HOURS') + } + stages { + stage('main') { + agent { + dockerfile { + filename 'ci/docker/Dockerfile-cuda11.8' + args '--gpus 2' + label 'v100' + } + } + environment { + HOME = "$WORKSPACE" + PYBIN = "/opt/python/cp38-cp38/bin" + LIBRARY_PATH = "$WORKSPACE/finufft/build" + LD_LIBRARY_PATH = "$WORKSPACE/finufft/build" + } + steps { + + // TODO - reconsider install strategy once finufft/cufinufft 2.2 is released + checkout([$class: 'GitSCM', + branches: [[name: '*/master']], + userRemoteConfigs: [[url: "https://github.com/flatironinstitute/finufft"]]] + ) + + sh '''#!/bin/bash -ex + nvidia-smi + ''' + sh '''#!/bin/bash -ex + echo $HOME + ''' + sh '''#!/bin/bash -ex + cd finufft + # v100 cuda arch + cuda_arch="70" + + cmake -B build . -DFINUFFT_USE_CUDA=ON \ + -DFINUFFT_USE_CPU=OFF \ + -DFINUFFT_BUILD_TESTS=ON \ + -DCMAKE_CUDA_ARCHITECTURES="$cuda_arch" \ + -DBUILD_TESTING=ON + cd build + make -j4 + ''' + + sh '${PYBIN}/python3 -m venv $HOME' + sh '''#!/bin/bash -ex + source $HOME/bin/activate + python3 -m pip install --upgrade pip + # we could also move pytorch install inside docker + python3 -m pip install "torch~=2.1.0" --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install finufft/python/cufinufft + python3 -m pip install finufft/python/finufft + + python3 -m pip install -e .[dev] + + python3 -m pytest -k "cuda" tests/ --cov + ''' + } + } + } + post { + failure { + emailext subject: '$PROJECT_NAME - Build #$BUILD_NUMBER - $BUILD_STATUS', + body: '''$PROJECT_NAME - Build #$BUILD_NUMBER - $BUILD_STATUS + +Check console output at $BUILD_URL to view full results. + +Building $BRANCH_NAME for $CAUSE +$JOB_DESCRIPTION + +Chages: +$CHANGES + +End of build log: +${BUILD_LOG,maxLines=200} +''', + recipientProviders: [ + [$class: 'DevelopersRecipientProvider'], + ], + replyTo: '$DEFAULT_REPLYTO', + to: 'janden@flatironinstitute.org' + } + } +} diff --git a/ci/docker/Dockerfile-cuda11.8 b/ci/docker/Dockerfile-cuda11.8 new file mode 100644 index 0000000..dd0a5c2 --- /dev/null +++ b/ci/docker/Dockerfile-cuda11.8 @@ -0,0 +1,69 @@ +# Based on https://github.com/flatironinstitute/finufft/blob/master/tools/cufinufft/docker/cuda11.2/Dockerfile-x86_64 + +FROM quay.io/pypa/manylinux2014_x86_64 +LABEL maintainer "Brian Ward" + +ENV CUDA_MAJOR 11 +ENV CUDA_MINOR 8 +ENV CUDA_DASH_VERSION ${CUDA_MAJOR}-${CUDA_MINOR} +ENV CUDA_DOT_VERSION ${CUDA_MAJOR}.${CUDA_MINOR} + +# ---- The following block adds layers for CUDA --- # +# base +RUN NVIDIA_GPGKEY_SUM=d0664fbbdb8c32356d45de36c5984617217b2d0bef41b93ccecd326ba3b80c87 && \ + curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/D42D0685.pub | sed '/^Version/d' > /etc/pki/rpm-gpg/RPM-GPG-KEY-NVIDIA && \ + echo "$NVIDIA_GPGKEY_SUM /etc/pki/rpm-gpg/RPM-GPG-KEY-NVIDIA" | sha256sum -c --strict - + +COPY ci/docker/cuda.repo /etc/yum.repos.d/cuda.repo + +# For libraries in the cuda-compat-* package: https://docs.nvidia.com/cuda/eula/index.html#attachment-a +RUN yum install -y \ + cuda-cudart-${CUDA_DASH_VERSION} \ + cuda-compat-${CUDA_DASH_VERSION} && \ + ln -s cuda-${CUDA_DOT_VERSION} /usr/local/cuda && \ + rm -rf /var/cache/yum/* + +# nvidia-docker 1.0 +RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \ + echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} +ENV LD_LIBRARY_PATH ${LD_LIBRARY_PATH}:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + +# nvidia-container-runtime +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility +ENV NVIDIA_REQUIRE_CUDA "cuda>=${CUDA_DOT_VERSION} brand=tesla,driver>=418,driver<419 brand=tesla,driver>=440,driver<441" + +# runtime +RUN yum install -y \ + cuda-libraries-${CUDA_DASH_VERSION} \ + cuda-nvtx-${CUDA_DASH_VERSION} && \ + rm -rf /var/cache/yum/* + +# devel +RUN yum install -y \ + cuda-cudart-devel-${CUDA_DASH_VERSION} \ + cuda-libraries-devel-${CUDA_DASH_VERSION} \ + cuda-nvprof-${CUDA_DASH_VERSION} \ + cuda-nvcc-${CUDA_DASH_VERSION} && \ + rm -rf /var/cache/yum/* + +ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs + +# /CUDA # + +# CUDA 11 doesn't work on gcc/g++ newer than v9 +RUN yum install -y \ + devtoolset-9-gcc \ + devtoolset-9-gcc-c++ && \ + rm -rf /var/cache/yum/* + +ENV PATH /opt/rh/devtoolset-9/root/usr/bin:${PATH} + +# finufft reqs +RUN yum install -y \ + cmake && \ + rm -rf /var/cache/yum/* + +RUN diff --git a/ci/docker/cuda.repo b/ci/docker/cuda.repo new file mode 100644 index 0000000..ba2cba6 --- /dev/null +++ b/ci/docker/cuda.repo @@ -0,0 +1,6 @@ +[cuda] +name=cuda +baseurl=https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64 +enabled=1 +gpgcheck=1 +gpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-NVIDIA From a89b8ef8a72eec67b684e6f177a89f4a4927d360 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 10:53:26 -0400 Subject: [PATCH 04/17] Work around modeord issue, cuda forward tests passing --- pytorch_finufft/functional.py | 68 ++++++++++++++++++------------- tests/test_1d/test_forward_1d.py | 19 ++++----- tests/test_2d/test_backward_2d.py | 14 ++++--- tests/test_2d/test_forward_2d.py | 22 +++++----- tests/test_3d/test_forward_3d.py | 32 +++++++-------- 5 files changed, 83 insertions(+), 72 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index bf8651e..2fa9d7c 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -6,8 +6,10 @@ import numpy as np import finufft + try: import cufinufft + CUFINUFFT_AVAIL = True except: CUFINUFFT_AVAIL = False @@ -1600,15 +1602,13 @@ def backward( ) - - - ############################################################################### # Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 ############################################################################### + def get_nufft_func(dim, nufft_type, device_type): - if device_type == 'cuda': + if device_type == "cuda": return getattr(cufinufft, f"nufft{dim}d{nufft_type}") return getattr(finufft, f"nufft{dim}d{nufft_type}") @@ -1616,13 +1616,14 @@ def get_nufft_func(dim, nufft_type, device_type): class finufft_type1(torch.autograd.Function): @staticmethod def forward( - ctx: Any, - points: torch.Tensor, - values: torch.Tensor, - output_shape: Union[int, tuple[int, int], tuple[int, int, int]], - out: Optional[torch.Tensor]=None, - fftshift: bool=False, - finufftkwargs: dict[str, Union[int, float]]=None): + ctx: Any, + points: torch.Tensor, + values: torch.Tensor, + output_shape: Union[int, tuple[int, int], tuple[int, int, int]], + out: Optional[torch.Tensor] = None, + fftshift: bool = False, + finufftkwargs: dict[str, Union[int, float]] = None, + ): """ Evaluates the Type 1 NUFFT on the inputs. @@ -1633,7 +1634,9 @@ def forward( # All this requires is a check on the out array to make sure it is the # correct shape. - err._type1_checks(points, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately + err._type1_checks( + points, values, output_shape + ) # revisit these error checks to take into account the shape of points instead of passing them separately # ^ make sure these checks check for consistency between output shape and len(points) if finufftkwargs is None: @@ -1661,24 +1664,28 @@ def forward( ndim = points.shape[0] assert len(output_shape) == ndim + # if _mode_ordering: + # values = torch.fft.ifftshift(values) + nufft_func = get_nufft_func(ndim, 1, points.device.type) - if points.device.type == 'cuda': + if points.device.type == "cuda": finufft_out = nufft_func( - *points, values, output_shape, - # modeord=_mode_ordering, # TODO(cuda): modeord not supported? - isign=_i_sign, **finufftkwargs + *points, values, output_shape, isign=_i_sign, **finufftkwargs ) else: finufft_out = torch.from_numpy( - nufft_func( - *points.data.numpy(), - values.data.numpy(), - output_shape, - modeord=_mode_ordering, - isign=_i_sign, - **finufftkwargs, + nufft_func( + *points.data.numpy(), + values.data.numpy(), + output_shape, + isign=_i_sign, + **finufftkwargs, + ) ) - ) + + # because modeord is missing from cufinufft + if _mode_ordering: + finufft_out = torch.fft.ifftshift(finufft_out) return finufft_out @@ -1709,7 +1716,9 @@ def backward( start_points = -(np.array(grad_output.shape) // 2) end_points = start_points + grad_output.shape - slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) + slices = tuple( + slice(start, end) for start, end in zip(start_points, end_points) + ) # CPU idiosyncracy that needs to be done differently coord_ramps = torch.from_numpy(np.mgrid[slices]) @@ -1725,12 +1734,14 @@ def backward( # wrt points if _mode_ordering != 0: - coord_ramps = torch.fft.ifftshift(coord_ramps, dim=tuple(range(1, ndim+1))) + coord_ramps = torch.fft.ifftshift( + coord_ramps, dim=tuple(range(1, ndim + 1)) + ) ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign grads_points = [] - for ramp in ramped_grad_output: # we can batch this into finufft + for ramp in ramped_grad_output: # we can batch this into finufft backprop_ramp = torch.from_numpy( nufft_func( *points.numpy(), @@ -1738,7 +1749,8 @@ def backward( isign=_i_sign, modeord=_mode_ordering, **finufftkwargs, - )) + ) + ) grad_points = (backprop_ramp.conj() * values).real grads_points.append(grad_points) diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 8e93a33..8dc6ff9 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -66,15 +66,14 @@ def test_1d_t1_forward_CPU(values: torch.Tensor) -> None: torch.linalg.norm(finufft1D1_out - against_scipy) / N**2 ) == pytest.approx(0, abs=1e-06) - abs_errors = torch.abs(finufft1D1_out - against_torch) l_inf_error = abs_errors.max() l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) - assert l_inf_error < 3.5e-3 * N ** .6 - assert l_2_error < 7.5e-4 * N ** 1.1 - assert l_1_error < 5e-4 * N ** 1.6 + assert l_inf_error < 3.5e-3 * N**0.6 + assert l_2_error < 7.5e-4 * N**1.1 + assert l_1_error < 5e-4 * N**1.6 @pytest.mark.parametrize("targets", cases) @@ -136,8 +135,8 @@ def test_t1_forward_CPU(N: int) -> None: l_1_error = torch.sum(abs_errors) assert l_inf_error < 4.5e-5 * N - assert l_2_error < 1e-5 * N ** 2 - assert l_1_error < 1e-5 * N ** 3 + assert l_2_error < 1e-5 * N**2 + assert l_1_error < 1e-5 * N**3 @pytest.mark.parametrize("N", Ns) @@ -148,9 +147,9 @@ def test_t1_forward_cuda(N: int) -> None: """ g = np.mgrid[:N] * 2 * np.pi / N g.shape = 1, -1 - points = torch.from_numpy(g.reshape(1, -1)).to('cuda') + points = torch.from_numpy(g.reshape(1, -1)).to("cuda") - values = torch.randn(*points[0].shape, dtype=torch.complex128).to('cuda') + values = torch.randn(*points[0].shape, dtype=torch.complex128).to("cuda") print("N is " + str(N)) print("shape of points is " + str(points.shape)) @@ -170,8 +169,8 @@ def test_t1_forward_cuda(N: int) -> None: l_1_error = torch.sum(abs_errors) assert l_inf_error < 4.5e-5 * N - assert l_2_error < 1e-5 * N ** 2 - assert l_1_error < 1e-5 * N ** 3 + assert l_2_error < 1e-5 * N**2 + assert l_1_error < 1e-5 * N**3 # @pytest.mark.parametrize("values", cases) diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 6a9b707..9cf2400 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -104,8 +104,9 @@ def test_t1_backward_CPU_values( @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @pytest.mark.parametrize("isign", [-1, 1]) -def test_t1_consolidated_backward_CPU_values(N: int, modifier: int, fftshift: bool, isign: int) -> None: - +def test_t1_consolidated_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi values = torch.randn(N, dtype=torch.complex128) @@ -116,7 +117,7 @@ def test_t1_consolidated_backward_CPU_values(N: int, modifier: int, fftshift: bo def func(points, values): return pytorch_finufft.functional.finufft_type1.apply( - points, values, (N,N + modifier), None, fftshift, dict(isign=isign) + points, values, (N, N + modifier), None, fftshift, dict(isign=isign) ) assert gradcheck(func, inputs) @@ -126,8 +127,9 @@ def func(points, values): @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @pytest.mark.parametrize("isign", [-1, 1]) -def test_t1_consolidated_backward_CPU_points(N: int, modifier: int, fftshift: bool, isign: int) -> None: - +def test_t1_consolidated_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi values = torch.randn(N, dtype=torch.complex128) @@ -138,7 +140,7 @@ def test_t1_consolidated_backward_CPU_points(N: int, modifier: int, fftshift: bo def func(points, values): return pytorch_finufft.functional.finufft_type1.apply( - points, values, (N,N + modifier), None, fftshift, dict(isign=isign) + points, values, (N, N + modifier), None, fftshift, dict(isign=isign) ) assert gradcheck(func, inputs, atol=1e-5 * N) diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 04d37f5..3008f8a 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch + torch.manual_seed(0) import pytorch_finufft @@ -52,8 +53,8 @@ def test_2d_t1_forward_CPU(N: int) -> None: l_1_error = torch.sum(abs_errors) assert l_inf_error < 5e-5 * N - assert l_2_error < 1e-5 * N ** 2 - assert l_1_error < 1e-5 * N ** 3 + assert l_2_error < 1e-5 * N**2 + assert l_1_error < 1e-5 * N**3 @pytest.mark.parametrize("N", Ns) @@ -102,8 +103,8 @@ def test_2d_t2_forward_CPU(N: int) -> None: l_1_error = torch.sum(abs_errors) assert l_inf_error < 1e-5 * N - assert l_2_error < 1e-5 * N ** 2 - assert l_1_error < 1e-5 * N ** 3 + assert l_2_error < 1e-5 * N**2 + assert l_1_error < 1e-5 * N**3 # @pytest.mark.parametrize("N", Ns) @@ -151,9 +152,8 @@ def test_t1_forward_CPU(N: int) -> None: l_1_error = torch.sum(abs_errors) assert l_inf_error < 4.5e-5 * N - assert l_2_error < 1e-5 * N ** 2 - assert l_1_error < 1e-5 * N ** 3 - + assert l_2_error < 1e-5 * N**2 + assert l_1_error < 1e-5 * N**3 @pytest.mark.parametrize("N", Ns) @@ -163,9 +163,9 @@ def test_t1_forward_cuda(N: int) -> None: over which to call FINUFFT through the API. """ g = np.mgrid[:N, :N] * 2 * np.pi / N - points = torch.from_numpy(g.reshape(2, -1)).to('cuda') + points = torch.from_numpy(g.reshape(2, -1)).to("cuda") - values = torch.randn(*points[0].shape, dtype=torch.complex128).to('cuda') + values = torch.randn(*points[0].shape, dtype=torch.complex128).to("cuda") print("N is " + str(N)) print("shape of points is " + str(points.shape)) @@ -185,5 +185,5 @@ def test_t1_forward_cuda(N: int) -> None: l_1_error = torch.sum(abs_errors) assert l_inf_error < 4.5e-5 * N - assert l_2_error < 1e-5 * N ** 2 - assert l_1_error < 1e-5 * N ** 3 + assert l_2_error < 1e-5 * N**2 + assert l_1_error < 1e-5 * N**3 diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index 3eccd58..dd7c2ed 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch + torch.manual_seed(0) import pytorch_finufft @@ -45,10 +46,9 @@ def test_3d_t1_forward_CPU(N: int) -> None: l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) - assert l_inf_error < 2e-5 * N ** 1.5 - assert l_2_error < 1e-5 * N ** 3 - assert l_1_error < 1e-5 * N ** 4.5 - + assert l_inf_error < 2e-5 * N**1.5 + assert l_2_error < 1e-5 * N**3 + assert l_1_error < 1e-5 * N**4.5 @pytest.mark.parametrize("N", Ns) @@ -79,9 +79,9 @@ def test_3d_t2_forward_CPU(N: int) -> None: l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) - assert l_inf_error < 1e-5 * N ** 1.5 - assert l_2_error < 1e-5 * N ** 3 - assert l_1_error < 1e-5 * N ** 4.5 + assert l_inf_error < 1e-5 * N**1.5 + assert l_2_error < 1e-5 * N**3 + assert l_1_error < 1e-5 * N**4.5 @pytest.mark.parametrize("N", Ns) @@ -112,9 +112,9 @@ def test_t1_forward_CPU(N: int) -> None: l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) - assert l_inf_error < 1.5e-5 * N ** 1.5 - assert l_2_error < 1e-5 * N ** 3 - assert l_1_error < 1e-5 * N ** 4.5 + assert l_inf_error < 1.5e-5 * N**1.5 + assert l_2_error < 1e-5 * N**3 + assert l_1_error < 1e-5 * N**4.5 @pytest.mark.parametrize("N", Ns) @@ -124,9 +124,9 @@ def test_t1_forward_cuda(N: int) -> None: over which to call FINUFFT through the API. """ g = np.mgrid[:N, :N, :N] * 2 * np.pi / N - points = torch.from_numpy(g.reshape(3, -1)).to('cuda') + points = torch.from_numpy(g.reshape(3, -1)).to("cuda") - values = torch.randn(*points[0].shape, dtype=torch.complex128).to('cuda') + values = torch.randn(*points[0].shape, dtype=torch.complex128).to("cuda") print("N is " + str(N)) print("shape of points is " + str(points.shape)) @@ -145,8 +145,6 @@ def test_t1_forward_cuda(N: int) -> None: l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) - - - assert l_inf_error < 1.5e-5 * N ** 1.5 - assert l_2_error < 1e-5 * N ** 3 - assert l_1_error < 1e-5 * N ** 4.5 + assert l_inf_error < 1.5e-5 * N**1.5 + assert l_2_error < 1e-5 * N**3 + assert l_1_error < 1e-5 * N**4.5 From 69371242a25adf0bdfde90c4361264f8e3124ba7 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 10:58:55 -0400 Subject: [PATCH 05/17] Skip cuda tests in GHA --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4fcbc91..50b4cb9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,4 +34,4 @@ jobs: - name: Pytest run: | - pytest tests/ --cov + pytest tests/ --cov -k "not cuda" From 6bd21c2d23b57e2ab0d1b12719bf5b10cb739cc8 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 11:35:23 -0400 Subject: [PATCH 06/17] MVP cuda backward --- pytorch_finufft/functional.py | 72 +++++++++++++++---------------- tests/test_2d/test_backward_2d.py | 46 ++++++++++++++++++++ 2 files changed, 82 insertions(+), 36 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 2fa9d7c..76e4e4d 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1610,7 +1610,19 @@ def backward( def get_nufft_func(dim, nufft_type, device_type): if device_type == "cuda": return getattr(cufinufft, f"nufft{dim}d{nufft_type}") - return getattr(finufft, f"nufft{dim}d{nufft_type}") + + # CPU needs extra work to go to/from torch and numpy + finufft_func = getattr(finufft, f"nufft{dim}d{nufft_type}") + + def f(*args, **kwargs): + new_args = [arg for arg in args] + for i in range(len(new_args)): + if isinstance(new_args[i], torch.Tensor): + new_args[i] = new_args[i].data.numpy() + + return torch.from_numpy(finufft_func(*new_args, **kwargs)) + + return f class finufft_type1(torch.autograd.Function): @@ -1668,21 +1680,9 @@ def forward( # values = torch.fft.ifftshift(values) nufft_func = get_nufft_func(ndim, 1, points.device.type) - if points.device.type == "cuda": - finufft_out = nufft_func( - *points, values, output_shape, isign=_i_sign, **finufftkwargs - ) - else: - finufft_out = torch.from_numpy( - nufft_func( - *points.data.numpy(), - values.data.numpy(), - output_shape, - isign=_i_sign, - **finufftkwargs, - ) - ) - + finufft_out = nufft_func( + *points, values, output_shape, isign=_i_sign, **finufftkwargs + ) # because modeord is missing from cufinufft if _mode_ordering: finufft_out = torch.fft.ifftshift(finufft_out) @@ -1721,19 +1721,19 @@ def backward( ) # CPU idiosyncracy that needs to be done differently - coord_ramps = torch.from_numpy(np.mgrid[slices]) + coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device) grads_points = None grad_values = None ndim = points.shape[0] - nufft_func = get_nufft_func(ndim, 2) + nufft_func = get_nufft_func(ndim, 2, points.device.type) if ctx.needs_input_grad[0]: # wrt points - if _mode_ordering != 0: + if _mode_ordering: coord_ramps = torch.fft.ifftshift( coord_ramps, dim=tuple(range(1, ndim + 1)) ) @@ -1742,31 +1742,31 @@ def backward( grads_points = [] for ramp in ramped_grad_output: # we can batch this into finufft - backprop_ramp = torch.from_numpy( - nufft_func( - *points.numpy(), - ramp.data.numpy(), - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, - ) + if _mode_ordering: + ramp = torch.fft.fftshift(ramp) + + backprop_ramp = nufft_func( + *points, + ramp, + isign=_i_sign, + **finufftkwargs, ) + grad_points = (backprop_ramp.conj() * values).real + grads_points.append(grad_points) grads_points = torch.stack(grads_points) if ctx.needs_input_grad[1]: - np_grad_output = grad_output.data.numpy() + if _mode_ordering: + grad_output = torch.fft.fftshift(grad_output) - grad_values = torch.from_numpy( - nufft_func( - *points.numpy(), - np_grad_output, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, - ) + grad_values = nufft_func( + *points, + grad_output, + isign=_i_sign, + **finufftkwargs, ) return ( diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 9cf2400..0cf7f71 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -146,6 +146,52 @@ def func(points, values): assert gradcheck(func, inputs, atol=1e-5 * N) +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t1_consolidated_backward_cuda_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + points = torch.rand((2, N), dtype=torch.float64).to("cuda") * 2 * np.pi + values = torch.randn(N, dtype=torch.complex128).to("cuda") + + points.requires_grad = False + values.requires_grad = True + + inputs = (points, values) + + def func(points, values): + return pytorch_finufft.functional.finufft_type1.apply( + points, values, (N, N + modifier), None, fftshift, dict(isign=isign) + ) + + assert gradcheck(func, inputs) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t1_consolidated_backward_cuda_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + points = torch.rand((2, N), dtype=torch.float64).to("cuda") * 2 * np.pi + values = torch.randn(N, dtype=torch.complex128).to("cuda") + + points.requires_grad = True + values.requires_grad = False + + inputs = (points, values) + + def func(points, values): + return pytorch_finufft.functional.finufft_type1.apply( + points, values, (N, N + modifier), None, fftshift, dict(isign=isign) + ) + + assert gradcheck(func, inputs, atol=1e-5 * N) + + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [True, False]) From fd1f729338b4b40e0f7b2b6ead0e887154ce72f4 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 11:38:08 -0400 Subject: [PATCH 07/17] Minor clean up --- pytorch_finufft/functional.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 76e4e4d..3aaf62a 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1650,6 +1650,7 @@ def forward( points, values, output_shape ) # revisit these error checks to take into account the shape of points instead of passing them separately # ^ make sure these checks check for consistency between output shape and len(points) + # need device checks if finufftkwargs is None: finufftkwargs = dict() @@ -1676,9 +1677,6 @@ def forward( ndim = points.shape[0] assert len(output_shape) == ndim - # if _mode_ordering: - # values = torch.fft.ifftshift(values) - nufft_func = get_nufft_func(ndim, 1, points.device.type) finufft_out = nufft_func( *points, values, output_shape, isign=_i_sign, **finufftkwargs From b66a421c3366c9e6e0e9fec048e18e25dd9571bb Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 12:49:12 -0400 Subject: [PATCH 08/17] Factor out common test code --- tests/test_1d/test_forward_1d.py | 41 ++++------------- tests/test_2d/test_backward_2d.py | 74 ++++++++++--------------------- tests/test_2d/test_forward_2d.py | 40 ++++------------- tests/test_3d/test_forward_3d.py | 40 ++++------------- 4 files changed, 47 insertions(+), 148 deletions(-) diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 8dc6ff9..eec8a9b 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -105,17 +105,16 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor): ) -@pytest.mark.parametrize("N", Ns) -def test_t1_forward_CPU(N: int) -> None: +def check_t1_forward(N: int, device: str) -> None: """ Tests against implementations of the FFT by setting up a uniform grid over which to call FINUFFT through the API. """ g = np.mgrid[:N] * 2 * np.pi / N g.shape = 1, -1 - points = torch.from_numpy(g.reshape(1, -1)) + points = torch.from_numpy(g.reshape(1, -1)).to(device) - values = torch.randn(*points[0].shape, dtype=torch.complex128) + values = torch.randn(*points[0].shape, dtype=torch.complex128).to(device) print("N is " + str(N)) print("shape of points is " + str(points.shape)) @@ -140,37 +139,13 @@ def test_t1_forward_CPU(N: int) -> None: @pytest.mark.parametrize("N", Ns) -def test_t1_forward_cuda(N: int) -> None: - """ - Tests against implementations of the FFT by setting up a uniform grid - over which to call FINUFFT through the API. - """ - g = np.mgrid[:N] * 2 * np.pi / N - g.shape = 1, -1 - points = torch.from_numpy(g.reshape(1, -1)).to("cuda") - - values = torch.randn(*points[0].shape, dtype=torch.complex128).to("cuda") - - print("N is " + str(N)) - print("shape of points is " + str(points.shape)) - print("shape of values is " + str(values.shape)) - - finufft_out = pytorch_finufft.functional.finufft_type1.apply( - points, - values, - (N,), - ) - - against_torch = torch.fft.fft(values.reshape(g[0].shape)) +def test_t1_forward_CPU(N: int) -> None: + check_t1_forward(N, "cpu") - abs_errors = torch.abs(finufft_out - against_torch) - l_inf_error = abs_errors.max() - l_2_error = torch.sqrt(torch.sum(abs_errors**2)) - l_1_error = torch.sum(abs_errors) - assert l_inf_error < 4.5e-5 * N - assert l_2_error < 1e-5 * N**2 - assert l_1_error < 1e-5 * N**3 +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_cuda(N: int) -> None: + check_t1_forward(N, "cuda") # @pytest.mark.parametrize("values", cases) diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 0cf7f71..751f364 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -100,18 +100,19 @@ def test_t1_backward_CPU_values( assert gradcheck(apply_finufft2d1(modifier, fftshift, isign), inputs) -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [False, True]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t1_consolidated_backward_CPU_values( - N: int, modifier: int, fftshift: bool, isign: int +def check_t1_backward( + N: int, + modifier: int, + fftshift: bool, + isign: int, + device: str, + points_or_values: bool, ) -> None: - points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi - values = torch.randn(N, dtype=torch.complex128) + points = torch.rand((2, N), dtype=torch.float64).to(device) * 2 * np.pi + values = torch.randn(N, dtype=torch.complex128).to(device) - points.requires_grad = False - values.requires_grad = True + points.requires_grad = points_or_values + values.requires_grad = not points_or_values inputs = (points, values) @@ -120,7 +121,7 @@ def func(points, values): points, values, (N, N + modifier), None, fftshift, dict(isign=isign) ) - assert gradcheck(func, inputs) + assert gradcheck(func, inputs, atol=1e-5 * N) @pytest.mark.parametrize("N", Ns) @@ -130,20 +131,17 @@ def func(points, values): def test_t1_consolidated_backward_CPU_points( N: int, modifier: int, fftshift: bool, isign: int ) -> None: - points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi - values = torch.randn(N, dtype=torch.complex128) - - points.requires_grad = True - values.requires_grad = False + check_t1_backward(N, modifier, fftshift, isign, "cpu", True) - inputs = (points, values) - - def func(points, values): - return pytorch_finufft.functional.finufft_type1.apply( - points, values, (N, N + modifier), None, fftshift, dict(isign=isign) - ) - assert gradcheck(func, inputs, atol=1e-5 * N) +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t1_consolidated_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t1_backward(N, modifier, fftshift, isign, "cpu", False) @pytest.mark.parametrize("N", Ns) @@ -153,20 +151,7 @@ def func(points, values): def test_t1_consolidated_backward_cuda_values( N: int, modifier: int, fftshift: bool, isign: int ) -> None: - points = torch.rand((2, N), dtype=torch.float64).to("cuda") * 2 * np.pi - values = torch.randn(N, dtype=torch.complex128).to("cuda") - - points.requires_grad = False - values.requires_grad = True - - inputs = (points, values) - - def func(points, values): - return pytorch_finufft.functional.finufft_type1.apply( - points, values, (N, N + modifier), None, fftshift, dict(isign=isign) - ) - - assert gradcheck(func, inputs) + check_t1_backward(N, modifier, fftshift, isign, "cuda", False) @pytest.mark.parametrize("N", Ns) @@ -176,20 +161,7 @@ def func(points, values): def test_t1_consolidated_backward_cuda_points( N: int, modifier: int, fftshift: bool, isign: int ) -> None: - points = torch.rand((2, N), dtype=torch.float64).to("cuda") * 2 * np.pi - values = torch.randn(N, dtype=torch.complex128).to("cuda") - - points.requires_grad = True - values.requires_grad = False - - inputs = (points, values) - - def func(points, values): - return pytorch_finufft.functional.finufft_type1.apply( - points, values, (N, N + modifier), None, fftshift, dict(isign=isign) - ) - - assert gradcheck(func, inputs, atol=1e-5 * N) + check_t1_backward(N, modifier, fftshift, isign, "cuda", True) @pytest.mark.parametrize("N", Ns) diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 3008f8a..c24d4d0 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -123,16 +123,15 @@ def test_2d_t2_forward_CPU(N: int) -> None: # pass -@pytest.mark.parametrize("N", Ns) -def test_t1_forward_CPU(N: int) -> None: +def check_t1_forward(N: int, device: str) -> None: """ Tests against implementations of the FFT by setting up a uniform grid over which to call FINUFFT through the API. """ g = np.mgrid[:N, :N] * 2 * np.pi / N - points = torch.from_numpy(g.reshape(2, -1)) + points = torch.from_numpy(g.reshape(2, -1)).to(device) - values = torch.randn(*points[0].shape, dtype=torch.complex128) + values = torch.randn(*points[0].shape, dtype=torch.complex128).to(device) print("N is " + str(N)) print("shape of points is " + str(points.shape)) @@ -157,33 +156,10 @@ def test_t1_forward_CPU(N: int) -> None: @pytest.mark.parametrize("N", Ns) -def test_t1_forward_cuda(N: int) -> None: - """ - Tests against implementations of the FFT by setting up a uniform grid - over which to call FINUFFT through the API. - """ - g = np.mgrid[:N, :N] * 2 * np.pi / N - points = torch.from_numpy(g.reshape(2, -1)).to("cuda") - - values = torch.randn(*points[0].shape, dtype=torch.complex128).to("cuda") - - print("N is " + str(N)) - print("shape of points is " + str(points.shape)) - print("shape of values is " + str(values.shape)) - - finufft_out = pytorch_finufft.functional.finufft_type1.apply( - points, - values, - (N, N), - ) +def test_t1_forward_CPU(N: int) -> None: + check_t1_forward(N, "cpu") - against_torch = torch.fft.fft2(values.reshape(g[0].shape)) - abs_errors = torch.abs(finufft_out - against_torch) - l_inf_error = abs_errors.max() - l_2_error = torch.sqrt(torch.sum(abs_errors**2)) - l_1_error = torch.sum(abs_errors) - - assert l_inf_error < 4.5e-5 * N - assert l_2_error < 1e-5 * N**2 - assert l_1_error < 1e-5 * N**3 +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_cuda(N: int) -> None: + check_t1_forward(N, "cuda") diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index dd7c2ed..bc148b1 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -84,16 +84,15 @@ def test_3d_t2_forward_CPU(N: int) -> None: assert l_1_error < 1e-5 * N**4.5 -@pytest.mark.parametrize("N", Ns) -def test_t1_forward_CPU(N: int) -> None: +def check_t1_forward(N: int, device: str) -> None: """ Tests against implementations of the FFT by setting up a uniform grid over which to call FINUFFT through the API. """ g = np.mgrid[:N, :N, :N] * 2 * np.pi / N - points = torch.from_numpy(g.reshape(3, -1)) + points = torch.from_numpy(g.reshape(3, -1)).to(device) - values = torch.randn(*points[0].shape, dtype=torch.complex128) + values = torch.randn(*points[0].shape, dtype=torch.complex128).to(device) print("N is " + str(N)) print("shape of points is " + str(points.shape)) @@ -118,33 +117,10 @@ def test_t1_forward_CPU(N: int) -> None: @pytest.mark.parametrize("N", Ns) -def test_t1_forward_cuda(N: int) -> None: - """ - Tests against implementations of the FFT by setting up a uniform grid - over which to call FINUFFT through the API. - """ - g = np.mgrid[:N, :N, :N] * 2 * np.pi / N - points = torch.from_numpy(g.reshape(3, -1)).to("cuda") - - values = torch.randn(*points[0].shape, dtype=torch.complex128).to("cuda") - - print("N is " + str(N)) - print("shape of points is " + str(points.shape)) - print("shape of values is " + str(values.shape)) - - finufft_out = pytorch_finufft.functional.finufft_type1.apply( - points, - values, - (N, N, N), - ) +def test_t1_forward_CPU(N: int) -> None: + check_t1_forward(N, "cpu") - against_torch = torch.fft.fftn(values.reshape(g[0].shape)) - abs_errors = torch.abs(finufft_out - against_torch) - l_inf_error = abs_errors.max() - l_2_error = torch.sqrt(torch.sum(abs_errors**2)) - l_1_error = torch.sum(abs_errors) - - assert l_inf_error < 1.5e-5 * N**1.5 - assert l_2_error < 1e-5 * N**3 - assert l_1_error < 1e-5 * N**4.5 +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_cuda(N: int) -> None: + check_t1_forward(N, "cuda") From 052cf4144d50ffe425d75046e8847ed37ce8aaaa Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 12:51:32 -0400 Subject: [PATCH 09/17] Formatting --- pytorch_finufft/functional.py | 2 +- tests/test_2d/test_backward_2d.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 3aaf62a..3f410ab 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -4,8 +4,8 @@ from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import finufft +import numpy as np try: import cufinufft diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 751f364..2aba3da 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np import pytest import torch @@ -5,8 +7,6 @@ import pytorch_finufft -from functools import partial - torch.set_default_tensor_type(torch.DoubleTensor) torch.set_default_dtype(torch.float64) torch.manual_seed(0) From 8d8f7485016114ae79ef36d38df8fbe894d8ae76 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 14:10:18 -0400 Subject: [PATCH 10/17] Consolidate yum installs --- ci/docker/Dockerfile-cuda11.8 | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/ci/docker/Dockerfile-cuda11.8 b/ci/docker/Dockerfile-cuda11.8 index dd0a5c2..ce622f7 100644 --- a/ci/docker/Dockerfile-cuda11.8 +++ b/ci/docker/Dockerfile-cuda11.8 @@ -20,8 +20,7 @@ COPY ci/docker/cuda.repo /etc/yum.repos.d/cuda.repo RUN yum install -y \ cuda-cudart-${CUDA_DASH_VERSION} \ cuda-compat-${CUDA_DASH_VERSION} && \ - ln -s cuda-${CUDA_DOT_VERSION} /usr/local/cuda && \ - rm -rf /var/cache/yum/* + ln -s cuda-${CUDA_DOT_VERSION} /usr/local/cuda # nvidia-docker 1.0 RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \ @@ -38,16 +37,11 @@ ENV NVIDIA_REQUIRE_CUDA "cuda>=${CUDA_DOT_VERSION} brand=tesla,driver>=418,drive # runtime RUN yum install -y \ cuda-libraries-${CUDA_DASH_VERSION} \ - cuda-nvtx-${CUDA_DASH_VERSION} && \ - rm -rf /var/cache/yum/* - -# devel -RUN yum install -y \ + cuda-nvtx-${CUDA_DASH_VERSION} \ cuda-cudart-devel-${CUDA_DASH_VERSION} \ cuda-libraries-devel-${CUDA_DASH_VERSION} \ cuda-nvprof-${CUDA_DASH_VERSION} \ - cuda-nvcc-${CUDA_DASH_VERSION} && \ - rm -rf /var/cache/yum/* + cuda-nvcc-${CUDA_DASH_VERSION} ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs @@ -56,14 +50,9 @@ ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs # CUDA 11 doesn't work on gcc/g++ newer than v9 RUN yum install -y \ devtoolset-9-gcc \ - devtoolset-9-gcc-c++ && \ + devtoolset-9-gcc-c++ \ + cmake && \ rm -rf /var/cache/yum/* ENV PATH /opt/rh/devtoolset-9/root/usr/bin:${PATH} -# finufft reqs -RUN yum install -y \ - cmake && \ - rm -rf /var/cache/yum/* - -RUN From ed5c4e543695c2218087382c7db48091dd8d1517 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 14:24:09 -0400 Subject: [PATCH 11/17] CI: Jenkinsfile work --- Jenkinsfile | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 09ee115..a4f55d5 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -6,7 +6,7 @@ pipeline { timeout(time: 1, unit: 'HOURS') } stages { - stage('main') { + stage('CUDA Tests') { agent { dockerfile { filename 'ci/docker/Dockerfile-cuda11.8' @@ -23,16 +23,18 @@ pipeline { steps { // TODO - reconsider install strategy once finufft/cufinufft 2.2 is released - checkout([$class: 'GitSCM', - branches: [[name: '*/master']], - userRemoteConfigs: [[url: "https://github.com/flatironinstitute/finufft"]]] - ) + checkout scmGit(branches: [[name: '*/master']], + extensions: [cloneOption(noTags: true, reference: '', shallow: true), + [$class: 'RelativeTargetDirectory', relativeTargetDir: 'finufft'], + cleanAfterCheckout()], + userRemoteConfigs: [[url: 'https://github.com/flatironinstitute/finufft']]) sh '''#!/bin/bash -ex nvidia-smi ''' sh '''#!/bin/bash -ex echo $HOME + ls ''' sh '''#!/bin/bash -ex cd finufft From 549224367bdc6990f8a5c3715f945922ccd20ed1 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 14:25:47 -0400 Subject: [PATCH 12/17] CI: Fix email --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index a4f55d5..ed6a0a8 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -86,7 +86,7 @@ ${BUILD_LOG,maxLines=200} [$class: 'DevelopersRecipientProvider'], ], replyTo: '$DEFAULT_REPLYTO', - to: 'janden@flatironinstitute.org' + to: 'bward@flatironinstitute.org' } } } From 41cab364566a299483751034ba485366b592f6d8 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 14:31:10 -0400 Subject: [PATCH 13/17] Lint fixes --- pytorch_finufft/functional.py | 17 ++++++++++------- tests/test_2d/test_backward_2d.py | 2 -- tests/test_2d/test_forward_2d.py | 3 ++- tests/test_3d/test_forward_3d.py | 3 ++- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 3f410ab..07d964d 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -11,7 +11,7 @@ import cufinufft CUFINUFFT_AVAIL = True -except: +except ImportError: CUFINUFFT_AVAIL = False import torch @@ -1646,11 +1646,13 @@ def forward( # All this requires is a check on the out array to make sure it is the # correct shape. - err._type1_checks( - points, values, output_shape - ) # revisit these error checks to take into account the shape of points instead of passing them separately - # ^ make sure these checks check for consistency between output shape and len(points) - # need device checks + # TODO: + # revisit these error checks to take into account the shape of points + # instead of passing them separately + # make sure these checks check for consistency between output shape and + # len(points) + # Also need device checks + err._type1_checks(points, values, output_shape) if finufftkwargs is None: finufftkwargs = dict() @@ -1663,7 +1665,8 @@ def forward( # to note instead that there is a conflict in fftshift if _mode_ordering != 1: raise ValueError( - "Double specification of ordering; only one of fftshift and modeord should be provided" + "Double specification of ordering; only one of fftshift and " + "modeord should be provided" ) _mode_ordering = 0 diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 2aba3da..ddde4cf 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -1,5 +1,3 @@ -from functools import partial - import numpy as np import pytest import torch diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index c24d4d0..0ae6d85 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -2,9 +2,10 @@ import pytest import torch +import pytorch_finufft + torch.manual_seed(0) -import pytorch_finufft # Case generation Ns = [ diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index bc148b1..524e9a6 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -2,9 +2,10 @@ import pytest import torch +import pytorch_finufft + torch.manual_seed(0) -import pytorch_finufft # Case generation Ns = [ From 2a4f17d45a3be3bc5d078ec934072afc212664ad Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 16:40:59 -0400 Subject: [PATCH 14/17] Try k40s --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index ed6a0a8..f7cab71 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -11,7 +11,7 @@ pipeline { dockerfile { filename 'ci/docker/Dockerfile-cuda11.8' args '--gpus 2' - label 'v100' + label 'docker && gpu' } } environment { From 80a0dcc5bda31ece774889001903d4f5fa1a75e6 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 17:49:39 -0400 Subject: [PATCH 15/17] Back to v100 --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index f7cab71..2eea4b9 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -11,7 +11,7 @@ pipeline { dockerfile { filename 'ci/docker/Dockerfile-cuda11.8' args '--gpus 2' - label 'docker && gpu' + label 'docker && v100' } } environment { From 03c79d626fb0a94ff98a04cfc563627bd0d0d101 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 17:55:09 -0400 Subject: [PATCH 16/17] CI: Build finufft as well as cufinufft --- Jenkinsfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 2eea4b9..5064921 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -42,8 +42,7 @@ pipeline { cuda_arch="70" cmake -B build . -DFINUFFT_USE_CUDA=ON \ - -DFINUFFT_USE_CPU=OFF \ - -DFINUFFT_BUILD_TESTS=ON \ + -DFINUFFT_BUILD_TESTS=OFF \ -DCMAKE_CUDA_ARCHITECTURES="$cuda_arch" \ -DBUILD_TESTING=ON cd build From f43192001bc173ff30e2ef0e8f45548579bee4da Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 6 Oct 2023 18:20:20 -0400 Subject: [PATCH 17/17] CI: tweaks --- Jenkinsfile | 6 ++--- pytorch_finufft/functional.py | 42 ++++++++++++++++++++++------------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 5064921..c030e39 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -16,7 +16,7 @@ pipeline { } environment { HOME = "$WORKSPACE" - PYBIN = "/opt/python/cp38-cp38/bin" + PYBIN = "/opt/python/cp39-cp39/bin" LIBRARY_PATH = "$WORKSPACE/finufft/build" LD_LIBRARY_PATH = "$WORKSPACE/finufft/build" } @@ -42,6 +42,7 @@ pipeline { cuda_arch="70" cmake -B build . -DFINUFFT_USE_CUDA=ON \ + -DFINUFFT_USE_CPU=OFF \ -DFINUFFT_BUILD_TESTS=OFF \ -DCMAKE_CUDA_ARCHITECTURES="$cuda_arch" \ -DBUILD_TESTING=ON @@ -56,11 +57,10 @@ pipeline { # we could also move pytorch install inside docker python3 -m pip install "torch~=2.1.0" --index-url https://download.pytorch.org/whl/cu118 python3 -m pip install finufft/python/cufinufft - python3 -m pip install finufft/python/finufft python3 -m pip install -e .[dev] - python3 -m pytest -k "cuda" tests/ --cov + python3 -m pytest -k "cuda" tests/ --cov -v ''' } } diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 07d964d..c6fde87 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -4,8 +4,15 @@ from typing import Any, Dict, Optional, Tuple, Union -import finufft import numpy as np +import torch + +try: + import finufft + + FINUFFT_AVAIL = True +except ImportError: + FINUFFT_AVAIL = False try: import cufinufft @@ -13,7 +20,12 @@ CUFINUFFT_AVAIL = True except ImportError: CUFINUFFT_AVAIL = False -import torch + +if not (FINUFFT_AVAIL or CUFINUFFT_AVAIL): + raise ImportError( + "No FINUFFT implementation available. " + "Install either finufft or cufinufft and ensure they are importable." + ) import pytorch_finufft._err as err @@ -1631,7 +1643,7 @@ def forward( ctx: Any, points: torch.Tensor, values: torch.Tensor, - output_shape: Union[int, tuple[int, int], tuple[int, int, int]], + output_shape: Union[int, Tuple[int, int], Tuple[int, int, int]], out: Optional[torch.Tensor] = None, fftshift: bool = False, finufftkwargs: dict[str, Union[int, float]] = None, @@ -1693,21 +1705,21 @@ def forward( @staticmethod def backward( ctx: Any, grad_output: torch.Tensor - ) -> tuple[Union[torch.Tensor, None], ...]: + ) -> Tuple[Union[torch.Tensor, None], ...]: """ - Implements derivatives wrt. each argument in the forward method. + Implements derivatives wrt. each argument in the forward method. - Parameters - ---------- - ctx : Any - Pytorch context object. - grad_output : torch.Tensor - Backpass gradient output + Parameters + ---------- + ctx : Any + Pytorch context object. + grad_output : torch.Tensor + Backpass gradient output - Returns - ------- - tuple[Union[torch.Tensor, None], ...] - A tuple of derivatives wrt. each argument in the forward method + Returns + ------- + Tuple[Union[torch.Tensor, None], ...] + A tuple of derivatives wrt. each argument in the forward method """ _i_sign = -1 * ctx.isign _mode_ordering = ctx.mode_ordering