From 8e5dbdb3f722cb927dd7ea97dd545aa3abdf23da Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 10 Sep 2024 08:09:11 +0800 Subject: [PATCH] Fix some parallization issues --- mpcs/src/basefold/commit_phase.rs | 17 +++++++++++++---- mpcs/src/basefold/structure.rs | 2 +- mpcs/src/util/merkle_tree.rs | 9 ++++++--- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index da8b4ad75..99fe42d99 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -24,7 +24,8 @@ use multilinear_extensions::{mle::FieldType, virtual_poly::build_eq_x_r_vec}; use crate::util::plonky2_util::reverse_index_bits_in_place; use rayon::prelude::{ - IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSlice, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, + ParallelSlice, }; use super::structure::BasefoldCommitmentWithData; @@ -318,10 +319,14 @@ where let timer = start_timer!(|| "Simple batch commit phase"); assert_eq!(point.len(), num_vars); assert_eq!(comm.num_polys, batch_coeffs.len()); + let prepare_timer = start_timer!(|| "Prepare"); let mut oracles = Vec::with_capacity(num_vars); let mut trees = Vec::with_capacity(num_vars); - let mut running_oracle = comm.batch_codewords(&batch_coeffs.to_vec()); + let batch_codewords_timer = start_timer!(|| "Batch codewords"); + let mut running_oracle = comm.batch_codewords(batch_coeffs); + end_timer!(batch_codewords_timer); let mut running_evals = (0..(1 << num_vars)) + .into_par_iter() .map(|i| { comm.polynomials_bh_evals .iter() @@ -329,13 +334,17 @@ where .map(|(eval, coeff)| field_type_index_ext(eval, i) * *coeff) .sum() }) - .collect_vec(); + .collect(); + end_timer!(prepare_timer); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube - let build_eq_timer = start_timer!(|| "Basefold::open"); + let build_eq_timer = start_timer!(|| "Basefold::build eq"); let mut eq = build_eq_x_r_vec(point); end_timer!(build_eq_timer); + + let reverse_bits_timer = start_timer!(|| "Basefold::reverse bits"); reverse_index_bits_in_place(&mut eq); + end_timer!(reverse_bits_timer); let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round"); let mut last_sumcheck_message = sum_check_first_round(&mut eq, &mut running_evals); diff --git a/mpcs/src/basefold/structure.rs b/mpcs/src/basefold/structure.rs index 0a1972c38..f79e57865 100644 --- a/mpcs/src/basefold/structure.rs +++ b/mpcs/src/basefold/structure.rs @@ -89,7 +89,7 @@ where self.codeword_tree.leaves() } - pub fn batch_codewords(&self, coeffs: &Vec) -> Vec { + pub fn batch_codewords(&self, coeffs: &[E]) -> Vec { self.codeword_tree.batch_leaves(coeffs) } diff --git a/mpcs/src/util/merkle_tree.rs b/mpcs/src/util/merkle_tree.rs index b3d1bf899..eb187cdef 100644 --- a/mpcs/src/util/merkle_tree.rs +++ b/mpcs/src/util/merkle_tree.rs @@ -2,7 +2,9 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::FieldType; use rayon::{ - iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, + iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, + }, slice::ParallelSlice, }; @@ -63,12 +65,13 @@ where &self.leaves } - pub fn batch_leaves(&self, coeffs: &Vec) -> Vec { + pub fn batch_leaves(&self, coeffs: &[E]) -> Vec { (0..self.leaves[0].len()) + .into_par_iter() .map(|i| { self.leaves .iter() - .zip(coeffs) + .zip(coeffs.iter()) .map(|(leaf, coeff)| field_type_index_ext(leaf, i) * *coeff) .sum() })