Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starting to add GPU support using cuFINUFFT #4

Closed
wants to merge 13 commits into from
Closed

Starting to add GPU support using cuFINUFFT #4

wants to merge 13 commits into from

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Nov 9, 2021

So far I just have the CMake definitions to compile cuFINUFFT when nvcc is found, but I haven't started writing the boilerplate needed to loop it into XLA. Coming soon!

Keeping @lgarrison in the loop.

@dfm
Copy link
Collaborator Author

dfm commented Nov 9, 2021

As a note to self: the modules I'm testing with on my CCA machine are: module load fftw/3 gcc/10 cuda cmake

> module list
Currently Loaded Modules:
  1) slurm (S)   2) cuda/11.4.2   3) gcc/10.2.0   4) openblas/0.3.15-threaded (S)   5) fftw/3.3.10   6) cmake/3.21.4

@lgarrison
Copy link
Member

@dfm I've made some progress on the compiled side of things! The pybind11 module seems to compile against cufinufft now. I did start some fumbling attempts at hooking up the JAX calls, but I'm definitely out of my depth on that. I've pushed what I've done so far, which fails with NotImplementedError: XLA translation rule for primitive 'nufft1' not found. Probably needs your eyes on it; I'm not even sure I'm on the right track!

I've been testing with pytest -x tests/gpu_ops_test.py.

@dfm
Copy link
Collaborator Author

dfm commented Nov 15, 2021

@lgarrison: amazing!! The next few days are a little crazy, but I'll take a look ASAP and see if I can put together the glue.

@dfm
Copy link
Collaborator Author

dfm commented Nov 16, 2021

@lgarrison: I took a stab at this and the Python part is actually way simpler than you were trying :D

The Type 1, 2D tests are all passing, and the float32 3D tests are passing, but I'm getting some issues with the other ops both in terms of numerics and memory issues. Want to take a look and see if you can track it down? I won't be able to do much for the next few days.

@lgarrison
Copy link
Member

Thanks, this is so, so much better than what I was trying! Will definitely try this out and look into the test failures.

@lgarrison
Copy link
Member

@dfm I think the GPU tests might be fixed! Want to try it out?

A few notes/todos:

  • I had to set a CUFINUFFT option to use a slower(?) algorithm for 3D, double precision (that was what I understood from error: not enough shared memory when setting tol to less than 1e-3 cufinufft#58). There might be a solution in changing the bin size and/or increasing the shared memory limit; I don't quite understand that part of the code. Haven't checked the performance impact.
  • I added CUDA arch 80 to cmake to test on A100s; is this always okay for whatever minimum CUDA version we're requiring?
  • None of this is using the CUDA stream that JAX is passing us... is that safe?? Will it do anything asynchronous without that? I'll do some tests.
  • Will JAX use multiple GPUs? How does that affect how we call CUFINUFFT?

@dfm
Copy link
Collaborator Author

dfm commented Nov 18, 2021

This is awesome! I'll test it too. I don't know much about all this, but responses inline:

I added CUDA arch 80 to cmake to test on A100s; is this always okay for whatever minimum CUDA version we're requiring?

I have no idea so I trust your judgement :D

None of this is using the CUDA stream that JAX is passing us... is that safe?? Will it do anything asynchronous without that? I'll do some tests.

I don't think this is great. JAX supports assigning data and computation to specific devices - I think that would fail here? See a discussion here and an example here.

Will JAX use multiple GPUs? How does that affect how we call CUFINUFFT?

I think this is related to the above comment. JAX will want to be able to assign our computation to a specific device (and manage the device memory accordingly). I think that's managed by the stream parameter (this is beyond the scope of my CUDA knowledge)? Do you have any sense how hard it would be to support this? We could check in with the cufinufft developers because it sounded like @JBlaschke had already made some progress on multi-device support.

@lgarrison
Copy link
Member

I think it wouldn't be so bad to add stream support to CUFINUFFT. It could be implemented as an optional C++ argument to the existing API. The other piece seems to be passing the CUDA device ID to CUFINUFFT; I'm sure we can convince the stream or JAX to give us the ID.

Would be good to hear from @JBlaschke and the other CUFINUFFT devs if this sounds like the right approach! I'd be happy to open an upstream PR.

@JBlaschke
Copy link

Hey @dfm and @lgarrison -- I fully support making cufiNUFFT compatible with streams :)

I started working on multi-gpu support that doesn't clobber pre-existing contexts (thanks NVIDIA for implementing API contexts that are inconsistent with the Driver) here: flatironinstitute/cufinufft#99

Anyway, that work stalled a little (our "temporary" fixes ended up being permanent enough). And if y'all are willing to wait till after Thanksgiving, then this would be a good excuse to finish that work.

This is what I am thinking:

  1. If no arguments are given, then cufiNUFFT will use the currently active context (if none are active, use the primary context).
  2. If the user specifies stream, context, and/or device -- then cufinufft will then use those.

Mode (1) is the "we got this" mode, where the calling software is responsible for maintaining consistency between the data pointers and the active stream, context, and device. This mode is fine -- if you remember the set the correct stream, device, and/or context before calling cufiNUFFT. However, I find that this is not always safe [1], and therefore mode (2) make sense. Furthermore, relying on the environment to be consistent with device pointers can be difficult when using task-based parallelism. So I think that mode (2) is not just idealism.

[1] -- NB: I think on a multi-GPU system (or when using multiple streams/contexts) it is good form to specify the device ID (stream ID) together with data pointers -- as this is a "complete" descriptor of where the data is. However is only an opinion.

@lgarrison
Copy link
Member

Thanks, that sounds great! I'll definitely defer to your expertise on CUDA contexts, as I haven't had to directly manage them before.

I probably have a few development cycles to spare for this if that would be helpful. Otherwise, I'm happy to wait for the larger refactoring to be done (or even implement just the streams part if the larger refactoring ends up more complicated).

Out of curiosity, I did try compiling the JAX extensions against your branch in flatironinstitute/cufinufft#99, and it seems to build but segfaults on import. Might be something silly on my end, or maybe this is an expected failure at this point?

@lgarrison
Copy link
Member

Closing in favor of #20.

@lgarrison lgarrison closed this Oct 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants