From e1c0135b4f69a331a563499c3e4b6685f39ddf30 Mon Sep 17 00:00:00 2001 From: Ming Date: Thu, 19 Sep 2024 20:16:37 +0800 Subject: [PATCH] sumcheck protocol support mixed num_vars monomial form (#235) ### Goal To make sumcheck protocol support different num_vars, aiming for - [x] minimal change to sumcheck protocol, make verifier remain the same. no extra meta data passed from prover, and what prover have just mle and it's eval size Besides this PR also remove some parallism in verifier since it's unnecessary for relative low cost. ### design rationale (Also comments in codebase for reference) To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars, for it evaluation value we need to times 2^(max_num_vars - num_vars) E.g. Giving multivariate poly $f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n$ For i round univariate poly, $f^i(x)$ $f^i[0] = \sum_b f(r, 0, b), b \in [0, 1]^{n-i-1}, r \in {F}^{n-i-1}$ chanllenge get from prev rounds = $\sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n'$ = $2^{(|b| - |b1|)} * \sum_{b1} f_1(r, 0, b1) + \sum_b f_2(r, 0, b)$ same applied on f^i[1] It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value ### benchmark benchmark with ceno_zkvm `riscv_add`, and gkr `keccak` both remain the same and no impact. You might see some redundancy coding style, but this is for retain the best performance. I tried other variants and it impact benchmark results ### scope Related to #109 #210 .... To address #126 #127 This enhance protocol features potiential can be used for `range table-circuit`, `init/final-memory`, `cpu-init/cpu-final halt` to make selector sumcheck support batching different num_instance witin. --- Cargo.lock | 1 + ceno_zkvm/src/virtual_polys.rs | 105 +++++++- multilinear_extensions/Cargo.toml | 1 + multilinear_extensions/src/mle.rs | 10 +- multilinear_extensions/src/util.rs | 5 + multilinear_extensions/src/virtual_poly_v2.rs | 53 ++-- sumcheck/src/prover_v2.rs | 254 +++++++++++++----- sumcheck/src/test.rs | 2 + sumcheck/src/util.rs | 2 +- sumcheck/src/verifier.rs | 11 +- 10 files changed, 336 insertions(+), 108 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c47791660..357fad1cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1198,6 +1198,7 @@ dependencies = [ "ff", "ff_ext", "goldilocks", + "itertools 0.12.1", "log", "rayon", "serde", diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 375f15b8c..5d81ea042 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -168,19 +168,28 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { #[cfg(test)] mod tests { + use ark_std::test_rng; + use ff_ext::ExtensionField; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use multilinear_extensions::{mle::IntoMLE, virtual_poly_v2::ArcMultilinearExtension}; + use multilinear_extensions::{ + mle::IntoMLE, + virtual_poly::VPAuxInfo, + virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, + }; + use sumcheck::structs::{IOPProverStateV2, IOPVerifierState}; + use transcript::Transcript; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, expression::{Expression, ToExpr}, virtual_polys::VirtualPolynomials, }; + use ff::Field; + type E = GoldilocksExt2; #[test] fn test_add_mle_list_by_expr() { - type E = GoldilocksExt2; let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); let x = cb.create_witin(|| "x").unwrap(); @@ -218,4 +227,96 @@ mod tests { assert!(distrinct_zerocheck_terms_set.len() == 1); assert!(virtual_polys.degree() == 3); } + + #[test] + fn test_sumcheck_different_degree() { + let max_num_vars = 3; + let fn_eval = |fs: &[ArcMultilinearExtension]| -> E { + let base_2 = ::BaseField::from(2); + + let evals = fs.iter().fold( + vec![::BaseField::ONE; 1 << fs[0].num_vars()], + |mut evals, f| { + evals + .iter_mut() + .zip(f.get_base_field_vec()) + .for_each(|(e, v)| { + *e *= v; + }); + evals + }, + ); + + <::BaseField as std::convert::Into>::into( + evals.iter().sum::<::BaseField>() + * base_2.pow([(max_num_vars - fs[0].num_vars()) as u64]), + ) + }; + let num_threads = 1; + let mut transcript = Transcript::new(b"test"); + + let mut rng = test_rng(); + + let f1: [ArcMultilinearExtension; 2] = std::array::from_fn(|_| { + (0..1 << (max_num_vars - 2)) + .map(|_| ::BaseField::random(&mut rng)) + .collect_vec() + .into_mle() + .into() + }); + let f2: [ArcMultilinearExtension; 1] = std::array::from_fn(|_| { + (0..1 << (max_num_vars)) + .map(|_| ::BaseField::random(&mut rng)) + .collect_vec() + .into_mle() + .into() + }); + let f3: [ArcMultilinearExtension; 3] = std::array::from_fn(|_| { + (0..1 << (max_num_vars - 1)) + .map(|_| ::BaseField::random(&mut rng)) + .collect_vec() + .into_mle() + .into() + }); + + let mut virtual_polys = VirtualPolynomials::::new(num_threads, max_num_vars); + + virtual_polys.add_mle_list(f1.iter().collect(), E::ONE); + virtual_polys.add_mle_list(f2.iter().collect(), E::ONE); + virtual_polys.add_mle_list(f3.iter().collect(), E::ONE); + + let (sumcheck_proofs, _) = IOPProverStateV2::prove_batch_polys( + num_threads, + virtual_polys.get_batched_polys(), + &mut transcript, + ); + + let mut transcript = Transcript::new(b"test"); + let subclaim = IOPVerifierState::::verify( + fn_eval(&f1) + fn_eval(&f2) + fn_eval(&f3), + &sumcheck_proofs, + &VPAuxInfo { + max_degree: 3, + num_variables: max_num_vars, + phantom: std::marker::PhantomData, + }, + &mut transcript, + ); + + let mut verifier_poly = VirtualPolynomialV2::new(max_num_vars); + verifier_poly.add_mle_list(f1.to_vec(), E::ONE); + verifier_poly.add_mle_list(f2.to_vec(), E::ONE); + verifier_poly.add_mle_list(f3.to_vec(), E::ONE); + assert!( + verifier_poly.evaluate( + subclaim + .point + .iter() + .map(|c| c.elements) + .collect::>() + .as_ref() + ) == subclaim.expected_evaluation, + "wrong subclaim" + ); + } } diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index b4bf646ba..903dd776a 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -11,6 +11,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } tracing-flame = "0.2.0" ff_ext = { path = "../ff_ext" } +itertools = "0.12.1" ark-std.workspace = true ff.workspace = true goldilocks.workspace = true diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index a4530919c..8ce82ae02 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -124,7 +124,9 @@ pub trait IntoMLEs: Sized { fn into_mles(self) -> Vec; } -impl> IntoMLEs> for Vec> { +impl> IntoMLEs> + for Vec> +{ fn into_mles(self) -> Vec> { self.into_iter().map(|v| v.into_mle()).collect() } @@ -1000,12 +1002,6 @@ macro_rules! op_mle { match &$a.evaluations() { $crate::mle::FieldType::Base(a) => { let $tmp_a = if let Some((start, offset)) = $a.evaluations_range() { - println!( - "op_mle start {}, offset {}, a.len {}", - start, - offset, - a.len() - ); &a[start..][..offset] } else { &a[..] diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index c25e1f8eb..28e4f8284 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -25,3 +25,8 @@ pub fn create_uninit_vec(len: usize) -> Vec> { unsafe { vec.set_len(len) }; vec } + +#[inline(always)] +pub fn largest_even_below(n: usize) -> usize { + if n % 2 == 0 { n } else { n.saturating_sub(1) } +} diff --git a/multilinear_extensions/src/virtual_poly_v2.rs b/multilinear_extensions/src/virtual_poly_v2.rs index c8a46a367..dcf588baf 100644 --- a/multilinear_extensions/src/virtual_poly_v2.rs +++ b/multilinear_extensions/src/virtual_poly_v2.rs @@ -6,6 +6,7 @@ use crate::{ }; use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; +use itertools::Itertools; use serde::{Deserialize, Serialize}; pub type ArcMultilinearExtension<'a, E> = @@ -55,8 +56,8 @@ pub struct VirtualPolynomialV2<'a, E: ExtensionField> { pub struct VPAuxInfo { /// max number of multiplicands in each product pub max_degree: usize, - /// number of variables of the polynomial - pub num_variables: usize, + /// max number of variables of the polynomial + pub max_num_variables: usize, /// Associated field #[doc(hidden)] pub phantom: PhantomData, @@ -69,12 +70,12 @@ impl AsRef<[u8]> for VPAuxInfo { } impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { - /// Creates an empty virtual polynomial with `num_variables`. - pub fn new(num_variables: usize) -> Self { + /// Creates an empty virtual polynomial with `max_num_variables`. + pub fn new(max_num_variables: usize) -> Self { VirtualPolynomialV2 { aux_info: VPAuxInfo { max_degree: 0, - num_variables, + max_num_variables, phantom: PhantomData, }, products: Vec::new(), @@ -93,7 +94,7 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { aux_info: VPAuxInfo { // The max degree is the max degree of any individual variable max_degree: 1, - num_variables: mle.num_vars(), + max_num_variables: mle.num_vars(), phantom: PhantomData, }, // here `0` points to the first polynomial of `flattened_ml_extensions` @@ -104,8 +105,10 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { } /// Add a product of list of multilinear extensions to self - /// Returns an error if the list is empty, or the MLE has a different - /// `num_vars()` from self. + /// Returns an error if the list is empty. + /// + /// mle in mle_list must be in same num_vars() in same product, + /// while different product can have different num_vars() /// /// The MLEs will be multiplied together, and then multiplied by the scalar /// `coefficient`. @@ -114,18 +117,20 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { let mut indexed_product = Vec::with_capacity(mle_list.len()); assert!(!mle_list.is_empty(), "input mle_list is empty"); + // sanity check: all mle in mle_list must have same num_vars() + assert!( + mle_list + .iter() + .map(|m| { + assert!(m.num_vars() <= self.aux_info.max_num_variables); + m.num_vars() + }) + .all_equal() + ); self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len()); for mle in mle_list { - assert_eq!( - mle.num_vars(), - self.aux_info.num_variables, - "product has a multiplicand with wrong number of variables {} vs {}", - mle.num_vars(), - self.aux_info.num_variables - ); - let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) { indexed_product.push(*index) @@ -163,10 +168,10 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { assert_eq!( mle.num_vars(), - self.aux_info.num_variables, + self.aux_info.max_num_variables, "product has a multiplicand with wrong number of variables {} vs {}", mle.num_vars(), - self.aux_info.num_variables + self.aux_info.max_num_variables ); let mle_ptr = Arc::as_ptr(&mle) as *const () as usize; @@ -200,17 +205,17 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { let start = start_timer!(|| "evaluation"); assert_eq!( - self.aux_info.num_variables, + self.aux_info.max_num_variables, point.len(), "wrong number of variables {} vs {}", - self.aux_info.num_variables, + self.aux_info.max_num_variables, point.len() ); let evals: Vec = self .flattened_ml_extensions .iter() - .map(|x| x.evaluate(point)) + .map(|x| x.evaluate(&point[0..x.num_vars()])) .collect(); let res = self @@ -225,11 +230,11 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { /// Print out the evaluation map for testing. Panic if the num_vars() > 5. pub fn print_evals(&self) { - if self.aux_info.num_variables > 5 { + if self.aux_info.max_num_variables > 5 { panic!("this function is used for testing only. cannot print more than 5 num_vars()") } - for i in 0..1 << self.aux_info.num_variables { - let point = bit_decompose(i, self.aux_info.num_variables); + for i in 0..1 << self.aux_info.max_num_variables { + let point = bit_decompose(i, self.aux_info.max_num_variables); let point_fr: Vec = point.iter().map(|&x| E::from(x as u64)).collect(); println!("{} {:?}", i, self.evaluate(point_fr.as_ref())) } diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index ce8f0418b..b5900ca69 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -8,6 +8,7 @@ use multilinear_extensions::{ commutative_op_mle_pair, mle::{DenseMultilinearExtension, MultilinearExtension}, op_mle, op_mle_3, + util::largest_even_below, virtual_poly_v2::VirtualPolynomialV2, }; use rayon::{ @@ -47,11 +48,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { assert!( polys .iter() - .map(|poly| (poly.aux_info.num_variables, poly.aux_info.max_degree)) + .map(|poly| (poly.aux_info.max_num_variables, poly.aux_info.max_degree)) .all_equal() ); let (num_variables, max_degree) = ( - polys[0].aux_info.num_variables, + polys[0].aux_info.max_num_variables, polys[0].aux_info.max_degree, ); @@ -81,8 +82,6 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { }) .collect::>(); - // rayon::in_place_scope( - // let (mut prover_states, mut prover_msgs) = rayon::in_place_scope( let scoped_fn = |s: &Scope<'a>| { // spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last // work thread @@ -120,7 +119,9 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { .iter_mut() .for_each(|mle| { let mle = Arc::get_mut(mle).unwrap(); - mle.fix_variables_in_place(&[p.elements]); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } }); tx_prover_state .send(Some((thread_id, prover_state))) @@ -189,10 +190,20 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { .for_each(|mle| { if num_variables == 1 { // first time fix variable should be create new instance - *mle = mle.fix_variables(&[p.elements]).into(); + if mle.num_vars() > 0 { + *mle = mle.fix_variables(&[p.elements]).into(); + } else { + *mle = + Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + mle.get_base_field_vec().to_vec(), + )) + } } else { let mle = Arc::get_mut(mle).unwrap(); - mle.fix_variables_in_place(&[p.elements]); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } } }); tx_prover_state @@ -292,9 +303,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { |mle: &mut Arc< dyn MultilinearExtension>, >| { - Arc::get_mut(mle) - .unwrap() - .fix_variables_in_place(&[p.elements]); + if mle.num_vars() > 0 { + Arc::get_mut(mle) + .unwrap() + .fix_variables_in_place(&[p.elements]); + } }, ); }; @@ -325,7 +338,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ) -> Self { let start = start_timer!(|| "sum check prover init"); assert_ne!( - polynomial.aux_info.num_variables, 0, + polynomial.aux_info.max_num_variables, 0, "Attempt to prove a constant." ); end_timer!(start); @@ -333,7 +346,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let max_degree = polynomial.aux_info.max_degree; assert!(extrapolation_aux.len() == max_degree - 1); Self { - challenges: Vec::with_capacity(polynomial.aux_info.num_variables), + challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, extrapolation_aux, @@ -353,7 +366,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); assert!( - self.round < self.poly.aux_info.num_variables, + self.round < self.poly.aux_info.max_num_variables, "Prover is not active" ); @@ -385,7 +398,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { if self.challenges.len() == 1 { self.poly.flattened_ml_extensions.iter_mut().for_each(|f| { - *f = Arc::new(f.fix_variables(&[r.elements])); + if f.num_vars() > 0 { + *f = Arc::new(f.fix_variables(&[r.elements])); + } else { + panic!("calling sumcheck on constant") + } }); } else { self.poly @@ -396,7 +413,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { // reason .map(Arc::get_mut) .for_each(|f| { - f.unwrap().fix_variables_in_place(&[r.elements]); + if let Some(f) = f { + if f.num_vars() > 0 { + f.fix_variables_in_place(&[r.elements]); + } + } }); } } @@ -407,6 +428,16 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) + // + // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars, + // for it evaluation value we need to times 2^(max_num_vars - num_vars) + // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n + // For i round univariate poly, f^i(x) + // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds + // = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n' + // = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b) + // same applied on f^i[1] + // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value let span = entered_span!("products_sum"); let AdditiveVec(products_sum) = self.poly.products.iter().fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), @@ -418,13 +449,24 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { |f| { - (0..f.len()) + let res = (0..largest_even_below(f.len())) .step_by(2) - .fold(AdditiveArray::(array::from_fn(|_| 0.into())), |mut acc, b| { + .fold(AdditiveArray::<_, 2>(array::from_fn(|_| 0.into())), |mut acc, b| { acc.0[0] += f[b]; acc.0[1] += f[b+1]; acc - }) + }); + let res = if f.len() == 1 { + AdditiveArray::<_, 2>([f[0]; 2]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } }, |sum| AdditiveArray(sum.0.map(E::from)) } @@ -436,16 +478,28 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[1]], ); commutative_op_mle_pair!( - |f, g| (0..f.len()).step_by(2).fold( - AdditiveArray::(array::from_fn(|_| 0.into())), - |mut acc, b| { - acc.0[0] += f[b] * g[b]; - acc.0[1] += f[b + 1] * g[b + 1]; - acc.0[2] += - (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); - acc + |f, g| { + let res = (0..largest_even_below(f.len())).step_by(2).fold( + AdditiveArray::<_, 3>(array::from_fn(|_| 0.into())), + |mut acc, b| { + acc.0[0] += f[b] * g[b]; + acc.0[1] += f[b + 1] * g[b + 1]; + acc.0[2] += + (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); + acc + }); + let res = if f.len() == 1 { + AdditiveArray::<_, 3>([f[0] * g[0]; 3]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res } - ), + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() @@ -457,23 +511,36 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[2]], ); op_mle_3!( - |f1, f2, f3| (0..f1.len()) - .step_by(2) - .map(|b| { - // f = c x + d - let c1 = f1[b + 1] - f1[b]; - let c2 = f2[b + 1] - f2[b]; - let c3 = f3[b + 1] - f3[b]; - AdditiveArray([ - f1[b] * (f2[b] * f3[b]), - f1[b + 1] * (f2[b + 1] * f3[b + 1]), - (c1 + f1[b + 1]) - * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), - (c1 + c1 + f1[b + 1]) - * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), - ]) - }) - .sum::>(), + |f1, f2, f3| { + let res = (0..largest_even_below(f1.len())) + .step_by(2) + .map(|b| { + // f = c x + d + let c1 = f1[b + 1] - f1[b]; + let c2 = f2[b + 1] - f2[b]; + let c3 = f3[b + 1] - f3[b]; + AdditiveArray([ + f1[b] * (f2[b] * f3[b]), + f1[b + 1] * (f2[b + 1] * f3[b + 1]), + (c1 + f1[b + 1]) + * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), + (c1 + c1 + f1[b + 1]) + * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), + ]) + }) + .sum::>(); + let res = if f1.len() == 1 { + AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() @@ -538,7 +605,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { poly: VirtualPolynomialV2<'a, E>, transcript: &mut Transcript, ) -> (IOPProof, IOPProverStateV2<'a, E>) { - let (num_variables, max_degree) = (poly.aux_info.num_variables, poly.aux_info.max_degree); + let (num_variables, max_degree) = + (poly.aux_info.max_num_variables, poly.aux_info.max_degree); // return empty proof when target polymonial is constant if num_variables == 0 { @@ -588,9 +656,14 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { .par_iter_mut() .for_each(|mle| { if let Some(mle) = Arc::get_mut(mle) { - mle.fix_variables_in_place_parallel(&[p.elements]) + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]) + } } else { - *mle = mle.fix_variables(&[p.elements]).into() + *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + mle.get_base_field_vec().to_vec(), + )) } }); }; @@ -616,13 +689,13 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { pub(crate) fn prover_init_parallel(polynomial: VirtualPolynomialV2<'a, E>) -> Self { let start = start_timer!(|| "sum check prover init"); assert_ne!( - polynomial.aux_info.num_variables, 0, + polynomial.aux_info.max_num_variables, 0, "Attempt to prove a constant." ); let max_degree = polynomial.aux_info.max_degree; let prover_state = Self { - challenges: Vec::with_capacity(polynomial.aux_info.num_variables), + challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, extrapolation_aux: (1..max_degree) @@ -651,7 +724,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); assert!( - self.round < self.poly.aux_info.num_variables, + self.round < self.poly.aux_info.max_num_variables, "Prover is not active" ); @@ -682,7 +755,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { .flattened_ml_extensions .par_iter_mut() .for_each(|f| { - *f = Arc::new(f.fix_variables_parallel(&[r.elements])); + if f.num_vars() > 0 { + *f = Arc::new(f.fix_variables_parallel(&[r.elements])); + } else { + panic!("calling sumcheck on constant") + } }); } else { self.poly @@ -693,7 +770,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { // reason .map(Arc::get_mut) .for_each(|f| { - f.unwrap().fix_variables_in_place_parallel(&[r.elements]); + if let Some(f) = f { + if f.num_vars() > 0 { + f.fix_variables_in_place_parallel(&[r.elements]) + } + } }); } } @@ -718,17 +799,30 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { 1 => { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { - |f| (0..f.len()) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { - AdditiveArray([ - f[b], - f[b + 1] - ]) - }) - .sum::>(), + |f| { + let res = (0..largest_even_below(f.len())) + .into_par_iter() + .step_by(2) + .with_min_len(64) + .map(|b| { + AdditiveArray([ + f[b], + f[b + 1] + ]) + }) + .sum::>(); + let res = if f.len() == 1 { + AdditiveArray::<_, 2>([f[0]; 2]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) } .to_vec() @@ -739,7 +833,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[1]], ); commutative_op_mle_pair!( - |f, g| (0..f.len()) + |f, g| { + let res = (0..largest_even_below(f.len())) .into_par_iter() .step_by(2) .with_min_len(64) @@ -751,7 +846,19 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { * (g[b + 1] + g[b + 1] - g[b]), ]) }) - .sum::>(), + .sum::>(); + let res = if f.len() == 1 { + AdditiveArray::<_, 3>([f[0] * g[0]; 3]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() @@ -763,7 +870,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[2]], ); op_mle_3!( - |f1, f2, f3| (0..f1.len()) + |f1, f2, f3| { + let res = (0..largest_even_below(f1.len())) .step_by(2) .map(|b| { // f = c x + d @@ -779,7 +887,19 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), ]) }) - .sum::>(), + .sum::>(); + let res = if f1.len() == 1 { + AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 0188746c7..94e367be8 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -13,6 +13,8 @@ use crate::{ util::interpolate_uni_poly, }; +// TODO add more tests related to various num_vars combination after PR #162 + fn test_sumcheck( nv: usize, num_multiplicands_range: (usize, usize), diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index eb89a5655..08b87487d 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -235,7 +235,7 @@ pub(crate) fn merge_sumcheck_polys_v2<'a, E: ExtensionField>( ) -> VirtualPolynomialV2<'a, E> { let log2_max_thread_id = ceil_log2(max_thread_id); let mut poly = prover_states[0].poly.clone(); // giving only one evaluation left, this clone is low cost. - poly.aux_info.num_variables = log2_max_thread_id; // size_log2 variates sumcheck + poly.aux_info.max_num_variables = log2_max_thread_id; // size_log2 variates sumcheck for i in 0..poly.flattened_ml_extensions.len() { let ml_ext = DenseMultilinearExtension::from_evaluations_ext_vec( log2_max_thread_id, diff --git a/sumcheck/src/verifier.rs b/sumcheck/src/verifier.rs index c402424bd..f67c09103 100644 --- a/sumcheck/src/verifier.rs +++ b/sumcheck/src/verifier.rs @@ -1,7 +1,6 @@ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use multilinear_extensions::virtual_poly::VPAuxInfo; -use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use transcript::{Challenge, Transcript}; use crate::{ @@ -122,12 +121,10 @@ impl IOPVerifierState { // the deferred check during the interactive phase: // 2. set `expected` to P(r)` - let mut expected_vec = self .polynomials_received - .clone() - .into_par_iter() - .zip(self.challenges.clone().into_par_iter()) + .iter() + .zip(self.challenges.iter()) .map(|(evaluations, challenge)| { if evaluations.len() != self.max_degree + 1 { panic!( @@ -136,11 +133,11 @@ impl IOPVerifierState { self.max_degree + 1 ); } - interpolate_uni_poly::(&evaluations, challenge.elements) + interpolate_uni_poly::(evaluations, challenge.elements) }) .collect::>(); - // insert the asserted_sum to the first position of the expected vector + // l-append asserted_sum to the first position of the expected vector expected_vec.insert(0, *asserted_sum); for (i, (evaluations, &expected)) in self