From 86573f3b2f1f34bdd530fbe5c4ec70816c80fd45 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 2 Jul 2024 20:19:40 +0800 Subject: [PATCH 1/7] add fix_var benches --- Cargo.lock | 1 + multilinear_extensions/Cargo.toml | 7 +++ multilinear_extensions/benches/mle.rs | 70 +++++++++++++++++++++++++++ 3 files changed, 78 insertions(+) create mode 100644 multilinear_extensions/benches/mle.rs diff --git a/Cargo.lock b/Cargo.lock index 1aaab2229..cfd03e3f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -782,6 +782,7 @@ name = "multilinear_extensions" version = "0.1.0" dependencies = [ "ark-std", + "criterion", "ff", "ff_ext", "goldilocks", diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index 4c4b4dcac..f07cee0d7 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -15,5 +15,12 @@ goldilocks.workspace = true rayon.workspace = true serde.workspace = true +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + +[[bench]] +name = "mle" +harness = false + [features] parallel = [ ] diff --git a/multilinear_extensions/benches/mle.rs b/multilinear_extensions/benches/mle.rs new file mode 100644 index 000000000..54c601a8b --- /dev/null +++ b/multilinear_extensions/benches/mle.rs @@ -0,0 +1,70 @@ +use ark_std::rand::thread_rng; +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; +use multilinear_extensions::mle::DenseMultilinearExtension; + +fn fix_var(c: &mut Criterion) { + let mut rng = thread_rng(); + + const NUM_SAMPLES: usize = 10; + for nv in 12..20 { + let mut group = c.benchmark_group("mle"); + group.sample_size(NUM_SAMPLES); + + for i in 0..nv { + group.bench_function( + BenchmarkId::new("fix_var", format!("({},{})", nv, nv - i)), + |b| { + b.iter_with_setup( + || { + let mut v = + DenseMultilinearExtension::::random(nv, &mut rng); + let r = GoldilocksExt2::random(&mut rng); + for _ in 0..i { + v.fix_variables_in_place(&[r]); + } + (v, r) + }, + |(mut v, r)| v.fix_variables_in_place(&[r]), + ); + }, + ); + } + group.finish(); + } +} + +fn fix_var_par(c: &mut Criterion) { + let mut rng = thread_rng(); + + const NUM_SAMPLES: usize = 10; + for nv in 12..20 { + let mut group = c.benchmark_group("mle"); + group.sample_size(NUM_SAMPLES); + + for i in 0..nv { + group.bench_function( + BenchmarkId::new("fix_var_par", format!("({},{})", nv, nv - i)), + |b| { + b.iter_with_setup( + || { + let mut v = + DenseMultilinearExtension::::random(nv, &mut rng); + let r = GoldilocksExt2::random(&mut rng); + for _ in 0..i { + v.fix_variables_in_place(&[r]); + } + (v, r) + }, + |(mut v, r)| v.fix_variables_in_place_parallel(&[r]), + ); + }, + ); + } + group.finish(); + } +} + +criterion_group!(benches, fix_var, fix_var_par,); +criterion_main!(benches); From 8847fa7fbf524d22b504b5742c54d928f823e5c6 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 3 Jul 2024 18:14:11 +0800 Subject: [PATCH 2/7] add docs for extrapolation based on barycentric formula --- sumcheck/src/prover.rs | 11 +++++++++-- sumcheck/src/util.rs | 10 ++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 5af55e749..4d83e92a2 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -648,6 +648,7 @@ impl IOPProverState { |mut products_sum, (coefficient, products)| { let span = entered_span!("sum"); + // p^m(0), p^m(1), ... let mut sum = match products.len() { 1 => { let f = &self.poly.flattened_ml_extensions[products[0]]; @@ -678,8 +679,12 @@ impl IOPProverState { .step_by(2) .with_min_len(64) .map(|b| { + // at round m, + // f^m(0,b) = f[b], g^m(0,b) = g[b] + // f^m(1,b) = f[b+1], g^m(1,b) = g[b+1] + // f^m(2,b) = 2*f[b+1]-f[b], g^m(2,b) = 2*g[b+1]-g[b] AdditiveArray([ - f[b] * g[b], + f[b] * g[b], // f^m(0)*g^m(0) = \sum_b f^m(0,b)*g^m(0,b) f[b + 1] * g[b + 1], (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]), @@ -696,8 +701,10 @@ impl IOPProverState { sum.iter_mut().for_each(|sum| *sum *= coefficient); let span = entered_span!("extrapolation"); + // from { p^m(0), p^m(1), ..., p^m(products.len()) } + // extrapolate { p^m(products.len()+1), ..., p^m(max_degree) } let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) - .into_par_iter() + .into_iter() .map(|i| { let (points, weights) = &self.extrapolation_aux[products.len() - 1]; let at = E::from((products.len() + 1 + i) as u64); diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index 1098fd240..0eff13504 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -15,6 +15,7 @@ use rayon::{prelude::ParallelIterator, slice::ParallelSliceMut}; use crate::structs::IOPProverState; +// for each j, its barycentric weight = \frac{1}{ \prod_{i != j} point_j - point_i } pub fn barycentric_weights(points: &[F]) -> Vec { let mut weights = points .iter() @@ -91,13 +92,19 @@ fn serial_batch_inversion_and_mul(v: &mut [F], coeff: &F) { } } +// refer to https://people.maths.ox.ac.uk/trefethen/barycentric.pdf for barycentric formula +// p(x) = \sum f_j * (w_j / (x - x_j)) / \sum w_j / (x - x_j) pub(crate) fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { let (coeffs, sum_inv) = { + // x - x_j let mut coeffs = points.iter().map(|point| *at - point).collect::>(); + // 1 / (x - x_j) batch_inversion(&mut coeffs); let mut sum = F::ZERO; coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { + // w_j / (x - x_j) *coeff *= weight; + // \sum w_j / (x - x_j) sum += *coeff }); let sum_inv = sum.invert().unwrap_or(F::ZERO); @@ -106,8 +113,11 @@ pub(crate) fn extrapolate(points: &[F], weights: &[F], evals: &[F coeffs .iter() .zip(evals) + // w_j * y_j / (x - x_j) .map(|(coeff, eval)| *coeff * eval) + // \sum_j w_j * y_j / (x - x_j) .sum::() + // p(x) * sum_inv } From b3e8a385d02b6a9a71647c3920c44a78c6023ff6 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 3 Jul 2024 19:36:02 +0800 Subject: [PATCH 3/7] add eval_partial_poly benches --- sumcheck/Cargo.toml | 4 + sumcheck/benches/eval_partial_poly.rs | 127 ++++++++++++++++++++++++++ sumcheck/src/util.rs | 2 +- 3 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 sumcheck/benches/eval_partial_poly.rs diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 5a749e41b..bd244f1c7 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -30,5 +30,9 @@ criterion = { version = "0.5", features = ["html_reports"] } name = "devirgo_sumcheck" harness = false +[[bench]] +name = "eval_partial_poly" +harness = false + [features] non_pow2_rayon_thread = [ ] \ No newline at end of file diff --git a/sumcheck/benches/eval_partial_poly.rs b/sumcheck/benches/eval_partial_poly.rs new file mode 100644 index 000000000..a9fcb6ee4 --- /dev/null +++ b/sumcheck/benches/eval_partial_poly.rs @@ -0,0 +1,127 @@ +use std::sync::Arc; + +use ark_std::rand::thread_rng; +use criterion::*; +use ff::Field; +use goldilocks::{Goldilocks, GoldilocksExt2}; +use multilinear_extensions::{ + commutative_op_mle_pair, mle::DenseMultilinearExtension, op_mle, + virtual_poly::VirtualPolynomial, +}; +use sumcheck::util::{barycentric_weights, extrapolate, AdditiveArray, AdditiveVec}; + +fn eval_partial_poly(c: &mut Criterion) { + type E = GoldilocksExt2; + type F = Goldilocks; + let mut rng = thread_rng(); + + const NUM_SAMPLES: usize = 10; + for nv in 12..20 { + let mut group = c.benchmark_group("mle"); + group.sample_size(NUM_SAMPLES); + + let mut setup = |nv, i| { + let mut f = Arc::new(DenseMultilinearExtension::::random(nv, &mut rng)); + let mut g = Arc::new(DenseMultilinearExtension::::random(nv, &mut rng)); + let mut h = Arc::new(DenseMultilinearExtension::::random(nv, &mut rng)); + let r = E::random(&mut rng); + for _ in 0..i { + Arc::get_mut(&mut f).unwrap().fix_variables_in_place(&[r]); + Arc::get_mut(&mut g).unwrap().fix_variables_in_place(&[r]); + Arc::get_mut(&mut h).unwrap().fix_variables_in_place(&[r]); + } + let mut p = VirtualPolynomial::new_from_mle(f, F::from(2_u64)); + p.mul_by_mle(g.clone(), F::from(3u64)); + + let mut q = VirtualPolynomial::new_from_mle(g, F::from(5u64)); + q.mul_by_mle(h, F::from(7u64)); + + p.merge(&q); + + assert_eq!(p.products.len(), 2); + assert_eq!(p.flattened_ml_extensions.len(), 3); + + p + }; + + let routine = |poly: VirtualPolynomial| { + let AdditiveVec(products_sum) = poly.products.iter().fold( + AdditiveVec::new(poly.aux_info.max_degree + 1), + |mut products_sum, (coefficient, products)| { + assert_eq!(products.len(), 2); + let mut sum = match products.len() { + 1 => { + let f = &poly.flattened_ml_extensions[products[0]]; + op_mle! { + |f| (0..f.len()) + .into_iter() + .step_by(2) + .map(|b| { + AdditiveArray([ + f[b], + f[b + 1] + ]) + }) + .sum::>(), + |sum| AdditiveArray(sum.0.map(E::from)) + } + .to_vec() + } + 2 => { + let (f, g) = ( + &poly.flattened_ml_extensions[products[0]], + &poly.flattened_ml_extensions[products[1]], + ); + commutative_op_mle_pair!( + |f, g| (0..f.len()) + .into_iter() + .step_by(2) + .map(|b| { + AdditiveArray([ + f[b] * g[b], + f[b + 1] * g[b + 1], + (f[b + 1] + f[b + 1] - f[b]) + * (g[b + 1] + g[b + 1] - g[b]), + ]) + }) + .sum::>(), + |sum| AdditiveArray(sum.0.map(E::from)) + ) + .to_vec() + } + _ => unimplemented!("do not support degree > 2"), + }; + sum.iter_mut().for_each(|sum| *sum *= coefficient); + + let extrapolation = (0..poly.aux_info.max_degree - products.len()) + .into_iter() + .map(|i| { + let points = (0..(1 + products.len()) as u64) + .map(E::from) + .collect::>(); + let weights = barycentric_weights(&points); + let at = E::from((products.len() + 1 + i) as u64); + extrapolate(&points, &weights, &sum, &at) + }) + .collect::>(); + sum.extend(extrapolation); + + products_sum += AdditiveVec(sum); + products_sum + }, + ); + }; + for i in 0..nv { + group.bench_function( + BenchmarkId::new("eval_partial_poly", format!("({},{})", nv, nv - i)), + |b| { + b.iter_with_setup(|| setup(nv, i), routine); + }, + ); + } + group.finish(); + } +} + +criterion_group!(benches, eval_partial_poly); +criterion_main!(benches); diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index 0eff13504..0b5804ada 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -94,7 +94,7 @@ fn serial_batch_inversion_and_mul(v: &mut [F], coeff: &F) { // refer to https://people.maths.ox.ac.uk/trefethen/barycentric.pdf for barycentric formula // p(x) = \sum f_j * (w_j / (x - x_j)) / \sum w_j / (x - x_j) -pub(crate) fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { +pub fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { let (coeffs, sum_inv) = { // x - x_j let mut coeffs = points.iter().map(|point| *at - point).collect::>(); From a4660d1a5a251808bd5e364bbd81f5ad61b35d0e Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 11 Jul 2024 09:51:48 +0800 Subject: [PATCH 4/7] add serial version of sumcheck prover --- sumcheck/benches/devirgo_sumcheck.rs | 49 +++++++++++++++-- sumcheck/examples/devirgo_sumcheck.rs | 7 +++ sumcheck/src/prover.rs | 76 ++++++++++++++++++++++++++- 3 files changed, 128 insertions(+), 4 deletions(-) diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index c4dcfeb3e..f3117a706 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -18,7 +18,12 @@ use multilinear_extensions::{ }; use transcript::Transcript; -criterion_group!(benches, sumcheck_fn, devirgo_sumcheck_fn,); +criterion_group!( + benches, + sumcheck_serial_fn, + sumcheck_fn, + devirgo_sumcheck_fn, +); criterion_main!(benches); const NUM_SAMPLES: usize = 10; @@ -84,7 +89,7 @@ const RAYON_NUM_THREADS: usize = 8; fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in 24..25 { + for nv in 20..25 { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("sumcheck_nv_{}", nv)); group.sample_size(NUM_SAMPLES); @@ -119,10 +124,48 @@ fn sumcheck_fn(c: &mut Criterion) { } } +fn sumcheck_serial_fn(c: &mut Criterion) { + type E = GoldilocksExt2; + + for nv in 20..25 { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("sumcheck_serial_nv_{}", nv)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_sumcheck", format!("sumcheck_serial_nv_{}", nv)), + |b| { + b.iter_with_setup( + || { + let prover_transcript = Transcript::::new(b"test"); + let (asserted_sum, virtual_poly, virtual_poly_splitted) = + { prepare_input(RAYON_NUM_THREADS, nv) }; + ( + prover_transcript, + asserted_sum, + virtual_poly, + virtual_poly_splitted, + ) + }, + |(mut prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted)| { + let (sumcheck_proof_v1_serial, _) = IOPProverState::::prove_serial( + virtual_poly.clone(), + &mut prover_transcript, + ); + }, + ); + }, + ); + + group.finish(); + } +} + fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in 24..25 { + for nv in 20..25 { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); group.sample_size(NUM_SAMPLES); diff --git a/sumcheck/examples/devirgo_sumcheck.rs b/sumcheck/examples/devirgo_sumcheck.rs index 3cc7be741..23d8493a0 100644 --- a/sumcheck/examples/devirgo_sumcheck.rs +++ b/sumcheck/examples/devirgo_sumcheck.rs @@ -78,6 +78,7 @@ const RAYON_NUM_THREADS: usize = 8; fn main() { let mut prover_transcript_v1 = Transcript::::new(b"test"); + let mut prover_transcript_v1_serial = Transcript::::new(b"test"); let mut prover_transcript_v2 = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = prepare_input(RAYON_NUM_THREADS); @@ -113,4 +114,10 @@ fn main() { println!("v1 finish"); assert!(sumcheck_proof_v2 == sumcheck_proof_v1); + + let (sumcheck_proof_v1_serial, _) = + IOPProverState::::prove_serial(virtual_poly.clone(), &mut prover_transcript_v1_serial); + + println!("v1 serial finish"); + assert!(sumcheck_proof_v2 == sumcheck_proof_v1_serial); } diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 4d83e92a2..7d07d1595 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -424,7 +424,7 @@ impl IOPProverState { let span = entered_span!("extrapolation"); let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) - .into_par_iter() + .into_iter() .map(|i| { let (points, weights) = &self.extrapolation_aux[products.len() - 1]; let at = E::from((products.len() + 1 + i) as u64); @@ -449,6 +449,80 @@ impl IOPProverState { } } + // The prover runs in sequential mode in the sense that we use the serial version of + // fix_variables and eval_partial_poly in each round of sumcheck. + pub fn prove_serial( + poly: VirtualPolynomial, + transcript: &mut Transcript, + ) -> (IOPProof, IOPProverState) { + let (num_variables, max_degree) = (poly.aux_info.num_variables, poly.aux_info.max_degree); + + // return empty proof when target polymonial is constant + if num_variables == 0 { + return ( + IOPProof::default(), + IOPProverState { + poly: poly, + ..Default::default() + }, + ); + } + let start = start_timer!(|| "sum check prove"); + + transcript.append_message(&num_variables.to_le_bytes()); + transcript.append_message(&max_degree.to_le_bytes()); + + let mut prover_state = Self::prover_init_parallel(poly); + let mut challenge = None; + let mut prover_msgs = Vec::with_capacity(num_variables); + let span = entered_span!("prove_rounds"); + for _ in 0..num_variables { + let prover_msg = + IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); + + prover_msg + .evaluations + .iter() + .for_each(|e| transcript.append_field_element_ext(e)); + + prover_msgs.push(prover_msg); + let span = entered_span!("get_challenge"); + challenge = Some(transcript.get_and_append_challenge(b"Internal round")); + exit_span!(span); + } + exit_span!(span); + + let span = entered_span!("after_rounds_prover_state"); + // pushing the last challenge point to the state + if let Some(p) = challenge { + prover_state.challenges.push(p); + // fix last challenge to collect final evaluation + prover_state + .poly + .flattened_ml_extensions + .iter_mut() + .for_each(|mle| { + Arc::make_mut(mle).fix_variables_in_place(&[p.elements]); + }); + }; + exit_span!(span); + + end_timer!(start); + ( + IOPProof { + // the point consists of the first elements in the challenge + point: prover_state + .challenges + .iter() + .map(|challenge| challenge.elements) + .collect(), + proofs: prover_msgs, + ..Default::default() + }, + prover_state.into(), + ) + } + /// collect all mle evaluation (claim) after sumcheck pub fn get_mle_final_evaluations(&self) -> Vec { self.poly From 097cf9034ce41e69435bb82b1ddb8e8d6bbf601e Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 25 Jul 2024 20:11:41 +0800 Subject: [PATCH 5/7] add logs for counting GKR proving time distribution --- gkr/examples/keccak256.rs | 6 ++-- gkr/src/prover.rs | 42 ++++++++++++++++++++++----- multilinear_extensions/benches/mle.rs | 4 +-- sumcheck/benches/eval_partial_poly.rs | 6 ++-- 4 files changed, 42 insertions(+), 16 deletions(-) diff --git a/gkr/examples/keccak256.rs b/gkr/examples/keccak256.rs index d8a1433ac..938153f70 100644 --- a/gkr/examples/keccak256.rs +++ b/gkr/examples/keccak256.rs @@ -24,7 +24,7 @@ fn main() { #[allow(unused_mut)] let mut max_thread_id: usize = env::var("RAYON_NUM_THREADS") .map(|v| str::parse::(&v).unwrap_or(1)) - .unwrap(); + .unwrap_or(1); if !is_power_of_2(max_thread_id) { #[cfg(not(feature = "non_pow2_rayon_thread"))] @@ -85,7 +85,7 @@ fn main() { let subscriber = Registry::default() .with( fmt::layer() - .compact() + .compact() .with_thread_ids(false) .with_thread_names(false), ) @@ -93,7 +93,7 @@ fn main() { .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); - for log2_n in 0..12 { + for log2_n in 10..11 { let Some((proof, output_mle)) = prove_keccak256::(log2_n, &circuit, (1 << log2_n).min(max_thread_id)) else { diff --git a/gkr/src/prover.rs b/gkr/src/prover.rs index 26a8c6c87..e060922d3 100644 --- a/gkr/src/prover.rs +++ b/gkr/src/prover.rs @@ -1,4 +1,4 @@ -use std::mem; +use std::{mem, ops::AddAssign, time::{Duration, Instant}}; use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; @@ -59,12 +59,21 @@ impl IOPProverState { ) }); + let sumcheck_prove_dur = Instant::now(); + let mut phase2_build_mle_durs = vec![]; + let mut phase2_sumcheck_prove_durs = vec![]; + let mut phase1_prove_sumcheck_durs = vec![]; + let mut build_eq_durs = vec![]; let sumcheck_proofs = (0..circuit.layers.len() as LayerId) .map(|layer_id| { let timer = start_timer!(|| format!("Prove layer {}", layer_id)); prover_state.layer_id = layer_id; + let mut phase2_build_mle_dur = Duration::ZERO; + let mut phase2_sumcheck_dur = Duration::ZERO; + let mut phase1_prove_sumcheck_dur = Duration::ZERO; + let mut build_eq_dur = Duration::ZERO; let dummy_step = SumcheckStepType::Undefined; let proofs = circuit.layers[layer_id as usize] .sumcheck_steps @@ -85,8 +94,9 @@ impl IOPProverState { transcript, )].to_vec() }, - (SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2, _) => - [ + (SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2, _) => { + let dur = Instant::now(); + let proofs = [ prover_state .prove_and_update_state_phase1_step1( circuit, @@ -99,8 +109,10 @@ impl IOPProverState { circuit_witness, transcript, ), - ].to_vec() - , + ].to_vec(); + phase1_prove_sumcheck_dur.add_assign(dur.elapsed()); + proofs + }, (SumcheckStepType::Phase2Step1, step2, _) => { let span = entered_span!("phase2_gkr"); let max_steps = match step2 { @@ -113,9 +125,12 @@ impl IOPProverState { let mut layer_polys = (0..max_thread_id).map(|_| ArcDenseMultilinearExtension::default()).collect::>>(); let mut res = vec![]; for step in 0..max_steps { + let dur = Instant::now(); let bounded_eval_point = prover_state.to_next_step_point.clone(); eqs.push(build_eq_x_r_vec(&bounded_eval_point)); + build_eq_dur.add_assign(dur.elapsed()); // build step round poly + let build_mle_dur = Instant::now(); let virtual_polys: Vec> = (0..max_thread_id).into_par_iter().zip(layer_polys.par_iter_mut()).map(|(thread_id, layer_poly)| { let span = entered_span!("build_poly"); let (next_layer_poly_step1, virtual_poly) = match step { @@ -154,18 +169,23 @@ impl IOPProverState { }, _ => unimplemented!(), }; + if let Some(next_layer_poly_step1) = next_layer_poly_step1 { let _ = mem::replace(layer_poly, next_layer_poly_step1); } exit_span!(span); virtual_poly }).collect(); + let build_mle_dur = build_mle_dur.elapsed(); + phase2_build_mle_dur.add_assign(build_mle_dur); + let sumcheck_dur = Instant::now(); let (sumcheck_proof, sumcheck_prover_state) = sumcheck::structs::IOPProverState::::prove_batch_polys( max_thread_id, virtual_polys.try_into().unwrap(), transcript, ); + phase2_sumcheck_dur.add_assign(sumcheck_dur.elapsed()); let iop_prover_step = match step { @@ -221,14 +241,22 @@ impl IOPProverState { }) .collect_vec(); end_timer!(timer); - + println!("layer {} | phase2 build mle: {:?}", layer_id, phase2_build_mle_dur); + phase2_build_mle_durs.push(phase2_build_mle_dur); + phase2_sumcheck_prove_durs.push(phase2_sumcheck_dur); + phase1_prove_sumcheck_durs.push(phase1_prove_sumcheck_dur); + build_eq_durs.push(build_eq_dur); proofs }) .flatten() .collect_vec(); end_timer!(timer); exit_span!(span); - + println!("phase2 build mle in total: {:?}", phase2_build_mle_durs.iter().sum::()); + println!("phase2 prove sumcheck in total: {:?}", phase2_sumcheck_prove_durs.iter().sum::()); + println!("phase1 prove sumcheck in total: {:?}", phase1_prove_sumcheck_durs.iter().sum::()); + println!("build eq in total: {:?}", build_eq_durs.iter().sum::()); + println!("prove sumcheck took {:?}", sumcheck_prove_dur.elapsed()); ( IOPProof { sumcheck_proofs }, GKRInputClaims { diff --git a/multilinear_extensions/benches/mle.rs b/multilinear_extensions/benches/mle.rs index 54c601a8b..b40494994 100644 --- a/multilinear_extensions/benches/mle.rs +++ b/multilinear_extensions/benches/mle.rs @@ -8,7 +8,7 @@ fn fix_var(c: &mut Criterion) { let mut rng = thread_rng(); const NUM_SAMPLES: usize = 10; - for nv in 12..20 { + for nv in 16..24 { let mut group = c.benchmark_group("mle"); group.sample_size(NUM_SAMPLES); @@ -39,7 +39,7 @@ fn fix_var_par(c: &mut Criterion) { let mut rng = thread_rng(); const NUM_SAMPLES: usize = 10; - for nv in 12..20 { + for nv in 16..24 { let mut group = c.benchmark_group("mle"); group.sample_size(NUM_SAMPLES); diff --git a/sumcheck/benches/eval_partial_poly.rs b/sumcheck/benches/eval_partial_poly.rs index a9fcb6ee4..29419ef48 100644 --- a/sumcheck/benches/eval_partial_poly.rs +++ b/sumcheck/benches/eval_partial_poly.rs @@ -20,6 +20,8 @@ fn eval_partial_poly(c: &mut Criterion) { let mut group = c.benchmark_group("mle"); group.sample_size(NUM_SAMPLES); + let points = (0..=2u64).map(E::from).collect::>(); + let weights = barycentric_weights(&points); let mut setup = |nv, i| { let mut f = Arc::new(DenseMultilinearExtension::::random(nv, &mut rng)); let mut g = Arc::new(DenseMultilinearExtension::::random(nv, &mut rng)); @@ -96,10 +98,6 @@ fn eval_partial_poly(c: &mut Criterion) { let extrapolation = (0..poly.aux_info.max_degree - products.len()) .into_iter() .map(|i| { - let points = (0..(1 + products.len()) as u64) - .map(E::from) - .collect::>(); - let weights = barycentric_weights(&points); let at = E::from((products.len() + 1 + i) as u64); extrapolate(&points, &weights, &sum, &at) }) From 010977905d25097176b960be52417b7aeeb33f49 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 25 Jul 2024 21:06:15 +0800 Subject: [PATCH 6/7] add build_eq_r benchmark --- multilinear_extensions/benches/mle.rs | 46 ++++++++++++++++++++-- multilinear_extensions/src/virtual_poly.rs | 9 ++--- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/multilinear_extensions/benches/mle.rs b/multilinear_extensions/benches/mle.rs index b40494994..5bd6a5387 100644 --- a/multilinear_extensions/benches/mle.rs +++ b/multilinear_extensions/benches/mle.rs @@ -1,8 +1,11 @@ -use ark_std::rand::thread_rng; +use ark_std::rand::{thread_rng, Rng}; use criterion::*; use ff::Field; use goldilocks::GoldilocksExt2; -use multilinear_extensions::mle::DenseMultilinearExtension; +use multilinear_extensions::{ + mle::DenseMultilinearExtension, + virtual_poly::{build_eq_x_r, build_eq_x_r_sequential}, +}; fn fix_var(c: &mut Criterion) { let mut rng = thread_rng(); @@ -66,5 +69,42 @@ fn fix_var_par(c: &mut Criterion) { } } -criterion_group!(benches, fix_var, fix_var_par,); +fn bench_build_eq_internal(c: &mut Criterion, use_par: bool) { + const NUM_SAMPLES: usize = 10; + let mut rng = thread_rng(); + let group_name = if use_par { + "build_eq_par" + } else { + "build_eq_seq" + }; + let mut group = c.benchmark_group(group_name); + group.sample_size(NUM_SAMPLES); + + for num_vars in 15..24 { + group.bench_function(format!("{}", num_vars), |b| { + b.iter_batched( + || { + (0..num_vars) + .map(|_| GoldilocksExt2::random(&mut rng)) + .collect::>() + }, + |r| { + if use_par { + build_eq_x_r(&r) + } else { + build_eq_x_r_sequential(&r) + } + }, + BatchSize::SmallInput, + ); + }); + } +} + +fn bench_build_eq(c: &mut Criterion) { + bench_build_eq_internal(c, false); + bench_build_eq_internal(c, true); +} + +criterion_group!(benches, bench_build_eq, fix_var, fix_var_par,); criterion_main!(benches); diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index 34b96a047..933d59d82 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -358,7 +358,7 @@ pub fn eq_eval(x: &[F], y: &[F]) -> F { /// Evaluate /// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i)) /// over r, which is -/// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) +/// eq(x,r) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) pub fn build_eq_x_r_sequential(r: &[E]) -> ArcDenseMultilinearExtension { let evals = build_eq_x_r_vec_sequential(r); let mle = DenseMultilinearExtension::from_evaluations_ext_vec(r.len(), evals); @@ -438,8 +438,7 @@ pub fn build_eq_x_r(r: &[E]) -> ArcDenseMultilinearExtension< /// Evaluate /// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i)) /// over r, which is -/// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) - +/// eq(x,r) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) #[tracing::instrument(skip_all, name = "multilinear_extensions::build_eq_x_r_vec")] pub fn build_eq_x_r_vec(r: &[E]) -> Vec { // avoid unnecessary allocation @@ -479,8 +478,8 @@ fn build_eq_x_r_helper(r: &[E], buf: &mut [Vec; 2]) { for (i, r) in r.iter().rev().enumerate() { let [current, next] = buf; let (cur_size, next_size) = (1 << i, 1 << (i + 1)); - // suppose at the previous step we processed buf [0..size] - // for the current step we are populating new buf[0..2*size] + // suppose at the previous step we processed current_buf[0..size] + // for the current step we are populating new_buf[0..2*size] // for j travese 0..size // buf[2*j + 1] = r * buf[j] // buf[2*j] = (1 - r) * buf[j] From 9c8a86fd079baab2342c345b0f0fbd88168c28ab Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 25 Jul 2024 23:08:03 +0800 Subject: [PATCH 7/7] fmt --- gkr/src/prover.rs | 26 +++++++++++++++++++++----- multilinear_extensions/benches/mle.rs | 2 +- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/gkr/src/prover.rs b/gkr/src/prover.rs index e060922d3..d6760fe07 100644 --- a/gkr/src/prover.rs +++ b/gkr/src/prover.rs @@ -1,4 +1,8 @@ -use std::{mem, ops::AddAssign, time::{Duration, Instant}}; +use std::{ + mem, + ops::AddAssign, + time::{Duration, Instant}, +}; use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; @@ -252,10 +256,22 @@ impl IOPProverState { .collect_vec(); end_timer!(timer); exit_span!(span); - println!("phase2 build mle in total: {:?}", phase2_build_mle_durs.iter().sum::()); - println!("phase2 prove sumcheck in total: {:?}", phase2_sumcheck_prove_durs.iter().sum::()); - println!("phase1 prove sumcheck in total: {:?}", phase1_prove_sumcheck_durs.iter().sum::()); - println!("build eq in total: {:?}", build_eq_durs.iter().sum::()); + println!( + "phase2 build mle in total: {:?}", + phase2_build_mle_durs.iter().sum::() + ); + println!( + "phase2 prove sumcheck in total: {:?}", + phase2_sumcheck_prove_durs.iter().sum::() + ); + println!( + "phase1 prove sumcheck in total: {:?}", + phase1_prove_sumcheck_durs.iter().sum::() + ); + println!( + "build eq in total: {:?}", + build_eq_durs.iter().sum::() + ); println!("prove sumcheck took {:?}", sumcheck_prove_dur.elapsed()); ( IOPProof { sumcheck_proofs }, diff --git a/multilinear_extensions/benches/mle.rs b/multilinear_extensions/benches/mle.rs index 5bd6a5387..1d9641a45 100644 --- a/multilinear_extensions/benches/mle.rs +++ b/multilinear_extensions/benches/mle.rs @@ -1,4 +1,4 @@ -use ark_std::rand::{thread_rng, Rng}; +use ark_std::rand::thread_rng; use criterion::*; use ff::Field; use goldilocks::GoldilocksExt2;