From 547539caeb353651f0cf6e63fbe35e157d1782a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Wed, 3 Jan 2024 17:47:31 -0500 Subject: [PATCH] chore: Refactors CPU MSM operations using halo2curves library - Expanded the `msm` module within the `provider/util` directory and introduced a new function, `cpu_best_msm`. - Changed the curves library dependency in `bn256_grumpkin.rs` from `pasta_curves` to `halo2curves`, - Removed the `msm.rs` file along with two associated functions `cpu_msm_serial` and `cpu_best_msm` used for non-GPU accelerated operations, and all related tests. - Reorganized the import of `CurveAffine, CurveExt` from the `halo2curves` library in `provider/mod.rs`. Fixes #193 --- Cargo.toml | 2 +- src/provider/bn256_grumpkin.rs | 4 +- src/provider/util/mod.rs | 11 ++- src/provider/util/msm.rs | 154 --------------------------------- 4 files changed, 13 insertions(+), 158 deletions(-) delete mode 100644 src/provider/util/msm.rs diff --git a/Cargo.toml b/Cargo.toml index c828aec51..9293f65ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ pasta-msm = { git = "https://github.com/lurk-lab/pasta-msm", branch = "dev", ver grumpkin-msm = { git = "https://github.com/lurk-lab/grumpkin-msm", branch = "dev", features = ["dont-implement-sort"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -halo2curves = { version = "0.5.0", features = ["bits", "derive_serde"] } +halo2curves = { version = "0.5.0", features = ["bits", "derive_serde", "multicore"] } [target.'cfg(target_arch = "wasm32")'.dependencies] # see https://github.com/rust-random/rand/pull/948 diff --git a/src/provider/bn256_grumpkin.rs b/src/provider/bn256_grumpkin.rs index ba4561df3..1f1b36677 100644 --- a/src/provider/bn256_grumpkin.rs +++ b/src/provider/bn256_grumpkin.rs @@ -9,10 +9,10 @@ use ff::{FromUniformBytes, PrimeField}; use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup}; #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] use grumpkin_msm::{bn256 as bn256_msm, grumpkin as grumpkin_msm}; +// Remove this when https://github.com/zcash/pasta_curves/issues/41 resolves +use halo2curves::{CurveAffine, CurveExt}; use num_bigint::BigInt; use num_traits::Num; -// Remove this when https://github.com/zcash/pasta_curves/issues/41 resolves -use pasta_curves::arithmetic::{CurveAffine, CurveExt}; use rayon::prelude::*; use sha3::Shake256; use std::io::Read; diff --git a/src/provider/util/mod.rs b/src/provider/util/mod.rs index 40a0443fa..ec949c239 100644 --- a/src/provider/util/mod.rs +++ b/src/provider/util/mod.rs @@ -1,3 +1,12 @@ /// Utilities for provider module pub(in crate::provider) mod fb_msm; -pub(in crate::provider) mod msm; +pub mod msm { + use halo2curves::msm::best_multiexp; + use halo2curves::CurveAffine; + + // this argument swap is useful until Rust gets named arguments + // and saves significant complexity in macro code + pub fn cpu_best_msm(bases: &[C], scalars: &[C::Scalar]) -> C::Curve { + best_multiexp(scalars, bases) + } +} diff --git a/src/provider/util/msm.rs b/src/provider/util/msm.rs deleted file mode 100644 index b5ea2a889..000000000 --- a/src/provider/util/msm.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! This module provides a multi-scalar multiplication routine -/// Adapted from zcash/halo2 -use ff::PrimeField; -use itertools::Itertools as _; -use pasta_curves::{self, arithmetic::CurveAffine, group::Group as AnotherGroup}; -use rayon::{current_num_threads, prelude::*}; - -fn cpu_msm_serial(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { - let c = if bases.len() < 4 { - 1 - } else if bases.len() < 32 { - 3 - } else { - (f64::from(bases.len() as u32)).ln().ceil() as usize - }; - - fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { - let skip_bits = segment * c; - let skip_bytes = skip_bits / 8; - - if skip_bytes >= 32 { - return 0; - } - - let mut v = [0; 8]; - #[allow(clippy::disallowed_methods)] - for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { - *v = *o; - } - - let mut tmp = u64::from_le_bytes(v); - tmp >>= skip_bits - (skip_bytes * 8); - tmp %= 1 << c; - - tmp as usize - } - - let segments = (256 / c) + 1; - - (0..segments) - .rev() - .fold(C::Curve::identity(), |mut acc, segment| { - (0..c).for_each(|_| acc = acc.double()); - - #[derive(Clone, Copy)] - enum Bucket { - None, - Affine(C), - Projective(C::Curve), - } - - impl Bucket { - fn add_assign(&mut self, other: &C) { - *self = match *self { - Self::None => Self::Affine(*other), - Self::Affine(a) => Self::Projective(a + *other), - Self::Projective(a) => Self::Projective(a + other), - } - } - - fn add(self, other: C::Curve) -> C::Curve { - match self { - Self::None => other, - Self::Affine(a) => other + a, - Self::Projective(a) => other + a, - } - } - } - - let mut buckets = vec![Bucket::None; (1 << c) - 1]; - - for (coeff, base) in coeffs.iter().zip_eq(bases.iter()) { - let coeff = get_at::(segment, c, &coeff.to_repr()); - if coeff != 0 { - buckets[coeff - 1].add_assign(base); - } - } - - // Summation by parts - // e.g. 3a + 2b + 1c = a + - // (a) + b + - // ((a) + b) + c - let mut running_sum = C::Curve::identity(); - for exp in buckets.into_iter().rev() { - running_sum = exp.add(running_sum); - acc += &running_sum; - } - acc - }) -} - -/// Performs a multi-scalar-multiplication operation without GPU acceleration. -/// -/// This function will panic if coeffs and bases have a different length. -/// -/// This will use multithreading if beneficial. -/// Adapted from zcash/halo2 -pub(crate) fn cpu_best_msm(bases: &[C], coeffs: &[C::Scalar]) -> C::Curve { - assert_eq!(coeffs.len(), bases.len()); - - let num_threads = current_num_threads(); - if coeffs.len() > num_threads { - let chunk = coeffs.len() / num_threads; - coeffs - .par_chunks(chunk) - .zip_eq(bases.par_chunks(chunk)) - .map(|(coeffs, bases)| cpu_msm_serial(coeffs, bases)) - .reduce(C::Curve::identity, |sum, evl| sum + evl) - } else { - cpu_msm_serial(coeffs, bases) - } -} - -#[cfg(test)] -mod tests { - use super::cpu_best_msm; - - use crate::provider::{ - bn256_grumpkin::{bn256, grumpkin}, - secp_secq::{secp256k1, secq256k1}, - }; - use group::{ff::Field, Group}; - use halo2curves::CurveAffine; - use itertools::Itertools as _; - use pasta_curves::{pallas, vesta}; - use rand_core::OsRng; - - fn test_msm_with>() { - let n = 8; - let coeffs = (0..n).map(|_| F::random(OsRng)).collect::>(); - let bases = (0..n) - .map(|_| A::from(A::generator() * F::random(OsRng))) - .collect::>(); - let naive = coeffs - .iter() - .zip_eq(bases.iter()) - .fold(A::CurveExt::identity(), |acc, (coeff, base)| { - acc + *base * coeff - }); - let msm = cpu_best_msm(&bases, &coeffs); - - assert_eq!(naive, msm) - } - - #[test] - fn test_msm() { - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - } -}