diff --git a/Cargo.lock b/Cargo.lock index c47791660..357fad1cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1198,6 +1198,7 @@ dependencies = [ "ff", "ff_ext", "goldilocks", + "itertools 0.12.1", "log", "rayon", "serde", diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 375f15b8c..5d81ea042 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -168,19 +168,28 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { #[cfg(test)] mod tests { + use ark_std::test_rng; + use ff_ext::ExtensionField; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use multilinear_extensions::{mle::IntoMLE, virtual_poly_v2::ArcMultilinearExtension}; + use multilinear_extensions::{ + mle::IntoMLE, + virtual_poly::VPAuxInfo, + virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, + }; + use sumcheck::structs::{IOPProverStateV2, IOPVerifierState}; + use transcript::Transcript; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, expression::{Expression, ToExpr}, virtual_polys::VirtualPolynomials, }; + use ff::Field; + type E = GoldilocksExt2; #[test] fn test_add_mle_list_by_expr() { - type E = GoldilocksExt2; let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); let x = cb.create_witin(|| "x").unwrap(); @@ -218,4 +227,96 @@ mod tests { assert!(distrinct_zerocheck_terms_set.len() == 1); assert!(virtual_polys.degree() == 3); } + + #[test] + fn test_sumcheck_different_degree() { + let max_num_vars = 3; + let fn_eval = |fs: &[ArcMultilinearExtension]| -> E { + let base_2 = ::BaseField::from(2); + + let evals = fs.iter().fold( + vec![::BaseField::ONE; 1 << fs[0].num_vars()], + |mut evals, f| { + evals + .iter_mut() + .zip(f.get_base_field_vec()) + .for_each(|(e, v)| { + *e *= v; + }); + evals + }, + ); + + <::BaseField as std::convert::Into>::into( + evals.iter().sum::<::BaseField>() + * base_2.pow([(max_num_vars - fs[0].num_vars()) as u64]), + ) + }; + let num_threads = 1; + let mut transcript = Transcript::new(b"test"); + + let mut rng = test_rng(); + + let f1: [ArcMultilinearExtension; 2] = std::array::from_fn(|_| { + (0..1 << (max_num_vars - 2)) + .map(|_| ::BaseField::random(&mut rng)) + .collect_vec() + .into_mle() + .into() + }); + let f2: [ArcMultilinearExtension; 1] = std::array::from_fn(|_| { + (0..1 << (max_num_vars)) + .map(|_| ::BaseField::random(&mut rng)) + .collect_vec() + .into_mle() + .into() + }); + let f3: [ArcMultilinearExtension; 3] = std::array::from_fn(|_| { + (0..1 << (max_num_vars - 1)) + .map(|_| ::BaseField::random(&mut rng)) + .collect_vec() + .into_mle() + .into() + }); + + let mut virtual_polys = VirtualPolynomials::::new(num_threads, max_num_vars); + + virtual_polys.add_mle_list(f1.iter().collect(), E::ONE); + virtual_polys.add_mle_list(f2.iter().collect(), E::ONE); + virtual_polys.add_mle_list(f3.iter().collect(), E::ONE); + + let (sumcheck_proofs, _) = IOPProverStateV2::prove_batch_polys( + num_threads, + virtual_polys.get_batched_polys(), + &mut transcript, + ); + + let mut transcript = Transcript::new(b"test"); + let subclaim = IOPVerifierState::::verify( + fn_eval(&f1) + fn_eval(&f2) + fn_eval(&f3), + &sumcheck_proofs, + &VPAuxInfo { + max_degree: 3, + num_variables: max_num_vars, + phantom: std::marker::PhantomData, + }, + &mut transcript, + ); + + let mut verifier_poly = VirtualPolynomialV2::new(max_num_vars); + verifier_poly.add_mle_list(f1.to_vec(), E::ONE); + verifier_poly.add_mle_list(f2.to_vec(), E::ONE); + verifier_poly.add_mle_list(f3.to_vec(), E::ONE); + assert!( + verifier_poly.evaluate( + subclaim + .point + .iter() + .map(|c| c.elements) + .collect::>() + .as_ref() + ) == subclaim.expected_evaluation, + "wrong subclaim" + ); + } } diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index b4bf646ba..903dd776a 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -11,6 +11,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } tracing-flame = "0.2.0" ff_ext = { path = "../ff_ext" } +itertools = "0.12.1" ark-std.workspace = true ff.workspace = true goldilocks.workspace = true diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index a4530919c..8ce82ae02 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -124,7 +124,9 @@ pub trait IntoMLEs: Sized { fn into_mles(self) -> Vec; } -impl> IntoMLEs> for Vec> { +impl> IntoMLEs> + for Vec> +{ fn into_mles(self) -> Vec> { self.into_iter().map(|v| v.into_mle()).collect() } @@ -1000,12 +1002,6 @@ macro_rules! op_mle { match &$a.evaluations() { $crate::mle::FieldType::Base(a) => { let $tmp_a = if let Some((start, offset)) = $a.evaluations_range() { - println!( - "op_mle start {}, offset {}, a.len {}", - start, - offset, - a.len() - ); &a[start..][..offset] } else { &a[..] diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index c25e1f8eb..28e4f8284 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -25,3 +25,8 @@ pub fn create_uninit_vec(len: usize) -> Vec> { unsafe { vec.set_len(len) }; vec } + +#[inline(always)] +pub fn largest_even_below(n: usize) -> usize { + if n % 2 == 0 { n } else { n.saturating_sub(1) } +} diff --git a/multilinear_extensions/src/virtual_poly_v2.rs b/multilinear_extensions/src/virtual_poly_v2.rs index c8a46a367..dcf588baf 100644 --- a/multilinear_extensions/src/virtual_poly_v2.rs +++ b/multilinear_extensions/src/virtual_poly_v2.rs @@ -6,6 +6,7 @@ use crate::{ }; use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; +use itertools::Itertools; use serde::{Deserialize, Serialize}; pub type ArcMultilinearExtension<'a, E> = @@ -55,8 +56,8 @@ pub struct VirtualPolynomialV2<'a, E: ExtensionField> { pub struct VPAuxInfo { /// max number of multiplicands in each product pub max_degree: usize, - /// number of variables of the polynomial - pub num_variables: usize, + /// max number of variables of the polynomial + pub max_num_variables: usize, /// Associated field #[doc(hidden)] pub phantom: PhantomData, @@ -69,12 +70,12 @@ impl AsRef<[u8]> for VPAuxInfo { } impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { - /// Creates an empty virtual polynomial with `num_variables`. - pub fn new(num_variables: usize) -> Self { + /// Creates an empty virtual polynomial with `max_num_variables`. + pub fn new(max_num_variables: usize) -> Self { VirtualPolynomialV2 { aux_info: VPAuxInfo { max_degree: 0, - num_variables, + max_num_variables, phantom: PhantomData, }, products: Vec::new(), @@ -93,7 +94,7 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { aux_info: VPAuxInfo { // The max degree is the max degree of any individual variable max_degree: 1, - num_variables: mle.num_vars(), + max_num_variables: mle.num_vars(), phantom: PhantomData, }, // here `0` points to the first polynomial of `flattened_ml_extensions` @@ -104,8 +105,10 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { } /// Add a product of list of multilinear extensions to self - /// Returns an error if the list is empty, or the MLE has a different - /// `num_vars()` from self. + /// Returns an error if the list is empty. + /// + /// mle in mle_list must be in same num_vars() in same product, + /// while different product can have different num_vars() /// /// The MLEs will be multiplied together, and then multiplied by the scalar /// `coefficient`. @@ -114,18 +117,20 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { let mut indexed_product = Vec::with_capacity(mle_list.len()); assert!(!mle_list.is_empty(), "input mle_list is empty"); + // sanity check: all mle in mle_list must have same num_vars() + assert!( + mle_list + .iter() + .map(|m| { + assert!(m.num_vars() <= self.aux_info.max_num_variables); + m.num_vars() + }) + .all_equal() + ); self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len()); for mle in mle_list { - assert_eq!( - mle.num_vars(), - self.aux_info.num_variables, - "product has a multiplicand with wrong number of variables {} vs {}", - mle.num_vars(), - self.aux_info.num_variables - ); - let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) { indexed_product.push(*index) @@ -163,10 +168,10 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { assert_eq!( mle.num_vars(), - self.aux_info.num_variables, + self.aux_info.max_num_variables, "product has a multiplicand with wrong number of variables {} vs {}", mle.num_vars(), - self.aux_info.num_variables + self.aux_info.max_num_variables ); let mle_ptr = Arc::as_ptr(&mle) as *const () as usize; @@ -200,17 +205,17 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { let start = start_timer!(|| "evaluation"); assert_eq!( - self.aux_info.num_variables, + self.aux_info.max_num_variables, point.len(), "wrong number of variables {} vs {}", - self.aux_info.num_variables, + self.aux_info.max_num_variables, point.len() ); let evals: Vec = self .flattened_ml_extensions .iter() - .map(|x| x.evaluate(point)) + .map(|x| x.evaluate(&point[0..x.num_vars()])) .collect(); let res = self @@ -225,11 +230,11 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { /// Print out the evaluation map for testing. Panic if the num_vars() > 5. pub fn print_evals(&self) { - if self.aux_info.num_variables > 5 { + if self.aux_info.max_num_variables > 5 { panic!("this function is used for testing only. cannot print more than 5 num_vars()") } - for i in 0..1 << self.aux_info.num_variables { - let point = bit_decompose(i, self.aux_info.num_variables); + for i in 0..1 << self.aux_info.max_num_variables { + let point = bit_decompose(i, self.aux_info.max_num_variables); let point_fr: Vec = point.iter().map(|&x| E::from(x as u64)).collect(); println!("{} {:?}", i, self.evaluate(point_fr.as_ref())) } diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index ce8f0418b..b5900ca69 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -8,6 +8,7 @@ use multilinear_extensions::{ commutative_op_mle_pair, mle::{DenseMultilinearExtension, MultilinearExtension}, op_mle, op_mle_3, + util::largest_even_below, virtual_poly_v2::VirtualPolynomialV2, }; use rayon::{ @@ -47,11 +48,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { assert!( polys .iter() - .map(|poly| (poly.aux_info.num_variables, poly.aux_info.max_degree)) + .map(|poly| (poly.aux_info.max_num_variables, poly.aux_info.max_degree)) .all_equal() ); let (num_variables, max_degree) = ( - polys[0].aux_info.num_variables, + polys[0].aux_info.max_num_variables, polys[0].aux_info.max_degree, ); @@ -81,8 +82,6 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { }) .collect::>(); - // rayon::in_place_scope( - // let (mut prover_states, mut prover_msgs) = rayon::in_place_scope( let scoped_fn = |s: &Scope<'a>| { // spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last // work thread @@ -120,7 +119,9 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { .iter_mut() .for_each(|mle| { let mle = Arc::get_mut(mle).unwrap(); - mle.fix_variables_in_place(&[p.elements]); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } }); tx_prover_state .send(Some((thread_id, prover_state))) @@ -189,10 +190,20 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { .for_each(|mle| { if num_variables == 1 { // first time fix variable should be create new instance - *mle = mle.fix_variables(&[p.elements]).into(); + if mle.num_vars() > 0 { + *mle = mle.fix_variables(&[p.elements]).into(); + } else { + *mle = + Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + mle.get_base_field_vec().to_vec(), + )) + } } else { let mle = Arc::get_mut(mle).unwrap(); - mle.fix_variables_in_place(&[p.elements]); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } } }); tx_prover_state @@ -292,9 +303,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { |mle: &mut Arc< dyn MultilinearExtension>, >| { - Arc::get_mut(mle) - .unwrap() - .fix_variables_in_place(&[p.elements]); + if mle.num_vars() > 0 { + Arc::get_mut(mle) + .unwrap() + .fix_variables_in_place(&[p.elements]); + } }, ); }; @@ -325,7 +338,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ) -> Self { let start = start_timer!(|| "sum check prover init"); assert_ne!( - polynomial.aux_info.num_variables, 0, + polynomial.aux_info.max_num_variables, 0, "Attempt to prove a constant." ); end_timer!(start); @@ -333,7 +346,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let max_degree = polynomial.aux_info.max_degree; assert!(extrapolation_aux.len() == max_degree - 1); Self { - challenges: Vec::with_capacity(polynomial.aux_info.num_variables), + challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, extrapolation_aux, @@ -353,7 +366,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); assert!( - self.round < self.poly.aux_info.num_variables, + self.round < self.poly.aux_info.max_num_variables, "Prover is not active" ); @@ -385,7 +398,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { if self.challenges.len() == 1 { self.poly.flattened_ml_extensions.iter_mut().for_each(|f| { - *f = Arc::new(f.fix_variables(&[r.elements])); + if f.num_vars() > 0 { + *f = Arc::new(f.fix_variables(&[r.elements])); + } else { + panic!("calling sumcheck on constant") + } }); } else { self.poly @@ -396,7 +413,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { // reason .map(Arc::get_mut) .for_each(|f| { - f.unwrap().fix_variables_in_place(&[r.elements]); + if let Some(f) = f { + if f.num_vars() > 0 { + f.fix_variables_in_place(&[r.elements]); + } + } }); } } @@ -407,6 +428,16 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) + // + // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars, + // for it evaluation value we need to times 2^(max_num_vars - num_vars) + // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n + // For i round univariate poly, f^i(x) + // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds + // = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n' + // = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b) + // same applied on f^i[1] + // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value let span = entered_span!("products_sum"); let AdditiveVec(products_sum) = self.poly.products.iter().fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), @@ -418,13 +449,24 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { |f| { - (0..f.len()) + let res = (0..largest_even_below(f.len())) .step_by(2) - .fold(AdditiveArray::(array::from_fn(|_| 0.into())), |mut acc, b| { + .fold(AdditiveArray::<_, 2>(array::from_fn(|_| 0.into())), |mut acc, b| { acc.0[0] += f[b]; acc.0[1] += f[b+1]; acc - }) + }); + let res = if f.len() == 1 { + AdditiveArray::<_, 2>([f[0]; 2]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } }, |sum| AdditiveArray(sum.0.map(E::from)) } @@ -436,16 +478,28 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[1]], ); commutative_op_mle_pair!( - |f, g| (0..f.len()).step_by(2).fold( - AdditiveArray::(array::from_fn(|_| 0.into())), - |mut acc, b| { - acc.0[0] += f[b] * g[b]; - acc.0[1] += f[b + 1] * g[b + 1]; - acc.0[2] += - (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); - acc + |f, g| { + let res = (0..largest_even_below(f.len())).step_by(2).fold( + AdditiveArray::<_, 3>(array::from_fn(|_| 0.into())), + |mut acc, b| { + acc.0[0] += f[b] * g[b]; + acc.0[1] += f[b + 1] * g[b + 1]; + acc.0[2] += + (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); + acc + }); + let res = if f.len() == 1 { + AdditiveArray::<_, 3>([f[0] * g[0]; 3]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res } - ), + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() @@ -457,23 +511,36 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[2]], ); op_mle_3!( - |f1, f2, f3| (0..f1.len()) - .step_by(2) - .map(|b| { - // f = c x + d - let c1 = f1[b + 1] - f1[b]; - let c2 = f2[b + 1] - f2[b]; - let c3 = f3[b + 1] - f3[b]; - AdditiveArray([ - f1[b] * (f2[b] * f3[b]), - f1[b + 1] * (f2[b + 1] * f3[b + 1]), - (c1 + f1[b + 1]) - * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), - (c1 + c1 + f1[b + 1]) - * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), - ]) - }) - .sum::>(), + |f1, f2, f3| { + let res = (0..largest_even_below(f1.len())) + .step_by(2) + .map(|b| { + // f = c x + d + let c1 = f1[b + 1] - f1[b]; + let c2 = f2[b + 1] - f2[b]; + let c3 = f3[b + 1] - f3[b]; + AdditiveArray([ + f1[b] * (f2[b] * f3[b]), + f1[b + 1] * (f2[b + 1] * f3[b + 1]), + (c1 + f1[b + 1]) + * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), + (c1 + c1 + f1[b + 1]) + * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), + ]) + }) + .sum::>(); + let res = if f1.len() == 1 { + AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() @@ -538,7 +605,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { poly: VirtualPolynomialV2<'a, E>, transcript: &mut Transcript, ) -> (IOPProof, IOPProverStateV2<'a, E>) { - let (num_variables, max_degree) = (poly.aux_info.num_variables, poly.aux_info.max_degree); + let (num_variables, max_degree) = + (poly.aux_info.max_num_variables, poly.aux_info.max_degree); // return empty proof when target polymonial is constant if num_variables == 0 { @@ -588,9 +656,14 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { .par_iter_mut() .for_each(|mle| { if let Some(mle) = Arc::get_mut(mle) { - mle.fix_variables_in_place_parallel(&[p.elements]) + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]) + } } else { - *mle = mle.fix_variables(&[p.elements]).into() + *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + mle.get_base_field_vec().to_vec(), + )) } }); }; @@ -616,13 +689,13 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { pub(crate) fn prover_init_parallel(polynomial: VirtualPolynomialV2<'a, E>) -> Self { let start = start_timer!(|| "sum check prover init"); assert_ne!( - polynomial.aux_info.num_variables, 0, + polynomial.aux_info.max_num_variables, 0, "Attempt to prove a constant." ); let max_degree = polynomial.aux_info.max_degree; let prover_state = Self { - challenges: Vec::with_capacity(polynomial.aux_info.num_variables), + challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, extrapolation_aux: (1..max_degree) @@ -651,7 +724,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); assert!( - self.round < self.poly.aux_info.num_variables, + self.round < self.poly.aux_info.max_num_variables, "Prover is not active" ); @@ -682,7 +755,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { .flattened_ml_extensions .par_iter_mut() .for_each(|f| { - *f = Arc::new(f.fix_variables_parallel(&[r.elements])); + if f.num_vars() > 0 { + *f = Arc::new(f.fix_variables_parallel(&[r.elements])); + } else { + panic!("calling sumcheck on constant") + } }); } else { self.poly @@ -693,7 +770,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { // reason .map(Arc::get_mut) .for_each(|f| { - f.unwrap().fix_variables_in_place_parallel(&[r.elements]); + if let Some(f) = f { + if f.num_vars() > 0 { + f.fix_variables_in_place_parallel(&[r.elements]) + } + } }); } } @@ -718,17 +799,30 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { 1 => { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { - |f| (0..f.len()) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { - AdditiveArray([ - f[b], - f[b + 1] - ]) - }) - .sum::>(), + |f| { + let res = (0..largest_even_below(f.len())) + .into_par_iter() + .step_by(2) + .with_min_len(64) + .map(|b| { + AdditiveArray([ + f[b], + f[b + 1] + ]) + }) + .sum::>(); + let res = if f.len() == 1 { + AdditiveArray::<_, 2>([f[0]; 2]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) } .to_vec() @@ -739,7 +833,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[1]], ); commutative_op_mle_pair!( - |f, g| (0..f.len()) + |f, g| { + let res = (0..largest_even_below(f.len())) .into_par_iter() .step_by(2) .with_min_len(64) @@ -751,7 +846,19 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { * (g[b + 1] + g[b + 1] - g[b]), ]) }) - .sum::>(), + .sum::>(); + let res = if f.len() == 1 { + AdditiveArray::<_, 3>([f[0] * g[0]; 3]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() @@ -763,7 +870,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[2]], ); op_mle_3!( - |f1, f2, f3| (0..f1.len()) + |f1, f2, f3| { + let res = (0..largest_even_below(f1.len())) .step_by(2) .map(|b| { // f = c x + d @@ -779,7 +887,19 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), ]) }) - .sum::>(), + .sum::>(); + let res = if f1.len() == 1 { + AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 0188746c7..94e367be8 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -13,6 +13,8 @@ use crate::{ util::interpolate_uni_poly, }; +// TODO add more tests related to various num_vars combination after PR #162 + fn test_sumcheck( nv: usize, num_multiplicands_range: (usize, usize), diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index eb89a5655..08b87487d 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -235,7 +235,7 @@ pub(crate) fn merge_sumcheck_polys_v2<'a, E: ExtensionField>( ) -> VirtualPolynomialV2<'a, E> { let log2_max_thread_id = ceil_log2(max_thread_id); let mut poly = prover_states[0].poly.clone(); // giving only one evaluation left, this clone is low cost. - poly.aux_info.num_variables = log2_max_thread_id; // size_log2 variates sumcheck + poly.aux_info.max_num_variables = log2_max_thread_id; // size_log2 variates sumcheck for i in 0..poly.flattened_ml_extensions.len() { let ml_ext = DenseMultilinearExtension::from_evaluations_ext_vec( log2_max_thread_id, diff --git a/sumcheck/src/verifier.rs b/sumcheck/src/verifier.rs index c402424bd..f67c09103 100644 --- a/sumcheck/src/verifier.rs +++ b/sumcheck/src/verifier.rs @@ -1,7 +1,6 @@ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use multilinear_extensions::virtual_poly::VPAuxInfo; -use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use transcript::{Challenge, Transcript}; use crate::{ @@ -122,12 +121,10 @@ impl IOPVerifierState { // the deferred check during the interactive phase: // 2. set `expected` to P(r)` - let mut expected_vec = self .polynomials_received - .clone() - .into_par_iter() - .zip(self.challenges.clone().into_par_iter()) + .iter() + .zip(self.challenges.iter()) .map(|(evaluations, challenge)| { if evaluations.len() != self.max_degree + 1 { panic!( @@ -136,11 +133,11 @@ impl IOPVerifierState { self.max_degree + 1 ); } - interpolate_uni_poly::(&evaluations, challenge.elements) + interpolate_uni_poly::(evaluations, challenge.elements) }) .collect::>(); - // insert the asserted_sum to the first position of the expected vector + // l-append asserted_sum to the first position of the expected vector expected_vec.insert(0, *asserted_sum); for (i, (evaluations, &expected)) in self