Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vendored finufft and add GPU support #20

Merged
merged 28 commits into from
Oct 30, 2023
Merged

Update vendored finufft and add GPU support #20

merged 28 commits into from
Oct 30, 2023

Conversation

lgarrison
Copy link
Member

In recent months, cufinufft has been merged into the primary finufft codebase (thanks to @blackwer), which itself has matured a lot over the last few years. The vendored finufft hasn't been updated in a long time, so this PR does that, fixes some minor compatibility issues, and removes the deprecated vendored cufinufft. It also adds GPU support.

The GPU support comes from #4 but with an important update, which is that we can now pass the GPU stream that JAX gives us to the cufinufft library via flatironinstitute/finufft#330. This is probably required for multi-GPU support (although I haven't tested it) and may help with performance too. The only other changes were various updates to match the new API, build fixes, etc.

Some notes about the CMake build: cufinufft no longer requires building two static libraries with different macro definitions (one for float and one for double). The multiple precisions are supported via C++ templates, so the library can be built all at once. Also, now that finufft itself uses CMake, we might prefer to include it as a CMake sub-package rather than itemizing source files. But I left the itemization approach for now, since it required fewer changes.

For the moment, the vendored finufft is maintained in my own fork while we wait for some important PRs (including flatironinstitute/finufft#330 and flatironinstitute/finufft#354) to be merged. It should be easy to re-target the primary repo later.

The tests pass on the GPU and CPU, but more help testing the GPU in particular would be great! For anybody looking to run this on the Flatiron clusters, this is my build environment:

env.sh
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cmake
ml cuda/12
ml cudnn

dfm and others added 25 commits November 8, 2021 15:03
Passes CPU tests; GPU compilation still needs to be fixed for finufft refactor.
…the single and double precision interfaces are compiled together now
for n in range(num_repeat):
np.testing.assert_allclose(
calc_unmap_pt[n], func(c[n], *(x_[n] for x_ in x[:-1]), x[-1][0])
with jax.experimental.enable_x64():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this here? Perhaps we could adjust the tolerance for the allclose calls below to test both? I think the jax._src.public_test_util.check_close might do what we want.

@dfm
Copy link
Collaborator

dfm commented Oct 23, 2023

This is awesome! I wonder if it's worth getting the Flatiron Jenkins CI set up to test the GPU ops? I've been doing that with the JAX CUDA ops in my exoplanet-core package: https://github.com/exoplanet-dev/exoplanet-core/blob/main/ci/Jenkinsfile

I'd like to update all the custom call stuff because the best practices have changed a little in the meantime, but that should be a separate PR.

For now, I'm keen to merge this with the only question being about running a GPU-enabled CI.

@lgarrison
Copy link
Member Author

Yeah, I think it definitely makes sense to set this up on Jenkins! I actually started looking into this, but @Matematija reported an non-deterministic crash that I'm looking into first, might be when using the latest JAX. I'll keep you posted.

…x.experimental. Point to vendored finufft with more fixes.
@lgarrison
Copy link
Member Author

The problems appear at JAX 0.4.9, which is also the version where JAX starts using a non-blocking CUDA stream, according to jax-ml/jax#16580. I've gone through and fixed some cufinufft stream race conditions for kernel launches, cudaMemcpyAsync, Thrust, and (maybe) cufft, which resolves the particular crash that @Matematija reported, but there are still problems, because the tests don't pass. I'll need to dig into why (I'll probably be delayed in working on this by jury duty).

@lgarrison
Copy link
Member Author

I think this is fixed now! The relevant patch is in the finufft submodule, so don't forget to update submodules when you pull. I'll work on Jenkins next.

@lgarrison
Copy link
Member Author

@dfm Can you add the Jenkins webhook to this project, or give me permissions to do so? I think I would need to be a repo admin.

@dfm dfm merged commit b2b2cd0 into main Oct 30, 2023
2 checks passed
@lgarrison lgarrison deleted the 2023-gpu branch November 2, 2023 17:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants