Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize parallel version sumcheck #743

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions multilinear_extensions/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![deny(clippy::cargo)]
#![feature(sync_unsafe_cell)]
pub mod mle;
pub mod util;
pub mod virtual_poly;
Expand Down
74 changes: 56 additions & 18 deletions multilinear_extensions/src/mle.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
use std::{any::TypeId, borrow::Cow, mem, sync::Arc};
use std::{
any::TypeId,
borrow::Cow,
cell::SyncUnsafeCell,
mem::{self, MaybeUninit},
sync::Arc,
};

use crate::{op_mle, util::ceil_log2};
use crate::{
op_mle,
util::{ceil_log2, create_uninit_vec, max_usable_threads},
};
use ark_std::{end_timer, rand::RngCore, start_timer};
use core::hash::Hash;
use ff::Field;
use ff_ext::ExtensionField;
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
Expand Down Expand Up @@ -611,6 +621,7 @@ impl<E: ExtensionField> MultilinearExtension<E> for DenseMultilinearExtension<E>
/// Reduce the number of variables of `self` by fixing the
/// `partial_point.len()` variables at `partial_point`.
fn fix_variables_parallel(&self, partial_point: &[E]) -> Self {
let n_threads = max_usable_threads();
// TODO: return error.
assert!(
partial_point.len() <= self.num_vars(),
Expand All @@ -626,12 +637,25 @@ impl<E: ExtensionField> MultilinearExtension<E> for DenseMultilinearExtension<E>
*poly = op_mle!(self, |evaluations| {
Cow::Owned(DenseMultilinearExtension::from_evaluations_ext_vec(
self.num_vars() - 1,
evaluations
.par_iter()
.chunks(2)
.with_min_len(64)
.map(|buf| *point * (*buf[1] - *buf[0]) + *buf[0])
.collect(),
unsafe {
let data = create_uninit_vec::<E>(evaluations.len() / 2);
let vec_sync_unsafe = SyncUnsafeCell::new(data);
(0..n_threads).into_par_iter().for_each(|thread_id| {
let ptr = (*vec_sync_unsafe.get()).as_mut_ptr();
(0..evaluations.len())
.skip(2 * thread_id)
.step_by(2 * n_threads)
.for_each(|i| {
*ptr.add(i / 2) = MaybeUninit::new(
*point * (evaluations[i + 1] - evaluations[i])
+ evaluations[i],
);
});
});
let maybe_uninit_vec: Vec<MaybeUninit<E>> =
vec_sync_unsafe.into_inner();
std::mem::transmute::<Vec<MaybeUninit<E>>, Vec<E>>(maybe_uninit_vec)
},
))
});
}
Expand All @@ -653,20 +677,34 @@ impl<E: ExtensionField> MultilinearExtension<E> for DenseMultilinearExtension<E>
self.num_vars()
);
let nv = self.num_vars();
let n_threads = max_usable_threads();
// evaluate single variable of partial point from left to right
for (i, point) in partial_point.iter().enumerate() {
let max_log2_size = nv - i;
// override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, b2,..bt, 1] in parallel
match &mut self.evaluations {
FieldType::Base(evaluations) => {
let evaluations_ext = evaluations
.par_iter()
.chunks(2)
.with_min_len(64)
.map(|buf| *point * (*buf[1] - *buf[0]) + *buf[0])
.collect();
let _ = mem::replace(&mut self.evaluations, FieldType::Ext(evaluations_ext));
}
FieldType::Base(evaluations) => unsafe {
let data = create_uninit_vec::<E>(evaluations.len() / 2);
let vec_sync_unsafe = SyncUnsafeCell::new(data);
(0..n_threads).into_par_iter().for_each(|thread_id| {
let ptr = (*vec_sync_unsafe.get()).as_mut_ptr();
(0..evaluations.len())
.skip(2 * thread_id)
.step_by(2 * n_threads)
.for_each(|i| {
*ptr.add(i / 2) = MaybeUninit::new(
*point * (evaluations[i + 1] - evaluations[i]) + evaluations[i],
);
});
});
let maybe_uninit_vec: Vec<MaybeUninit<E>> = vec_sync_unsafe.into_inner();
let _ = mem::replace(
&mut self.evaluations,
FieldType::Ext(std::mem::transmute::<Vec<MaybeUninit<E>>, Vec<E>>(
maybe_uninit_vec,
)),
);
},
FieldType::Ext(evaluations) => {
evaluations
.par_iter_mut()
Expand Down
91 changes: 47 additions & 44 deletions sumcheck/src/prover_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use multilinear_extensions::{
commutative_op_mle_pair,
mle::{DenseMultilinearExtension, MultilinearExtension},
op_mle, op_mle_product_3, op_mle3_range,
util::largest_even_below,
util::{largest_even_below, max_usable_threads},
virtual_poly_v2::VirtualPolynomialV2,
};
use rayon::{
Scope,
iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator},
iter::IntoParallelRefMutIterator,
prelude::{IntoParallelIterator, ParallelIterator},
};
use transcript::{Challenge, Transcript, TranscriptSyncronized};
Expand Down Expand Up @@ -722,6 +722,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
//
// eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n)
let span = entered_span!("fix_variables");
let n_threads = max_usable_threads();
if self.round == 0 {
assert!(challenge.is_none(), "first round should be prover first.");
} else {
Expand Down Expand Up @@ -769,8 +770,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
let AdditiveVec(products_sum) = self
.poly
.products
.par_iter()
.fold_with(
.iter()
.fold(
AdditiveVec::new(self.poly.aux_info.max_degree + 1),
|mut products_sum, (coefficient, products)| {
let span = entered_span!("sum");
Expand All @@ -780,17 +781,17 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
let f = &self.poly.flattened_ml_extensions[products[0]];
op_mle! {
|f| {
let res = (0..largest_even_below(f.len()))
.into_par_iter()
.step_by(2)
.with_min_len(64)
.map(|b| {
let res = (0..n_threads).into_par_iter().map(|thread_id| {
(0..largest_even_below(f.len()))
.skip(2*thread_id)
.step_by(2*n_threads)
.map(|b| {
AdditiveArray([
f[b],
f[b + 1]
])
})
.sum::<AdditiveArray<_, 2>>();
}).sum::<AdditiveArray<_, 2>>()
}).sum::<AdditiveArray<_, 2>>();
let res = if f.len() == 1 {
AdditiveArray::<_, 2>([f[0]; 2])
} else {
Expand All @@ -814,19 +815,20 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
);
commutative_op_mle_pair!(
|f, g| {
let res = (0..largest_even_below(f.len()))
.into_par_iter()
.step_by(2)
.with_min_len(64)
.map(|b| {
AdditiveArray([
f[b] * g[b],
f[b + 1] * g[b + 1],
(f[b + 1] + f[b + 1] - f[b])
* (g[b + 1] + g[b + 1] - g[b]),
])
})
.sum::<AdditiveArray<_, 3>>();
let res = (0..n_threads).into_par_iter().map(|thread_id| {
(0..largest_even_below(f.len()))
.skip(2*thread_id)
.step_by(2*n_threads)
.map(|b| {
AdditiveArray([
f[b] * g[b],
f[b + 1] * g[b + 1],
(f[b + 1] + f[b + 1] - f[b])
* (g[b + 1] + g[b + 1] - g[b]),
])
}).sum::<AdditiveArray<_, 3>>()
}).sum::<AdditiveArray<_, 3>>();

let res = if f.len() == 1 {
AdditiveArray::<_, 3>([f[0] * g[0]; 3])
} else {
Expand All @@ -851,23 +853,26 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
);
op_mle_product_3!(
|f1, f2, f3| {
let res = (0..largest_even_below(f1.len()))
.step_by(2)
.map(|b| {
// f = c x + d
let c1 = f1[b + 1] - f1[b];
let c2 = f2[b + 1] - f2[b];
let c3 = f3[b + 1] - f3[b];
AdditiveArray([
f1[b] * (f2[b] * f3[b]),
f1[b + 1] * (f2[b + 1] * f3[b + 1]),
(c1 + f1[b + 1])
* ((c2 + f2[b + 1]) * (c3 + f3[b + 1])),
(c1 + c1 + f1[b + 1])
* ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])),
])
})
.sum::<AdditiveArray<_, 4>>();
let res = (0..n_threads).into_par_iter().map(|thread_id| {
(0..largest_even_below(f1.len()))
.skip(2*thread_id)
.step_by(2*n_threads)
.map(|b| {
// f = c x + d
let c1 = f1[b + 1] - f1[b];
let c2 = f2[b + 1] - f2[b];
let c3 = f3[b + 1] - f3[b];
AdditiveArray([
f1[b] * (f2[b] * f3[b]),
f1[b + 1] * (f2[b + 1] * f3[b + 1]),
(c1 + f1[b + 1])
* ((c2 + f2[b + 1]) * (c3 + f3[b + 1])),
(c1 + c1 + f1[b + 1])
* ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])),
])
}).sum::<AdditiveArray<_, 4>>()
}).sum::<AdditiveArray<_, 4>>();

let res = if f1.len() == 1 {
AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4])
} else {
Expand Down Expand Up @@ -905,9 +910,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
exit_span!(span);
products_sum
},
)
.reduce_with(|acc, item| acc + item)
.unwrap();
);
exit_span!(span);

end_timer!(start);
Expand Down