Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starting to add GPU support using cuFINUFFT #4

Closed
wants to merge 13 commits into from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ build
_skbuild
dist
MANIFEST
__pycache__
*.egg-info
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "finufft"]
path = vendor/finufft
url = https://github.com/flatironinstitute/finufft
[submodule "vendor/cufinufft"]
path = vendor/cufinufft
url = https://github.com/flatironinstitute/cufinufft
97 changes: 86 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
cmake_minimum_required(VERSION 3.12)
project(jax_finufft LANGUAGES C CXX)
message(STATUS "Using CMake version: " ${CMAKE_VERSION})

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR})
# Add the /cmake directory to the module path so that we can find FFTW
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/cmake)

# Handle Python settings passed from scikit-build
if(SKBUILD)
set(Python_EXECUTABLE "${PYTHON_EXECUTABLE}")
set(Python_INCLUDE_DIR "${PYTHON_INCLUDE_DIR}")
Expand All @@ -13,22 +14,29 @@ if(SKBUILD)
OUTPUT_VARIABLE _tmp_dir
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
else()
find_package(Python COMPONENTS Interpreter Development REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())"
OUTPUT_VARIABLE _tmp_dir
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
endif()

# Find and link pybind11 and fftw
find_package(pybind11 CONFIG REQUIRED)
find_package(FFTW REQUIRED COMPONENTS FLOAT_LIB DOUBLE_LIB)
link_libraries(${FFTW_FLOAT_LIB} ${FFTW_DOUBLE_LIB})

# Work out compiler flags
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
add_compile_options(-Wall -O3 -funroll-loops)

add_compile_options(-Wall -Wno-unknown-pragmas -O3 -funroll-loops)
set(FINUFFT_INCLUDE_DIRS
${CMAKE_CURRENT_LIST_DIR}/lib
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include
${FFTW_INCLUDE_DIRS})

message(STATUS "FINUFFT include dirs: " "${FINUFFT_INCLUDE_DIRS}")

# Build single and double point versions of the FINUFFT library
add_library(finufft STATIC
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/spreadinterp.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils.cpp
Expand All @@ -44,10 +52,77 @@ add_library(finufft_32 STATIC
target_compile_definitions(finufft_32 PUBLIC SINGLE)
target_include_directories(finufft_32 PRIVATE ${FINUFFT_INCLUDE_DIRS})

pybind11_add_module(jax_finufft
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft.cc
# Build the XLA bindings to those libraries
pybind11_add_module(jax_finufft_cpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils_precindep.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.c)
target_link_libraries(jax_finufft PRIVATE finufft finufft_32)
target_include_directories(jax_finufft PRIVATE ${FINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft DESTINATION .)
target_link_libraries(jax_finufft_cpu PRIVATE finufft finufft_32)
target_include_directories(jax_finufft_cpu PRIVATE ${FINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft_cpu DESTINATION .)

include(CheckLanguage)
check_language(CUDA)
if (CMAKE_CUDA_COMPILER)
enable_language(CUDA)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES "52;60;61;70;75;80")
endif()

# Find cufft
find_package(CUDAToolkit)

set(CUFINUFFT_INCLUDE_DIRS
${CMAKE_CURRENT_LIST_DIR}/lib
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/include
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/cuda_samples
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})

set(CUFINUFFT_SOURCES
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/spreadinterp2d.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/spread2d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/spread2d_wrapper_paul.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/interp2d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/cufinufft2d.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/3d/spreadinterp3d.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/3d/spread3d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/3d/interp3d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/3d/cufinufft3d.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/memtransfer_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/deconvolve_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/cufinufft.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/dirft2d.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/common.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/spreadinterp.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/utils_fp.cpp)

add_library(cufinufft STATIC
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/precision_independent.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/profile.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/legendre_rule_fast.c
${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/utils.cpp)
target_include_directories(cufinufft PRIVATE ${CUFINUFFT_INCLUDE_DIRS})

add_library(cufinufft_64 STATIC ${CUFINUFFT_SOURCES})
target_include_directories(cufinufft_64 PRIVATE ${CUFINUFFT_INCLUDE_DIRS})

add_library(cufinufft_32 STATIC ${CUFINUFFT_SOURCES})
target_compile_definitions(cufinufft_32 PUBLIC SINGLE)
target_include_directories(cufinufft_32 PRIVATE ${CUFINUFFT_INCLUDE_DIRS})

pybind11_add_module(jax_finufft_gpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
target_link_libraries(jax_finufft_gpu PRIVATE cufinufft cufinufft_32 cufinufft_64)
target_link_libraries(jax_finufft_gpu PRIVATE ${CUDA_cufft_LIBRARY} ${CUDA_nvToolsExt_LIBRARY})
target_include_directories(jax_finufft_gpu PRIVATE ${CUFINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft_gpu DESTINATION .)


else()
message(STATUS "No CUDA compiler found; GPU support will be disabled")
endif()
File renamed without changes.
21 changes: 21 additions & 0 deletions lib/jax_finufft_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef _JAX_FINUFFT_COMMON_H_
#define _JAX_FINUFFT_COMMON_H_

// This descriptor is common to both the jax_finufft and jax_finufft_gpu modules
// We will use the jax_finufft namespace for both

namespace jax_finufft {

template <typename T>
struct NufftDescriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
};

}

#endif
4 changes: 3 additions & 1 deletion lib/jax_finufft.cc → lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "pybind11_kernel_helpers.h"

#include "jax_finufft_cpu.h"

using namespace jax_finufft;

namespace {
Expand Down Expand Up @@ -85,7 +87,7 @@ pybind11::dict Registrations() {
return dict;
}

PYBIND11_MODULE(jax_finufft, m) {
PYBIND11_MODULE(jax_finufft_cpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);
Expand Down
10 changes: 0 additions & 10 deletions lib/jax_finufft.h → lib/jax_finufft_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@

namespace jax_finufft {

template <typename T>
struct NufftDescriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
};

template <typename T>
struct plan_type;

Expand Down
38 changes: 38 additions & 0 deletions lib/jax_finufft_gpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// This file defines the Python interface to the XLA custom call implemented on the CPU.
// It is exposed as a standard pybind11 module defining "capsule" objects containing our
// method. For simplicity, we export a separate capsule for each supported dtype.

#include "pybind11_kernel_helpers.h"
#include "jax_finufft_gpu.h"
#include "kernels.h"

using namespace jax_finufft;

namespace {

pybind11::dict Registrations() {
pybind11::dict dict;

// TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"?
// dict["nufft1d1f"] = encapsulate_function(nufft1d1f);
// dict["nufft1d2f"] = encapsulate_function(nufft1d2f);
dict["nufft2d1f"] = encapsulate_function(nufft2d1f);
dict["nufft2d2f"] = encapsulate_function(nufft2d2f);
dict["nufft3d1f"] = encapsulate_function(nufft3d1f);
dict["nufft3d2f"] = encapsulate_function(nufft3d2f);

// dict["nufft1d1"] = encapsulate_function(nufft1d1);
// dict["nufft1d2"] = encapsulate_function(nufft1d2);
dict["nufft2d1"] = encapsulate_function(nufft2d1);
dict["nufft2d2"] = encapsulate_function(nufft2d2);
dict["nufft3d1"] = encapsulate_function(nufft3d1);
dict["nufft3d2"] = encapsulate_function(nufft3d2);

return dict;
}

PYBIND11_MODULE(jax_finufft_gpu, m) {
m.def("registrations", &Registrations);
}

} // namespace
139 changes: 139 additions & 0 deletions lib/jax_finufft_gpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#ifndef _JAX_FINUFFT_GPU_H_
#define _JAX_FINUFFT_GPU_H_

#include <complex>

#include "cufinufft.h"

namespace jax_finufft {

template <typename T>
struct plan_type;

template <>
struct plan_type<double> {
typedef cufinufft_plan type;
};

template <>
struct plan_type<float> {
typedef cufinufftf_plan type;
};

template <typename T>
void default_opts(int type, int dim, cufinufft_opts* opts);

template <>
void default_opts<float>(int type, int dim, cufinufft_opts* opts) {
cufinufftf_default_opts(type, dim, opts);
}

template <>
void default_opts<double>(int type, int dim, cufinufft_opts* opts) {
cufinufft_default_opts(type, dim, opts);

// double precision in 3D blows out shared memory.
// Fall back to a slower, non-shared memory algorithm
// https://github.com/flatironinstitute/cufinufft/issues/58
if(dim > 2){
opts->gpu_method = 1;
}
}

template <typename T>
int makeplan(int type, int dim, int* nmodes, int iflag, int ntr, T eps, int batch,
typename plan_type<T>::type* plan, cufinufft_opts* opts);

template <>
int makeplan<float>(int type, int dim, int* nmodes, int iflag, int ntr, float eps, int batch,
typename plan_type<float>::type* plan, cufinufft_opts* opts) {
return cufinufftf_makeplan(type, dim, nmodes, iflag, ntr, eps, batch, plan, opts);
}

template <>
int makeplan<double>(int type, int dim, int* nmodes, int iflag, int ntr, double eps, int batch,
typename plan_type<double>::type* plan, cufinufft_opts* opts) {
return cufinufft_makeplan(type, dim, nmodes, iflag, ntr, eps, batch, plan, opts);
}

template <typename T>
int setpts(typename plan_type<T>::type plan, int64_t M, T* x, T* y, T* z, int64_t N, T* s, T* t,
T* u);

template <>
int setpts<float>(typename plan_type<float>::type plan, int64_t M, float* x, float* y, float* z,
int64_t N, float* s, float* t, float* u) {
return cufinufftf_setpts(M, x, y, z, N, s, t, u, plan);
}

template <>
int setpts<double>(typename plan_type<double>::type plan, int64_t M, double* x, double* y,
double* z, int64_t N, double* s, double* t, double* u) {
return cufinufft_setpts(M, x, y, z, N, s, t, u, plan);
}

template <typename T>
int execute(typename plan_type<T>::type plan, std::complex<T>* c, std::complex<T>* f);

template <>
int execute<float>(typename plan_type<float>::type plan, std::complex<float>* c,
std::complex<float>* f) {
cuFloatComplex* _c = reinterpret_cast<cuFloatComplex*>(c);
cuFloatComplex* _f = reinterpret_cast<cuFloatComplex*>(f);
return cufinufftf_execute(_c, _f, plan);
}

template <>
int execute<double>(typename plan_type<double>::type plan, std::complex<double>* c,
std::complex<double>* f) {
cuDoubleComplex* _c = reinterpret_cast<cuDoubleComplex*>(c);
cuDoubleComplex* _f = reinterpret_cast<cuDoubleComplex*>(f);
return cufinufft_execute(_c, _f, plan);
}

template <typename T>
void destroy(typename plan_type<T>::type plan);

template <>
void destroy<float>(typename plan_type<float>::type plan) {
cufinufftf_destroy(plan);
}

template <>
void destroy<double>(typename plan_type<double>::type plan) {
cufinufft_destroy(plan);
}

template <int ndim, typename T>
T* y_index(T* y, int64_t index) {
return &(y[index]);
}

template <>
double* y_index<1, double>(double* y, int64_t index) {
return NULL;
}

template <>
float* y_index<1, float>(float* y, int64_t index) {
return NULL;
}

template <int ndim, typename T>
T* z_index(T* z, int64_t index) {
return NULL;
}

template <>
double* z_index<3, double>(double* z, int64_t index) {
return &(z[index]);
}

template <>
float* z_index<3, float>(float* z, int64_t index) {
return &(z[index]);
}

} // namespace jax_finufft

#endif
2 changes: 2 additions & 0 deletions lib/kernel_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <string>
#include <type_traits>

#include "jax_finufft_common.h"

namespace jax_finufft {

// https://en.cppreference.com/w/cpp/numeric/bit_cast
Expand Down
Loading