Skip to content

Commit

Permalink
Fix some parallization issues
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhangsjtu committed Sep 9, 2024
1 parent 0f3cf90 commit 5092195
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
17 changes: 13 additions & 4 deletions mpcs/src/basefold/commit_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use multilinear_extensions::{mle::FieldType, virtual_poly::build_eq_x_r_vec};

use crate::util::plonky2_util::reverse_index_bits_in_place;
use rayon::prelude::{
IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSlice,
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
ParallelSlice,
};

use super::structure::BasefoldCommitmentWithData;
Expand Down Expand Up @@ -318,24 +319,32 @@ where
let timer = start_timer!(|| "Simple batch commit phase");
assert_eq!(point.len(), num_vars);
assert_eq!(comm.num_polys, batch_coeffs.len());
let prepare_timer = start_timer!(|| "Prepare");
let mut oracles = Vec::with_capacity(num_vars);
let mut trees = Vec::with_capacity(num_vars);
let mut running_oracle = comm.batch_codewords(&batch_coeffs.to_vec());
let batch_codewords_timer = start_timer!(|| "Batch codewords");
let mut running_oracle = comm.batch_codewords(batch_coeffs);
end_timer!(batch_codewords_timer);
let mut running_evals = (0..(1 << num_vars))
.into_par_iter()
.map(|i| {
comm.polynomials_bh_evals
.iter()
.zip(batch_coeffs)
.map(|(eval, coeff)| field_type_index_ext(eval, i) * *coeff)
.sum()
})
.collect_vec();
.collect();
end_timer!(prepare_timer);

// eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
let build_eq_timer = start_timer!(|| "Basefold::open");
let build_eq_timer = start_timer!(|| "Basefold::build eq");
let mut eq = build_eq_x_r_vec(point);
end_timer!(build_eq_timer);

let reverse_bits_timer = start_timer!(|| "Basefold::reverse bits");
reverse_index_bits_in_place(&mut eq);
end_timer!(reverse_bits_timer);

let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round");
let mut last_sumcheck_message = sum_check_first_round(&mut eq, &mut running_evals);
Expand Down
2 changes: 1 addition & 1 deletion mpcs/src/basefold/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ where
self.codeword_tree.leaves()
}

pub fn batch_codewords(&self, coeffs: &Vec<E>) -> Vec<E> {
pub fn batch_codewords(&self, coeffs: &[E]) -> Vec<E> {
self.codeword_tree.batch_leaves(coeffs)
}

Expand Down
9 changes: 6 additions & 3 deletions mpcs/src/util/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use ff_ext::ExtensionField;
use itertools::Itertools;
use multilinear_extensions::mle::FieldType;
use rayon::{
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
},
slice::ParallelSlice,
};

Expand Down Expand Up @@ -63,12 +65,13 @@ where
&self.leaves
}

pub fn batch_leaves(&self, coeffs: &Vec<E>) -> Vec<E> {
pub fn batch_leaves(&self, coeffs: &[E]) -> Vec<E> {
(0..self.leaves[0].len())
.into_par_iter()
.map(|i| {
self.leaves
.iter()
.zip(coeffs)
.zip(coeffs.iter())
.map(|(leaf, coeff)| field_type_index_ext(leaf, i) * *coeff)
.sum()
})
Expand Down

0 comments on commit 5092195

Please sign in to comment.