-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Starting to add GPU support using cuFINUFFT #4
Conversation
As a note to self: the modules I'm testing with on my CCA machine are:
|
@dfm I've made some progress on the compiled side of things! The pybind11 module seems to compile against cufinufft now. I did start some fumbling attempts at hooking up the JAX calls, but I'm definitely out of my depth on that. I've pushed what I've done so far, which fails with I've been testing with |
@lgarrison: amazing!! The next few days are a little crazy, but I'll take a look ASAP and see if I can put together the glue. |
@lgarrison: I took a stab at this and the Python part is actually way simpler than you were trying :D The Type 1, 2D tests are all passing, and the float32 3D tests are passing, but I'm getting some issues with the other ops both in terms of numerics and memory issues. Want to take a look and see if you can track it down? I won't be able to do much for the next few days. |
Thanks, this is so, so much better than what I was trying! Will definitely try this out and look into the test failures. |
@dfm I think the GPU tests might be fixed! Want to try it out? A few notes/todos:
|
This is awesome! I'll test it too. I don't know much about all this, but responses inline:
I have no idea so I trust your judgement :D
I don't think this is great. JAX supports assigning data and computation to specific devices - I think that would fail here? See a discussion here and an example here.
I think this is related to the above comment. JAX will want to be able to assign our computation to a specific device (and manage the device memory accordingly). I think that's managed by the |
I think it wouldn't be so bad to add stream support to CUFINUFFT. It could be implemented as an optional C++ argument to the existing API. The other piece seems to be passing the CUDA device ID to CUFINUFFT; I'm sure we can convince the stream or JAX to give us the ID. Would be good to hear from @JBlaschke and the other CUFINUFFT devs if this sounds like the right approach! I'd be happy to open an upstream PR. |
Hey @dfm and @lgarrison -- I fully support making cufiNUFFT compatible with streams :) I started working on multi-gpu support that doesn't clobber pre-existing contexts (thanks NVIDIA for implementing API contexts that are inconsistent with the Driver) here: flatironinstitute/cufinufft#99 Anyway, that work stalled a little (our "temporary" fixes ended up being permanent enough). And if y'all are willing to wait till after Thanksgiving, then this would be a good excuse to finish that work. This is what I am thinking:
Mode (1) is the "we got this" mode, where the calling software is responsible for maintaining consistency between the data pointers and the active stream, context, and device. This mode is fine -- if you remember the set the correct stream, device, and/or context before calling cufiNUFFT. However, I find that this is not always safe [1], and therefore mode (2) make sense. Furthermore, relying on the environment to be consistent with device pointers can be difficult when using task-based parallelism. So I think that mode (2) is not just idealism. [1] -- NB: I think on a multi-GPU system (or when using multiple streams/contexts) it is good form to specify the device ID (stream ID) together with data pointers -- as this is a "complete" descriptor of where the data is. However is only an opinion. |
Thanks, that sounds great! I'll definitely defer to your expertise on CUDA contexts, as I haven't had to directly manage them before. I probably have a few development cycles to spare for this if that would be helpful. Otherwise, I'm happy to wait for the larger refactoring to be done (or even implement just the streams part if the larger refactoring ends up more complicated). Out of curiosity, I did try compiling the JAX extensions against your branch in flatironinstitute/cufinufft#99, and it seems to build but segfaults on import. Might be something silly on my end, or maybe this is an expected failure at this point? |
Closing in favor of #20. |
So far I just have the CMake definitions to compile cuFINUFFT when nvcc is found, but I haven't started writing the boilerplate needed to loop it into XLA. Coming soon!
Keeping @lgarrison in the loop.