Skip to content

Commit

Permalink
Refactor UnivariatePolynomial (#29)
Browse files Browse the repository at this point in the history
* refactor: use `enum` instead of `trait` for `UnivariatePolynomial` basis

* refactor `UnivariatePolynomial` and others

* chore
  • Loading branch information
han0110 authored Oct 16, 2023
1 parent 6466eb2 commit 4ce89e1
Show file tree
Hide file tree
Showing 25 changed files with 867 additions and 310 deletions.
11 changes: 4 additions & 7 deletions plonkish_backend/src/accumulation/protostar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
},
backend::PlonkishBackend,
pcs::{AdditiveCommitment, PolynomialCommitmentScheme},
poly::Polynomial,
util::{
arithmetic::{inner_product, powers, Field},
chain,
Expand Down Expand Up @@ -101,33 +100,31 @@ where
{
fn init(
strategy: ProtostarStrategy,
k: usize,
num_instances: &[usize],
num_witness_polys: usize,
num_challenges: usize,
) -> Self {
let zero_poly = Pcs::Polynomial::from_evals(vec![F::ZERO; 1 << k]);
Self {
instance: ProtostarAccumulatorInstance::init(
strategy,
num_instances,
num_witness_polys,
num_challenges,
),
witness_polys: iter::repeat_with(|| zero_poly.clone())
witness_polys: iter::repeat_with(Default::default)
.take(num_witness_polys)
.collect(),
e_poly: zero_poly,
e_poly: Default::default(),
_marker: PhantomData,
}
}

fn from_nark(strategy: ProtostarStrategy, k: usize, nark: PlonkishNark<F, Pcs>) -> Self {
fn from_nark(strategy: ProtostarStrategy, nark: PlonkishNark<F, Pcs>) -> Self {
let witness_polys = nark.witness_polys;
Self {
instance: ProtostarAccumulatorInstance::from_nark(strategy, nark.instance),
witness_polys,
e_poly: Pcs::Polynomial::from_evals(vec![F::ZERO; 1 << k]),
e_poly: Default::default(),
_marker: PhantomData,
}
}
Expand Down
7 changes: 1 addition & 6 deletions plonkish_backend/src/accumulation/protostar/hyperplonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ where
fn init_accumulator(pp: &Self::ProverParam) -> Result<Self::Accumulator, Error> {
Ok(ProtostarAccumulator::init(
pp.strategy,
pp.pp.num_vars,
&pp.pp.num_instances,
pp.num_folding_witness_polys,
pp.num_folding_challenges,
Expand All @@ -89,11 +88,7 @@ where
pp: &Self::ProverParam,
nark: PlonkishNark<F, Self::Pcs>,
) -> Result<Self::Accumulator, Error> {
Ok(ProtostarAccumulator::from_nark(
pp.strategy,
pp.pp.num_vars,
nark,
))
Ok(ProtostarAccumulator::from_nark(pp.strategy, nark))
}

fn prove_nark(
Expand Down
16 changes: 14 additions & 2 deletions plonkish_backend/src/accumulation/protostar/hyperplonk/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ where
return Vec::new();
}

let num_cross_terms = cross_term_expressions.len();
if accumulator.instance.u.is_zero_vartime() {
return vec![MultilinearPolynomial::new(vec![F::ZERO; 1 << num_vars]); num_cross_terms];
}

let ev = init_hadamard_evaluator(
cross_term_expressions,
num_vars,
Expand All @@ -105,7 +110,6 @@ where

let size = 1 << ev.num_vars;
let chunk_size = div_ceil(size, num_threads());
let num_cross_terms = ev.reg.indexed_outputs().len();

let mut outputs = vec![F::ZERO; num_cross_terms * size];
parallelize_iter(
Expand Down Expand Up @@ -141,6 +145,11 @@ where
return Vec::new();
}

let num_cross_terms = cross_term_expressions.len();
if accumulator.instance.u.is_zero_vartime() {
return vec![F::ZERO; num_cross_terms];
}

let ev = init_hadamard_evaluator(
cross_term_expressions,
num_vars,
Expand All @@ -152,7 +161,6 @@ where
let size = 1 << ev.num_vars;
let num_threads = num_threads();
let chunk_size = div_ceil(size, num_threads);
let num_cross_terms = ev.reg.indexed_outputs().len();

let mut partial_sums = vec![vec![F::ZERO; num_cross_terms]; num_threads];
parallelize_iter(
Expand Down Expand Up @@ -183,6 +191,10 @@ where
F: PrimeField,
Pcs: PolynomialCommitmentScheme<F, Polynomial = MultilinearPolynomial<F>>,
{
if accumulator.instance.u.is_zero_vartime() {
return MultilinearPolynomial::new(vec![F::ZERO; 1 << num_vars]);
}

let [(acc_pow, acc_zeta, acc_u), (incoming_pow, incoming_zeta, incoming_u)] =
[accumulator, incoming].map(|witness| {
let pow = witness.witness_polys.last().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions plonkish_backend/src/backend/hyperplonk/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
classic::{ClassicSumCheck, EvaluationsProver},
SumCheck, VirtualPolynomial,
},
poly::{multilinear::MultilinearPolynomial, Polynomial},
poly::multilinear::MultilinearPolynomial,
util::{
arithmetic::{div_ceil, steps_by, sum, BatchInvert, BooleanHypercube, PrimeField},
end_timer,
Expand Down Expand Up @@ -377,7 +377,7 @@ pub(crate) fn prove_sum_check<F: PrimeField>(
let num_vars = polys[0].num_vars();
let ys = [y];
let virtual_poly = VirtualPolynomial::new(expression, polys.to_vec(), &challenges, &ys);
let (x, evals) = ClassicSumCheck::<EvaluationsProver<_>>::prove(
let (_, x, evals) = ClassicSumCheck::<EvaluationsProver<_>>::prove(
&(),
num_vars,
virtual_poly,
Expand Down
2 changes: 1 addition & 1 deletion plonkish_backend/src/backend/hyperplonk/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
mock::MockCircuit,
PlonkishCircuit, PlonkishCircuitInfo,
},
poly::{multilinear::MultilinearPolynomial, Polynomial},
poly::multilinear::MultilinearPolynomial,
util::{
arithmetic::{powers, BooleanHypercube, PrimeField},
expression::{Expression, Query, Rotation},
Expand Down
15 changes: 7 additions & 8 deletions plonkish_backend/src/pcs/multilinear.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
poly::{multilinear::MultilinearPolynomial, Polynomial},
poly::multilinear::MultilinearPolynomial,
util::{arithmetic::Field, end_timer, izip, parallel::parallelize, start_timer, Itertools},
Error,
};
Expand Down Expand Up @@ -116,7 +116,7 @@ mod additive {
classic::{ClassicSumCheck, CoefficientsProver},
eq_xy_eval, SumCheck as _, VirtualPolynomial,
},
poly::{multilinear::MultilinearPolynomial, Polynomial},
poly::multilinear::MultilinearPolynomial,
util::{
arithmetic::{inner_product, PrimeField},
end_timer,
Expand Down Expand Up @@ -155,7 +155,7 @@ mod additive {
let merged_polys = evals.iter().zip(eq_xt.evals().iter()).fold(
vec![(F::ONE, Cow::<MultilinearPolynomial<_>>::default()); points.len()],
|mut merged_polys, (eval, eq_xt_i)| {
if merged_polys[eval.point()].1.is_zero() {
if merged_polys[eval.point()].1.is_empty() {
merged_polys[eval.point()] = (*eq_xt_i, Cow::Borrowed(polys[eval.poly()]));
} else {
let coeff = merged_polys[eval.point()].0;
Expand Down Expand Up @@ -197,7 +197,7 @@ mod additive {
);
let tilde_gs_sum =
inner_product(evals.iter().map(Evaluation::value), &eq_xt[..evals.len()]);
let (challenges, _) =
let (g_prime_eval, challenges, _) =
SumCheck::prove(&(), num_vars, virtual_poly, tilde_gs_sum, transcript)?;

let timer = start_timer(|| "g_prime");
Expand All @@ -212,17 +212,16 @@ mod additive {
.sum::<MultilinearPolynomial<_>>();
end_timer(timer);

let (g_prime_comm, g_prime_eval) = if cfg!(feature = "sanity-check") {
let g_prime_comm = if cfg!(feature = "sanity-check") {
let scalars = evals
.iter()
.zip(eq_xt.evals())
.map(|(eval, eq_xt_i)| eq_xy_evals[eval.point()] * eq_xt_i)
.collect_vec();
let bases = evals.iter().map(|eval| comms[eval.poly()]);
let comm = Pcs::Commitment::sum_with_scalar(&scalars, bases);
(comm, g_prime.evaluate(&challenges))
Pcs::Commitment::sum_with_scalar(&scalars, bases)
} else {
(Pcs::Commitment::default(), F::ZERO)
Pcs::Commitment::default()
};
Pcs::open(
pp,
Expand Down
2 changes: 1 addition & 1 deletion plonkish_backend/src/pcs/multilinear/brakedown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use crate::{
pcs::{multilinear::validate_input, Evaluation, Point, PolynomialCommitmentScheme},
poly::{multilinear::MultilinearPolynomial, Polynomial},
poly::multilinear::MultilinearPolynomial,
util::{
arithmetic::{div_ceil, inner_product, PrimeField},
code::{Brakedown, BrakedownSpec, LinearCodes},
Expand Down
8 changes: 4 additions & 4 deletions plonkish_backend/src/pcs/multilinear/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
},
poly::{
multilinear::{merge_into, MultilinearPolynomial},
univariate::UnivariatePolynomial,
univariate::{UnivariateBasis::Monomial, UnivariatePolynomial},
Polynomial,
},
util::{
Expand Down Expand Up @@ -62,7 +62,7 @@ where
)));
}

Ok(UnivariateKzg::commit_coeffs(pp, poly.evals()))
Ok(UnivariateKzg::commit_monomial(pp, poly.evals()))
}

fn batch_commit<'a>(
Expand Down Expand Up @@ -99,12 +99,12 @@ where

let fs = {
let mut fs = Vec::with_capacity(num_vars);
fs.push(UnivariatePolynomial::new(poly.evals().to_vec()));
fs.push(UnivariatePolynomial::new(Monomial, poly.evals().to_vec()));
for x_i in &point[..num_vars - 1] {
let f_i_minus_one = fs.last().unwrap().coeffs();
let mut f_i = Vec::with_capacity(f_i_minus_one.len() >> 1);
merge_into(&mut f_i, f_i_minus_one, x_i, 1, 0);
fs.push(UnivariatePolynomial::new(f_i));
fs.push(UnivariatePolynomial::new(Monomial, f_i));
}

if cfg!(feature = "sanity-check") {
Expand Down
25 changes: 9 additions & 16 deletions plonkish_backend/src/pcs/multilinear/hyrax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use crate::{
},
AdditiveCommitment, Evaluation, Point, PolynomialCommitmentScheme,
},
poly::{multilinear::MultilinearPolynomial, Polynomial},
poly::multilinear::MultilinearPolynomial,
util::{
arithmetic::{div_ceil, variable_base_msm, Curve, CurveAffine, Group},
arithmetic::{batch_projective_to_affine, div_ceil, variable_base_msm, CurveAffine, Group},
parallel::parallelize,
transcript::{TranscriptRead, TranscriptWrite},
Deserialize, DeserializeOwned, Itertools, Serialize,
Expand Down Expand Up @@ -93,16 +93,13 @@ impl<C: CurveAffine> AdditiveCommitment<C::Scalar> for MultilinearHyraxCommitmen
assert_eq!(bases.0.len(), num_chunks);
}

let mut output_projective = vec![C::CurveExt::identity(); num_chunks];
parallelize(&mut output_projective, |(output, start)| {
let mut output = vec![C::CurveExt::identity(); num_chunks];
parallelize(&mut output, |(output, start)| {
for (output, idx) in output.iter_mut().zip(start..) {
*output = variable_base_msm(scalars.clone(), bases.iter().map(|base| &base.0[idx]))
}
});
let mut output = vec![C::identity(); num_chunks];
C::CurveExt::batch_normalize(&output_projective, &mut output);

MultilinearHyraxCommitment(output)
MultilinearHyraxCommitment(batch_projective_to_affine(&output))
}
}

Expand Down Expand Up @@ -171,17 +168,15 @@ where

let row_len = pp.row_len();
let scalars = poly.evals();
let comm_projective = {
let comm = {
let mut comm = vec![C::CurveExt::identity(); pp.num_chunks()];
parallelize(&mut comm, |(comm, start)| {
for (comm, start) in comm.iter_mut().zip((start * row_len..).step_by(row_len)) {
*comm = variable_base_msm(&scalars[start..start + row_len], pp.g());
}
});
comm
batch_projective_to_affine(&comm)
};
let mut comm = vec![C::identity(); pp.num_chunks()];
C::CurveExt::batch_normalize(&comm_projective, &mut comm);

Ok(MultilinearHyraxCommitment(comm))
}
Expand All @@ -200,17 +195,15 @@ where
.iter()
.flat_map(|poly| poly.evals().chunks(pp.row_len()))
.collect_vec();
let comms_projective = {
let comms = {
let mut comms = vec![C::CurveExt::identity(); scalars.len()];
parallelize(&mut comms, |(comms, start)| {
for (comm, scalars) in comms.iter_mut().zip(&scalars[start..]) {
*comm = variable_base_msm(*scalars, pp.g());
}
});
comms
batch_projective_to_affine(&comms)
};
let mut comms = vec![C::identity(); scalars.len()];
C::CurveExt::batch_normalize(&comms_projective, &mut comms);

Ok(comms
.into_iter()
Expand Down
17 changes: 5 additions & 12 deletions plonkish_backend/src/pcs/multilinear/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ use crate::{
multilinear::{additive, err_too_many_variates, validate_input},
AdditiveCommitment, Evaluation, Point, PolynomialCommitmentScheme,
},
poly::{multilinear::MultilinearPolynomial, Polynomial},
poly::multilinear::MultilinearPolynomial,
util::{
arithmetic::{
inner_product, variable_base_msm, Curve, CurveAffine, CurveExt, Field, Group,
batch_projective_to_affine, inner_product, variable_base_msm, Curve, CurveAffine,
CurveExt, Field, Group,
},
chain,
parallel::parallelize,
Expand Down Expand Up @@ -99,7 +100,7 @@ where
assert!(poly_size.is_power_of_two());
let num_vars = poly_size.ilog2() as usize;

let g_projective = {
let g = {
let mut g = vec![C::Curve::identity(); poly_size];
parallelize(&mut g, |(g, start)| {
let hasher = C::CurveExt::hash_to_curve("MultilinearIpa::setup");
Expand All @@ -109,15 +110,7 @@ where
*g = hasher(&message);
}
});
g
};

let g = {
let mut g = vec![C::identity(); poly_size];
parallelize(&mut g, |(g, start)| {
C::Curve::batch_normalize(&g_projective[start..(start + g.len())], g);
});
g
batch_projective_to_affine(&g)
};

let hasher = C::CurveExt::hash_to_curve("MultilinearIpa::setup");
Expand Down
Loading

0 comments on commit 4ce89e1

Please sign in to comment.