diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index c4dcfeb3e..f3117a706 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -18,7 +18,12 @@ use multilinear_extensions::{ }; use transcript::Transcript; -criterion_group!(benches, sumcheck_fn, devirgo_sumcheck_fn,); +criterion_group!( + benches, + sumcheck_serial_fn, + sumcheck_fn, + devirgo_sumcheck_fn, +); criterion_main!(benches); const NUM_SAMPLES: usize = 10; @@ -84,7 +89,7 @@ const RAYON_NUM_THREADS: usize = 8; fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in 24..25 { + for nv in 20..25 { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("sumcheck_nv_{}", nv)); group.sample_size(NUM_SAMPLES); @@ -119,10 +124,48 @@ fn sumcheck_fn(c: &mut Criterion) { } } +fn sumcheck_serial_fn(c: &mut Criterion) { + type E = GoldilocksExt2; + + for nv in 20..25 { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("sumcheck_serial_nv_{}", nv)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_sumcheck", format!("sumcheck_serial_nv_{}", nv)), + |b| { + b.iter_with_setup( + || { + let prover_transcript = Transcript::::new(b"test"); + let (asserted_sum, virtual_poly, virtual_poly_splitted) = + { prepare_input(RAYON_NUM_THREADS, nv) }; + ( + prover_transcript, + asserted_sum, + virtual_poly, + virtual_poly_splitted, + ) + }, + |(mut prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted)| { + let (sumcheck_proof_v1_serial, _) = IOPProverState::::prove_serial( + virtual_poly.clone(), + &mut prover_transcript, + ); + }, + ); + }, + ); + + group.finish(); + } +} + fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in 24..25 { + for nv in 20..25 { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); group.sample_size(NUM_SAMPLES); diff --git a/sumcheck/examples/devirgo_sumcheck.rs b/sumcheck/examples/devirgo_sumcheck.rs index 3cc7be741..23d8493a0 100644 --- a/sumcheck/examples/devirgo_sumcheck.rs +++ b/sumcheck/examples/devirgo_sumcheck.rs @@ -78,6 +78,7 @@ const RAYON_NUM_THREADS: usize = 8; fn main() { let mut prover_transcript_v1 = Transcript::::new(b"test"); + let mut prover_transcript_v1_serial = Transcript::::new(b"test"); let mut prover_transcript_v2 = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = prepare_input(RAYON_NUM_THREADS); @@ -113,4 +114,10 @@ fn main() { println!("v1 finish"); assert!(sumcheck_proof_v2 == sumcheck_proof_v1); + + let (sumcheck_proof_v1_serial, _) = + IOPProverState::::prove_serial(virtual_poly.clone(), &mut prover_transcript_v1_serial); + + println!("v1 serial finish"); + assert!(sumcheck_proof_v2 == sumcheck_proof_v1_serial); } diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 4d83e92a2..7d07d1595 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -424,7 +424,7 @@ impl IOPProverState { let span = entered_span!("extrapolation"); let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) - .into_par_iter() + .into_iter() .map(|i| { let (points, weights) = &self.extrapolation_aux[products.len() - 1]; let at = E::from((products.len() + 1 + i) as u64); @@ -449,6 +449,80 @@ impl IOPProverState { } } + // The prover runs in sequential mode in the sense that we use the serial version of + // fix_variables and eval_partial_poly in each round of sumcheck. + pub fn prove_serial( + poly: VirtualPolynomial, + transcript: &mut Transcript, + ) -> (IOPProof, IOPProverState) { + let (num_variables, max_degree) = (poly.aux_info.num_variables, poly.aux_info.max_degree); + + // return empty proof when target polymonial is constant + if num_variables == 0 { + return ( + IOPProof::default(), + IOPProverState { + poly: poly, + ..Default::default() + }, + ); + } + let start = start_timer!(|| "sum check prove"); + + transcript.append_message(&num_variables.to_le_bytes()); + transcript.append_message(&max_degree.to_le_bytes()); + + let mut prover_state = Self::prover_init_parallel(poly); + let mut challenge = None; + let mut prover_msgs = Vec::with_capacity(num_variables); + let span = entered_span!("prove_rounds"); + for _ in 0..num_variables { + let prover_msg = + IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); + + prover_msg + .evaluations + .iter() + .for_each(|e| transcript.append_field_element_ext(e)); + + prover_msgs.push(prover_msg); + let span = entered_span!("get_challenge"); + challenge = Some(transcript.get_and_append_challenge(b"Internal round")); + exit_span!(span); + } + exit_span!(span); + + let span = entered_span!("after_rounds_prover_state"); + // pushing the last challenge point to the state + if let Some(p) = challenge { + prover_state.challenges.push(p); + // fix last challenge to collect final evaluation + prover_state + .poly + .flattened_ml_extensions + .iter_mut() + .for_each(|mle| { + Arc::make_mut(mle).fix_variables_in_place(&[p.elements]); + }); + }; + exit_span!(span); + + end_timer!(start); + ( + IOPProof { + // the point consists of the first elements in the challenge + point: prover_state + .challenges + .iter() + .map(|challenge| challenge.elements) + .collect(), + proofs: prover_msgs, + ..Default::default() + }, + prover_state.into(), + ) + } + /// collect all mle evaluation (claim) after sumcheck pub fn get_mle_final_evaluations(&self) -> Vec { self.poly