Skip to content

Commit

Permalink
Faster polyeval
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunming Jiang committed Jan 2, 2025
1 parent 1520dbe commit ec8458a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 37 deletions.
54 changes: 53 additions & 1 deletion spartan_parallel/src/dense_mlpoly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use super::random::RandomTape;
use super::transcript::ProofTranscript;
use core::ops::Index;
use merlin::Transcript;
use rayon::{iter::ParallelIterator, slice::ParallelSliceMut};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::{cmp::min, collections::HashMap};

#[cfg(feature = "multicore")]
use rayon::prelude::*;
Expand Down Expand Up @@ -247,6 +248,57 @@ impl<S: SpartanExtensionField> DensePolynomial<S> {
self.len = n;
}

fn fold_r(proofs: &mut [S], r: &[S], step: usize, mut l: usize) {
for r in r {
let r1 = S::field_one() - r.clone();
let r2 = r.clone();

l = l.div_ceil(2);
(0..l).for_each(|i| {
proofs[i * step] = r1 * proofs[2 * i * step] + r2 * proofs[(2 * i + 1) * step];
});
}
}

// returns Z(r) in O(n) time
pub fn evaluate_and_consume_parallel(&mut self, r: &[S]) -> S {
assert_eq!(r.len(), self.get_num_vars());
let mut inst = std::mem::take(&mut self.Z);

let len = self.len;
let dist_size = len / min(len, rayon::current_num_threads().next_power_of_two()); // distributed number of proofs on each thread
let num_threads = len / dist_size;

// To perform rigorous parallelism, both len and # threads must be powers of 2
// # threads must fully divide num_proofs for even distribution
assert_eq!(len, len.next_power_of_two());
assert_eq!(num_threads, num_threads.next_power_of_two());

// Determine parallelism levels
let levels = len.log_2(); // total layers
let sub_levels = dist_size.log_2(); // parallel layers
let final_levels = num_threads.log_2(); // single core final layers
// Divide r into sub and final
let sub_r = &r[0..sub_levels];
let final_r = &r[sub_levels..levels];

if sub_levels > 0 {
inst = inst
.par_chunks_mut(dist_size)
.map(|chunk| {
Self::fold_r(chunk, sub_r, 1, dist_size);
chunk.to_vec()
})
.flatten().collect()
}

if final_levels > 0 {
// aggregate the final result from sub-threads outputs using a single core
Self::fold_r(&mut inst, final_r, dist_size, num_threads);
}
inst[0]
}

// returns Z(r) in O(n) time
pub fn evaluate(&self, r: &[S]) -> S {
// r must have a value for each variable
Expand Down
57 changes: 22 additions & 35 deletions spartan_parallel/src/r1csproof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
evals_ABC: &mut DensePolynomialPqx<S>,
evals_z: &mut DensePolynomialPqx<S>,
transcript: &mut Transcript,
) -> (SumcheckInstanceProof<S>, Vec<S>, Vec<S>) {
) -> (SumcheckInstanceProof<S>, Vec<S>, Vec<S>, Vec<Vec<S>>) {
let comb_func = |poly_A_comp: &S, poly_B_comp: &S, poly_C_comp: &S| -> S {
*poly_A_comp * *poly_B_comp * *poly_C_comp
};
let (sc_proof_phase_two, r, claims) = SumcheckInstanceProof::<S>::prove_cubic_disjoint_rounds(
let (sc_proof_phase_two, r, claims, claimed_vars_at_ry) = SumcheckInstanceProof::<S>::prove_cubic_disjoint_rounds(
claim,
num_rounds,
num_rounds_y_max,
Expand All @@ -102,7 +102,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
transcript,
);

(sc_proof_phase_two, r, claims)
(sc_proof_phase_two, r, claims, claimed_vars_at_ry)
}

fn protocol_name() -> &'static [u8] {
Expand Down Expand Up @@ -344,7 +344,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {

// Sumcheck 2: (rA + rB + rC) * Z * eq(p) = e
let timer_tmp = Timer::new("prove_sum_check");
let (sc_proof_phase2, ry_rev, _claims_phase2) = R1CSProof::prove_phase_two(
let (sc_proof_phase2, ry_rev, _claims_phase2, claimed_vars_at_ry) = R1CSProof::prove_phase_two(
num_rounds_y + num_rounds_w + num_rounds_p,
num_rounds_y,
num_rounds_w,
Expand Down Expand Up @@ -378,6 +378,10 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
let timer_polyeval = Timer::new("polyeval");

// For every possible wit_sec.num_inputs, compute ry_factor = prodX(1 - ryX)...
let mut rq_factors = vec![ONE; num_rounds_q + 1];
for i in 0..num_rounds_q {
rq_factors[i + 1] = rq_factors[i] * (ONE - rq[i]);
}
let mut ry_factors = vec![ONE; num_rounds_y + 1];
for i in 0..num_rounds_y {
ry_factors[i + 1] = ry_factors[i] * (ONE - ry[i]);
Expand All @@ -388,42 +392,26 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
let mut num_inputs_list = Vec::new();
// List of all evaluations
let mut Zr_list = Vec::new();
// List of evaluations separated by witness_secs
// Obtain list of evaluations separated by witness_secs
// Note: eval_vars_at_ry_list and raw_eval_vars_at_ry_list are W * P but claimed_vars_at_ry_list is P * W, and
// raw_eval_vars_at_ry_list does not multiply rq_factor and ry_factor
let mut eval_vars_at_ry_list = vec![Vec::new(); num_witness_secs];
let mut raw_eval_vars_at_ry_list = vec![Vec::new(); num_witness_secs]; // Does not multiply ry_factor
let mut raw_eval_vars_at_ry_list = vec![Vec::new(); num_witness_secs];
for i in 0..num_witness_secs {
let w = witness_secs[i];
let wit_sec_num_instance = w.w_mat.len();
eval_vars_at_ry_list.push(Vec::new());

for p in 0..wit_sec_num_instance {
if w.num_inputs[p] > 1 {
poly_list.push(&w.poly_w[p]);
num_proofs_list.push(w.w_mat[p].len());
num_inputs_list.push(w.num_inputs[p]);
// Depending on w.num_inputs[p], ry_short can be two different values
let ry_short = {
// if w.num_inputs[p] >= num_inputs, need to pad 0's to the front of ry
if w.num_inputs[p] >= max_num_inputs {
let ry_pad = vec![ZERO; w.num_inputs[p].log_2() - max_num_inputs.log_2()];
[ry_pad, ry.clone()].concat()
}
// Else ry_short is the last w.num_inputs[p].log_2() entries of ry
// thus, to obtain the actual ry, need to multiply by (1 - ry0)(1 - ry1)..., which is ry_factors[num_rounds_y - w.num_inputs[p]]
else {
ry[num_rounds_y - w.num_inputs[p].log_2()..].to_vec()
}
};
let rq_short = rq[num_rounds_q - num_proofs_list[num_proofs_list.len() - 1].log_2()..].to_vec();
let r = &[rq_short, ry_short.clone()].concat();
let eval_vars_at_ry = poly_list[poly_list.len() - 1].evaluate(r);
Zr_list.push(eval_vars_at_ry);
if w.num_inputs[p] >= max_num_inputs {
eval_vars_at_ry_list[i].push(eval_vars_at_ry);
} else {
eval_vars_at_ry_list[i].push(eval_vars_at_ry * ry_factors[num_rounds_y - w.num_inputs[p].log_2()]);
}
raw_eval_vars_at_ry_list[i].push(eval_vars_at_ry);
// Find out the extra q and y padding to remove in raw_eval_vars_at_ry_list
let rq_pad_inv = rq_factors[num_rounds_q - num_proofs[p].log_2()].invert().unwrap();
let ry_pad_inv = if w.num_inputs[p] >= max_num_inputs { ONE } else { ry_factors[num_rounds_y - w.num_inputs[p].log_2()].invert().unwrap() };
eval_vars_at_ry_list[i].push(claimed_vars_at_ry[p][i] * rq_pad_inv); // I don't know why need to divide by rq and later multiply it back, but it doesn't work without this
let claimed_vars_at_ry_no_pad = claimed_vars_at_ry[p][i] * rq_pad_inv * ry_pad_inv;
Zr_list.push(claimed_vars_at_ry_no_pad);
raw_eval_vars_at_ry_list[i].push(claimed_vars_at_ry_no_pad);
} else {
eval_vars_at_ry_list[i].push(ZERO);
raw_eval_vars_at_ry_list[i].push(ZERO);
Expand Down Expand Up @@ -491,16 +479,15 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
};
let mut eval_vars_comb =
(0..num_witness_secs).fold(ZERO, |s, i| s + prefix_list[i] * e(i));
for q in 0..(num_rounds_q - num_proofs[p].log_2()) {
eval_vars_comb = eval_vars_comb * (ONE - rq[q]);
}
eval_vars_comb *= rq_factors[num_rounds_q - num_proofs[p].log_2()];
eval_vars_comb_list.push(eval_vars_comb);
}
timer_polyeval.stop();

let poly_vars = DensePolynomial::new(eval_vars_comb_list);
let eval_vars_at_ry = poly_vars.evaluate(&rp);

// prove the final step of sum-check #2
// Deferred to verifier
timer_prove.stop();

(
Expand Down
13 changes: 12 additions & 1 deletion spartan_parallel/src/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
poly_C: &mut DensePolynomialPqx<S>,
comb_func: F,
transcript: &mut Transcript,
) -> (Self, Vec<S>, Vec<S>)
) -> (Self, Vec<S>, Vec<S>, Vec<Vec<S>>)
where
F: Fn(&S, &S, &S) -> S,
{
Expand All @@ -353,6 +353,8 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
let mut witness_secs_len = num_rounds_w.pow2();
let mut instance_len: usize = num_rounds_p.pow2();

// Every variable binded to ry
let mut claimed_vars_at_ry = Vec::new();
for j in 0..num_rounds {
/* For debugging only */
/* If the value is not 0, the instance / input is wrong */
Expand Down Expand Up @@ -385,6 +387,14 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
} else {
MODE_P
};
if j == num_rounds_y_max {
for p in 0..poly_C.num_instances {
claimed_vars_at_ry.push(Vec::new());
for w in 0..poly_C.num_witness_secs {
claimed_vars_at_ry[p].push(poly_C.index(p, 0, w, 0));
}
}
}

if inputs_len > 1 {
inputs_len /= 2
Expand Down Expand Up @@ -486,6 +496,7 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
poly_B.index(0, 0, 0, 0),
poly_C.index(0, 0, 0, 0),
],
claimed_vars_at_ry,
)
}

Expand Down

0 comments on commit ec8458a

Please sign in to comment.