From 7688c9f22f91c4ce5b1ba41a5b392902ddc5c48e Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Wed, 3 Jan 2024 14:06:18 -0800 Subject: [PATCH] upgrade halo2curves to 0.5.0; shed local MSM code (#288) * upgrade halo2curves to 0.5.0; shed local MSM code * remove asm specific digest tests update digests to pass tests * include halo2curves dependency with flags --- Cargo.toml | 7 +- src/lib.rs | 23 ++--- src/provider/bn256_grumpkin.rs | 23 +++-- src/provider/mod.rs | 41 +-------- src/provider/msm.rs | 151 --------------------------------- src/provider/pasta.rs | 10 +-- src/provider/secp_secq.rs | 13 ++- src/provider/traits.rs | 2 +- 8 files changed, 35 insertions(+), 235 deletions(-) delete mode 100644 src/provider/msm.rs diff --git a/Cargo.toml b/Cargo.toml index 7a7ca9c5..785fc94f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,16 +32,19 @@ bincode = "1.3" bitvec = "1.0" byteorder = "1.4.3" thiserror = "1.0" -halo2curves = { version = "0.4.0", features = ["derive_serde"] } group = "0.13.0" once_cell = "1.18.0" [target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies] pasta-msm = { version = "0.1.4" } -[target.wasm32-unknown-unknown.dependencies] +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +halo2curves = { version = "0.5.0", features = ["bits", "derive_serde"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies] # see https://github.com/rust-random/rand/pull/948 getrandom = { version = "0.2.0", default-features = false, features = ["js"] } +halo2curves = { version = "0.5.0", default-features = false, features = ["bits", "derive_serde"] } [dev-dependencies] criterion = { version = "0.4", features = ["html_reports"] } diff --git a/src/lib.rs b/src/lib.rs index 01bb0bc5..6a509a32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -959,29 +959,16 @@ mod tests { let trivial_circuit2_grumpkin = TrivialCircuit::<::Scalar>::default(); let cubic_circuit1_grumpkin = CubicCircuit::<::Scalar>::default(); - #[cfg(feature = "asm")] test_pp_digest_with::( &trivial_circuit1_grumpkin, &trivial_circuit2_grumpkin, - "c4ecd363a6c1473de7e0d24fc1dbb660f563556e2e13fb4614acdff04cab7701", + "1507bae161c78d6fbb231d5aa288a5cbc91f667c563e1fc4d47e7965a00a6b02", ); - #[cfg(feature = "asm")] - test_pp_digest_with::( - &cubic_circuit1_grumpkin, - &trivial_circuit2_grumpkin, - "4853a6463b6309f6ae76442934d0a423f51f1e10abaddd0d39bf5644ed589100", - ); - #[cfg(not(feature = "asm"))] - test_pp_digest_with::( - &trivial_circuit1_grumpkin, - &trivial_circuit2_grumpkin, - "c26cc841d42c19bf98bc2482e66cd30903922f2a923927b85d66f375a821f101", - ); - #[cfg(not(feature = "asm"))] + test_pp_digest_with::( &cubic_circuit1_grumpkin, &trivial_circuit2_grumpkin, - "4c484cab71e93dda69b420beb7276af969c2034a7ffb0ea8e6964e96a7e5a901", + "3ffcbf855534eea209f2c9735c71ed055e88eecc7342144d47d5de9597432001", ); let trivial_circuit1_secp = TrivialCircuit::<::Scalar>::default(); @@ -991,12 +978,12 @@ mod tests { test_pp_digest_with::( &trivial_circuit1_secp, &trivial_circuit2_secp, - "b794d655fb39891eaf530ca3be1ec2a5ac97f72a0d07c45dbb84529d8a611502", + "ac3329f372c18a100b89fe6363844d2df42e6be539ce21bdfbe867e709be5403", ); test_pp_digest_with::( &cubic_circuit1_secp, &trivial_circuit2_secp, - "50e6acf363c31c2ac1c9c646b4494cb21aae6cb648c7b0d4c95015c811fba302", + "2310754f2fd0e1c4e097d178f7d36e18c0362ee59c713f2a0157a9d9be066103", ); } diff --git a/src/provider/bn256_grumpkin.rs b/src/provider/bn256_grumpkin.rs index d80d7b59..d15779d6 100644 --- a/src/provider/bn256_grumpkin.rs +++ b/src/provider/bn256_grumpkin.rs @@ -1,10 +1,7 @@ //! This module implements the Nova traits for `bn256::Point`, `bn256::Scalar`, `grumpkin::Point`, `grumpkin::Scalar`. use crate::{ impl_traits, - provider::{ - msm::cpu_best_msm, - traits::{CompressedGroup, DlogGroup, PairingGroup}, - }, + provider::traits::{CompressedGroup, DlogGroup, PairingGroup}, traits::{Group, PrimeFieldExt, TranscriptReprTrait}, }; use digest::{ExtendableOutput, Update}; @@ -13,19 +10,19 @@ use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEn use num_bigint::BigInt; use num_traits::Num; // Remove this when https://github.com/zcash/pasta_curves/issues/41 resolves +use halo2curves::{ + bn256::{ + pairing, G1Affine as Bn256Affine, G1Compressed as Bn256Compressed, G2Affine, G2Compressed, Gt, + G1 as Bn256Point, G2, + }, + grumpkin::{G1Affine as GrumpkinAffine, G1Compressed as GrumpkinCompressed, G1 as GrumpkinPoint}, + msm::best_multiexp, +}; use pasta_curves::arithmetic::{CurveAffine, CurveExt}; use rayon::prelude::*; use sha3::Shake256; use std::io::Read; -use halo2curves::bn256::{ - pairing, G1Affine as Bn256Affine, G1Compressed as Bn256Compressed, G2Affine, G2Compressed, Gt, - G1 as Bn256Point, G2, -}; -use halo2curves::grumpkin::{ - G1Affine as GrumpkinAffine, G1Compressed as GrumpkinCompressed, G1 as GrumpkinPoint, -}; - /// Re-exports that give access to the standard aliases used in the code base, for bn256 pub mod bn256 { pub use halo2curves::bn256::{Fq as Base, Fr as Scalar, G1Affine as Affine, G1 as Point}; @@ -93,7 +90,7 @@ impl DlogGroup for G2 { scalars: &[Self::Scalar], bases: &[Self::PreprocessedGroupElement], ) -> Self { - cpu_best_msm(scalars, bases) + best_multiexp(scalars, bases) } fn preprocessed(&self) -> Self::PreprocessedGroupElement { diff --git a/src/provider/mod.rs b/src/provider/mod.rs index 1236f129..b2739cec 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -14,7 +14,6 @@ pub(crate) mod traits; // crate-private modules mod keccak; -mod msm; use crate::{ provider::{ @@ -114,17 +113,11 @@ impl Engine for VestaEngine { #[cfg(test)] mod tests { - use crate::provider::{ - bn256_grumpkin::{bn256, grumpkin}, - msm::cpu_best_msm, - secp_secq::{secp256k1, secq256k1}, - traits::DlogGroup, - }; + use crate::provider::{bn256_grumpkin::bn256, secp_secq::secp256k1, traits::DlogGroup}; use digest::{ExtendableOutput, Update}; - use group::{ff::Field, Curve, Group}; - use halo2curves::{CurveAffine, CurveExt}; - use pasta_curves::{pallas, vesta}; - use rand_core::OsRng; + use group::Curve; + use halo2curves::CurveExt; + use pasta_curves::pallas; use sha3::Shake256; use std::io::Read; @@ -157,32 +150,6 @@ mod tests { }; } - fn test_msm_with>() { - let n = 8; - let coeffs = (0..n).map(|_| F::random(OsRng)).collect::>(); - let bases = (0..n) - .map(|_| A::from(A::generator() * F::random(OsRng))) - .collect::>(); - let naive = coeffs - .iter() - .zip(bases.iter()) - .fold(A::CurveExt::identity(), |acc, (coeff, base)| { - acc + *base * coeff - }); - - assert_eq!(naive, cpu_best_msm(&coeffs, &bases)) - } - - #[test] - fn test_msm() { - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - } - #[test] fn test_bn256_from_label() { impl_cycle_pair_test!(bn256); diff --git a/src/provider/msm.rs b/src/provider/msm.rs deleted file mode 100644 index 4c499a4f..00000000 --- a/src/provider/msm.rs +++ /dev/null @@ -1,151 +0,0 @@ -//! This module provides a multi-scalar multiplication routine -/// Adapted from zcash/halo2 -use ff::PrimeField; -use pasta_curves::{self, arithmetic::CurveAffine, group::Group as AnotherGroup}; -use rayon::{current_num_threads, prelude::*}; - -fn cpu_msm_serial(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { - let c = if bases.len() < 4 { - 1 - } else if bases.len() < 32 { - 3 - } else { - (f64::from(bases.len() as u32)).ln().ceil() as usize - }; - - fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { - let skip_bits = segment * c; - let skip_bytes = skip_bits / 8; - - if skip_bytes >= 32 { - return 0; - } - - let mut v = [0; 8]; - for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { - *v = *o; - } - - let mut tmp = u64::from_le_bytes(v); - tmp >>= skip_bits - (skip_bytes * 8); - tmp %= 1 << c; - - tmp as usize - } - - let segments = (256 / c) + 1; - - (0..segments) - .rev() - .fold(C::Curve::identity(), |mut acc, segment| { - (0..c).for_each(|_| acc = acc.double()); - - #[derive(Clone, Copy)] - enum Bucket { - None, - Affine(C), - Projective(C::Curve), - } - - impl Bucket { - fn add_assign(&mut self, other: &C) { - *self = match *self { - Bucket::None => Bucket::Affine(*other), - Bucket::Affine(a) => Bucket::Projective(a + *other), - Bucket::Projective(a) => Bucket::Projective(a + other), - } - } - - fn add(self, other: C::Curve) -> C::Curve { - match self { - Bucket::None => other, - Bucket::Affine(a) => other + a, - Bucket::Projective(a) => other + a, - } - } - } - - let mut buckets = vec![Bucket::None; (1 << c) - 1]; - - for (coeff, base) in coeffs.iter().zip(bases.iter()) { - let coeff = get_at::(segment, c, &coeff.to_repr()); - if coeff != 0 { - buckets[coeff - 1].add_assign(base); - } - } - - // Summation by parts - // e.g. 3a + 2b + 1c = a + - // (a) + b + - // ((a) + b) + c - let mut running_sum = C::Curve::identity(); - for exp in buckets.into_iter().rev() { - running_sum = exp.add(running_sum); - acc += &running_sum; - } - acc - }) -} - -/// Performs a multi-scalar-multiplication operation without GPU acceleration. -/// -/// This function will panic if coeffs and bases have a different length. -/// -/// This will use multithreading if beneficial. -/// Adapted from zcash/halo2 -pub(crate) fn cpu_best_msm(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { - assert_eq!(coeffs.len(), bases.len()); - - let num_threads = current_num_threads(); - if coeffs.len() > num_threads { - let chunk = coeffs.len() / num_threads; - coeffs - .par_chunks(chunk) - .zip(bases.par_chunks(chunk)) - .map(|(coeffs, bases)| cpu_msm_serial(coeffs, bases)) - .reduce(C::Curve::identity, |sum, evl| sum + evl) - } else { - cpu_msm_serial(coeffs, bases) - } -} - -#[cfg(test)] -mod tests { - use super::cpu_best_msm; - - use crate::provider::{ - bn256_grumpkin::{bn256, grumpkin}, - secp_secq::{secp256k1, secq256k1}, - }; - use group::{ff::Field, Group}; - use halo2curves::CurveAffine; - use pasta_curves::{pallas, vesta}; - use rand_core::OsRng; - - fn test_msm_with>() { - let n = 8; - let coeffs = (0..n).map(|_| F::random(OsRng)).collect::>(); - let bases = (0..n) - .map(|_| A::from(A::generator() * F::random(OsRng))) - .collect::>(); - let naive = coeffs - .iter() - .zip(bases.iter()) - .fold(A::CurveExt::identity(), |acc, (coeff, base)| { - acc + *base * coeff - }); - let msm = cpu_best_msm(&coeffs, &bases); - - assert_eq!(naive, msm) - } - - #[test] - fn test_msm() { - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - test_msm_with::(); - } -} diff --git a/src/provider/pasta.rs b/src/provider/pasta.rs index 8f4b907b..e529af49 100644 --- a/src/provider/pasta.rs +++ b/src/provider/pasta.rs @@ -1,13 +1,11 @@ //! This module implements the Nova traits for `pallas::Point`, `pallas::Scalar`, `vesta::Point`, `vesta::Scalar`. use crate::{ - provider::{ - msm::cpu_best_msm, - traits::{CompressedGroup, DlogGroup}, - }, + provider::traits::{CompressedGroup, DlogGroup}, traits::{Group, PrimeFieldExt, TranscriptReprTrait}, }; use digest::{ExtendableOutput, Update}; use ff::{FromUniformBytes, PrimeField}; +use halo2curves::msm::best_multiexp; use num_bigint::BigInt; use num_traits::Num; use pasta_curves::{ @@ -82,10 +80,10 @@ macro_rules! impl_traits { if scalars.len() >= 128 { pasta_msm::$name(bases, scalars) } else { - cpu_best_msm(scalars, bases) + best_multiexp(scalars, bases) } #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - cpu_best_msm(scalars, bases) + best_multiexp(scalars, bases) } fn preprocessed(&self) -> Self::PreprocessedGroupElement { diff --git a/src/provider/secp_secq.rs b/src/provider/secp_secq.rs index 39b58209..3f2ee1d3 100644 --- a/src/provider/secp_secq.rs +++ b/src/provider/secp_secq.rs @@ -1,15 +1,17 @@ //! This module implements the Nova traits for `secp::Point`, `secp::Scalar`, `secq::Point`, `secq::Scalar`. use crate::{ impl_traits, - provider::{ - msm::cpu_best_msm, - traits::{CompressedGroup, DlogGroup}, - }, + provider::traits::{CompressedGroup, DlogGroup}, traits::{Group, PrimeFieldExt, TranscriptReprTrait}, }; use digest::{ExtendableOutput, Update}; use ff::{FromUniformBytes, PrimeField}; use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding}; +use halo2curves::{ + msm::best_multiexp, + secp256k1::{Secp256k1, Secp256k1Affine, Secp256k1Compressed}, + secq256k1::{Secq256k1, Secq256k1Affine, Secq256k1Compressed}, +}; use num_bigint::BigInt; use num_traits::Num; use pasta_curves::arithmetic::{CurveAffine, CurveExt}; @@ -17,9 +19,6 @@ use rayon::prelude::*; use sha3::Shake256; use std::io::Read; -use halo2curves::secp256k1::{Secp256k1, Secp256k1Affine, Secp256k1Compressed}; -use halo2curves::secq256k1::{Secq256k1, Secq256k1Affine, Secq256k1Compressed}; - /// Re-exports that give access to the standard aliases used in the code base, for secp pub mod secp256k1 { pub use halo2curves::secp256k1::{ diff --git a/src/provider/traits.rs b/src/provider/traits.rs index fd9761a6..5d569498 100644 --- a/src/provider/traits.rs +++ b/src/provider/traits.rs @@ -145,7 +145,7 @@ macro_rules! impl_traits { scalars: &[Self::Scalar], bases: &[Self::PreprocessedGroupElement], ) -> Self { - cpu_best_msm(scalars, bases) + best_multiexp(scalars, bases) } fn preprocessed(&self) -> Self::PreprocessedGroupElement {