Skip to content

Commit

Permalink
points()
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Jan 11, 2024
1 parent b0e2003 commit f1d4593
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 114 deletions.
4 changes: 2 additions & 2 deletions cuda/bn254.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ extern "C" RustError cuda_bn254(point_t *out, const affine_t points[], size_t np
return mult_pippenger<bucket_t>(out, points, npoints, scalars);
}

extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context,
extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context, size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, scalars);
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, npoints, scalars);
}
#endif
4 changes: 2 additions & 2 deletions cuda/grumpkin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t
return mult_pippenger<bucket_t>(out, points, npoints, scalars);
}

extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context,
extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context, size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, scalars);
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, npoints, scalars);
}
#endif
4 changes: 2 additions & 2 deletions cuda/pallas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ extern "C" RustError cuda_pallas(point_t *out, const affine_t points[], size_t n
return mult_pippenger<bucket_t>(out, points, npoints, scalars);
}

extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context,
extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context, size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, scalars);
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, npoints, scalars);
}

#endif
10 changes: 5 additions & 5 deletions cuda/vesta.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,26 @@ typedef pallas_t scalar_t;

#ifndef __CUDA_ARCH__

extern "C" void drop_msm_context_grumpkin(msm_context_t<affine_t::mem_t> &ref) {
extern "C" void drop_msm_context_vesta(msm_context_t<affine_t::mem_t> &ref) {
CUDA_OK(cudaFree(ref.d_points));
}

extern "C" RustError
cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t<affine_t::mem_t> *msm_context)
cuda_vesta_init(const affine_t points[], size_t npoints, msm_context_t<affine_t::mem_t> *msm_context)
{
return mult_pippenger_init<bucket_t, point_t, affine_t, scalar_t>(points, npoints, msm_context);
}

extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t npoints,
extern "C" RustError cuda_vesta(point_t *out, const affine_t points[], size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger<bucket_t>(out, points, npoints, scalars);
}

extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context,
extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context, size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, scalars);
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, npoints, scalars);
}

#endif
131 changes: 71 additions & 60 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ pub mod bn256 {
bn256::{Fr as Scalar, G1Affine as Affine, G1 as Point},
CurveExt,
};

use crate::impl_msm;

impl_msm!(
cuda_bn254,
cuda_bn254_init,
Expand All @@ -42,9 +42,9 @@ pub mod grumpkin {
grumpkin::{Fr as Scalar, G1Affine as Affine, G1 as Point},
CurveExt,
};

use crate::impl_msm;

impl_msm!(
cuda_grumpkin,
cuda_grumpkin_init,
Expand Down Expand Up @@ -105,67 +105,58 @@ macro_rules! impl_msm {
}

#[derive(Default, Debug, Clone)]
pub enum MSMContext<'a> {
CUDA(CudaMSMContext),
CPU(&'a [$affine]),
#[default]
Uninit,
pub struct MSMContext<'a> {
cuda_context: CudaMSMContext,
on_gpu: bool,
cpu_context: &'a [$affine],
}

unsafe impl<'a> Send for MSMContext<'a> {}

unsafe impl<'a> Sync for MSMContext<'a> {}

impl<'a> MSMContext<'a> {
pub fn new_cpu(points: &'a [$affine]) -> Self {
Self::CPU(points)
}

pub fn new_cuda(cuda_context: CudaMSMContext) -> Self {
Self::CUDA(cuda_context)
fn new(points: &'a [$affine]) -> Self {
Self {
cuda_context: CudaMSMContext::default(),
on_gpu: false,
cpu_context: points,
}
}

pub fn npoints(&self) -> usize {
match self {
Self::CUDA(cuda_context) => cuda_context.npoints,
Self::CPU(points) => points.len(),
Self::Uninit => panic!("not initialized"),
fn npoints(&self) -> usize {
if self.on_gpu {
assert_eq!(
self.cpu_context.len(),
self.cuda_context.npoints
);
}
self.cpu_context.len()
}

pub fn cuda(&self) -> &CudaMSMContext {
match self {
Self::CUDA(cuda_context) => cuda_context,
Self::CPU(_) => panic!("not a cuda context"),
Self::Uninit => panic!("not initialized"),
}
fn cuda(&self) -> &CudaMSMContext {
&self.cuda_context
}

pub fn points(&self) -> &[$affine] {
match self {
Self::CUDA(_) => {
panic!("cuda context; no host side points")
}
Self::CPU(points) => points,
Self::Uninit => panic!("not initialized"),
}
fn points(&self) -> &[$affine] {
&self.cpu_context
}
}

extern "C" {
fn $name_cpu(
out: *mut $point,
points: *const $affine,
npoints: usize,
scalars: *const $scalar,
);

}

pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point {
let npoints = points.len();
assert!(npoints == scalars.len(), "length mismatch");

#[cfg(feature = "cuda")]
if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } {
extern "C" {
Expand All @@ -175,23 +166,28 @@ macro_rules! impl_msm {
npoints: usize,
scalars: *const $scalar,
) -> cuda::Error;

}
let mut ret = $point::default();
let err =
unsafe { $name(&mut ret, &points[0], npoints, &scalars[0]) };
let err = unsafe {
$name(&mut ret, &points[0], npoints, &scalars[0])
};
assert!(err.code == 0, "{}", String::from(err));

return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap();
}
let mut ret = $point::default();
unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) };
$point::new_jacobian(ret.x, ret.y, ret.z).unwrap()
}

pub fn init(points: &[$affine]) -> MSMContext {
let npoints = points.len();

let mut ret = MSMContext::new(points);

#[cfg(feature = "cuda")]
if unsafe { !CUDA_OFF && cuda_available() } {
if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } {
extern "C" {
fn $name_init(
points: *const $affine,
Expand All @@ -200,50 +196,60 @@ macro_rules! impl_msm {
) -> cuda::Error;
}

let mut ret = CudaMSMContext::default();

let npoints = points.len();
let err = unsafe {
$name_init(points.as_ptr() as *const _, npoints, &mut ret)
$name_init(
points.as_ptr() as *const _,
npoints,
&mut ret.cuda_context,
)
};
assert!(err.code == 0, "{}", String::from(err));
return MSMContext::new_cuda(ret);
ret.on_gpu = true;
return ret;
}

MSMContext::new_cpu(points)
ret
}

pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point {
assert!(context.npoints() >= scalars.len(), "not enough points");

let npoints = context.npoints();
let nscalars = scalars.len();
assert!(npoints >= nscalars, "not enough points");

let mut ret = $point::default();

#[cfg(feature = "cuda")]
if unsafe { !CUDA_OFF && cuda_available() } {
if nscalars >= 1 << 16
&& context.on_gpu
&& unsafe { !CUDA_OFF && cuda_available() }
{
extern "C" {
fn $name_with(
out: *mut $point,
context: &CudaMSMContext,
npoints: usize,
scalars: *const $scalar,
) -> cuda::Error;
}

let err = unsafe { $name_with(&mut ret, context.cuda(), &scalars[0]) };

let err = unsafe {
$name_with(&mut ret, &context.cuda_context, nscalars, &scalars[0])
};
assert!(err.code == 0, "{}", String::from(err));
return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap();
}

unsafe {
$name_cpu(
&mut ret,
&context.points()[0],
context.npoints(),
&context.cpu_context[0],
nscalars,
&scalars[0],
)
};
$point::new_jacobian(ret.x, ret.y, ret.z).unwrap()
}

};
}

Expand All @@ -269,6 +275,11 @@ mod tests {
let ret = crate::bn256::msm(&points, &scalars).to_affine();
println!("{:?}", ret);

let context = crate::bn256::init(&points);
let ret_other = crate::bn256::with(&context, &scalars).to_affine();
println!("{:?}", ret_other);

assert_eq!(ret, naive);
assert_eq!(ret, ret_other);
}
}
Loading

0 comments on commit f1d4593

Please sign in to comment.