From 3d591416f8f98afce8c9c4cd5bea8e2a99f46871 Mon Sep 17 00:00:00 2001 From: dreamATD Date: Wed, 4 Sep 2024 16:22:19 +0800 Subject: [PATCH] Fix clippy warning, make code cleaner --- mpcs/src/basefold.rs | 55 ++++----- mpcs/src/basefold/commit_phase.rs | 22 ++-- mpcs/src/basefold/encoding.rs | 14 +-- mpcs/src/basefold/encoding/basecode.rs | 75 +++++------- mpcs/src/basefold/encoding/rs.rs | 149 ++++++++--------------- mpcs/src/basefold/encoding/utils.rs | 35 ++++++ mpcs/src/basefold/query_phase.rs | 162 +++++++++++++------------ mpcs/src/basefold/structure.rs | 12 +- mpcs/src/basefold/sumcheck.rs | 55 ++++----- mpcs/src/lib.rs | 47 ++++--- mpcs/src/sum_check.rs | 10 +- mpcs/src/sum_check/classic.rs | 8 +- mpcs/src/sum_check/classic/coeff.rs | 8 +- mpcs/src/util/arithmetic.rs | 27 ++--- mpcs/src/util/arithmetic/hypercube.rs | 2 +- mpcs/src/util/hash.rs | 4 +- mpcs/src/util/merkle_tree.rs | 47 ++++--- mpcs/src/util/parallel.rs | 5 +- mpcs/src/util/transcript.rs | 13 +- 19 files changed, 356 insertions(+), 394 deletions(-) create mode 100644 mpcs/src/basefold/encoding/utils.rs diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index ce2378df3..bb923ebd7 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -242,7 +242,7 @@ where fn batch_commit_and_write( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, ) -> Result { let timer = start_timer!(|| "Basefold::batch_commit_and_write"); @@ -263,14 +263,14 @@ where fn batch_commit( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], ) -> Result { // assumptions // 1. there must be at least one polynomial // 2. all polynomials must exist in the same field type // 3. all polynomials must have the same number of variables - if polys.len() == 0 { + if polys.is_empty() { return Err(Error::InvalidPcsParam( "cannot batch commit to zero polynomials".to_string(), )); @@ -330,8 +330,8 @@ where assert!(comm.num_polys == 1); let (trees, oracles) = commit_phase::( &pp.encoding_params, - &point, - &comm, + point, + comm, transcript, poly.num_vars, poly.num_vars - Spec::get_basecode_msg_size_log(), @@ -341,7 +341,7 @@ where let query_timer = start_timer!(|| "Basefold::open::query_phase"); // Each entry in queried_els stores a list of triples (F, F, i) indicating the // position opened at each round and the two values at that round - let queries = prover_query_phase(transcript, &comm, &oracles, Spec::get_number_queries()); + let queries = prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries()); end_timer!(query_timer); let query_timer = start_timer!(|| "Basefold::open::build_query_result"); @@ -366,8 +366,8 @@ where /// not very useful in ceno. fn batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, @@ -391,12 +391,7 @@ where }) } - validate_input( - "batch open", - pp.get_max_message_size_log(), - &polys.clone(), - &points.to_vec(), - )?; + validate_input("batch open", pp.get_max_message_size_log(), polys, points)?; let sumcheck_timer = start_timer!(|| "Basefold::batch_open::initial sumcheck"); // evals.len() is the batch size, i.e., how many polynomials are being opened together @@ -463,7 +458,7 @@ where .map(|((scalar, poly), point)| { inner_product( &poly_iter_ext(poly).collect_vec(), - build_eq_x_r_vec(&point).iter(), + build_eq_x_r_vec(point).iter(), ) * scalar * E::from(1 << (num_vars - poly.num_vars)) // When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube @@ -528,7 +523,7 @@ where ); *scalar * evals_from_sum_check - * &eq_xy_eval(point.as_slice(), &challenges[0..point.len()]) + * eq_xy_eval(point.as_slice(), &challenges[0..point.len()]) }) .sum::(); assert_eq!(new_target_sum, desired_sum); @@ -541,7 +536,7 @@ where let (trees, oracles) = batch_commit_phase::( &pp.encoding_params, &point, - comms.as_slice(), + comms, transcript, num_vars, num_vars - Spec::get_basecode_msg_size_log(), @@ -553,7 +548,7 @@ where let query_result = batch_prover_query_phase( transcript, 1 << (num_vars + Spec::get_rate_log()), - comms.as_slice(), + comms, &oracles, Spec::get_number_queries(), ); @@ -564,7 +559,7 @@ where BatchedQueriesResultWithMerklePath::from_batched_query_result( query_result, &trees, - &comms, + comms, ); end_timer!(query_timer); @@ -583,7 +578,7 @@ where /// 3. The point is already a random point generated by a sum-check. fn simple_batch_open( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], comm: &Self::CommitmentWithData, point: &[E], evals: &[E], @@ -623,9 +618,9 @@ where // the new target sum, where coeffs is computed as follows let (trees, oracles) = simple_batch_commit_phase::( &pp.encoding_params, - &point, + point, &eq_xt, - &comm, + comm, transcript, num_vars, num_vars - Spec::get_basecode_msg_size_log(), @@ -635,12 +630,8 @@ where let query_timer = start_timer!(|| "Basefold::open::query_phase"); // Each entry in queried_els stores a list of triples (F, F, i) indicating the // position opened at each round and the two values at that round - let queries = simple_batch_prover_query_phase( - transcript, - &comm, - &oracles, - Spec::get_number_queries(), - ); + let queries = + simple_batch_prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries()); end_timer!(query_timer); let query_timer = start_timer!(|| "Basefold::open::build_query_result"); @@ -779,7 +770,7 @@ where &roots, comm, eq.as_slice(), - &eval, + eval, &hasher, ); end_timer!(timer); @@ -789,7 +780,7 @@ where fn batch_verify( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, @@ -798,10 +789,10 @@ where // let key = "RAYON_NUM_THREADS"; // env::set_var(key, "32"); let hasher = new_hasher::(); - let comms = comms.into_iter().collect_vec(); + let comms = comms.iter().collect_vec(); let num_vars = points.iter().map(|point| point.len()).max().unwrap(); let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); - validate_input("batch verify", num_vars, &vec![], &points.to_vec())?; + validate_input("batch verify", num_vars, &[], points)?; let poly_num_vars = comms.iter().map(|c| c.num_vars().unwrap()).collect_vec(); evals.iter().for_each(|eval| { assert_eq!( diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index e99a1d10a..da8b4ad75 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -60,7 +60,7 @@ where // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube let build_eq_timer = start_timer!(|| "Basefold::open"); - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); end_timer!(build_eq_timer); reverse_index_bits_in_place(&mut eq); @@ -142,10 +142,11 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + (trees, oracles) } // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) +#[allow(clippy::too_many_arguments)] pub fn batch_commit_phase>( pp: &>::ProverParameters, point: &[E], @@ -177,7 +178,7 @@ where running_oracle .iter_mut() .zip_eq(field_type_iter_ext(&comm.get_codewords()[0])) - .for_each(|(r, a)| *r += E::from(a) * coeffs[index]); + .for_each(|(r, a)| *r += a * coeffs[index]); }); end_timer!(build_oracle_timer); @@ -196,16 +197,16 @@ where // to align the polynomials to the variable with index 0 before adding them // together. So each element is repeated by // sum_of_all_evals_for_sumcheck.len() / bh_evals.len() times - *r += E::from(field_type_index_ext( + *r += field_type_index_ext( &comm.polynomials_bh_evals[0], pos >> (num_vars - log2_strict(comm.polynomials_bh_evals[0].len())), - )) * coeffs[index] + ) * coeffs[index] }); }); end_timer!(build_oracle_timer); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); reverse_index_bits_in_place(&mut eq); let sumcheck_timer = start_timer!(|| "Basefold first round"); @@ -256,7 +257,7 @@ where running_oracle .iter_mut() .zip_eq(field_type_iter_ext(&comm.get_codewords()[0])) - .for_each(|(r, a)| *r += E::from(a) * coeffs[index]); + .for_each(|(r, a)| *r += a * coeffs[index]); }); } else { // The difference of the last round is that we don't need to compute the message, @@ -296,10 +297,11 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + (trees, oracles) } // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) +#[allow(clippy::too_many_arguments)] pub fn simple_batch_commit_phase>( pp: &>::ProverParameters, point: &[E], @@ -331,7 +333,7 @@ where // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube let build_eq_timer = start_timer!(|| "Basefold::open"); - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); end_timer!(build_eq_timer); reverse_index_bits_in_place(&mut eq); @@ -404,7 +406,7 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + (trees, oracles) } fn basefold_one_round_by_interpolation_weights>( diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs index 92bdae11d..9fa651e2f 100644 --- a/mpcs/src/basefold/encoding.rs +++ b/mpcs/src/basefold/encoding.rs @@ -1,6 +1,8 @@ use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; +mod utils; + mod basecode; pub use basecode::{Basecode, BasecodeDefaultSpec}; @@ -132,27 +134,25 @@ pub trait EncodingScheme: std::fmt::Debug + Clone { } } -fn concatenate_field_types(coeffs: &Vec>) -> FieldType { +fn concatenate_field_types(coeffs: &[FieldType]) -> FieldType { match coeffs[0] { FieldType::Ext(_) => { let res = coeffs .iter() - .map(|x| match x { - FieldType::Ext(x) => x.iter().map(|x| *x), + .flat_map(|x| match x { + FieldType::Ext(x) => x.iter().copied(), _ => unreachable!(), }) - .flatten() .collect::>(); FieldType::Ext(res) } FieldType::Base(_) => { let res = coeffs .iter() - .map(|x| match x { - FieldType::Base(x) => x.iter().map(|x| *x), + .flat_map(|x| match x { + FieldType::Base(x) => x.iter().copied(), _ => unreachable!(), }) - .flatten() .collect::>(); FieldType::Base(res) } diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs index 27922487a..5f13035a7 100644 --- a/mpcs/src/basefold/encoding/basecode.rs +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -5,11 +5,10 @@ use crate::{ util::{ arithmetic::base_from_raw_bytes, log2_strict, num_of_bytes, plonky2_util::reverse_bits, }, - Error, + vec_mut, Error, }; use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; use ark_std::{end_timer, start_timer}; -use ctr; use ff::{BatchInverter, Field, PrimeField}; use ff_ext::ExtensionField; use generic_array::GenericArray; @@ -106,7 +105,7 @@ where type VerifierParameters = BasecodeVerifierParameters; fn setup(max_msg_size_log: usize, rng_seed: [u8; 32]) -> Self::PublicParameters { - let rng = ChaCha8Rng::from_seed(rng_seed.clone()); + let rng = ChaCha8Rng::from_seed(rng_seed); let (table_w_weights, table) = get_table_aes::(max_msg_size_log, Spec::get_rate_log(), &mut rng.clone()); BasecodeParameters { @@ -137,11 +136,11 @@ where Self::ProverParameters { table_w_weights: pp.table_w_weights.clone(), table: pp.table.clone(), - rng_seed: pp.rng_seed.clone(), + rng_seed: pp.rng_seed, _phantom: PhantomData, }, Self::VerifierParameters { - rng_seed: pp.rng_seed.clone(), + rng_seed: pp.rng_seed, aes_key: key, aes_iv: iv, }, @@ -175,15 +174,15 @@ where } fn get_number_queries() -> usize { - return Spec::get_number_queries(); + Spec::get_number_queries() } fn get_rate_log() -> usize { - return Spec::get_rate_log(); + Spec::get_rate_log() } fn get_basecode_msg_size_log() -> usize { - return Spec::get_basecode_msg_size_log(); + Spec::get_basecode_msg_size_log() } fn message_is_left_and_right_folding() -> bool { @@ -250,7 +249,7 @@ fn get_basecode(poly: &Vec, rate: usize, message_size: usize) -> Ve target .iter_mut() .enumerate() - .for_each(|(i, target)| *target = horner(&chunk[..], &domain[i])); + .for_each(|(i, target)| *target = horner(chunk, &domain[i])); target }) .collect::>>(); @@ -265,7 +264,7 @@ pub fn evaluate_over_foldable_domain_generic_basecode( num_coeffs: usize, log_rate: usize, base_codewords: Vec>, - table: &Vec>, + table: &[Vec], ) -> FieldType { let timer = start_timer!(|| "evaluate over foldable domain"); let k = num_coeffs; @@ -284,44 +283,28 @@ pub fn evaluate_over_foldable_domain_generic_basecode( let level = &table[i + log_rate]; // chunk_size is equal to 1 << (i+1), i.e., the codeword size after the current iteration // half_chunk is equal to 1 << i, i.e. the current codeword size - chunk_size = chunk_size << 1; + chunk_size <<= 1; assert_eq!(level.len(), chunk_size >> 1); - match coeffs_with_bc { - FieldType::Ext(ref mut coeffs_with_bc) => { - coeffs_with_bc.par_chunks_mut(chunk_size).for_each(|chunk| { - let half_chunk = chunk_size >> 1; - for j in half_chunk..chunk_size { - // Suppose the current codewords are (a, b) - // The new codeword is computed by two halves: - // left = a + t * b - // right = a - t * b - let rhs = chunk[j] * E::from(level[j - half_chunk]); - chunk[j] = chunk[j - half_chunk] - rhs; - chunk[j - half_chunk] = chunk[j - half_chunk] + rhs; - } - }); - } - FieldType::Base(ref mut coeffs_with_bc) => { - coeffs_with_bc.par_chunks_mut(chunk_size).for_each(|chunk| { - let half_chunk = chunk_size >> 1; - for j in half_chunk..chunk_size { - // Suppose the current codewords are (a, b) - // The new codeword is computed by two halves: - // left = a + t * b - // right = a - t * b - let rhs = chunk[j] * level[j - half_chunk]; - chunk[j] = chunk[j - half_chunk] - rhs; - chunk[j - half_chunk] = chunk[j - half_chunk] + rhs; - } - }); - } - _ => unreachable!(), - } + vec_mut!(coeffs_with_bc, |c| { + c.par_chunks_mut(chunk_size).for_each(|chunk| { + let half_chunk = chunk_size >> 1; + for j in half_chunk..chunk_size { + // Suppose the current codewords are (a, b) + // The new codeword is computed by two halves: + // left = a + t * b + // right = a - t * b + let rhs = chunk[j] * level[j - half_chunk]; + chunk[j] = chunk[j - half_chunk] - rhs; + chunk[j - half_chunk] += rhs; + } + }); + }); } end_timer!(timer); coeffs_with_bc } +#[allow(clippy::type_complexity)] pub fn get_table_aes( poly_size_log: usize, rate: usize, @@ -355,7 +338,7 @@ pub fn get_table_aes( // Collect the bytes into field elements let flat_table: Vec = dest .par_chunks_exact(num_of_bytes::(1)) - .map(|chunk| base_from_raw_bytes::(&chunk.to_vec())) + .map(|chunk| base_from_raw_bytes::(chunk)) .collect::>(); // Now, flat_table is a field vector of size n, filled with random field elements @@ -396,7 +379,7 @@ pub fn get_table_aes( unflattened_table_w_weights[i] = level; } - return (unflattened_table_w_weights, unflattened_table); + (unflattened_table_w_weights, unflattened_table) } pub fn query_root_table_from_rng_aes( @@ -421,9 +404,7 @@ pub fn query_root_table_from_rng_aes( let mut dest: Vec = vec![0u8; bytes]; cipher.apply_keystream(&mut dest); - let res = base_from_raw_bytes::(&dest); - - res + base_from_raw_bytes::(&dest) } #[cfg(test)] diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 7e6fc10e4..79c047060 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use super::{EncodingProverParameters, EncodingScheme}; use crate::{ util::{field_type_index_mul_base, log2_strict, plonky2_util::reverse_bits}, - Error, + vec_mut, Error, }; use ark_std::{end_timer, start_timer}; use ff::{Field, PrimeField}; @@ -31,7 +31,7 @@ pub fn fft_root_table(lg_n: usize) -> FftRootTable { // bases[i] = g^2^i, for i = 0, ..., lg_n - 1 // Note that the end of bases is g^{n/2} = -1 let mut bases = Vec::with_capacity(lg_n); - let mut base = F::ROOT_OF_UNITY.pow(&[(1 << (F::S - lg_n as u32)) as u64]); + let mut base = F::ROOT_OF_UNITY.pow([(1 << (F::S - lg_n as u32)) as u64]); bases.push(base); for _ in 1..lg_n { base = base.square(); // base = g^2^_ @@ -72,34 +72,20 @@ fn ifft( let n_inv = (E::BaseField::ONE + E::BaseField::ONE) .invert() .unwrap() - .pow(&[lg_n as u64]); + .pow([lg_n as u64]); fft(poly, zero_factor, root_table); // We reverse all values except the first, and divide each by n. field_type_index_mul_base(poly, 0, &n_inv); field_type_index_mul_base(poly, n / 2, &n_inv); - match poly { - FieldType::Base(poly) => { - for i in 1..(n / 2) { - let j = n - i; - let coeffs_i = poly[j] * n_inv; - let coeffs_j = poly[i] * n_inv; - poly[i] = coeffs_i; - poly[j] = coeffs_j; - } - } - FieldType::Ext(poly) => { - for i in 1..(n / 2) { - let j = n - i; - let coeffs_i = poly[j] * n_inv; - let coeffs_j = poly[i] * n_inv; - poly[i] = coeffs_i; - poly[j] = coeffs_j; - } - } - _ => panic!("Unsupported field type"), - } + vec_mut!(|poly| for i in 1..(n / 2) { + let j = n - i; + let coeffs_i = poly[j] * n_inv; + let coeffs_j = poly[i] * n_inv; + poly[i] = coeffs_i; + poly[j] = coeffs_j; + }) } /// Core FFT implementation. @@ -111,7 +97,7 @@ fn fft_classic_inner( ) { // We've already done the first lg_packed_width (if they were required) iterations. - for lg_half_m in r..lg_n { + for (lg_half_m, cur_root_table) in root_table.iter().enumerate().take(lg_n).skip(r) { let n = 1 << lg_n; let lg_m = lg_half_m + 1; let m = 1 << lg_m; // Subarray size (in field elements). @@ -119,32 +105,18 @@ fn fft_classic_inner( debug_assert!(half_m != 0); // omega values for this iteration, as slice of vectors - let omega_table = &root_table[lg_half_m][..]; - match values { - FieldType::Base(values) => { - for k in (0..n).step_by(m) { - for j in 0..half_m { - let omega = omega_table[j]; - let t = omega * values[k + half_m + j]; - let u = values[k + j]; - values[k + j] = u + t; - values[k + half_m + j] = u - t; - } - } - } - FieldType::Ext(values) => { - for k in (0..n).step_by(m) { - for j in 0..half_m { - let omega = omega_table[j]; - let t = values[k + half_m + j] * omega; - let u = values[k + j]; - values[k + j] = u + t; - values[k + half_m + j] = u - t; - } + let omega_table = &cur_root_table[..]; + vec_mut!(|values| { + for k in (0..n).step_by(m) { + for j in 0..half_m { + let omega = omega_table[j]; + let t = values[k + half_m + j] * omega; + let u = values[k + j]; + values[k + j] = u + t; + values[k + half_m + j] = u - t; } } - _ => panic!("Unsupported field type"), - } + }) } } @@ -159,15 +131,7 @@ pub fn fft( r: usize, root_table: &[Vec], ) { - match values { - FieldType::Base(values) => { - reverse_index_bits_in_place(values); - } - FieldType::Ext(values) => { - reverse_index_bits_in_place(values); - } - _ => panic!("Unsupported field type"), - } + vec_mut!(|values| reverse_index_bits_in_place(values)); let n = values.len(); let lg_n = log2_strict(n); @@ -215,21 +179,12 @@ pub fn coset_fft( root_table: &[Vec], ) { let mut shift_power = E::BaseField::ONE; - match coeffs { - FieldType::Base(coeffs) => { - for coeff in coeffs.iter_mut() { - *coeff *= shift_power; - shift_power *= shift; - } - } - FieldType::Ext(coeffs) => { - for coeff in coeffs.iter_mut() { - *coeff *= shift_power; - shift_power *= shift; - } + vec_mut!(|coeffs| { + for coeff in coeffs.iter_mut() { + *coeff *= shift_power; + shift_power *= shift; } - _ => panic!("Unsupported field type"), - } + }); fft(coeffs, zero_factor, root_table); } @@ -347,7 +302,7 @@ where fft_root_table: pp.fft_root_table [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] .iter() - .map(|v| v.clone()) + .cloned() .chain( pp.fft_root_table [Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] @@ -373,15 +328,15 @@ where } fn get_number_queries() -> usize { - return Spec::get_number_queries(); + Spec::get_number_queries() } fn get_rate_log() -> usize { - return Spec::get_rate_log(); + Spec::get_rate_log() } fn get_basecode_msg_size_log() -> usize { - return Spec::get_basecode_msg_size_log(); + Spec::get_basecode_msg_size_log() } fn message_is_left_and_right_folding() -> bool { @@ -442,7 +397,7 @@ where } else { // In this case, the level-th row of fft root table of the verifier // only stores the first 2^(level+1)-th roots of unity. - vp.fft_root_table[level][0].pow(&[index as u64]) + vp.fft_root_table[level][0].pow([index as u64]) } * vp.gamma_powers[vp.full_message_size_log + Spec::get_rate_log() - level - 1]; let x1 = -x0; // The weight is 1/(x1-x0) = -1/(2x0) @@ -462,7 +417,7 @@ where } else { // In this case, this level of fft root table of the verifier // only stores the first 2^(level+1)-th root of unity. - vp.fft_root_table[level][0].pow(&[(1 << (level + 1)) - index as u64]) + vp.fft_root_table[level][0].pow([(1 << (level + 1)) - index as u64]) }; (E::from(x0), E::from(x1), E::from(w)) } @@ -508,7 +463,7 @@ impl RSCode { let k = 1 << (full_message_size_log - lg_m); coset_fft( &mut ret, - E::BaseField::MULTIPLICATIVE_GENERATOR.pow(&[k]), + E::BaseField::MULTIPLICATIVE_GENERATOR.pow([k]), Spec::get_rate_log(), fft_root_table, ); @@ -527,10 +482,10 @@ impl RSCode { // x0 is the index-th 2^(level+1)-th root of unity, multiplied by // the shift factor at level+1, which is gamma^2^(full_codeword_log_n - level - 1). let x0 = E::BaseField::ROOT_OF_UNITY - .pow(&[1 << (E::BaseField::S - (level as u32 + 1))]) - .pow(&[index as u64]) + .pow([1 << (E::BaseField::S - (level as u32 + 1))]) + .pow([index as u64]) * E::BaseField::MULTIPLICATIVE_GENERATOR - .pow(&[1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]); + .pow([1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]); let x1 = -x0; let w = (x1 - x0).invert().unwrap(); (E::from(x0), E::from(x1), E::from(w)) @@ -538,11 +493,11 @@ impl RSCode { } #[allow(unused)] -fn naive_fft(poly: &Vec, rate: usize, shift: E::BaseField) -> Vec { +fn naive_fft(poly: &[E], rate: usize, shift: E::BaseField) -> Vec { let timer = start_timer!(|| "Encode RSCode"); let message_size = poly.len(); let domain_size_bit = log2_strict(message_size * rate); - let root = E::BaseField::ROOT_OF_UNITY.pow(&[1 << (E::BaseField::S - domain_size_bit as u32)]); + let root = E::BaseField::ROOT_OF_UNITY.pow([1 << (E::BaseField::S - domain_size_bit as u32)]); // The domain is shift * H where H is the multiplicative subgroup of size // message_size * rate. let mut domain = Vec::::with_capacity(message_size * rate); @@ -553,7 +508,7 @@ fn naive_fft(poly: &Vec, rate: usize, shift: E::BaseField) let mut res = vec![E::ZERO; message_size * rate]; res.iter_mut() .enumerate() - .for_each(|(i, target)| *target = horner(&poly[..], &E::from(domain[i]))); + .for_each(|(i, target)| *target = horner(poly, &E::from(domain[i]))); end_timer!(timer); res @@ -573,9 +528,7 @@ mod tests { fn test_naive_fft() { let num_vars = 5; - let poly: Vec = (0..(1 << num_vars)) - .map(|i| GoldilocksExt2::from(i)) - .collect(); + let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); let mut poly2 = FieldType::Ext(poly.clone()); let naive = naive_fft::(&poly, 1, Goldilocks::ONE); @@ -655,9 +608,7 @@ mod tests { fn test_ifft() { let num_vars = 5; - let poly: Vec = (0..(1 << num_vars)) - .map(|i| GoldilocksExt2::from(i)) - .collect(); + let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); let mut poly = FieldType::Ext(poly); let original = poly.clone(); @@ -705,7 +656,7 @@ mod tests { pub fn test_colinearity() { let num_vars = 10; - let poly: Vec = (0..(1 << num_vars)).map(|i| E::from(i)).collect(); + let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); let poly = FieldType::Ext(poly); let rng_seed = [0; 32]; @@ -744,7 +695,7 @@ mod tests { pub fn test_low_degree() { let num_vars = 10; - let poly: Vec = (0..(1 << num_vars)).map(|i| E::from(i)).collect(); + let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); let poly = FieldType::Ext(poly); let rng_seed = [0; 32]; @@ -799,9 +750,9 @@ mod tests { reverse_index_bits_in_place(&mut left_right_diff); assert_eq!(left_right_diff[1], c1 - c_mid1); let root_of_unity_inv = F::ROOT_OF_UNITY_INV - .pow(&[1 << (F::S as usize - log2_strict(left_right_diff.len()) - 1)]); + .pow([1 << (F::S as usize - log2_strict(left_right_diff.len()) - 1)]); for (i, coeff) in left_right_diff.iter_mut().enumerate() { - *coeff *= root_of_unity_inv.pow(&[i as u64]); + *coeff *= root_of_unity_inv.pow([i as u64]); } assert_eq!(left_right_diff[0], c0 - c_mid); assert_eq!(left_right_diff[1], (c1 - c_mid1) * root_of_unity_inv); @@ -851,13 +802,13 @@ mod tests { let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 1); let root_of_unity = - F::ROOT_OF_UNITY.pow(&[1 << (F::S as usize - log2_strict(codeword.len()))]); - assert_eq!(root_of_unity.pow(&[codeword.len() as u64]), F::ONE); - assert_eq!(root_of_unity.pow(&[(codeword.len() >> 1) as u64]), -F::ONE); + F::ROOT_OF_UNITY.pow([1 << (F::S as usize - log2_strict(codeword.len()))]); + assert_eq!(root_of_unity.pow([codeword.len() as u64]), F::ONE); + assert_eq!(root_of_unity.pow([(codeword.len() >> 1) as u64]), -F::ONE); assert_eq!( folding_coeffs.0, E::from(F::MULTIPLICATIVE_GENERATOR) - * E::from(root_of_unity).pow(&[(codeword.len() >> 2) as u64]) + * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) ); assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); assert_eq!( diff --git a/mpcs/src/basefold/encoding/utils.rs b/mpcs/src/basefold/encoding/utils.rs new file mode 100644 index 000000000..36137526a --- /dev/null +++ b/mpcs/src/basefold/encoding/utils.rs @@ -0,0 +1,35 @@ +#[macro_export] +macro_rules! vec_mut { + ($a:ident, |$tmp_a:ident| $op:expr) => { + match $a { + multilinear_extensions::mle::FieldType::Base(ref mut $tmp_a) => $op, + multilinear_extensions::mle::FieldType::Ext(ref mut $tmp_a) => $op, + _ => unreachable!(), + } + }; + (|$a:ident| $op:expr) => { + vec_mut!($a, |$a| $op) + }; +} + +#[macro_export] +macro_rules! vec_map { + ($a:ident, |$tmp_a:ident| $op:expr) => { + match &$a { + multilinear_extensions::mle::FieldType::Base(a) => { + let $tmp_a = &a[..]; + let out = $op; + multilinear_extensions::mle::FieldType::Base(out) + } + multilinear_extensions::mle::FieldType::Ext(a) => { + let $tmp_a = &a[..]; + let out = $op; + multilinear_extensions::mle::FieldType::Base(out) + } + _ => unreachable!(), + } + }; + (|$a:ident| $op:expr) => { + vec_map!($a, |$a| $op) + }; +} diff --git a/mpcs/src/basefold/query_phase.rs b/mpcs/src/basefold/query_phase.rs index 150c0b9c0..57d41db68 100644 --- a/mpcs/src/basefold/query_phase.rs +++ b/mpcs/src/basefold/query_phase.rs @@ -29,7 +29,7 @@ use super::{ pub fn prover_query_phase( transcript: &mut impl TranscriptWrite, E>, comm: &BasefoldCommitmentWithData, - oracles: &Vec>, + oracles: &[Vec], num_verifier_queries: usize, ) -> QueriesResult where @@ -49,7 +49,7 @@ where .map(|x_index| { ( *x_index, - basefold_get_query::(&comm.get_codewords()[0], &oracles, *x_index), + basefold_get_query::(&comm.get_codewords()[0], oracles, *x_index), ) }) .collect(), @@ -60,7 +60,7 @@ pub fn batch_prover_query_phase( transcript: &mut impl TranscriptWrite, E>, codeword_size: usize, comms: &[BasefoldCommitmentWithData], - oracles: &Vec>, + oracles: &[Vec], num_verifier_queries: usize, ) -> BatchedQueriesResult where @@ -80,7 +80,7 @@ where .map(|x_index| { ( *x_index, - batch_basefold_get_query::(comms, &oracles, codeword_size, *x_index), + batch_basefold_get_query::(comms, oracles, codeword_size, *x_index), ) }) .collect(), @@ -90,7 +90,7 @@ where pub fn simple_batch_prover_query_phase( transcript: &mut impl TranscriptWrite, E>, comm: &BasefoldCommitmentWithData, - oracles: &Vec>, + oracles: &[Vec], num_verifier_queries: usize, ) -> SimpleBatchQueriesResult where @@ -110,21 +110,22 @@ where .map(|x_index| { ( *x_index, - simple_batch_basefold_get_query::(comm.get_codewords(), &oracles, *x_index), + simple_batch_basefold_get_query::(comm.get_codewords(), oracles, *x_index), ) }) .collect(), } } +#[allow(clippy::too_many_arguments)] pub fn verifier_query_phase>( vp: &>::VerifierParameters, queries: &QueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, + sum_check_messages: &[Vec], + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - final_message: &Vec, + final_message: &[E], roots: &Vec>, comm: &BasefoldCommitment, partial_eq: &[E], @@ -136,7 +137,7 @@ pub fn verifier_query_phase>( let timer = start_timer!(|| "Verifier query phase"); let encode_timer = start_timer!(|| "Encode final codeword"); - let mut message = final_message.clone(); + let mut message = final_message.to_vec(); interpolate_over_boolean_hypercube(&mut message); if >::message_is_even_and_odd_folding() { reverse_index_bits_in_place(&mut message); @@ -186,16 +187,17 @@ pub fn verifier_query_phase>( end_timer!(timer); } +#[allow(clippy::too_many_arguments)] pub fn batch_verifier_query_phase>( vp: &>::VerifierParameters, queries: &BatchedQueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, + sum_check_messages: &[Vec], + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - final_message: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, + final_message: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], coeffs: &[E], partial_eq: &[E], eval: &E, @@ -205,7 +207,7 @@ pub fn batch_verifier_query_phase>( { let timer = start_timer!(|| "Verifier batch query phase"); let encode_timer = start_timer!(|| "Encode final codeword"); - let mut message = final_message.clone(); + let mut message = final_message.to_vec(); if >::message_is_even_and_odd_folding() { reverse_index_bits_in_place(&mut message); } @@ -259,16 +261,17 @@ pub fn batch_verifier_query_phase>( end_timer!(timer); } +#[allow(clippy::too_many_arguments)] pub fn simple_batch_verifier_query_phase>( vp: &>::VerifierParameters, queries: &SimpleBatchQueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, - batch_coeffs: &Vec, + sum_check_messages: &[Vec], + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - final_message: &Vec, - roots: &Vec>, + final_message: &[E], + roots: &[Digest], comm: &BasefoldCommitment, partial_eq: &[E], evals: &[E], @@ -279,7 +282,7 @@ pub fn simple_batch_verifier_query_phase>::message_is_even_and_odd_folding() { reverse_index_bits_in_place(&mut message); } @@ -337,7 +340,7 @@ pub fn simple_batch_verifier_query_phase( poly_codeword: &FieldType, - oracles: &Vec>, + oracles: &[Vec], x_index: usize, ) -> SingleQueryResult where @@ -373,15 +376,15 @@ where inner: oracle_queries, }; - return SingleQueryResult { + SingleQueryResult { oracle_query, commitment_query, - }; + } } fn batch_basefold_get_query( comms: &[BasefoldCommitmentWithData], - oracles: &Vec>, + oracles: &[Vec], codeword_size: usize, x_index: usize, ) -> BatchedSingleQueryResult @@ -433,8 +436,8 @@ where } fn simple_batch_basefold_get_query( - poly_codewords: &Vec>, - oracles: &Vec>, + poly_codewords: &[FieldType], + oracles: &[Vec], x_index: usize, ) -> SimpleBatchSingleQueryResult where @@ -486,10 +489,10 @@ where inner: oracle_queries, }; - return SimpleBatchSingleQueryResult { + SimpleBatchSingleQueryResult { oracle_query, commitment_query, - }; + } } #[derive(Debug, Copy, Clone, Serialize, Deserialize)] @@ -530,7 +533,7 @@ where } } - pub fn batch(&self, coeffs: &Vec) -> (E, E) { + pub fn batch(&self, coeffs: &[E]) -> (E, E) { match self { SimpleBatchLeavesPair::Ext(x) => { let mut result = (E::ZERO, E::ZERO); @@ -892,12 +895,9 @@ where ) -> Vec> { let ret = self .get_inner() - .into_iter() + .iter() .enumerate() - .map(|(i, query_result)| { - let path = path(i, query_result.index); - path - }) + .map(|(i, query_result)| path(i, query_result.index)) .collect_vec(); ret } @@ -919,7 +919,7 @@ where query_result .merkle_path(path) .into_iter() - .zip(query_result.get_inner_into().into_iter()) + .zip(query_result.get_inner_into()) .map( |(path, codeword_result)| CodewordSingleQueryResultWithMerklePath { query: codeword_result, @@ -936,7 +936,7 @@ where .for_each(|q| q.write_transcript(transcript)); } - fn check_merkle_paths(&self, roots: &Vec>, hasher: &Hasher) { + fn check_merkle_paths(&self, roots: &[Digest], hasher: &Hasher) { // let timer = start_timer!(|| "ListQuery::Check Merkle Path"); self.get_inner() .iter() @@ -972,7 +972,7 @@ where { pub fn from_single_query_result( single_query_result: SingleQueryResult, - oracle_trees: &Vec>, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { assert!(commitment.codeword_tree.height() > 0); @@ -982,7 +982,7 @@ where |i, j| oracle_trees[i].merkle_path_without_leaf_sibling_or_root(j), ), commitment_query: CodewordSingleQueryResultWithMerklePath { - query: single_query_result.commitment_query.clone(), + query: single_query_result.commitment_query, merkle_path: commitment .codeword_tree .merkle_path_without_leaf_sibling_or_root( @@ -1043,14 +1043,15 @@ where } } + #[allow(clippy::too_many_arguments)] pub fn check>( &self, vp: &>::VerifierParameters, - fold_challenges: &Vec, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, index: usize, hasher: &Hasher, @@ -1058,7 +1059,7 @@ where // let timer = start_timer!(|| "Checking codeword single query"); self.oracle_query.check_merkle_paths(roots, hasher); self.commitment_query - .check_merkle_path(&Digest(comm.root().0.try_into().unwrap()), hasher); + .check_merkle_path(&Digest(comm.root().0), hasher); let (mut curr_left, mut curr_right) = self.commitment_query.query.codepoints.as_ext(); @@ -1118,7 +1119,7 @@ where { pub fn from_query_result( query_result: QueriesResult, - oracle_trees: &Vec>, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { Self { @@ -1197,14 +1198,15 @@ where } } + #[allow(clippy::too_many_arguments)] pub fn check>( &self, vp: &>::VerifierParameters, - fold_challenges: &Vec, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, hasher: &Hasher, ) { @@ -1250,8 +1252,8 @@ where { pub fn from_batched_single_query_result( batched_single_query_result: BatchedSingleQueryResult, - oracle_trees: &Vec>, - commitments: &Vec>, + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithData], ) -> Self { Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( @@ -1324,22 +1326,29 @@ where } } + #[allow(clippy::too_many_arguments)] pub fn check>( &self, vp: &>::VerifierParameters, - fold_challenges: &Vec, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - final_codeword: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, + final_codeword: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], coeffs: &[E], index: usize, hasher: &Hasher, ) { self.oracle_query.check_merkle_paths(roots, hasher); - self.commitments_query - .check_merkle_paths(&comms.iter().map(|comm| comm.root()).collect(), hasher); + self.commitments_query.check_merkle_paths( + comms + .iter() + .map(|comm| comm.root()) + .collect_vec() + .as_slice(), + hasher, + ); // end_timer!(commit_timer); let mut curr_left = E::ZERO; @@ -1358,7 +1367,7 @@ where .collect_vec(); matching_comms.iter().for_each(|index| { - let query = self.commitments_query.get_inner()[*index].query.clone(); + let query = self.commitments_query.get_inner()[*index].query; assert_eq!(query.index >> 1, left_index >> 1); curr_left += query.left_ext() * coeffs[*index]; curr_right += query.right_ext() * coeffs[*index]; @@ -1404,7 +1413,7 @@ where matching_comms.iter().for_each(|index| { let query: CodewordSingleQueryResult = - self.commitments_query.get_inner()[*index].query.clone(); + self.commitments_query.get_inner()[*index].query; assert_eq!(query.index >> 1, next_index >> 1); if next_index & 1 == 0 { res += query.left_ext() * coeffs[*index]; @@ -1444,8 +1453,8 @@ where { pub fn from_batched_query_result( batched_query_result: BatchedQueriesResult, - oracle_trees: &Vec>, - commitments: &Vec>, + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithData], ) -> Self { Self { inner: batched_query_result @@ -1523,15 +1532,16 @@ where } } + #[allow(clippy::too_many_arguments)] pub fn check>( &self, vp: &>::VerifierParameters, - fold_challenges: &Vec, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - final_codeword: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, + final_codeword: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], coeffs: &[E], hasher: &Hasher, ) { @@ -1767,7 +1777,7 @@ where { pub fn from_single_query_result( single_query_result: SimpleBatchSingleQueryResult, - oracle_trees: &Vec>, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { Self { @@ -1843,15 +1853,16 @@ where } } + #[allow(clippy::too_many_arguments)] pub fn check>( &self, vp: &>::VerifierParameters, - fold_challenges: &Vec, - batch_coeffs: &Vec, + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, index: usize, hasher: &Hasher, @@ -1859,7 +1870,7 @@ where let timer = start_timer!(|| "Checking codeword single query"); self.oracle_query.check_merkle_paths(roots, hasher); self.commitment_query - .check_merkle_path(&Digest(comm.root().0.try_into().unwrap()), hasher); + .check_merkle_path(&Digest(comm.root().0), hasher); let (mut curr_left, mut curr_right) = self.commitment_query.query.leaves.batch(batch_coeffs); @@ -1922,7 +1933,7 @@ where { pub fn from_query_result( query_result: SimpleBatchQueriesResult, - oracle_trees: &Vec>, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { Self { @@ -2005,15 +2016,16 @@ where } } + #[allow(clippy::too_many_arguments)] pub fn check>( &self, vp: &>::VerifierParameters, - fold_challenges: &Vec, - batch_coeffs: &Vec, + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, hasher: &Hasher, ) { diff --git a/mpcs/src/basefold/structure.rs b/mpcs/src/basefold/structure.rs index 4318421f8..0a1972c38 100644 --- a/mpcs/src/basefold/structure.rs +++ b/mpcs/src/basefold/structure.rs @@ -118,21 +118,21 @@ where } } -impl Into> for BasefoldCommitmentWithData +impl From> for Digest where E::BaseField: Serialize + DeserializeOwned, { - fn into(self) -> Digest { - self.get_root_as() + fn from(val: BasefoldCommitmentWithData) -> Self { + val.get_root_as() } } -impl Into> for &BasefoldCommitmentWithData +impl From<&BasefoldCommitmentWithData> for BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - fn into(self) -> BasefoldCommitment { - self.to_commitment() + fn from(val: &BasefoldCommitmentWithData) -> Self { + val.to_commitment() } } diff --git a/mpcs/src/basefold/sumcheck.rs b/mpcs/src/basefold/sumcheck.rs index a6a84f0eb..bb5312111 100644 --- a/mpcs/src/basefold/sumcheck.rs +++ b/mpcs/src/basefold/sumcheck.rs @@ -7,8 +7,8 @@ use rayon::prelude::{ }; pub fn sum_check_first_round_field_type( - mut eq: &mut Vec, - mut bh_values: &mut FieldType, + eq: &mut Vec, + bh_values: &mut FieldType, ) -> Vec { // The input polynomials are in the form of evaluations. Instead of viewing // every one element as the evaluation of the polynomial at a single point, @@ -16,24 +16,21 @@ pub fn sum_check_first_round_field_type( // a single point, leaving the first variable free, and obtaining a univariate // polynomial. The one_level_interp_hc transforms the evaluation forms into // the coefficient forms, for every of these partial polynomials. - one_level_interp_hc(&mut eq); - one_level_interp_hc_field_type(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc_field_type(bh_values); parallel_pi_field_type(bh_values, eq) // p_i(&bh_values, &eq) } -pub fn sum_check_first_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, -) -> Vec { +pub fn sum_check_first_round(eq: &mut Vec, bh_values: &mut Vec) -> Vec { // The input polynomials are in the form of evaluations. Instead of viewing // every one element as the evaluation of the polynomial at a single point, // we can view every two elements as partially evaluating the polynomial at // a single point, leaving the first variable free, and obtaining a univariate // polynomial. The one_level_interp_hc transforms the evaluation forms into // the coefficient forms, for every of these partial polynomials. - one_level_interp_hc(&mut eq); - one_level_interp_hc(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc(bh_values); parallel_pi(bh_values, eq) // p_i(&bh_values, &eq) } @@ -51,7 +48,7 @@ pub fn one_level_interp_hc(evals: &mut Vec) { return; } evals.par_chunks_mut(2).for_each(|chunk| { - chunk[1] = chunk[1] - chunk[0]; + chunk[1] -= chunk[0]; }); } @@ -68,15 +65,15 @@ pub fn one_level_eval_hc(evals: &mut Vec, challenge: F) { }); } -fn parallel_pi_field_type(evals: &mut FieldType, eq: &mut Vec) -> Vec { +fn parallel_pi_field_type(evals: &mut FieldType, eq: &mut [E]) -> Vec { match evals { - FieldType::Ext(evals) => parallel_pi(evals, &eq), - FieldType::Base(evals) => parallel_pi_base(evals, &eq), + FieldType::Ext(evals) => parallel_pi(evals, eq), + FieldType::Base(evals) => parallel_pi_base(evals, eq), _ => unreachable!(), } } -fn parallel_pi(evals: &Vec, eq: &Vec) -> Vec { +fn parallel_pi(evals: &[F], eq: &[F]) -> Vec { if evals.len() == 1 { return vec![evals[0], evals[0], evals[0]]; } @@ -111,7 +108,7 @@ fn parallel_pi(evals: &Vec, eq: &Vec) -> Vec { coeffs } -fn parallel_pi_base(evals: &Vec, eq: &Vec) -> Vec { +fn parallel_pi_base(evals: &[E::BaseField], eq: &[E]) -> Vec { if evals.len() == 1 { return vec![E::from(evals[0]), E::from(evals[0]), E::from(evals[0])]; } @@ -147,31 +144,27 @@ fn parallel_pi_base(evals: &Vec, eq: &Vec) - } pub fn sum_check_challenge_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, + eq: &mut Vec, + bh_values: &mut Vec, challenge: F, ) -> Vec { // Note that when the last round ends, every two elements are in // the coefficient form. Use the challenge to reduce the two elements // into a single value. This is equivalent to substituting the challenge // to the first variable of the poly. - one_level_eval_hc(&mut bh_values, challenge); - one_level_eval_hc(&mut eq, challenge); + one_level_eval_hc(bh_values, challenge); + one_level_eval_hc(eq, challenge); - one_level_interp_hc(&mut eq); - one_level_interp_hc(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc(bh_values); - parallel_pi(&bh_values, &eq) + parallel_pi(bh_values, eq) // p_i(&bh_values,&eq) } -pub fn sum_check_last_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, - challenge: F, -) { - one_level_eval_hc(&mut bh_values, challenge); - one_level_eval_hc(&mut eq, challenge); +pub fn sum_check_last_round(eq: &mut Vec, bh_values: &mut Vec, challenge: F) { + one_level_eval_hc(bh_values, challenge); + one_level_eval_hc(eq, challenge); } #[cfg(test)] @@ -184,7 +177,7 @@ mod tests { use super::*; use crate::util::test::rand_vec; - pub fn p_i(evals: &Vec, eq: &Vec) -> Vec { + pub fn p_i(evals: &[F], eq: &[F]) -> Vec { if evals.len() == 1 { return vec![evals[0], evals[0], evals[0]]; } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index d596b8472..277f98fd0 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -51,14 +51,14 @@ pub fn pcs_commit_and_write>( pp: &Pcs::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], ) -> Result { Pcs::batch_commit(pp, polys) } -pub fn pcs_batch_commit_and_write<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( +pub fn pcs_batch_commit_and_write>( pp: &Pcs::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, ) -> Result { Pcs::batch_commit_and_write(pp, polys, transcript) @@ -77,8 +77,8 @@ pub fn pcs_open>( pub fn pcs_batch_open>( pp: &Pcs::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Pcs::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, @@ -115,7 +115,7 @@ pub fn pcs_verify>( pub fn pcs_batch_verify<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( vp: &Pcs::VerifierParam, - comms: &Vec, + comms: &[Pcs::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, @@ -155,12 +155,12 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_commit( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], ) -> Result; fn batch_commit_and_write( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, ) -> Result; @@ -175,8 +175,8 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, @@ -188,7 +188,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { /// 3. The point is already a random point generated by a sum-check. fn simple_batch_open( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], comm: &Self::CommitmentWithData, point: &[E], evals: &[E], @@ -220,7 +220,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_verify( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, @@ -259,8 +259,8 @@ where fn ni_batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], ) -> Result, Error> { @@ -282,7 +282,7 @@ where fn ni_batch_verify<'a>( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], proof: &PCSProof, @@ -341,11 +341,11 @@ pub use basefold::{ fn validate_input( function: &str, param_num_vars: usize, - polys: &Vec>, - points: &Vec>, + polys: &[DenseMultilinearExtension], + points: &[Vec], ) -> Result<(), Error> { - let polys = polys.into_iter().collect_vec(); - let points = points.into_iter().collect_vec(); + let polys = polys.iter().collect_vec(); + let points = points.iter().collect_vec(); for poly in polys.iter() { if param_num_vars < poly.num_vars { return Err(err_too_many_variates( @@ -428,15 +428,13 @@ pub mod test_util { // Verify let result = { let mut transcript = T::from_proof(proof.as_slice()); - let result = Pcs::verify( + Pcs::verify( &vp, &Pcs::read_commitment(&vp, &mut transcript).unwrap(), &transcript.squeeze_challenges(num_vars), &transcript.read_field_element_ext().unwrap(), &mut transcript, - ); - - result + ) }; result.unwrap(); } @@ -598,8 +596,7 @@ pub mod test_util { let point = transcript.squeeze_challenges(num_vars); let evals = transcript.read_field_elements_ext(batch_size).unwrap(); - let result = Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript); - result + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript) }; result.unwrap(); diff --git a/mpcs/src/sum_check.rs b/mpcs/src/sum_check.rs index e12233cf9..8ea3517af 100644 --- a/mpcs/src/sum_check.rs +++ b/mpcs/src/sum_check.rs @@ -93,8 +93,8 @@ pub fn evaluate( &|query| evals[&query], &|idx| challenges[idx], &|scalar| -scalar, - &|lhs, rhs| lhs + &rhs, - &|lhs, rhs| lhs * &rhs, + &|lhs, rhs| lhs + rhs, + &|lhs, rhs| lhs * rhs, &|value, scalar| scalar * value, ) } @@ -104,11 +104,7 @@ pub fn lagrange_eval(x: &[F], b: usize) -> F { product(x.iter().enumerate().map( |(idx, x_i)| { - if b.nth_bit(idx) { - *x_i - } else { - F::ONE - x_i - } + if b.nth_bit(idx) { *x_i } else { F::ONE - x_i } }, )) } diff --git a/mpcs/src/sum_check/classic.rs b/mpcs/src/sum_check/classic.rs index 439535e71..853d25602 100644 --- a/mpcs/src/sum_check/classic.rs +++ b/mpcs/src/sum_check/classic.rs @@ -116,10 +116,8 @@ impl<'a, E: ExtensionField> ProverState<'a, E> { .expression .used_rotation() .into_iter() - .filter_map(|rotation| { - (rotation != Rotation::cur()) - .then(|| (rotation, self.bh.rotation_map(rotation))) - }) + .filter(|&rotation| (rotation != Rotation::cur())) + .map(|rotation| (rotation, self.bh.rotation_map(rotation))) .collect::>(); for query in self.expression.used_query() { if query.rotation() != Rotation::cur() { @@ -312,7 +310,7 @@ mod tests { #[test] fn test_sum_check_protocol() { - let polys = vec![ + let polys = [ DenseMultilinearExtension::::from_evaluations_vec( 2, vec![Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)], diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 9e596fe08..470883d7f 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -205,7 +205,7 @@ impl ClassicSumCheckProver for CoefficientsProver { products.iter_mut().for_each(|(lhs, _)| { *lhs *= &rhs; }); - (constant * &rhs, products) + (constant * rhs, products) }, ); Self(constant, flattened) @@ -215,7 +215,7 @@ impl ClassicSumCheckProver for CoefficientsProver { // Initialize h(X) to zero let mut coeffs = Coefficients(FieldType::Ext(vec![E::ZERO; state.expression.degree() + 1])); // First, sum the constant over the hypercube and add to h(X) - coeffs += &(E::from(state.size() as u64) * &self.0); + coeffs += &(E::from(state.size() as u64) * self.0); // Next, for every product of polynomials, where each product is assumed to be exactly 2 // put this into h(X). if self.1.iter().all(|(_, products)| products.len() == 2) { @@ -287,11 +287,11 @@ impl CoefficientsProver { .take(n) .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { let coeff_0 = lhs_0 * rhs_0; - let coeff_2 = (lhs_1 - lhs_0) * &(rhs_1 - rhs_0); + let coeff_2 = (lhs_1 - lhs_0) * (rhs_1 - rhs_0); coeffs[0] += &coeff_0; coeffs[2] += &coeff_2; if !LAZY { - coeffs[1] += &(lhs_1 * rhs_1 - &coeff_0 - &coeff_2); + coeffs[1] += &(lhs_1 * rhs_1 - coeff_0 - coeff_2); } }); }; diff --git a/mpcs/src/util/arithmetic.rs b/mpcs/src/util/arithmetic.rs index f4ca552d8..609f65455 100644 --- a/mpcs/src/util/arithmetic.rs +++ b/mpcs/src/util/arithmetic.rs @@ -80,7 +80,7 @@ pub fn inner_product<'a, 'b, F: Field>( rhs: impl IntoIterator, ) -> F { lhs.into_iter() - .zip_eq(rhs.into_iter()) + .zip_eq(rhs) .map(|(lhs, rhs)| *lhs * rhs) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -92,8 +92,8 @@ pub fn inner_product_three<'a, 'b, 'c, F: Field>( c: impl IntoIterator, ) -> F { a.into_iter() - .zip_eq(b.into_iter()) - .zip_eq(c.into_iter()) + .zip_eq(b) + .zip_eq(c) .map(|((a, b), c)| *a * b * c) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -107,8 +107,9 @@ pub fn barycentric_weights(points: &[F]) -> Vec { points .iter() .enumerate() - .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i)) - .reduce(|acc, value| acc * &value) + .filter(|&(i, _point_i)| (i != j)) + .map(|(_i, point_i)| *point_j - point_i) + .reduce(|acc, value| acc * value) .unwrap_or(F::ONE) }) .collect_vec(); @@ -126,7 +127,7 @@ pub fn barycentric_interpolate(weights: &[F], points: &[F], evals: &[F let sum_inv = coeffs.iter().fold(F::ZERO, |sum, coeff| sum + coeff); (coeffs, sum_inv.invert().unwrap()) }; - inner_product(&coeffs, evals) * &sum_inv + inner_product(&coeffs, evals) * sum_inv } pub fn modulus() -> BigUint { @@ -134,11 +135,7 @@ pub fn modulus() -> BigUint { } pub fn fe_from_bool(value: bool) -> F { - if value { - F::ONE - } else { - F::ZERO - } + if value { F::ONE } else { F::ZERO } } pub fn fe_mod_from_le_bytes(bytes: impl AsRef<[u8]>) -> F { @@ -221,17 +218,17 @@ pub fn interpolate2(points: [(F, F); 2], x: F) -> F { a1 + (x - a0) * (b1 - a1) * (b0 - a0).invert().unwrap() } -pub fn degree_2_zero_plus_one(poly: &Vec) -> F { +pub fn degree_2_zero_plus_one(poly: &[F]) -> F { poly[0] + poly[0] + poly[1] + poly[2] } -pub fn degree_2_eval(poly: &Vec, point: F) -> F { +pub fn degree_2_eval(poly: &[F], point: F) -> F { poly[0] + point * poly[1] + point * point * poly[2] } -pub fn base_from_raw_bytes(bytes: &Vec) -> E::BaseField { +pub fn base_from_raw_bytes(bytes: &[u8]) -> E::BaseField { let mut res = E::BaseField::ZERO; - bytes.into_iter().for_each(|b| { + bytes.iter().for_each(|b| { res += E::BaseField::from(u64::from(*b)); }); res diff --git a/mpcs/src/util/arithmetic/hypercube.rs b/mpcs/src/util/arithmetic/hypercube.rs index 8de970499..ccd6294e3 100644 --- a/mpcs/src/util/arithmetic/hypercube.rs +++ b/mpcs/src/util/arithmetic/hypercube.rs @@ -29,7 +29,7 @@ pub fn interpolate_over_boolean_hypercube(evals: &mut Vec) { evals.par_chunks_mut(chunk_size).for_each(|chunk| { let half_chunk = chunk_size >> 1; for j in half_chunk..chunk_size { - chunk[j] = chunk[j] - chunk[j - half_chunk]; + chunk[j] -= chunk[j - half_chunk]; } }); } diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 0583f579d..fa8010970 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -41,8 +41,8 @@ pub fn hash_two_leaves_base( } pub fn hash_two_leaves_batch_ext( - a: &Vec, - b: &Vec, + a: &[E], + b: &[E], hasher: &Hasher, ) -> Digest { let mut left_hasher = hasher.clone(); diff --git a/mpcs/src/util/merkle_tree.rs b/mpcs/src/util/merkle_tree.rs index c1bff62d0..b3d1bf899 100644 --- a/mpcs/src/util/merkle_tree.rs +++ b/mpcs/src/util/merkle_tree.rs @@ -1,4 +1,5 @@ use ff_ext::ExtensionField; +use itertools::Itertools; use multilinear_extensions::mle::FieldType; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, @@ -34,14 +35,14 @@ where { pub fn from_leaves(leaves: FieldType, hasher: &Hasher) -> Self { Self { - inner: merkelize::(&vec![&leaves], hasher), + inner: merkelize::(&[&leaves], hasher), leaves: vec![leaves], } } pub fn from_batch_leaves(leaves: Vec>, hasher: &Hasher) -> Self { Self { - inner: merkelize::(&leaves.iter().collect(), hasher), + inner: merkelize::(&leaves.iter().collect_vec(), hasher), leaves, } } @@ -80,8 +81,14 @@ where pub fn get_leaf_as_base(&self, index: usize) -> Vec { match &self.leaves[0] { - FieldType::Base(_) => self.leaves.iter().map(|leaves| field_type_index_base(leaves, index)).collect(), - FieldType::Ext(_) => panic!("Mismatching field type, calling get_leaf_as_base on a Merkle tree over extension fields"), + FieldType::Base(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_base(leaves, index)) + .collect(), + FieldType::Ext(_) => panic!( + "Mismatching field type, calling get_leaf_as_base on a Merkle tree over extension fields" + ), FieldType::Unreachable => unreachable!(), } } @@ -136,6 +143,10 @@ where Self { inner } } + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + pub fn len(&self) -> usize { self.inner.len() } @@ -234,7 +245,7 @@ where } fn merkelize( - values: &Vec<&FieldType>, + values: &[&FieldType], hasher: &Hasher, ) -> Vec>> { #[cfg(feature = "sanity-check")] @@ -273,14 +284,16 @@ fn merkelize( hasher, ), FieldType::Ext(_) => hash_two_leaves_batch_ext::( - &values + values .iter() .map(|values| field_type_index_ext(values, i << 1)) - .collect(), - &values + .collect_vec() + .as_slice(), + values .iter() .map(|values| field_type_index_ext(values, (i << 1) + 1)) - .collect(), + .collect_vec() + .as_slice(), hasher, ), FieldType::Unreachable => unreachable!(), @@ -303,7 +316,7 @@ fn merkelize( } fn authenticate_merkle_path_root( - path: &Vec>, + path: &[Digest], leaves: FieldType, x_index: usize, root: &Digest, @@ -319,11 +332,11 @@ fn authenticate_merkle_path_root( // The lowest bit in the index is ignored. It can point to either leaves x_index >>= 1; - for i in 0..path.len() { + for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, &path[i], hasher) + hash_two_digests(&hash, path_i, hasher) } else { - hash_two_digests(&path[i], &hash, hasher) + hash_two_digests(path_i, &hash, hasher) }; x_index >>= 1; } @@ -331,7 +344,7 @@ fn authenticate_merkle_path_root( } fn authenticate_merkle_path_root_batch( - path: &Vec>, + path: &[Digest], left: FieldType, right: FieldType, x_index: usize, @@ -351,11 +364,11 @@ fn authenticate_merkle_path_root_batch( // The lowest bit in the index is ignored. It can point to either leaves x_index >>= 1; - for i in 0..path.len() { + for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, &path[i], hasher) + hash_two_digests(&hash, path_i, hasher) } else { - hash_two_digests(&path[i], &hash, hasher) + hash_two_digests(path_i, &hash, hasher) }; x_index >>= 1; } diff --git a/mpcs/src/util/parallel.rs b/mpcs/src/util/parallel.rs index 5f44d473d..950d43797 100644 --- a/mpcs/src/util/parallel.rs +++ b/mpcs/src/util/parallel.rs @@ -1,10 +1,7 @@ pub fn num_threads() -> usize { #[cfg(feature = "parallel")] let nt = rayon::current_num_threads(); - return nt; - - #[cfg(not(feature = "parallel"))] - return 1; + if cfg!(feature = "parallel") { nt } else { 1 } } pub fn parallelize_iter(iter: I, f: F) diff --git a/mpcs/src/util/transcript.rs b/mpcs/src/util/transcript.rs index 343bfffba..2edf082cf 100644 --- a/mpcs/src/util/transcript.rs +++ b/mpcs/src/util/transcript.rs @@ -21,17 +21,16 @@ pub trait FieldTranscript { fn common_field_element_ext(&mut self, fe: &E) -> Result<(), Error>; fn common_field_elements(&mut self, fes: FieldType) -> Result<(), Error> { - Ok(match fes { + match fes { FieldType::Base(fes) => fes .iter() - .map(|fe| self.common_field_element_base(fe)) - .try_collect()?, + .try_for_each(|fe| self.common_field_element_base(fe))?, FieldType::Ext(fes) => fes .iter() - .map(|fe| self.common_field_element_ext(fe)) - .try_collect()?, + .try_for_each(|fe| self.common_field_element_ext(fe))?, FieldType::Unreachable => unreachable!(), - }) + }; + Ok(()) } } @@ -146,7 +145,7 @@ impl Stream { self.inner.len() - self.pointer } - pub fn read_exact(&mut self, output: &mut Vec) -> Result<(), Error> { + pub fn read_exact(&mut self, output: &mut [T]) -> Result<(), Error> { let left = self.left(); if left < output.len() { return Err(Error::Transcript(