Skip to content

Commit

Permalink
Fix clippy warning, make code cleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamATD committed Sep 4, 2024
1 parent d6e75ec commit 3d59141
Show file tree
Hide file tree
Showing 19 changed files with 356 additions and 394 deletions.
55 changes: 23 additions & 32 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ where

fn batch_commit_and_write(
pp: &Self::ProverParam,
polys: &Vec<DenseMultilinearExtension<E>>,
polys: &[DenseMultilinearExtension<E>],
transcript: &mut impl TranscriptWrite<Self::CommitmentChunk, E>,
) -> Result<Self::CommitmentWithData, Error> {
let timer = start_timer!(|| "Basefold::batch_commit_and_write");
Expand All @@ -263,14 +263,14 @@ where

fn batch_commit(
pp: &Self::ProverParam,
polys: &Vec<DenseMultilinearExtension<E>>,
polys: &[DenseMultilinearExtension<E>],
) -> Result<Self::CommitmentWithData, Error> {
// assumptions
// 1. there must be at least one polynomial
// 2. all polynomials must exist in the same field type
// 3. all polynomials must have the same number of variables

if polys.len() == 0 {
if polys.is_empty() {
return Err(Error::InvalidPcsParam(
"cannot batch commit to zero polynomials".to_string(),
));
Expand Down Expand Up @@ -330,8 +330,8 @@ where
assert!(comm.num_polys == 1);
let (trees, oracles) = commit_phase::<E, Spec>(
&pp.encoding_params,
&point,
&comm,
point,
comm,
transcript,
poly.num_vars,
poly.num_vars - Spec::get_basecode_msg_size_log(),
Expand All @@ -341,7 +341,7 @@ where
let query_timer = start_timer!(|| "Basefold::open::query_phase");
// Each entry in queried_els stores a list of triples (F, F, i) indicating the
// position opened at each round and the two values at that round
let queries = prover_query_phase(transcript, &comm, &oracles, Spec::get_number_queries());
let queries = prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries());
end_timer!(query_timer);

let query_timer = start_timer!(|| "Basefold::open::build_query_result");
Expand All @@ -366,8 +366,8 @@ where
/// not very useful in ceno.
fn batch_open(
pp: &Self::ProverParam,
polys: &Vec<DenseMultilinearExtension<E>>,
comms: &Vec<Self::CommitmentWithData>,
polys: &[DenseMultilinearExtension<E>],
comms: &[Self::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
transcript: &mut impl TranscriptWrite<Self::CommitmentChunk, E>,
Expand All @@ -391,12 +391,7 @@ where
})
}

validate_input(
"batch open",
pp.get_max_message_size_log(),
&polys.clone(),
&points.to_vec(),
)?;
validate_input("batch open", pp.get_max_message_size_log(), polys, points)?;

let sumcheck_timer = start_timer!(|| "Basefold::batch_open::initial sumcheck");
// evals.len() is the batch size, i.e., how many polynomials are being opened together
Expand Down Expand Up @@ -463,7 +458,7 @@ where
.map(|((scalar, poly), point)| {
inner_product(
&poly_iter_ext(poly).collect_vec(),
build_eq_x_r_vec(&point).iter(),
build_eq_x_r_vec(point).iter(),
) * scalar
* E::from(1 << (num_vars - poly.num_vars))
// When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube
Expand Down Expand Up @@ -528,7 +523,7 @@ where
);
*scalar
* evals_from_sum_check
* &eq_xy_eval(point.as_slice(), &challenges[0..point.len()])
* eq_xy_eval(point.as_slice(), &challenges[0..point.len()])
})
.sum::<E>();
assert_eq!(new_target_sum, desired_sum);
Expand All @@ -541,7 +536,7 @@ where
let (trees, oracles) = batch_commit_phase::<E, Spec>(
&pp.encoding_params,
&point,
comms.as_slice(),
comms,
transcript,
num_vars,
num_vars - Spec::get_basecode_msg_size_log(),
Expand All @@ -553,7 +548,7 @@ where
let query_result = batch_prover_query_phase(
transcript,
1 << (num_vars + Spec::get_rate_log()),
comms.as_slice(),
comms,
&oracles,
Spec::get_number_queries(),
);
Expand All @@ -564,7 +559,7 @@ where
BatchedQueriesResultWithMerklePath::from_batched_query_result(
query_result,
&trees,
&comms,
comms,
);
end_timer!(query_timer);

Expand All @@ -583,7 +578,7 @@ where
/// 3. The point is already a random point generated by a sum-check.
fn simple_batch_open(
pp: &Self::ProverParam,
polys: &Vec<DenseMultilinearExtension<E>>,
polys: &[DenseMultilinearExtension<E>],
comm: &Self::CommitmentWithData,
point: &[E],
evals: &[E],
Expand Down Expand Up @@ -623,9 +618,9 @@ where
// the new target sum, where coeffs is computed as follows
let (trees, oracles) = simple_batch_commit_phase::<E, Spec>(
&pp.encoding_params,
&point,
point,
&eq_xt,
&comm,
comm,
transcript,
num_vars,
num_vars - Spec::get_basecode_msg_size_log(),
Expand All @@ -635,12 +630,8 @@ where
let query_timer = start_timer!(|| "Basefold::open::query_phase");
// Each entry in queried_els stores a list of triples (F, F, i) indicating the
// position opened at each round and the two values at that round
let queries = simple_batch_prover_query_phase(
transcript,
&comm,
&oracles,
Spec::get_number_queries(),
);
let queries =
simple_batch_prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries());
end_timer!(query_timer);

let query_timer = start_timer!(|| "Basefold::open::build_query_result");
Expand Down Expand Up @@ -779,7 +770,7 @@ where
&roots,
comm,
eq.as_slice(),
&eval,
eval,
&hasher,
);
end_timer!(timer);
Expand All @@ -789,7 +780,7 @@ where

fn batch_verify(
vp: &Self::VerifierParam,
comms: &Vec<Self::Commitment>,
comms: &[Self::Commitment],
points: &[Vec<E>],
evals: &[Evaluation<E>],
transcript: &mut impl TranscriptRead<Self::CommitmentChunk, E>,
Expand All @@ -798,10 +789,10 @@ where
// let key = "RAYON_NUM_THREADS";
// env::set_var(key, "32");
let hasher = new_hasher::<E::BaseField>();
let comms = comms.into_iter().collect_vec();
let comms = comms.iter().collect_vec();
let num_vars = points.iter().map(|point| point.len()).max().unwrap();
let num_rounds = num_vars - Spec::get_basecode_msg_size_log();
validate_input("batch verify", num_vars, &vec![], &points.to_vec())?;
validate_input("batch verify", num_vars, &[], points)?;
let poly_num_vars = comms.iter().map(|c| c.num_vars().unwrap()).collect_vec();
evals.iter().for_each(|eval| {
assert_eq!(
Expand Down
22 changes: 12 additions & 10 deletions mpcs/src/basefold/commit_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where

// eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
let build_eq_timer = start_timer!(|| "Basefold::open");
let mut eq = build_eq_x_r_vec(&point);
let mut eq = build_eq_x_r_vec(point);
end_timer!(build_eq_timer);
reverse_index_bits_in_place(&mut eq);

Expand Down Expand Up @@ -142,10 +142,11 @@ where
end_timer!(sumcheck_timer);
}
end_timer!(timer);
return (trees, oracles);
(trees, oracles)
}

// outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval)
#[allow(clippy::too_many_arguments)]
pub fn batch_commit_phase<E: ExtensionField, Spec: BasefoldSpec<E>>(
pp: &<Spec::EncodingScheme as EncodingScheme<E>>::ProverParameters,
point: &[E],
Expand Down Expand Up @@ -177,7 +178,7 @@ where
running_oracle
.iter_mut()
.zip_eq(field_type_iter_ext(&comm.get_codewords()[0]))
.for_each(|(r, a)| *r += E::from(a) * coeffs[index]);
.for_each(|(r, a)| *r += a * coeffs[index]);
});
end_timer!(build_oracle_timer);

Expand All @@ -196,16 +197,16 @@ where
// to align the polynomials to the variable with index 0 before adding them
// together. So each element is repeated by
// sum_of_all_evals_for_sumcheck.len() / bh_evals.len() times
*r += E::from(field_type_index_ext(
*r += field_type_index_ext(
&comm.polynomials_bh_evals[0],
pos >> (num_vars - log2_strict(comm.polynomials_bh_evals[0].len())),
)) * coeffs[index]
) * coeffs[index]
});
});
end_timer!(build_oracle_timer);

// eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
let mut eq = build_eq_x_r_vec(&point);
let mut eq = build_eq_x_r_vec(point);
reverse_index_bits_in_place(&mut eq);

let sumcheck_timer = start_timer!(|| "Basefold first round");
Expand Down Expand Up @@ -256,7 +257,7 @@ where
running_oracle
.iter_mut()
.zip_eq(field_type_iter_ext(&comm.get_codewords()[0]))
.for_each(|(r, a)| *r += E::from(a) * coeffs[index]);
.for_each(|(r, a)| *r += a * coeffs[index]);
});
} else {
// The difference of the last round is that we don't need to compute the message,
Expand Down Expand Up @@ -296,10 +297,11 @@ where
end_timer!(sumcheck_timer);
}
end_timer!(timer);
return (trees, oracles);
(trees, oracles)
}

// outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval)
#[allow(clippy::too_many_arguments)]
pub fn simple_batch_commit_phase<E: ExtensionField, Spec: BasefoldSpec<E>>(
pp: &<Spec::EncodingScheme as EncodingScheme<E>>::ProverParameters,
point: &[E],
Expand Down Expand Up @@ -331,7 +333,7 @@ where

// eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
let build_eq_timer = start_timer!(|| "Basefold::open");
let mut eq = build_eq_x_r_vec(&point);
let mut eq = build_eq_x_r_vec(point);
end_timer!(build_eq_timer);
reverse_index_bits_in_place(&mut eq);

Expand Down Expand Up @@ -404,7 +406,7 @@ where
end_timer!(sumcheck_timer);
}
end_timer!(timer);
return (trees, oracles);
(trees, oracles)
}

fn basefold_one_round_by_interpolation_weights<E: ExtensionField, Spec: BasefoldSpec<E>>(
Expand Down
14 changes: 7 additions & 7 deletions mpcs/src/basefold/encoding.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use ff_ext::ExtensionField;
use multilinear_extensions::mle::FieldType;

mod utils;

mod basecode;
pub use basecode::{Basecode, BasecodeDefaultSpec};

Expand Down Expand Up @@ -132,27 +134,25 @@ pub trait EncodingScheme<E: ExtensionField>: std::fmt::Debug + Clone {
}
}

fn concatenate_field_types<E: ExtensionField>(coeffs: &Vec<FieldType<E>>) -> FieldType<E> {
fn concatenate_field_types<E: ExtensionField>(coeffs: &[FieldType<E>]) -> FieldType<E> {
match coeffs[0] {
FieldType::Ext(_) => {
let res = coeffs
.iter()
.map(|x| match x {
FieldType::Ext(x) => x.iter().map(|x| *x),
.flat_map(|x| match x {
FieldType::Ext(x) => x.iter().copied(),
_ => unreachable!(),
})
.flatten()
.collect::<Vec<_>>();
FieldType::Ext(res)
}
FieldType::Base(_) => {
let res = coeffs
.iter()
.map(|x| match x {
FieldType::Base(x) => x.iter().map(|x| *x),
.flat_map(|x| match x {
FieldType::Base(x) => x.iter().copied(),
_ => unreachable!(),
})
.flatten()
.collect::<Vec<_>>();
FieldType::Base(res)
}
Expand Down
Loading

0 comments on commit 3d59141

Please sign in to comment.