diff --git a/src/function.rs b/src/function.rs index 5b732a8..e3eeec4 100644 --- a/src/function.rs +++ b/src/function.rs @@ -411,6 +411,39 @@ macro_rules! launch { }; } +/// Launch a cooperative kernel function asynchronously. +#[macro_export] +macro_rules! launch_cooperative { + ($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => { + { + let name = std::ffi::CString::new(stringify!($function)).unwrap(); + let function = $module.get_function(&name); + match function { + Ok(f) => launch_cooperative!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ), + Err(e) => Err(e), + } + } + }; + ($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => { + { + fn assert_impl_devicecopy(_val: T) {}; + if false { + $( + assert_impl_devicecopy($arg); + )* + }; + + $stream.launch_cooperative(&$function, $grid, $block, $shared, + &[ + $( + &$arg as *const _ as *mut ::std::ffi::c_void, + )* + ] + ) + } + }; +} + #[cfg(test)] mod test { use super::*; diff --git a/src/stream.rs b/src/stream.rs index 8eadbe2..edf4a87 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -293,6 +293,38 @@ impl Stream { .to_result() } + // Hidden implementation detail function. Highly unsafe. Use the `launch_cooperative!` macro instead. + #[doc(hidden)] + pub unsafe fn launch_cooperative( + &self, + func: &Function, + grid_size: G, + block_size: B, + shared_mem_bytes: u32, + args: &[*mut c_void], + ) -> CudaResult<()> + where + G: Into, + B: Into, + { + let grid_size: GridSize = grid_size.into(); + let block_size: BlockSize = block_size.into(); + + cuda_driver_sys::cuLaunchCooperativeKernel( + func.to_inner(), + grid_size.x, + grid_size.y, + grid_size.z, + block_size.x, + block_size.y, + block_size.z, + shared_mem_bytes, + self.inner, + args.as_ptr() as *mut _, + ) + .to_result() + } + // Get the inner `CUstream` from the `Stream`. // // Necessary for certain CUDA functions outside of this