diff --git a/Cargo.lock b/Cargo.lock index 002169d90..ea8775ddd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -300,6 +300,7 @@ dependencies = [ "strum 0.25.0", "strum_macros 0.25.3", "sumcheck", + "thread_local", "tracing", "tracing-flame", "tracing-subscriber", diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index bd13aab86..eb9d256f4 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -28,6 +28,7 @@ tracing-flame = "0.2.0" tracing = "0.1.40" rand = "0.8" +thread_local = "1.1.8" [dev-dependencies] pprof = { version = "0.13", features = ["flamegraph"]} diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 9385c0294..0b4a672a8 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -246,7 +246,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } - /// lookup a < b as usigned byte + /// lookup a < b as unsigned byte pub(crate) fn lookup_ltu_limb8( &mut self, res: Expression, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 7b44d052c..7a37fc39a 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -4,7 +4,11 @@ use ceno_emul::StepRecord; use ff_ext::ExtensionField; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; -use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix}; +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + witness::{LkMultiplicity, RowMajorMatrix}, +}; pub mod riscv; @@ -18,6 +22,7 @@ pub trait Instruction { fn assign_instance( config: &Self::InstructionConfig, instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, step: StepRecord, ) -> Result<(), ZKVMError>; @@ -25,15 +30,19 @@ pub trait Instruction { config: &Self::InstructionConfig, num_witin: usize, steps: Vec, - ) -> Result, ZKVMError> { + ) -> Result<(RowMajorMatrix, LkMultiplicity), ZKVMError> { + let lk_multiplicity = LkMultiplicity::default(); let mut raw_witin = RowMajorMatrix::::new(steps.len(), num_witin); let raw_witin_iter = raw_witin.par_iter_mut(); raw_witin_iter .zip_eq(steps.into_par_iter()) - .map(|(instance, step)| Self::assign_instance(config, instance, step)) + .map(|(instance, step)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + Self::assign_instance(config, instance, &mut lk_multiplicity, step) + }) .collect::>()?; - Ok(raw_witin) + Ok((raw_witin, lk_multiplicity)) } } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 47bbd49d8..84cf4f36c 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -17,6 +17,7 @@ use crate::{ instructions::Instruction, set_val, uint::UIntValue, + witness::LkMultiplicity, }; use core::mem::MaybeUninit; @@ -151,13 +152,14 @@ impl Instruction for AddInstruction { fn assign_instance( config: &Self::InstructionConfig, instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, step: StepRecord, ) -> Result<(), ZKVMError> { // TODO use fields from step set_val!(instance, config.pc, 1); set_val!(instance, config.ts, 2); - let addend_0 = UIntValue::new(step.rs1().unwrap().value); - let addend_1 = UIntValue::new(step.rs2().unwrap().value); + let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value); + let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value); config .prev_rd_value .assign_limbs(instance, [0, 0].iter().map(E::BaseField::from).collect()); @@ -167,7 +169,7 @@ impl Instruction for AddInstruction { config .addend_1 .assign_limbs(instance, addend_1.u16_fields()); - let carries = addend_0.add_u16_carries(&addend_1); + let (_, carries) = addend_0.add(&addend_1, lk_multiplicity, true); config.outcome.assign_carries( instance, carries @@ -199,6 +201,7 @@ impl Instruction for SubInstruction { fn assign_instance( config: &Self::InstructionConfig, instance: &mut [MaybeUninit], + _lk_multiplicity: &mut LkMultiplicity, _step: StepRecord, ) -> Result<(), ZKVMError> { // TODO use field from step @@ -263,7 +266,7 @@ mod test { .unwrap() .unwrap(); - let raw_witin = AddInstruction::assign_instances( + let (raw_witin, _) = AddInstruction::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord { @@ -310,7 +313,7 @@ mod test { .unwrap() .unwrap(); - let raw_witin = AddInstruction::assign_instances( + let (raw_witin, _) = AddInstruction::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord { diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index fc0181d4b..007126d4d 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -15,6 +15,7 @@ use crate::{ }, set_val, utils::{i64_to_base, limb_u8_to_u16}, + witness::LkMultiplicity, }; use super::{ @@ -222,6 +223,7 @@ impl Instruction for BltInstruction { fn assign_instance( config: &Self::InstructionConfig, instance: &mut [std::mem::MaybeUninit], + _lk_multiplicity: &mut LkMultiplicity, _step: ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { // take input from _step @@ -250,7 +252,7 @@ mod test { let num_wits = circuit_builder.cs.num_witin as usize; // generate mock witness let num_instances = 1 << 4; - let raw_witin = BltInstruction::assign_instances( + let (raw_witin, _) = BltInstruction::assign_instances( &config, num_wits, vec![StepRecord::default(); num_instances], diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 3e9612f1a..811132e63 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,5 +1,6 @@ #![feature(box_patterns)] #![feature(stmt_expr_attributes)] +#![feature(variant_count)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 8a27fdb9a..44927014d 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -7,6 +7,7 @@ use crate::{ error::{UtilError, ZKVMError}, expression::{Expression, ToExpr, WitIn}, utils::add_one_to_big_num, + witness::LkMultiplicity, }; use ark_std::iterable::Iterable; use constants::BYTE_BIT_WIDTH; @@ -476,13 +477,29 @@ impl + Copy> UIntValue { mem::size_of::() / u16_bytes }; - pub fn new(val: T) -> Self { + #[allow(dead_code)] + pub fn new(val: T, lkm: &mut LkMultiplicity) -> Self { + let uint = UIntValue:: { + val, + limbs: Self::split_to_u16(val), + }; + Self::assert_u16(&uint.limbs, lkm); + uint + } + + pub fn new_unchecked(val: T) -> Self { UIntValue:: { val, limbs: Self::split_to_u16(val), } } + fn assert_u16(v: &[u16], lkm: &mut LkMultiplicity) { + v.iter().for_each(|v| { + lkm.assert_ux::<16>(*v as u64); + }) + } + fn split_to_u16(value: T) -> Vec { let value: u64 = value.into(); // Convert to u64 for generality (0..Self::LIMBS) @@ -502,20 +519,35 @@ impl + Copy> UIntValue { self.limbs.iter().map(|v| F::from(*v as u64)).collect_vec() } - pub fn add_u16_carries(&self, rhs: &Self) -> Vec { - self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold( + pub fn add( + &self, + rhs: &Self, + lkm: &mut LkMultiplicity, + with_overflow: bool, + ) -> (Vec, Vec) { + let res = self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold( vec![], |mut acc, (a_limb, b_limb)| { let (a, b) = a_limb.overflowing_add(*b_limb); - if let Some(prev_carry) = acc.last() { - let (_, d) = a.overflowing_add(*prev_carry as u16); - acc.push(b || d); + if let Some((_, prev_carry)) = acc.last() { + let (e, d) = a.overflowing_add(*prev_carry as u16); + acc.push((e, b || d)); } else { - acc.push(b); + acc.push((a, b)); } + // range check + if let Some((limb, _)) = acc.last() { + lkm.assert_ux::<16>(*limb as u64); + }; acc }, - ) + ); + let (limbs, mut carries): (Vec, Vec) = res.into_iter().unzip(); + if !with_overflow { + carries.resize(carries.len() - 1, false); + } + carries.iter().for_each(|c| lkm.assert_ux::<16>(*c as u64)); + (limbs, carries) } } diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 9b44586da..718145ddc 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -1,6 +1,10 @@ use std::{ + array, + cell::RefCell, + collections::HashMap, mem::{self, MaybeUninit}, slice::ChunksMut, + sync::Arc, }; use multilinear_extensions::util::create_uninit_vec; @@ -8,6 +12,9 @@ use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, slice::ParallelSliceMut, }; +use thread_local::ThreadLocal; + +use crate::structs::ROMType; #[macro_export] macro_rules! set_val { @@ -51,3 +58,102 @@ impl RowMajorMatrix { .collect() } } + +/// A lock-free thread safe struct to count logup multiplicity for each ROM type +/// Lock-free by thread-local such that each thread will only have its local copy +/// struct is cloneable, for internallly it use Arc so the clone will be low cost +#[derive(Clone, Default)] +#[allow(clippy::type_complexity)] +pub struct LkMultiplicity { + multiplicity: Arc; mem::variant_count::()]>>>, +} + +#[allow(dead_code)] +impl LkMultiplicity { + /// assert within range + #[inline(always)] + pub fn assert_ux(&mut self, v: u64) { + match C { + 16 => self.assert_u16(v), + 8 => self.assert_byte(v), + 5 => self.assert_u5(v), + _ => panic!("Unsupported bit range"), + } + } + + fn assert_u5(&mut self, v: u64) { + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::U5 as usize] + .entry(v) + .or_default()) += 1; + } + + fn assert_u16(&mut self, v: u64) { + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::U16 as usize] + .entry(v) + .or_default()) += 1; + } + + fn assert_byte(&mut self, v: u64) { + let v = v * (1 << u8::BITS); + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::U16 as usize] + .entry(v) + .or_default()) += 1; + } + + /// lookup a < b as unsigned byte + pub fn lookup_ltu_limb8(&mut self, a: u64, b: u64) { + let key = a.wrapping_mul(256) + b; + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::Ltu as usize] + .entry(key) + .or_default()) += 1; + } + + /// merge result from multiple thread local to single result + fn into_finalize_result(self) -> [HashMap; mem::variant_count::()] { + Arc::try_unwrap(self.multiplicity) + .unwrap() + .into_iter() + .fold(array::from_fn(|_| HashMap::new()), |mut x, y| { + x.iter_mut().zip(y.borrow().iter()).for_each(|(m1, m2)| { + for (key, value) in m2 { + *m1.entry(*key).or_insert(0) += value; + } + }); + x + }) + } +} + +#[cfg(test)] +mod tests { + use std::thread; + + use crate::{structs::ROMType, witness::LkMultiplicity}; + + #[test] + fn test_lk_multiplicity_threads() { + // TODO figure out a way to verify thread_local hit/miss in unittest env + let lkm = LkMultiplicity::default(); + let thread_count = 20; + // each thread calling assert_byte once + for _ in 0..thread_count { + let mut lkm = lkm.clone(); + thread::spawn(move || lkm.assert_byte(8u64)).join().unwrap(); + } + let res = lkm.into_finalize_result(); + // check multiplicity counts of assert_byte + assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count); + } +} diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 18cecc5f0..4ff1ce8f3 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -46,5 +46,25 @@ sanity-check = [] print-trace = [ "ark-std/print-trace" ] [[bench]] -name = "commit_open_verify" +name = "commit_open_verify_rs" +harness = false + +[[bench]] +name = "commit_open_verify_basecode" +harness = false + +[[bench]] +name = "basecode" +harness = false + +[[bench]] +name = "rscode" +harness = false + +[[bench]] +name = "interpolate" +harness = false + +[[bench]] +name = "fft" harness = false \ No newline at end of file diff --git a/mpcs/benches/basecode.rs b/mpcs/benches/basecode.rs new file mode 100644 index 000000000..7e8ccca99 --- /dev/null +++ b/mpcs/benches/basecode.rs @@ -0,0 +1,110 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::{ + util::{ + arithmetic::interpolate_field_type_over_boolean_hypercube, + plonky2_util::reverse_index_bits_in_place_field_type, + }, Basefold, BasefoldBasecodeParams, BasefoldSpec, EncodingScheme, PolynomialCommitmentScheme +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type Pcs = Basefold; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "encoding_basecode_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, _) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + + let mut codeword = + <>::EncodingScheme as EncodingScheme>::encode( + &pp.encoding_params, + &coeffs, + ); + + // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above + // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); + + // The sum-check protocol starts from the first variable, but the FRI part + // will eventually produce the evaluation at (alpha_k, ..., alpha_1), so apply + // the bit-reversion to reverse the variable indices of the polynomial. + // In short: store the poly and codeword in big endian + reverse_index_bits_in_place_field_type(&mut coeffs); + reverse_index_bits_in_place_field_type(&mut codeword); + + (coeffs, codeword) + }) + .collect::<(Vec>, Vec>)>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify.rs b/mpcs/benches/commit_open_verify.rs deleted file mode 100644 index 8d731f514..000000000 --- a/mpcs/benches/commit_open_verify.rs +++ /dev/null @@ -1,278 +0,0 @@ -use std::time::Duration; - -use criterion::*; -use ff::Field; -use goldilocks::GoldilocksExt2; - -use itertools::{chain, Itertools}; -use mpcs::{ - util::transcript::{ - FieldTranscript, FieldTranscriptRead, FieldTranscriptWrite, InMemoryTranscript, - PoseidonTranscript, - }, - Basefold, BasefoldDefaultParams, Evaluation, PolynomialCommitmentScheme, -}; - -use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; -use rand::{rngs::OsRng, SeedableRng}; -use rand_chacha::ChaCha8Rng; - -type Pcs = Basefold; -type T = PoseidonTranscript; -type E = GoldilocksExt2; - -const NUM_SAMPLES: usize = 10; - -fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { - let mut group = c.benchmark_group(format!( - "commit_open_verify_goldilocks_{}", - if is_base { "base" } else { "ext2" } - )); - group.sample_size(NUM_SAMPLES); - // Challenge is over extension field, poly over the base field - for num_vars in 10..=20 { - let (pp, vp) = { - let rng = ChaCha8Rng::from_seed([0u8; 32]); - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size, &rng).unwrap(); - - group.bench_function(BenchmarkId::new("setup", format!("{}", num_vars)), |b| { - b.iter(|| { - Pcs::setup(poly_size, &rng).unwrap(); - }) - }); - Pcs::trim(¶m).unwrap() - }; - - let proof = { - let mut transcript = T::new(); - let poly = if is_base { - DenseMultilinearExtension::random(num_vars, &mut OsRng) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - }; - - let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); - - group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { - b.iter(|| { - Pcs::commit(&pp, &poly).unwrap(); - }) - }); - - let point = transcript.squeeze_challenges(num_vars); - let eval = poly.evaluate(point.as_slice()); - transcript.write_field_element_ext(&eval).unwrap(); - Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - - group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| { - b.iter_batched( - || transcript.clone(), - |mut transcript| { - Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - }, - BatchSize::SmallInput, - ); - }); - - transcript.into_proof() - }; - // Verify - let mut transcript = T::from_proof(proof.as_slice()); - Pcs::verify( - &vp, - &Pcs::read_commitment(&vp, &mut transcript).unwrap(), - &transcript.squeeze_challenges(num_vars), - &transcript.read_field_element_ext().unwrap(), - &mut transcript, - ) - .unwrap(); - group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| { - b.iter_batched( - || T::from_proof(proof.as_slice()), - |mut transcript| { - Pcs::verify( - &vp, - &Pcs::read_commitment(&vp, &mut transcript).unwrap(), - &transcript.squeeze_challenges(num_vars), - &transcript.read_field_element_ext().unwrap(), - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }); - } -} - -fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { - let mut group = c.benchmark_group(format!( - "commit_batch_open_verify_goldilocks_{}", - if is_base { "base" } else { "extension" } - )); - group.sample_size(NUM_SAMPLES); - // Challenge is over extension field, poly over the base field - for num_vars in 10..=20 { - for batch_size_log in 1..=6 { - let batch_size = 1 << batch_size_log; - let num_points = batch_size >> 1; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() - }; - // Batch commit and open - let evals = chain![ - (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys - (0..num_points).map(|point| (point * 2 + 1, point)), - ] - .unique() - .collect_vec(); - - let proof = { - let mut transcript = T::new(); - let polys = (0..batch_size) - .map(|i| { - if is_base { - DenseMultilinearExtension::random(num_vars - (i >> 1), &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - } - }) - .collect_vec(); - let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_commit", format!("{}", num_vars)), - |b| { - b.iter(|| { - Pcs::batch_commit(&pp, &polys).unwrap(); - }) - }, - ); - - let points = (0..num_points) - .map(|i| transcript.squeeze_challenges(num_vars - i)) - .take(num_points) - .collect_vec(); - - let evals = evals - .iter() - .copied() - .map(|(poly, point)| { - Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) - }) - .collect_vec(); - transcript - .write_field_elements_ext(evals.iter().map(Evaluation::value)) - .unwrap(); - Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_open", format!("{}", num_vars)), - |b| { - b.iter_batched( - || transcript.clone(), - |mut transcript| { - Pcs::batch_open( - &pp, - &polys, - &comms, - &points, - &evals, - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - - transcript.into_proof() - }; - // Batch verify - let mut transcript = T::from_proof(proof.as_slice()); - let comms = &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(); - - let points = (0..num_points) - .map(|i| transcript.squeeze_challenges(num_vars - i)) - .take(num_points) - .collect_vec(); - - let evals2 = transcript.read_field_elements_ext(evals.len()).unwrap(); - - Pcs::batch_verify( - &vp, - comms, - &points, - &evals - .iter() - .copied() - .zip(evals2.clone()) - .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) - .collect_vec(), - &mut transcript, - ) - .unwrap(); - - group.bench_function( - BenchmarkId::new("batch_verify", format!("{}", num_vars)), - |b| { - b.iter_batched( - || transcript.clone(), - |mut transcript| { - Pcs::batch_verify( - &vp, - comms, - &points, - &evals - .iter() - .copied() - .zip(evals2.clone()) - .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) - .collect_vec(), - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - } - } -} - -fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, false); -} - -fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, true); -} - -fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, false); -} - -fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, true); -} - -criterion_group! { - name = bench_basefold; - config = Criterion::default().warm_up_time(Duration::from_millis(3000)); - targets = bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2 -} - -criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify_basecode.rs b/mpcs/benches/commit_open_verify_basecode.rs new file mode 100644 index 000000000..2087397b4 --- /dev/null +++ b/mpcs/benches/commit_open_verify_basecode.rs @@ -0,0 +1,389 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::{chain, Itertools}; +use mpcs::{ + util::plonky2_util::log2_ceil, Basefold, BasefoldBasecodeParams, Evaluation, + PolynomialCommitmentScheme, +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use transcript::Transcript; + +type Pcs = Basefold; +type T = Transcript; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "commit_open_verify_goldilocks_basecode_{}", + if is_base { "base" } else { "ext2" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + let (pp, vp) = { + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + + group.bench_function(BenchmarkId::new("setup", format!("{}", num_vars)), |b| { + b.iter(|| { + Pcs::setup(poly_size, &rng).unwrap(); + }) + }); + Pcs::trim(¶m, poly_size).unwrap() + }; + + let mut transcript = T::new(b"BaseFold"); + let poly = if is_base { + DenseMultilinearExtension::random(num_vars, &mut OsRng) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + }; + + let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); + + group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { + b.iter(|| { + Pcs::commit(&pp, &poly).unwrap(); + }) + }); + + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + let eval = poly.evaluate(point.as_slice()); + transcript.append_field_element_ext(&eval); + let transcript_for_bench = transcript.clone(); + let proof = Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + + group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| { + b.iter_batched( + || transcript_for_bench.clone(), + |mut transcript| { + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + }, + BatchSize::SmallInput, + ); + }); + // Verify + let comm = Pcs::get_pure_commitment(&comm); + let mut transcript = T::new(b"BaseFold"); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + transcript.append_field_element_ext(&eval); + let transcript_for_bench = transcript.clone(); + Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); + group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| { + b.iter_batched( + || transcript_for_bench.clone(), + |mut transcript| { + Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); + }, + BatchSize::SmallInput, + ); + }); + } +} + +fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "batch_commit_open_verify_goldilocks_basecode_{}", + if is_base { "base" } else { "ext2" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let num_points = batch_size >> 1; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + // Setup + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + // Batch commit and open + let evals = chain![ + (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys + (0..num_points).map(|point| (point * 2 + 1, point)), + ] + .unique() + .collect_vec(); + + let mut transcript = T::new(b"BaseFold"); + let polys = (0..batch_size) + .map(|i| { + if is_base { + DenseMultilinearExtension::random( + num_vars - log2_ceil((i >> 1) + 1), + &mut rng.clone(), + ) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars - log2_ceil((i >> 1) + 1), + (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) + .map(|_| E::random(&mut OsRng)) + .collect(), + ) + } + }) + .collect_vec(); + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) + .collect_vec(); + + let points = (0..num_points) + .map(|i| { + (0..num_vars - i) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>() + }) + .take(num_points) + .collect_vec(); + + let evals = evals + .iter() + .copied() + .map(|(poly, point)| { + Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) + }) + .collect_vec(); + let values: Vec = evals + .iter() + .map(Evaluation::value) + .map(|x| *x) + .collect::>(); + transcript.append_field_element_exts(values.as_slice()); + let transcript_for_bench = transcript.clone(); + let proof = + Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript_for_bench.clone(), + |mut transcript| { + Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + // Batch verify + let mut transcript = T::new(b"BaseFold"); + let comms = comms + .iter() + .map(|comm| { + let comm = Pcs::get_pure_commitment(comm); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + comm + }) + .collect_vec(); + let points = (0..num_points) + .map(|i| { + (0..num_vars - i) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>() + }) + .take(num_points) + .collect_vec(); + + let values: Vec = evals + .iter() + .map(Evaluation::value) + .map(|x| *x) + .collect::>(); + transcript.append_field_element_exts(values.as_slice()); + + let backup_transcript = transcript.clone(); + + Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::batch_verify( + &vp, + &comms, + &points, + &evals, + &proof, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "simple_batch_commit_open_verify_goldilocks_basecode_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let mut transcript = T::new(b"BaseFold"); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + Pcs::batch_commit(&pp, &polys).unwrap(); + }) + }, + ); + + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); + + transcript.append_field_element_exts(&evals); + let transcript_for_bench = transcript.clone(); + let proof = Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript_for_bench.clone(), + |mut transcript| { + Pcs::simple_batch_open( + &pp, + &polys, + &comm, + &point, + &evals, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + let comm = Pcs::get_pure_commitment(&comm); + + // Batch verify + let mut transcript = Transcript::new(b"BaseFold"); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + transcript.append_field_element_exts(&evals); + let backup_transcript = transcript.clone(); + + Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::simple_batch_verify( + &vp, + &comm, + &point, + &evals, + &proof, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_commit_open_verify_goldilocks(c, false); +} + +fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_commit_open_verify_goldilocks(c, true); +} + +fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks(c, true); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2,bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify_rs.rs b/mpcs/benches/commit_open_verify_rs.rs new file mode 100644 index 000000000..6cbc0d1e1 --- /dev/null +++ b/mpcs/benches/commit_open_verify_rs.rs @@ -0,0 +1,389 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::{chain, Itertools}; +use mpcs::{ + util::plonky2_util::log2_ceil, Basefold, BasefoldRSParams, Evaluation, + PolynomialCommitmentScheme, +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use transcript::Transcript; + +type Pcs = Basefold; +type T = Transcript; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "ext2" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + let (pp, vp) = { + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + + group.bench_function(BenchmarkId::new("setup", format!("{}", num_vars)), |b| { + b.iter(|| { + Pcs::setup(poly_size, &rng).unwrap(); + }) + }); + Pcs::trim(¶m, poly_size).unwrap() + }; + + let mut transcript = T::new(b"BaseFold"); + let poly = if is_base { + DenseMultilinearExtension::random(num_vars, &mut OsRng) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + }; + + let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); + + group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { + b.iter(|| { + Pcs::commit(&pp, &poly).unwrap(); + }) + }); + + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + let eval = poly.evaluate(point.as_slice()); + transcript.append_field_element_ext(&eval); + let transcript_for_bench = transcript.clone(); + let proof = Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + + group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| { + b.iter_batched( + || transcript_for_bench.clone(), + |mut transcript| { + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + }, + BatchSize::SmallInput, + ); + }); + // Verify + let comm = Pcs::get_pure_commitment(&comm); + let mut transcript = T::new(b"BaseFold"); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + transcript.append_field_element_ext(&eval); + let transcript_for_bench = transcript.clone(); + Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); + group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| { + b.iter_batched( + || transcript_for_bench.clone(), + |mut transcript| { + Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); + }, + BatchSize::SmallInput, + ); + }); + } +} + +fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "ext2" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let num_points = batch_size >> 1; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + // Setup + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + // Batch commit and open + let evals = chain![ + (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys + (0..num_points).map(|point| (point * 2 + 1, point)), + ] + .unique() + .collect_vec(); + + let mut transcript = T::new(b"BaseFold"); + let polys = (0..batch_size) + .map(|i| { + if is_base { + DenseMultilinearExtension::random( + num_vars - log2_ceil((i >> 1) + 1), + &mut rng.clone(), + ) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars - log2_ceil((i >> 1) + 1), + (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) + .map(|_| E::random(&mut OsRng)) + .collect(), + ) + } + }) + .collect_vec(); + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) + .collect_vec(); + + let points = (0..num_points) + .map(|i| { + (0..num_vars - i) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>() + }) + .take(num_points) + .collect_vec(); + + let evals = evals + .iter() + .copied() + .map(|(poly, point)| { + Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) + }) + .collect_vec(); + let values: Vec = evals + .iter() + .map(Evaluation::value) + .map(|x| *x) + .collect::>(); + transcript.append_field_element_exts(values.as_slice()); + let transcript_for_bench = transcript.clone(); + let proof = + Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript_for_bench.clone(), + |mut transcript| { + Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + // Batch verify + let mut transcript = T::new(b"BaseFold"); + let comms = comms + .iter() + .map(|comm| { + let comm = Pcs::get_pure_commitment(comm); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + comm + }) + .collect_vec(); + let points = (0..num_points) + .map(|i| { + (0..num_vars - i) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>() + }) + .take(num_points) + .collect_vec(); + + let values: Vec = evals + .iter() + .map(Evaluation::value) + .map(|x| *x) + .collect::>(); + transcript.append_field_element_exts(values.as_slice()); + + let backup_transcript = transcript.clone(); + + Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::batch_verify( + &vp, + &comms, + &points, + &evals, + &proof, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "simple_batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let mut transcript = T::new(b"BaseFold"); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + Pcs::batch_commit(&pp, &polys).unwrap(); + }) + }, + ); + + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); + + transcript.append_field_element_exts(&evals); + let transcript_for_bench = transcript.clone(); + let proof = Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript_for_bench.clone(), + |mut transcript| { + Pcs::simple_batch_open( + &pp, + &polys, + &comm, + &point, + &evals, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + let comm = Pcs::get_pure_commitment(&comm); + + // Batch verify + let mut transcript = Transcript::new(b"BaseFold"); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + transcript.append_field_element_exts(&evals); + let backup_transcript = transcript.clone(); + + Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::simple_batch_verify( + &vp, + &comm, + &point, + &evals, + &proof, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_commit_open_verify_goldilocks(c, false); +} + +fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_commit_open_verify_goldilocks(c, true); +} + +fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks(c, true); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2,bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/fft.rs b/mpcs/benches/fft.rs new file mode 100644 index 000000000..2588a11e2 --- /dev/null +++ b/mpcs/benches/fft.rs @@ -0,0 +1,80 @@ +use std::time::Duration; + +use criterion::*; +use ff::{Field, PrimeField}; +use goldilocks::{Goldilocks, GoldilocksExt2}; + +use itertools::Itertools; +use mpcs::{coset_fft, fft_root_table}; + +use multilinear_extensions::mle::DenseMultilinearExtension; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 6; + +fn bench_fft(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "fft_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + let root_table = fft_root_table(num_vars); + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let mut polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys.par_iter_mut().for_each(|poly| { + coset_fft::( + &mut poly.evaluations, + Goldilocks::MULTIPLICATIVE_GENERATOR, + 0, + &root_table, + ); + }); + }) + }, + ); + } + } +} + +fn bench_fft_goldilocks_2(c: &mut Criterion) { + bench_fft(c, false); +} + +fn bench_fft_base(c: &mut Criterion) { + bench_fft(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_fft_base, bench_fft_goldilocks_2 +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/interpolate.rs b/mpcs/benches/interpolate.rs new file mode 100644 index 000000000..00af7c092 --- /dev/null +++ b/mpcs/benches/interpolate.rs @@ -0,0 +1,81 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::util::arithmetic::interpolate_field_type_over_boolean_hypercube; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "interpolate_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + coeffs + }) + .collect::>>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs new file mode 100644 index 000000000..ad42c7f39 --- /dev/null +++ b/mpcs/benches/rscode.rs @@ -0,0 +1,112 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::{ + util::{ + arithmetic::interpolate_field_type_over_boolean_hypercube, + plonky2_util::reverse_index_bits_in_place_field_type, + }, + Basefold, BasefoldRSParams, BasefoldSpec, EncodingScheme, + PolynomialCommitmentScheme, +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type Pcs = Basefold; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "encoding_rscode_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, _) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + + let mut codeword = + <>::EncodingScheme as EncodingScheme>::encode( + &pp.encoding_params, + &coeffs, + ); + + // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above + // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); + + // The sum-check protocol starts from the first variable, but the FRI part + // will eventually produce the evaluation at (alpha_k, ..., alpha_1), so apply + // the bit-reversion to reverse the variable indices of the polynomial. + // In short: store the poly and codeword in big endian + reverse_index_bits_in_place_field_type(&mut coeffs); + reverse_index_bits_in_place_field_type(&mut codeword); + + (coeffs, codeword) + }) + .collect::<(Vec>, Vec>)>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 29cc51d1d..9824bded9 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -8,28 +8,34 @@ use crate::{ arithmetic::{ inner_product, inner_product_three, interpolate_field_type_over_boolean_hypercube, }, - base_to_usize, expression::{Expression, Query, Rotation}, ext_to_usize, - hash::{new_hasher, Digest}, + hash::{new_hasher, write_digest_to_transcript, Digest}, log2_strict, merkle_tree::MerkleTree, multiply_poly, plonky2_util::reverse_index_bits_in_place_field_type, poly_index_ext, poly_iter_ext, - transcript::{TranscriptRead, TranscriptWrite}, - u32_to_field, }, validate_input, Error, Evaluation, NoninteractivePCS, PolynomialCommitmentScheme, }; use ark_std::{end_timer, start_timer}; +pub use encoding::{ + Basecode, BasecodeDefaultSpec, EncodingProverParameters, EncodingScheme, RSCode, + RSCodeDefaultSpec, +}; use ff_ext::ExtensionField; use multilinear_extensions::mle::MultilinearExtension; use query_phase::{ - batch_query_phase, batch_verifier_query_phase, query_phase, verifier_query_phase, + batch_prover_query_phase, batch_verifier_query_phase, prover_query_phase, + simple_batch_prover_query_phase, simple_batch_verifier_query_phase, verifier_query_phase, BatchedQueriesResultWithMerklePath, QueriesResultWithMerklePath, + SimpleBatchQueriesResultWithMerklePath, }; use std::{borrow::BorrowMut, ops::Deref}; +pub use structure::BasefoldSpec; +use structure::{BasefoldProof, ProofQueriesResultWithMerklePath}; +use transcript::Transcript; use itertools::Itertools; use serde::{de::DeserializeOwned, Serialize}; @@ -39,102 +45,67 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, }; -use rand_chacha::ChaCha8Rng; -use rayon::prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator}; +use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; +use rayon::{ + iter::IntoParallelIterator, + prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator}, +}; use std::borrow::Cow; + type SumCheck = ClassicSumCheck>; -mod basecode; mod structure; -use basecode::{ - encode_field_type_rs_basecode, evaluate_over_foldable_domain_generic_basecode, get_table_aes, -}; pub use structure::{ - Basefold, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, - BasefoldDefaultParams, BasefoldExtParams, BasefoldParams, BasefoldProverParams, + Basefold, BasefoldBasecodeParams, BasefoldCommitment, BasefoldCommitmentWithData, + BasefoldDefault, BasefoldParams, BasefoldProverParams, BasefoldRSParams, BasefoldVerifierParams, }; mod commit_phase; -use commit_phase::{batch_commit_phase, commit_phase}; +use commit_phase::{batch_commit_phase, commit_phase, simple_batch_commit_phase}; +mod encoding; +pub use encoding::{coset_fft, fft, fft_root_table}; mod query_phase; // This sumcheck module is different from the mpcs::sumcheck module, in that // it deals only with the special case of the form \sum eq(r_i)f_i(). mod sumcheck; -impl PolynomialCommitmentScheme for Basefold +impl, Rng: RngCore> Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, { - type Param = BasefoldParams; - type ProverParam = BasefoldProverParams; - type VerifierParam = BasefoldVerifierParams; - type CommitmentWithData = BasefoldCommitmentWithData; - type Commitment = BasefoldCommitment; - type CommitmentChunk = Digest; - type Rng = ChaCha8Rng; - - fn setup(poly_size: usize, rng: &Self::Rng) -> Result { - let log_rate = V::get_rate(); - let (table_w_weights, table) = get_table_aes::(poly_size, log_rate, &mut rng.clone()); - - Ok(BasefoldParams { - log_rate, - num_verifier_queries: V::get_reps(), - max_num_vars: log2_strict(poly_size), - table_w_weights, - table, - rng: rng.clone(), - }) - } - - fn trim(param: &Self::Param) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { - Ok(( - BasefoldProverParams { - log_rate: param.log_rate, - table_w_weights: param.table_w_weights.clone(), - table: param.table.clone(), - num_verifier_queries: param.num_verifier_queries, - max_num_vars: param.max_num_vars, - }, - BasefoldVerifierParams { - rng: param.rng.clone(), - max_num_vars: param.max_num_vars, - log_rate: param.log_rate, - num_verifier_queries: param.num_verifier_queries, - }, - )) - } - - fn commit( - pp: &Self::ProverParam, + /// Converts a polynomial to a code word, also returns the evaluations over the boolean hypercube + /// for said polynomial + fn get_poly_bh_evals_and_codeword( + pp: &BasefoldProverParams, poly: &DenseMultilinearExtension, - ) -> Result { - let timer = start_timer!(|| "Basefold::commit"); + ) -> (FieldType, FieldType) { // bh_evals is just a copy of poly.evals(). // Note that this function implicitly assumes that the size of poly.evals() is a // power of two. Otherwise, the function crashes with index out of bound. let mut bh_evals = poly.evaluations.clone(); - let num_vars = log2_strict(bh_evals.len()); - assert!(num_vars <= pp.max_num_vars && num_vars >= V::get_basecode()); + let num_vars = poly.num_vars; + assert!( + num_vars <= pp.encoding_params.get_max_message_size_log(), + "num_vars {} > pp.max_num_vars {}", + num_vars, + pp.encoding_params.get_max_message_size_log() + ); + assert!( + num_vars >= Spec::get_basecode_msg_size_log(), + "num_vars {} < Spec::get_basecode_msg_size_log() {}", + num_vars, + Spec::get_basecode_msg_size_log() + ); // Switch to coefficient form let mut coeffs = bh_evals.clone(); interpolate_field_type_over_boolean_hypercube(&mut coeffs); - // Split the input into chunks of message size, encode each message, and return the codewords - let basecode = - encode_field_type_rs_basecode(&coeffs, 1 << pp.log_rate, 1 << V::get_basecode()); - - // Apply the recursive definition of the BaseFold code to the list of base codewords, - // and produce the final codeword - let mut codeword = evaluate_over_foldable_domain_generic_basecode::( - 1 << V::get_basecode(), - coeffs.len(), - pp.log_rate, - basecode, - &pp.table, - ); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place_field_type(&mut coeffs); + } + let mut codeword = Spec::EncodingScheme::encode(&pp.encoding_params, &coeffs); // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); @@ -146,6 +117,108 @@ where reverse_index_bits_in_place_field_type(&mut bh_evals); reverse_index_bits_in_place_field_type(&mut codeword); + (bh_evals, codeword) + } + + /// Transpose a matrix of field elements, generic over the type of field element + pub fn transpose_field_type( + matrix: &[FieldType], + ) -> Result>, Error> { + let transpose_fn = match matrix[0] { + FieldType::Ext(_) => Self::get_column_ext, + FieldType::Base(_) => Self::get_column_base, + FieldType::Unreachable => unreachable!(), + }; + + let len = matrix[0].len(); + (0..len) + .into_par_iter() + .map(|i| (transpose_fn)(matrix, i)) + .collect() + } + + fn get_column_base( + matrix: &[FieldType], + column_index: usize, + ) -> Result, Error> { + Ok(FieldType::Base( + matrix + .par_iter() + .map(|row| match row { + FieldType::Base(content) => Ok(content[column_index]), + _ => Err(Error::InvalidPcsParam( + "expected base field type".to_string(), + )), + }) + .collect::, Error>>()?, + )) + } + + fn get_column_ext(matrix: &[FieldType], column_index: usize) -> Result, Error> { + Ok(FieldType::Ext( + matrix + .par_iter() + .map(|row| match row { + FieldType::Ext(content) => Ok(content[column_index]), + _ => Err(Error::InvalidPcsParam( + "expected ext field type".to_string(), + )), + }) + .collect::, Error>>()?, + )) + } +} + +impl, Rng: RngCore + std::fmt::Debug> + PolynomialCommitmentScheme for Basefold +where + E: Serialize + DeserializeOwned, + E::BaseField: Serialize + DeserializeOwned, +{ + type Param = BasefoldParams; + type ProverParam = BasefoldProverParams; + type VerifierParam = BasefoldVerifierParams; + type CommitmentWithData = BasefoldCommitmentWithData; + type Commitment = BasefoldCommitment; + type CommitmentChunk = Digest; + type Proof = BasefoldProof; + type Rng = ChaCha8Rng; + + fn setup(poly_size: usize, rng: &Self::Rng) -> Result { + let mut seed = [0u8; 32]; + let mut rng = rng.clone(); + rng.fill_bytes(&mut seed); + let pp = >::setup(log2_strict(poly_size), seed); + + Ok(BasefoldParams { params: pp }) + } + + fn trim( + pp: &Self::Param, + poly_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { + >::trim(&pp.params, log2_strict(poly_size)).map( + |(pp, vp)| { + ( + BasefoldProverParams { + encoding_params: pp, + }, + BasefoldVerifierParams { + encoding_params: vp, + }, + ) + }, + ) + } + + fn commit( + pp: &Self::ProverParam, + poly: &DenseMultilinearExtension, + ) -> Result { + let timer = start_timer!(|| "Basefold::commit"); + + let (bh_evals, codeword) = Self::get_poly_bh_evals_and_codeword(pp, poly); + // Compute and store all the layers of the Merkle tree let hasher = new_hasher::(); let codeword_tree = MerkleTree::::from_leaves(codeword, &hasher); @@ -160,70 +233,106 @@ where Ok(Self::CommitmentWithData { codeword_tree, - bh_evals, - num_vars, + polynomials_bh_evals: vec![bh_evals], + num_vars: poly.num_vars, is_base, + num_polys: 1, }) } - fn batch_commit_and_write( - pp: &Self::ProverParam, - polys: &Vec>, - transcript: &mut impl TranscriptWrite, - ) -> Result, Error> { - let timer = start_timer!(|| "Basefold::batch_commit_and_write"); - let comms = Self::batch_commit(pp, polys)?; - comms.iter().for_each(|comm| { - transcript.write_commitment(&comm.get_root_as()).unwrap(); - transcript - .write_field_element_base(&u32_to_field::(comm.num_vars as u32)) - .unwrap(); - transcript - .write_field_element_base(&u32_to_field::(comm.is_base as u32)) - .unwrap(); - }); - end_timer!(timer); - Ok(comms) - } - fn batch_commit( pp: &Self::ProverParam, - polys: &Vec>, - ) -> Result, Error> { - let polys_vec: Vec<&DenseMultilinearExtension> = - polys.into_iter().map(|poly| poly).collect(); - polys_vec + polys: &[DenseMultilinearExtension], + ) -> Result { + // assumptions + // 1. there must be at least one polynomial + // 2. all polynomials must exist in the same field type + // 3. all polynomials must have the same number of variables + + if polys.is_empty() { + return Err(Error::InvalidPcsParam( + "cannot batch commit to zero polynomials".to_string(), + )); + } + + for i in 1..polys.len() { + if polys[i].num_vars != polys[0].num_vars { + return Err(Error::InvalidPcsParam( + "cannot batch commit to polynomials with different number of variables" + .to_string(), + )); + } + } + + // convert each polynomial to a code word + let (bh_evals, codewords) = polys .par_iter() - .map(|poly| Self::commit(pp, poly)) - .collect() + .map(|poly| Self::get_poly_bh_evals_and_codeword(pp, poly)) + .collect::<(Vec>, Vec>)>(); + + // transpose the codewords, to group evaluations at the same point + // let leaves = Self::transpose_field_type::(codewords.as_slice())?; + + // build merkle tree from leaves + let hasher = new_hasher::(); + let codeword_tree = MerkleTree::::from_batch_leaves(codewords, &hasher); + + let is_base = match polys[0].evaluations { + FieldType::Ext(_) => false, + FieldType::Base(_) => true, + _ => unreachable!(), + }; + + Ok(Self::CommitmentWithData { + codeword_tree, + polynomials_bh_evals: bh_evals, + num_vars: polys[0].num_vars, + is_base, + num_polys: polys.len(), + }) + } + + fn write_commitment( + comm: &Self::Commitment, + transcript: &mut Transcript, + ) -> Result<(), Error> { + write_digest_to_transcript(&comm.root(), transcript); + Ok(()) + } + + fn get_pure_commitment(comm: &Self::CommitmentWithData) -> Self::Commitment { + comm.to_commitment() } + /// Open a single polynomial commitment at one point. If the given + /// commitment with data contains more than one polynomial, this function + /// will panic. fn open( pp: &Self::ProverParam, poly: &DenseMultilinearExtension, comm: &Self::CommitmentWithData, point: &[E], _eval: &E, // Opening does not need eval, except for sanity check - transcript: &mut impl TranscriptWrite, - ) -> Result<(), Error> { + transcript: &mut Transcript, + ) -> Result { let hasher = new_hasher::(); let timer = start_timer!(|| "Basefold::open"); - assert!(comm.num_vars >= V::get_basecode()); - let (trees, oracles) = commit_phase( - &point, - &comm, + assert!(comm.num_vars >= Spec::get_basecode_msg_size_log()); + assert!(comm.num_polys == 1); + let (trees, oracles, commit_phase_proof) = commit_phase::( + &pp.encoding_params, + point, + comm, transcript, poly.num_vars, - poly.num_vars - V::get_basecode(), - &pp.table_w_weights, - pp.log_rate, + poly.num_vars - Spec::get_basecode_msg_size_log(), &hasher, ); let query_timer = start_timer!(|| "Basefold::open::query_phase"); // Each entry in queried_els stores a list of triples (F, F, i) indicating the // position opened at each round and the two values at that round - let queries = query_phase(transcript, &comm, &oracles, pp.num_verifier_queries); + let queries = prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries()); end_timer!(query_timer); let query_timer = start_timer!(|| "Basefold::open::build_query_result"); @@ -238,23 +347,39 @@ where end_timer!(timer); - Ok(()) + Ok(Self::Proof { + sumcheck_messages: commit_phase_proof.sumcheck_messages, + roots: commit_phase_proof.roots, + final_message: commit_phase_proof.final_message, + query_result_with_merkle_path: ProofQueriesResultWithMerklePath::Single( + queries_with_merkle_path, + ), + sumcheck_proof: None, + }) } + /// Open a batch of polynomial commitments at several points. + /// The current version only supports one polynomial per commitment. + /// Because otherwise it is complex to match the polynomials and + /// the commitments, and because currently this high flexibility is + /// not very useful in ceno. fn batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], - transcript: &mut impl TranscriptWrite, - ) -> Result<(), Error> { + transcript: &mut Transcript, + ) -> Result { let hasher = new_hasher::(); let timer = start_timer!(|| "Basefold::batch_open"); let num_vars = polys.iter().map(|poly| poly.num_vars).max().unwrap(); - let comms = comms.into_iter().collect_vec(); let min_num_vars = polys.iter().map(|p| p.num_vars).min().unwrap(); - assert!(min_num_vars >= V::get_basecode()); + assert!(min_num_vars >= Spec::get_basecode_msg_size_log()); + + comms.iter().for_each(|comm| { + assert!(comm.num_polys == 1); + }); if cfg!(feature = "sanity-check") { evals.iter().for_each(|eval| { @@ -265,17 +390,18 @@ where }) } - validate_input( - "batch open", - pp.max_num_vars, - &polys.clone(), - &points.to_vec(), - )?; + validate_input("batch open", pp.get_max_message_size_log(), polys, points)?; let sumcheck_timer = start_timer!(|| "Basefold::batch_open::initial sumcheck"); // evals.len() is the batch size, i.e., how many polynomials are being opened together let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; - let t = transcript.squeeze_challenges(batch_size_log); + let t = (0..batch_size_log) + .map(|_| { + transcript + .get_and_append_challenge(b"batch coeffs") + .elements + }) + .collect::>(); // Use eq(X,t) where t is random to batch the different evaluation queries. // Note that this is a small polynomial (only batch_size) compared to the polynomials @@ -337,7 +463,7 @@ where .map(|((scalar, poly), point)| { inner_product( &poly_iter_ext(poly).collect_vec(), - build_eq_x_r_vec(&point).iter(), + build_eq_x_r_vec(point).iter(), ) * scalar * E::from(1 << (num_vars - poly.num_vars)) // When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube @@ -366,7 +492,7 @@ where let virtual_poly = VirtualPolynomial::new(&expression, sumcheck_polys, &[], points.as_slice()); - let (challenges, merged_poly_evals) = + let (challenges, merged_poly_evals, sumcheck_proof) = SumCheck::prove(&(), num_vars, virtual_poly, target_sum, transcript)?; end_timer!(sumcheck_timer); @@ -402,7 +528,7 @@ where ); *scalar * evals_from_sum_check - * &eq_xy_eval(point.as_slice(), &challenges[0..point.len()]) + * eq_xy_eval(point.as_slice(), &challenges[0..point.len()]) }) .sum::(); assert_eq!(new_target_sum, desired_sum); @@ -412,25 +538,24 @@ where let point = challenges; - let (trees, oracles) = batch_commit_phase( + let (trees, oracles, commit_phase_proof) = batch_commit_phase::( + &pp.encoding_params, &point, - comms.as_slice(), + comms, transcript, num_vars, - num_vars - V::get_basecode(), - &pp.table_w_weights, - pp.log_rate, + num_vars - Spec::get_basecode_msg_size_log(), coeffs.as_slice(), &hasher, ); let query_timer = start_timer!(|| "Basefold::batch_open query phase"); - let query_result = batch_query_phase( + let query_result = batch_prover_query_phase( transcript, - 1 << (num_vars + pp.log_rate), - comms.as_slice(), + 1 << (num_vars + Spec::get_rate_log()), + comms, &oracles, - pp.num_verifier_queries, + Spec::get_number_queries(), ); end_timer!(query_timer); @@ -439,7 +564,7 @@ where BatchedQueriesResultWithMerklePath::from_batched_query_result( query_result, &trees, - &comms, + comms, ); end_timer!(query_timer); @@ -448,44 +573,107 @@ where end_timer!(query_timer); end_timer!(timer); - Ok(()) + Ok(Self::Proof { + sumcheck_messages: commit_phase_proof.sumcheck_messages, + roots: commit_phase_proof.roots, + final_message: commit_phase_proof.final_message, + query_result_with_merkle_path: ProofQueriesResultWithMerklePath::Batched( + query_result_with_merkle_path, + ), + sumcheck_proof: Some(sumcheck_proof), + }) } - fn read_commitments( - _: &Self::VerifierParam, - num_polys: usize, - transcript: &mut impl TranscriptRead, - ) -> Result, Error> { - let roots = (0..num_polys) - .map(|_| { - let commitment = transcript.read_commitment().unwrap(); - let num_vars = base_to_usize::(&transcript.read_field_element_base().unwrap()); - let is_base = - base_to_usize::(&transcript.read_field_element_base().unwrap()) != 0; - (num_vars, commitment, is_base) - }) - .collect_vec(); + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment and have the same + /// number of variables. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + pp: &Self::ProverParam, + polys: &[DenseMultilinearExtension], + comm: &Self::CommitmentWithData, + point: &[E], + evals: &[E], + transcript: &mut Transcript, + ) -> Result { + let hasher = new_hasher::(); + let timer = start_timer!(|| "Basefold::batch_open"); + let num_vars = polys[0].num_vars; - Ok(roots + polys .iter() - .map(|(num_vars, commitment, is_base)| { - BasefoldCommitment::new(commitment.clone(), *num_vars, *is_base) + .for_each(|poly| assert_eq!(poly.num_vars, num_vars)); + assert!(num_vars >= Spec::get_basecode_msg_size_log()); + assert_eq!(comm.num_polys, polys.len()); + assert_eq!(comm.num_polys, evals.len()); + + if cfg!(feature = "sanity-check") { + evals + .iter() + .zip(polys) + .for_each(|(eval, poly)| assert_eq!(&poly.evaluate(point), eval)) + } + // evals.len() is the batch size, i.e., how many polynomials are being opened together + let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; + let t = (0..batch_size_log) + .map(|_| { + transcript + .get_and_append_challenge(b"batch coeffs") + .elements }) - .collect_vec()) - } + .collect::>(); - fn commit_and_write( - pp: &Self::ProverParam, - poly: &DenseMultilinearExtension, - transcript: &mut impl TranscriptWrite, - ) -> Result { - let comm = Self::commit(pp, poly)?; + // Use eq(X,t) where t is random to batch the different evaluation queries. + // Note that this is a small polynomial (only batch_size) compared to the polynomials + // to open. + let eq_xt = build_eq_x_r_vec(&t)[..evals.len()].to_vec(); + let _target_sum = inner_product(evals, &eq_xt); + + // Now the verifier has obtained the new target sum, and is able to compute the random + // linear coefficients. + // The remaining tasks for the prover is to prove that + // sum_i coeffs[i] poly_evals[i] is equal to + // the new target sum, where coeffs is computed as follows + let (trees, oracles, commit_phase_proof) = simple_batch_commit_phase::( + &pp.encoding_params, + point, + &eq_xt, + comm, + transcript, + num_vars, + num_vars - Spec::get_basecode_msg_size_log(), + &hasher, + ); + + let query_timer = start_timer!(|| "Basefold::open::query_phase"); + // Each entry in queried_els stores a list of triples (F, F, i) indicating the + // position opened at each round and the two values at that round + let queries = + simple_batch_prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries()); + end_timer!(query_timer); - transcript.write_commitment(&comm.get_root_as())?; - transcript.write_field_element_base(&u32_to_field::(comm.num_vars as u32))?; - transcript.write_field_element_base(&u32_to_field::(comm.is_base as u32))?; + let query_timer = start_timer!(|| "Basefold::open::build_query_result"); - Ok(comm) + let queries_with_merkle_path = + SimpleBatchQueriesResultWithMerklePath::from_query_result(queries, &trees, comm); + end_timer!(query_timer); + + let query_timer = start_timer!(|| "Basefold::open::write_queries"); + queries_with_merkle_path.write_transcript(transcript); + end_timer!(query_timer); + + end_timer!(timer); + + Ok(Self::Proof { + sumcheck_messages: commit_phase_proof.sumcheck_messages, + roots: commit_phase_proof.roots, + final_message: commit_phase_proof.final_message, + query_result_with_merkle_path: ProofQueriesResultWithMerklePath::SimpleBatched( + queries_with_merkle_path, + ), + sumcheck_proof: None, + }) } fn verify( @@ -493,59 +681,47 @@ where comm: &Self::Commitment, point: &[E], eval: &E, - transcript: &mut impl TranscriptRead, + proof: &Self::Proof, + transcript: &mut Transcript, ) -> Result<(), Error> { let timer = start_timer!(|| "Basefold::verify"); - assert!(comm.num_vars().unwrap() >= V::get_basecode()); + assert!(comm.num_vars().unwrap() >= Spec::get_basecode_msg_size_log()); let hasher = new_hasher::(); - let _field_size = 255; let num_vars = point.len(); - let num_rounds = num_vars - V::get_basecode(); + if let Some(comm_num_vars) = comm.num_vars() { + assert_eq!(num_vars, comm_num_vars); + } + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); - let mut fold_challenges: Vec = Vec::with_capacity(vp.max_num_vars); - let _size = 0; - let mut roots = Vec::new(); - let mut sumcheck_messages = Vec::with_capacity(num_rounds); - let sumcheck_timer = start_timer!(|| "Basefold::verify::interaction"); + let mut fold_challenges: Vec = Vec::with_capacity(num_vars); + let roots = &proof.roots; + let sumcheck_messages = &proof.sumcheck_messages; for i in 0..num_rounds { - sumcheck_messages.push(transcript.read_field_elements_ext(3).unwrap()); - fold_challenges.push(transcript.squeeze_challenge()); + transcript.append_field_element_exts(sumcheck_messages[i].as_slice()); + fold_challenges.push( + transcript + .get_and_append_challenge(b"commit round") + .elements, + ); if i < num_rounds - 1 { - roots.push(transcript.read_commitment().unwrap()); + write_digest_to_transcript(&roots[i], transcript); } } - end_timer!(sumcheck_timer); - let read_timer = start_timer!(|| "Basefold::verify::read transcript"); - let final_message = transcript - .read_field_elements_ext(1 << V::get_basecode()) - .unwrap(); - let query_challenges = transcript - .squeeze_challenges(vp.num_verifier_queries) - .iter() - .map(|index| ext_to_usize(index) % (1 << (num_vars + vp.log_rate))) - .collect_vec(); - let read_query_timer = start_timer!(|| "Basefold::verify::read query"); - let query_result_with_merkle_path = if comm.is_base() { - QueriesResultWithMerklePath::read_transcript_base( - transcript, - num_rounds, - vp.log_rate, - num_vars, - query_challenges.as_slice(), - ) - } else { - QueriesResultWithMerklePath::read_transcript_ext( - transcript, - num_rounds, - vp.log_rate, - num_vars, - query_challenges.as_slice(), - ) - }; - end_timer!(read_query_timer); - end_timer!(read_timer); + let final_message = &proof.final_message; + transcript.append_field_element_exts(final_message.as_slice()); + + let queries: Vec<_> = (0..Spec::get_number_queries()) + .map(|_| { + ext_to_usize( + &transcript + .get_and_append_challenge(b"query indices") + .elements, + ) % (1 << (num_vars + Spec::get_rate_log())) + }) + .collect(); + let query_result_with_merkle_path = proof.query_result_with_merkle_path.as_single(); // coeff is the eq polynomial evaluated at the last challenge.len() variables // in reverse order. @@ -558,19 +734,19 @@ where let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); eq.par_iter_mut().for_each(|e| *e *= coeff); - verifier_query_phase( + verifier_query_phase::( + queries.as_slice(), + &vp.encoding_params, &query_result_with_merkle_path, &sumcheck_messages, &fold_challenges, num_rounds, num_vars, - vp.log_rate, &final_message, &roots, comm, eq.as_slice(), - vp.rng.clone(), - &eval, + eval, &hasher, ); end_timer!(timer); @@ -580,19 +756,20 @@ where fn batch_verify( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], - transcript: &mut impl TranscriptRead, + proof: &Self::Proof, + transcript: &mut Transcript, ) -> Result<(), Error> { let timer = start_timer!(|| "Basefold::batch_verify"); // let key = "RAYON_NUM_THREADS"; // env::set_var(key, "32"); let hasher = new_hasher::(); - let comms = comms.into_iter().collect_vec(); + let comms = comms.iter().collect_vec(); let num_vars = points.iter().map(|point| point.len()).max().unwrap(); - let num_rounds = num_vars - V::get_basecode(); - validate_input("batch verify", vp.max_num_vars, &vec![], &points.to_vec())?; + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); + validate_input("batch verify", num_vars, &[], points)?; let poly_num_vars = comms.iter().map(|c| c.num_vars().unwrap()).collect_vec(); evals.iter().for_each(|eval| { assert_eq!( @@ -600,11 +777,17 @@ where comms[eval.poly()].num_vars().unwrap() ); }); - assert!(poly_num_vars.iter().min().unwrap() >= &V::get_basecode()); + assert!(poly_num_vars.iter().min().unwrap() >= &Spec::get_basecode_msg_size_log()); let sumcheck_timer = start_timer!(|| "Basefold::batch_verify::initial sumcheck"); let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; - let t = transcript.squeeze_challenges(batch_size_log); + let t = (0..batch_size_log) + .map(|_| { + transcript + .get_and_append_challenge(b"batch coeffs") + .elements + }) + .collect::>(); let eq_xt = DenseMultilinearExtension::from_evaluations_ext_vec(t.len(), build_eq_x_r_vec(&t)); @@ -617,8 +800,14 @@ where &poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(), ); - let (new_target_sum, verify_point) = - SumCheck::verify(&(), num_vars, 2, target_sum, transcript)?; + let (new_target_sum, verify_point) = SumCheck::verify( + &(), + num_vars, + 2, + target_sum, + proof.sumcheck_proof.as_ref().unwrap(), + transcript, + )?; end_timer!(sumcheck_timer); // Now the goal is to use the BaseFold to check the new target sum. Note that this time @@ -632,52 +821,33 @@ where coeffs[eval.poly()] += eq_xy_evals[eval.point()] * poly_index_ext(&eq_xt, i) }); - // start of verify - // read first $(num_var - 1) commitments - let read_timer = start_timer!(|| "Basefold::verify::read transcript"); - let mut sumcheck_messages: Vec> = Vec::with_capacity(num_rounds); - let mut roots: Vec> = Vec::with_capacity(num_rounds - 1); - let mut fold_challenges: Vec = Vec::with_capacity(num_rounds); + let mut fold_challenges: Vec = Vec::with_capacity(num_vars); + let roots = &proof.roots; + let sumcheck_messages = &proof.sumcheck_messages; for i in 0..num_rounds { - sumcheck_messages.push(transcript.read_field_elements_ext(3).unwrap()); - fold_challenges.push(transcript.squeeze_challenge()); + transcript.append_field_element_exts(sumcheck_messages[i].as_slice()); + fold_challenges.push( + transcript + .get_and_append_challenge(b"commit round") + .elements, + ); if i < num_rounds - 1 { - roots.push(transcript.read_commitment().unwrap()); + write_digest_to_transcript(&roots[i], transcript); } } - let final_message = transcript - .read_field_elements_ext(1 << V::get_basecode()) - .unwrap(); + let final_message = &proof.final_message; + transcript.append_field_element_exts(final_message.as_slice()); - let query_challenges = transcript - .squeeze_challenges(vp.num_verifier_queries) - .iter() - .map(|index| ext_to_usize(index) % (1 << (num_vars + vp.log_rate))) - .collect_vec(); - - let read_query_timer = start_timer!(|| "Basefold::verify::read query"); - // Here we assumed that all the commitments have the same type: - // either all base field or all extension field. Need to handle - // more complex case later. - let query_result_with_merkle_path = if comms[0].is_base { - BatchedQueriesResultWithMerklePath::read_transcript_base( - transcript, - num_rounds, - vp.log_rate, - poly_num_vars.as_slice(), - query_challenges.as_slice(), - ) - } else { - BatchedQueriesResultWithMerklePath::read_transcript_ext( - transcript, - num_rounds, - vp.log_rate, - poly_num_vars.as_slice(), - query_challenges.as_slice(), - ) - }; - end_timer!(read_query_timer); - end_timer!(read_timer); + let queries: Vec<_> = (0..Spec::get_number_queries()) + .map(|_| { + ext_to_usize( + &transcript + .get_and_append_challenge(b"query indices") + .elements, + ) % (1 << (num_vars + Spec::get_rate_log())) + }) + .collect(); + let query_result_with_merkle_path = proof.query_result_with_merkle_path.as_batched(); // coeff is the eq polynomial evaluated at the last challenge.len() variables // in reverse order. @@ -692,28 +862,122 @@ where ); eq.par_iter_mut().for_each(|e| *e *= coeff); - batch_verifier_query_phase( + batch_verifier_query_phase::( + queries.as_slice(), + &vp.encoding_params, &query_result_with_merkle_path, &sumcheck_messages, &fold_challenges, num_rounds, num_vars, - vp.log_rate, &final_message, &roots, &comms, &coeffs, eq.as_slice(), - vp.rng.clone(), &new_target_sum, &hasher, ); end_timer!(timer); Ok(()) } + + fn simple_batch_verify( + vp: &Self::VerifierParam, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + proof: &Self::Proof, + transcript: &mut Transcript, + ) -> Result<(), Error> { + let timer = start_timer!(|| "Basefold::simple batch verify"); + assert!(comm.num_vars().unwrap() >= Spec::get_basecode_msg_size_log()); + let batch_size = evals.len(); + if let Some(num_polys) = comm.num_polys { + assert_eq!(num_polys, batch_size); + } + let hasher = new_hasher::(); + + let num_vars = point.len(); + if let Some(comm_num_vars) = comm.num_vars { + assert_eq!(num_vars, comm_num_vars); + } + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); + + // evals.len() is the batch size, i.e., how many polynomials are being opened together + let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; + let t = (0..batch_size_log) + .map(|_| { + transcript + .get_and_append_challenge(b"batch coeffs") + .elements + }) + .collect::>(); + let eq_xt = build_eq_x_r_vec(&t)[..evals.len()].to_vec(); + + let mut fold_challenges: Vec = Vec::with_capacity(num_vars); + let roots = &proof.roots; + let sumcheck_messages = &proof.sumcheck_messages; + for i in 0..num_rounds { + transcript.append_field_element_exts(sumcheck_messages[i].as_slice()); + fold_challenges.push( + transcript + .get_and_append_challenge(b"commit round") + .elements, + ); + if i < num_rounds - 1 { + write_digest_to_transcript(&roots[i], transcript); + } + } + let final_message = &proof.final_message; + transcript.append_field_element_exts(final_message.as_slice()); + + let queries: Vec<_> = (0..Spec::get_number_queries()) + .map(|_| { + ext_to_usize( + &transcript + .get_and_append_challenge(b"query indices") + .elements, + ) % (1 << (num_vars + Spec::get_rate_log())) + }) + .collect(); + let query_result_with_merkle_path = proof.query_result_with_merkle_path.as_simple_batched(); + + // coeff is the eq polynomial evaluated at the last challenge.len() variables + // in reverse order. + let rev_challenges = fold_challenges.clone().into_iter().rev().collect_vec(); + let coeff = eq_xy_eval( + &point[point.len() - fold_challenges.len()..], + &rev_challenges, + ); + // Compute eq as the partially evaluated eq polynomial + let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); + eq.par_iter_mut().for_each(|e| *e *= coeff); + + simple_batch_verifier_query_phase::( + queries.as_slice(), + &vp.encoding_params, + &query_result_with_merkle_path, + &sumcheck_messages, + &fold_challenges, + &eq_xt, + num_rounds, + num_vars, + &final_message, + &roots, + comm, + eq.as_slice(), + evals, + &hasher, + ); + end_timer!(timer); + + Ok(()) + } } -impl NoninteractivePCS for Basefold +impl, Rng: RngCore + std::fmt::Debug> NoninteractivePCS + for Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, @@ -724,46 +988,94 @@ where mod test { use crate::{ basefold::Basefold, - test_util::{run_batch_commit_open_verify, run_commit_open_verify}, - util::transcript::PoseidonTranscript, + test_util::{ + run_batch_commit_open_verify, run_commit_open_verify, + run_simple_batch_commit_open_verify, + }, }; use goldilocks::GoldilocksExt2; + use rand_chacha::ChaCha8Rng; - use super::BasefoldDefaultParams; + use super::{structure::BasefoldBasecodeParams, BasefoldRSParams}; - type PcsGoldilocks = Basefold; + type PcsGoldilocksRSCode = Basefold; + type PcsGoldilocksBaseCode = Basefold; #[test] - fn commit_open_verify_goldilocks_base() { + fn commit_open_verify_goldilocks_basecode_base() { // Challenge is over extension field, poly over the base field - run_commit_open_verify::>( - true, 10, 11, - ); + run_commit_open_verify::(true, 10, 11); } #[test] - fn commit_open_verify_goldilocks_2() { + fn commit_open_verify_goldilocks_rscode_base() { + // Challenge is over extension field, poly over the base field + run_commit_open_verify::(true, 10, 11); + } + + #[test] + fn commit_open_verify_goldilocks_basecode_2() { // Both challenge and poly are over extension field - run_commit_open_verify::>( - false, 10, 11, + run_commit_open_verify::(false, 10, 11); + } + + #[test] + fn commit_open_verify_goldilocks_rscode_2() { + // Both challenge and poly are over extension field + run_commit_open_verify::(false, 10, 11); + } + + #[test] + fn simple_batch_commit_open_verify_goldilocks_basecode_base() { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::( + true, 10, 11, 4, ); } #[test] - fn batch_commit_open_verify_goldilocks_base() { + fn simple_batch_commit_open_verify_goldilocks_rscode_base() { // Both challenge and poly are over base field - run_batch_commit_open_verify::< - GoldilocksExt2, - PcsGoldilocks, - PoseidonTranscript, - >(true, 10, 11); + run_simple_batch_commit_open_verify::(true, 10, 11, 4); + } + + #[test] + fn simple_batch_commit_open_verify_goldilocks_basecode_2() { + // Both challenge and poly are over extension field + run_simple_batch_commit_open_verify::( + false, 10, 11, 4, + ); } #[test] - fn batch_commit_open_verify_goldilocks_2() { + fn simple_batch_commit_open_verify_goldilocks_rscode_2() { // Both challenge and poly are over extension field - run_batch_commit_open_verify::>( - false, 10, 11, + run_simple_batch_commit_open_verify::( + false, 10, 11, 4, ); } + + #[test] + fn batch_commit_open_verify_goldilocks_basecode_base() { + // Both challenge and poly are over base field + run_batch_commit_open_verify::(true, 10, 11); + } + + #[test] + fn batch_commit_open_verify_goldilocks_rscode_base() { + // Both challenge and poly are over base field + run_batch_commit_open_verify::(true, 10, 11); + } + + #[test] + fn batch_commit_open_verify_goldilocks_basecode_2() { + // Both challenge and poly are over extension field + run_batch_commit_open_verify::(false, 10, 11); + } + + #[test] + fn batch_commit_open_verify_goldilocks_rscode_2() { + // Both challenge and poly are over extension field + run_batch_commit_open_verify::(false, 10, 11); + } } diff --git a/mpcs/src/basefold/basecode.rs b/mpcs/src/basefold/basecode.rs deleted file mode 100644 index a10679ede..000000000 --- a/mpcs/src/basefold/basecode.rs +++ /dev/null @@ -1,293 +0,0 @@ -use crate::util::{arithmetic::base_from_raw_bytes, log2_strict, num_of_bytes}; -use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; -use ark_std::{end_timer, start_timer}; -use ctr; -use ff::{BatchInverter, Field, PrimeField}; -use ff_ext::ExtensionField; -use generic_array::GenericArray; -use multilinear_extensions::mle::FieldType; -use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; - -use itertools::Itertools; - -use crate::util::plonky2_util::reverse_index_bits_in_place; -use rand_chacha::rand_core::RngCore; -use rayon::prelude::IntoParallelRefIterator; - -use crate::util::arithmetic::{horner, steps}; - -pub fn encode_field_type_rs_basecode( - poly: &FieldType, - rate: usize, - message_size: usize, -) -> Vec> { - match poly { - FieldType::Ext(poly) => encode_rs_basecode(poly, rate, message_size) - .iter() - .map(|x| FieldType::Ext(x.clone())) - .collect(), - FieldType::Base(poly) => encode_rs_basecode(poly, rate, message_size) - .iter() - .map(|x| FieldType::Base(x.clone())) - .collect(), - _ => panic!("Unsupported field type"), - } -} - -// Split the input into chunks of message size, encode each message, and return the codewords -pub fn encode_rs_basecode( - poly: &Vec, - rate: usize, - message_size: usize, -) -> Vec> { - let timer = start_timer!(|| "Encode basecode"); - // The domain is just counting 1, 2, 3, ... , domain_size - let domain: Vec = steps(F::ONE).take(message_size * rate).collect(); - let res = poly - .par_chunks_exact(message_size) - .map(|chunk| { - let mut target = vec![F::ZERO; message_size * rate]; - // Just Reed-Solomon code, but with the naive domain - target - .iter_mut() - .enumerate() - .for_each(|(i, target)| *target = horner(&chunk[..], &domain[i])); - target - }) - .collect::>>(); - end_timer!(timer); - - res -} - -fn concatenate_field_types(coeffs: &Vec>) -> FieldType { - match coeffs[0] { - FieldType::Ext(_) => { - let res = coeffs - .iter() - .map(|x| match x { - FieldType::Ext(x) => x.iter().map(|x| *x), - _ => unreachable!(), - }) - .flatten() - .collect::>(); - FieldType::Ext(res) - } - FieldType::Base(_) => { - let res = coeffs - .iter() - .map(|x| match x { - FieldType::Base(x) => x.iter().map(|x| *x), - _ => unreachable!(), - }) - .flatten() - .collect::>(); - FieldType::Base(res) - } - _ => unreachable!(), - } -} - -// this function assumes all codewords in base_codeword has equivalent length -pub fn evaluate_over_foldable_domain_generic_basecode( - base_message_length: usize, - num_coeffs: usize, - log_rate: usize, - base_codewords: Vec>, - table: &Vec>, -) -> FieldType { - let timer = start_timer!(|| "evaluate over foldable domain"); - let k = num_coeffs; - let logk = log2_strict(k); - let base_log_k = log2_strict(base_message_length); - // concatenate together all base codewords - // let now = Instant::now(); - let mut coeffs_with_bc = concatenate_field_types(&base_codewords); - // println!("concatenate base codewords {:?}", now.elapsed()); - // iterate over array, replacing even indices with (evals[i] - evals[(i+1)]) - let mut chunk_size = base_codewords[0].len(); // block length of the base code - for i in base_log_k..logk { - // In beginning of each iteration, the current codeword size is 1<> 1); - match coeffs_with_bc { - FieldType::Ext(ref mut coeffs_with_bc) => { - coeffs_with_bc.par_chunks_mut(chunk_size).for_each(|chunk| { - let half_chunk = chunk_size >> 1; - for j in half_chunk..chunk_size { - // Suppose the current codewords are (a, b) - // The new codeword is computed by two halves: - // left = a + t * b - // right = a - t * b - let rhs = chunk[j] * E::from(level[j - half_chunk]); - chunk[j] = chunk[j - half_chunk] - rhs; - chunk[j - half_chunk] = chunk[j - half_chunk] + rhs; - } - }); - } - FieldType::Base(ref mut coeffs_with_bc) => { - coeffs_with_bc.par_chunks_mut(chunk_size).for_each(|chunk| { - let half_chunk = chunk_size >> 1; - for j in half_chunk..chunk_size { - // Suppose the current codewords are (a, b) - // The new codeword is computed by two halves: - // left = a + t * b - // right = a - t * b - let rhs = chunk[j] * level[j - half_chunk]; - chunk[j] = chunk[j - half_chunk] - rhs; - chunk[j - half_chunk] = chunk[j - half_chunk] + rhs; - } - }); - } - _ => unreachable!(), - } - } - end_timer!(timer); - coeffs_with_bc -} - -pub fn get_table_aes( - poly_size: usize, - rate: usize, - rng: &mut Rng, -) -> ( - Vec>, - Vec>, -) { - // The size (logarithmic) of the codeword for the polynomial - let lg_n: usize = rate + log2_strict(poly_size); - - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - - let mut cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - - // Allocate the buffer for storing n field elements (the entire codeword) - let bytes = num_of_bytes::(1 << lg_n); - let mut dest: Vec = vec![0u8; bytes]; - cipher.apply_keystream(&mut dest[..]); - - // Now, dest is a vector filled with random data for a field vector of size n - - // Collect the bytes into field elements - let flat_table: Vec = dest - .par_chunks_exact(num_of_bytes::(1)) - .map(|chunk| base_from_raw_bytes::(&chunk.to_vec())) - .collect::>(); - - // Now, flat_table is a field vector of size n, filled with random field elements - assert_eq!(flat_table.len(), 1 << lg_n); - - // Multiply -2 to every element to get the weights. Now weights = { -2x } - let mut weights: Vec = flat_table - .par_iter() - .map(|el| E::BaseField::ZERO - *el - *el) - .collect(); - - // Then invert all the elements. Now weights = { -1/2x } - let mut scratch_space = vec![E::BaseField::ZERO; weights.len()]; - BatchInverter::invert_with_external_scratch(&mut weights, &mut scratch_space); - - // Zip x and -1/2x together. The result is the list { (x, -1/2x) } - // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which - // involves computing 1/(b-a) where b=-x and a=x, and 1/(b-a) here is exactly -1/2x - let flat_table_w_weights = flat_table - .iter() - .zip(weights) - .map(|(el, w)| (*el, w)) - .collect_vec(); - - // Split the positions from 0 to n-1 into slices of sizes: - // 2, 2, 4, 8, ..., n/2, exactly lg_n number of them - // The weights are (x, -1/2x), the table elements are just x - - let mut unflattened_table_w_weights = vec![Vec::new(); lg_n]; - let mut unflattened_table = vec![Vec::new(); lg_n]; - - let mut level_weights = flat_table_w_weights[0..2].to_vec(); - // Apply the reverse-bits permutation to a vector of size 2, equivalent to just swapping - reverse_index_bits_in_place(&mut level_weights); - unflattened_table_w_weights[0] = level_weights; - - unflattened_table[0] = flat_table[0..2].to_vec(); - for i in 1..lg_n { - unflattened_table[i] = flat_table[(1 << i)..(1 << (i + 1))].to_vec(); - let mut level = flat_table_w_weights[(1 << i)..(1 << (i + 1))].to_vec(); - reverse_index_bits_in_place(&mut level); - unflattened_table_w_weights[i] = level; - } - - return (unflattened_table_w_weights, unflattened_table); -} - -pub fn query_point( - block_length: usize, - eval_index: usize, - level: usize, - mut cipher: &mut ctr::Ctr32LE, -) -> E::BaseField { - let level_index = eval_index % (block_length); - let mut el = - query_root_table_from_rng_aes::(level, level_index % (block_length >> 1), &mut cipher); - - if level_index >= (block_length >> 1) { - el = -E::BaseField::ONE * el; - } - - return el; -} - -pub fn query_root_table_from_rng_aes( - level: usize, - index: usize, - cipher: &mut ctr::Ctr32LE, -) -> E::BaseField { - let mut level_offset: u128 = 1; - for lg_m in 1..=level { - let half_m = 1 << (lg_m - 1); - level_offset += half_m; - } - - let pos = ((level_offset + (index as u128)) - * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) - .checked_div(8) - .unwrap(); - - cipher.seek(pos); - - let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; - let mut dest: Vec = vec![0u8; bytes]; - cipher.apply_keystream(&mut dest); - - let res = base_from_raw_bytes::(&dest); - - res -} - -#[cfg(test)] -mod tests { - use super::*; - use goldilocks::GoldilocksExt2; - use multilinear_extensions::mle::DenseMultilinearExtension; - - #[test] - fn time_rs_code() { - use rand::rngs::OsRng; - - let poly = DenseMultilinearExtension::random(20, &mut OsRng); - - encode_field_type_rs_basecode::(&poly.evaluations, 2, 64); - } -} diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index 1716edaea..7134a8717 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -1,5 +1,6 @@ use super::{ - basecode::encode_rs_basecode, + encoding::EncodingScheme, + structure::{BasefoldCommitPhaseProof, BasefoldSpec}, sumcheck::{ sum_check_challenge_round, sum_check_first_round, sum_check_first_round_field_type, sum_check_last_round, @@ -8,50 +9,58 @@ use super::{ use crate::util::{ arithmetic::{interpolate2_weights, interpolate_over_boolean_hypercube}, field_type_index_ext, field_type_iter_ext, - hash::{Digest, Hasher}, + hash::{write_digest_to_transcript, Hasher}, log2_strict, merkle_tree::MerkleTree, - transcript::TranscriptWrite, }; use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; - use itertools::Itertools; use serde::{de::DeserializeOwned, Serialize}; +use transcript::Transcript; use multilinear_extensions::{mle::FieldType, virtual_poly::build_eq_x_r_vec}; use crate::util::plonky2_util::reverse_index_bits_in_place; use rayon::prelude::{ - IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSlice, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, + ParallelSlice, }; use super::structure::BasefoldCommitmentWithData; // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) -pub fn commit_phase( +pub fn commit_phase>( + pp: &>::ProverParameters, point: &[E], comm: &BasefoldCommitmentWithData, - transcript: &mut impl TranscriptWrite, E>, + transcript: &mut Transcript, num_vars: usize, num_rounds: usize, - table_w_weights: &Vec>, - log_rate: usize, hasher: &Hasher, -) -> (Vec>, Vec>) +) -> (Vec>, Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, { let timer = start_timer!(|| "Commit phase"); + #[cfg(feature = "sanity-check")] assert_eq!(point.len(), num_vars); let mut oracles = Vec::with_capacity(num_vars); let mut trees = Vec::with_capacity(num_vars); - let mut running_oracle = field_type_iter_ext(comm.get_codeword()).collect_vec(); - let mut running_evals = comm.bh_evals.clone(); + let mut running_oracle = field_type_iter_ext(&comm.get_codewords()[0]).collect_vec(); + let mut running_evals = comm.polynomials_bh_evals[0].clone(); + + #[cfg(feature = "sanity-check")] + assert_eq!( + running_oracle.len(), + running_evals.len() << Spec::get_rate_log() + ); + #[cfg(feature = "sanity-check")] + assert_eq!(running_evals.len(), 1 << num_vars); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube let build_eq_timer = start_timer!(|| "Basefold::open"); - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); end_timer!(build_eq_timer); reverse_index_bits_in_place(&mut eq); @@ -59,38 +68,44 @@ where let mut last_sumcheck_message = sum_check_first_round_field_type(&mut eq, &mut running_evals); end_timer!(sumcheck_timer); + #[cfg(feature = "sanity-check")] + assert_eq!(last_sumcheck_message.len(), 3); + let mut running_evals = match running_evals { FieldType::Ext(evals) => evals, FieldType::Base(evals) => evals.iter().map(|x| E::from(*x)).collect_vec(), _ => unreachable!(), }; + let mut sumcheck_messages = Vec::with_capacity(num_rounds); + let mut roots = Vec::with_capacity(num_rounds - 1); + let mut final_message = Vec::new(); for i in 0..num_rounds { let sumcheck_timer = start_timer!(|| format!("Basefold round {}", i)); // For the first round, no need to send the running root, because this root is // committing to a vector that can be recovered from linearly combining other // already-committed vectors. - transcript - .write_field_elements_ext(&last_sumcheck_message) - .unwrap(); + transcript.append_field_element_exts(&last_sumcheck_message); + sumcheck_messages.push(last_sumcheck_message.clone()); - let challenge = transcript.squeeze_challenge(); + let challenge = transcript.get_and_append_challenge(b"commit round"); // Fold the current oracle for FRI - running_oracle = basefold_one_round_by_interpolation_weights::( - &table_w_weights, + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, log2_strict(running_oracle.len()) - 1, &running_oracle, - challenge, + challenge.elements, ); if i < num_rounds - 1 { last_sumcheck_message = - sum_check_challenge_round(&mut eq, &mut running_evals, challenge); + sum_check_challenge_round(&mut eq, &mut running_evals, challenge.elements); let running_tree = MerkleTree::::from_leaves(FieldType::Ext(running_oracle.clone()), hasher); let running_root = running_tree.root(); - transcript.write_commitment(&running_root).unwrap(); + write_digest_to_transcript(&running_root, transcript); + roots.push(running_root.clone()); oracles.push(running_oracle.clone()); trees.push(running_tree); @@ -99,11 +114,12 @@ where // and we don't interpolate the small polynomials. So after the last round, // running_evals is exactly the evaluation representation of the // folded polynomial so far. - sum_check_last_round(&mut eq, &mut running_evals, challenge); + sum_check_last_round(&mut eq, &mut running_evals, challenge.elements); // For the FRI part, we send the current polynomial as the message. // Transform it back into little endiean before sending it reverse_index_bits_in_place(&mut running_evals); - transcript.write_field_elements_ext(&running_evals).unwrap(); + transcript.append_field_element_exts(&running_evals); + final_message = running_evals.clone(); if cfg!(feature = "sanity-check") { // If the prover is honest, in the last round, the running oracle @@ -111,9 +127,17 @@ where let mut coeffs = running_evals.clone(); interpolate_over_boolean_hypercube(&mut coeffs); - let basecode = encode_rs_basecode(&coeffs, 1 << log_rate, coeffs.len()); - assert_eq!(basecode.len(), 1); - let basecode = basecode[0].clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(b) => b, + _ => panic!("Should be ext field"), + }; reverse_index_bits_in_place(&mut running_oracle); assert_eq!(basecode, running_oracle); @@ -122,21 +146,29 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + return ( + trees, + oracles, + BasefoldCommitPhaseProof { + sumcheck_messages, + roots, + final_message, + }, + ); } // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) -pub fn batch_commit_phase( +#[allow(clippy::too_many_arguments)] +pub fn batch_commit_phase>( + pp: &>::ProverParameters, point: &[E], - comms: &[&BasefoldCommitmentWithData], - transcript: &mut impl TranscriptWrite, E>, + comms: &[BasefoldCommitmentWithData], + transcript: &mut Transcript, num_vars: usize, num_rounds: usize, - table_w_weights: &Vec>, - log_rate: usize, coeffs: &[E], hasher: &Hasher, -) -> (Vec>, Vec>) +) -> (Vec>, Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, { @@ -144,7 +176,7 @@ where assert_eq!(point.len(), num_vars); let mut oracles = Vec::with_capacity(num_vars); let mut trees = Vec::with_capacity(num_vars); - let mut running_oracle = vec![E::ZERO; 1 << (num_vars + log_rate)]; + let mut running_oracle = vec![E::ZERO; 1 << (num_vars + Spec::get_rate_log())]; let build_oracle_timer = start_timer!(|| "Basefold build initial oracle"); // Before the interaction, collect all the polynomials whose num variables match the @@ -157,8 +189,8 @@ where .for_each(|(index, comm)| { running_oracle .iter_mut() - .zip_eq(field_type_iter_ext(comm.get_codeword())) - .for_each(|(r, a)| *r += E::from(a) * coeffs[index]); + .zip_eq(field_type_iter_ext(&comm.get_codewords()[0])) + .for_each(|(r, a)| *r += a * coeffs[index]); }); end_timer!(build_oracle_timer); @@ -177,16 +209,16 @@ where // to align the polynomials to the variable with index 0 before adding them // together. So each element is repeated by // sum_of_all_evals_for_sumcheck.len() / bh_evals.len() times - *r += E::from(field_type_index_ext( - &comm.bh_evals, - pos >> (num_vars - log2_strict(comm.bh_evals.len())), - )) * coeffs[index] + *r += field_type_index_ext( + &comm.polynomials_bh_evals[0], + pos >> (num_vars - log2_strict(comm.polynomials_bh_evals[0].len())), + ) * coeffs[index] }); }); end_timer!(build_oracle_timer); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); reverse_index_bits_in_place(&mut eq); let sumcheck_timer = start_timer!(|| "Basefold first round"); @@ -196,20 +228,22 @@ where sumcheck_messages.push(last_sumcheck_message.clone()); end_timer!(sumcheck_timer); + let mut roots = Vec::with_capacity(num_rounds - 1); + let mut final_message = Vec::new(); for i in 0..num_rounds { let sumcheck_timer = start_timer!(|| format!("Batch basefold round {}", i)); // For the first round, no need to send the running root, because this root is // committing to a vector that can be recovered from linearly combining other // already-committed vectors. - transcript - .write_field_elements_ext(&last_sumcheck_message) - .unwrap(); + transcript.append_field_element_exts(&last_sumcheck_message); - let challenge = transcript.squeeze_challenge(); + let challenge = transcript + .get_and_append_challenge(b"commit round") + .elements; // Fold the current oracle for FRI - running_oracle = basefold_one_round_by_interpolation_weights::( - &table_w_weights, + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, log2_strict(running_oracle.len()) - 1, &running_oracle, challenge, @@ -222,7 +256,8 @@ where let running_tree = MerkleTree::::from_leaves(FieldType::Ext(running_oracle.clone()), hasher); let running_root = running_tree.root(); - transcript.write_commitment(&running_root).unwrap(); + write_digest_to_transcript(&running_root, transcript); + roots.push(running_root); oracles.push(running_oracle.clone()); trees.push(running_tree); @@ -236,8 +271,8 @@ where .for_each(|(index, comm)| { running_oracle .iter_mut() - .zip_eq(field_type_iter_ext(comm.get_codeword())) - .for_each(|(r, a)| *r += E::from(a) * coeffs[index]); + .zip_eq(field_type_iter_ext(&comm.get_codewords()[0])) + .for_each(|(r, a)| *r += a * coeffs[index]); }); } else { // The difference of the last round is that we don't need to compute the message, @@ -248,19 +283,157 @@ where // For the FRI part, we send the current polynomial as the message. // Transform it back into little endiean before sending it reverse_index_bits_in_place(&mut sum_of_all_evals_for_sumcheck); - transcript - .write_field_elements_ext(&sum_of_all_evals_for_sumcheck) - .unwrap(); + transcript.append_field_element_exts(&sum_of_all_evals_for_sumcheck); + final_message = sum_of_all_evals_for_sumcheck.clone(); if cfg!(feature = "sanity-check") { // If the prover is honest, in the last round, the running oracle // on the prover side should be exactly the encoding of the folded polynomial. let mut coeffs = sum_of_all_evals_for_sumcheck.clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } + interpolate_over_boolean_hypercube(&mut coeffs); + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(x) => x, + _ => panic!("Expected ext field"), + }; + + reverse_index_bits_in_place(&mut running_oracle); + assert_eq!(basecode, running_oracle); + } + } + end_timer!(sumcheck_timer); + } + end_timer!(timer); + return ( + trees, + oracles, + BasefoldCommitPhaseProof { + sumcheck_messages, + roots, + final_message, + }, + ); +} + +// outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) +#[allow(clippy::too_many_arguments)] +pub fn simple_batch_commit_phase>( + pp: &>::ProverParameters, + point: &[E], + batch_coeffs: &[E], + comm: &BasefoldCommitmentWithData, + transcript: &mut Transcript, + num_vars: usize, + num_rounds: usize, + hasher: &Hasher, +) -> (Vec>, Vec>, BasefoldCommitPhaseProof) +where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Simple batch commit phase"); + assert_eq!(point.len(), num_vars); + assert_eq!(comm.num_polys, batch_coeffs.len()); + let prepare_timer = start_timer!(|| "Prepare"); + let mut oracles = Vec::with_capacity(num_vars); + let mut trees = Vec::with_capacity(num_vars); + let batch_codewords_timer = start_timer!(|| "Batch codewords"); + let mut running_oracle = comm.batch_codewords(batch_coeffs); + end_timer!(batch_codewords_timer); + let mut running_evals = (0..(1 << num_vars)) + .into_par_iter() + .map(|i| { + comm.polynomials_bh_evals + .iter() + .zip(batch_coeffs) + .map(|(eval, coeff)| field_type_index_ext(eval, i) * *coeff) + .sum() + }) + .collect(); + end_timer!(prepare_timer); + + // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube + let build_eq_timer = start_timer!(|| "Basefold::build eq"); + let mut eq = build_eq_x_r_vec(point); + end_timer!(build_eq_timer); + + let reverse_bits_timer = start_timer!(|| "Basefold::reverse bits"); + reverse_index_bits_in_place(&mut eq); + end_timer!(reverse_bits_timer); + + let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round"); + let mut last_sumcheck_message = sum_check_first_round(&mut eq, &mut running_evals); + end_timer!(sumcheck_timer); + + let mut sumcheck_messages = Vec::with_capacity(num_rounds); + let mut roots = Vec::with_capacity(num_rounds - 1); + let mut final_message = Vec::new(); + for i in 0..num_rounds { + let sumcheck_timer = start_timer!(|| format!("Basefold round {}", i)); + // For the first round, no need to send the running root, because this root is + // committing to a vector that can be recovered from linearly combining other + // already-committed vectors. + transcript.append_field_element_exts(&last_sumcheck_message); + sumcheck_messages.push(last_sumcheck_message.clone()); + + let challenge = transcript + .get_and_append_challenge(b"commit round") + .elements; + + // Fold the current oracle for FRI + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, + log2_strict(running_oracle.len()) - 1, + &running_oracle, + challenge, + ); + + if i < num_rounds - 1 { + last_sumcheck_message = + sum_check_challenge_round(&mut eq, &mut running_evals, challenge); + let running_tree = + MerkleTree::::from_leaves(FieldType::Ext(running_oracle.clone()), hasher); + let running_root = running_tree.root(); + write_digest_to_transcript(&running_root, transcript); + roots.push(running_root); + + oracles.push(running_oracle.clone()); + trees.push(running_tree); + } else { + // The difference of the last round is that we don't need to compute the message, + // and we don't interpolate the small polynomials. So after the last round, + // running_evals is exactly the evaluation representation of the + // folded polynomial so far. + sum_check_last_round(&mut eq, &mut running_evals, challenge); + // For the FRI part, we send the current polynomial as the message. + // Transform it back into little endiean before sending it + reverse_index_bits_in_place(&mut running_evals); + transcript.append_field_element_exts(&running_evals); + final_message = running_evals.clone(); + + if cfg!(feature = "sanity-check") { + // If the prover is honest, in the last round, the running oracle + // on the prover side should be exactly the encoding of the folded polynomial. + + let mut coeffs = running_evals.clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } interpolate_over_boolean_hypercube(&mut coeffs); - let basecode = encode_rs_basecode(&coeffs, 1 << log_rate, coeffs.len()); - assert_eq!(basecode.len(), 1); - let basecode = basecode[0].clone(); + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(basecode) => basecode, + _ => panic!("Should be ext field"), + }; reverse_index_bits_in_place(&mut running_oracle); assert_eq!(basecode, running_oracle); @@ -269,28 +442,30 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + return ( + trees, + oracles, + BasefoldCommitPhaseProof { + sumcheck_messages, + roots, + final_message, + }, + ); } -fn basefold_one_round_by_interpolation_weights( - table: &Vec>, - level_index: usize, +fn basefold_one_round_by_interpolation_weights>( + pp: &>::ProverParameters, + level: usize, values: &Vec, challenge: E, ) -> Vec { - let level = &table[level_index]; values .par_chunks_exact(2) .enumerate() .map(|(i, ys)| { - interpolate2_weights( - [ - (E::from(level[i].0), ys[0]), - (E::from(-(level[i].0)), ys[1]), - ], - E::from(level[i].1), - challenge, - ) + let (x0, x1, w) = + >::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, ys[0]), (x1, ys[1])], w, challenge) }) .collect::>() } diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs new file mode 100644 index 000000000..9fa651e2f --- /dev/null +++ b/mpcs/src/basefold/encoding.rs @@ -0,0 +1,237 @@ +use ff_ext::ExtensionField; +use multilinear_extensions::mle::FieldType; + +mod utils; + +mod basecode; +pub use basecode::{Basecode, BasecodeDefaultSpec}; + +mod rs; +use plonky2::util::log2_strict; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; +pub use rs::{coset_fft, fft, fft_root_table, RSCode, RSCodeDefaultSpec}; + +use serde::{de::DeserializeOwned, Serialize}; + +use crate::{util::arithmetic::interpolate2_weights, Error}; + +pub trait EncodingProverParameters { + fn get_max_message_size_log(&self) -> usize; +} + +pub trait EncodingScheme: std::fmt::Debug + Clone { + type PublicParameters: Clone + std::fmt::Debug + Serialize + DeserializeOwned; + type ProverParameters: Clone + + std::fmt::Debug + + Serialize + + DeserializeOwned + + EncodingProverParameters + + Sync; + type VerifierParameters: Clone + std::fmt::Debug + Serialize + DeserializeOwned + Sync; + + fn setup(max_msg_size_log: usize, rng_seed: [u8; 32]) -> Self::PublicParameters; + + fn trim( + pp: &Self::PublicParameters, + max_msg_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>; + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType; + + /// Encodes a message of small length, such that the verifier is also able + /// to execute the encoding. + fn encode_small(vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType; + + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; + + /// Whether the message needs to be bit-reversed to allow even-odd + /// folding. If the folding is already even-odd style (like RS code), + /// then set this function to return false. If the folding is originally + /// left-right, like basefold, then return true. + fn message_is_left_and_right_folding() -> bool; + + fn message_is_even_and_odd_folding() -> bool { + !Self::message_is_left_and_right_folding() + } + + /// Returns three values: x0, x1 and 1/(x1-x0). Note that although + /// 1/(x1-x0) can be computed from the other two values, we return it + /// separately because inversion is expensive. + /// These three values can be used to interpolate a linear function + /// that passes through the two points (x0, y0) and (x1, y1), for the + /// given y0 and y1, then compute the value of the linear function at + /// any give x. + /// Params: + /// - level: which particular code in this family of codes? + /// - index: position in the codeword (after folded) + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E); + + /// The same as `prover_folding_coeffs`, but for the verifier. The two + /// functions, although provide the same functionality, may use different + /// implementations. For example, prover can use precomputed values stored + /// in the parameters, but the verifier may need to recompute them. + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E); + + /// Fold the given codeword into a smaller codeword of half size, using + /// the folding coefficients computed by `prover_folding_coeffs`. + /// The given codeword is assumed to be bit-reversed on the original + /// codeword directly produced from the `encode` method. + fn fold_bitreversed_codeword( + pp: &Self::ProverParameters, + codeword: &FieldType, + challenge: E, + ) -> Vec { + let level = log2_strict(codeword.len()) - 1; + match codeword { + FieldType::Ext(codeword) => codeword + .par_chunks_exact(2) + .enumerate() + .map(|(i, ys)| { + let (x0, x1, w) = Self::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, ys[0]), (x1, ys[1])], w, challenge) + }) + .collect::>(), + FieldType::Base(codeword) => codeword + .par_chunks_exact(2) + .enumerate() + .map(|(i, ys)| { + let (x0, x1, w) = Self::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, E::from(ys[0])), (x1, E::from(ys[1]))], w, challenge) + }) + .collect::>(), + _ => panic!("Unsupported field type"), + } + } + + /// Fold the given message into a smaller message of half size using challenge + /// as the random linear combination coefficient. + /// Note that this is always even-odd fold, assuming the message has + /// been bit-reversed (or not) according to the setting + /// of the `message_need_bit_reversion` function. + fn fold_message(msg: &FieldType, challenge: E) -> Vec { + match msg { + FieldType::Ext(msg) => msg + .par_chunks_exact(2) + .map(|ys| ys[0] + ys[1] * challenge) + .collect::>(), + FieldType::Base(msg) => msg + .par_chunks_exact(2) + .map(|ys| E::from(ys[0]) + E::from(ys[1]) * challenge) + .collect::>(), + _ => panic!("Unsupported field type"), + } + } +} + +fn concatenate_field_types(coeffs: &[FieldType]) -> FieldType { + match coeffs[0] { + FieldType::Ext(_) => { + let res = coeffs + .iter() + .flat_map(|x| match x { + FieldType::Ext(x) => x.iter().copied(), + _ => unreachable!(), + }) + .collect::>(); + FieldType::Ext(res) + } + FieldType::Base(_) => { + let res = coeffs + .iter() + .flat_map(|x| match x { + FieldType::Base(x) => x.iter().copied(), + _ => unreachable!(), + }) + .collect::>(); + FieldType::Base(res) + } + _ => unreachable!(), + } +} + +#[cfg(test)] +pub(crate) mod test_util { + use ff_ext::ExtensionField; + use multilinear_extensions::mle::FieldType; + use rand::rngs::OsRng; + + use crate::util::plonky2_util::reverse_index_bits_in_place_field_type; + + use super::EncodingScheme; + + pub fn test_codeword_folding>() { + let num_vars = 12; + + let poly: Vec = (0..(1 << num_vars)).map(|i| E::from(i)).collect(); + let mut poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp: Code::PublicParameters = Code::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + reverse_index_bits_in_place_field_type(&mut codeword); + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut poly); + } + let challenge = E::random(&mut OsRng); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let mut folded_message = FieldType::Ext(Code::fold_message(&poly, challenge)); + if Code::message_is_left_and_right_folding() { + // Reverse the message back before encoding if it has been + // bit-reversed + reverse_index_bits_in_place_field_type(&mut folded_message); + } + let mut encoded_folded_message = Code::encode(&pp, &folded_message); + reverse_index_bits_in_place_field_type(&mut encoded_folded_message); + let encoded_folded_message = match encoded_folded_message { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + for (i, (a, b)) in folded_codeword + .iter() + .zip(encoded_folded_message.iter()) + .enumerate() + { + assert_eq!(a, b, "Failed at index {}", i); + } + + let mut folded_codeword = FieldType::Ext(folded_codeword); + for round in 0..4 { + let folded_codeword_vec = + Code::fold_bitreversed_codeword(&pp, &folded_codeword, challenge); + + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut folded_message); + } + folded_message = FieldType::Ext(Code::fold_message(&folded_message, challenge)); + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut folded_message); + } + let mut encoded_folded_message = Code::encode(&pp, &folded_message); + reverse_index_bits_in_place_field_type(&mut encoded_folded_message); + let encoded_folded_message = match encoded_folded_message { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + for (i, (a, b)) in folded_codeword_vec + .iter() + .zip(encoded_folded_message.iter()) + .enumerate() + { + assert_eq!(a, b, "Failed at index {} in round {}", i, round); + } + folded_codeword = FieldType::Ext(folded_codeword_vec); + } + } +} diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs new file mode 100644 index 000000000..051a510e6 --- /dev/null +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -0,0 +1,451 @@ +use std::marker::PhantomData; + +use super::{concatenate_field_types, EncodingProverParameters, EncodingScheme}; +use crate::{ + util::{ + arithmetic::base_from_raw_bytes, log2_strict, num_of_bytes, plonky2_util::reverse_bits, + }, + vec_mut, Error, +}; +use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; +use ark_std::{end_timer, start_timer}; +use ff::{BatchInverter, Field, PrimeField}; +use ff_ext::ExtensionField; +use generic_array::GenericArray; +use multilinear_extensions::mle::FieldType; +use rand::SeedableRng; +use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; + +use itertools::Itertools; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::util::plonky2_util::reverse_index_bits_in_place; +use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; +use rayon::prelude::IntoParallelRefIterator; + +use crate::util::arithmetic::{horner, steps}; + +pub trait BasecodeSpec: std::fmt::Debug + Clone { + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; +} + +#[derive(Debug, Clone)] +pub struct BasecodeDefaultSpec {} + +impl BasecodeSpec for BasecodeDefaultSpec { + fn get_number_queries() -> usize { + 766 + } + + fn get_rate_log() -> usize { + 3 + } + + fn get_basecode_msg_size_log() -> usize { + 7 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasecodeParameters { + pub(crate) table: Vec>, + pub(crate) table_w_weights: Vec>, + pub(crate) rng_seed: [u8; 32], +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasecodeProverParameters { + pub(crate) table: Vec>, + pub(crate) table_w_weights: Vec>, + pub(crate) rng_seed: [u8; 32], + #[serde(skip)] + _phantom: PhantomData Spec>, +} + +impl EncodingProverParameters + for BasecodeProverParameters +{ + fn get_max_message_size_log(&self) -> usize { + self.table.len() - Spec::get_rate_log() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BasecodeVerifierParameters { + pub(crate) rng_seed: [u8; 32], + pub(crate) aes_key: [u8; 16], + pub(crate) aes_iv: [u8; 16], +} + +#[derive(Debug, Clone)] +pub struct Basecode { + _phantom_data: PhantomData, +} + +impl EncodingScheme for Basecode +where + E::BaseField: Serialize + DeserializeOwned, +{ + type PublicParameters = BasecodeParameters; + + type ProverParameters = BasecodeProverParameters; + + type VerifierParameters = BasecodeVerifierParameters; + + fn setup(max_msg_size_log: usize, rng_seed: [u8; 32]) -> Self::PublicParameters { + let rng = ChaCha8Rng::from_seed(rng_seed); + let (table_w_weights, table) = + get_table_aes::(max_msg_size_log, Spec::get_rate_log(), &mut rng.clone()); + BasecodeParameters { + table, + table_w_weights, + rng_seed, + } + } + + fn trim( + pp: &Self::PublicParameters, + max_msg_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { + if pp.table.len() < Spec::get_rate_log() + max_msg_size_log { + return Err(Error::InvalidPcsParam(format!( + "Public parameter is setup for a smaller message size (log={}) than the trimmed message size (log={})", + pp.table.len() - Spec::get_rate_log(), + max_msg_size_log, + ))); + } + let mut key: [u8; 16] = [0u8; 16]; + let mut iv: [u8; 16] = [0u8; 16]; + let mut rng = ChaCha8Rng::from_seed(pp.rng_seed); + rng.set_word_pos(0); + rng.fill_bytes(&mut key); + rng.fill_bytes(&mut iv); + Ok(( + Self::ProverParameters { + table_w_weights: pp.table_w_weights.clone(), + table: pp.table.clone(), + rng_seed: pp.rng_seed, + _phantom: PhantomData, + }, + Self::VerifierParameters { + rng_seed: pp.rng_seed, + aes_key: key, + aes_iv: iv, + }, + )) + } + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType { + // Split the input into chunks of message size, encode each message, and return the codewords + let basecode = encode_field_type_rs_basecode( + coeffs, + 1 << Spec::get_rate_log(), + 1 << Spec::get_basecode_msg_size_log(), + ); + + // Apply the recursive definition of the BaseFold code to the list of base codewords, + // and produce the final codeword + evaluate_over_foldable_domain_generic_basecode::( + 1 << Spec::get_basecode_msg_size_log(), + coeffs.len(), + Spec::get_rate_log(), + basecode, + &pp.table, + ) + } + + fn encode_small(_vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType { + let mut basecodes = + encode_field_type_rs_basecode(coeffs, 1 << Spec::get_rate_log(), coeffs.len()); + assert_eq!(basecodes.len(), 1); + basecodes.remove(0) + } + + fn get_number_queries() -> usize { + Spec::get_number_queries() + } + + fn get_rate_log() -> usize { + Spec::get_rate_log() + } + + fn get_basecode_msg_size_log() -> usize { + Spec::get_basecode_msg_size_log() + } + + fn message_is_left_and_right_folding() -> bool { + true + } + + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E) { + let level = &pp.table_w_weights[level]; + ( + E::from(level[index].0), + E::from(-level[index].0), + E::from(level[index].1), + ) + } + + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E) { + type Aes128Ctr64LE = ctr::Ctr32LE; + let mut cipher = Aes128Ctr64LE::new( + GenericArray::from_slice(&vp.aes_key[..]), + GenericArray::from_slice(&vp.aes_iv[..]), + ); + + let x0: E::BaseField = query_root_table_from_rng_aes::(level, index, &mut cipher); + let x1 = -x0; + + let w = (x1 - x0).invert().unwrap(); + + (E::from(x0), E::from(x1), E::from(w)) + } +} + +fn encode_field_type_rs_basecode( + poly: &FieldType, + rate: usize, + message_size: usize, +) -> Vec> { + match poly { + FieldType::Ext(poly) => get_basecode(poly, rate, message_size) + .iter() + .map(|x| FieldType::Ext(x.clone())) + .collect(), + FieldType::Base(poly) => get_basecode(poly, rate, message_size) + .iter() + .map(|x| FieldType::Base(x.clone())) + .collect(), + _ => panic!("Unsupported field type"), + } +} + +// Split the input into chunks of message size, encode each message, and return the codewords +// FIXME: It is expensive for now because it is using naive FFT (although it is +// over a small domain) +fn get_basecode(poly: &Vec, rate: usize, message_size: usize) -> Vec> { + let timer = start_timer!(|| "Encode basecode"); + // The domain is just counting 1, 2, 3, ... , domain_size + let domain: Vec = steps(F::ONE).take(message_size * rate).collect(); + let res = poly + .par_chunks_exact(message_size) + .map(|chunk| { + let mut target = vec![F::ZERO; message_size * rate]; + // Just Reed-Solomon code, but with the naive domain + target + .iter_mut() + .enumerate() + .for_each(|(i, target)| *target = horner(chunk, &domain[i])); + target + }) + .collect::>>(); + end_timer!(timer); + + res +} + +// this function assumes all codewords in base_codeword has equivalent length +pub fn evaluate_over_foldable_domain_generic_basecode( + base_message_length: usize, + num_coeffs: usize, + log_rate: usize, + base_codewords: Vec>, + table: &[Vec], +) -> FieldType { + let timer = start_timer!(|| "evaluate over foldable domain"); + let k = num_coeffs; + let logk = log2_strict(k); + let base_log_k = log2_strict(base_message_length); + // concatenate together all base codewords + // let now = Instant::now(); + let mut coeffs_with_bc = concatenate_field_types(&base_codewords); + // println!("concatenate base codewords {:?}", now.elapsed()); + // iterate over array, replacing even indices with (evals[i] - evals[(i+1)]) + let mut chunk_size = base_codewords[0].len(); // block length of the base code + for i in base_log_k..logk { + // In beginning of each iteration, the current codeword size is 1<> 1); + vec_mut!(coeffs_with_bc, |c| { + c.par_chunks_mut(chunk_size).for_each(|chunk| { + let half_chunk = chunk_size >> 1; + for j in half_chunk..chunk_size { + // Suppose the current codewords are (a, b) + // The new codeword is computed by two halves: + // left = a + t * b + // right = a - t * b + let rhs = chunk[j] * level[j - half_chunk]; + chunk[j] = chunk[j - half_chunk] - rhs; + chunk[j - half_chunk] += rhs; + } + }); + }); + } + end_timer!(timer); + coeffs_with_bc +} + +#[allow(clippy::type_complexity)] +pub fn get_table_aes( + poly_size_log: usize, + rate: usize, + rng: &mut Rng, +) -> ( + Vec>, + Vec>, +) { + // The size (logarithmic) of the codeword for the polynomial + let lg_n: usize = rate + poly_size_log; + + let mut key: [u8; 16] = [0u8; 16]; + let mut iv: [u8; 16] = [0u8; 16]; + rng.fill_bytes(&mut key); + rng.fill_bytes(&mut iv); + + type Aes128Ctr64LE = ctr::Ctr32LE; + + let mut cipher = Aes128Ctr64LE::new( + GenericArray::from_slice(&key[..]), + GenericArray::from_slice(&iv[..]), + ); + + // Allocate the buffer for storing n field elements (the entire codeword) + let bytes = num_of_bytes::(1 << lg_n); + let mut dest: Vec = vec![0u8; bytes]; + cipher.apply_keystream(&mut dest[..]); + + // Now, dest is a vector filled with random data for a field vector of size n + + // Collect the bytes into field elements + let flat_table: Vec = dest + .par_chunks_exact(num_of_bytes::(1)) + .map(|chunk| base_from_raw_bytes::(chunk)) + .collect::>(); + + // Now, flat_table is a field vector of size n, filled with random field elements + assert_eq!(flat_table.len(), 1 << lg_n); + + // Multiply -2 to every element to get the weights. Now weights = { -2x } + let mut weights: Vec = flat_table + .par_iter() + .map(|el| E::BaseField::ZERO - *el - *el) + .collect(); + + // Then invert all the elements. Now weights = { -1/2x } + let mut scratch_space = vec![E::BaseField::ZERO; weights.len()]; + BatchInverter::invert_with_external_scratch(&mut weights, &mut scratch_space); + + // Zip x and -1/2x together. The result is the list { (x, -1/2x) } + // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which + // involves computing 1/(b-a) where b=-x and a=x, and 1/(b-a) here is exactly -1/2x + let flat_table_w_weights = flat_table + .iter() + .zip(weights) + .map(|(el, w)| (*el, w)) + .collect_vec(); + + // Split the positions from 0 to n-1 into slices of sizes: + // 2, 2, 4, 8, ..., n/2, exactly lg_n number of them + // The weights are (x, -1/2x), the table elements are just x + + let mut unflattened_table_w_weights = vec![Vec::new(); lg_n]; + let mut unflattened_table = vec![Vec::new(); lg_n]; + + unflattened_table_w_weights[0] = flat_table_w_weights[1..2].to_vec(); + unflattened_table[0] = flat_table[1..2].to_vec(); + for i in 1..lg_n { + unflattened_table[i] = flat_table[(1 << i)..(1 << (i + 1))].to_vec(); + let mut level = flat_table_w_weights[(1 << i)..(1 << (i + 1))].to_vec(); + reverse_index_bits_in_place(&mut level); + unflattened_table_w_weights[i] = level; + } + + (unflattened_table_w_weights, unflattened_table) +} + +pub fn query_root_table_from_rng_aes( + level: usize, + index: usize, + cipher: &mut ctr::Ctr32LE, +) -> E::BaseField { + let mut level_offset: u128 = 1; + for lg_m in 1..=level { + let half_m = 1 << (lg_m - 1); + level_offset += half_m; + } + + let pos = ((level_offset + (reverse_bits(index, level) as u128)) + * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) + .checked_div(8) + .unwrap(); + + cipher.seek(pos); + + let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; + let mut dest: Vec = vec![0u8; bytes]; + cipher.apply_keystream(&mut dest); + + base_from_raw_bytes::(&dest) +} + +#[cfg(test)] +mod tests { + use crate::basefold::encoding::test_util::test_codeword_folding; + + use super::*; + use goldilocks::GoldilocksExt2; + use multilinear_extensions::mle::DenseMultilinearExtension; + + #[test] + fn time_rs_code() { + use rand::rngs::OsRng; + + let poly = DenseMultilinearExtension::random(20, &mut OsRng); + + encode_field_type_rs_basecode::(&poly.evaluations, 2, 64); + } + + #[test] + fn prover_verifier_consistency() { + type Code = Basecode; + let pp: BasecodeParameters = Code::setup(10, [0; 32]); + let (pp, vp) = Code::trim(&pp, 10).unwrap(); + for level in 0..(10 + >::get_rate_log()) { + for index in 0..(1 << level) { + assert_eq!( + Code::prover_folding_coeffs(&pp, level, index), + Code::verifier_folding_coeffs(&vp, level, index), + "failed for level = {}, index = {}", + level, + index + ); + } + } + } + + #[test] + fn test_basecode_codeword_folding() { + test_codeword_folding::>(); + } +} diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs new file mode 100644 index 000000000..79c047060 --- /dev/null +++ b/mpcs/src/basefold/encoding/rs.rs @@ -0,0 +1,877 @@ +use std::marker::PhantomData; + +use super::{EncodingProverParameters, EncodingScheme}; +use crate::{ + util::{field_type_index_mul_base, log2_strict, plonky2_util::reverse_bits}, + vec_mut, Error, +}; +use ark_std::{end_timer, start_timer}; +use ff::{Field, PrimeField}; +use ff_ext::ExtensionField; +use multilinear_extensions::mle::FieldType; + +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::util::plonky2_util::reverse_index_bits_in_place; + +use crate::util::arithmetic::horner; + +pub trait RSCodeSpec: std::fmt::Debug + Clone { + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; +} + +/// The FFT codes in this file are borrowed and adapted from Plonky2. +type FftRootTable = Vec>; + +pub fn fft_root_table(lg_n: usize) -> FftRootTable { + // bases[i] = g^2^i, for i = 0, ..., lg_n - 1 + // Note that the end of bases is g^{n/2} = -1 + let mut bases = Vec::with_capacity(lg_n); + let mut base = F::ROOT_OF_UNITY.pow([(1 << (F::S - lg_n as u32)) as u64]); + bases.push(base); + for _ in 1..lg_n { + base = base.square(); // base = g^2^_ + bases.push(base); + } + + // The result table looks like this: + // len=2: [1, g^{n/2}=-1] + // len=2: [1, g^{n/4}] + // len=4: [1, g^{n/8}, g^{n/4}, g^{3n/8}] + // len=8: [1, g^{n/16}, ..., g^{7n/16}] + // ... + // len=n/2: [1, g, ..., g^{n/2-1}] + // There is no need to compute the other halves of these powers, because + // those would be simply the negations of the previous halves. + let mut root_table = Vec::with_capacity(lg_n); + for lg_m in 1..=lg_n { + let half_m = 1 << (lg_m - 1); + let base = bases[lg_n - lg_m]; + let mut root_row = Vec::with_capacity(half_m.max(2)); + root_row.push(F::ONE); + for i in 1..half_m.max(2) { + root_row.push(root_row[i - 1] * base); + } + root_table.push(root_row); + } + root_table +} + +#[allow(unused)] +fn ifft( + poly: &mut FieldType, + zero_factor: usize, + root_table: &FftRootTable, +) { + let n = poly.len(); + let lg_n = log2_strict(n); + let n_inv = (E::BaseField::ONE + E::BaseField::ONE) + .invert() + .unwrap() + .pow([lg_n as u64]); + + fft(poly, zero_factor, root_table); + + // We reverse all values except the first, and divide each by n. + field_type_index_mul_base(poly, 0, &n_inv); + field_type_index_mul_base(poly, n / 2, &n_inv); + vec_mut!(|poly| for i in 1..(n / 2) { + let j = n - i; + let coeffs_i = poly[j] * n_inv; + let coeffs_j = poly[i] * n_inv; + poly[i] = coeffs_i; + poly[j] = coeffs_j; + }) +} + +/// Core FFT implementation. +fn fft_classic_inner( + values: &mut FieldType, + r: usize, + lg_n: usize, + root_table: &[Vec], +) { + // We've already done the first lg_packed_width (if they were required) iterations. + + for (lg_half_m, cur_root_table) in root_table.iter().enumerate().take(lg_n).skip(r) { + let n = 1 << lg_n; + let lg_m = lg_half_m + 1; + let m = 1 << lg_m; // Subarray size (in field elements). + let half_m = m / 2; + debug_assert!(half_m != 0); + + // omega values for this iteration, as slice of vectors + let omega_table = &cur_root_table[..]; + vec_mut!(|values| { + for k in (0..n).step_by(m) { + for j in 0..half_m { + let omega = omega_table[j]; + let t = values[k + half_m + j] * omega; + let u = values[k + j]; + values[k + j] = u + t; + values[k + half_m + j] = u - t; + } + } + }) + } +} + +/// FFT implementation based on Section 32.3 of "Introduction to +/// Algorithms" by Cormen et al. +/// +/// The parameter r signifies that the first 1/2^r of the entries of +/// input may be non-zero, but the last 1 - 1/2^r entries are +/// definitely zero. +pub fn fft( + values: &mut FieldType, + r: usize, + root_table: &[Vec], +) { + vec_mut!(|values| reverse_index_bits_in_place(values)); + + let n = values.len(); + let lg_n = log2_strict(n); + + if root_table.len() != lg_n { + panic!( + "Expected root table of length {}, but it was {}.", + lg_n, + root_table.len() + ); + } + + // After reverse_index_bits, the only non-zero elements of values + // are at indices i*2^r for i = 0..n/2^r. The loop below copies + // the value at i*2^r to the positions [i*2^r + 1, i*2^r + 2, ..., + // (i+1)*2^r - 1]; i.e. it replaces the 2^r - 1 zeros following + // element i*2^r with the value at i*2^r. This corresponds to the + // first r rounds of the FFT when there are 2^r zeros at the end + // of the original input. + if r > 0 { + // if r == 0 then this loop is a noop. + let mask = !((1 << r) - 1); + match values { + FieldType::Base(values) => { + for i in 0..n { + values[i] = values[i & mask]; + } + } + FieldType::Ext(values) => { + for i in 0..n { + values[i] = values[i & mask]; + } + } + _ => panic!("Unsupported field type"), + } + } + + fft_classic_inner::(values, r, lg_n, root_table); +} + +pub fn coset_fft( + coeffs: &mut FieldType, + shift: E::BaseField, + zero_factor: usize, + root_table: &[Vec], +) { + let mut shift_power = E::BaseField::ONE; + vec_mut!(|coeffs| { + for coeff in coeffs.iter_mut() { + *coeff *= shift_power; + shift_power *= shift; + } + }); + fft(coeffs, zero_factor, root_table); +} + +#[derive(Debug, Clone)] +pub struct RSCodeDefaultSpec {} + +impl RSCodeSpec for RSCodeDefaultSpec { + fn get_number_queries() -> usize { + 336 + } + + fn get_rate_log() -> usize { + 3 + } + + fn get_basecode_msg_size_log() -> usize { + 7 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct RSCodeParameters { + pub(crate) fft_root_table: FftRootTable, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct RSCodeProverParameters { + pub(crate) fft_root_table: FftRootTable, + pub(crate) gamma_powers: Vec, + pub(crate) gamma_powers_inv_div_two: Vec, + pub(crate) full_message_size_log: usize, +} + +impl EncodingProverParameters for RSCodeProverParameters { + fn get_max_message_size_log(&self) -> usize { + self.full_message_size_log + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RSCodeVerifierParameters +where + E::BaseField: Serialize + DeserializeOwned, +{ + /// The verifier also needs a FFT table (much smaller) + /// for small-size encoding. It contains the same roots as the + /// prover's version for the first few levels (i < basecode_msg_size_log) + /// For the other levels (i >= basecode_msg_size_log), + /// it contains only the g^(2^i). + pub(crate) fft_root_table: FftRootTable, + pub(crate) full_message_size_log: usize, + pub(crate) gamma_powers: Vec, + pub(crate) gamma_powers_inv_div_two: Vec, +} + +#[derive(Debug, Clone)] +pub struct RSCode { + _phantom_data: PhantomData, +} + +impl EncodingScheme for RSCode +where + E::BaseField: Serialize + DeserializeOwned, +{ + type PublicParameters = RSCodeParameters; + + type ProverParameters = RSCodeProverParameters; + + type VerifierParameters = RSCodeVerifierParameters; + + fn setup(max_message_size_log: usize, _rng_seed: [u8; 32]) -> Self::PublicParameters { + RSCodeParameters { + fft_root_table: fft_root_table(max_message_size_log + Spec::get_rate_log()), + } + } + + fn trim( + pp: &Self::PublicParameters, + max_message_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { + if pp.fft_root_table.len() < max_message_size_log + Spec::get_rate_log() { + return Err(Error::InvalidPcsParam(format!( + "Public parameter is setup for a smaller message size (log={}) than the trimmed message size (log={})", + pp.fft_root_table.len() - Spec::get_rate_log(), + max_message_size_log, + ))); + } + let mut gamma_powers = Vec::with_capacity(max_message_size_log); + let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); + gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); + gamma_powers_inv.push(E::BaseField::MULTIPLICATIVE_GENERATOR.invert().unwrap()); + for i in 1..max_message_size_log + Spec::get_rate_log() { + gamma_powers.push(gamma_powers[i - 1].square()); + gamma_powers_inv.push(gamma_powers_inv[i - 1].square()); + } + let inv_of_two = E::BaseField::from(2).invert().unwrap(); + gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); + Ok(( + Self::ProverParameters { + fft_root_table: pp.fft_root_table[..max_message_size_log + Spec::get_rate_log()] + .to_vec(), + gamma_powers: gamma_powers.clone(), + gamma_powers_inv_div_two: gamma_powers_inv.clone(), + full_message_size_log: max_message_size_log, + }, + Self::VerifierParameters { + fft_root_table: pp.fft_root_table + [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] + .iter() + .cloned() + .chain( + pp.fft_root_table + [Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] + .iter() + .map(|v| vec![v[1]]), + ) + .collect(), + full_message_size_log: max_message_size_log, + gamma_powers, + gamma_powers_inv_div_two: gamma_powers_inv, + }, + )) + } + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType { + // Use the full message size to determine the shift factor. + Self::encode_internal(&pp.fft_root_table, coeffs, pp.full_message_size_log) + } + + fn encode_small(vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType { + // Use the full message size to determine the shift factor. + Self::encode_internal(&vp.fft_root_table, coeffs, vp.full_message_size_log) + } + + fn get_number_queries() -> usize { + Spec::get_number_queries() + } + + fn get_rate_log() -> usize { + Spec::get_rate_log() + } + + fn get_basecode_msg_size_log() -> usize { + Spec::get_basecode_msg_size_log() + } + + fn message_is_left_and_right_folding() -> bool { + false + } + + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // level is the logarithmic of the codeword size after folded. + // Therefore, the domain after folded is gamma^2^(full_log_n - level) H + // where H is the multiplicative subgroup of size 2^level. + // The element at index i in this domain is + // gamma^2^(full_log_n - level) * ((2^level)-th root of unity)^i + // The x0 and x1 are exactly the two square roots, i.e., + // x0 = gamma^2^(full_log_n - level - 1) * ((2^(level+1))-th root of unity)^i + // Since root_table[i] stores the first half of the powers of + // the 2^(i+1)-th roots of unity, we can avoid recomputing them. + let x0 = if index < (1 << level) { + pp.fft_root_table[level][index] + } else { + -pp.fft_root_table[level][index - (1 << level)] + } * pp.gamma_powers[pp.full_message_size_log + Spec::get_rate_log() - level - 1]; + let x1 = -x0; + // The weight is 1/(x1-x0) = -1/(2x0) + // = -1/2 * (gamma^{-1})^2^(full_codeword_log_n - level - 1) * ((2^(level+1))-th root of unity)^{2^(level+1)-i} + let w = -pp.gamma_powers_inv_div_two + [pp.full_message_size_log + Spec::get_rate_log() - level - 1] + * if index == 0 { + E::BaseField::ONE + } else if index < (1 << level) { + -pp.fft_root_table[level][(1 << level) - index] + } else if index == 1 << level { + -E::BaseField::ONE + } else { + pp.fft_root_table[level][(1 << (level + 1)) - index] + }; + (E::from(x0), E::from(x1), E::from(w)) + } + + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // The same as prover_folding_coeffs, exept that the powers of + // g is computed on the fly for levels exceeding the root table. + let x0 = if level < Spec::get_basecode_msg_size_log() + Spec::get_rate_log() { + if index < (1 << level) { + vp.fft_root_table[level][index] + } else { + -vp.fft_root_table[level][index - (1 << level)] + } + } else { + // In this case, the level-th row of fft root table of the verifier + // only stores the first 2^(level+1)-th roots of unity. + vp.fft_root_table[level][0].pow([index as u64]) + } * vp.gamma_powers[vp.full_message_size_log + Spec::get_rate_log() - level - 1]; + let x1 = -x0; + // The weight is 1/(x1-x0) = -1/(2x0) + // = -1/2 * (gamma^{-1})^2^(full_log_n - level - 1) * ((2^(level+1))-th root of unity)^{2^(level+1)-i} + let w = -vp.gamma_powers_inv_div_two + [vp.full_message_size_log + Spec::get_rate_log() - level - 1] + * if level < Spec::get_basecode_msg_size_log() + Spec::get_rate_log() { + if index == 0 { + E::BaseField::ONE + } else if index < (1 << level) { + -vp.fft_root_table[level][(1 << level) - index] + } else if index == 1 << level { + -E::BaseField::ONE + } else { + vp.fft_root_table[level][(1 << (level + 1)) - index] + } + } else { + // In this case, this level of fft root table of the verifier + // only stores the first 2^(level+1)-th root of unity. + vp.fft_root_table[level][0].pow([(1 << (level + 1)) - index as u64]) + }; + (E::from(x0), E::from(x1), E::from(w)) + } +} + +impl RSCode { + fn encode_internal( + fft_root_table: &FftRootTable, + coeffs: &FieldType, + full_message_size_log: usize, + ) -> FieldType + where + E::BaseField: Serialize + DeserializeOwned, + { + let lg_m = log2_strict(coeffs.len()); + let fft_root_table = &fft_root_table[..lg_m + Spec::get_rate_log()]; + assert!( + lg_m <= full_message_size_log, + "Encoded message exceeds the maximum supported message size of the table." + ); + let rate = 1 << Spec::get_rate_log(); + let mut ret = match coeffs { + FieldType::Base(coeffs) => { + let mut coeffs = coeffs.clone(); + coeffs.extend(itertools::repeat_n( + E::BaseField::ZERO, + coeffs.len() * (rate - 1), + )); + FieldType::Base(coeffs) + } + FieldType::Ext(coeffs) => { + let mut coeffs = coeffs.clone(); + coeffs.extend(itertools::repeat_n(E::ZERO, coeffs.len() * (rate - 1))); + FieldType::Ext(coeffs) + } + _ => panic!("Unsupported field type"), + }; + // Let gamma be the multiplicative generator of the base field. + // The full domain is gamma H where H is the multiplicative subgroup + // of size n * rate. + // When the input message size is not n, but n/2^k, then the domain is + // gamma^2^k H. + let k = 1 << (full_message_size_log - lg_m); + coset_fft( + &mut ret, + E::BaseField::MULTIPLICATIVE_GENERATOR.pow([k]), + Spec::get_rate_log(), + fft_root_table, + ); + ret + } + + #[allow(unused)] + fn folding_coeffs_naive( + level: usize, + index: usize, + full_message_size_log: usize, + ) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // x0 is the index-th 2^(level+1)-th root of unity, multiplied by + // the shift factor at level+1, which is gamma^2^(full_codeword_log_n - level - 1). + let x0 = E::BaseField::ROOT_OF_UNITY + .pow([1 << (E::BaseField::S - (level as u32 + 1))]) + .pow([index as u64]) + * E::BaseField::MULTIPLICATIVE_GENERATOR + .pow([1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]); + let x1 = -x0; + let w = (x1 - x0).invert().unwrap(); + (E::from(x0), E::from(x1), E::from(w)) + } +} + +#[allow(unused)] +fn naive_fft(poly: &[E], rate: usize, shift: E::BaseField) -> Vec { + let timer = start_timer!(|| "Encode RSCode"); + let message_size = poly.len(); + let domain_size_bit = log2_strict(message_size * rate); + let root = E::BaseField::ROOT_OF_UNITY.pow([1 << (E::BaseField::S - domain_size_bit as u32)]); + // The domain is shift * H where H is the multiplicative subgroup of size + // message_size * rate. + let mut domain = Vec::::with_capacity(message_size * rate); + domain.push(shift); + for i in 1..message_size * rate { + domain.push(domain[i - 1] * root); + } + let mut res = vec![E::ZERO; message_size * rate]; + res.iter_mut() + .enumerate() + .for_each(|(i, target)| *target = horner(poly, &E::from(domain[i]))); + end_timer!(timer); + + res +} + +#[cfg(test)] +mod tests { + use crate::{ + basefold::encoding::test_util::test_codeword_folding, + util::{field_type_index_ext, plonky2_util::reverse_index_bits_in_place_field_type}, + }; + + use super::*; + use goldilocks::{Goldilocks, GoldilocksExt2}; + + #[test] + fn test_naive_fft() { + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let mut poly2 = FieldType::Ext(poly.clone()); + + let naive = naive_fft::(&poly, 1, Goldilocks::ONE); + + let root_table = fft_root_table(num_vars); + fft::(&mut poly2, 0, &root_table); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_naive_fft_with_shift() { + use rand::rngs::OsRng; + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)) + .map(|_| GoldilocksExt2::random(&mut OsRng)) + .collect(); + let mut poly2 = FieldType::Ext(poly.clone()); + + let naive = naive_fft::(&poly, 1, Goldilocks::MULTIPLICATIVE_GENERATOR); + + let root_table = fft_root_table(num_vars); + coset_fft::( + &mut poly2, + Goldilocks::MULTIPLICATIVE_GENERATOR, + 0, + &root_table, + ); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_naive_fft_with_rate() { + use rand::rngs::OsRng; + let num_vars = 5; + let rate_bits = 1; + + let poly: Vec = (0..(1 << num_vars)) + .map(|_| GoldilocksExt2::random(&mut OsRng)) + .collect(); + let mut poly2 = vec![GoldilocksExt2::ZERO; poly.len() * (1 << rate_bits)]; + poly2.as_mut_slice()[..poly.len()].copy_from_slice(poly.as_slice()); + let mut poly2 = FieldType::Ext(poly2.clone()); + + let naive = naive_fft::( + &poly, + 1 << rate_bits, + Goldilocks::MULTIPLICATIVE_GENERATOR, + ); + + let root_table = fft_root_table(num_vars + rate_bits); + coset_fft::( + &mut poly2, + Goldilocks::MULTIPLICATIVE_GENERATOR, + rate_bits, + &root_table, + ); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_ifft() { + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let mut poly = FieldType::Ext(poly); + let original = poly.clone(); + + let root_table = fft_root_table(num_vars); + fft::(&mut poly, 0, &root_table); + ifft::(&mut poly, 0, &root_table); + + assert_eq!(original, poly); + } + + #[test] + fn prover_verifier_consistency() { + type Code = RSCode; + let pp: RSCodeParameters = Code::setup(10, [0; 32]); + let (pp, vp) = Code::trim(&pp, 10).unwrap(); + for level in 0..(10 + >::get_rate_log()) { + for index in 0..(1 << level) { + let (naive_x0, naive_x1, naive_w) = + Code::folding_coeffs_naive(level, index, pp.full_message_size_log); + let (p_x0, p_x1, p_w) = Code::prover_folding_coeffs(&pp, level, index); + let (v_x0, v_x1, v_w) = Code::verifier_folding_coeffs(&vp, level, index); + // assert_eq!(v_w * (v_x1 - v_x0), GoldilocksExt2::ONE); + // assert_eq!(p_w * (p_x1 - p_x0), GoldilocksExt2::ONE); + assert_eq!( + (v_x0, v_x1, v_w, p_x0, p_x1, p_w), + (naive_x0, naive_x1, naive_w, naive_x0, naive_x1, naive_w), + "failed for level = {}, index = {}", + level, + index + ); + } + } + } + + #[test] + fn test_rs_codeword_folding() { + test_codeword_folding::>(); + } + + type E = GoldilocksExt2; + type F = Goldilocks; + type Code = RSCode; + + #[test] + pub fn test_colinearity() { + let num_vars = 10; + + let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp = >::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + reverse_index_bits_in_place_field_type(&mut codeword); + let challenge = E::from(2); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let codeword = match codeword { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + + for (i, (a, b)) in folded_codeword.iter().zip(codeword.chunks(2)).enumerate() { + let (x0, x1, _) = Code::prover_folding_coeffs( + &pp, + num_vars + >::get_rate_log() - 1, + i, + ); + // Check that (x0, b[0]), (x1, b[1]) and (challenge, a) are + // on the same line, i.e., + // (b[0]-a)/(x0-challenge) = (b[1]-a)/(x1-challenge) + // which is equivalent to + // (x0-challenge)*(b[1]-a) = (x1-challenge)*(b[0]-a) + assert_eq!( + (x0 - challenge) * (b[1] - a), + (x1 - challenge) * (b[0] - a), + "failed for i = {}", + i + ); + } + } + + #[test] + pub fn test_low_degree() { + let num_vars = 10; + + let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp = >::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + check_low_degree(&codeword, "low degree check for original codeword"); + let c0 = field_type_index_ext(&codeword, 0); + let c_mid = field_type_index_ext(&codeword, codeword.len() >> 1); + let c1 = field_type_index_ext(&codeword, 1); + let c_mid1 = field_type_index_ext(&codeword, (codeword.len() >> 1) + 1); + + reverse_index_bits_in_place_field_type(&mut codeword); + // After the bit inversion, the first element is still the first, + // but the middle one is switched to the second. + assert_eq!(c0, field_type_index_ext(&codeword, 0)); + assert_eq!(c_mid, field_type_index_ext(&codeword, 1)); + // The second element is placed at the middle, and the next to middle + // element is still at the place. + assert_eq!(c1, field_type_index_ext(&codeword, codeword.len() >> 1)); + assert_eq!( + c_mid1, + field_type_index_ext(&codeword, (codeword.len() >> 1) + 1) + ); + + // For RS codeword, the addition of the left and right halves is also + // a valid codeword + let codeword_vec = match &codeword { + FieldType::Ext(coeffs) => coeffs.clone(), + _ => panic!("Wrong field type"), + }; + let mut left_right_sum: Vec = codeword_vec + .chunks(2) + .map(|chunk| chunk[0] + chunk[1]) + .collect(); + assert_eq!(left_right_sum[0], c0 + c_mid); + reverse_index_bits_in_place(&mut left_right_sum); + assert_eq!(left_right_sum[1], c1 + c_mid1); + check_low_degree( + &FieldType::Ext(left_right_sum.clone()), + "check low degree of left+right", + ); + + // The the difference of the left and right halves is also + // a valid codeword after twisted by omega^(-i), regardless of the + // shift of the coset. + let mut left_right_diff: Vec = codeword_vec + .chunks(2) + .map(|chunk| chunk[0] - chunk[1]) + .collect(); + assert_eq!(left_right_diff[0], c0 - c_mid); + reverse_index_bits_in_place(&mut left_right_diff); + assert_eq!(left_right_diff[1], c1 - c_mid1); + let root_of_unity_inv = F::ROOT_OF_UNITY_INV + .pow([1 << (F::S as usize - log2_strict(left_right_diff.len()) - 1)]); + for (i, coeff) in left_right_diff.iter_mut().enumerate() { + *coeff *= root_of_unity_inv.pow([i as u64]); + } + assert_eq!(left_right_diff[0], c0 - c_mid); + assert_eq!(left_right_diff[1], (c1 - c_mid1) * root_of_unity_inv); + check_low_degree( + &FieldType::Ext(left_right_diff.clone()), + "check low degree of (left-right)*omega^(-i)", + ); + + let challenge = E::from(2); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let c_fold = folded_codeword[0]; + let c_fold1 = folded_codeword[folded_codeword.len() >> 1]; + let mut folded_codeword = FieldType::Ext(folded_codeword); + reverse_index_bits_in_place_field_type(&mut folded_codeword); + assert_eq!(c_fold, field_type_index_ext(&folded_codeword, 0)); + assert_eq!(c_fold1, field_type_index_ext(&folded_codeword, 1)); + + // The top level folding coefficient should have shift factor gamma + let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 0); + assert_eq!(folding_coeffs.0, E::from(F::MULTIPLICATIVE_GENERATOR)); + assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); + assert_eq!( + (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, + E::ONE + ); + // The three points (x0, c0), (x1, c_mid), (challenge, c_fold) should + // be colinear + assert_eq!( + (c_mid - c_fold) * (folding_coeffs.0 - challenge), + (c0 - c_fold) * (folding_coeffs.1 - challenge), + ); + // So the folded value should be equal to + // (gamma^{-1} * alpha * (c0 - c_mid) + (c0 + c_mid)) / 2 + assert_eq!( + c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), + challenge * (c0 - c_mid) + (c0 + c_mid) * F::MULTIPLICATIVE_GENERATOR + ); + assert_eq!( + c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), + challenge * left_right_diff[0] + left_right_sum[0] * F::MULTIPLICATIVE_GENERATOR + ); + assert_eq!( + c_fold * F::from(2), + challenge * left_right_diff[0] * F::MULTIPLICATIVE_GENERATOR.invert().unwrap() + + left_right_sum[0] + ); + + let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 1); + let root_of_unity = + F::ROOT_OF_UNITY.pow([1 << (F::S as usize - log2_strict(codeword.len()))]); + assert_eq!(root_of_unity.pow([codeword.len() as u64]), F::ONE); + assert_eq!(root_of_unity.pow([(codeword.len() >> 1) as u64]), -F::ONE); + assert_eq!( + folding_coeffs.0, + E::from(F::MULTIPLICATIVE_GENERATOR) + * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) + ); + assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); + assert_eq!( + (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, + E::ONE + ); + + // The folded codeword is the linear combination of the left+right and the + // twisted left-right vectors. + // The coefficients are respectively 1/2 and gamma^{-1}/2 * alpha. + // In another word, the folded codeword multipled by 2 is the linear + // combination by coeffs: 1 and gamma^{-1} * alpha + let gamma_inv = F::MULTIPLICATIVE_GENERATOR.invert().unwrap(); + let b = challenge * gamma_inv; + let folded_codeword_vec = match &folded_codeword { + FieldType::Ext(coeffs) => coeffs.clone(), + _ => panic!("Wrong field type"), + }; + assert_eq!( + c_fold * F::from(2), + left_right_diff[0] * b + left_right_sum[0] + ); + for (i, (c, (diff, sum))) in folded_codeword_vec + .iter() + .zip(left_right_diff.iter().zip(left_right_sum.iter())) + .enumerate() + { + assert_eq!(*c + c, *sum + b * diff, "failed for i = {}", i); + } + + check_low_degree(&folded_codeword, "low degree check for folded"); + } + + fn check_low_degree(codeword: &FieldType, message: &str) { + let mut codeword = codeword.clone(); + let codeword_bits = log2_strict(codeword.len()); + let root_table = fft_root_table(codeword_bits); + let original = codeword.clone(); + ifft(&mut codeword, 0, &root_table); + for i in (codeword.len() >> >::get_rate_log())..codeword.len() { + assert_eq!( + field_type_index_ext(&codeword, i), + E::ZERO, + "{}: zero check failed for i = {}", + message, + i + ) + } + fft(&mut codeword, 0, &root_table); + let original = match original { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + let codeword = match codeword { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + original + .iter() + .zip(codeword.iter()) + .enumerate() + .for_each(|(i, (a, b))| { + assert_eq!(a, b, "{}: failed for i = {}", message, i); + }); + } +} diff --git a/mpcs/src/basefold/encoding/utils.rs b/mpcs/src/basefold/encoding/utils.rs new file mode 100644 index 000000000..36137526a --- /dev/null +++ b/mpcs/src/basefold/encoding/utils.rs @@ -0,0 +1,35 @@ +#[macro_export] +macro_rules! vec_mut { + ($a:ident, |$tmp_a:ident| $op:expr) => { + match $a { + multilinear_extensions::mle::FieldType::Base(ref mut $tmp_a) => $op, + multilinear_extensions::mle::FieldType::Ext(ref mut $tmp_a) => $op, + _ => unreachable!(), + } + }; + (|$a:ident| $op:expr) => { + vec_mut!($a, |$a| $op) + }; +} + +#[macro_export] +macro_rules! vec_map { + ($a:ident, |$tmp_a:ident| $op:expr) => { + match &$a { + multilinear_extensions::mle::FieldType::Base(a) => { + let $tmp_a = &a[..]; + let out = $op; + multilinear_extensions::mle::FieldType::Base(out) + } + multilinear_extensions::mle::FieldType::Ext(a) => { + let $tmp_a = &a[..]; + let out = $op; + multilinear_extensions::mle::FieldType::Base(out) + } + _ => unreachable!(), + } + }; + (|$a:ident| $op:expr) => { + vec_map!($a, |$a| $op) + }; +} diff --git a/mpcs/src/basefold/query_phase.rs b/mpcs/src/basefold/query_phase.rs index dacd254a1..3a8796f16 100644 --- a/mpcs/src/basefold/query_phase.rs +++ b/mpcs/src/basefold/query_phase.rs @@ -1,43 +1,49 @@ -use super::basecode::{encode_rs_basecode, query_point}; use crate::util::{ arithmetic::{ - degree_2_eval, degree_2_zero_plus_one, inner_product, interpolate2, + degree_2_eval, degree_2_zero_plus_one, inner_product, interpolate2_weights, interpolate_over_boolean_hypercube, }, - ext_to_usize, + ext_to_usize, field_type_index_base, field_type_index_ext, hash::{Digest, Hasher}, log2_strict, merkle_tree::{MerklePathWithoutLeafOrRoot, MerkleTree}, - transcript::{TranscriptRead, TranscriptWrite}, }; -use aes::cipher::KeyIvInit; use ark_std::{end_timer, start_timer}; use core::fmt::Debug; -use ctr; use ff_ext::ExtensionField; -use generic_array::GenericArray; - use itertools::Itertools; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use transcript::Transcript; use multilinear_extensions::mle::FieldType; -use crate::util::plonky2_util::{reverse_bits, reverse_index_bits_in_place}; -use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; -use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use crate::util::plonky2_util::reverse_index_bits_in_place; +use rayon::{ + iter::IndexedParallelIterator, + prelude::{IntoParallelRefIterator, ParallelIterator}, +}; -use super::structure::{BasefoldCommitment, BasefoldCommitmentWithData}; +use super::{ + encoding::EncodingScheme, + structure::{BasefoldCommitment, BasefoldCommitmentWithData, BasefoldSpec}, +}; -pub fn query_phase( - transcript: &mut impl TranscriptWrite, E>, +pub fn prover_query_phase( + transcript: &mut Transcript, comm: &BasefoldCommitmentWithData, - oracles: &Vec>, + oracles: &[Vec], num_verifier_queries: usize, ) -> QueriesResult where E::BaseField: Serialize + DeserializeOwned, { - let queries = transcript.squeeze_challenges(num_verifier_queries); + let queries: Vec<_> = (0..num_verifier_queries) + .map(|_| { + transcript + .get_and_append_challenge(b"query indices") + .elements + }) + .collect(); // Transform the challenge queries from field elements into integers let queries_usize: Vec = queries @@ -51,16 +57,316 @@ where .map(|x_index| { ( *x_index, - basefold_get_query::(comm.get_codeword(), &oracles, *x_index), + basefold_get_query::(&comm.get_codewords()[0], oracles, *x_index), + ) + }) + .collect(), + } +} + +pub fn batch_prover_query_phase( + transcript: &mut Transcript, + codeword_size: usize, + comms: &[BasefoldCommitmentWithData], + oracles: &[Vec], + num_verifier_queries: usize, +) -> BatchedQueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + let queries: Vec<_> = (0..num_verifier_queries) + .map(|_| { + transcript + .get_and_append_challenge(b"query indices") + .elements + }) + .collect(); + + // Transform the challenge queries from field elements into integers + let queries_usize: Vec = queries + .iter() + .map(|x_index| ext_to_usize(x_index) % codeword_size) + .collect_vec(); + + BatchedQueriesResult { + inner: queries_usize + .par_iter() + .map(|x_index| { + ( + *x_index, + batch_basefold_get_query::(comms, oracles, codeword_size, *x_index), + ) + }) + .collect(), + } +} + +pub fn simple_batch_prover_query_phase( + transcript: &mut Transcript, + comm: &BasefoldCommitmentWithData, + oracles: &[Vec], + num_verifier_queries: usize, +) -> SimpleBatchQueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + let queries: Vec<_> = (0..num_verifier_queries) + .map(|_| { + transcript + .get_and_append_challenge(b"query indices") + .elements + }) + .collect(); + + // Transform the challenge queries from field elements into integers + let queries_usize: Vec = queries + .iter() + .map(|x_index| ext_to_usize(x_index) % comm.codeword_size()) + .collect_vec(); + + SimpleBatchQueriesResult { + inner: queries_usize + .par_iter() + .map(|x_index| { + ( + *x_index, + simple_batch_basefold_get_query::(comm.get_codewords(), oracles, *x_index), ) }) .collect(), } } +#[allow(clippy::too_many_arguments)] +pub fn verifier_query_phase>( + indices: &[usize], + vp: &>::VerifierParameters, + queries: &QueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &Vec>, + comm: &BasefoldCommitment, + partial_eq: &[E], + eval: &E, + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier query phase"); + + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + interpolate_over_boolean_hypercube(&mut message); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + queries.check::( + indices, + vp, + fold_challenges, + num_rounds, + num_vars, + &final_codeword, + roots, + comm, + hasher, + ); + + let final_timer = start_timer!(|| "Final checks"); + assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + + end_timer!(timer); +} + +#[allow(clippy::too_many_arguments)] +pub fn batch_verifier_query_phase>( + indices: &[usize], + vp: &>::VerifierParameters, + queries: &BatchedQueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], + coeffs: &[E], + partial_eq: &[E], + eval: &E, + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier batch query phase"); + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + interpolate_over_boolean_hypercube(&mut message); + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + // For computing the weights on the fly, because the verifier is incapable of storing + // the weights. + + queries.check::( + indices, + vp, + fold_challenges, + num_rounds, + num_vars, + &final_codeword, + roots, + comms, + coeffs, + hasher, + ); + + #[allow(unused)] + let final_timer = start_timer!(|| "Final checks"); + assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + end_timer!(timer); +} + +#[allow(clippy::too_many_arguments)] +pub fn simple_batch_verifier_query_phase>( + indices: &[usize], + vp: &>::VerifierParameters, + queries: &SimpleBatchQueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + batch_coeffs: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + partial_eq: &[E], + evals: &[E], + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier query phase"); + + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + interpolate_over_boolean_hypercube(&mut message); + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + // For computing the weights on the fly, because the verifier is incapable of storing + // the weights. + queries.check::( + indices, + vp, + fold_challenges, + batch_coeffs, + num_rounds, + num_vars, + &final_codeword, + roots, + comm, + hasher, + ); + + let final_timer = start_timer!(|| "Final checks"); + assert_eq!( + &inner_product(batch_coeffs, evals), + °ree_2_zero_plus_one(&sum_check_messages[0]) + ); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + + end_timer!(timer); +} + fn basefold_get_query( poly_codeword: &FieldType, - oracles: &Vec>, + oracles: &[Vec], x_index: usize, ) -> SingleQueryResult where @@ -96,15 +402,15 @@ where inner: oracle_queries, }; - return SingleQueryResult { + SingleQueryResult { oracle_query, commitment_query, - }; + } } fn batch_basefold_get_query( - comms: &[&BasefoldCommitmentWithData], - oracles: &Vec>, + comms: &[BasefoldCommitmentWithData], + oracles: &[Vec], codeword_size: usize, x_index: usize, ) -> BatchedSingleQueryResult @@ -133,7 +439,7 @@ where let x_index = x_index >> (log2_strict(codeword_size) - comm.codeword_size_log()); let p1 = x_index | 1; let p0 = p1 - 1; - match comm.get_codeword() { + match &comm.get_codewords()[0] { FieldType::Ext(poly_codeword) => { CodewordSingleQueryResult::new_ext(poly_codeword[p0], poly_codeword[p1], p0) } @@ -155,42 +461,147 @@ where } } -#[derive(Debug, Copy, Clone, Serialize, Deserialize)] -enum CodewordPointPair { - Ext(E, E), - Base(E::BaseField, E::BaseField), -} - -impl CodewordPointPair { - pub fn as_ext(&self) -> (E, E) { - match self { - CodewordPointPair::Ext(x, y) => (*x, *y), - CodewordPointPair::Base(x, y) => (E::from(*x), E::from(*y)), - } - } -} - -#[derive(Debug, Copy, Clone, Serialize, Deserialize)] -struct CodewordSingleQueryResult -where - E::BaseField: Serialize + DeserializeOwned, -{ - codepoints: CodewordPointPair, - index: usize, -} - -impl CodewordSingleQueryResult +fn simple_batch_basefold_get_query( + poly_codewords: &[FieldType], + oracles: &[Vec], + x_index: usize, +) -> SimpleBatchSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, { - fn new_ext(left: E, right: E, index: usize) -> Self { - Self { - codepoints: CodewordPointPair::Ext(left, right), - index, - } - } + let mut index = x_index; + let p1 = index | 1; + let p0 = p1 - 1; - fn new_base(left: E::BaseField, right: E::BaseField, index: usize) -> Self { + let commitment_query = match poly_codewords[0] { + FieldType::Ext(_) => SimpleBatchCommitmentSingleQueryResult::new_ext( + poly_codewords + .iter() + .map(|c| field_type_index_ext(c, p0)) + .collect(), + poly_codewords + .iter() + .map(|c| field_type_index_ext(c, p1)) + .collect(), + p0, + ), + FieldType::Base(_) => SimpleBatchCommitmentSingleQueryResult::new_base( + poly_codewords + .iter() + .map(|c| field_type_index_base(c, p0)) + .collect(), + poly_codewords + .iter() + .map(|c| field_type_index_base(c, p1)) + .collect(), + p0, + ), + _ => unreachable!(), + }; + index >>= 1; + + let mut oracle_queries = Vec::with_capacity(oracles.len() + 1); + for oracle in oracles { + let p1 = index | 1; + let p0 = p1 - 1; + + oracle_queries.push(CodewordSingleQueryResult::new_ext( + oracle[p0], oracle[p1], p0, + )); + index >>= 1; + } + + let oracle_query = OracleListQueryResult { + inner: oracle_queries, + }; + + SimpleBatchSingleQueryResult { + oracle_query, + commitment_query, + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +enum CodewordPointPair { + Ext(E, E), + Base(E::BaseField, E::BaseField), +} + +impl CodewordPointPair { + pub fn as_ext(&self) -> (E, E) { + match self { + CodewordPointPair::Ext(x, y) => (*x, *y), + CodewordPointPair::Base(x, y) => (E::from(*x), E::from(*y)), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +enum SimpleBatchLeavesPair +where + E::BaseField: Serialize + DeserializeOwned, +{ + Ext(Vec<(E, E)>), + Base(Vec<(E::BaseField, E::BaseField)>), +} + +impl SimpleBatchLeavesPair +where + E::BaseField: Serialize + DeserializeOwned, +{ + #[allow(unused)] + pub fn as_ext(&self) -> Vec<(E, E)> { + match self { + SimpleBatchLeavesPair::Ext(x) => x.clone(), + SimpleBatchLeavesPair::Base(x) => { + x.iter().map(|(x, y)| ((*x).into(), (*y).into())).collect() + } + } + } + + pub fn batch(&self, coeffs: &[E]) -> (E, E) { + match self { + SimpleBatchLeavesPair::Ext(x) => { + let mut result = (E::ZERO, E::ZERO); + for (i, (x, y)) in x.iter().enumerate() { + result.0 += coeffs[i] * *x; + result.1 += coeffs[i] * *y; + } + result + } + SimpleBatchLeavesPair::Base(x) => { + let mut result = (E::ZERO, E::ZERO); + for (i, (x, y)) in x.iter().enumerate() { + result.0 += coeffs[i] * *x; + result.1 += coeffs[i] * *y; + } + result + } + } + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +struct CodewordSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + codepoints: CodewordPointPair, + index: usize, +} + +impl CodewordSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + fn new_ext(left: E, right: E, index: usize) -> Self { + Self { + codepoints: CodewordPointPair::Ext(left, right), + index, + } + } + + fn new_base(left: E::BaseField, right: E::BaseField, index: usize) -> Self { Self { codepoints: CodewordPointPair::Base(left, right), index, @@ -211,47 +622,21 @@ where } } - pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + pub fn write_transcript(&self, transcript: &mut Transcript) { match self.codepoints { CodewordPointPair::Ext(x, y) => { - transcript.write_field_element_ext(&x).unwrap(); - transcript.write_field_element_ext(&y).unwrap(); + transcript.append_field_element_ext(&x); + transcript.append_field_element_ext(&y); } CodewordPointPair::Base(x, y) => { - transcript.write_field_element_base(&x).unwrap(); - transcript.write_field_element_base(&y).unwrap(); + transcript.append_field_element(&x); + transcript.append_field_element(&y); } }; } - - pub fn read_transcript_ext( - transcript: &mut impl TranscriptRead, E>, - full_codeword_size_log: usize, - codeword_size_log: usize, - index: usize, - ) -> Self { - Self::new_ext( - transcript.read_field_element_ext().unwrap(), - transcript.read_field_element_ext().unwrap(), - index >> (full_codeword_size_log - codeword_size_log), - ) - } - - pub fn read_transcript_base( - transcript: &mut impl TranscriptRead, E>, - full_codeword_size_log: usize, - codeword_size_log: usize, - index: usize, - ) -> Self { - Self::new_base( - transcript.read_field_element_base().unwrap(), - transcript.read_field_element_base().unwrap(), - index >> (full_codeword_size_log - codeword_size_log), - ) - } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct CodewordSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, @@ -264,51 +649,11 @@ impl CodewordSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, { - pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + pub fn write_transcript(&self, transcript: &mut Transcript) { self.query.write_transcript(transcript); self.merkle_path.write_transcript(transcript); } - pub fn read_transcript_base( - transcript: &mut impl TranscriptRead, E>, - full_codeword_size_log: usize, - codeword_size_log: usize, - index: usize, - ) -> Self { - Self { - query: CodewordSingleQueryResult::read_transcript_base( - transcript, - full_codeword_size_log, - codeword_size_log, - index, - ), - merkle_path: MerklePathWithoutLeafOrRoot::read_transcript( - transcript, - codeword_size_log, - ), - } - } - - pub fn read_transcript_ext( - transcript: &mut impl TranscriptRead, E>, - full_codeword_size_log: usize, - codeword_size_log: usize, - index: usize, - ) -> Self { - Self { - query: CodewordSingleQueryResult::read_transcript_ext( - transcript, - full_codeword_size_log, - codeword_size_log, - index, - ), - merkle_path: MerklePathWithoutLeafOrRoot::read_transcript( - transcript, - codeword_size_log, - ), - } - } - pub fn check_merkle_path(&self, root: &Digest, hasher: &Hasher) { // let timer = start_timer!(|| "CodewordSingleQuery::Check Merkle Path"); match self.query.codepoints { @@ -351,7 +696,7 @@ where inner: Vec>, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct OracleListQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, @@ -359,35 +704,7 @@ where inner: Vec>, } -impl OracleListQueryResultWithMerklePath -where - E::BaseField: Serialize + DeserializeOwned, -{ - pub fn read_transcript( - transcript: &mut impl TranscriptRead, E>, - num_rounds: usize, - codeword_size_log: usize, - index: usize, - ) -> Self { - // Remember that the prover doesn't send the commitment in the last round. - // In the first round, the oracle is sent after folding, so the first oracle - // has half the size of the full codeword size. - Self { - inner: (0..num_rounds - 1) - .map(|round| { - CodewordSingleQueryResultWithMerklePath::read_transcript_ext( - transcript, - codeword_size_log, - codeword_size_log - round - 1, - index, - ) - }) - .collect(), - } - } -} - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct CommitmentsQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, @@ -395,55 +712,6 @@ where inner: Vec>, } -impl CommitmentsQueryResultWithMerklePath -where - E::BaseField: Serialize + DeserializeOwned, -{ - pub fn read_transcript_base( - transcript: &mut impl TranscriptRead, E>, - max_num_vars: usize, - poly_num_vars: &[usize], - log_rate: usize, - index: usize, - ) -> Self { - Self { - inner: poly_num_vars - .iter() - .map(|num_vars| { - CodewordSingleQueryResultWithMerklePath::read_transcript_base( - transcript, - max_num_vars + log_rate, - num_vars + log_rate, - index, - ) - }) - .collect(), - } - } - - pub fn read_transcript_ext( - transcript: &mut impl TranscriptRead, E>, - max_num_vars: usize, - poly_num_vars: &[usize], - log_rate: usize, - index: usize, - ) -> Self { - Self { - inner: poly_num_vars - .iter() - .map(|num_vars| { - CodewordSingleQueryResultWithMerklePath::read_transcript_ext( - transcript, - max_num_vars + log_rate, - num_vars + log_rate, - index, - ) - }) - .collect(), - } - } -} - impl ListQueryResult for OracleListQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -510,12 +778,9 @@ where ) -> Vec> { let ret = self .get_inner() - .into_iter() + .iter() .enumerate() - .map(|(i, query_result)| { - let path = path(i, query_result.index); - path - }) + .map(|(i, query_result)| path(i, query_result.index)) .collect_vec(); ret } @@ -537,7 +802,7 @@ where query_result .merkle_path(path) .into_iter() - .zip(query_result.get_inner_into().into_iter()) + .zip(query_result.get_inner_into()) .map( |(path, codeword_result)| CodewordSingleQueryResultWithMerklePath { query: codeword_result, @@ -548,13 +813,13 @@ where ) } - fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + fn write_transcript(&self, transcript: &mut Transcript) { self.get_inner() .iter() .for_each(|q| q.write_transcript(transcript)); } - fn check_merkle_paths(&self, roots: &Vec>, hasher: &Hasher) { + fn check_merkle_paths(&self, roots: &[Digest], hasher: &Hasher) { // let timer = start_timer!(|| "ListQuery::Check Merkle Path"); self.get_inner() .iter() @@ -575,7 +840,7 @@ where commitment_query: CodewordSingleQueryResult, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct SingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, @@ -590,16 +855,17 @@ where { pub fn from_single_query_result( single_query_result: SingleQueryResult, - oracle_trees: &Vec>, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { + assert!(commitment.codeword_tree.height() > 0); Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( single_query_result.oracle_query, |i, j| oracle_trees[i].merkle_path_without_leaf_sibling_or_root(j), ), commitment_query: CodewordSingleQueryResultWithMerklePath { - query: single_query_result.commitment_query.clone(), + query: single_query_result.commitment_query, merkle_path: commitment .codeword_tree .merkle_path_without_leaf_sibling_or_root( @@ -609,74 +875,28 @@ where } } - pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + pub fn write_transcript(&self, transcript: &mut Transcript) { self.oracle_query.write_transcript(transcript); self.commitment_query.write_transcript(transcript); } - pub fn read_transcript_base( - transcript: &mut impl TranscriptRead, E>, - num_rounds: usize, - log_rate: usize, - num_vars: usize, - index: usize, - ) -> Self { - Self { - oracle_query: OracleListQueryResultWithMerklePath::read_transcript( - transcript, - num_rounds, - num_vars + log_rate, - index, - ), - commitment_query: CodewordSingleQueryResultWithMerklePath::read_transcript_base( - transcript, - num_vars + log_rate, - num_vars + log_rate, - index, - ), - } - } - - pub fn read_transcript_ext( - transcript: &mut impl TranscriptRead, E>, - num_rounds: usize, - log_rate: usize, - num_vars: usize, - index: usize, - ) -> Self { - Self { - oracle_query: OracleListQueryResultWithMerklePath::read_transcript( - transcript, - num_rounds, - num_vars + log_rate, - index, - ), - commitment_query: CodewordSingleQueryResultWithMerklePath::read_transcript_ext( - transcript, - num_vars + log_rate, - num_vars + log_rate, - index, - ), - } - } - - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, - mut cipher: ctr::Ctr32LE, index: usize, hasher: &Hasher, ) { // let timer = start_timer!(|| "Checking codeword single query"); self.oracle_query.check_merkle_paths(roots, hasher); self.commitment_query - .check_merkle_path(&Digest(comm.root().0.try_into().unwrap()), hasher); + .check_merkle_path(&Digest(comm.root().0), hasher); let (mut curr_left, mut curr_right) = self.commitment_query.query.codepoints.as_ext(); @@ -684,18 +904,14 @@ where let mut left_index = right_index - 1; for i in 0..num_rounds { - // let round_timer = start_timer!(|| format!("SingleQueryResult::round {}", i)); - let ri0 = reverse_bits(left_index, num_vars + log_rate - i); - - let x0 = E::from(query_point::( - 1 << (num_vars + log_rate - i), - ri0, - num_vars + log_rate - i - 1, - &mut cipher, - )); - let x1 = -x0; + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, + ); - let res = interpolate2([(x0, curr_left), (x1, curr_right)], fold_challenges[i]); + let res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); let next_index = right_index >> 1; let next_oracle_value = if i < num_rounds - 1 { @@ -720,6 +936,88 @@ where } } +pub struct QueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + inner: Vec<(usize, SingleQueryResult)>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QueriesResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + inner: Vec<(usize, SingleQueryResultWithMerklePath)>, +} + +impl QueriesResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn from_query_result( + query_result: QueriesResult, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithData, + ) -> Self { + Self { + inner: query_result + .inner + .into_iter() + .map(|(i, q)| { + ( + i, + SingleQueryResultWithMerklePath::from_single_query_result( + q, + oracle_trees, + commitment, + ), + ) + }) + .collect(), + } + } + + pub fn write_transcript(&self, transcript: &mut Transcript) { + self.inner.iter().for_each(|(_, q)| { + q.write_transcript(transcript); + }); + } + + #[allow(clippy::too_many_arguments)] + pub fn check>( + &self, + indices: &[usize], + vp: &>::VerifierParameters, + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_codeword: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + hasher: &Hasher, + ) { + let timer = start_timer!(|| "QueriesResult::check"); + self.inner.par_iter().zip(indices.par_iter()).for_each( + |((index, query), index_in_proof)| { + assert_eq!(index_in_proof, index); + query.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + final_codeword, + roots, + comm, + *index, + hasher, + ); + }, + ); + end_timer!(timer); + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] struct BatchedSingleQueryResult where @@ -729,7 +1027,7 @@ where commitments_query: CommitmentsQueryResult, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct BatchedSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, @@ -744,8 +1042,8 @@ where { pub fn from_batched_single_query_result( batched_single_query_result: BatchedSingleQueryResult, - oracle_trees: &Vec>, - commitments: &Vec<&BasefoldCommitmentWithData>, + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithData], ) -> Self { Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( @@ -763,78 +1061,34 @@ where } } - pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + pub fn write_transcript(&self, transcript: &mut Transcript) { self.oracle_query.write_transcript(transcript); self.commitments_query.write_transcript(transcript); } - pub fn read_transcript_base( - transcript: &mut impl TranscriptRead, E>, + #[allow(clippy::too_many_arguments)] + pub fn check>( + &self, + vp: &>::VerifierParameters, + fold_challenges: &[E], num_rounds: usize, - log_rate: usize, - poly_num_vars: &[usize], - index: usize, - ) -> Self { - let num_vars = poly_num_vars.iter().max().unwrap(); - Self { - oracle_query: OracleListQueryResultWithMerklePath::read_transcript( - transcript, - num_rounds, - *num_vars + log_rate, - index, - ), - commitments_query: CommitmentsQueryResultWithMerklePath::read_transcript_base( - transcript, - *num_vars, - poly_num_vars, - log_rate, - index, - ), - } - } - - pub fn read_transcript_ext( - transcript: &mut impl TranscriptRead, E>, - num_rounds: usize, - log_rate: usize, - poly_num_vars: &[usize], - index: usize, - ) -> Self { - let num_vars = poly_num_vars.iter().max().unwrap(); - Self { - oracle_query: OracleListQueryResultWithMerklePath::read_transcript( - transcript, - num_rounds, - *num_vars + log_rate, - index, - ), - commitments_query: CommitmentsQueryResultWithMerklePath::read_transcript_ext( - transcript, - *num_vars, - poly_num_vars, - log_rate, - index, - ), - } - } - - pub fn check( - &self, - fold_challenges: &Vec, - num_rounds: usize, - num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, - coeffs: &[E], - mut cipher: ctr::Ctr32LE, + num_vars: usize, + final_codeword: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], + coeffs: &[E], index: usize, hasher: &Hasher, ) { self.oracle_query.check_merkle_paths(roots, hasher); - self.commitments_query - .check_merkle_paths(&comms.iter().map(|comm| comm.root()).collect(), hasher); + self.commitments_query.check_merkle_paths( + comms + .iter() + .map(|comm| comm.root()) + .collect_vec() + .as_slice(), + hasher, + ); // end_timer!(commit_timer); let mut curr_left = E::ZERO; @@ -845,7 +1099,6 @@ where for i in 0..num_rounds { // let round_timer = start_timer!(|| format!("BatchedSingleQueryResult::round {}", i)); - let ri0 = reverse_bits(left_index, num_vars + log_rate - i); let matching_comms = comms .iter() .enumerate() @@ -854,21 +1107,20 @@ where .collect_vec(); matching_comms.iter().for_each(|index| { - let query = self.commitments_query.get_inner()[*index].query.clone(); + let query = self.commitments_query.get_inner()[*index].query; assert_eq!(query.index >> 1, left_index >> 1); curr_left += query.left_ext() * coeffs[*index]; curr_right += query.right_ext() * coeffs[*index]; }); - let x0: E = E::from(query_point::( - 1 << (num_vars + log_rate - i), - ri0, - num_vars + log_rate - i - 1, - &mut cipher, - )); - let x1 = -x0; + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, + ); - let mut res = interpolate2([(x0, curr_left), (x1, curr_right)], fold_challenges[i]); + let mut res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); let next_index = right_index >> 1; @@ -901,7 +1153,7 @@ where matching_comms.iter().for_each(|index| { let query: CodewordSingleQueryResult = - self.commitments_query.get_inner()[*index].query.clone(); + self.commitments_query.get_inner()[*index].query; assert_eq!(query.index >> 1, next_index >> 1); if next_index & 1 == 0 { res += query.left_ext() * coeffs[*index]; @@ -928,6 +1180,7 @@ where inner: Vec<(usize, BatchedSingleQueryResult)>, } +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct BatchedQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, @@ -941,8 +1194,8 @@ where { pub fn from_batched_query_result( batched_query_result: BatchedQueriesResult, - oracle_trees: &Vec>, - commitments: &Vec<&BasefoldCommitmentWithData>, + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithData], ) -> Self { Self { inner: batched_query_result @@ -962,118 +1215,283 @@ where } } - pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + pub fn write_transcript(&self, transcript: &mut Transcript) { self.inner .iter() .for_each(|(_, q)| q.write_transcript(transcript)); } - pub fn read_transcript_base( - transcript: &mut impl TranscriptRead, E>, - num_rounds: usize, - log_rate: usize, - poly_num_vars: &[usize], + #[allow(clippy::too_many_arguments)] + pub fn check>( + &self, indices: &[usize], - ) -> Self { + vp: &>::VerifierParameters, + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_codeword: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], + coeffs: &[E], + hasher: &Hasher, + ) { + let timer = start_timer!(|| "BatchedQueriesResult::check"); + self.inner.par_iter().zip(indices.par_iter()).for_each( + |((index, query), index_in_proof)| { + assert_eq!(index, index_in_proof); + query.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + final_codeword, + roots, + comms, + coeffs, + *index, + hasher, + ); + }, + ); + end_timer!(timer); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleBatchCommitmentSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + leaves: SimpleBatchLeavesPair, + index: usize, +} + +impl SimpleBatchCommitmentSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + fn new_ext(left: Vec, right: Vec, index: usize) -> Self { Self { - inner: indices - .iter() - .map(|index| { - ( - *index, - BatchedSingleQueryResultWithMerklePath::read_transcript_base( - transcript, - num_rounds, - log_rate, - poly_num_vars, - *index, - ), - ) - }) - .collect(), + leaves: SimpleBatchLeavesPair::Ext(left.into_iter().zip(right).collect()), + index, } } - pub fn read_transcript_ext( - transcript: &mut impl TranscriptRead, E>, - num_rounds: usize, - log_rate: usize, - poly_num_vars: &[usize], - indices: &[usize], + fn new_base(left: Vec, right: Vec, index: usize) -> Self { + Self { + leaves: SimpleBatchLeavesPair::Base(left.into_iter().zip(right).collect()), + index, + } + } + + #[allow(unused)] + fn left_ext(&self) -> Vec { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => x.iter().map(|(x, _)| *x).collect(), + SimpleBatchLeavesPair::Base(x) => x.iter().map(|(x, _)| E::from(*x)).collect(), + } + } + + #[allow(unused)] + fn right_ext(&self) -> Vec { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => x.iter().map(|(_, x)| *x).collect(), + SimpleBatchLeavesPair::Base(x) => x.iter().map(|(_, x)| E::from(*x)).collect(), + } + } + + pub fn write_transcript(&self, transcript: &mut Transcript) { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => { + x.iter().for_each(|(x, y)| { + transcript.append_field_element_ext(x); + transcript.append_field_element_ext(y); + }); + } + SimpleBatchLeavesPair::Base(x) => { + x.iter().for_each(|(x, y)| { + transcript.append_field_element(x); + transcript.append_field_element(y); + }); + } + }; + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleBatchCommitmentSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + query: SimpleBatchCommitmentSingleQueryResult, + merkle_path: MerklePathWithoutLeafOrRoot, +} + +impl SimpleBatchCommitmentSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn write_transcript(&self, transcript: &mut Transcript) { + self.query.write_transcript(transcript); + self.merkle_path.write_transcript(transcript); + } + + pub fn check_merkle_path(&self, root: &Digest, hasher: &Hasher) { + // let timer = start_timer!(|| "CodewordSingleQuery::Check Merkle Path"); + match &self.query.leaves { + SimpleBatchLeavesPair::Ext(inner) => { + self.merkle_path.authenticate_batch_leaves_root_ext( + inner.iter().map(|(x, _)| *x).collect(), + inner.iter().map(|(_, x)| *x).collect(), + self.query.index, + root, + hasher, + ); + } + SimpleBatchLeavesPair::Base(inner) => { + self.merkle_path.authenticate_batch_leaves_root_base( + inner.iter().map(|(x, _)| *x).collect(), + inner.iter().map(|(_, x)| *x).collect(), + self.query.index, + root, + hasher, + ); + } + } + // end_timer!(timer); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleBatchSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + oracle_query: OracleListQueryResult, + commitment_query: SimpleBatchCommitmentSingleQueryResult, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleBatchSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + oracle_query: OracleListQueryResultWithMerklePath, + commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath, +} + +impl SimpleBatchSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn from_single_query_result( + single_query_result: SimpleBatchSingleQueryResult, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithData, ) -> Self { Self { - inner: indices - .iter() - .map(|index| { - ( - *index, - BatchedSingleQueryResultWithMerklePath::read_transcript_ext( - transcript, - num_rounds, - log_rate, - poly_num_vars, - *index, - ), - ) - }) - .collect(), + oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( + single_query_result.oracle_query, + |i, j| oracle_trees[i].merkle_path_without_leaf_sibling_or_root(j), + ), + commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath { + query: single_query_result.commitment_query.clone(), + merkle_path: commitment + .codeword_tree + .merkle_path_without_leaf_sibling_or_root( + single_query_result.commitment_query.index, + ), + }, } } - pub fn check( + pub fn write_transcript(&self, transcript: &mut Transcript) { + self.oracle_query.write_transcript(transcript); + self.commitment_query.write_transcript(transcript); + } + + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, - coeffs: &[E], - cipher: ctr::Ctr32LE, + final_codeword: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + index: usize, hasher: &Hasher, ) { - let timer = start_timer!(|| "BatchedQueriesResult::check"); - self.inner.par_iter().for_each(|(index, query)| { - query.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - final_codeword, - roots, - comms, - coeffs, - cipher.clone(), - *index, - hasher, + let timer = start_timer!(|| "Checking codeword single query"); + self.oracle_query.check_merkle_paths(roots, hasher); + self.commitment_query + .check_merkle_path(&Digest(comm.root().0), hasher); + + let (mut curr_left, mut curr_right) = + self.commitment_query.query.leaves.batch(batch_coeffs); + + let mut right_index = index | 1; + let mut left_index = right_index - 1; + + for i in 0..num_rounds { + // let round_timer = start_timer!(|| format!("SingleQueryResult::round {}", i)); + + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, ); - }); + + let res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); + + let next_index = right_index >> 1; + let next_oracle_value = if i < num_rounds - 1 { + right_index = next_index | 1; + left_index = right_index - 1; + let next_oracle_query = self.oracle_query.get_inner()[i].clone(); + (curr_left, curr_right) = next_oracle_query.query.codepoints.as_ext(); + if next_index & 1 == 0 { + curr_left + } else { + curr_right + } + } else { + // Note that final_codeword has been bit-reversed, so no need to bit-reverse + // next_index here. + final_codeword[next_index] + }; + assert_eq!(res, next_oracle_value, "Failed at round {}", i); + // end_timer!(round_timer); + } end_timer!(timer); } } -pub struct QueriesResult +pub struct SimpleBatchQueriesResult where E::BaseField: Serialize + DeserializeOwned, { - inner: Vec<(usize, SingleQueryResult)>, + inner: Vec<(usize, SimpleBatchSingleQueryResult)>, } -pub struct QueriesResultWithMerklePath +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, { - inner: Vec<(usize, SingleQueryResultWithMerklePath)>, + inner: Vec<(usize, SimpleBatchSingleQueryResultWithMerklePath)>, } -impl QueriesResultWithMerklePath +impl SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, { pub fn from_query_result( - query_result: QueriesResult, - oracle_trees: &Vec>, + query_result: SimpleBatchQueriesResult, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { Self { @@ -1083,7 +1501,7 @@ where .map(|(i, q)| { ( i, - SingleQueryResultWithMerklePath::from_single_query_result( + SimpleBatchSingleQueryResultWithMerklePath::from_single_query_result( q, oracle_trees, commitment, @@ -1094,287 +1512,44 @@ where } } - pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + pub fn write_transcript(&self, transcript: &mut Transcript) { self.inner .iter() .for_each(|(_, q)| q.write_transcript(transcript)); } - pub fn read_transcript_base( - transcript: &mut impl TranscriptRead, E>, - num_rounds: usize, - log_rate: usize, - poly_num_vars: usize, - indices: &[usize], - ) -> Self { - Self { - inner: indices - .iter() - .map(|index| { - ( - *index, - SingleQueryResultWithMerklePath::read_transcript_base( - transcript, - num_rounds, - log_rate, - poly_num_vars, - *index, - ), - ) - }) - .collect(), - } - } - - pub fn read_transcript_ext( - transcript: &mut impl TranscriptRead, E>, - num_rounds: usize, - log_rate: usize, - poly_num_vars: usize, - indices: &[usize], - ) -> Self { - Self { - inner: indices - .iter() - .map(|index| { - ( - *index, - SingleQueryResultWithMerklePath::read_transcript_ext( - transcript, - num_rounds, - log_rate, - poly_num_vars, - *index, - ), - ) - }) - .collect(), - } - } - - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + indices: &[usize], + vp: &>::VerifierParameters, + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, - cipher: ctr::Ctr32LE, hasher: &Hasher, ) { let timer = start_timer!(|| "QueriesResult::check"); - self.inner.par_iter().for_each(|(index, query)| { - query.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - final_codeword, - roots, - comm, - cipher.clone(), - *index, - hasher, - ); - }); - end_timer!(timer); - } -} - -pub fn batch_query_phase( - transcript: &mut impl TranscriptWrite, E>, - codeword_size: usize, - comms: &[&BasefoldCommitmentWithData], - oracles: &Vec>, - num_verifier_queries: usize, -) -> BatchedQueriesResult -where - E::BaseField: Serialize + DeserializeOwned, -{ - let queries = transcript.squeeze_challenges(num_verifier_queries); - - // Transform the challenge queries from field elements into integers - let queries_usize: Vec = queries - .iter() - .map(|x_index| ext_to_usize(x_index) % codeword_size) - .collect_vec(); - - BatchedQueriesResult { - inner: queries_usize - .par_iter() - .map(|x_index| { - ( - *x_index, - batch_basefold_get_query::(comms, &oracles, codeword_size, *x_index), - ) - }) - .collect(), - } -} - -pub fn verifier_query_phase( - queries: &QueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, - num_rounds: usize, - num_vars: usize, - log_rate: usize, - final_message: &Vec, - roots: &Vec>, - comm: &BasefoldCommitment, - partial_eq: &[E], - rng: ChaCha8Rng, - eval: &E, - hasher: &Hasher, -) where - E::BaseField: Serialize + DeserializeOwned, -{ - let timer = start_timer!(|| "Verifier query phase"); - - let encode_timer = start_timer!(|| "Encode final codeword"); - let mut message = final_message.clone(); - interpolate_over_boolean_hypercube(&mut message); - let mut final_codeword = encode_rs_basecode(&message, 1 << log_rate, message.len()); - assert_eq!(final_codeword.len(), 1); - let mut final_codeword = final_codeword.remove(0); - reverse_index_bits_in_place(&mut final_codeword); - end_timer!(encode_timer); - - // For computing the weights on the fly, because the verifier is incapable of storing - // the weights. - let aes_timer = start_timer!(|| "Initialize AES"); - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - let mut rng = rng.clone(); - rng.set_word_pos(0); - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - let cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - end_timer!(aes_timer); - - queries.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - &final_codeword, - roots, - comm, - cipher, - hasher, - ); - - let final_timer = start_timer!(|| "Final checks"); - assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); - - // The sum-check part of the protocol - for i in 0..fold_challenges.len() - 1 { - assert_eq!( - degree_2_eval(&sum_check_messages[i], fold_challenges[i]), - degree_2_zero_plus_one(&sum_check_messages[i + 1]) - ); - } - - // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial - // sent from the prover - assert_eq!( - degree_2_eval( - &sum_check_messages[fold_challenges.len() - 1], - fold_challenges[fold_challenges.len() - 1] - ), - inner_product(final_message, partial_eq) - ); - end_timer!(final_timer); - - end_timer!(timer); -} - -pub fn batch_verifier_query_phase( - queries: &BatchedQueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, - num_rounds: usize, - num_vars: usize, - log_rate: usize, - final_message: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, - coeffs: &[E], - partial_eq: &[E], - rng: ChaCha8Rng, - eval: &E, - hasher: &Hasher, -) where - E::BaseField: Serialize + DeserializeOwned, -{ - let timer = start_timer!(|| "Verifier batch query phase"); - let encode_timer = start_timer!(|| "Encode final codeword"); - let mut message = final_message.clone(); - interpolate_over_boolean_hypercube(&mut message); - let mut final_codeword = encode_rs_basecode(&message, 1 << log_rate, message.len()); - assert_eq!(final_codeword.len(), 1); - let mut final_codeword = final_codeword.remove(0); - reverse_index_bits_in_place(&mut final_codeword); - end_timer!(encode_timer); - - // For computing the weights on the fly, because the verifier is incapable of storing - // the weights. - let aes_timer = start_timer!(|| "Initialize AES"); - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - let mut rng = rng.clone(); - rng.set_word_pos(0); - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - let cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - end_timer!(aes_timer); - - queries.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - &final_codeword, - roots, - comms, - coeffs, - cipher, - hasher, - ); - - #[allow(unused)] - let final_timer = start_timer!(|| "Final checks"); - assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); - - // The sum-check part of the protocol - for i in 0..fold_challenges.len() - 1 { - assert_eq!( - degree_2_eval(&sum_check_messages[i], fold_challenges[i]), - degree_2_zero_plus_one(&sum_check_messages[i + 1]) + self.inner.par_iter().zip(indices.par_iter()).for_each( + |((index, query), index_in_proof)| { + assert_eq!(index, index_in_proof); + query.check::( + vp, + fold_challenges, + batch_coeffs, + num_rounds, + num_vars, + final_codeword, + roots, + comm, + *index, + hasher, + ); + }, ); + end_timer!(timer); } - - // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial - // sent from the prover - assert_eq!( - degree_2_eval( - &sum_check_messages[fold_challenges.len() - 1], - fold_challenges[fold_challenges.len() - 1] - ), - inner_product(final_message, partial_eq) - ); - end_timer!(final_timer); - end_timer!(timer); } diff --git a/mpcs/src/basefold/structure.rs b/mpcs/src/basefold/structure.rs index 8dd39e75b..e84b8dd76 100644 --- a/mpcs/src/basefold/structure.rs +++ b/mpcs/src/basefold/structure.rs @@ -1,45 +1,61 @@ -use crate::util::{hash::Digest, merkle_tree::MerkleTree}; +use crate::{ + sum_check::classic::{Coefficients, SumcheckProof}, + util::{hash::Digest, merkle_tree::MerkleTree}, +}; use core::fmt::Debug; use ff_ext::ExtensionField; +use rand::RngCore; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use multilinear_extensions::mle::FieldType; -use rand_chacha::rand_core::RngCore; +use rand_chacha::ChaCha8Rng; use std::{marker::PhantomData, slice}; +pub use super::encoding::{EncodingProverParameters, EncodingScheme, RSCode, RSCodeDefaultSpec}; +use super::{ + query_phase::{ + BatchedQueriesResultWithMerklePath, QueriesResultWithMerklePath, + SimpleBatchQueriesResultWithMerklePath, + }, + Basecode, BasecodeDefaultSpec, +}; + #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldParams +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldParams> where E::BaseField: Serialize + DeserializeOwned, { - pub(super) log_rate: usize, - pub(super) num_verifier_queries: usize, - pub(super) max_num_vars: usize, - pub(super) table_w_weights: Vec>, - pub(super) table: Vec>, - pub(super) rng: Rng, + pub(super) params: >::PublicParameters, } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldProverParams -where - E::BaseField: Serialize + DeserializeOwned, -{ - pub(super) log_rate: usize, - pub(super) table_w_weights: Vec>, - pub(super) table: Vec>, - pub(super) num_verifier_queries: usize, - pub(super) max_num_vars: usize, +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldProverParams> { + pub encoding_params: >::ProverParameters, +} + +impl> BasefoldProverParams { + pub fn get_max_message_size_log(&self) -> usize { + self.encoding_params.get_max_message_size_log() + } } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldVerifierParams { - pub(super) rng: Rng, - pub(super) max_num_vars: usize, - pub(super) log_rate: usize, - pub(super) num_verifier_queries: usize, +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldVerifierParams> { + pub(super) encoding_params: >::VerifierParameters, } /// A polynomial commitment together with all the data (e.g., the codeword, and Merkle tree) @@ -51,9 +67,10 @@ where E::BaseField: Serialize + DeserializeOwned, { pub(crate) codeword_tree: MerkleTree, - pub(crate) bh_evals: FieldType, + pub(crate) polynomials_bh_evals: Vec>, pub(crate) num_vars: usize, pub(crate) is_base: bool, + pub(crate) num_polys: usize, } impl BasefoldCommitmentWithData @@ -61,7 +78,12 @@ where E::BaseField: Serialize + DeserializeOwned, { pub fn to_commitment(&self) -> BasefoldCommitment { - BasefoldCommitment::new(self.codeword_tree.root(), self.num_vars, self.is_base) + BasefoldCommitment::new( + self.codeword_tree.root(), + self.num_vars, + self.is_base, + self.num_polys, + ) } pub fn get_root_ref(&self) -> &Digest { @@ -72,12 +94,16 @@ where Digest::(self.get_root_ref().0) } - pub fn get_codeword(&self) -> &FieldType { + pub fn get_codewords(&self) -> &Vec> { self.codeword_tree.leaves() } + pub fn batch_codewords(&self, coeffs: &[E]) -> Vec { + self.codeword_tree.batch_leaves(coeffs) + } + pub fn codeword_size(&self) -> usize { - self.codeword_tree.size() + self.codeword_tree.size().1 } pub fn codeword_size_log(&self) -> usize { @@ -85,14 +111,14 @@ where } pub fn poly_size(&self) -> usize { - self.bh_evals.len() + 1 << self.num_vars } - pub fn get_codeword_entry_base(&self, index: usize) -> E::BaseField { + pub fn get_codeword_entry_base(&self, index: usize) -> Vec { self.codeword_tree.get_leaf_as_base(index) } - pub fn get_codeword_entry_ext(&self, index: usize) -> E { + pub fn get_codeword_entry_ext(&self, index: usize) -> Vec { self.codeword_tree.get_leaf_as_extension(index) } @@ -101,21 +127,21 @@ where } } -impl Into> for BasefoldCommitmentWithData +impl From> for Digest where E::BaseField: Serialize + DeserializeOwned, { - fn into(self) -> Digest { - self.get_root_as() + fn from(val: BasefoldCommitmentWithData) -> Self { + val.get_root_as() } } -impl Into> for &BasefoldCommitmentWithData +impl From<&BasefoldCommitmentWithData> for BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - fn into(self) -> BasefoldCommitment { - self.to_commitment() + fn from(val: &BasefoldCommitmentWithData) -> Self { + val.to_commitment() } } @@ -128,17 +154,24 @@ where pub(super) root: Digest, pub(super) num_vars: Option, pub(super) is_base: bool, + pub(super) num_polys: Option, } impl BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - pub fn new(root: Digest, num_vars: usize, is_base: bool) -> Self { + pub fn new( + root: Digest, + num_vars: usize, + is_base: bool, + num_polys: usize, + ) -> Self { Self { root, num_vars: Some(num_vars), is_base, + num_polys: Some(num_polys), } } @@ -153,14 +186,6 @@ where pub fn is_base(&self) -> bool { self.is_base } - - pub fn as_challenge_field(&self) -> BasefoldCommitment { - BasefoldCommitment:: { - root: Digest::(self.root().0), - num_vars: self.num_vars, - is_base: self.is_base, - } - } } impl PartialEq for BasefoldCommitmentWithData @@ -168,7 +193,8 @@ where E::BaseField: Serialize + DeserializeOwned, { fn eq(&self, other: &Self) -> bool { - self.get_codeword().eq(other.get_codeword()) && self.bh_evals.eq(&other.bh_evals) + self.get_codewords().eq(other.get_codewords()) + && self.polynomials_bh_evals.eq(&other.polynomials_bh_evals) } } @@ -177,37 +203,50 @@ impl Eq for BasefoldCommitmentWithData where { } -pub trait BasefoldExtParams: Debug { - fn get_reps() -> usize; +pub trait BasefoldSpec: Debug + Clone { + type EncodingScheme: EncodingScheme; + + fn get_number_queries() -> usize { + Self::EncodingScheme::get_number_queries() + } - fn get_rate() -> usize; + fn get_rate_log() -> usize { + Self::EncodingScheme::get_rate_log() + } - fn get_basecode() -> usize; + fn get_basecode_msg_size_log() -> usize { + Self::EncodingScheme::get_basecode_msg_size_log() + } } -#[derive(Debug)] -pub struct BasefoldDefaultParams; +#[derive(Debug, Clone)] +pub struct BasefoldBasecodeParams; -impl BasefoldExtParams for BasefoldDefaultParams { - fn get_reps() -> usize { - return 260; - } +impl BasefoldSpec for BasefoldBasecodeParams +where + E::BaseField: Serialize + DeserializeOwned, +{ + type EncodingScheme = Basecode; +} - fn get_rate() -> usize { - return 3; - } +#[derive(Debug, Clone)] +pub struct BasefoldRSParams; - fn get_basecode() -> usize { - return 7; - } +impl BasefoldSpec for BasefoldRSParams +where + E::BaseField: Serialize + DeserializeOwned, +{ + type EncodingScheme = RSCode; } #[derive(Debug)] -pub struct Basefold(PhantomData<(E, V)>); +pub struct Basefold, Rng: RngCore>( + PhantomData<(E, Spec, Rng)>, +); -pub type BasefoldDefault = Basefold; +pub type BasefoldDefault = Basefold; -impl Clone for Basefold { +impl, Rng: RngCore> Clone for Basefold { fn clone(&self) -> Self { Self(PhantomData) } @@ -232,3 +271,61 @@ where slice::from_ref(root) } } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ProofQueriesResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + Single(QueriesResultWithMerklePath), + Batched(BatchedQueriesResultWithMerklePath), + SimpleBatched(SimpleBatchQueriesResultWithMerklePath), +} + +impl ProofQueriesResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn as_single<'a>(&'a self) -> &'a QueriesResultWithMerklePath { + match self { + Self::Single(x) => x, + _ => panic!("Not a single query result"), + } + } + + pub fn as_batched<'a>(&'a self) -> &'a BatchedQueriesResultWithMerklePath { + match self { + Self::Batched(x) => x, + _ => panic!("Not a batched query result"), + } + } + + pub fn as_simple_batched<'a>(&'a self) -> &'a SimpleBatchQueriesResultWithMerklePath { + match self { + Self::SimpleBatched(x) => x, + _ => panic!("Not a simple batched query result"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BasefoldProof +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub(crate) sumcheck_messages: Vec>, + pub(crate) roots: Vec>, + pub(crate) final_message: Vec, + pub(crate) query_result_with_merkle_path: ProofQueriesResultWithMerklePath, + pub(crate) sumcheck_proof: Option>>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BasefoldCommitPhaseProof +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub(crate) sumcheck_messages: Vec>, + pub(crate) roots: Vec>, + pub(crate) final_message: Vec, +} diff --git a/mpcs/src/basefold/sumcheck.rs b/mpcs/src/basefold/sumcheck.rs index a6a84f0eb..bb5312111 100644 --- a/mpcs/src/basefold/sumcheck.rs +++ b/mpcs/src/basefold/sumcheck.rs @@ -7,8 +7,8 @@ use rayon::prelude::{ }; pub fn sum_check_first_round_field_type( - mut eq: &mut Vec, - mut bh_values: &mut FieldType, + eq: &mut Vec, + bh_values: &mut FieldType, ) -> Vec { // The input polynomials are in the form of evaluations. Instead of viewing // every one element as the evaluation of the polynomial at a single point, @@ -16,24 +16,21 @@ pub fn sum_check_first_round_field_type( // a single point, leaving the first variable free, and obtaining a univariate // polynomial. The one_level_interp_hc transforms the evaluation forms into // the coefficient forms, for every of these partial polynomials. - one_level_interp_hc(&mut eq); - one_level_interp_hc_field_type(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc_field_type(bh_values); parallel_pi_field_type(bh_values, eq) // p_i(&bh_values, &eq) } -pub fn sum_check_first_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, -) -> Vec { +pub fn sum_check_first_round(eq: &mut Vec, bh_values: &mut Vec) -> Vec { // The input polynomials are in the form of evaluations. Instead of viewing // every one element as the evaluation of the polynomial at a single point, // we can view every two elements as partially evaluating the polynomial at // a single point, leaving the first variable free, and obtaining a univariate // polynomial. The one_level_interp_hc transforms the evaluation forms into // the coefficient forms, for every of these partial polynomials. - one_level_interp_hc(&mut eq); - one_level_interp_hc(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc(bh_values); parallel_pi(bh_values, eq) // p_i(&bh_values, &eq) } @@ -51,7 +48,7 @@ pub fn one_level_interp_hc(evals: &mut Vec) { return; } evals.par_chunks_mut(2).for_each(|chunk| { - chunk[1] = chunk[1] - chunk[0]; + chunk[1] -= chunk[0]; }); } @@ -68,15 +65,15 @@ pub fn one_level_eval_hc(evals: &mut Vec, challenge: F) { }); } -fn parallel_pi_field_type(evals: &mut FieldType, eq: &mut Vec) -> Vec { +fn parallel_pi_field_type(evals: &mut FieldType, eq: &mut [E]) -> Vec { match evals { - FieldType::Ext(evals) => parallel_pi(evals, &eq), - FieldType::Base(evals) => parallel_pi_base(evals, &eq), + FieldType::Ext(evals) => parallel_pi(evals, eq), + FieldType::Base(evals) => parallel_pi_base(evals, eq), _ => unreachable!(), } } -fn parallel_pi(evals: &Vec, eq: &Vec) -> Vec { +fn parallel_pi(evals: &[F], eq: &[F]) -> Vec { if evals.len() == 1 { return vec![evals[0], evals[0], evals[0]]; } @@ -111,7 +108,7 @@ fn parallel_pi(evals: &Vec, eq: &Vec) -> Vec { coeffs } -fn parallel_pi_base(evals: &Vec, eq: &Vec) -> Vec { +fn parallel_pi_base(evals: &[E::BaseField], eq: &[E]) -> Vec { if evals.len() == 1 { return vec![E::from(evals[0]), E::from(evals[0]), E::from(evals[0])]; } @@ -147,31 +144,27 @@ fn parallel_pi_base(evals: &Vec, eq: &Vec) - } pub fn sum_check_challenge_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, + eq: &mut Vec, + bh_values: &mut Vec, challenge: F, ) -> Vec { // Note that when the last round ends, every two elements are in // the coefficient form. Use the challenge to reduce the two elements // into a single value. This is equivalent to substituting the challenge // to the first variable of the poly. - one_level_eval_hc(&mut bh_values, challenge); - one_level_eval_hc(&mut eq, challenge); + one_level_eval_hc(bh_values, challenge); + one_level_eval_hc(eq, challenge); - one_level_interp_hc(&mut eq); - one_level_interp_hc(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc(bh_values); - parallel_pi(&bh_values, &eq) + parallel_pi(bh_values, eq) // p_i(&bh_values,&eq) } -pub fn sum_check_last_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, - challenge: F, -) { - one_level_eval_hc(&mut bh_values, challenge); - one_level_eval_hc(&mut eq, challenge); +pub fn sum_check_last_round(eq: &mut Vec, bh_values: &mut Vec, challenge: F) { + one_level_eval_hc(bh_values, challenge); + one_level_eval_hc(eq, challenge); } #[cfg(test)] @@ -184,7 +177,7 @@ mod tests { use super::*; use crate::util::test::rand_vec; - pub fn p_i(evals: &Vec, eq: &Vec) -> Vec { + pub fn p_i(evals: &[F], eq: &[F]) -> Vec { if evals.len() == 1 { return vec![evals[0], evals[0], evals[0]]; } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 216b50d52..7ebe96203 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -2,12 +2,10 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; use rand::RngCore; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Serialize}; use std::fmt::Debug; -use util::{ - hash::Digest, - transcript::{InMemoryTranscript, PoseidonTranscript, TranscriptRead, TranscriptWrite}, -}; +use transcript::Transcript; +use util::hash::Digest; pub mod sum_check; pub mod util; @@ -29,8 +27,9 @@ pub fn pcs_setup>( pub fn pcs_trim>( param: &Pcs::Param, + poly_size: usize, ) -> Result<(Pcs::ProverParam, Pcs::VerifierParam), Error> { - Pcs::trim(param) + Pcs::trim(param, poly_size) } pub fn pcs_commit>( @@ -43,23 +42,23 @@ pub fn pcs_commit>( pub fn pcs_commit_and_write>( pp: &Pcs::ProverParam, poly: &DenseMultilinearExtension, - transcript: &mut impl TranscriptWrite, + transcript: &mut Transcript, ) -> Result { Pcs::commit_and_write(pp, poly, transcript) } pub fn pcs_batch_commit>( pp: &Pcs::ProverParam, - polys: &Vec>, -) -> Result, Error> { + polys: &[DenseMultilinearExtension], +) -> Result { Pcs::batch_commit(pp, polys) } -pub fn pcs_batch_commit_and_write<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( +pub fn pcs_batch_commit_and_write>( pp: &Pcs::ProverParam, - polys: &Vec>, - transcript: &mut impl TranscriptWrite, -) -> Result, Error> { + polys: &[DenseMultilinearExtension], + transcript: &mut Transcript, +) -> Result { Pcs::batch_commit_and_write(pp, polys, transcript) } @@ -69,60 +68,45 @@ pub fn pcs_open>( comm: &Pcs::CommitmentWithData, point: &[E], eval: &E, - transcript: &mut impl TranscriptWrite, -) -> Result<(), Error> { + transcript: &mut Transcript, +) -> Result { Pcs::open(pp, poly, comm, point, eval, transcript) } pub fn pcs_batch_open>( pp: &Pcs::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Pcs::CommitmentWithData], points: &[Vec], evals: &[Evaluation], - transcript: &mut impl TranscriptWrite, -) -> Result<(), Error> { + transcript: &mut Transcript, +) -> Result { Pcs::batch_open(pp, polys, comms, points, evals, transcript) } -pub fn pcs_read_commitment>( - vp: &Pcs::VerifierParam, - transcript: &mut impl TranscriptRead, -) -> Result { - let comms = Pcs::read_commitments(vp, 1, transcript)?; - assert_eq!(comms.len(), 1); - Ok(comms.into_iter().next().unwrap()) -} - -pub fn pcs_read_commitments>( - vp: &Pcs::VerifierParam, - num_polys: usize, - transcript: &mut impl TranscriptRead, -) -> Result, Error> { - Pcs::read_commitments(vp, num_polys, transcript) -} - pub fn pcs_verify>( vp: &Pcs::VerifierParam, comm: &Pcs::Commitment, point: &[E], eval: &E, - transcript: &mut impl TranscriptRead, + proof: &Pcs::Proof, + transcript: &mut Transcript, ) -> Result<(), Error> { - Pcs::verify(vp, comm, point, eval, transcript) + Pcs::verify(vp, comm, point, eval, proof, transcript) } pub fn pcs_batch_verify<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( vp: &Pcs::VerifierParam, - comms: &Vec, + comms: &[Pcs::Commitment], points: &[Vec], evals: &[Evaluation], - transcript: &mut impl TranscriptRead, + proof: &Pcs::Proof, + transcript: &mut Transcript, ) -> Result<(), Error> where Pcs::Commitment: 'a, { - Pcs::batch_verify(vp, comms, points, evals, transcript) + Pcs::batch_verify(vp, comms, points, evals, proof, transcript) } pub trait PolynomialCommitmentScheme: Clone + Debug { @@ -132,11 +116,15 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { type CommitmentWithData: Clone + Debug + Default + Serialize + DeserializeOwned; type Commitment: Clone + Debug + Default + Serialize + DeserializeOwned; type CommitmentChunk: Clone + Debug + Default; + type Proof: Clone + Debug + Serialize + DeserializeOwned; type Rng: RngCore + Clone; fn setup(poly_size: usize, rng: &Self::Rng) -> Result; - fn trim(param: &Self::Param) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; + fn trim( + param: &Self::Param, + poly_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; fn commit( pp: &Self::ProverParam, @@ -146,19 +134,34 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn commit_and_write( pp: &Self::ProverParam, poly: &DenseMultilinearExtension, - transcript: &mut impl TranscriptWrite, - ) -> Result; + transcript: &mut Transcript, + ) -> Result { + let comm = Self::commit(pp, poly)?; + Self::write_commitment(&Self::get_pure_commitment(&comm), transcript)?; + Ok(comm) + } + + fn write_commitment( + comm: &Self::Commitment, + transcript: &mut Transcript, + ) -> Result<(), Error>; + + fn get_pure_commitment(comm: &Self::CommitmentWithData) -> Self::Commitment; fn batch_commit( pp: &Self::ProverParam, - polys: &Vec>, - ) -> Result, Error>; + polys: &[DenseMultilinearExtension], + ) -> Result; fn batch_commit_and_write( pp: &Self::ProverParam, - polys: &Vec>, - transcript: &mut impl TranscriptWrite, - ) -> Result, Error>; + polys: &[DenseMultilinearExtension], + transcript: &mut Transcript, + ) -> Result { + let comm = Self::batch_commit(pp, polys)?; + Self::write_commitment(&Self::get_pure_commitment(&comm), transcript)?; + Ok(comm) + } fn open( pp: &Self::ProverParam, @@ -166,54 +169,58 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { comm: &Self::CommitmentWithData, point: &[E], eval: &E, - transcript: &mut impl TranscriptWrite, - ) -> Result<(), Error>; + transcript: &mut Transcript, + ) -> Result; fn batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], - transcript: &mut impl TranscriptWrite, - ) -> Result<(), Error>; - - fn read_commitment( - vp: &Self::VerifierParam, - transcript: &mut impl TranscriptRead, - ) -> Result { - let comms = Self::read_commitments(vp, 1, transcript)?; - assert_eq!(comms.len(), 1); - Ok(comms.into_iter().next().unwrap()) - } - - fn read_commitments( - vp: &Self::VerifierParam, - num_polys: usize, - transcript: &mut impl TranscriptRead, - ) -> Result, Error>; + transcript: &mut Transcript, + ) -> Result; + + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + pp: &Self::ProverParam, + polys: &[DenseMultilinearExtension], + comm: &Self::CommitmentWithData, + point: &[E], + evals: &[E], + transcript: &mut Transcript, + ) -> Result; fn verify( vp: &Self::VerifierParam, comm: &Self::Commitment, point: &[E], eval: &E, - transcript: &mut impl TranscriptRead, + proof: &Self::Proof, + transcript: &mut Transcript, ) -> Result<(), Error>; fn batch_verify( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], - transcript: &mut impl TranscriptRead, + proof: &Self::Proof, + transcript: &mut Transcript, ) -> Result<(), Error>; -} -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct PCSProof(Vec) -where - E::BaseField: Serialize + DeserializeOwned; + fn simple_batch_verify( + vp: &Self::VerifierParam, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + proof: &Self::Proof, + transcript: &mut Transcript, + ) -> Result<(), Error>; +} pub trait NoninteractivePCS: PolynomialCommitmentScheme> @@ -226,22 +233,20 @@ where comm: &Self::CommitmentWithData, point: &[E], eval: &E, - ) -> Result, Error> { - let mut transcript = PoseidonTranscript::::new(); - Self::open(pp, poly, comm, point, eval, &mut transcript)?; - Ok(PCSProof(transcript.into_proof())) + ) -> Result { + let mut transcript = Transcript::::new(b"BaseFold"); + Self::open(pp, poly, comm, point, eval, &mut transcript) } fn ni_batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], - ) -> Result, Error> { - let mut transcript = PoseidonTranscript::::new(); - Self::batch_open(pp, polys, comms, points, evals, &mut transcript)?; - Ok(PCSProof(transcript.into_proof())) + ) -> Result { + let mut transcript = Transcript::::new(b"BaseFold"); + Self::batch_open(pp, polys, comms, points, evals, &mut transcript) } fn ni_verify( @@ -249,24 +254,24 @@ where comm: &Self::Commitment, point: &[E], eval: &E, - proof: &PCSProof, + proof: &Self::Proof, ) -> Result<(), Error> { - let mut transcript = PoseidonTranscript::::from_proof(proof.0.as_slice()); - Self::verify(vp, comm, point, eval, &mut transcript) + let mut transcript = Transcript::::new(b"BaseFold"); + Self::verify(vp, comm, point, eval, proof, &mut transcript) } fn ni_batch_verify<'a>( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], - proof: &PCSProof, + proof: &Self::Proof, ) -> Result<(), Error> where Self::Commitment: 'a, { - let mut transcript = PoseidonTranscript::::from_proof(proof.0.as_slice()); - Self::batch_verify(vp, comms, points, evals, &mut transcript) + let mut transcript = Transcript::::new(b"BaseFold"); + Self::batch_verify(vp, comms, points, evals, proof, &mut transcript) } } @@ -308,18 +313,19 @@ pub enum Error { mod basefold; pub use basefold::{ - Basefold, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, - BasefoldDefaultParams, BasefoldExtParams, BasefoldParams, + coset_fft, fft, fft_root_table, Basecode, BasecodeDefaultSpec, Basefold, + BasefoldBasecodeParams, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, + BasefoldParams, BasefoldRSParams, BasefoldSpec, EncodingScheme, RSCode, RSCodeDefaultSpec, }; fn validate_input( function: &str, param_num_vars: usize, - polys: &Vec>, - points: &Vec>, + polys: &[DenseMultilinearExtension], + points: &[Vec], ) -> Result<(), Error> { - let polys = polys.into_iter().collect_vec(); - let points = points.into_iter().collect_vec(); + let polys = polys.iter().collect_vec(); + let points = points.iter().collect_vec(); for poly in polys.iter() { if param_num_vars < poly.num_vars { return Err(err_too_many_variates( @@ -351,25 +357,20 @@ fn err_too_many_variates(function: &str, upto: usize, got: usize) -> Error { #[cfg(test)] pub mod test_util { - use crate::{ - util::transcript::{InMemoryTranscript, TranscriptRead, TranscriptWrite}, - Evaluation, PolynomialCommitmentScheme, - }; + use crate::{Evaluation, PolynomialCommitmentScheme}; use ff_ext::ExtensionField; use itertools::{chain, Itertools}; use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use rand::{prelude::*, rngs::OsRng}; use rand_chacha::ChaCha8Rng; + use transcript::Transcript; - pub fn run_commit_open_verify( + pub fn run_commit_open_verify( base: bool, num_vars_start: usize, num_vars_end: usize, ) where Pcs: PolynomialCommitmentScheme, - T: TranscriptRead - + TranscriptWrite - + InMemoryTranscript, { for num_vars in num_vars_start..num_vars_end { // Setup @@ -377,11 +378,11 @@ pub mod test_util { let rng = ChaCha8Rng::from_seed([0u8; 32]); let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; // Commit and open - let proof = { - let mut transcript = T::new(); + let (comm, eval, proof) = { + let mut transcript = Transcript::new(b"BaseFold"); let poly = if base { DenseMultilinearExtension::random(num_vars, &mut OsRng) } else { @@ -392,23 +393,26 @@ pub mod test_util { }; let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); - let point = transcript.squeeze_challenges(num_vars); + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); let eval = poly.evaluate(point.as_slice()); - transcript.write_field_element_ext(&eval).unwrap(); - Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - - transcript.into_proof() + transcript.append_field_element_ext(&eval); + ( + Pcs::get_pure_commitment(&comm), + eval, + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(), + ) }; // Verify let result = { - let mut transcript = T::from_proof(proof.as_slice()); - let result = Pcs::verify( - &vp, - &Pcs::read_commitment(&vp, &mut transcript).unwrap(), - &transcript.squeeze_challenges(num_vars), - &transcript.read_field_element_ext().unwrap(), - &mut transcript, - ); + let mut transcript = Transcript::new(b"BaseFold"); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + transcript.append_field_element_ext(&eval); + let result = Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript); result }; @@ -416,16 +420,13 @@ pub mod test_util { } } - pub fn run_batch_commit_open_verify( + pub fn run_batch_commit_open_verify( base: bool, num_vars_start: usize, num_vars_end: usize, ) where E: ExtensionField, Pcs: PolynomialCommitmentScheme, - T: TranscriptRead - + TranscriptWrite - + InMemoryTranscript, { for num_vars in num_vars_start..num_vars_end { let batch_size = 2; @@ -435,7 +436,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -445,8 +446,8 @@ pub mod test_util { .unique() .collect_vec(); - let proof = { - let mut transcript = T::new(); + let (comms, points, evals, proof) = { + let mut transcript = Transcript::new(b"BaseFold"); let polys = (0..batch_size) .map(|i| { if base { @@ -459,10 +460,18 @@ pub mod test_util { } }) .collect_vec(); - let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap()) + .collect_vec(); let points = (0..num_points) - .map(|i| transcript.squeeze_challenges(num_vars - i)) + .map(|i| { + (0..num_vars - i) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>() + }) .take(num_points) .collect_vec(); @@ -475,350 +484,120 @@ pub mod test_util { value: polys[poly].evaluate(&points[point]), }) .collect_vec(); - transcript - .write_field_elements_ext(evals.iter().map(Evaluation::value)) - .unwrap(); - Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); - transcript.into_proof() + let values: Vec = evals + .iter() + .map(Evaluation::value) + .map(|x| *x) + .collect::>(); + transcript.append_field_element_exts(values.as_slice()); + + let proof = + Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); + (comms, points, evals, proof) }; // Batch verify let result = { - let mut transcript = T::from_proof(proof.as_slice()); - let comms = &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(); + let mut transcript = Transcript::new(b"BaseFold"); + let comms = comms + .iter() + .map(|comm| { + let comm = Pcs::get_pure_commitment(comm); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + comm + }) + .collect_vec(); + let old_points = points; let points = (0..num_points) - .map(|i| transcript.squeeze_challenges(num_vars - i)) + .map(|i| { + (0..num_vars - i) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>() + }) .take(num_points) .collect_vec(); + assert_eq!(points, old_points); + let values: Vec = evals + .iter() + .map(Evaluation::value) + .map(|x| *x) + .collect::>(); + transcript.append_field_element_exts(values.as_slice()); - let evals2 = transcript.read_field_elements_ext(evals.len()).unwrap(); - - let result = Pcs::batch_verify( - &vp, - comms, - &points, - &evals - .iter() - .copied() - .zip(evals2) - .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) - .collect_vec(), - &mut transcript, - ); + let result = + Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript); result }; result.unwrap(); } } -} -#[cfg(test)] -mod test { - use crate::{ - basefold::{Basefold, BasefoldExtParams}, - util::transcript::{FieldTranscript, InMemoryTranscript, PoseidonTranscript}, - PolynomialCommitmentScheme, - }; - use goldilocks::GoldilocksExt2; - use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; - use rand::{prelude::*, rngs::OsRng}; - use rand_chacha::ChaCha8Rng; - #[test] - fn test_transcript() { - #[derive(Debug)] - pub struct Five {} - - impl BasefoldExtParams for Five { - fn get_reps() -> usize { - return 5; - } - fn get_rate() -> usize { - return 3; - } - fn get_basecode() -> usize { - return 2; - } - } + pub(super) fn run_simple_batch_commit_open_verify( + base: bool, + num_vars_start: usize, + num_vars_end: usize, + batch_size: usize, + ) where + E: ExtensionField, + Pcs: PolynomialCommitmentScheme, + { + for num_vars in num_vars_start..num_vars_end { + let rng = ChaCha8Rng::from_seed([0u8; 32]); + // Setup + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; - type Pcs = Basefold; - let num_vars = 10; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - let poly_size = 1 << num_vars; - let mut transcript = PoseidonTranscript::new(); - let poly = DenseMultilinearExtension::random(num_vars, &mut OsRng); - let param = - >::setup(poly_size, &rng).unwrap(); - - let (pp, vp) = >::trim(¶m).unwrap(); - println!("before commit"); - let comm = >::commit_and_write( - &pp, - &poly, - &mut transcript, - ) - .unwrap(); - let point = transcript.squeeze_challenges(num_vars); - let eval = poly.evaluate(point.as_slice()); - >::open( - &pp, - &poly, - &comm, - &point, - &eval, - &mut transcript, - ) - .unwrap(); - let proof = transcript.into_proof(); - println!("transcript commit len {:?}", proof.len() * 8); - assert!(comm.is_base()); - let mut transcript = PoseidonTranscript::::from_proof(proof.as_slice()); - let comm = >::read_commitment( - &vp, - &mut transcript, - ) - .unwrap(); - assert!(comm.is_base()); - assert_eq!(comm.num_vars().unwrap(), num_vars); - } + let (comm, evals, proof) = { + let mut transcript = Transcript::new(b"BaseFold"); + let polys = (0..batch_size) + .map(|_| { + if base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); - // use gkr::structs::{Circuit, CircuitWitness, IOPProverState, IOPVerifierState}; - // use gkr::utils::MultilinearExtensionFromVectors; - // use simple_frontend::structs::{CircuitBuilder, ConstantType}; - // use transcript::Transcript; - - // enum TableType { - // FakeHashTable, - // } - - // struct AllInputIndex { - // // public - // inputs_idx: usize, - - // // private - // other_x_pows_idx: usize, - // count_idx: usize, - // } - - // fn construct_circuit() -> (Circuit, AllInputIndex) { - // let mut circuit_builder = CircuitBuilder::::new(); - // let one = F::BaseField::ONE; - // let neg_one = -F::BaseField::ONE; - - // let table_size = 4; - // let x = circuit_builder.create_constant_in(1, 2); - // let (other_x_pows_idx, other_pows_of_x) = circuit_builder.create_wire_in(table_size - 1); - // let pow_of_xs = [x, other_pows_of_x].concat(); - // for i in 0..table_size - 1 { - // // circuit_builder.mul2( - // // pow_of_xs[i + 1], - // // pow_of_xs[i], - // // pow_of_xs[i], - // // Goldilocks::ONE, - // // ); - // let tmp = circuit_builder.create_cell(); - // circuit_builder.mul2(tmp, pow_of_xs[i], pow_of_xs[i], F::BaseField::ONE); - // let diff = circuit_builder.create_cell(); - // circuit_builder.add(diff, pow_of_xs[i + 1], one); - // circuit_builder.add(diff, tmp, neg_one); - // circuit_builder.assert_const(diff, F::BaseField::ZERO); - // } - - // let table_type = TableType::FakeHashTable as usize; - // let count_idx = circuit_builder.define_table_type(table_type); - // for i in 0..table_size { - // circuit_builder.add_table_item(table_type, pow_of_xs[i]); - // } - - // let (inputs_idx, inputs) = circuit_builder.create_wire_in(5); - // inputs.iter().for_each(|input| { - // circuit_builder.add_input_item(table_type, *input); - // }); - - // circuit_builder.assign_table_challenge(table_type, ConstantType::Challenge(0)); - - // circuit_builder.configure(); - // // circuit_builder.print_info(); - // ( - // Circuit::::new(&circuit_builder), - // AllInputIndex { - // other_x_pows_idx, - // inputs_idx, - // count_idx, - // }, - // ) - // } - - // pub(super) fn test_with_gkr() - // where - // F: SmallField + FromUniformBytes<64>, - // F::BaseField: Into, - // Pcs: NoninteractivePCS, Rng = ChaCha8Rng>, - // for<'a> &'a Pcs::CommitmentWithData: Into, - // for<'de> ::BaseField: Deserialize<'de>, - // T: TranscriptRead - // + TranscriptWrite - // + InMemoryTranscript, - // { - // // This test is copied from examples/fake_hash_lookup_par, which is currently - // // not using PCS for the check. The verifier outputs a GKRInputClaims that the - // // verifier is unable to check without the PCS. - - // let rng = ChaCha8Rng::from_seed([0u8; 32]); - // // Setup - // let (pp, vp) = { - // let poly_size = 1 << 10; - // let param = Pcs::setup(poly_size, &rng).unwrap(); - // Pcs::trim(¶m).unwrap() - // }; - - // let (circuit, all_input_index) = construct_circuit::(); - // // println!("circuit: {:?}", circuit); - // let mut wires_in = vec![vec![]; circuit.n_wires_in]; - // wires_in[all_input_index.inputs_idx] = vec![ - // F::from(2u64), - // F::from(2u64), - // F::from(4u64), - // F::from(16u64), - // F::from(2u64), - // ]; - // // x = 2, 2^2 = 4, 2^2^2 = 16, 2^2^2^2 = 256 - // wires_in[all_input_index.other_x_pows_idx] = - // vec![F::from(4u64), F::from(16u64), F::from(256u64)]; - // wires_in[all_input_index.count_idx] = - // vec![F::from(3u64), F::from(1u64), F::from(1u64), F::from(0u64)]; - - // let circuit_witness = { - // let challenge = F::from(9); - // let mut circuit_witness = CircuitWitness::new(&circuit, vec![challenge]); - // for _ in 0..4 { - // circuit_witness.add_instance(&circuit, &wires_in); - // } - // circuit_witness - // }; - - // #[cfg(feature = "sanity-check")] - // circuit_witness.check_correctness(&circuit); - - // let instance_num_vars = circuit_witness.instance_num_vars(); - - // // Commit to the input wires - - // let polys = circuit_witness - // .wires_in_ref() - // .iter() - // .map(|values| { - // MultilinearPolynomial::new( - // values - // .as_slice() - // .mle(circuit.max_wires_in_num_vars, instance_num_vars) - // .evaluations - // .clone(), - // ) - // }) - // .collect_vec(); - // println!( - // "Polynomial num vars: {:?}", - // polys.iter().map(|p| p.num_vars()).collect_vec() - // ); - // let comms_with_data = Pcs::batch_commit(&pp, &polys).unwrap(); - // let comms: Vec = comms_with_data.iter().map(|cm| cm.into()).collect_vec(); - // println!("Finish commitment"); - - // // Commitments should be part of the proof, which is not yet - - // let (proof, output_num_vars, output_eval) = { - // let mut prover_transcript = Transcript::new(b"example"); - // let output_num_vars = instance_num_vars + circuit.last_layer_ref().num_vars(); - - // let output_point = (0..output_num_vars) - // .map(|_| { - // prover_transcript - // .get_and_append_challenge(b"output point") - // .elements[0] - // }) - // .collect_vec(); - - // let output_eval = circuit_witness - // .layer_poly(0, circuit.last_layer_ref().num_vars()) - // .evaluate(&output_point); - // ( - // IOPProverState::prove_parallel( - // &circuit, - // &circuit_witness, - // &[(output_point, output_eval)], - // &[], - // &mut prover_transcript, - // ), - // output_num_vars, - // output_eval, - // ) - // }; - - // let gkr_input_claims = { - // let mut verifier_transcript = &mut Transcript::new(b"example"); - // let output_point = (0..output_num_vars) - // .map(|_| { - // verifier_transcript - // .get_and_append_challenge(b"output point") - // .elements[0] - // }) - // .collect_vec(); - // IOPVerifierState::verify_parallel( - // &circuit, - // circuit_witness.challenges(), - // &[(output_point, output_eval)], - // &[], - // &proof, - // instance_num_vars, - // &mut verifier_transcript, - // ) - // .expect("verification failed") - // }; - - // // Generate pcs proof - // let expected_values = circuit_witness - // .wires_in_ref() - // .iter() - // .map(|witness| { - // witness - // .as_slice() - // .mle(circuit.max_wires_in_num_vars, instance_num_vars) - // .evaluate(&gkr_input_claims.point) - // }) - // .collect_vec(); - // let points = vec![gkr_input_claims.point]; - // let evals = expected_values - // .iter() - // .enumerate() - // .map(|(i, e)| Evaluation { - // poly: i, - // point: 0, - // value: *e, - // }) - // .collect_vec(); - // // This should be part of the GKR proof - // let pcs_proof = Pcs::ni_batch_open(&pp, &polys, &comms_with_data, &points, &evals).unwrap(); - // println!("Finish opening"); - - // // Check outside of the GKR verifier - // for i in 0..gkr_input_claims.values.len() { - // assert_eq!(expected_values[i], gkr_input_claims.values[i]); - // } - - // // This should be part of the GKR verifier - // let evals = gkr_input_claims - // .values - // .iter() - // .enumerate() - // .map(|(i, e)| Evaluation { - // poly: i, - // point: 0, - // value: *e, - // }) - // .collect_vec(); - // Pcs::ni_batch_verify(&vp, &comms, &points, &evals, &pcs_proof).unwrap(); - - // println!("verification succeeded"); - // } + transcript.append_field_element_exts(&evals); + let proof = + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + (Pcs::get_pure_commitment(&comm), evals, proof) + }; + // Batch verify + let result = { + let mut transcript = Transcript::new(b"BaseFold"); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + + let point = (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect::>(); + + transcript.append_field_element_exts(&evals); + + let result = + Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript); + result + }; + + result.unwrap(); + } + } } diff --git a/mpcs/src/sum_check.rs b/mpcs/src/sum_check.rs index e12233cf9..a3e1b460c 100644 --- a/mpcs/src/sum_check.rs +++ b/mpcs/src/sum_check.rs @@ -2,17 +2,19 @@ use crate::{ util::{ arithmetic::{inner_product, powers, product, BooleanHypercube}, expression::{CommonPolynomial, Expression, Query}, - transcript::{FieldTranscriptRead, FieldTranscriptWrite}, BitIndex, }, Error, }; use std::{collections::HashMap, fmt::Debug}; +use classic::{ClassicSumCheckRoundMessage, SumcheckProof}; use ff::PrimeField; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; +use serde::{de::DeserializeOwned, Serialize}; +use transcript::Transcript; pub mod classic; @@ -40,24 +42,29 @@ impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> { } } -pub trait SumCheck: Clone + Debug { +pub trait SumCheck: Clone + Debug +where + E::BaseField: Serialize + DeserializeOwned, +{ type ProverParam: Clone + Debug; type VerifierParam: Clone + Debug; + type RoundMessage: ClassicSumCheckRoundMessage + Clone + Debug; fn prove( pp: &Self::ProverParam, num_vars: usize, virtual_poly: VirtualPolynomial, sum: E, - transcript: &mut impl FieldTranscriptWrite, - ) -> Result<(Vec, Vec), Error>; + transcript: &mut Transcript, + ) -> Result<(Vec, Vec, SumcheckProof), Error>; fn verify( vp: &Self::VerifierParam, num_vars: usize, degree: usize, sum: E, - transcript: &mut impl FieldTranscriptRead, + proof: &SumcheckProof, + transcript: &mut Transcript, ) -> Result<(E, Vec), Error>; } @@ -93,8 +100,8 @@ pub fn evaluate( &|query| evals[&query], &|idx| challenges[idx], &|scalar| -scalar, - &|lhs, rhs| lhs + &rhs, - &|lhs, rhs| lhs * &rhs, + &|lhs, rhs| lhs + rhs, + &|lhs, rhs| lhs * rhs, &|value, scalar| scalar * value, ) } @@ -104,11 +111,7 @@ pub fn lagrange_eval(x: &[F], b: usize) -> F { product(x.iter().enumerate().map( |(idx, x_i)| { - if b.nth_bit(idx) { - *x_i - } else { - F::ONE - x_i - } + if b.nth_bit(idx) { *x_i } else { F::ONE - x_i } }, )) } diff --git a/mpcs/src/sum_check/classic.rs b/mpcs/src/sum_check/classic.rs index 439535e71..28b03fdad 100644 --- a/mpcs/src/sum_check/classic.rs +++ b/mpcs/src/sum_check/classic.rs @@ -5,7 +5,6 @@ use crate::{ expression::{Expression, Rotation}, parallel::par_map_collect, poly_index_ext, - transcript::{FieldTranscriptRead, FieldTranscriptWrite}, }, Error, }; @@ -14,13 +13,16 @@ use ff::Field; use ff_ext::ExtensionField; use itertools::Itertools; use num_integer::Integer; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData}; +use transcript::Transcript; mod coeff; use multilinear_extensions::{ mle::{DenseMultilinearExtension, MultilinearExtension}, virtual_poly::build_eq_x_r_vec, }; +pub(crate) use coeff::Coefficients; pub use coeff::CoefficientsProver; #[derive(Debug)] @@ -116,10 +118,8 @@ impl<'a, E: ExtensionField> ProverState<'a, E> { .expression .used_rotation() .into_iter() - .filter_map(|rotation| { - (rotation != Rotation::cur()) - .then(|| (rotation, self.bh.rotation_map(rotation))) - }) + .filter(|&rotation| (rotation != Rotation::cur())) + .map(|rotation| (rotation, self.bh.rotation_map(rotation))) .collect::>(); for query in self.expression.used_query() { if query.rotation() != Rotation::cur() { @@ -164,8 +164,17 @@ impl<'a, E: ExtensionField> ProverState<'a, E> { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SumcheckProof> +where + E::BaseField: Serialize + DeserializeOwned, +{ + rounds: Vec, + phantom: PhantomData, +} + pub trait ClassicSumCheckProver: Clone + Debug { - type RoundMessage: ClassicSumCheckRoundMessage; + type RoundMessage: ClassicSumCheckRoundMessage + Clone + Debug; fn new(state: &ProverState) -> Self; @@ -177,15 +186,7 @@ pub trait ClassicSumCheckProver: Clone + Debug { pub trait ClassicSumCheckRoundMessage: Sized + Debug { type Auxiliary: Default; - fn write(&self, transcript: &mut impl FieldTranscriptWrite) -> Result<(), Error>; - - fn read_base( - degree: usize, - transcript: &mut impl FieldTranscriptRead, - ) -> Result; - - fn read_ext(degree: usize, transcript: &mut impl FieldTranscriptRead) - -> Result; + fn write(&self, transcript: &mut Transcript) -> Result<(), Error>; fn sum(&self) -> E; @@ -220,17 +221,21 @@ pub trait ClassicSumCheckRoundMessage: Sized + Debug { #[derive(Clone, Debug)] pub struct ClassicSumCheck

(PhantomData

); -impl> SumCheck for ClassicSumCheck

{ +impl> SumCheck for ClassicSumCheck

+where + E::BaseField: Serialize + DeserializeOwned, +{ type ProverParam = (); type VerifierParam = (); + type RoundMessage = P::RoundMessage; fn prove( _: &Self::ProverParam, num_vars: usize, virtual_poly: VirtualPolynomial, sum: E, - transcript: &mut impl FieldTranscriptWrite, - ) -> Result<(Vec, Vec), Error> { + transcript: &mut Transcript, + ) -> Result<(Vec, Vec, SumcheckProof), Error> { let _timer = start_timer!(|| { let degree = virtual_poly.expression.degree(); format!("sum_check_prove-{num_vars}-{degree}") @@ -246,6 +251,8 @@ impl> SumCheck for ClassicSumC let aux = P::RoundMessage::auxiliary(state.degree); + let mut prover_messages = Vec::with_capacity(num_vars); + for _round in 0..num_vars { let timer = start_timer!(|| format!("sum_check_prove_round-{_round}")); let msg = prover.prove_round(&state); @@ -259,15 +266,22 @@ impl> SumCheck for ClassicSumC ); } - let challenge = transcript.squeeze_challenge(); + let challenge = transcript + .get_and_append_challenge(b"sumcheck round") + .elements; challenges.push(challenge); let timer = start_timer!(|| format!("sum_check_next_round-{_round}")); state.next_round(msg.evaluate(&aux, &challenge), &challenge); end_timer!(timer); + prover_messages.push(msg); } - Ok((challenges, state.into_evals())) + let proof = SumcheckProof { + rounds: prover_messages, + phantom: PhantomData, + }; + Ok((challenges, state.into_evals(), proof)) } fn verify( @@ -275,20 +289,26 @@ impl> SumCheck for ClassicSumC num_vars: usize, degree: usize, sum: E, - transcript: &mut impl FieldTranscriptRead, + proof: &SumcheckProof, + transcript: &mut Transcript, ) -> Result<(E, Vec), Error> { let (msgs, challenges) = { let mut msgs = Vec::with_capacity(num_vars); let mut challenges = Vec::with_capacity(num_vars); - for _ in 0..num_vars { - msgs.push(P::RoundMessage::read_ext(degree, transcript)?); - challenges.push(transcript.squeeze_challenge()); + for i in 0..num_vars { + proof.rounds[i].write(transcript)?; + msgs.push(proof.rounds[i].clone()); + challenges.push( + transcript + .get_and_append_challenge(b"sumcheck round") + .elements, + ); } (msgs, challenges) }; Ok(( - P::RoundMessage::verify_consistency(degree, sum, &msgs, &challenges)?, + P::RoundMessage::verify_consistency(degree, sum, msgs.as_slice(), &challenges)?, challenges, )) } @@ -299,20 +319,16 @@ mod tests { use crate::{ sum_check::eq_xy_eval, - util::{ - arithmetic::inner_product, - expression::Query, - poly_iter_ext, - transcript::{InMemoryTranscript, PoseidonTranscript}, - }, + util::{arithmetic::inner_product, expression::Query, poly_iter_ext}, }; + use transcript::Transcript; use super::*; use goldilocks::{Goldilocks as Fr, GoldilocksExt2 as E}; #[test] fn test_sum_check_protocol() { - let polys = vec![ + let polys = [ DenseMultilinearExtension::::from_evaluations_vec( 2, vec![Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)], @@ -348,26 +364,28 @@ mod tests { &build_eq_x_r_vec(&points[1]), ) * Fr::from(4) * Fr::from(2); // The third polynomial is summed twice because the hypercube is larger - let mut transcript = PoseidonTranscript::::new(); - let (challenges, evals) = > as SumCheck>::prove( - &(), - 2, - virtual_poly.clone(), - sum, - &mut transcript, - ) - .unwrap(); + let mut transcript = Transcript::::new(b"sumcheck"); + let (challenges, evals, proof) = + > as SumCheck>::prove( + &(), + 2, + virtual_poly.clone(), + sum, + &mut transcript, + ) + .unwrap(); assert_eq!(polys[0].evaluate(&challenges), evals[0]); assert_eq!(polys[1].evaluate(&challenges), evals[1]); assert_eq!(polys[2].evaluate(&challenges[..1]), evals[2]); - let proof = transcript.into_proof(); - let mut transcript = PoseidonTranscript::::from_proof(&proof); + let mut transcript = Transcript::::new(b"sumcheck"); let (new_sum, verifier_challenges) = > as SumCheck< E, - >>::verify(&(), 2, 2, sum, &mut transcript) + >>::verify( + &(), 2, 2, sum, &proof, &mut transcript + ) .unwrap(); assert_eq!(verifier_challenges, challenges); @@ -378,13 +396,14 @@ mod tests { + evals[2] * eq_xy_eval(&points[1], &challenges[..1]) * Fr::from(4) ); - let mut transcript = PoseidonTranscript::::from_proof(&proof); + let mut transcript = Transcript::::new(b"sumcheck"); > as SumCheck>::verify( &(), 2, 2, sum + Fr::ONE, + &proof, &mut transcript, ) .expect_err("Should panic"); diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 9e596fe08..40bbf2d25 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -6,14 +6,15 @@ use crate::{ impl_index, parallel::{num_threads, parallelize_iter}, poly_index_ext, poly_iter_ext, - transcript::{FieldTranscriptRead, FieldTranscriptWrite}, }, Error, }; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::FieldType; +use serde::{Deserialize, Serialize}; use std::{fmt::Debug, iter, ops::AddAssign}; +use transcript::Transcript; macro_rules! zip_self { (@ $iter:expr, $step:expr, $skip:expr) => { @@ -30,40 +31,21 @@ macro_rules! zip_self { }; } -#[derive(Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Coefficients(FieldType); impl ClassicSumCheckRoundMessage for Coefficients { type Auxiliary = (); - fn write(&self, transcript: &mut impl FieldTranscriptWrite) -> Result<(), Error> { + fn write(&self, transcript: &mut Transcript) -> Result<(), Error> { match &self.0 { - FieldType::Ext(coeffs) => transcript.write_field_elements_ext(coeffs), - FieldType::Base(coeffs) => transcript.write_field_elements_base(coeffs), + FieldType::Ext(coeffs) => transcript.append_field_element_exts(coeffs), + FieldType::Base(coeffs) => coeffs + .iter() + .for_each(|c| transcript.append_field_element(c)), FieldType::Unreachable => unreachable!(), - } - } - - fn read_base( - degree: usize, - transcript: &mut impl FieldTranscriptRead, - ) -> Result { - Ok(Self( - transcript - .read_field_elements_base(degree + 1) - .map(FieldType::Base)?, - )) - } - - fn read_ext( - degree: usize, - transcript: &mut impl FieldTranscriptRead, - ) -> Result { - Ok(Self( - transcript - .read_field_elements_ext(degree + 1) - .map(FieldType::Ext)?, - )) + }; + Ok(()) } fn sum(&self) -> E { @@ -205,7 +187,7 @@ impl ClassicSumCheckProver for CoefficientsProver { products.iter_mut().for_each(|(lhs, _)| { *lhs *= &rhs; }); - (constant * &rhs, products) + (constant * rhs, products) }, ); Self(constant, flattened) @@ -215,7 +197,7 @@ impl ClassicSumCheckProver for CoefficientsProver { // Initialize h(X) to zero let mut coeffs = Coefficients(FieldType::Ext(vec![E::ZERO; state.expression.degree() + 1])); // First, sum the constant over the hypercube and add to h(X) - coeffs += &(E::from(state.size() as u64) * &self.0); + coeffs += &(E::from(state.size() as u64) * self.0); // Next, for every product of polynomials, where each product is assumed to be exactly 2 // put this into h(X). if self.1.iter().all(|(_, products)| products.len() == 2) { @@ -287,11 +269,11 @@ impl CoefficientsProver { .take(n) .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { let coeff_0 = lhs_0 * rhs_0; - let coeff_2 = (lhs_1 - lhs_0) * &(rhs_1 - rhs_0); + let coeff_2 = (lhs_1 - lhs_0) * (rhs_1 - rhs_0); coeffs[0] += &coeff_0; coeffs[2] += &coeff_2; if !LAZY { - coeffs[1] += &(lhs_1 * rhs_1 - &coeff_0 - &coeff_2); + coeffs[1] += &(lhs_1 * rhs_1 - coeff_0 - coeff_2); } }); }; diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index cf056d51a..c796baa39 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -3,7 +3,6 @@ pub mod expression; pub mod hash; pub mod parallel; pub mod plonky2_util; -pub mod transcript; use ff::{Field, PrimeField}; use ff_ext::ExtensionField; use goldilocks::SmallField; @@ -113,6 +112,42 @@ pub fn field_type_index_ext(poly: &FieldType, index: usize } } +pub fn field_type_index_mul_base( + poly: &mut FieldType, + index: usize, + scalar: &E::BaseField, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] *= scalar, + FieldType::Base(coeffs) => coeffs[index] *= scalar, + _ => unreachable!(), + } +} + +pub fn field_type_index_set_base( + poly: &mut FieldType, + index: usize, + scalar: &E::BaseField, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] = E::from(*scalar), + FieldType::Base(coeffs) => coeffs[index] = *scalar, + _ => unreachable!(), + } +} + +pub fn field_type_index_set_ext( + poly: &mut FieldType, + index: usize, + scalar: &E, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] = *scalar, + FieldType::Base(_) => panic!("Cannot set base field from extension field"), + _ => unreachable!(), + } +} + pub struct FieldTypeIterExt<'a, E: ExtensionField> { inner: &'a FieldType, index: usize, diff --git a/mpcs/src/util/arithmetic.rs b/mpcs/src/util/arithmetic.rs index f4ca552d8..609f65455 100644 --- a/mpcs/src/util/arithmetic.rs +++ b/mpcs/src/util/arithmetic.rs @@ -80,7 +80,7 @@ pub fn inner_product<'a, 'b, F: Field>( rhs: impl IntoIterator, ) -> F { lhs.into_iter() - .zip_eq(rhs.into_iter()) + .zip_eq(rhs) .map(|(lhs, rhs)| *lhs * rhs) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -92,8 +92,8 @@ pub fn inner_product_three<'a, 'b, 'c, F: Field>( c: impl IntoIterator, ) -> F { a.into_iter() - .zip_eq(b.into_iter()) - .zip_eq(c.into_iter()) + .zip_eq(b) + .zip_eq(c) .map(|((a, b), c)| *a * b * c) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -107,8 +107,9 @@ pub fn barycentric_weights(points: &[F]) -> Vec { points .iter() .enumerate() - .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i)) - .reduce(|acc, value| acc * &value) + .filter(|&(i, _point_i)| (i != j)) + .map(|(_i, point_i)| *point_j - point_i) + .reduce(|acc, value| acc * value) .unwrap_or(F::ONE) }) .collect_vec(); @@ -126,7 +127,7 @@ pub fn barycentric_interpolate(weights: &[F], points: &[F], evals: &[F let sum_inv = coeffs.iter().fold(F::ZERO, |sum, coeff| sum + coeff); (coeffs, sum_inv.invert().unwrap()) }; - inner_product(&coeffs, evals) * &sum_inv + inner_product(&coeffs, evals) * sum_inv } pub fn modulus() -> BigUint { @@ -134,11 +135,7 @@ pub fn modulus() -> BigUint { } pub fn fe_from_bool(value: bool) -> F { - if value { - F::ONE - } else { - F::ZERO - } + if value { F::ONE } else { F::ZERO } } pub fn fe_mod_from_le_bytes(bytes: impl AsRef<[u8]>) -> F { @@ -221,17 +218,17 @@ pub fn interpolate2(points: [(F, F); 2], x: F) -> F { a1 + (x - a0) * (b1 - a1) * (b0 - a0).invert().unwrap() } -pub fn degree_2_zero_plus_one(poly: &Vec) -> F { +pub fn degree_2_zero_plus_one(poly: &[F]) -> F { poly[0] + poly[0] + poly[1] + poly[2] } -pub fn degree_2_eval(poly: &Vec, point: F) -> F { +pub fn degree_2_eval(poly: &[F], point: F) -> F { poly[0] + point * poly[1] + point * point * poly[2] } -pub fn base_from_raw_bytes(bytes: &Vec) -> E::BaseField { +pub fn base_from_raw_bytes(bytes: &[u8]) -> E::BaseField { let mut res = E::BaseField::ZERO; - bytes.into_iter().for_each(|b| { + bytes.iter().for_each(|b| { res += E::BaseField::from(u64::from(*b)); }); res diff --git a/mpcs/src/util/arithmetic/hypercube.rs b/mpcs/src/util/arithmetic/hypercube.rs index 16c9868af..ccd6294e3 100644 --- a/mpcs/src/util/arithmetic/hypercube.rs +++ b/mpcs/src/util/arithmetic/hypercube.rs @@ -1,4 +1,3 @@ -use ark_std::{end_timer, start_timer}; use ff::Field; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; @@ -30,7 +29,7 @@ pub fn interpolate_over_boolean_hypercube(evals: &mut Vec) { evals.par_chunks_mut(chunk_size).for_each(|chunk| { let half_chunk = chunk_size >> 1; for j in half_chunk..chunk_size { - chunk[j] = chunk[j] - chunk[j - half_chunk]; + chunk[j] -= chunk[j - half_chunk]; } }); } diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 4842f6e09..712e6f49c 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -5,12 +5,23 @@ use goldilocks::SmallField; use poseidon::Poseidon; use serde::{Deserialize, Serialize}; +use transcript::Transcript; -pub const DIGEST_WIDTH: usize = super::transcript::OUTPUT_WIDTH; +pub const DIGEST_WIDTH: usize = transcript::basic::OUTPUT_WIDTH; #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] pub struct Digest(pub [F; DIGEST_WIDTH]); pub type Hasher = Poseidon; +pub fn write_digest_to_transcript( + digest: &Digest, + transcript: &mut Transcript, +) { + digest + .0 + .iter() + .for_each(|x| transcript.append_field_element(x)); +} + pub fn new_hasher() -> Hasher { // FIXME: Change to the right parameter Hasher::::new(8, 22) @@ -40,6 +51,54 @@ pub fn hash_two_leaves_base( Digest(result) } +pub fn hash_two_leaves_batch_ext( + a: &[E], + b: &[E], + hasher: &Hasher, +) -> Digest { + let mut left_hasher = hasher.clone(); + a.iter().for_each(|a| left_hasher.update(a.as_bases())); + let left = Digest( + left_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + let mut right_hasher = hasher.clone(); + b.iter().for_each(|b| right_hasher.update(b.as_bases())); + let right = Digest( + right_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + hash_two_digests(&left, &right, hasher) +} + +pub fn hash_two_leaves_batch_base( + a: &Vec, + b: &Vec, + hasher: &Hasher, +) -> Digest { + let mut left_hasher = hasher.clone(); + left_hasher.update(a.as_slice()); + let left = Digest( + left_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + let mut right_hasher = hasher.clone(); + right_hasher.update(b.as_slice()); + let right = Digest( + right_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + hash_two_digests(&left, &right, hasher) +} + pub fn hash_two_digests( a: &Digest, b: &Digest, diff --git a/mpcs/src/util/merkle_tree.rs b/mpcs/src/util/merkle_tree.rs index 1c563f126..d486e4361 100644 --- a/mpcs/src/util/merkle_tree.rs +++ b/mpcs/src/util/merkle_tree.rs @@ -1,19 +1,27 @@ use ff_ext::ExtensionField; +use itertools::Itertools; use multilinear_extensions::mle::FieldType; use rayon::{ - iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, + iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, + }, slice::ParallelSlice, }; use crate::util::{ - hash::{hash_two_digests, hash_two_leaves_base, hash_two_leaves_ext, Digest, Hasher}, - log2_strict, - transcript::{TranscriptRead, TranscriptWrite}, - Deserialize, DeserializeOwned, Serialize, + field_type_index_base, field_type_index_ext, + hash::{ + hash_two_digests, hash_two_leaves_base, hash_two_leaves_batch_base, + hash_two_leaves_batch_ext, hash_two_leaves_ext, Digest, Hasher, + }, + log2_strict, Deserialize, DeserializeOwned, Serialize, }; +use transcript::Transcript; use ark_std::{end_timer, start_timer}; +use super::hash::write_digest_to_transcript; + #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(bound(serialize = "E: Serialize", deserialize = "E: DeserializeOwned"))] pub struct MerkleTree @@ -21,7 +29,7 @@ where E::BaseField: Serialize + DeserializeOwned, { inner: Vec>>, - leaves: FieldType, + leaves: Vec>, } impl MerkleTree @@ -30,7 +38,14 @@ where { pub fn from_leaves(leaves: FieldType, hasher: &Hasher) -> Self { Self { - inner: merkelize::(&leaves, hasher), + inner: merkelize::(&[&leaves], hasher), + leaves: vec![leaves], + } + } + + pub fn from_batch_leaves(leaves: Vec>, hasher: &Hasher) -> Self { + Self { + inner: merkelize::(&leaves.iter().collect_vec(), hasher), leaves, } } @@ -47,26 +62,53 @@ where self.inner.len() } - pub fn leaves(&self) -> &FieldType { + pub fn leaves(&self) -> &Vec> { &self.leaves } - pub fn size(&self) -> usize { - self.leaves.len() + pub fn batch_leaves(&self, coeffs: &[E]) -> Vec { + (0..self.leaves[0].len()) + .into_par_iter() + .map(|i| { + self.leaves + .iter() + .zip(coeffs.iter()) + .map(|(leaf, coeff)| field_type_index_ext(leaf, i) * *coeff) + .sum() + }) + .collect() + } + + pub fn size(&self) -> (usize, usize) { + (self.leaves.len(), self.leaves[0].len()) } - pub fn get_leaf_as_base(&self, index: usize) -> E::BaseField { - match &self.leaves { - FieldType::Base(leaves) => leaves[index], - FieldType::Ext(_) => panic!("Mismatching field type, calling get_leaf_as_base on a Merkle tree over extension fields"), + pub fn get_leaf_as_base(&self, index: usize) -> Vec { + match &self.leaves[0] { + FieldType::Base(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_base(leaves, index)) + .collect(), + FieldType::Ext(_) => panic!( + "Mismatching field type, calling get_leaf_as_base on a Merkle tree over extension fields" + ), FieldType::Unreachable => unreachable!(), } } - pub fn get_leaf_as_extension(&self, index: usize) -> E { - match &self.leaves { - FieldType::Base(leaves) => E::from(leaves[index]), - FieldType::Ext(leaves) => leaves[index], + pub fn get_leaf_as_extension(&self, index: usize) -> Vec { + match &self.leaves[0] { + FieldType::Base(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_ext(leaves, index)) + .collect(), + FieldType::Ext(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_ext(leaves, index)) + .collect(), FieldType::Unreachable => unreachable!(), } } @@ -75,7 +117,7 @@ where &self, leaf_index: usize, ) -> MerklePathWithoutLeafOrRoot { - assert!(leaf_index < self.size()); + assert!(leaf_index < self.size().1); MerklePathWithoutLeafOrRoot::::new( self.inner .iter() @@ -105,6 +147,10 @@ where Self { inner } } + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + pub fn len(&self) -> usize { self.inner.len() } @@ -113,22 +159,10 @@ where self.inner.iter() } - pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + pub fn write_transcript(&self, transcript: &mut Transcript) { self.inner .iter() - .for_each(|hash| transcript.write_commitment(hash).unwrap()); - } - - pub fn read_transcript( - transcript: &mut impl TranscriptRead, E>, - height: usize, - ) -> Self { - // Since no root, the number of digests is height - 1 - let mut inner = Vec::with_capacity(height - 1); - for _ in 0..(height - 1) { - inner.push(transcript.read_commitment().unwrap()); - } - Self { inner } + .for_each(|hash| write_digest_to_transcript(hash, transcript)); } pub fn authenticate_leaves_root_ext( @@ -164,28 +198,100 @@ where hasher, ) } + + pub fn authenticate_batch_leaves_root_ext( + &self, + left: Vec, + right: Vec, + index: usize, + root: &Digest, + hasher: &Hasher, + ) { + authenticate_merkle_path_root_batch::( + &self.inner, + FieldType::Ext(left), + FieldType::Ext(right), + index, + root, + hasher, + ) + } + + pub fn authenticate_batch_leaves_root_base( + &self, + left: Vec, + right: Vec, + index: usize, + root: &Digest, + hasher: &Hasher, + ) { + authenticate_merkle_path_root_batch::( + &self.inner, + FieldType::Base(left), + FieldType::Base(right), + index, + root, + hasher, + ) + } } fn merkelize( - values: &FieldType, + values: &[&FieldType], hasher: &Hasher, ) -> Vec>> { - let timer = start_timer!(|| format!("merkelize {} values", values.len())); - let log_v = log2_strict(values.len()); + #[cfg(feature = "sanity-check")] + for i in 0..(values.len() - 1) { + assert_eq!(values[i].len(), values[i + 1].len()); + } + let timer = start_timer!(|| format!("merkelize {} values", values[0].len() * values.len())); + let log_v = log2_strict(values[0].len()); let mut tree = Vec::with_capacity(log_v); // The first layer of hashes, half the number of leaves - let mut hashes = vec![Digest::default(); values.len() >> 1]; - hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = match values { - FieldType::Base(values) => { - hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1], hasher) - } - FieldType::Ext(values) => { - hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1], hasher) - } - FieldType::Unreachable => unreachable!(), - }; - }); + let mut hashes = vec![Digest::default(); values[0].len() >> 1]; + if values.len() == 1 { + hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { + *hash = match &values[0] { + FieldType::Base(values) => { + hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1], hasher) + } + FieldType::Ext(values) => { + hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1], hasher) + } + FieldType::Unreachable => unreachable!(), + }; + }); + } else { + hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { + *hash = match &values[0] { + FieldType::Base(_) => hash_two_leaves_batch_base::( + &values + .iter() + .map(|values| field_type_index_base(values, i << 1)) + .collect(), + &values + .iter() + .map(|values| field_type_index_base(values, (i << 1) + 1)) + .collect(), + hasher, + ), + FieldType::Ext(_) => hash_two_leaves_batch_ext::( + values + .iter() + .map(|values| field_type_index_ext(values, i << 1)) + .collect_vec() + .as_slice(), + values + .iter() + .map(|values| field_type_index_ext(values, (i << 1) + 1)) + .collect_vec() + .as_slice(), + hasher, + ), + FieldType::Unreachable => unreachable!(), + }; + }); + } tree.push(hashes); @@ -202,7 +308,7 @@ fn merkelize( } fn authenticate_merkle_path_root( - path: &Vec>, + path: &[Digest], leaves: FieldType, x_index: usize, root: &Digest, @@ -218,11 +324,43 @@ fn authenticate_merkle_path_root( // The lowest bit in the index is ignored. It can point to either leaves x_index >>= 1; - for i in 0..path.len() { + for path_i in path.iter() { + hash = if x_index & 1 == 0 { + hash_two_digests(&hash, path_i, hasher) + } else { + hash_two_digests(path_i, &hash, hasher) + }; + x_index >>= 1; + } + assert_eq!(&hash, root); +} + +fn authenticate_merkle_path_root_batch( + path: &[Digest], + left: FieldType, + right: FieldType, + x_index: usize, + root: &Digest, + hasher: &Hasher, +) { + let mut x_index = x_index; + let mut hash = match (left, right) { + (FieldType::Base(left), FieldType::Base(right)) => { + hash_two_leaves_batch_base::(&left, &right, hasher) + } + (FieldType::Ext(left), FieldType::Ext(right)) => { + hash_two_leaves_batch_ext::(&left, &right, hasher) + } + _ => unreachable!(), + }; + + // The lowest bit in the index is ignored. It can point to either leaves + x_index >>= 1; + for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, &path[i], hasher) + hash_two_digests(&hash, path_i, hasher) } else { - hash_two_digests(&path[i], &hash, hasher) + hash_two_digests(path_i, &hash, hasher) }; x_index >>= 1; } diff --git a/mpcs/src/util/parallel.rs b/mpcs/src/util/parallel.rs index 5f44d473d..950d43797 100644 --- a/mpcs/src/util/parallel.rs +++ b/mpcs/src/util/parallel.rs @@ -1,10 +1,7 @@ pub fn num_threads() -> usize { #[cfg(feature = "parallel")] let nt = rayon::current_num_threads(); - return nt; - - #[cfg(not(feature = "parallel"))] - return 1; + if cfg!(feature = "parallel") { nt } else { 1 } } pub fn parallelize_iter(iter: I, f: F) diff --git a/mpcs/src/util/transcript.rs b/mpcs/src/util/transcript.rs index 343bfffba..2edf082cf 100644 --- a/mpcs/src/util/transcript.rs +++ b/mpcs/src/util/transcript.rs @@ -21,17 +21,16 @@ pub trait FieldTranscript { fn common_field_element_ext(&mut self, fe: &E) -> Result<(), Error>; fn common_field_elements(&mut self, fes: FieldType) -> Result<(), Error> { - Ok(match fes { + match fes { FieldType::Base(fes) => fes .iter() - .map(|fe| self.common_field_element_base(fe)) - .try_collect()?, + .try_for_each(|fe| self.common_field_element_base(fe))?, FieldType::Ext(fes) => fes .iter() - .map(|fe| self.common_field_element_ext(fe)) - .try_collect()?, + .try_for_each(|fe| self.common_field_element_ext(fe))?, FieldType::Unreachable => unreachable!(), - }) + }; + Ok(()) } } @@ -146,7 +145,7 @@ impl Stream { self.inner.len() - self.pointer } - pub fn read_exact(&mut self, output: &mut Vec) -> Result<(), Error> { + pub fn read_exact(&mut self, output: &mut [T]) -> Result<(), Error> { let left = self.left(); if left < output.len() { return Err(Error::Transcript(