Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vendored finufft and add GPU support #20

Merged
merged 28 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8a4e748
starting to add optional cuda support
dfm Nov 8, 2021
2e827c0
include dirs for cuda
dfm Nov 8, 2021
e50c233
Merge branch 'main' of https://github.com/dfm/jax-finufft into gpu
dfm Nov 9, 2021
94c5bbf
getting cufinufft to compile
dfm Nov 9, 2021
974b0e9
adding first pass at gpu kernels
dfm Nov 10, 2021
9ffc006
order of parameters
dfm Nov 10, 2021
9473643
Minor refactoring to support GPU
lgarrison Nov 12, 2021
1b12da0
Maybe sort-of calling all the right functions?
lgarrison Nov 12, 2021
fe79e97
Add FindCUDAToolkit to cmake to bring in cufft
lgarrison Nov 15, 2021
1295a1c
Trying to hook up Jax CUDA ops
lgarrison Nov 15, 2021
93ba4b3
Don't fail on no CUDA
lgarrison Nov 15, 2021
65f399c
first pass at getting GPU ops to work
dfm Nov 16, 2021
1d94e9e
Fix GPU tests
lgarrison Nov 18, 2021
b4c7dea
vendor: update vendored finufft version to latest and fix deprecations
lgarrison Oct 11, 2023
b195154
Merge branch 'gpu' into 2023-gpu
lgarrison Oct 16, 2023
23e5e1b
gpu: use new cufinufft API and change CMake to reflect the fact that …
lgarrison Oct 19, 2023
d2ae165
xla: uppercase CUDA doesn't work anymore, use cuda. GPU tests now run…
lgarrison Oct 19, 2023
c770a54
gpu: fix extraneous translation_rule arg
lgarrison Oct 20, 2023
f4ac665
gpu: custom call target registration uses capital CUDA, while transla…
lgarrison Oct 20, 2023
c36b6c3
gpu: use x64 for some tests that were off by 1.1e-7
lgarrison Oct 20, 2023
097be09
gpu: skip some 1D tests
lgarrison Oct 20, 2023
37db2f0
cmake: get colored output through ninja
lgarrison Oct 20, 2023
e02f231
gpu: use the CUDA stream provided by JAX
lgarrison Oct 20, 2023
5cbb18c
vendor: use lgarrison fork of finufft until flatironinstitute/finufft…
lgarrison Oct 20, 2023
1ca71ee
Merge branch 'main' into 2023-gpu
lgarrison Oct 20, 2023
d3c8ccf
Fixes for modern JAX: block until CUDA operations complete. Import ja…
lgarrison Oct 24, 2023
e3db065
Probably don't need to sync the stream, JAX ought to do that. But we …
lgarrison Oct 27, 2023
e3230c6
vendor: update finufft
lgarrison Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ _skbuild
dist
MANIFEST
__pycache__/
*.egg-info
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "finufft"]
path = vendor/finufft
url = https://github.com/flatironinstitute/finufft
url = https://github.com/lgarrison/finufft
92 changes: 81 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ cmake_minimum_required(VERSION 3.15)
project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX)
message(STATUS "Using CMake version: " ${CMAKE_VERSION})

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR})
# Add the /cmake directory to the module path so that we can find FFTW
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/cmake)

# Handle Python settings passed from scikit-build
if(SKBUILD)
set(Python_EXECUTABLE "${PYTHON_EXECUTABLE}")
set(Python_INCLUDE_DIR "${PYTHON_INCLUDE_DIR}")
Expand All @@ -13,23 +15,29 @@ if(SKBUILD)
OUTPUT_VARIABLE _tmp_dir
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
else()
find_package(Python COMPONENTS Interpreter Development REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())"
OUTPUT_VARIABLE _tmp_dir
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
endif()

set(PYBIND11_NEWPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)
find_package(FFTW REQUIRED COMPONENTS FLOAT_LIB DOUBLE_LIB)
link_libraries(${FFTW_FLOAT_LIB} ${FFTW_DOUBLE_LIB})

# Work out compiler flags
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
add_compile_options(-Wall -O3 -funroll-loops)

add_compile_options(-Wall -Wno-unknown-pragmas -O3 -funroll-loops -fdiagnostics-color)
set(FINUFFT_INCLUDE_DIRS
${CMAKE_CURRENT_LIST_DIR}/lib
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include
${FFTW_INCLUDE_DIRS})

message(STATUS "FINUFFT include dirs: " "${FINUFFT_INCLUDE_DIRS}")

# Build single and double point versions of the FINUFFT library
add_library(finufft STATIC
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/spreadinterp.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils.cpp
Expand All @@ -45,10 +53,72 @@ add_library(finufft_32 STATIC
target_compile_definitions(finufft_32 PUBLIC SINGLE)
target_include_directories(finufft_32 PRIVATE ${FINUFFT_INCLUDE_DIRS})

pybind11_add_module(jax_finufft
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft.cc
# Build the XLA bindings to those libraries
pybind11_add_module(jax_finufft_cpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils_precindep.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.c)
target_link_libraries(jax_finufft PRIVATE finufft finufft_32)
target_include_directories(jax_finufft PRIVATE ${FINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft DESTINATION .)
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.cpp)

target_link_libraries(jax_finufft_cpu PRIVATE finufft finufft_32)
target_include_directories(jax_finufft_cpu PRIVATE ${FINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft_cpu DESTINATION .)

include(CheckLanguage)
check_language(CUDA)
if (CMAKE_CUDA_COMPILER)
enable_language(CUDA)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80;90")
endif()

# Find cufft
find_package(CUDAToolkit)

set(CUFINUFFT_INCLUDE_DIRS
${CMAKE_CURRENT_LIST_DIR}/lib
# ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include/cufinufft/contrib/
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include/cufinufft/contrib/cuda_samples
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})

set(CUFINUFFT_SOURCES
# TODO: 1D not supported via JAX, but needed for compilation
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/1d/spread1d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/1d/interp1d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/1d/cufinufft1d.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/2d/spread2d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/2d/interp2d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/2d/cufinufft2d.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/3d/spread3d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/3d/interp3d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/3d/cufinufft3d.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/deconvolve_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/memtransfer_wrapper.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/spreadinterp.cpp

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/precision_independent.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/utils.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/cufinufft.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/common.cu)

add_library(cufinufft STATIC ${CUFINUFFT_SOURCES})
target_include_directories(cufinufft PRIVATE ${CUFINUFFT_INCLUDE_DIRS})

pybind11_add_module(jax_finufft_gpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
target_link_libraries(jax_finufft_gpu PRIVATE cufinufft)
target_link_libraries(jax_finufft_gpu PRIVATE ${CUDA_cufft_LIBRARY} ${CUDA_nvToolsExt_LIBRARY})
target_include_directories(jax_finufft_gpu PRIVATE ${CUFINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft_gpu DESTINATION .)

else()
message(STATUS "No CUDA compiler found; GPU support will be disabled")
endif()
File renamed without changes.
21 changes: 21 additions & 0 deletions lib/jax_finufft_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef _JAX_FINUFFT_COMMON_H_
#define _JAX_FINUFFT_COMMON_H_

// This descriptor is common to both the jax_finufft and jax_finufft_gpu modules
// We will use the jax_finufft namespace for both

namespace jax_finufft {

template <typename T>
struct NufftDescriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
};

}

#endif
6 changes: 4 additions & 2 deletions lib/jax_finufft.cc → lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "pybind11_kernel_helpers.h"

#include "jax_finufft_cpu.h"

using namespace jax_finufft;

namespace {
Expand All @@ -15,7 +17,7 @@ void run_nufft(int type, void *desc_in, T *x, T *y, T *z, std::complex<T> *c, st
int64_t n_k = 1;
for (int d = 0; d < ndim; ++d) n_k *= descriptor->n_k[d];

nufft_opts *opts = new nufft_opts;
finufft_opts *opts = new finufft_opts;
default_opts<T>(opts);

typename plan_type<T>::type plan;
Expand Down Expand Up @@ -86,7 +88,7 @@ pybind11::dict Registrations() {
return dict;
}

PYBIND11_MODULE(jax_finufft, m) {
PYBIND11_MODULE(jax_finufft_cpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);
Expand Down
22 changes: 6 additions & 16 deletions lib/jax_finufft.h → lib/jax_finufft_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@

namespace jax_finufft {

template <typename T>
struct NufftDescriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
};

template <typename T>
struct plan_type;

Expand All @@ -31,31 +21,31 @@ struct plan_type<float> {
};

template <typename T>
void default_opts(nufft_opts* opts);
void default_opts(finufft_opts* opts);

template <>
void default_opts<float>(nufft_opts* opts) {
void default_opts<float>(finufft_opts* opts) {
finufftf_default_opts(opts);
}

template <>
void default_opts<double>(nufft_opts* opts) {
void default_opts<double>(finufft_opts* opts) {
finufft_default_opts(opts);
}

template <typename T>
int makeplan(int type, int dim, int64_t* nmodes, int iflag, int ntr, T eps,
typename plan_type<T>::type* plan, nufft_opts* opts);
typename plan_type<T>::type* plan, finufft_opts* opts);

template <>
int makeplan<float>(int type, int dim, int64_t* nmodes, int iflag, int ntr, float eps,
typename plan_type<float>::type* plan, nufft_opts* opts) {
typename plan_type<float>::type* plan, finufft_opts* opts) {
return finufftf_makeplan(type, dim, nmodes, iflag, ntr, eps, plan, opts);
}

template <>
int makeplan<double>(int type, int dim, int64_t* nmodes, int iflag, int ntr, double eps,
typename plan_type<double>::type* plan, nufft_opts* opts) {
typename plan_type<double>::type* plan, finufft_opts* opts) {
return finufft_makeplan(type, dim, nmodes, iflag, ntr, eps, plan, opts);
}

Expand Down
37 changes: 37 additions & 0 deletions lib/jax_finufft_gpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// This file defines the Python interface to the XLA custom call implemented on the CPU.
// It is exposed as a standard pybind11 module defining "capsule" objects containing our
// method. For simplicity, we export a separate capsule for each supported dtype.

#include "pybind11_kernel_helpers.h"
#include "kernels.h"

using namespace jax_finufft;

namespace {

pybind11::dict Registrations() {
pybind11::dict dict;

// TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"?
// dict["nufft1d1f"] = encapsulate_function(nufft1d1f);
// dict["nufft1d2f"] = encapsulate_function(nufft1d2f);
dict["nufft2d1f"] = encapsulate_function(nufft2d1f);
dict["nufft2d2f"] = encapsulate_function(nufft2d2f);
dict["nufft3d1f"] = encapsulate_function(nufft3d1f);
dict["nufft3d2f"] = encapsulate_function(nufft3d2f);

// dict["nufft1d1"] = encapsulate_function(nufft1d1);
// dict["nufft1d2"] = encapsulate_function(nufft1d2);
dict["nufft2d1"] = encapsulate_function(nufft2d1);
dict["nufft2d2"] = encapsulate_function(nufft2d2);
dict["nufft3d1"] = encapsulate_function(nufft3d1);
dict["nufft3d2"] = encapsulate_function(nufft3d2);

return dict;
}

PYBIND11_MODULE(jax_finufft_gpu, m) {
m.def("registrations", &Registrations);
}

} // namespace
Loading
Loading