Skip to content

Commit

Permalink
add serial version of sumcheck prover
Browse files Browse the repository at this point in the history
  • Loading branch information
kunxian-xia committed Jul 11, 2024
1 parent b3e8a38 commit a4660d1
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 4 deletions.
49 changes: 46 additions & 3 deletions sumcheck/benches/devirgo_sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ use multilinear_extensions::{
};
use transcript::Transcript;

criterion_group!(benches, sumcheck_fn, devirgo_sumcheck_fn,);
criterion_group!(
benches,
sumcheck_serial_fn,
sumcheck_fn,
devirgo_sumcheck_fn,
);
criterion_main!(benches);

const NUM_SAMPLES: usize = 10;
Expand Down Expand Up @@ -84,7 +89,7 @@ const RAYON_NUM_THREADS: usize = 8;
fn sumcheck_fn(c: &mut Criterion) {
type E = GoldilocksExt2;

for nv in 24..25 {
for nv in 20..25 {
// expand more input size once runtime is acceptable
let mut group = c.benchmark_group(format!("sumcheck_nv_{}", nv));
group.sample_size(NUM_SAMPLES);
Expand Down Expand Up @@ -119,10 +124,48 @@ fn sumcheck_fn(c: &mut Criterion) {
}
}

fn sumcheck_serial_fn(c: &mut Criterion) {
type E = GoldilocksExt2;

for nv in 20..25 {
// expand more input size once runtime is acceptable
let mut group = c.benchmark_group(format!("sumcheck_serial_nv_{}", nv));
group.sample_size(NUM_SAMPLES);

// Benchmark the proving time
group.bench_function(
BenchmarkId::new("prove_sumcheck", format!("sumcheck_serial_nv_{}", nv)),
|b| {
b.iter_with_setup(
|| {
let prover_transcript = Transcript::<E>::new(b"test");
let (asserted_sum, virtual_poly, virtual_poly_splitted) =
{ prepare_input(RAYON_NUM_THREADS, nv) };
(
prover_transcript,
asserted_sum,
virtual_poly,
virtual_poly_splitted,
)
},
|(mut prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted)| {
let (sumcheck_proof_v1_serial, _) = IOPProverState::<E>::prove_serial(
virtual_poly.clone(),
&mut prover_transcript,
);
},
);
},
);

group.finish();
}
}

fn devirgo_sumcheck_fn(c: &mut Criterion) {
type E = GoldilocksExt2;

for nv in 24..25 {
for nv in 20..25 {
// expand more input size once runtime is acceptable
let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv));
group.sample_size(NUM_SAMPLES);
Expand Down
7 changes: 7 additions & 0 deletions sumcheck/examples/devirgo_sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ const RAYON_NUM_THREADS: usize = 8;

fn main() {
let mut prover_transcript_v1 = Transcript::<E>::new(b"test");
let mut prover_transcript_v1_serial = Transcript::<E>::new(b"test");
let mut prover_transcript_v2 = Transcript::<E>::new(b"test");

let (asserted_sum, virtual_poly, virtual_poly_splitted) = prepare_input(RAYON_NUM_THREADS);
Expand Down Expand Up @@ -113,4 +114,10 @@ fn main() {

println!("v1 finish");
assert!(sumcheck_proof_v2 == sumcheck_proof_v1);

let (sumcheck_proof_v1_serial, _) =
IOPProverState::<E>::prove_serial(virtual_poly.clone(), &mut prover_transcript_v1_serial);

println!("v1 serial finish");
assert!(sumcheck_proof_v2 == sumcheck_proof_v1_serial);
}
76 changes: 75 additions & 1 deletion sumcheck/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ impl<E: ExtensionField> IOPProverState<E> {

let span = entered_span!("extrapolation");
let extrapolation = (0..self.poly.aux_info.max_degree - products.len())
.into_par_iter()
.into_iter()
.map(|i| {
let (points, weights) = &self.extrapolation_aux[products.len() - 1];
let at = E::from((products.len() + 1 + i) as u64);
Expand All @@ -449,6 +449,80 @@ impl<E: ExtensionField> IOPProverState<E> {
}
}

// The prover runs in sequential mode in the sense that we use the serial version of
// fix_variables and eval_partial_poly in each round of sumcheck.
pub fn prove_serial(
poly: VirtualPolynomial<E>,
transcript: &mut Transcript<E>,
) -> (IOPProof<E>, IOPProverState<E>) {
let (num_variables, max_degree) = (poly.aux_info.num_variables, poly.aux_info.max_degree);

// return empty proof when target polymonial is constant
if num_variables == 0 {
return (
IOPProof::default(),
IOPProverState {
poly: poly,
..Default::default()
},
);
}
let start = start_timer!(|| "sum check prove");

transcript.append_message(&num_variables.to_le_bytes());
transcript.append_message(&max_degree.to_le_bytes());

let mut prover_state = Self::prover_init_parallel(poly);

Check warning on line 475 in sumcheck/src/prover.rs

View workflow job for this annotation

GitHub Actions / Run Tests

use of deprecated associated function `prover::<impl structs::IOPProverState<E>>::prover_init_parallel`: deprecated parallel version due to syncronizaion overhead
let mut challenge = None;
let mut prover_msgs = Vec::with_capacity(num_variables);
let span = entered_span!("prove_rounds");
for _ in 0..num_variables {
let prover_msg =
IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge);

prover_msg
.evaluations
.iter()
.for_each(|e| transcript.append_field_element_ext(e));

prover_msgs.push(prover_msg);
let span = entered_span!("get_challenge");
challenge = Some(transcript.get_and_append_challenge(b"Internal round"));
exit_span!(span);
}
exit_span!(span);

let span = entered_span!("after_rounds_prover_state");
// pushing the last challenge point to the state
if let Some(p) = challenge {
prover_state.challenges.push(p);
// fix last challenge to collect final evaluation
prover_state
.poly
.flattened_ml_extensions
.iter_mut()
.for_each(|mle| {
Arc::make_mut(mle).fix_variables_in_place(&[p.elements]);
});
};
exit_span!(span);

end_timer!(start);
(
IOPProof {
// the point consists of the first elements in the challenge
point: prover_state
.challenges
.iter()
.map(|challenge| challenge.elements)
.collect(),
proofs: prover_msgs,
..Default::default()
},
prover_state.into(),
)
}

/// collect all mle evaluation (claim) after sumcheck
pub fn get_mle_final_evaluations(&self) -> Vec<E> {
self.poly
Expand Down

0 comments on commit a4660d1

Please sign in to comment.