diff --git a/cuda/bn254.cu b/cuda/bn254.cu index e0611e4..fc97057 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -34,9 +34,9 @@ extern "C" RustError cuda_bn254(point_t *out, const affine_t points[], size_t np return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, +extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, size_t npoints, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars); } #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index 106be8f..8403f92 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -34,9 +34,9 @@ extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars); } #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index dd76686..62a375a 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -34,10 +34,10 @@ extern "C" RustError cuda_pallas(point_t *out, const affine_t points[], size_t n return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, +extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars); } #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index ff9d594..63db638 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -18,26 +18,26 @@ typedef pallas_t scalar_t; #ifndef __CUDA_ARCH__ -extern "C" void drop_msm_context_grumpkin(msm_context_t &ref) { +extern "C" void drop_msm_context_vesta(msm_context_t &ref) { CUDA_OK(cudaFree(ref.d_points)); } extern "C" RustError -cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +cuda_vesta_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) { return mult_pippenger_init(points, npoints, msm_context); } -extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t npoints, +extern "C" RustError cuda_vesta(point_t *out, const affine_t points[], size_t npoints, const scalar_t scalars[]) { return mult_pippenger(out, points, npoints, scalars); } -extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, +extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t *msm_context, size_t npoints, const scalar_t scalars[]) { - return mult_pippenger_with(out, msm_context, scalars); + return mult_pippenger_with(out, msm_context, npoints, scalars); } #endif diff --git a/src/lib.rs b/src/lib.rs index a1c9bac..cd51f3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,9 +23,9 @@ pub mod bn256 { bn256::{Fr as Scalar, G1Affine as Affine, G1 as Point}, CurveExt, }; - + use crate::impl_msm; - + impl_msm!( cuda_bn254, cuda_bn254_init, @@ -42,9 +42,9 @@ pub mod grumpkin { grumpkin::{Fr as Scalar, G1Affine as Affine, G1 as Point}, CurveExt, }; - + use crate::impl_msm; - + impl_msm!( cuda_grumpkin, cuda_grumpkin_init, @@ -105,11 +105,10 @@ macro_rules! impl_msm { } #[derive(Default, Debug, Clone)] - pub enum MSMContext<'a> { - CUDA(CudaMSMContext), - CPU(&'a [$affine]), - #[default] - Uninit, + pub struct MSMContext<'a> { + cuda_context: CudaMSMContext, + on_gpu: bool, + cpu_context: &'a [$affine], } unsafe impl<'a> Send for MSMContext<'a> {} @@ -117,41 +116,33 @@ macro_rules! impl_msm { unsafe impl<'a> Sync for MSMContext<'a> {} impl<'a> MSMContext<'a> { - pub fn new_cpu(points: &'a [$affine]) -> Self { - Self::CPU(points) - } - - pub fn new_cuda(cuda_context: CudaMSMContext) -> Self { - Self::CUDA(cuda_context) + fn new(points: &'a [$affine]) -> Self { + Self { + cuda_context: CudaMSMContext::default(), + on_gpu: false, + cpu_context: points, + } } - pub fn npoints(&self) -> usize { - match self { - Self::CUDA(cuda_context) => cuda_context.npoints, - Self::CPU(points) => points.len(), - Self::Uninit => panic!("not initialized"), + fn npoints(&self) -> usize { + if self.on_gpu { + assert_eq!( + self.cpu_context.len(), + self.cuda_context.npoints + ); } + self.cpu_context.len() } - pub fn cuda(&self) -> &CudaMSMContext { - match self { - Self::CUDA(cuda_context) => cuda_context, - Self::CPU(_) => panic!("not a cuda context"), - Self::Uninit => panic!("not initialized"), - } + fn cuda(&self) -> &CudaMSMContext { + &self.cuda_context } - pub fn points(&self) -> &[$affine] { - match self { - Self::CUDA(_) => { - panic!("cuda context; no host side points") - } - Self::CPU(points) => points, - Self::Uninit => panic!("not initialized"), - } + fn points(&self) -> &[$affine] { + &self.cpu_context } } - + extern "C" { fn $name_cpu( out: *mut $point, @@ -159,13 +150,13 @@ macro_rules! impl_msm { npoints: usize, scalars: *const $scalar, ); - + } - + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { let npoints = points.len(); assert!(npoints == scalars.len(), "length mismatch"); - + #[cfg(feature = "cuda")] if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { @@ -175,23 +166,28 @@ macro_rules! impl_msm { npoints: usize, scalars: *const $scalar, ) -> cuda::Error; - + } let mut ret = $point::default(); - let err = - unsafe { $name(&mut ret, &points[0], npoints, &scalars[0]) }; + let err = unsafe { + $name(&mut ret, &points[0], npoints, &scalars[0]) + }; assert!(err.code == 0, "{}", String::from(err)); - + return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); } let mut ret = $point::default(); unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() } - + pub fn init(points: &[$affine]) -> MSMContext { + let npoints = points.len(); + + let mut ret = MSMContext::new(points); + #[cfg(feature = "cuda")] - if unsafe { !CUDA_OFF && cuda_available() } { + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { fn $name_init( points: *const $affine, @@ -200,50 +196,60 @@ macro_rules! impl_msm { ) -> cuda::Error; } - let mut ret = CudaMSMContext::default(); - let npoints = points.len(); let err = unsafe { - $name_init(points.as_ptr() as *const _, npoints, &mut ret) + $name_init( + points.as_ptr() as *const _, + npoints, + &mut ret.cuda_context, + ) }; assert!(err.code == 0, "{}", String::from(err)); - return MSMContext::new_cuda(ret); + ret.on_gpu = true; + return ret; } - MSMContext::new_cpu(points) + ret } - + pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { - assert!(context.npoints() >= scalars.len(), "not enough points"); - + let npoints = context.npoints(); + let nscalars = scalars.len(); + assert!(npoints >= nscalars, "not enough points"); + let mut ret = $point::default(); - + #[cfg(feature = "cuda")] - if unsafe { !CUDA_OFF && cuda_available() } { + if nscalars >= 1 << 16 + && context.on_gpu + && unsafe { !CUDA_OFF && cuda_available() } + { extern "C" { fn $name_with( out: *mut $point, context: &CudaMSMContext, + npoints: usize, scalars: *const $scalar, ) -> cuda::Error; } - - let err = unsafe { $name_with(&mut ret, context.cuda(), &scalars[0]) }; + + let err = unsafe { + $name_with(&mut ret, &context.cuda_context, nscalars, &scalars[0]) + }; assert!(err.code == 0, "{}", String::from(err)); return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); } - + unsafe { $name_cpu( &mut ret, - &context.points()[0], - context.npoints(), + &context.cpu_context[0], + nscalars, &scalars[0], ) }; $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() } - }; } @@ -269,6 +275,11 @@ mod tests { let ret = crate::bn256::msm(&points, &scalars).to_affine(); println!("{:?}", ret); + let context = crate::bn256::init(&points); + let ret_other = crate::bn256::with(&context, &scalars).to_affine(); + println!("{:?}", ret_other); + assert_eq!(ret, naive); + assert_eq!(ret, ret_other); } } diff --git a/src/pasta.rs b/src/pasta.rs index 5b072e2..de0e781 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -87,11 +87,10 @@ macro_rules! impl_pasta { } #[derive(Default, Debug, Clone)] - pub enum MSMContext<'a> { - CUDA(CudaMSMContext), - CPU(&'a [$affine]), - #[default] - Uninit, + pub struct MSMContext<'a> { + cuda_context: CudaMSMContext, + on_gpu: bool, + cpu_context: &'a [$affine], } unsafe impl<'a> Send for MSMContext<'a> {} @@ -99,38 +98,30 @@ macro_rules! impl_pasta { unsafe impl<'a> Sync for MSMContext<'a> {} impl<'a> MSMContext<'a> { - pub fn new_cpu(points: &'a [$affine]) -> Self { - Self::CPU(points) - } - - pub fn new_cuda(cuda_context: CudaMSMContext) -> Self { - Self::CUDA(cuda_context) + fn new(points: &'a [$affine]) -> Self { + Self { + cuda_context: CudaMSMContext::default(), + on_gpu: false, + cpu_context: points, + } } - pub fn npoints(&self) -> usize { - match self { - Self::CUDA(cuda_context) => cuda_context.npoints, - Self::CPU(points) => points.len(), - Self::Uninit => panic!("not initialized"), + fn npoints(&self) -> usize { + if self.on_gpu { + assert_eq!( + self.cpu_context.len(), + self.cuda_context.npoints + ); } + self.cpu_context.len() } - pub fn cuda(&self) -> &CudaMSMContext { - match self { - Self::CUDA(cuda_context) => cuda_context, - Self::CPU(_) => panic!("not a cuda context"), - Self::Uninit => panic!("not initialized"), - } + fn cuda(&self) -> &CudaMSMContext { + &self.cuda_context } - pub fn points(&self) -> &[$affine] { - match self { - Self::CUDA(_) => { - panic!("cuda context; no host side points") - } - Self::CPU(points) => points, - Self::Uninit => panic!("not initialized"), - } + fn points(&self) -> &[$affine] { + &self.cpu_context } } @@ -140,6 +131,7 @@ macro_rules! impl_pasta { points: *const $affine, npoints: usize, scalars: *const $scalar, + is_mont: bool, ); } @@ -156,25 +148,32 @@ macro_rules! impl_pasta { points: *const $affine, npoints: usize, scalars: *const $scalar, + is_mont: bool, ) -> cuda::Error; } let mut ret = $point::default(); let err = unsafe { - $name(&mut ret, &points[0], npoints, &scalars[0]) + $name(&mut ret, &points[0], npoints, &scalars[0], true) }; assert!(err.code == 0, "{}", String::from(err)); return ret; } let mut ret = $point::default(); - unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; + unsafe { + $name_cpu(&mut ret, &points[0], npoints, &scalars[0], true) + }; ret } pub fn init(points: &[$affine]) -> MSMContext { + let npoints = points.len(); + + let mut ret = MSMContext::new(points); + #[cfg(feature = "cuda")] - if unsafe { !CUDA_OFF && cuda_available() } { + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { fn $name_init( points: *const $affine, @@ -183,35 +182,46 @@ macro_rules! impl_pasta { ) -> cuda::Error; } - let mut ret = CudaMSMContext::default(); - let npoints = points.len(); let err = unsafe { - $name_init(points.as_ptr() as *const _, npoints, &mut ret) + $name_init( + points.as_ptr() as *const _, + npoints, + &mut ret.cuda_context, + ) }; assert!(err.code == 0, "{}", String::from(err)); - return MSMContext::new_cuda(ret); + ret.on_gpu = true; + return ret; } - MSMContext::new_cpu(points) + ret } pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { - assert!(context.npoints() >= scalars.len(), "not enough points"); + let npoints = context.npoints(); + let nscalars = scalars.len(); + assert!(npoints >= nscalars, "not enough points"); let mut ret = $point::default(); #[cfg(feature = "cuda")] - if unsafe { !CUDA_OFF && cuda_available() } { + if nscalars >= 1 << 16 + && unsafe { !CUDA_OFF && cuda_available() } + { extern "C" { fn $name_with( out: *mut $point, context: &CudaMSMContext, + npoints: usize, scalars: *const $scalar, + is_mont: bool, ) -> cuda::Error; } - let err = unsafe { $name_with(&mut ret, context.cuda(), &scalars[0]) }; + let err = unsafe { + $name_with(&mut ret, context.cuda(), nscalars, &scalars[0], true) + }; assert!(err.code == 0, "{}", String::from(err)); return ret; } @@ -219,9 +229,10 @@ macro_rules! impl_pasta { unsafe { $name_cpu( &mut ret, - &context.points()[0], - context.npoints(), + &context.cpu_context[0], + nscalars, &scalars[0], + true, ) }; @@ -366,6 +377,11 @@ mod tests { let ret = pallas::msm(&points, &scalars).to_affine(); println!("{:?}", ret); + let context = pallas::init(&points); + let ret_other = pallas::with(&context, &scalars).to_affine(); + println!("{:?}", ret_other); + assert_eq!(ret, naive); + assert_eq!(ret, ret_other); } }