diff --git a/multilinear_extensions/src/lib.rs b/multilinear_extensions/src/lib.rs index 9f669e348..63d0f0449 100644 --- a/multilinear_extensions/src/lib.rs +++ b/multilinear_extensions/src/lib.rs @@ -1,4 +1,5 @@ #![deny(clippy::cargo)] +#![feature(sync_unsafe_cell)] pub mod mle; pub mod util; pub mod virtual_poly; diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 8a182e645..f3732297d 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1,12 +1,22 @@ -use std::{any::TypeId, borrow::Cow, mem, sync::Arc}; +use std::{ + any::TypeId, + borrow::Cow, + cell::SyncUnsafeCell, + mem::{self, MaybeUninit}, + sync::Arc, +}; -use crate::{op_mle, util::ceil_log2}; +use crate::{ + op_mle, + util::{ceil_log2, create_uninit_vec, max_usable_threads}, +}; use ark_std::{end_timer, rand::RngCore, start_timer}; use core::hash::Hash; use ff::Field; use ff_ext::ExtensionField; use rayon::iter::{ - IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, }; use serde::{Deserialize, Serialize}; use std::fmt::Debug; @@ -611,6 +621,7 @@ impl MultilinearExtension for DenseMultilinearExtension /// Reduce the number of variables of `self` by fixing the /// `partial_point.len()` variables at `partial_point`. fn fix_variables_parallel(&self, partial_point: &[E]) -> Self { + let n_threads = max_usable_threads(); // TODO: return error. assert!( partial_point.len() <= self.num_vars(), @@ -626,12 +637,25 @@ impl MultilinearExtension for DenseMultilinearExtension *poly = op_mle!(self, |evaluations| { Cow::Owned(DenseMultilinearExtension::from_evaluations_ext_vec( self.num_vars() - 1, - evaluations - .par_iter() - .chunks(2) - .with_min_len(64) - .map(|buf| *point * (*buf[1] - *buf[0]) + *buf[0]) - .collect(), + unsafe { + let data = create_uninit_vec::(evaluations.len() / 2); + let vec_sync_unsafe = SyncUnsafeCell::new(data); + (0..n_threads).into_par_iter().for_each(|thread_id| { + let ptr = (*vec_sync_unsafe.get()).as_mut_ptr(); + (0..evaluations.len()) + .skip(2 * thread_id) + .step_by(2 * n_threads) + .for_each(|i| { + *ptr.add(i / 2) = MaybeUninit::new( + *point * (evaluations[i + 1] - evaluations[i]) + + evaluations[i], + ); + }); + }); + let maybe_uninit_vec: Vec> = + vec_sync_unsafe.into_inner(); + std::mem::transmute::>, Vec>(maybe_uninit_vec) + }, )) }); } @@ -653,20 +677,34 @@ impl MultilinearExtension for DenseMultilinearExtension self.num_vars() ); let nv = self.num_vars(); + let n_threads = max_usable_threads(); // evaluate single variable of partial point from left to right for (i, point) in partial_point.iter().enumerate() { let max_log2_size = nv - i; // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, b2,..bt, 1] in parallel match &mut self.evaluations { - FieldType::Base(evaluations) => { - let evaluations_ext = evaluations - .par_iter() - .chunks(2) - .with_min_len(64) - .map(|buf| *point * (*buf[1] - *buf[0]) + *buf[0]) - .collect(); - let _ = mem::replace(&mut self.evaluations, FieldType::Ext(evaluations_ext)); - } + FieldType::Base(evaluations) => unsafe { + let data = create_uninit_vec::(evaluations.len() / 2); + let vec_sync_unsafe = SyncUnsafeCell::new(data); + (0..n_threads).into_par_iter().for_each(|thread_id| { + let ptr = (*vec_sync_unsafe.get()).as_mut_ptr(); + (0..evaluations.len()) + .skip(2 * thread_id) + .step_by(2 * n_threads) + .for_each(|i| { + *ptr.add(i / 2) = MaybeUninit::new( + *point * (evaluations[i + 1] - evaluations[i]) + evaluations[i], + ); + }); + }); + let maybe_uninit_vec: Vec> = vec_sync_unsafe.into_inner(); + let _ = mem::replace( + &mut self.evaluations, + FieldType::Ext(std::mem::transmute::>, Vec>( + maybe_uninit_vec, + )), + ); + }, FieldType::Ext(evaluations) => { evaluations .par_iter_mut() diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index b4021b77d..7f0b6b832 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -8,12 +8,12 @@ use multilinear_extensions::{ commutative_op_mle_pair, mle::{DenseMultilinearExtension, MultilinearExtension}, op_mle, op_mle_product_3, op_mle3_range, - util::largest_even_below, + util::{largest_even_below, max_usable_threads}, virtual_poly_v2::VirtualPolynomialV2, }; use rayon::{ Scope, - iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator}, + iter::IntoParallelRefMutIterator, prelude::{IntoParallelIterator, ParallelIterator}, }; use transcript::{Challenge, Transcript, TranscriptSyncronized}; @@ -722,6 +722,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { // // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n) let span = entered_span!("fix_variables"); + let n_threads = max_usable_threads(); if self.round == 0 { assert!(challenge.is_none(), "first round should be prover first."); } else { @@ -769,8 +770,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let AdditiveVec(products_sum) = self .poly .products - .par_iter() - .fold_with( + .iter() + .fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), |mut products_sum, (coefficient, products)| { let span = entered_span!("sum"); @@ -780,17 +781,17 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { |f| { - let res = (0..largest_even_below(f.len())) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { + let res = (0..n_threads).into_par_iter().map(|thread_id| { + (0..largest_even_below(f.len())) + .skip(2*thread_id) + .step_by(2*n_threads) + .map(|b| { AdditiveArray([ f[b], f[b + 1] ]) - }) - .sum::>(); + }).sum::>() + }).sum::>(); let res = if f.len() == 1 { AdditiveArray::<_, 2>([f[0]; 2]) } else { @@ -814,19 +815,20 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ); commutative_op_mle_pair!( |f, g| { - let res = (0..largest_even_below(f.len())) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { - AdditiveArray([ - f[b] * g[b], - f[b + 1] * g[b + 1], - (f[b + 1] + f[b + 1] - f[b]) - * (g[b + 1] + g[b + 1] - g[b]), - ]) - }) - .sum::>(); + let res = (0..n_threads).into_par_iter().map(|thread_id| { + (0..largest_even_below(f.len())) + .skip(2*thread_id) + .step_by(2*n_threads) + .map(|b| { + AdditiveArray([ + f[b] * g[b], + f[b + 1] * g[b + 1], + (f[b + 1] + f[b + 1] - f[b]) + * (g[b + 1] + g[b + 1] - g[b]), + ]) + }).sum::>() + }).sum::>(); + let res = if f.len() == 1 { AdditiveArray::<_, 3>([f[0] * g[0]; 3]) } else { @@ -851,23 +853,26 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ); op_mle_product_3!( |f1, f2, f3| { - let res = (0..largest_even_below(f1.len())) - .step_by(2) - .map(|b| { - // f = c x + d - let c1 = f1[b + 1] - f1[b]; - let c2 = f2[b + 1] - f2[b]; - let c3 = f3[b + 1] - f3[b]; - AdditiveArray([ - f1[b] * (f2[b] * f3[b]), - f1[b + 1] * (f2[b + 1] * f3[b + 1]), - (c1 + f1[b + 1]) - * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), - (c1 + c1 + f1[b + 1]) - * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), - ]) - }) - .sum::>(); + let res = (0..n_threads).into_par_iter().map(|thread_id| { + (0..largest_even_below(f1.len())) + .skip(2*thread_id) + .step_by(2*n_threads) + .map(|b| { + // f = c x + d + let c1 = f1[b + 1] - f1[b]; + let c2 = f2[b + 1] - f2[b]; + let c3 = f3[b + 1] - f3[b]; + AdditiveArray([ + f1[b] * (f2[b] * f3[b]), + f1[b + 1] * (f2[b + 1] * f3[b + 1]), + (c1 + f1[b + 1]) + * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), + (c1 + c1 + f1[b + 1]) + * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), + ]) + }).sum::>() + }).sum::>(); + let res = if f1.len() == 1 { AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) } else { @@ -905,9 +910,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { exit_span!(span); products_sum }, - ) - .reduce_with(|acc, item| acc + item) - .unwrap(); + ); exit_span!(span); end_timer!(start);