From a225ab73792c2036737689ffb81d8ac07038b915 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Mon, 9 Sep 2024 16:22:31 +0800 Subject: [PATCH] Better batch commit and switch to Reed Solomon code. (#155) Main tasks accomplished by this PR: - [x] Replace the naive batch commit (committing to individual polys) to real batch commit, i.e., committing to multiple polynomials in a single Merkle tree. - [x] Add the `simple_batch_prove` and `simple_batch_verify` methods. These methods support opening: - One commitment that commits to multiple polynomials of the same size. - One opening point. - [x] Switch the encoding algorithm from the one in BaseFold paper to Reed Solomon code. The encoding algorithm of RS code is much faster, and RS code has better distance so allows a better parameter. - [x] Estimate the appropriate parameter for RS code. (The original `batch_prove` and `batch_verify` methods supports opening multiple commitments, multiple points and a flexible combination between polys and points, but only allow each input commitment to contain only one polynomial) --------- Co-authored-by: Wisdom Ogwu <40731160+iammadab@users.noreply.github.com> Co-authored-by: dreamATD --- mpcs/Cargo.toml | 22 +- mpcs/benches/basecode.rs | 110 ++ mpcs/benches/commit_open_verify_basecode.rs | 395 +++++ ...pen_verify.rs => commit_open_verify_rs.rs} | 181 ++- mpcs/benches/fft.rs | 80 + mpcs/benches/interpolate.rs | 81 + mpcs/benches/rscode.rs | 112 ++ mpcs/src/basefold.rs | 715 +++++++-- mpcs/src/basefold/basecode.rs | 293 ---- mpcs/src/basefold/commit_phase.rs | 221 ++- mpcs/src/basefold/encoding.rs | 237 +++ mpcs/src/basefold/encoding/basecode.rs | 451 ++++++ mpcs/src/basefold/encoding/rs.rs | 877 +++++++++++ mpcs/src/basefold/encoding/utils.rs | 35 + mpcs/src/basefold/query_phase.rs | 1357 ++++++++++++----- mpcs/src/basefold/structure.rs | 158 +- mpcs/src/basefold/sumcheck.rs | 55 +- mpcs/src/lib.rs | 182 ++- mpcs/src/sum_check.rs | 10 +- mpcs/src/sum_check/classic.rs | 8 +- mpcs/src/sum_check/classic/coeff.rs | 8 +- mpcs/src/util.rs | 36 + mpcs/src/util/arithmetic.rs | 27 +- mpcs/src/util/arithmetic/hypercube.rs | 3 +- mpcs/src/util/hash.rs | 48 + mpcs/src/util/merkle_tree.rs | 214 ++- mpcs/src/util/parallel.rs | 5 +- mpcs/src/util/transcript.rs | 13 +- 28 files changed, 4829 insertions(+), 1105 deletions(-) create mode 100644 mpcs/benches/basecode.rs create mode 100644 mpcs/benches/commit_open_verify_basecode.rs rename mpcs/benches/{commit_open_verify.rs => commit_open_verify_rs.rs} (60%) create mode 100644 mpcs/benches/fft.rs create mode 100644 mpcs/benches/interpolate.rs create mode 100644 mpcs/benches/rscode.rs delete mode 100644 mpcs/src/basefold/basecode.rs create mode 100644 mpcs/src/basefold/encoding.rs create mode 100644 mpcs/src/basefold/encoding/basecode.rs create mode 100644 mpcs/src/basefold/encoding/rs.rs create mode 100644 mpcs/src/basefold/encoding/utils.rs diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 18cecc5f0..4ff1ce8f3 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -46,5 +46,25 @@ sanity-check = [] print-trace = [ "ark-std/print-trace" ] [[bench]] -name = "commit_open_verify" +name = "commit_open_verify_rs" +harness = false + +[[bench]] +name = "commit_open_verify_basecode" +harness = false + +[[bench]] +name = "basecode" +harness = false + +[[bench]] +name = "rscode" +harness = false + +[[bench]] +name = "interpolate" +harness = false + +[[bench]] +name = "fft" harness = false \ No newline at end of file diff --git a/mpcs/benches/basecode.rs b/mpcs/benches/basecode.rs new file mode 100644 index 000000000..7e8ccca99 --- /dev/null +++ b/mpcs/benches/basecode.rs @@ -0,0 +1,110 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::{ + util::{ + arithmetic::interpolate_field_type_over_boolean_hypercube, + plonky2_util::reverse_index_bits_in_place_field_type, + }, Basefold, BasefoldBasecodeParams, BasefoldSpec, EncodingScheme, PolynomialCommitmentScheme +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type Pcs = Basefold; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "encoding_basecode_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, _) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + + let mut codeword = + <>::EncodingScheme as EncodingScheme>::encode( + &pp.encoding_params, + &coeffs, + ); + + // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above + // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); + + // The sum-check protocol starts from the first variable, but the FRI part + // will eventually produce the evaluation at (alpha_k, ..., alpha_1), so apply + // the bit-reversion to reverse the variable indices of the polynomial. + // In short: store the poly and codeword in big endian + reverse_index_bits_in_place_field_type(&mut coeffs); + reverse_index_bits_in_place_field_type(&mut codeword); + + (coeffs, codeword) + }) + .collect::<(Vec>, Vec>)>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify_basecode.rs b/mpcs/benches/commit_open_verify_basecode.rs new file mode 100644 index 000000000..922d90001 --- /dev/null +++ b/mpcs/benches/commit_open_verify_basecode.rs @@ -0,0 +1,395 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::{chain, Itertools}; +use mpcs::{ + util::{ + plonky2_util::log2_ceil, + transcript::{ + FieldTranscript, FieldTranscriptRead, FieldTranscriptWrite, InMemoryTranscript, + PoseidonTranscript, + }, + }, + Basefold, BasefoldBasecodeParams, Evaluation, PolynomialCommitmentScheme, +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Pcs = Basefold; +type T = PoseidonTranscript; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "ext2" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + let (pp, vp) = { + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + + group.bench_function(BenchmarkId::new("setup", format!("{}", num_vars)), |b| { + b.iter(|| { + Pcs::setup(poly_size, &rng).unwrap(); + }) + }); + Pcs::trim(¶m, poly_size).unwrap() + }; + + let proof = { + let mut transcript = T::new(); + let poly = if is_base { + DenseMultilinearExtension::random(num_vars, &mut OsRng) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + }; + + let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); + + group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { + b.iter(|| { + Pcs::commit(&pp, &poly).unwrap(); + }) + }); + + let point = transcript.squeeze_challenges(num_vars); + let eval = poly.evaluate(point.as_slice()); + transcript.write_field_element_ext(&eval).unwrap(); + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + + group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| { + b.iter_batched( + || transcript.clone(), + |mut transcript| { + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + }, + BatchSize::SmallInput, + ); + }); + + transcript.into_proof() + }; + // Verify + let mut transcript = T::from_proof(proof.as_slice()); + Pcs::verify( + &vp, + &Pcs::read_commitment(&vp, &mut transcript).unwrap(), + &transcript.squeeze_challenges(num_vars), + &transcript.read_field_element_ext().unwrap(), + &mut transcript, + ) + .unwrap(); + group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| { + b.iter_batched( + || T::from_proof(proof.as_slice()), + |mut transcript| { + Pcs::verify( + &vp, + &Pcs::read_commitment(&vp, &mut transcript).unwrap(), + &transcript.squeeze_challenges(num_vars), + &transcript.read_field_element_ext().unwrap(), + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }); + } +} + +fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "ext2" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let num_points = batch_size >> 1; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + // Setup + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + // Batch commit and open + let evals = chain![ + (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys + (0..num_points).map(|point| (point * 2 + 1, point)), + ] + .unique() + .collect_vec(); + + let proof = { + let mut transcript = T::new(); + let polys = (0..batch_size) + .map(|i| { + if is_base { + DenseMultilinearExtension::random( + num_vars - log2_ceil((i >> 1) + 1), + &mut rng.clone(), + ) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars - log2_ceil((i >> 1) + 1), + (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) + .map(|_| E::random(&mut OsRng)) + .collect(), + ) + } + }) + .collect_vec(); + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) + .collect_vec(); + + let points = (0..num_points) + .map(|i| transcript.squeeze_challenges(num_vars - log2_ceil(i + 1))) + .take(num_points) + .collect_vec(); + + let evals = evals + .iter() + .copied() + .map(|(poly, point)| { + Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) + }) + .collect_vec(); + transcript + .write_field_elements_ext(evals.iter().map(Evaluation::value)) + .unwrap(); + Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript.clone(), + |mut transcript| { + Pcs::batch_open( + &pp, + &polys, + &comms, + &points, + &evals, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + + transcript.into_proof() + }; + // Batch verify + let mut transcript = T::from_proof(proof.as_slice()); + let comms = &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(); + + let points = (0..num_points) + .map(|i| transcript.squeeze_challenges(num_vars - log2_ceil(i + 1))) + .take(num_points) + .collect_vec(); + + let evals2 = transcript.read_field_elements_ext(evals.len()).unwrap(); + + let backup_transcript = transcript.clone(); + + Pcs::batch_verify( + &vp, + comms, + &points, + &evals + .iter() + .copied() + .zip(evals2.clone()) + .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) + .collect_vec(), + &mut transcript, + ) + .unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::batch_verify( + &vp, + comms, + &points, + &evals + .iter() + .copied() + .zip(evals2.clone()) + .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) + .collect_vec(), + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "simple_batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let proof = { + let mut transcript = T::new(); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + Pcs::batch_commit(&pp, &polys).unwrap(); + }) + }, + ); + + let point = transcript.squeeze_challenges(num_vars); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); + + transcript.write_field_elements_ext(&evals).unwrap(); + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript.clone(), + |mut transcript| { + Pcs::simple_batch_open( + &pp, + &polys, + &comm, + &point, + &evals, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + transcript.into_proof() + }; + // Batch verify + let mut transcript = T::from_proof(proof.as_slice()); + let comms = &Pcs::read_commitment(&vp, &mut transcript).unwrap(); + + let point = transcript.squeeze_challenges(num_vars); + let evals = transcript.read_field_elements_ext(batch_size).unwrap(); + + let backup_transcript = transcript.clone(); + + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_commit_open_verify_goldilocks(c, false); +} + +fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_commit_open_verify_goldilocks(c, true); +} + +fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks(c, true); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2,bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify.rs b/mpcs/benches/commit_open_verify_rs.rs similarity index 60% rename from mpcs/benches/commit_open_verify.rs rename to mpcs/benches/commit_open_verify_rs.rs index 8d731f514..e372e2d11 100644 --- a/mpcs/benches/commit_open_verify.rs +++ b/mpcs/benches/commit_open_verify_rs.rs @@ -6,31 +6,38 @@ use goldilocks::GoldilocksExt2; use itertools::{chain, Itertools}; use mpcs::{ - util::transcript::{ - FieldTranscript, FieldTranscriptRead, FieldTranscriptWrite, InMemoryTranscript, - PoseidonTranscript, + util::{ + plonky2_util::log2_ceil, + transcript::{ + FieldTranscript, FieldTranscriptRead, FieldTranscriptWrite, InMemoryTranscript, + PoseidonTranscript, + }, }, - Basefold, BasefoldDefaultParams, Evaluation, PolynomialCommitmentScheme, + Basefold, BasefoldRSParams, Evaluation, PolynomialCommitmentScheme, }; use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use rand::{rngs::OsRng, SeedableRng}; use rand_chacha::ChaCha8Rng; -type Pcs = Basefold; +type Pcs = Basefold; type T = PoseidonTranscript; type E = GoldilocksExt2; const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let mut group = c.benchmark_group(format!( - "commit_open_verify_goldilocks_{}", + "commit_open_verify_goldilocks_rs_{}", if is_base { "base" } else { "ext2" } )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field - for num_vars in 10..=20 { + for num_vars in NUM_VARS_START..=NUM_VARS_END { let (pp, vp) = { let rng = ChaCha8Rng::from_seed([0u8; 32]); let poly_size = 1 << num_vars; @@ -41,7 +48,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::setup(poly_size, &rng).unwrap(); }) }); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; let proof = { @@ -111,13 +118,13 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let mut group = c.benchmark_group(format!( - "commit_batch_open_verify_goldilocks_{}", - if is_base { "base" } else { "extension" } + "batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "ext2" } )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field - for num_vars in 10..=20 { - for batch_size_log in 1..=6 { + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { let batch_size = 1 << batch_size_log; let num_points = batch_size >> 1; let rng = ChaCha8Rng::from_seed([0u8; 32]); @@ -125,7 +132,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -140,28 +147,27 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let polys = (0..batch_size) .map(|i| { if is_base { - DenseMultilinearExtension::random(num_vars - (i >> 1), &mut rng.clone()) + DenseMultilinearExtension::random( + num_vars - log2_ceil((i >> 1) + 1), + &mut rng.clone(), + ) } else { DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + num_vars - log2_ceil((i >> 1) + 1), + (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) + .map(|_| E::random(&mut OsRng)) + .collect(), ) } }) .collect_vec(); - let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_commit", format!("{}", num_vars)), - |b| { - b.iter(|| { - Pcs::batch_commit(&pp, &polys).unwrap(); - }) - }, - ); + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) + .collect_vec(); let points = (0..num_points) - .map(|i| transcript.squeeze_challenges(num_vars - i)) + .map(|i| transcript.squeeze_challenges(num_vars - log2_ceil(i + 1))) .take(num_points) .collect_vec(); @@ -178,7 +184,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); group.bench_function( - BenchmarkId::new("batch_open", format!("{}", num_vars)), + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), |b| { b.iter_batched( || transcript.clone(), @@ -205,12 +211,14 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let comms = &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(); let points = (0..num_points) - .map(|i| transcript.squeeze_challenges(num_vars - i)) + .map(|i| transcript.squeeze_challenges(num_vars - log2_ceil(i + 1))) .take(num_points) .collect_vec(); let evals2 = transcript.read_field_elements_ext(evals.len()).unwrap(); + let backup_transcript = transcript.clone(); + Pcs::batch_verify( &vp, comms, @@ -226,10 +234,10 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { .unwrap(); group.bench_function( - BenchmarkId::new("batch_verify", format!("{}", num_vars)), + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), |b| { b.iter_batched( - || transcript.clone(), + || backup_transcript.clone(), |mut transcript| { Pcs::batch_verify( &vp, @@ -253,6 +261,107 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { } } +fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "simple_batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let proof = { + let mut transcript = T::new(); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + Pcs::batch_commit(&pp, &polys).unwrap(); + }) + }, + ); + + let point = transcript.squeeze_challenges(num_vars); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); + + transcript.write_field_elements_ext(&evals).unwrap(); + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript.clone(), + |mut transcript| { + Pcs::simple_batch_open( + &pp, + &polys, + &comm, + &point, + &evals, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + transcript.into_proof() + }; + // Batch verify + let mut transcript = T::from_proof(proof.as_slice()); + let comms = &Pcs::read_commitment(&vp, &mut transcript).unwrap(); + + let point = transcript.squeeze_challenges(num_vars); + let evals = transcript.read_field_elements_ext(batch_size).unwrap(); + + let backup_transcript = transcript.clone(); + + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { bench_commit_open_verify_goldilocks(c, false); } @@ -269,10 +378,18 @@ fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { bench_batch_commit_open_verify_goldilocks(c, true); } +fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, true); +} + criterion_group! { name = bench_basefold; config = Criterion::default().warm_up_time(Duration::from_millis(3000)); - targets = bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2 + targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2,bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, } criterion_main!(bench_basefold); diff --git a/mpcs/benches/fft.rs b/mpcs/benches/fft.rs new file mode 100644 index 000000000..2588a11e2 --- /dev/null +++ b/mpcs/benches/fft.rs @@ -0,0 +1,80 @@ +use std::time::Duration; + +use criterion::*; +use ff::{Field, PrimeField}; +use goldilocks::{Goldilocks, GoldilocksExt2}; + +use itertools::Itertools; +use mpcs::{coset_fft, fft_root_table}; + +use multilinear_extensions::mle::DenseMultilinearExtension; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 6; + +fn bench_fft(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "fft_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + let root_table = fft_root_table(num_vars); + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let mut polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys.par_iter_mut().for_each(|poly| { + coset_fft::( + &mut poly.evaluations, + Goldilocks::MULTIPLICATIVE_GENERATOR, + 0, + &root_table, + ); + }); + }) + }, + ); + } + } +} + +fn bench_fft_goldilocks_2(c: &mut Criterion) { + bench_fft(c, false); +} + +fn bench_fft_base(c: &mut Criterion) { + bench_fft(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_fft_base, bench_fft_goldilocks_2 +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/interpolate.rs b/mpcs/benches/interpolate.rs new file mode 100644 index 000000000..00af7c092 --- /dev/null +++ b/mpcs/benches/interpolate.rs @@ -0,0 +1,81 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::util::arithmetic::interpolate_field_type_over_boolean_hypercube; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "interpolate_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + coeffs + }) + .collect::>>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs new file mode 100644 index 000000000..ad42c7f39 --- /dev/null +++ b/mpcs/benches/rscode.rs @@ -0,0 +1,112 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::{ + util::{ + arithmetic::interpolate_field_type_over_boolean_hypercube, + plonky2_util::reverse_index_bits_in_place_field_type, + }, + Basefold, BasefoldRSParams, BasefoldSpec, EncodingScheme, + PolynomialCommitmentScheme, +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type Pcs = Basefold; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "encoding_rscode_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, _) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + + let mut codeword = + <>::EncodingScheme as EncodingScheme>::encode( + &pp.encoding_params, + &coeffs, + ); + + // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above + // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); + + // The sum-check protocol starts from the first variable, but the FRI part + // will eventually produce the evaluation at (alpha_k, ..., alpha_1), so apply + // the bit-reversion to reverse the variable indices of the polynomial. + // In short: store the poly and codeword in big endian + reverse_index_bits_in_place_field_type(&mut coeffs); + reverse_index_bits_in_place_field_type(&mut codeword); + + (coeffs, codeword) + }) + .collect::<(Vec>, Vec>)>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 29cc51d1d..bb923ebd7 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -23,13 +23,20 @@ use crate::{ validate_input, Error, Evaluation, NoninteractivePCS, PolynomialCommitmentScheme, }; use ark_std::{end_timer, start_timer}; +pub use encoding::{ + Basecode, BasecodeDefaultSpec, EncodingProverParameters, EncodingScheme, RSCode, + RSCodeDefaultSpec, +}; use ff_ext::ExtensionField; use multilinear_extensions::mle::MultilinearExtension; use query_phase::{ - batch_query_phase, batch_verifier_query_phase, query_phase, verifier_query_phase, + batch_prover_query_phase, batch_verifier_query_phase, prover_query_phase, + simple_batch_prover_query_phase, simple_batch_verifier_query_phase, verifier_query_phase, BatchedQueriesResultWithMerklePath, QueriesResultWithMerklePath, + SimpleBatchQueriesResultWithMerklePath, }; use std::{borrow::BorrowMut, ops::Deref}; +pub use structure::BasefoldSpec; use itertools::Itertools; use serde::{de::DeserializeOwned, Serialize}; @@ -39,102 +46,67 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, }; -use rand_chacha::ChaCha8Rng; -use rayon::prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator}; +use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; +use rayon::{ + iter::IntoParallelIterator, + prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator}, +}; use std::borrow::Cow; + type SumCheck = ClassicSumCheck>; -mod basecode; mod structure; -use basecode::{ - encode_field_type_rs_basecode, evaluate_over_foldable_domain_generic_basecode, get_table_aes, -}; pub use structure::{ - Basefold, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, - BasefoldDefaultParams, BasefoldExtParams, BasefoldParams, BasefoldProverParams, + Basefold, BasefoldBasecodeParams, BasefoldCommitment, BasefoldCommitmentWithData, + BasefoldDefault, BasefoldParams, BasefoldProverParams, BasefoldRSParams, BasefoldVerifierParams, }; mod commit_phase; -use commit_phase::{batch_commit_phase, commit_phase}; +use commit_phase::{batch_commit_phase, commit_phase, simple_batch_commit_phase}; +mod encoding; +pub use encoding::{coset_fft, fft, fft_root_table}; mod query_phase; // This sumcheck module is different from the mpcs::sumcheck module, in that // it deals only with the special case of the form \sum eq(r_i)f_i(). mod sumcheck; -impl PolynomialCommitmentScheme for Basefold +impl, Rng: RngCore> Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, { - type Param = BasefoldParams; - type ProverParam = BasefoldProverParams; - type VerifierParam = BasefoldVerifierParams; - type CommitmentWithData = BasefoldCommitmentWithData; - type Commitment = BasefoldCommitment; - type CommitmentChunk = Digest; - type Rng = ChaCha8Rng; - - fn setup(poly_size: usize, rng: &Self::Rng) -> Result { - let log_rate = V::get_rate(); - let (table_w_weights, table) = get_table_aes::(poly_size, log_rate, &mut rng.clone()); - - Ok(BasefoldParams { - log_rate, - num_verifier_queries: V::get_reps(), - max_num_vars: log2_strict(poly_size), - table_w_weights, - table, - rng: rng.clone(), - }) - } - - fn trim(param: &Self::Param) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { - Ok(( - BasefoldProverParams { - log_rate: param.log_rate, - table_w_weights: param.table_w_weights.clone(), - table: param.table.clone(), - num_verifier_queries: param.num_verifier_queries, - max_num_vars: param.max_num_vars, - }, - BasefoldVerifierParams { - rng: param.rng.clone(), - max_num_vars: param.max_num_vars, - log_rate: param.log_rate, - num_verifier_queries: param.num_verifier_queries, - }, - )) - } - - fn commit( - pp: &Self::ProverParam, + /// Converts a polynomial to a code word, also returns the evaluations over the boolean hypercube + /// for said polynomial + fn get_poly_bh_evals_and_codeword( + pp: &BasefoldProverParams, poly: &DenseMultilinearExtension, - ) -> Result { - let timer = start_timer!(|| "Basefold::commit"); + ) -> (FieldType, FieldType) { // bh_evals is just a copy of poly.evals(). // Note that this function implicitly assumes that the size of poly.evals() is a // power of two. Otherwise, the function crashes with index out of bound. let mut bh_evals = poly.evaluations.clone(); - let num_vars = log2_strict(bh_evals.len()); - assert!(num_vars <= pp.max_num_vars && num_vars >= V::get_basecode()); + let num_vars = poly.num_vars; + assert!( + num_vars <= pp.encoding_params.get_max_message_size_log(), + "num_vars {} > pp.max_num_vars {}", + num_vars, + pp.encoding_params.get_max_message_size_log() + ); + assert!( + num_vars >= Spec::get_basecode_msg_size_log(), + "num_vars {} < Spec::get_basecode_msg_size_log() {}", + num_vars, + Spec::get_basecode_msg_size_log() + ); // Switch to coefficient form let mut coeffs = bh_evals.clone(); interpolate_field_type_over_boolean_hypercube(&mut coeffs); - // Split the input into chunks of message size, encode each message, and return the codewords - let basecode = - encode_field_type_rs_basecode(&coeffs, 1 << pp.log_rate, 1 << V::get_basecode()); - - // Apply the recursive definition of the BaseFold code to the list of base codewords, - // and produce the final codeword - let mut codeword = evaluate_over_foldable_domain_generic_basecode::( - 1 << V::get_basecode(), - coeffs.len(), - pp.log_rate, - basecode, - &pp.table, - ); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place_field_type(&mut coeffs); + } + let mut codeword = Spec::EncodingScheme::encode(&pp.encoding_params, &coeffs); // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); @@ -146,6 +118,107 @@ where reverse_index_bits_in_place_field_type(&mut bh_evals); reverse_index_bits_in_place_field_type(&mut codeword); + (bh_evals, codeword) + } + + /// Transpose a matrix of field elements, generic over the type of field element + pub fn transpose_field_type( + matrix: &[FieldType], + ) -> Result>, Error> { + let transpose_fn = match matrix[0] { + FieldType::Ext(_) => Self::get_column_ext, + FieldType::Base(_) => Self::get_column_base, + FieldType::Unreachable => unreachable!(), + }; + + let len = matrix[0].len(); + (0..len) + .into_par_iter() + .map(|i| (transpose_fn)(matrix, i)) + .collect() + } + + fn get_column_base( + matrix: &[FieldType], + column_index: usize, + ) -> Result, Error> { + Ok(FieldType::Base( + matrix + .par_iter() + .map(|row| match row { + FieldType::Base(content) => Ok(content[column_index]), + _ => Err(Error::InvalidPcsParam( + "expected base field type".to_string(), + )), + }) + .collect::, Error>>()?, + )) + } + + fn get_column_ext(matrix: &[FieldType], column_index: usize) -> Result, Error> { + Ok(FieldType::Ext( + matrix + .par_iter() + .map(|row| match row { + FieldType::Ext(content) => Ok(content[column_index]), + _ => Err(Error::InvalidPcsParam( + "expected ext field type".to_string(), + )), + }) + .collect::, Error>>()?, + )) + } +} + +impl, Rng: RngCore + std::fmt::Debug> + PolynomialCommitmentScheme for Basefold +where + E: Serialize + DeserializeOwned, + E::BaseField: Serialize + DeserializeOwned, +{ + type Param = BasefoldParams; + type ProverParam = BasefoldProverParams; + type VerifierParam = BasefoldVerifierParams; + type CommitmentWithData = BasefoldCommitmentWithData; + type Commitment = BasefoldCommitment; + type CommitmentChunk = Digest; + type Rng = ChaCha8Rng; + + fn setup(poly_size: usize, rng: &Self::Rng) -> Result { + let mut seed = [0u8; 32]; + let mut rng = rng.clone(); + rng.fill_bytes(&mut seed); + let pp = >::setup(log2_strict(poly_size), seed); + + Ok(BasefoldParams { params: pp }) + } + + fn trim( + pp: &Self::Param, + poly_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { + >::trim(&pp.params, log2_strict(poly_size)).map( + |(pp, vp)| { + ( + BasefoldProverParams { + encoding_params: pp, + }, + BasefoldVerifierParams { + encoding_params: vp, + }, + ) + }, + ) + } + + fn commit( + pp: &Self::ProverParam, + poly: &DenseMultilinearExtension, + ) -> Result { + let timer = start_timer!(|| "Basefold::commit"); + + let (bh_evals, codeword) = Self::get_poly_bh_evals_and_codeword(pp, poly); + // Compute and store all the layers of the Merkle tree let hasher = new_hasher::(); let codeword_tree = MerkleTree::::from_leaves(codeword, &hasher); @@ -160,44 +233,89 @@ where Ok(Self::CommitmentWithData { codeword_tree, - bh_evals, - num_vars, + polynomials_bh_evals: vec![bh_evals], + num_vars: poly.num_vars, is_base, + num_polys: 1, }) } fn batch_commit_and_write( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, - ) -> Result, Error> { + ) -> Result { let timer = start_timer!(|| "Basefold::batch_commit_and_write"); - let comms = Self::batch_commit(pp, polys)?; - comms.iter().for_each(|comm| { - transcript.write_commitment(&comm.get_root_as()).unwrap(); - transcript - .write_field_element_base(&u32_to_field::(comm.num_vars as u32)) - .unwrap(); - transcript - .write_field_element_base(&u32_to_field::(comm.is_base as u32)) - .unwrap(); - }); + let comm = Self::batch_commit(pp, polys)?; + transcript.write_commitment(&comm.get_root_as()).unwrap(); + transcript + .write_field_element_base(&u32_to_field::(comm.num_vars as u32)) + .unwrap(); + transcript + .write_field_element_base(&u32_to_field::(comm.is_base as u32)) + .unwrap(); + transcript + .write_field_element_base(&u32_to_field::(comm.num_polys as u32)) + .unwrap(); end_timer!(timer); - Ok(comms) + Ok(comm) } fn batch_commit( pp: &Self::ProverParam, - polys: &Vec>, - ) -> Result, Error> { - let polys_vec: Vec<&DenseMultilinearExtension> = - polys.into_iter().map(|poly| poly).collect(); - polys_vec + polys: &[DenseMultilinearExtension], + ) -> Result { + // assumptions + // 1. there must be at least one polynomial + // 2. all polynomials must exist in the same field type + // 3. all polynomials must have the same number of variables + + if polys.is_empty() { + return Err(Error::InvalidPcsParam( + "cannot batch commit to zero polynomials".to_string(), + )); + } + + for i in 1..polys.len() { + if polys[i].num_vars != polys[0].num_vars { + return Err(Error::InvalidPcsParam( + "cannot batch commit to polynomials with different number of variables" + .to_string(), + )); + } + } + + // convert each polynomial to a code word + let (bh_evals, codewords) = polys .par_iter() - .map(|poly| Self::commit(pp, poly)) - .collect() + .map(|poly| Self::get_poly_bh_evals_and_codeword(pp, poly)) + .collect::<(Vec>, Vec>)>(); + + // transpose the codewords, to group evaluations at the same point + // let leaves = Self::transpose_field_type::(codewords.as_slice())?; + + // build merkle tree from leaves + let hasher = new_hasher::(); + let codeword_tree = MerkleTree::::from_batch_leaves(codewords, &hasher); + + let is_base = match polys[0].evaluations { + FieldType::Ext(_) => false, + FieldType::Base(_) => true, + _ => unreachable!(), + }; + + Ok(Self::CommitmentWithData { + codeword_tree, + polynomials_bh_evals: bh_evals, + num_vars: polys[0].num_vars, + is_base, + num_polys: polys.len(), + }) } + /// Open a single polynomial commitment at one point. If the given + /// commitment with data contains more than one polynomial, this function + /// will panic. fn open( pp: &Self::ProverParam, poly: &DenseMultilinearExtension, @@ -208,22 +326,22 @@ where ) -> Result<(), Error> { let hasher = new_hasher::(); let timer = start_timer!(|| "Basefold::open"); - assert!(comm.num_vars >= V::get_basecode()); - let (trees, oracles) = commit_phase( - &point, - &comm, + assert!(comm.num_vars >= Spec::get_basecode_msg_size_log()); + assert!(comm.num_polys == 1); + let (trees, oracles) = commit_phase::( + &pp.encoding_params, + point, + comm, transcript, poly.num_vars, - poly.num_vars - V::get_basecode(), - &pp.table_w_weights, - pp.log_rate, + poly.num_vars - Spec::get_basecode_msg_size_log(), &hasher, ); let query_timer = start_timer!(|| "Basefold::open::query_phase"); // Each entry in queried_els stores a list of triples (F, F, i) indicating the // position opened at each round and the two values at that round - let queries = query_phase(transcript, &comm, &oracles, pp.num_verifier_queries); + let queries = prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries()); end_timer!(query_timer); let query_timer = start_timer!(|| "Basefold::open::build_query_result"); @@ -241,10 +359,15 @@ where Ok(()) } + /// Open a batch of polynomial commitments at several points. + /// The current version only supports one polynomial per commitment. + /// Because otherwise it is complex to match the polynomials and + /// the commitments, and because currently this high flexibility is + /// not very useful in ceno. fn batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, @@ -252,9 +375,12 @@ where let hasher = new_hasher::(); let timer = start_timer!(|| "Basefold::batch_open"); let num_vars = polys.iter().map(|poly| poly.num_vars).max().unwrap(); - let comms = comms.into_iter().collect_vec(); let min_num_vars = polys.iter().map(|p| p.num_vars).min().unwrap(); - assert!(min_num_vars >= V::get_basecode()); + assert!(min_num_vars >= Spec::get_basecode_msg_size_log()); + + comms.iter().for_each(|comm| { + assert!(comm.num_polys == 1); + }); if cfg!(feature = "sanity-check") { evals.iter().for_each(|eval| { @@ -265,12 +391,7 @@ where }) } - validate_input( - "batch open", - pp.max_num_vars, - &polys.clone(), - &points.to_vec(), - )?; + validate_input("batch open", pp.get_max_message_size_log(), polys, points)?; let sumcheck_timer = start_timer!(|| "Basefold::batch_open::initial sumcheck"); // evals.len() is the batch size, i.e., how many polynomials are being opened together @@ -337,7 +458,7 @@ where .map(|((scalar, poly), point)| { inner_product( &poly_iter_ext(poly).collect_vec(), - build_eq_x_r_vec(&point).iter(), + build_eq_x_r_vec(point).iter(), ) * scalar * E::from(1 << (num_vars - poly.num_vars)) // When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube @@ -402,7 +523,7 @@ where ); *scalar * evals_from_sum_check - * &eq_xy_eval(point.as_slice(), &challenges[0..point.len()]) + * eq_xy_eval(point.as_slice(), &challenges[0..point.len()]) }) .sum::(); assert_eq!(new_target_sum, desired_sum); @@ -412,25 +533,24 @@ where let point = challenges; - let (trees, oracles) = batch_commit_phase( + let (trees, oracles) = batch_commit_phase::( + &pp.encoding_params, &point, - comms.as_slice(), + comms, transcript, num_vars, - num_vars - V::get_basecode(), - &pp.table_w_weights, - pp.log_rate, + num_vars - Spec::get_basecode_msg_size_log(), coeffs.as_slice(), &hasher, ); let query_timer = start_timer!(|| "Basefold::batch_open query phase"); - let query_result = batch_query_phase( + let query_result = batch_prover_query_phase( transcript, - 1 << (num_vars + pp.log_rate), - comms.as_slice(), + 1 << (num_vars + Spec::get_rate_log()), + comms, &oracles, - pp.num_verifier_queries, + Spec::get_number_queries(), ); end_timer!(query_timer); @@ -439,7 +559,7 @@ where BatchedQueriesResultWithMerklePath::from_batched_query_result( query_result, &trees, - &comms, + comms, ); end_timer!(query_timer); @@ -451,6 +571,84 @@ where Ok(()) } + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment and have the same + /// number of variables. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + pp: &Self::ProverParam, + polys: &[DenseMultilinearExtension], + comm: &Self::CommitmentWithData, + point: &[E], + evals: &[E], + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error> { + let hasher = new_hasher::(); + let timer = start_timer!(|| "Basefold::batch_open"); + let num_vars = polys[0].num_vars; + + polys + .iter() + .for_each(|poly| assert_eq!(poly.num_vars, num_vars)); + assert!(num_vars >= Spec::get_basecode_msg_size_log()); + assert_eq!(comm.num_polys, polys.len()); + assert_eq!(comm.num_polys, evals.len()); + + if cfg!(feature = "sanity-check") { + evals + .iter() + .zip(polys) + .for_each(|(eval, poly)| assert_eq!(&poly.evaluate(point), eval)) + } + // evals.len() is the batch size, i.e., how many polynomials are being opened together + let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; + let t = transcript.squeeze_challenges(batch_size_log); + + // Use eq(X,t) where t is random to batch the different evaluation queries. + // Note that this is a small polynomial (only batch_size) compared to the polynomials + // to open. + let eq_xt = build_eq_x_r_vec(&t)[..evals.len()].to_vec(); + let _target_sum = inner_product(evals, &eq_xt); + + // Now the verifier has obtained the new target sum, and is able to compute the random + // linear coefficients. + // The remaining tasks for the prover is to prove that + // sum_i coeffs[i] poly_evals[i] is equal to + // the new target sum, where coeffs is computed as follows + let (trees, oracles) = simple_batch_commit_phase::( + &pp.encoding_params, + point, + &eq_xt, + comm, + transcript, + num_vars, + num_vars - Spec::get_basecode_msg_size_log(), + &hasher, + ); + + let query_timer = start_timer!(|| "Basefold::open::query_phase"); + // Each entry in queried_els stores a list of triples (F, F, i) indicating the + // position opened at each round and the two values at that round + let queries = + simple_batch_prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries()); + end_timer!(query_timer); + + let query_timer = start_timer!(|| "Basefold::open::build_query_result"); + + let queries_with_merkle_path = + SimpleBatchQueriesResultWithMerklePath::from_query_result(queries, &trees, comm); + end_timer!(query_timer); + + let query_timer = start_timer!(|| "Basefold::open::write_queries"); + queries_with_merkle_path.write_transcript(transcript); + end_timer!(query_timer); + + end_timer!(timer); + + Ok(()) + } + fn read_commitments( _: &Self::VerifierParam, num_polys: usize, @@ -462,14 +660,15 @@ where let num_vars = base_to_usize::(&transcript.read_field_element_base().unwrap()); let is_base = base_to_usize::(&transcript.read_field_element_base().unwrap()) != 0; - (num_vars, commitment, is_base) + let num_polys = base_to_usize::(&transcript.read_field_element_base().unwrap()); + (num_vars, commitment, is_base, num_polys) }) .collect_vec(); Ok(roots .iter() - .map(|(num_vars, commitment, is_base)| { - BasefoldCommitment::new(commitment.clone(), *num_vars, *is_base) + .map(|(num_vars, commitment, is_base, num_polys)| { + BasefoldCommitment::new(commitment.clone(), *num_vars, *is_base, *num_polys) }) .collect_vec()) } @@ -484,6 +683,7 @@ where transcript.write_commitment(&comm.get_root_as())?; transcript.write_field_element_base(&u32_to_field::(comm.num_vars as u32))?; transcript.write_field_element_base(&u32_to_field::(comm.is_base as u32))?; + transcript.write_field_element_base(&u32_to_field::(comm.num_polys as u32))?; Ok(comm) } @@ -496,15 +696,16 @@ where transcript: &mut impl TranscriptRead, ) -> Result<(), Error> { let timer = start_timer!(|| "Basefold::verify"); - assert!(comm.num_vars().unwrap() >= V::get_basecode()); + assert!(comm.num_vars().unwrap() >= Spec::get_basecode_msg_size_log()); let hasher = new_hasher::(); - let _field_size = 255; let num_vars = point.len(); - let num_rounds = num_vars - V::get_basecode(); + if let Some(comm_num_vars) = comm.num_vars() { + assert_eq!(num_vars, comm_num_vars); + } + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); - let mut fold_challenges: Vec = Vec::with_capacity(vp.max_num_vars); - let _size = 0; + let mut fold_challenges: Vec = Vec::with_capacity(num_vars); let mut roots = Vec::new(); let mut sumcheck_messages = Vec::with_capacity(num_rounds); let sumcheck_timer = start_timer!(|| "Basefold::verify::interaction"); @@ -519,19 +720,19 @@ where let read_timer = start_timer!(|| "Basefold::verify::read transcript"); let final_message = transcript - .read_field_elements_ext(1 << V::get_basecode()) + .read_field_elements_ext(1 << Spec::get_basecode_msg_size_log()) .unwrap(); let query_challenges = transcript - .squeeze_challenges(vp.num_verifier_queries) + .squeeze_challenges(Spec::get_number_queries()) .iter() - .map(|index| ext_to_usize(index) % (1 << (num_vars + vp.log_rate))) + .map(|index| ext_to_usize(index) % (1 << (num_vars + Spec::get_rate_log()))) .collect_vec(); let read_query_timer = start_timer!(|| "Basefold::verify::read query"); let query_result_with_merkle_path = if comm.is_base() { QueriesResultWithMerklePath::read_transcript_base( transcript, num_rounds, - vp.log_rate, + Spec::get_rate_log(), num_vars, query_challenges.as_slice(), ) @@ -539,7 +740,7 @@ where QueriesResultWithMerklePath::read_transcript_ext( transcript, num_rounds, - vp.log_rate, + Spec::get_rate_log(), num_vars, query_challenges.as_slice(), ) @@ -558,19 +759,18 @@ where let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); eq.par_iter_mut().for_each(|e| *e *= coeff); - verifier_query_phase( + verifier_query_phase::( + &vp.encoding_params, &query_result_with_merkle_path, &sumcheck_messages, &fold_challenges, num_rounds, num_vars, - vp.log_rate, &final_message, &roots, comm, eq.as_slice(), - vp.rng.clone(), - &eval, + eval, &hasher, ); end_timer!(timer); @@ -580,7 +780,7 @@ where fn batch_verify( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, @@ -589,10 +789,10 @@ where // let key = "RAYON_NUM_THREADS"; // env::set_var(key, "32"); let hasher = new_hasher::(); - let comms = comms.into_iter().collect_vec(); + let comms = comms.iter().collect_vec(); let num_vars = points.iter().map(|point| point.len()).max().unwrap(); - let num_rounds = num_vars - V::get_basecode(); - validate_input("batch verify", vp.max_num_vars, &vec![], &points.to_vec())?; + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); + validate_input("batch verify", num_vars, &[], points)?; let poly_num_vars = comms.iter().map(|c| c.num_vars().unwrap()).collect_vec(); evals.iter().for_each(|eval| { assert_eq!( @@ -600,7 +800,7 @@ where comms[eval.poly()].num_vars().unwrap() ); }); - assert!(poly_num_vars.iter().min().unwrap() >= &V::get_basecode()); + assert!(poly_num_vars.iter().min().unwrap() >= &Spec::get_basecode_msg_size_log()); let sumcheck_timer = start_timer!(|| "Basefold::batch_verify::initial sumcheck"); let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; @@ -646,13 +846,13 @@ where } } let final_message = transcript - .read_field_elements_ext(1 << V::get_basecode()) + .read_field_elements_ext(1 << Spec::get_basecode_msg_size_log()) .unwrap(); let query_challenges = transcript - .squeeze_challenges(vp.num_verifier_queries) + .squeeze_challenges(Spec::get_number_queries()) .iter() - .map(|index| ext_to_usize(index) % (1 << (num_vars + vp.log_rate))) + .map(|index| ext_to_usize(index) % (1 << (num_vars + Spec::get_rate_log()))) .collect_vec(); let read_query_timer = start_timer!(|| "Basefold::verify::read query"); @@ -663,7 +863,7 @@ where BatchedQueriesResultWithMerklePath::read_transcript_base( transcript, num_rounds, - vp.log_rate, + Spec::get_rate_log(), poly_num_vars.as_slice(), query_challenges.as_slice(), ) @@ -671,7 +871,7 @@ where BatchedQueriesResultWithMerklePath::read_transcript_ext( transcript, num_rounds, - vp.log_rate, + Spec::get_rate_log(), poly_num_vars.as_slice(), query_challenges.as_slice(), ) @@ -692,28 +892,130 @@ where ); eq.par_iter_mut().for_each(|e| *e *= coeff); - batch_verifier_query_phase( + batch_verifier_query_phase::( + &vp.encoding_params, &query_result_with_merkle_path, &sumcheck_messages, &fold_challenges, num_rounds, num_vars, - vp.log_rate, &final_message, &roots, &comms, &coeffs, eq.as_slice(), - vp.rng.clone(), &new_target_sum, &hasher, ); end_timer!(timer); Ok(()) } + + fn simple_batch_verify( + vp: &Self::VerifierParam, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error> { + let timer = start_timer!(|| "Basefold::simple batch verify"); + assert!(comm.num_vars().unwrap() >= Spec::get_basecode_msg_size_log()); + let batch_size = evals.len(); + if let Some(num_polys) = comm.num_polys { + assert_eq!(num_polys, batch_size); + } + let hasher = new_hasher::(); + + let num_vars = point.len(); + if let Some(comm_num_vars) = comm.num_vars { + assert_eq!(num_vars, comm_num_vars); + } + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); + + // evals.len() is the batch size, i.e., how many polynomials are being opened together + let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; + let t = transcript.squeeze_challenges(batch_size_log); + let eq_xt = build_eq_x_r_vec(&t)[..evals.len()].to_vec(); + + let mut fold_challenges: Vec = Vec::with_capacity(num_vars); + let mut roots = Vec::new(); + let mut sumcheck_messages = Vec::with_capacity(num_rounds); + let sumcheck_timer = start_timer!(|| "Basefold::simple_batch_verify::interaction"); + for i in 0..num_rounds { + sumcheck_messages.push(transcript.read_field_elements_ext(3).unwrap()); + fold_challenges.push(transcript.squeeze_challenge()); + if i < num_rounds - 1 { + roots.push(transcript.read_commitment().unwrap()); + } + } + end_timer!(sumcheck_timer); + + let read_timer = start_timer!(|| "Basefold::verify::read transcript"); + let final_message = transcript + .read_field_elements_ext(1 << Spec::get_basecode_msg_size_log()) + .unwrap(); + let query_challenges = transcript + .squeeze_challenges(Spec::get_number_queries()) + .iter() + .map(|index| ext_to_usize(index) % (1 << (num_vars + Spec::get_rate_log()))) + .collect_vec(); + let read_query_timer = start_timer!(|| "Basefold::verify::read query"); + let query_result_with_merkle_path = if comm.is_base() { + SimpleBatchQueriesResultWithMerklePath::read_transcript_base( + transcript, + num_rounds, + Spec::get_rate_log(), + num_vars, + query_challenges.as_slice(), + batch_size, + ) + } else { + SimpleBatchQueriesResultWithMerklePath::read_transcript_ext( + transcript, + num_rounds, + Spec::get_rate_log(), + num_vars, + query_challenges.as_slice(), + batch_size, + ) + }; + end_timer!(read_query_timer); + end_timer!(read_timer); + + // coeff is the eq polynomial evaluated at the last challenge.len() variables + // in reverse order. + let rev_challenges = fold_challenges.clone().into_iter().rev().collect_vec(); + let coeff = eq_xy_eval( + &point[point.len() - fold_challenges.len()..], + &rev_challenges, + ); + // Compute eq as the partially evaluated eq polynomial + let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); + eq.par_iter_mut().for_each(|e| *e *= coeff); + + simple_batch_verifier_query_phase::( + &vp.encoding_params, + &query_result_with_merkle_path, + &sumcheck_messages, + &fold_challenges, + &eq_xt, + num_rounds, + num_vars, + &final_message, + &roots, + comm, + eq.as_slice(), + evals, + &hasher, + ); + end_timer!(timer); + + Ok(()) + } } -impl NoninteractivePCS for Basefold +impl, Rng: RngCore + std::fmt::Debug> NoninteractivePCS + for Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, @@ -724,45 +1026,128 @@ where mod test { use crate::{ basefold::Basefold, - test_util::{run_batch_commit_open_verify, run_commit_open_verify}, + test_util::{ + run_batch_commit_open_verify, run_commit_open_verify, + run_simple_batch_commit_open_verify, + }, util::transcript::PoseidonTranscript, }; use goldilocks::GoldilocksExt2; + use rand_chacha::ChaCha8Rng; - use super::BasefoldDefaultParams; + use super::{structure::BasefoldBasecodeParams, BasefoldRSParams}; - type PcsGoldilocks = Basefold; + type PcsGoldilocksRSCode = Basefold; + type PcsGoldilocksBaseCode = Basefold; #[test] - fn commit_open_verify_goldilocks_base() { + fn commit_open_verify_goldilocks_basecode_base() { // Challenge is over extension field, poly over the base field - run_commit_open_verify::>( - true, 10, 11, + run_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + PoseidonTranscript, + >(true, 10, 11); + } + + #[test] + fn commit_open_verify_goldilocks_rscode_base() { + // Challenge is over extension field, poly over the base field + run_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + PoseidonTranscript, + >(true, 10, 11); + } + + #[test] + fn commit_open_verify_goldilocks_basecode_2() { + // Both challenge and poly are over extension field + run_commit_open_verify::>( + false, 10, 11, ); } #[test] - fn commit_open_verify_goldilocks_2() { + fn commit_open_verify_goldilocks_rscode_2() { // Both challenge and poly are over extension field - run_commit_open_verify::>( + run_commit_open_verify::>( false, 10, 11, ); } #[test] - fn batch_commit_open_verify_goldilocks_base() { + fn simple_batch_commit_open_verify_goldilocks_basecode_base() { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + PoseidonTranscript, + >(true, 10, 11, 4); + } + + #[test] + fn simple_batch_commit_open_verify_goldilocks_rscode_base() { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + PoseidonTranscript, + >(true, 10, 11, 4); + } + + #[test] + fn simple_batch_commit_open_verify_goldilocks_basecode_2() { + // Both challenge and poly are over extension field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + PoseidonTranscript<_>, + >(false, 10, 11, 4); + } + + #[test] + fn simple_batch_commit_open_verify_goldilocks_rscode_2() { + // Both challenge and poly are over extension field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + PoseidonTranscript<_>, + >(false, 10, 11, 4); + } + + #[test] + fn batch_commit_open_verify_goldilocks_basecode_base() { // Both challenge and poly are over base field run_batch_commit_open_verify::< GoldilocksExt2, - PcsGoldilocks, + PcsGoldilocksBaseCode, PoseidonTranscript, >(true, 10, 11); } #[test] - fn batch_commit_open_verify_goldilocks_2() { + fn batch_commit_open_verify_goldilocks_rscode_base() { + // Both challenge and poly are over base field + run_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + PoseidonTranscript, + >(true, 10, 11); + } + + #[test] + fn batch_commit_open_verify_goldilocks_basecode_2() { + // Both challenge and poly are over extension field + run_batch_commit_open_verify::>( + false, 10, 11, + ); + } + + #[test] + fn batch_commit_open_verify_goldilocks_rscode_2() { // Both challenge and poly are over extension field - run_batch_commit_open_verify::>( + run_batch_commit_open_verify::>( false, 10, 11, ); } diff --git a/mpcs/src/basefold/basecode.rs b/mpcs/src/basefold/basecode.rs deleted file mode 100644 index a10679ede..000000000 --- a/mpcs/src/basefold/basecode.rs +++ /dev/null @@ -1,293 +0,0 @@ -use crate::util::{arithmetic::base_from_raw_bytes, log2_strict, num_of_bytes}; -use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; -use ark_std::{end_timer, start_timer}; -use ctr; -use ff::{BatchInverter, Field, PrimeField}; -use ff_ext::ExtensionField; -use generic_array::GenericArray; -use multilinear_extensions::mle::FieldType; -use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; - -use itertools::Itertools; - -use crate::util::plonky2_util::reverse_index_bits_in_place; -use rand_chacha::rand_core::RngCore; -use rayon::prelude::IntoParallelRefIterator; - -use crate::util::arithmetic::{horner, steps}; - -pub fn encode_field_type_rs_basecode( - poly: &FieldType, - rate: usize, - message_size: usize, -) -> Vec> { - match poly { - FieldType::Ext(poly) => encode_rs_basecode(poly, rate, message_size) - .iter() - .map(|x| FieldType::Ext(x.clone())) - .collect(), - FieldType::Base(poly) => encode_rs_basecode(poly, rate, message_size) - .iter() - .map(|x| FieldType::Base(x.clone())) - .collect(), - _ => panic!("Unsupported field type"), - } -} - -// Split the input into chunks of message size, encode each message, and return the codewords -pub fn encode_rs_basecode( - poly: &Vec, - rate: usize, - message_size: usize, -) -> Vec> { - let timer = start_timer!(|| "Encode basecode"); - // The domain is just counting 1, 2, 3, ... , domain_size - let domain: Vec = steps(F::ONE).take(message_size * rate).collect(); - let res = poly - .par_chunks_exact(message_size) - .map(|chunk| { - let mut target = vec![F::ZERO; message_size * rate]; - // Just Reed-Solomon code, but with the naive domain - target - .iter_mut() - .enumerate() - .for_each(|(i, target)| *target = horner(&chunk[..], &domain[i])); - target - }) - .collect::>>(); - end_timer!(timer); - - res -} - -fn concatenate_field_types(coeffs: &Vec>) -> FieldType { - match coeffs[0] { - FieldType::Ext(_) => { - let res = coeffs - .iter() - .map(|x| match x { - FieldType::Ext(x) => x.iter().map(|x| *x), - _ => unreachable!(), - }) - .flatten() - .collect::>(); - FieldType::Ext(res) - } - FieldType::Base(_) => { - let res = coeffs - .iter() - .map(|x| match x { - FieldType::Base(x) => x.iter().map(|x| *x), - _ => unreachable!(), - }) - .flatten() - .collect::>(); - FieldType::Base(res) - } - _ => unreachable!(), - } -} - -// this function assumes all codewords in base_codeword has equivalent length -pub fn evaluate_over_foldable_domain_generic_basecode( - base_message_length: usize, - num_coeffs: usize, - log_rate: usize, - base_codewords: Vec>, - table: &Vec>, -) -> FieldType { - let timer = start_timer!(|| "evaluate over foldable domain"); - let k = num_coeffs; - let logk = log2_strict(k); - let base_log_k = log2_strict(base_message_length); - // concatenate together all base codewords - // let now = Instant::now(); - let mut coeffs_with_bc = concatenate_field_types(&base_codewords); - // println!("concatenate base codewords {:?}", now.elapsed()); - // iterate over array, replacing even indices with (evals[i] - evals[(i+1)]) - let mut chunk_size = base_codewords[0].len(); // block length of the base code - for i in base_log_k..logk { - // In beginning of each iteration, the current codeword size is 1<> 1); - match coeffs_with_bc { - FieldType::Ext(ref mut coeffs_with_bc) => { - coeffs_with_bc.par_chunks_mut(chunk_size).for_each(|chunk| { - let half_chunk = chunk_size >> 1; - for j in half_chunk..chunk_size { - // Suppose the current codewords are (a, b) - // The new codeword is computed by two halves: - // left = a + t * b - // right = a - t * b - let rhs = chunk[j] * E::from(level[j - half_chunk]); - chunk[j] = chunk[j - half_chunk] - rhs; - chunk[j - half_chunk] = chunk[j - half_chunk] + rhs; - } - }); - } - FieldType::Base(ref mut coeffs_with_bc) => { - coeffs_with_bc.par_chunks_mut(chunk_size).for_each(|chunk| { - let half_chunk = chunk_size >> 1; - for j in half_chunk..chunk_size { - // Suppose the current codewords are (a, b) - // The new codeword is computed by two halves: - // left = a + t * b - // right = a - t * b - let rhs = chunk[j] * level[j - half_chunk]; - chunk[j] = chunk[j - half_chunk] - rhs; - chunk[j - half_chunk] = chunk[j - half_chunk] + rhs; - } - }); - } - _ => unreachable!(), - } - } - end_timer!(timer); - coeffs_with_bc -} - -pub fn get_table_aes( - poly_size: usize, - rate: usize, - rng: &mut Rng, -) -> ( - Vec>, - Vec>, -) { - // The size (logarithmic) of the codeword for the polynomial - let lg_n: usize = rate + log2_strict(poly_size); - - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - - let mut cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - - // Allocate the buffer for storing n field elements (the entire codeword) - let bytes = num_of_bytes::(1 << lg_n); - let mut dest: Vec = vec![0u8; bytes]; - cipher.apply_keystream(&mut dest[..]); - - // Now, dest is a vector filled with random data for a field vector of size n - - // Collect the bytes into field elements - let flat_table: Vec = dest - .par_chunks_exact(num_of_bytes::(1)) - .map(|chunk| base_from_raw_bytes::(&chunk.to_vec())) - .collect::>(); - - // Now, flat_table is a field vector of size n, filled with random field elements - assert_eq!(flat_table.len(), 1 << lg_n); - - // Multiply -2 to every element to get the weights. Now weights = { -2x } - let mut weights: Vec = flat_table - .par_iter() - .map(|el| E::BaseField::ZERO - *el - *el) - .collect(); - - // Then invert all the elements. Now weights = { -1/2x } - let mut scratch_space = vec![E::BaseField::ZERO; weights.len()]; - BatchInverter::invert_with_external_scratch(&mut weights, &mut scratch_space); - - // Zip x and -1/2x together. The result is the list { (x, -1/2x) } - // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which - // involves computing 1/(b-a) where b=-x and a=x, and 1/(b-a) here is exactly -1/2x - let flat_table_w_weights = flat_table - .iter() - .zip(weights) - .map(|(el, w)| (*el, w)) - .collect_vec(); - - // Split the positions from 0 to n-1 into slices of sizes: - // 2, 2, 4, 8, ..., n/2, exactly lg_n number of them - // The weights are (x, -1/2x), the table elements are just x - - let mut unflattened_table_w_weights = vec![Vec::new(); lg_n]; - let mut unflattened_table = vec![Vec::new(); lg_n]; - - let mut level_weights = flat_table_w_weights[0..2].to_vec(); - // Apply the reverse-bits permutation to a vector of size 2, equivalent to just swapping - reverse_index_bits_in_place(&mut level_weights); - unflattened_table_w_weights[0] = level_weights; - - unflattened_table[0] = flat_table[0..2].to_vec(); - for i in 1..lg_n { - unflattened_table[i] = flat_table[(1 << i)..(1 << (i + 1))].to_vec(); - let mut level = flat_table_w_weights[(1 << i)..(1 << (i + 1))].to_vec(); - reverse_index_bits_in_place(&mut level); - unflattened_table_w_weights[i] = level; - } - - return (unflattened_table_w_weights, unflattened_table); -} - -pub fn query_point( - block_length: usize, - eval_index: usize, - level: usize, - mut cipher: &mut ctr::Ctr32LE, -) -> E::BaseField { - let level_index = eval_index % (block_length); - let mut el = - query_root_table_from_rng_aes::(level, level_index % (block_length >> 1), &mut cipher); - - if level_index >= (block_length >> 1) { - el = -E::BaseField::ONE * el; - } - - return el; -} - -pub fn query_root_table_from_rng_aes( - level: usize, - index: usize, - cipher: &mut ctr::Ctr32LE, -) -> E::BaseField { - let mut level_offset: u128 = 1; - for lg_m in 1..=level { - let half_m = 1 << (lg_m - 1); - level_offset += half_m; - } - - let pos = ((level_offset + (index as u128)) - * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) - .checked_div(8) - .unwrap(); - - cipher.seek(pos); - - let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; - let mut dest: Vec = vec![0u8; bytes]; - cipher.apply_keystream(&mut dest); - - let res = base_from_raw_bytes::(&dest); - - res -} - -#[cfg(test)] -mod tests { - use super::*; - use goldilocks::GoldilocksExt2; - use multilinear_extensions::mle::DenseMultilinearExtension; - - #[test] - fn time_rs_code() { - use rand::rngs::OsRng; - - let poly = DenseMultilinearExtension::random(20, &mut OsRng); - - encode_field_type_rs_basecode::(&poly.evaluations, 2, 64); - } -} diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index 1716edaea..da8b4ad75 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -1,5 +1,6 @@ use super::{ - basecode::encode_rs_basecode, + encoding::EncodingScheme, + structure::BasefoldSpec, sumcheck::{ sum_check_challenge_round, sum_check_first_round, sum_check_first_round_field_type, sum_check_last_round, @@ -29,29 +30,37 @@ use rayon::prelude::{ use super::structure::BasefoldCommitmentWithData; // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) -pub fn commit_phase( +pub fn commit_phase>( + pp: &>::ProverParameters, point: &[E], comm: &BasefoldCommitmentWithData, transcript: &mut impl TranscriptWrite, E>, num_vars: usize, num_rounds: usize, - table_w_weights: &Vec>, - log_rate: usize, hasher: &Hasher, ) -> (Vec>, Vec>) where E::BaseField: Serialize + DeserializeOwned, { let timer = start_timer!(|| "Commit phase"); + #[cfg(feature = "sanity-check")] assert_eq!(point.len(), num_vars); let mut oracles = Vec::with_capacity(num_vars); let mut trees = Vec::with_capacity(num_vars); - let mut running_oracle = field_type_iter_ext(comm.get_codeword()).collect_vec(); - let mut running_evals = comm.bh_evals.clone(); + let mut running_oracle = field_type_iter_ext(&comm.get_codewords()[0]).collect_vec(); + let mut running_evals = comm.polynomials_bh_evals[0].clone(); + + #[cfg(feature = "sanity-check")] + assert_eq!( + running_oracle.len(), + running_evals.len() << Spec::get_rate_log() + ); + #[cfg(feature = "sanity-check")] + assert_eq!(running_evals.len(), 1 << num_vars); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube let build_eq_timer = start_timer!(|| "Basefold::open"); - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); end_timer!(build_eq_timer); reverse_index_bits_in_place(&mut eq); @@ -59,6 +68,9 @@ where let mut last_sumcheck_message = sum_check_first_round_field_type(&mut eq, &mut running_evals); end_timer!(sumcheck_timer); + #[cfg(feature = "sanity-check")] + assert_eq!(last_sumcheck_message.len(), 3); + let mut running_evals = match running_evals { FieldType::Ext(evals) => evals, FieldType::Base(evals) => evals.iter().map(|x| E::from(*x)).collect_vec(), @@ -77,8 +89,8 @@ where let challenge = transcript.squeeze_challenge(); // Fold the current oracle for FRI - running_oracle = basefold_one_round_by_interpolation_weights::( - &table_w_weights, + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, log2_strict(running_oracle.len()) - 1, &running_oracle, challenge, @@ -111,9 +123,17 @@ where let mut coeffs = running_evals.clone(); interpolate_over_boolean_hypercube(&mut coeffs); - let basecode = encode_rs_basecode(&coeffs, 1 << log_rate, coeffs.len()); - assert_eq!(basecode.len(), 1); - let basecode = basecode[0].clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(b) => b, + _ => panic!("Should be ext field"), + }; reverse_index_bits_in_place(&mut running_oracle); assert_eq!(basecode, running_oracle); @@ -122,18 +142,18 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + (trees, oracles) } // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) -pub fn batch_commit_phase( +#[allow(clippy::too_many_arguments)] +pub fn batch_commit_phase>( + pp: &>::ProverParameters, point: &[E], - comms: &[&BasefoldCommitmentWithData], + comms: &[BasefoldCommitmentWithData], transcript: &mut impl TranscriptWrite, E>, num_vars: usize, num_rounds: usize, - table_w_weights: &Vec>, - log_rate: usize, coeffs: &[E], hasher: &Hasher, ) -> (Vec>, Vec>) @@ -144,7 +164,7 @@ where assert_eq!(point.len(), num_vars); let mut oracles = Vec::with_capacity(num_vars); let mut trees = Vec::with_capacity(num_vars); - let mut running_oracle = vec![E::ZERO; 1 << (num_vars + log_rate)]; + let mut running_oracle = vec![E::ZERO; 1 << (num_vars + Spec::get_rate_log())]; let build_oracle_timer = start_timer!(|| "Basefold build initial oracle"); // Before the interaction, collect all the polynomials whose num variables match the @@ -157,8 +177,8 @@ where .for_each(|(index, comm)| { running_oracle .iter_mut() - .zip_eq(field_type_iter_ext(comm.get_codeword())) - .for_each(|(r, a)| *r += E::from(a) * coeffs[index]); + .zip_eq(field_type_iter_ext(&comm.get_codewords()[0])) + .for_each(|(r, a)| *r += a * coeffs[index]); }); end_timer!(build_oracle_timer); @@ -177,16 +197,16 @@ where // to align the polynomials to the variable with index 0 before adding them // together. So each element is repeated by // sum_of_all_evals_for_sumcheck.len() / bh_evals.len() times - *r += E::from(field_type_index_ext( - &comm.bh_evals, - pos >> (num_vars - log2_strict(comm.bh_evals.len())), - )) * coeffs[index] + *r += field_type_index_ext( + &comm.polynomials_bh_evals[0], + pos >> (num_vars - log2_strict(comm.polynomials_bh_evals[0].len())), + ) * coeffs[index] }); }); end_timer!(build_oracle_timer); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); reverse_index_bits_in_place(&mut eq); let sumcheck_timer = start_timer!(|| "Basefold first round"); @@ -208,8 +228,8 @@ where let challenge = transcript.squeeze_challenge(); // Fold the current oracle for FRI - running_oracle = basefold_one_round_by_interpolation_weights::( - &table_w_weights, + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, log2_strict(running_oracle.len()) - 1, &running_oracle, challenge, @@ -236,8 +256,8 @@ where .for_each(|(index, comm)| { running_oracle .iter_mut() - .zip_eq(field_type_iter_ext(comm.get_codeword())) - .for_each(|(r, a)| *r += E::from(a) * coeffs[index]); + .zip_eq(field_type_iter_ext(&comm.get_codewords()[0])) + .for_each(|(r, a)| *r += a * coeffs[index]); }); } else { // The difference of the last round is that we don't need to compute the message, @@ -257,10 +277,127 @@ where // on the prover side should be exactly the encoding of the folded polynomial. let mut coeffs = sum_of_all_evals_for_sumcheck.clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } + interpolate_over_boolean_hypercube(&mut coeffs); + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(x) => x, + _ => panic!("Expected ext field"), + }; + + reverse_index_bits_in_place(&mut running_oracle); + assert_eq!(basecode, running_oracle); + } + } + end_timer!(sumcheck_timer); + } + end_timer!(timer); + (trees, oracles) +} + +// outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) +#[allow(clippy::too_many_arguments)] +pub fn simple_batch_commit_phase>( + pp: &>::ProverParameters, + point: &[E], + batch_coeffs: &[E], + comm: &BasefoldCommitmentWithData, + transcript: &mut impl TranscriptWrite, E>, + num_vars: usize, + num_rounds: usize, + hasher: &Hasher, +) -> (Vec>, Vec>) +where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Simple batch commit phase"); + assert_eq!(point.len(), num_vars); + assert_eq!(comm.num_polys, batch_coeffs.len()); + let mut oracles = Vec::with_capacity(num_vars); + let mut trees = Vec::with_capacity(num_vars); + let mut running_oracle = comm.batch_codewords(&batch_coeffs.to_vec()); + let mut running_evals = (0..(1 << num_vars)) + .map(|i| { + comm.polynomials_bh_evals + .iter() + .zip(batch_coeffs) + .map(|(eval, coeff)| field_type_index_ext(eval, i) * *coeff) + .sum() + }) + .collect_vec(); + + // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube + let build_eq_timer = start_timer!(|| "Basefold::open"); + let mut eq = build_eq_x_r_vec(point); + end_timer!(build_eq_timer); + reverse_index_bits_in_place(&mut eq); + + let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round"); + let mut last_sumcheck_message = sum_check_first_round(&mut eq, &mut running_evals); + end_timer!(sumcheck_timer); + + for i in 0..num_rounds { + let sumcheck_timer = start_timer!(|| format!("Basefold round {}", i)); + // For the first round, no need to send the running root, because this root is + // committing to a vector that can be recovered from linearly combining other + // already-committed vectors. + transcript + .write_field_elements_ext(&last_sumcheck_message) + .unwrap(); + + let challenge = transcript.squeeze_challenge(); + + // Fold the current oracle for FRI + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, + log2_strict(running_oracle.len()) - 1, + &running_oracle, + challenge, + ); + + if i < num_rounds - 1 { + last_sumcheck_message = + sum_check_challenge_round(&mut eq, &mut running_evals, challenge); + let running_tree = + MerkleTree::::from_leaves(FieldType::Ext(running_oracle.clone()), hasher); + let running_root = running_tree.root(); + transcript.write_commitment(&running_root).unwrap(); + + oracles.push(running_oracle.clone()); + trees.push(running_tree); + } else { + // The difference of the last round is that we don't need to compute the message, + // and we don't interpolate the small polynomials. So after the last round, + // running_evals is exactly the evaluation representation of the + // folded polynomial so far. + sum_check_last_round(&mut eq, &mut running_evals, challenge); + // For the FRI part, we send the current polynomial as the message. + // Transform it back into little endiean before sending it + reverse_index_bits_in_place(&mut running_evals); + transcript.write_field_elements_ext(&running_evals).unwrap(); + + if cfg!(feature = "sanity-check") { + // If the prover is honest, in the last round, the running oracle + // on the prover side should be exactly the encoding of the folded polynomial. + + let mut coeffs = running_evals.clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } interpolate_over_boolean_hypercube(&mut coeffs); - let basecode = encode_rs_basecode(&coeffs, 1 << log_rate, coeffs.len()); - assert_eq!(basecode.len(), 1); - let basecode = basecode[0].clone(); + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(basecode) => basecode, + _ => panic!("Should be ext field"), + }; reverse_index_bits_in_place(&mut running_oracle); assert_eq!(basecode, running_oracle); @@ -269,28 +406,22 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + (trees, oracles) } -fn basefold_one_round_by_interpolation_weights( - table: &Vec>, - level_index: usize, +fn basefold_one_round_by_interpolation_weights>( + pp: &>::ProverParameters, + level: usize, values: &Vec, challenge: E, ) -> Vec { - let level = &table[level_index]; values .par_chunks_exact(2) .enumerate() .map(|(i, ys)| { - interpolate2_weights( - [ - (E::from(level[i].0), ys[0]), - (E::from(-(level[i].0)), ys[1]), - ], - E::from(level[i].1), - challenge, - ) + let (x0, x1, w) = + >::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, ys[0]), (x1, ys[1])], w, challenge) }) .collect::>() } diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs new file mode 100644 index 000000000..9fa651e2f --- /dev/null +++ b/mpcs/src/basefold/encoding.rs @@ -0,0 +1,237 @@ +use ff_ext::ExtensionField; +use multilinear_extensions::mle::FieldType; + +mod utils; + +mod basecode; +pub use basecode::{Basecode, BasecodeDefaultSpec}; + +mod rs; +use plonky2::util::log2_strict; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; +pub use rs::{coset_fft, fft, fft_root_table, RSCode, RSCodeDefaultSpec}; + +use serde::{de::DeserializeOwned, Serialize}; + +use crate::{util::arithmetic::interpolate2_weights, Error}; + +pub trait EncodingProverParameters { + fn get_max_message_size_log(&self) -> usize; +} + +pub trait EncodingScheme: std::fmt::Debug + Clone { + type PublicParameters: Clone + std::fmt::Debug + Serialize + DeserializeOwned; + type ProverParameters: Clone + + std::fmt::Debug + + Serialize + + DeserializeOwned + + EncodingProverParameters + + Sync; + type VerifierParameters: Clone + std::fmt::Debug + Serialize + DeserializeOwned + Sync; + + fn setup(max_msg_size_log: usize, rng_seed: [u8; 32]) -> Self::PublicParameters; + + fn trim( + pp: &Self::PublicParameters, + max_msg_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>; + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType; + + /// Encodes a message of small length, such that the verifier is also able + /// to execute the encoding. + fn encode_small(vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType; + + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; + + /// Whether the message needs to be bit-reversed to allow even-odd + /// folding. If the folding is already even-odd style (like RS code), + /// then set this function to return false. If the folding is originally + /// left-right, like basefold, then return true. + fn message_is_left_and_right_folding() -> bool; + + fn message_is_even_and_odd_folding() -> bool { + !Self::message_is_left_and_right_folding() + } + + /// Returns three values: x0, x1 and 1/(x1-x0). Note that although + /// 1/(x1-x0) can be computed from the other two values, we return it + /// separately because inversion is expensive. + /// These three values can be used to interpolate a linear function + /// that passes through the two points (x0, y0) and (x1, y1), for the + /// given y0 and y1, then compute the value of the linear function at + /// any give x. + /// Params: + /// - level: which particular code in this family of codes? + /// - index: position in the codeword (after folded) + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E); + + /// The same as `prover_folding_coeffs`, but for the verifier. The two + /// functions, although provide the same functionality, may use different + /// implementations. For example, prover can use precomputed values stored + /// in the parameters, but the verifier may need to recompute them. + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E); + + /// Fold the given codeword into a smaller codeword of half size, using + /// the folding coefficients computed by `prover_folding_coeffs`. + /// The given codeword is assumed to be bit-reversed on the original + /// codeword directly produced from the `encode` method. + fn fold_bitreversed_codeword( + pp: &Self::ProverParameters, + codeword: &FieldType, + challenge: E, + ) -> Vec { + let level = log2_strict(codeword.len()) - 1; + match codeword { + FieldType::Ext(codeword) => codeword + .par_chunks_exact(2) + .enumerate() + .map(|(i, ys)| { + let (x0, x1, w) = Self::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, ys[0]), (x1, ys[1])], w, challenge) + }) + .collect::>(), + FieldType::Base(codeword) => codeword + .par_chunks_exact(2) + .enumerate() + .map(|(i, ys)| { + let (x0, x1, w) = Self::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, E::from(ys[0])), (x1, E::from(ys[1]))], w, challenge) + }) + .collect::>(), + _ => panic!("Unsupported field type"), + } + } + + /// Fold the given message into a smaller message of half size using challenge + /// as the random linear combination coefficient. + /// Note that this is always even-odd fold, assuming the message has + /// been bit-reversed (or not) according to the setting + /// of the `message_need_bit_reversion` function. + fn fold_message(msg: &FieldType, challenge: E) -> Vec { + match msg { + FieldType::Ext(msg) => msg + .par_chunks_exact(2) + .map(|ys| ys[0] + ys[1] * challenge) + .collect::>(), + FieldType::Base(msg) => msg + .par_chunks_exact(2) + .map(|ys| E::from(ys[0]) + E::from(ys[1]) * challenge) + .collect::>(), + _ => panic!("Unsupported field type"), + } + } +} + +fn concatenate_field_types(coeffs: &[FieldType]) -> FieldType { + match coeffs[0] { + FieldType::Ext(_) => { + let res = coeffs + .iter() + .flat_map(|x| match x { + FieldType::Ext(x) => x.iter().copied(), + _ => unreachable!(), + }) + .collect::>(); + FieldType::Ext(res) + } + FieldType::Base(_) => { + let res = coeffs + .iter() + .flat_map(|x| match x { + FieldType::Base(x) => x.iter().copied(), + _ => unreachable!(), + }) + .collect::>(); + FieldType::Base(res) + } + _ => unreachable!(), + } +} + +#[cfg(test)] +pub(crate) mod test_util { + use ff_ext::ExtensionField; + use multilinear_extensions::mle::FieldType; + use rand::rngs::OsRng; + + use crate::util::plonky2_util::reverse_index_bits_in_place_field_type; + + use super::EncodingScheme; + + pub fn test_codeword_folding>() { + let num_vars = 12; + + let poly: Vec = (0..(1 << num_vars)).map(|i| E::from(i)).collect(); + let mut poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp: Code::PublicParameters = Code::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + reverse_index_bits_in_place_field_type(&mut codeword); + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut poly); + } + let challenge = E::random(&mut OsRng); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let mut folded_message = FieldType::Ext(Code::fold_message(&poly, challenge)); + if Code::message_is_left_and_right_folding() { + // Reverse the message back before encoding if it has been + // bit-reversed + reverse_index_bits_in_place_field_type(&mut folded_message); + } + let mut encoded_folded_message = Code::encode(&pp, &folded_message); + reverse_index_bits_in_place_field_type(&mut encoded_folded_message); + let encoded_folded_message = match encoded_folded_message { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + for (i, (a, b)) in folded_codeword + .iter() + .zip(encoded_folded_message.iter()) + .enumerate() + { + assert_eq!(a, b, "Failed at index {}", i); + } + + let mut folded_codeword = FieldType::Ext(folded_codeword); + for round in 0..4 { + let folded_codeword_vec = + Code::fold_bitreversed_codeword(&pp, &folded_codeword, challenge); + + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut folded_message); + } + folded_message = FieldType::Ext(Code::fold_message(&folded_message, challenge)); + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut folded_message); + } + let mut encoded_folded_message = Code::encode(&pp, &folded_message); + reverse_index_bits_in_place_field_type(&mut encoded_folded_message); + let encoded_folded_message = match encoded_folded_message { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + for (i, (a, b)) in folded_codeword_vec + .iter() + .zip(encoded_folded_message.iter()) + .enumerate() + { + assert_eq!(a, b, "Failed at index {} in round {}", i, round); + } + folded_codeword = FieldType::Ext(folded_codeword_vec); + } + } +} diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs new file mode 100644 index 000000000..051a510e6 --- /dev/null +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -0,0 +1,451 @@ +use std::marker::PhantomData; + +use super::{concatenate_field_types, EncodingProverParameters, EncodingScheme}; +use crate::{ + util::{ + arithmetic::base_from_raw_bytes, log2_strict, num_of_bytes, plonky2_util::reverse_bits, + }, + vec_mut, Error, +}; +use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; +use ark_std::{end_timer, start_timer}; +use ff::{BatchInverter, Field, PrimeField}; +use ff_ext::ExtensionField; +use generic_array::GenericArray; +use multilinear_extensions::mle::FieldType; +use rand::SeedableRng; +use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; + +use itertools::Itertools; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::util::plonky2_util::reverse_index_bits_in_place; +use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; +use rayon::prelude::IntoParallelRefIterator; + +use crate::util::arithmetic::{horner, steps}; + +pub trait BasecodeSpec: std::fmt::Debug + Clone { + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; +} + +#[derive(Debug, Clone)] +pub struct BasecodeDefaultSpec {} + +impl BasecodeSpec for BasecodeDefaultSpec { + fn get_number_queries() -> usize { + 766 + } + + fn get_rate_log() -> usize { + 3 + } + + fn get_basecode_msg_size_log() -> usize { + 7 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasecodeParameters { + pub(crate) table: Vec>, + pub(crate) table_w_weights: Vec>, + pub(crate) rng_seed: [u8; 32], +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasecodeProverParameters { + pub(crate) table: Vec>, + pub(crate) table_w_weights: Vec>, + pub(crate) rng_seed: [u8; 32], + #[serde(skip)] + _phantom: PhantomData Spec>, +} + +impl EncodingProverParameters + for BasecodeProverParameters +{ + fn get_max_message_size_log(&self) -> usize { + self.table.len() - Spec::get_rate_log() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BasecodeVerifierParameters { + pub(crate) rng_seed: [u8; 32], + pub(crate) aes_key: [u8; 16], + pub(crate) aes_iv: [u8; 16], +} + +#[derive(Debug, Clone)] +pub struct Basecode { + _phantom_data: PhantomData, +} + +impl EncodingScheme for Basecode +where + E::BaseField: Serialize + DeserializeOwned, +{ + type PublicParameters = BasecodeParameters; + + type ProverParameters = BasecodeProverParameters; + + type VerifierParameters = BasecodeVerifierParameters; + + fn setup(max_msg_size_log: usize, rng_seed: [u8; 32]) -> Self::PublicParameters { + let rng = ChaCha8Rng::from_seed(rng_seed); + let (table_w_weights, table) = + get_table_aes::(max_msg_size_log, Spec::get_rate_log(), &mut rng.clone()); + BasecodeParameters { + table, + table_w_weights, + rng_seed, + } + } + + fn trim( + pp: &Self::PublicParameters, + max_msg_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { + if pp.table.len() < Spec::get_rate_log() + max_msg_size_log { + return Err(Error::InvalidPcsParam(format!( + "Public parameter is setup for a smaller message size (log={}) than the trimmed message size (log={})", + pp.table.len() - Spec::get_rate_log(), + max_msg_size_log, + ))); + } + let mut key: [u8; 16] = [0u8; 16]; + let mut iv: [u8; 16] = [0u8; 16]; + let mut rng = ChaCha8Rng::from_seed(pp.rng_seed); + rng.set_word_pos(0); + rng.fill_bytes(&mut key); + rng.fill_bytes(&mut iv); + Ok(( + Self::ProverParameters { + table_w_weights: pp.table_w_weights.clone(), + table: pp.table.clone(), + rng_seed: pp.rng_seed, + _phantom: PhantomData, + }, + Self::VerifierParameters { + rng_seed: pp.rng_seed, + aes_key: key, + aes_iv: iv, + }, + )) + } + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType { + // Split the input into chunks of message size, encode each message, and return the codewords + let basecode = encode_field_type_rs_basecode( + coeffs, + 1 << Spec::get_rate_log(), + 1 << Spec::get_basecode_msg_size_log(), + ); + + // Apply the recursive definition of the BaseFold code to the list of base codewords, + // and produce the final codeword + evaluate_over_foldable_domain_generic_basecode::( + 1 << Spec::get_basecode_msg_size_log(), + coeffs.len(), + Spec::get_rate_log(), + basecode, + &pp.table, + ) + } + + fn encode_small(_vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType { + let mut basecodes = + encode_field_type_rs_basecode(coeffs, 1 << Spec::get_rate_log(), coeffs.len()); + assert_eq!(basecodes.len(), 1); + basecodes.remove(0) + } + + fn get_number_queries() -> usize { + Spec::get_number_queries() + } + + fn get_rate_log() -> usize { + Spec::get_rate_log() + } + + fn get_basecode_msg_size_log() -> usize { + Spec::get_basecode_msg_size_log() + } + + fn message_is_left_and_right_folding() -> bool { + true + } + + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E) { + let level = &pp.table_w_weights[level]; + ( + E::from(level[index].0), + E::from(-level[index].0), + E::from(level[index].1), + ) + } + + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E) { + type Aes128Ctr64LE = ctr::Ctr32LE; + let mut cipher = Aes128Ctr64LE::new( + GenericArray::from_slice(&vp.aes_key[..]), + GenericArray::from_slice(&vp.aes_iv[..]), + ); + + let x0: E::BaseField = query_root_table_from_rng_aes::(level, index, &mut cipher); + let x1 = -x0; + + let w = (x1 - x0).invert().unwrap(); + + (E::from(x0), E::from(x1), E::from(w)) + } +} + +fn encode_field_type_rs_basecode( + poly: &FieldType, + rate: usize, + message_size: usize, +) -> Vec> { + match poly { + FieldType::Ext(poly) => get_basecode(poly, rate, message_size) + .iter() + .map(|x| FieldType::Ext(x.clone())) + .collect(), + FieldType::Base(poly) => get_basecode(poly, rate, message_size) + .iter() + .map(|x| FieldType::Base(x.clone())) + .collect(), + _ => panic!("Unsupported field type"), + } +} + +// Split the input into chunks of message size, encode each message, and return the codewords +// FIXME: It is expensive for now because it is using naive FFT (although it is +// over a small domain) +fn get_basecode(poly: &Vec, rate: usize, message_size: usize) -> Vec> { + let timer = start_timer!(|| "Encode basecode"); + // The domain is just counting 1, 2, 3, ... , domain_size + let domain: Vec = steps(F::ONE).take(message_size * rate).collect(); + let res = poly + .par_chunks_exact(message_size) + .map(|chunk| { + let mut target = vec![F::ZERO; message_size * rate]; + // Just Reed-Solomon code, but with the naive domain + target + .iter_mut() + .enumerate() + .for_each(|(i, target)| *target = horner(chunk, &domain[i])); + target + }) + .collect::>>(); + end_timer!(timer); + + res +} + +// this function assumes all codewords in base_codeword has equivalent length +pub fn evaluate_over_foldable_domain_generic_basecode( + base_message_length: usize, + num_coeffs: usize, + log_rate: usize, + base_codewords: Vec>, + table: &[Vec], +) -> FieldType { + let timer = start_timer!(|| "evaluate over foldable domain"); + let k = num_coeffs; + let logk = log2_strict(k); + let base_log_k = log2_strict(base_message_length); + // concatenate together all base codewords + // let now = Instant::now(); + let mut coeffs_with_bc = concatenate_field_types(&base_codewords); + // println!("concatenate base codewords {:?}", now.elapsed()); + // iterate over array, replacing even indices with (evals[i] - evals[(i+1)]) + let mut chunk_size = base_codewords[0].len(); // block length of the base code + for i in base_log_k..logk { + // In beginning of each iteration, the current codeword size is 1<> 1); + vec_mut!(coeffs_with_bc, |c| { + c.par_chunks_mut(chunk_size).for_each(|chunk| { + let half_chunk = chunk_size >> 1; + for j in half_chunk..chunk_size { + // Suppose the current codewords are (a, b) + // The new codeword is computed by two halves: + // left = a + t * b + // right = a - t * b + let rhs = chunk[j] * level[j - half_chunk]; + chunk[j] = chunk[j - half_chunk] - rhs; + chunk[j - half_chunk] += rhs; + } + }); + }); + } + end_timer!(timer); + coeffs_with_bc +} + +#[allow(clippy::type_complexity)] +pub fn get_table_aes( + poly_size_log: usize, + rate: usize, + rng: &mut Rng, +) -> ( + Vec>, + Vec>, +) { + // The size (logarithmic) of the codeword for the polynomial + let lg_n: usize = rate + poly_size_log; + + let mut key: [u8; 16] = [0u8; 16]; + let mut iv: [u8; 16] = [0u8; 16]; + rng.fill_bytes(&mut key); + rng.fill_bytes(&mut iv); + + type Aes128Ctr64LE = ctr::Ctr32LE; + + let mut cipher = Aes128Ctr64LE::new( + GenericArray::from_slice(&key[..]), + GenericArray::from_slice(&iv[..]), + ); + + // Allocate the buffer for storing n field elements (the entire codeword) + let bytes = num_of_bytes::(1 << lg_n); + let mut dest: Vec = vec![0u8; bytes]; + cipher.apply_keystream(&mut dest[..]); + + // Now, dest is a vector filled with random data for a field vector of size n + + // Collect the bytes into field elements + let flat_table: Vec = dest + .par_chunks_exact(num_of_bytes::(1)) + .map(|chunk| base_from_raw_bytes::(chunk)) + .collect::>(); + + // Now, flat_table is a field vector of size n, filled with random field elements + assert_eq!(flat_table.len(), 1 << lg_n); + + // Multiply -2 to every element to get the weights. Now weights = { -2x } + let mut weights: Vec = flat_table + .par_iter() + .map(|el| E::BaseField::ZERO - *el - *el) + .collect(); + + // Then invert all the elements. Now weights = { -1/2x } + let mut scratch_space = vec![E::BaseField::ZERO; weights.len()]; + BatchInverter::invert_with_external_scratch(&mut weights, &mut scratch_space); + + // Zip x and -1/2x together. The result is the list { (x, -1/2x) } + // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which + // involves computing 1/(b-a) where b=-x and a=x, and 1/(b-a) here is exactly -1/2x + let flat_table_w_weights = flat_table + .iter() + .zip(weights) + .map(|(el, w)| (*el, w)) + .collect_vec(); + + // Split the positions from 0 to n-1 into slices of sizes: + // 2, 2, 4, 8, ..., n/2, exactly lg_n number of them + // The weights are (x, -1/2x), the table elements are just x + + let mut unflattened_table_w_weights = vec![Vec::new(); lg_n]; + let mut unflattened_table = vec![Vec::new(); lg_n]; + + unflattened_table_w_weights[0] = flat_table_w_weights[1..2].to_vec(); + unflattened_table[0] = flat_table[1..2].to_vec(); + for i in 1..lg_n { + unflattened_table[i] = flat_table[(1 << i)..(1 << (i + 1))].to_vec(); + let mut level = flat_table_w_weights[(1 << i)..(1 << (i + 1))].to_vec(); + reverse_index_bits_in_place(&mut level); + unflattened_table_w_weights[i] = level; + } + + (unflattened_table_w_weights, unflattened_table) +} + +pub fn query_root_table_from_rng_aes( + level: usize, + index: usize, + cipher: &mut ctr::Ctr32LE, +) -> E::BaseField { + let mut level_offset: u128 = 1; + for lg_m in 1..=level { + let half_m = 1 << (lg_m - 1); + level_offset += half_m; + } + + let pos = ((level_offset + (reverse_bits(index, level) as u128)) + * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) + .checked_div(8) + .unwrap(); + + cipher.seek(pos); + + let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; + let mut dest: Vec = vec![0u8; bytes]; + cipher.apply_keystream(&mut dest); + + base_from_raw_bytes::(&dest) +} + +#[cfg(test)] +mod tests { + use crate::basefold::encoding::test_util::test_codeword_folding; + + use super::*; + use goldilocks::GoldilocksExt2; + use multilinear_extensions::mle::DenseMultilinearExtension; + + #[test] + fn time_rs_code() { + use rand::rngs::OsRng; + + let poly = DenseMultilinearExtension::random(20, &mut OsRng); + + encode_field_type_rs_basecode::(&poly.evaluations, 2, 64); + } + + #[test] + fn prover_verifier_consistency() { + type Code = Basecode; + let pp: BasecodeParameters = Code::setup(10, [0; 32]); + let (pp, vp) = Code::trim(&pp, 10).unwrap(); + for level in 0..(10 + >::get_rate_log()) { + for index in 0..(1 << level) { + assert_eq!( + Code::prover_folding_coeffs(&pp, level, index), + Code::verifier_folding_coeffs(&vp, level, index), + "failed for level = {}, index = {}", + level, + index + ); + } + } + } + + #[test] + fn test_basecode_codeword_folding() { + test_codeword_folding::>(); + } +} diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs new file mode 100644 index 000000000..79c047060 --- /dev/null +++ b/mpcs/src/basefold/encoding/rs.rs @@ -0,0 +1,877 @@ +use std::marker::PhantomData; + +use super::{EncodingProverParameters, EncodingScheme}; +use crate::{ + util::{field_type_index_mul_base, log2_strict, plonky2_util::reverse_bits}, + vec_mut, Error, +}; +use ark_std::{end_timer, start_timer}; +use ff::{Field, PrimeField}; +use ff_ext::ExtensionField; +use multilinear_extensions::mle::FieldType; + +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::util::plonky2_util::reverse_index_bits_in_place; + +use crate::util::arithmetic::horner; + +pub trait RSCodeSpec: std::fmt::Debug + Clone { + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; +} + +/// The FFT codes in this file are borrowed and adapted from Plonky2. +type FftRootTable = Vec>; + +pub fn fft_root_table(lg_n: usize) -> FftRootTable { + // bases[i] = g^2^i, for i = 0, ..., lg_n - 1 + // Note that the end of bases is g^{n/2} = -1 + let mut bases = Vec::with_capacity(lg_n); + let mut base = F::ROOT_OF_UNITY.pow([(1 << (F::S - lg_n as u32)) as u64]); + bases.push(base); + for _ in 1..lg_n { + base = base.square(); // base = g^2^_ + bases.push(base); + } + + // The result table looks like this: + // len=2: [1, g^{n/2}=-1] + // len=2: [1, g^{n/4}] + // len=4: [1, g^{n/8}, g^{n/4}, g^{3n/8}] + // len=8: [1, g^{n/16}, ..., g^{7n/16}] + // ... + // len=n/2: [1, g, ..., g^{n/2-1}] + // There is no need to compute the other halves of these powers, because + // those would be simply the negations of the previous halves. + let mut root_table = Vec::with_capacity(lg_n); + for lg_m in 1..=lg_n { + let half_m = 1 << (lg_m - 1); + let base = bases[lg_n - lg_m]; + let mut root_row = Vec::with_capacity(half_m.max(2)); + root_row.push(F::ONE); + for i in 1..half_m.max(2) { + root_row.push(root_row[i - 1] * base); + } + root_table.push(root_row); + } + root_table +} + +#[allow(unused)] +fn ifft( + poly: &mut FieldType, + zero_factor: usize, + root_table: &FftRootTable, +) { + let n = poly.len(); + let lg_n = log2_strict(n); + let n_inv = (E::BaseField::ONE + E::BaseField::ONE) + .invert() + .unwrap() + .pow([lg_n as u64]); + + fft(poly, zero_factor, root_table); + + // We reverse all values except the first, and divide each by n. + field_type_index_mul_base(poly, 0, &n_inv); + field_type_index_mul_base(poly, n / 2, &n_inv); + vec_mut!(|poly| for i in 1..(n / 2) { + let j = n - i; + let coeffs_i = poly[j] * n_inv; + let coeffs_j = poly[i] * n_inv; + poly[i] = coeffs_i; + poly[j] = coeffs_j; + }) +} + +/// Core FFT implementation. +fn fft_classic_inner( + values: &mut FieldType, + r: usize, + lg_n: usize, + root_table: &[Vec], +) { + // We've already done the first lg_packed_width (if they were required) iterations. + + for (lg_half_m, cur_root_table) in root_table.iter().enumerate().take(lg_n).skip(r) { + let n = 1 << lg_n; + let lg_m = lg_half_m + 1; + let m = 1 << lg_m; // Subarray size (in field elements). + let half_m = m / 2; + debug_assert!(half_m != 0); + + // omega values for this iteration, as slice of vectors + let omega_table = &cur_root_table[..]; + vec_mut!(|values| { + for k in (0..n).step_by(m) { + for j in 0..half_m { + let omega = omega_table[j]; + let t = values[k + half_m + j] * omega; + let u = values[k + j]; + values[k + j] = u + t; + values[k + half_m + j] = u - t; + } + } + }) + } +} + +/// FFT implementation based on Section 32.3 of "Introduction to +/// Algorithms" by Cormen et al. +/// +/// The parameter r signifies that the first 1/2^r of the entries of +/// input may be non-zero, but the last 1 - 1/2^r entries are +/// definitely zero. +pub fn fft( + values: &mut FieldType, + r: usize, + root_table: &[Vec], +) { + vec_mut!(|values| reverse_index_bits_in_place(values)); + + let n = values.len(); + let lg_n = log2_strict(n); + + if root_table.len() != lg_n { + panic!( + "Expected root table of length {}, but it was {}.", + lg_n, + root_table.len() + ); + } + + // After reverse_index_bits, the only non-zero elements of values + // are at indices i*2^r for i = 0..n/2^r. The loop below copies + // the value at i*2^r to the positions [i*2^r + 1, i*2^r + 2, ..., + // (i+1)*2^r - 1]; i.e. it replaces the 2^r - 1 zeros following + // element i*2^r with the value at i*2^r. This corresponds to the + // first r rounds of the FFT when there are 2^r zeros at the end + // of the original input. + if r > 0 { + // if r == 0 then this loop is a noop. + let mask = !((1 << r) - 1); + match values { + FieldType::Base(values) => { + for i in 0..n { + values[i] = values[i & mask]; + } + } + FieldType::Ext(values) => { + for i in 0..n { + values[i] = values[i & mask]; + } + } + _ => panic!("Unsupported field type"), + } + } + + fft_classic_inner::(values, r, lg_n, root_table); +} + +pub fn coset_fft( + coeffs: &mut FieldType, + shift: E::BaseField, + zero_factor: usize, + root_table: &[Vec], +) { + let mut shift_power = E::BaseField::ONE; + vec_mut!(|coeffs| { + for coeff in coeffs.iter_mut() { + *coeff *= shift_power; + shift_power *= shift; + } + }); + fft(coeffs, zero_factor, root_table); +} + +#[derive(Debug, Clone)] +pub struct RSCodeDefaultSpec {} + +impl RSCodeSpec for RSCodeDefaultSpec { + fn get_number_queries() -> usize { + 336 + } + + fn get_rate_log() -> usize { + 3 + } + + fn get_basecode_msg_size_log() -> usize { + 7 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct RSCodeParameters { + pub(crate) fft_root_table: FftRootTable, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct RSCodeProverParameters { + pub(crate) fft_root_table: FftRootTable, + pub(crate) gamma_powers: Vec, + pub(crate) gamma_powers_inv_div_two: Vec, + pub(crate) full_message_size_log: usize, +} + +impl EncodingProverParameters for RSCodeProverParameters { + fn get_max_message_size_log(&self) -> usize { + self.full_message_size_log + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RSCodeVerifierParameters +where + E::BaseField: Serialize + DeserializeOwned, +{ + /// The verifier also needs a FFT table (much smaller) + /// for small-size encoding. It contains the same roots as the + /// prover's version for the first few levels (i < basecode_msg_size_log) + /// For the other levels (i >= basecode_msg_size_log), + /// it contains only the g^(2^i). + pub(crate) fft_root_table: FftRootTable, + pub(crate) full_message_size_log: usize, + pub(crate) gamma_powers: Vec, + pub(crate) gamma_powers_inv_div_two: Vec, +} + +#[derive(Debug, Clone)] +pub struct RSCode { + _phantom_data: PhantomData, +} + +impl EncodingScheme for RSCode +where + E::BaseField: Serialize + DeserializeOwned, +{ + type PublicParameters = RSCodeParameters; + + type ProverParameters = RSCodeProverParameters; + + type VerifierParameters = RSCodeVerifierParameters; + + fn setup(max_message_size_log: usize, _rng_seed: [u8; 32]) -> Self::PublicParameters { + RSCodeParameters { + fft_root_table: fft_root_table(max_message_size_log + Spec::get_rate_log()), + } + } + + fn trim( + pp: &Self::PublicParameters, + max_message_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { + if pp.fft_root_table.len() < max_message_size_log + Spec::get_rate_log() { + return Err(Error::InvalidPcsParam(format!( + "Public parameter is setup for a smaller message size (log={}) than the trimmed message size (log={})", + pp.fft_root_table.len() - Spec::get_rate_log(), + max_message_size_log, + ))); + } + let mut gamma_powers = Vec::with_capacity(max_message_size_log); + let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); + gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); + gamma_powers_inv.push(E::BaseField::MULTIPLICATIVE_GENERATOR.invert().unwrap()); + for i in 1..max_message_size_log + Spec::get_rate_log() { + gamma_powers.push(gamma_powers[i - 1].square()); + gamma_powers_inv.push(gamma_powers_inv[i - 1].square()); + } + let inv_of_two = E::BaseField::from(2).invert().unwrap(); + gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); + Ok(( + Self::ProverParameters { + fft_root_table: pp.fft_root_table[..max_message_size_log + Spec::get_rate_log()] + .to_vec(), + gamma_powers: gamma_powers.clone(), + gamma_powers_inv_div_two: gamma_powers_inv.clone(), + full_message_size_log: max_message_size_log, + }, + Self::VerifierParameters { + fft_root_table: pp.fft_root_table + [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] + .iter() + .cloned() + .chain( + pp.fft_root_table + [Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] + .iter() + .map(|v| vec![v[1]]), + ) + .collect(), + full_message_size_log: max_message_size_log, + gamma_powers, + gamma_powers_inv_div_two: gamma_powers_inv, + }, + )) + } + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType { + // Use the full message size to determine the shift factor. + Self::encode_internal(&pp.fft_root_table, coeffs, pp.full_message_size_log) + } + + fn encode_small(vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType { + // Use the full message size to determine the shift factor. + Self::encode_internal(&vp.fft_root_table, coeffs, vp.full_message_size_log) + } + + fn get_number_queries() -> usize { + Spec::get_number_queries() + } + + fn get_rate_log() -> usize { + Spec::get_rate_log() + } + + fn get_basecode_msg_size_log() -> usize { + Spec::get_basecode_msg_size_log() + } + + fn message_is_left_and_right_folding() -> bool { + false + } + + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // level is the logarithmic of the codeword size after folded. + // Therefore, the domain after folded is gamma^2^(full_log_n - level) H + // where H is the multiplicative subgroup of size 2^level. + // The element at index i in this domain is + // gamma^2^(full_log_n - level) * ((2^level)-th root of unity)^i + // The x0 and x1 are exactly the two square roots, i.e., + // x0 = gamma^2^(full_log_n - level - 1) * ((2^(level+1))-th root of unity)^i + // Since root_table[i] stores the first half of the powers of + // the 2^(i+1)-th roots of unity, we can avoid recomputing them. + let x0 = if index < (1 << level) { + pp.fft_root_table[level][index] + } else { + -pp.fft_root_table[level][index - (1 << level)] + } * pp.gamma_powers[pp.full_message_size_log + Spec::get_rate_log() - level - 1]; + let x1 = -x0; + // The weight is 1/(x1-x0) = -1/(2x0) + // = -1/2 * (gamma^{-1})^2^(full_codeword_log_n - level - 1) * ((2^(level+1))-th root of unity)^{2^(level+1)-i} + let w = -pp.gamma_powers_inv_div_two + [pp.full_message_size_log + Spec::get_rate_log() - level - 1] + * if index == 0 { + E::BaseField::ONE + } else if index < (1 << level) { + -pp.fft_root_table[level][(1 << level) - index] + } else if index == 1 << level { + -E::BaseField::ONE + } else { + pp.fft_root_table[level][(1 << (level + 1)) - index] + }; + (E::from(x0), E::from(x1), E::from(w)) + } + + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // The same as prover_folding_coeffs, exept that the powers of + // g is computed on the fly for levels exceeding the root table. + let x0 = if level < Spec::get_basecode_msg_size_log() + Spec::get_rate_log() { + if index < (1 << level) { + vp.fft_root_table[level][index] + } else { + -vp.fft_root_table[level][index - (1 << level)] + } + } else { + // In this case, the level-th row of fft root table of the verifier + // only stores the first 2^(level+1)-th roots of unity. + vp.fft_root_table[level][0].pow([index as u64]) + } * vp.gamma_powers[vp.full_message_size_log + Spec::get_rate_log() - level - 1]; + let x1 = -x0; + // The weight is 1/(x1-x0) = -1/(2x0) + // = -1/2 * (gamma^{-1})^2^(full_log_n - level - 1) * ((2^(level+1))-th root of unity)^{2^(level+1)-i} + let w = -vp.gamma_powers_inv_div_two + [vp.full_message_size_log + Spec::get_rate_log() - level - 1] + * if level < Spec::get_basecode_msg_size_log() + Spec::get_rate_log() { + if index == 0 { + E::BaseField::ONE + } else if index < (1 << level) { + -vp.fft_root_table[level][(1 << level) - index] + } else if index == 1 << level { + -E::BaseField::ONE + } else { + vp.fft_root_table[level][(1 << (level + 1)) - index] + } + } else { + // In this case, this level of fft root table of the verifier + // only stores the first 2^(level+1)-th root of unity. + vp.fft_root_table[level][0].pow([(1 << (level + 1)) - index as u64]) + }; + (E::from(x0), E::from(x1), E::from(w)) + } +} + +impl RSCode { + fn encode_internal( + fft_root_table: &FftRootTable, + coeffs: &FieldType, + full_message_size_log: usize, + ) -> FieldType + where + E::BaseField: Serialize + DeserializeOwned, + { + let lg_m = log2_strict(coeffs.len()); + let fft_root_table = &fft_root_table[..lg_m + Spec::get_rate_log()]; + assert!( + lg_m <= full_message_size_log, + "Encoded message exceeds the maximum supported message size of the table." + ); + let rate = 1 << Spec::get_rate_log(); + let mut ret = match coeffs { + FieldType::Base(coeffs) => { + let mut coeffs = coeffs.clone(); + coeffs.extend(itertools::repeat_n( + E::BaseField::ZERO, + coeffs.len() * (rate - 1), + )); + FieldType::Base(coeffs) + } + FieldType::Ext(coeffs) => { + let mut coeffs = coeffs.clone(); + coeffs.extend(itertools::repeat_n(E::ZERO, coeffs.len() * (rate - 1))); + FieldType::Ext(coeffs) + } + _ => panic!("Unsupported field type"), + }; + // Let gamma be the multiplicative generator of the base field. + // The full domain is gamma H where H is the multiplicative subgroup + // of size n * rate. + // When the input message size is not n, but n/2^k, then the domain is + // gamma^2^k H. + let k = 1 << (full_message_size_log - lg_m); + coset_fft( + &mut ret, + E::BaseField::MULTIPLICATIVE_GENERATOR.pow([k]), + Spec::get_rate_log(), + fft_root_table, + ); + ret + } + + #[allow(unused)] + fn folding_coeffs_naive( + level: usize, + index: usize, + full_message_size_log: usize, + ) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // x0 is the index-th 2^(level+1)-th root of unity, multiplied by + // the shift factor at level+1, which is gamma^2^(full_codeword_log_n - level - 1). + let x0 = E::BaseField::ROOT_OF_UNITY + .pow([1 << (E::BaseField::S - (level as u32 + 1))]) + .pow([index as u64]) + * E::BaseField::MULTIPLICATIVE_GENERATOR + .pow([1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]); + let x1 = -x0; + let w = (x1 - x0).invert().unwrap(); + (E::from(x0), E::from(x1), E::from(w)) + } +} + +#[allow(unused)] +fn naive_fft(poly: &[E], rate: usize, shift: E::BaseField) -> Vec { + let timer = start_timer!(|| "Encode RSCode"); + let message_size = poly.len(); + let domain_size_bit = log2_strict(message_size * rate); + let root = E::BaseField::ROOT_OF_UNITY.pow([1 << (E::BaseField::S - domain_size_bit as u32)]); + // The domain is shift * H where H is the multiplicative subgroup of size + // message_size * rate. + let mut domain = Vec::::with_capacity(message_size * rate); + domain.push(shift); + for i in 1..message_size * rate { + domain.push(domain[i - 1] * root); + } + let mut res = vec![E::ZERO; message_size * rate]; + res.iter_mut() + .enumerate() + .for_each(|(i, target)| *target = horner(poly, &E::from(domain[i]))); + end_timer!(timer); + + res +} + +#[cfg(test)] +mod tests { + use crate::{ + basefold::encoding::test_util::test_codeword_folding, + util::{field_type_index_ext, plonky2_util::reverse_index_bits_in_place_field_type}, + }; + + use super::*; + use goldilocks::{Goldilocks, GoldilocksExt2}; + + #[test] + fn test_naive_fft() { + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let mut poly2 = FieldType::Ext(poly.clone()); + + let naive = naive_fft::(&poly, 1, Goldilocks::ONE); + + let root_table = fft_root_table(num_vars); + fft::(&mut poly2, 0, &root_table); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_naive_fft_with_shift() { + use rand::rngs::OsRng; + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)) + .map(|_| GoldilocksExt2::random(&mut OsRng)) + .collect(); + let mut poly2 = FieldType::Ext(poly.clone()); + + let naive = naive_fft::(&poly, 1, Goldilocks::MULTIPLICATIVE_GENERATOR); + + let root_table = fft_root_table(num_vars); + coset_fft::( + &mut poly2, + Goldilocks::MULTIPLICATIVE_GENERATOR, + 0, + &root_table, + ); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_naive_fft_with_rate() { + use rand::rngs::OsRng; + let num_vars = 5; + let rate_bits = 1; + + let poly: Vec = (0..(1 << num_vars)) + .map(|_| GoldilocksExt2::random(&mut OsRng)) + .collect(); + let mut poly2 = vec![GoldilocksExt2::ZERO; poly.len() * (1 << rate_bits)]; + poly2.as_mut_slice()[..poly.len()].copy_from_slice(poly.as_slice()); + let mut poly2 = FieldType::Ext(poly2.clone()); + + let naive = naive_fft::( + &poly, + 1 << rate_bits, + Goldilocks::MULTIPLICATIVE_GENERATOR, + ); + + let root_table = fft_root_table(num_vars + rate_bits); + coset_fft::( + &mut poly2, + Goldilocks::MULTIPLICATIVE_GENERATOR, + rate_bits, + &root_table, + ); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_ifft() { + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let mut poly = FieldType::Ext(poly); + let original = poly.clone(); + + let root_table = fft_root_table(num_vars); + fft::(&mut poly, 0, &root_table); + ifft::(&mut poly, 0, &root_table); + + assert_eq!(original, poly); + } + + #[test] + fn prover_verifier_consistency() { + type Code = RSCode; + let pp: RSCodeParameters = Code::setup(10, [0; 32]); + let (pp, vp) = Code::trim(&pp, 10).unwrap(); + for level in 0..(10 + >::get_rate_log()) { + for index in 0..(1 << level) { + let (naive_x0, naive_x1, naive_w) = + Code::folding_coeffs_naive(level, index, pp.full_message_size_log); + let (p_x0, p_x1, p_w) = Code::prover_folding_coeffs(&pp, level, index); + let (v_x0, v_x1, v_w) = Code::verifier_folding_coeffs(&vp, level, index); + // assert_eq!(v_w * (v_x1 - v_x0), GoldilocksExt2::ONE); + // assert_eq!(p_w * (p_x1 - p_x0), GoldilocksExt2::ONE); + assert_eq!( + (v_x0, v_x1, v_w, p_x0, p_x1, p_w), + (naive_x0, naive_x1, naive_w, naive_x0, naive_x1, naive_w), + "failed for level = {}, index = {}", + level, + index + ); + } + } + } + + #[test] + fn test_rs_codeword_folding() { + test_codeword_folding::>(); + } + + type E = GoldilocksExt2; + type F = Goldilocks; + type Code = RSCode; + + #[test] + pub fn test_colinearity() { + let num_vars = 10; + + let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp = >::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + reverse_index_bits_in_place_field_type(&mut codeword); + let challenge = E::from(2); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let codeword = match codeword { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + + for (i, (a, b)) in folded_codeword.iter().zip(codeword.chunks(2)).enumerate() { + let (x0, x1, _) = Code::prover_folding_coeffs( + &pp, + num_vars + >::get_rate_log() - 1, + i, + ); + // Check that (x0, b[0]), (x1, b[1]) and (challenge, a) are + // on the same line, i.e., + // (b[0]-a)/(x0-challenge) = (b[1]-a)/(x1-challenge) + // which is equivalent to + // (x0-challenge)*(b[1]-a) = (x1-challenge)*(b[0]-a) + assert_eq!( + (x0 - challenge) * (b[1] - a), + (x1 - challenge) * (b[0] - a), + "failed for i = {}", + i + ); + } + } + + #[test] + pub fn test_low_degree() { + let num_vars = 10; + + let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp = >::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + check_low_degree(&codeword, "low degree check for original codeword"); + let c0 = field_type_index_ext(&codeword, 0); + let c_mid = field_type_index_ext(&codeword, codeword.len() >> 1); + let c1 = field_type_index_ext(&codeword, 1); + let c_mid1 = field_type_index_ext(&codeword, (codeword.len() >> 1) + 1); + + reverse_index_bits_in_place_field_type(&mut codeword); + // After the bit inversion, the first element is still the first, + // but the middle one is switched to the second. + assert_eq!(c0, field_type_index_ext(&codeword, 0)); + assert_eq!(c_mid, field_type_index_ext(&codeword, 1)); + // The second element is placed at the middle, and the next to middle + // element is still at the place. + assert_eq!(c1, field_type_index_ext(&codeword, codeword.len() >> 1)); + assert_eq!( + c_mid1, + field_type_index_ext(&codeword, (codeword.len() >> 1) + 1) + ); + + // For RS codeword, the addition of the left and right halves is also + // a valid codeword + let codeword_vec = match &codeword { + FieldType::Ext(coeffs) => coeffs.clone(), + _ => panic!("Wrong field type"), + }; + let mut left_right_sum: Vec = codeword_vec + .chunks(2) + .map(|chunk| chunk[0] + chunk[1]) + .collect(); + assert_eq!(left_right_sum[0], c0 + c_mid); + reverse_index_bits_in_place(&mut left_right_sum); + assert_eq!(left_right_sum[1], c1 + c_mid1); + check_low_degree( + &FieldType::Ext(left_right_sum.clone()), + "check low degree of left+right", + ); + + // The the difference of the left and right halves is also + // a valid codeword after twisted by omega^(-i), regardless of the + // shift of the coset. + let mut left_right_diff: Vec = codeword_vec + .chunks(2) + .map(|chunk| chunk[0] - chunk[1]) + .collect(); + assert_eq!(left_right_diff[0], c0 - c_mid); + reverse_index_bits_in_place(&mut left_right_diff); + assert_eq!(left_right_diff[1], c1 - c_mid1); + let root_of_unity_inv = F::ROOT_OF_UNITY_INV + .pow([1 << (F::S as usize - log2_strict(left_right_diff.len()) - 1)]); + for (i, coeff) in left_right_diff.iter_mut().enumerate() { + *coeff *= root_of_unity_inv.pow([i as u64]); + } + assert_eq!(left_right_diff[0], c0 - c_mid); + assert_eq!(left_right_diff[1], (c1 - c_mid1) * root_of_unity_inv); + check_low_degree( + &FieldType::Ext(left_right_diff.clone()), + "check low degree of (left-right)*omega^(-i)", + ); + + let challenge = E::from(2); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let c_fold = folded_codeword[0]; + let c_fold1 = folded_codeword[folded_codeword.len() >> 1]; + let mut folded_codeword = FieldType::Ext(folded_codeword); + reverse_index_bits_in_place_field_type(&mut folded_codeword); + assert_eq!(c_fold, field_type_index_ext(&folded_codeword, 0)); + assert_eq!(c_fold1, field_type_index_ext(&folded_codeword, 1)); + + // The top level folding coefficient should have shift factor gamma + let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 0); + assert_eq!(folding_coeffs.0, E::from(F::MULTIPLICATIVE_GENERATOR)); + assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); + assert_eq!( + (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, + E::ONE + ); + // The three points (x0, c0), (x1, c_mid), (challenge, c_fold) should + // be colinear + assert_eq!( + (c_mid - c_fold) * (folding_coeffs.0 - challenge), + (c0 - c_fold) * (folding_coeffs.1 - challenge), + ); + // So the folded value should be equal to + // (gamma^{-1} * alpha * (c0 - c_mid) + (c0 + c_mid)) / 2 + assert_eq!( + c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), + challenge * (c0 - c_mid) + (c0 + c_mid) * F::MULTIPLICATIVE_GENERATOR + ); + assert_eq!( + c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), + challenge * left_right_diff[0] + left_right_sum[0] * F::MULTIPLICATIVE_GENERATOR + ); + assert_eq!( + c_fold * F::from(2), + challenge * left_right_diff[0] * F::MULTIPLICATIVE_GENERATOR.invert().unwrap() + + left_right_sum[0] + ); + + let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 1); + let root_of_unity = + F::ROOT_OF_UNITY.pow([1 << (F::S as usize - log2_strict(codeword.len()))]); + assert_eq!(root_of_unity.pow([codeword.len() as u64]), F::ONE); + assert_eq!(root_of_unity.pow([(codeword.len() >> 1) as u64]), -F::ONE); + assert_eq!( + folding_coeffs.0, + E::from(F::MULTIPLICATIVE_GENERATOR) + * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) + ); + assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); + assert_eq!( + (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, + E::ONE + ); + + // The folded codeword is the linear combination of the left+right and the + // twisted left-right vectors. + // The coefficients are respectively 1/2 and gamma^{-1}/2 * alpha. + // In another word, the folded codeword multipled by 2 is the linear + // combination by coeffs: 1 and gamma^{-1} * alpha + let gamma_inv = F::MULTIPLICATIVE_GENERATOR.invert().unwrap(); + let b = challenge * gamma_inv; + let folded_codeword_vec = match &folded_codeword { + FieldType::Ext(coeffs) => coeffs.clone(), + _ => panic!("Wrong field type"), + }; + assert_eq!( + c_fold * F::from(2), + left_right_diff[0] * b + left_right_sum[0] + ); + for (i, (c, (diff, sum))) in folded_codeword_vec + .iter() + .zip(left_right_diff.iter().zip(left_right_sum.iter())) + .enumerate() + { + assert_eq!(*c + c, *sum + b * diff, "failed for i = {}", i); + } + + check_low_degree(&folded_codeword, "low degree check for folded"); + } + + fn check_low_degree(codeword: &FieldType, message: &str) { + let mut codeword = codeword.clone(); + let codeword_bits = log2_strict(codeword.len()); + let root_table = fft_root_table(codeword_bits); + let original = codeword.clone(); + ifft(&mut codeword, 0, &root_table); + for i in (codeword.len() >> >::get_rate_log())..codeword.len() { + assert_eq!( + field_type_index_ext(&codeword, i), + E::ZERO, + "{}: zero check failed for i = {}", + message, + i + ) + } + fft(&mut codeword, 0, &root_table); + let original = match original { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + let codeword = match codeword { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + original + .iter() + .zip(codeword.iter()) + .enumerate() + .for_each(|(i, (a, b))| { + assert_eq!(a, b, "{}: failed for i = {}", message, i); + }); + } +} diff --git a/mpcs/src/basefold/encoding/utils.rs b/mpcs/src/basefold/encoding/utils.rs new file mode 100644 index 000000000..36137526a --- /dev/null +++ b/mpcs/src/basefold/encoding/utils.rs @@ -0,0 +1,35 @@ +#[macro_export] +macro_rules! vec_mut { + ($a:ident, |$tmp_a:ident| $op:expr) => { + match $a { + multilinear_extensions::mle::FieldType::Base(ref mut $tmp_a) => $op, + multilinear_extensions::mle::FieldType::Ext(ref mut $tmp_a) => $op, + _ => unreachable!(), + } + }; + (|$a:ident| $op:expr) => { + vec_mut!($a, |$a| $op) + }; +} + +#[macro_export] +macro_rules! vec_map { + ($a:ident, |$tmp_a:ident| $op:expr) => { + match &$a { + multilinear_extensions::mle::FieldType::Base(a) => { + let $tmp_a = &a[..]; + let out = $op; + multilinear_extensions::mle::FieldType::Base(out) + } + multilinear_extensions::mle::FieldType::Ext(a) => { + let $tmp_a = &a[..]; + let out = $op; + multilinear_extensions::mle::FieldType::Base(out) + } + _ => unreachable!(), + } + }; + (|$a:ident| $op:expr) => { + vec_map!($a, |$a| $op) + }; +} diff --git a/mpcs/src/basefold/query_phase.rs b/mpcs/src/basefold/query_phase.rs index dacd254a1..57d41db68 100644 --- a/mpcs/src/basefold/query_phase.rs +++ b/mpcs/src/basefold/query_phase.rs @@ -1,37 +1,35 @@ -use super::basecode::{encode_rs_basecode, query_point}; use crate::util::{ arithmetic::{ - degree_2_eval, degree_2_zero_plus_one, inner_product, interpolate2, + degree_2_eval, degree_2_zero_plus_one, inner_product, interpolate2_weights, interpolate_over_boolean_hypercube, }, - ext_to_usize, + ext_to_usize, field_type_index_base, field_type_index_ext, hash::{Digest, Hasher}, log2_strict, merkle_tree::{MerklePathWithoutLeafOrRoot, MerkleTree}, transcript::{TranscriptRead, TranscriptWrite}, }; -use aes::cipher::KeyIvInit; use ark_std::{end_timer, start_timer}; use core::fmt::Debug; -use ctr; use ff_ext::ExtensionField; -use generic_array::GenericArray; use itertools::Itertools; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use multilinear_extensions::mle::FieldType; -use crate::util::plonky2_util::{reverse_bits, reverse_index_bits_in_place}; -use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; +use crate::util::plonky2_util::reverse_index_bits_in_place; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use super::structure::{BasefoldCommitment, BasefoldCommitmentWithData}; +use super::{ + encoding::EncodingScheme, + structure::{BasefoldCommitment, BasefoldCommitmentWithData, BasefoldSpec}, +}; -pub fn query_phase( +pub fn prover_query_phase( transcript: &mut impl TranscriptWrite, E>, comm: &BasefoldCommitmentWithData, - oracles: &Vec>, + oracles: &[Vec], num_verifier_queries: usize, ) -> QueriesResult where @@ -51,16 +49,298 @@ where .map(|x_index| { ( *x_index, - basefold_get_query::(comm.get_codeword(), &oracles, *x_index), + basefold_get_query::(&comm.get_codewords()[0], oracles, *x_index), ) }) .collect(), } } +pub fn batch_prover_query_phase( + transcript: &mut impl TranscriptWrite, E>, + codeword_size: usize, + comms: &[BasefoldCommitmentWithData], + oracles: &[Vec], + num_verifier_queries: usize, +) -> BatchedQueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + let queries = transcript.squeeze_challenges(num_verifier_queries); + + // Transform the challenge queries from field elements into integers + let queries_usize: Vec = queries + .iter() + .map(|x_index| ext_to_usize(x_index) % codeword_size) + .collect_vec(); + + BatchedQueriesResult { + inner: queries_usize + .par_iter() + .map(|x_index| { + ( + *x_index, + batch_basefold_get_query::(comms, oracles, codeword_size, *x_index), + ) + }) + .collect(), + } +} + +pub fn simple_batch_prover_query_phase( + transcript: &mut impl TranscriptWrite, E>, + comm: &BasefoldCommitmentWithData, + oracles: &[Vec], + num_verifier_queries: usize, +) -> SimpleBatchQueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + let queries = transcript.squeeze_challenges(num_verifier_queries); + + // Transform the challenge queries from field elements into integers + let queries_usize: Vec = queries + .iter() + .map(|x_index| ext_to_usize(x_index) % comm.codeword_size()) + .collect_vec(); + + SimpleBatchQueriesResult { + inner: queries_usize + .par_iter() + .map(|x_index| { + ( + *x_index, + simple_batch_basefold_get_query::(comm.get_codewords(), oracles, *x_index), + ) + }) + .collect(), + } +} + +#[allow(clippy::too_many_arguments)] +pub fn verifier_query_phase>( + vp: &>::VerifierParameters, + queries: &QueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &Vec>, + comm: &BasefoldCommitment, + partial_eq: &[E], + eval: &E, + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier query phase"); + + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + interpolate_over_boolean_hypercube(&mut message); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + queries.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + &final_codeword, + roots, + comm, + hasher, + ); + + let final_timer = start_timer!(|| "Final checks"); + assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + + end_timer!(timer); +} + +#[allow(clippy::too_many_arguments)] +pub fn batch_verifier_query_phase>( + vp: &>::VerifierParameters, + queries: &BatchedQueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], + coeffs: &[E], + partial_eq: &[E], + eval: &E, + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier batch query phase"); + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + interpolate_over_boolean_hypercube(&mut message); + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + // For computing the weights on the fly, because the verifier is incapable of storing + // the weights. + + queries.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + &final_codeword, + roots, + comms, + coeffs, + hasher, + ); + + #[allow(unused)] + let final_timer = start_timer!(|| "Final checks"); + assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + end_timer!(timer); +} + +#[allow(clippy::too_many_arguments)] +pub fn simple_batch_verifier_query_phase>( + vp: &>::VerifierParameters, + queries: &SimpleBatchQueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + batch_coeffs: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + partial_eq: &[E], + evals: &[E], + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier query phase"); + + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + interpolate_over_boolean_hypercube(&mut message); + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + // For computing the weights on the fly, because the verifier is incapable of storing + // the weights. + queries.check::( + vp, + fold_challenges, + batch_coeffs, + num_rounds, + num_vars, + &final_codeword, + roots, + comm, + hasher, + ); + + let final_timer = start_timer!(|| "Final checks"); + assert_eq!( + &inner_product(batch_coeffs, evals), + °ree_2_zero_plus_one(&sum_check_messages[0]) + ); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + + end_timer!(timer); +} + fn basefold_get_query( poly_codeword: &FieldType, - oracles: &Vec>, + oracles: &[Vec], x_index: usize, ) -> SingleQueryResult where @@ -96,15 +376,15 @@ where inner: oracle_queries, }; - return SingleQueryResult { + SingleQueryResult { oracle_query, commitment_query, - }; + } } fn batch_basefold_get_query( - comms: &[&BasefoldCommitmentWithData], - oracles: &Vec>, + comms: &[BasefoldCommitmentWithData], + oracles: &[Vec], codeword_size: usize, x_index: usize, ) -> BatchedSingleQueryResult @@ -133,7 +413,7 @@ where let x_index = x_index >> (log2_strict(codeword_size) - comm.codeword_size_log()); let p1 = x_index | 1; let p0 = p1 - 1; - match comm.get_codeword() { + match &comm.get_codewords()[0] { FieldType::Ext(poly_codeword) => { CodewordSingleQueryResult::new_ext(poly_codeword[p0], poly_codeword[p1], p0) } @@ -155,6 +435,66 @@ where } } +fn simple_batch_basefold_get_query( + poly_codewords: &[FieldType], + oracles: &[Vec], + x_index: usize, +) -> SimpleBatchSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + let mut index = x_index; + let p1 = index | 1; + let p0 = p1 - 1; + + let commitment_query = match poly_codewords[0] { + FieldType::Ext(_) => SimpleBatchCommitmentSingleQueryResult::new_ext( + poly_codewords + .iter() + .map(|c| field_type_index_ext(c, p0)) + .collect(), + poly_codewords + .iter() + .map(|c| field_type_index_ext(c, p1)) + .collect(), + p0, + ), + FieldType::Base(_) => SimpleBatchCommitmentSingleQueryResult::new_base( + poly_codewords + .iter() + .map(|c| field_type_index_base(c, p0)) + .collect(), + poly_codewords + .iter() + .map(|c| field_type_index_base(c, p1)) + .collect(), + p0, + ), + _ => unreachable!(), + }; + index >>= 1; + + let mut oracle_queries = Vec::with_capacity(oracles.len() + 1); + for oracle in oracles { + let p1 = index | 1; + let p0 = p1 - 1; + + oracle_queries.push(CodewordSingleQueryResult::new_ext( + oracle[p0], oracle[p1], p0, + )); + index >>= 1; + } + + let oracle_query = OracleListQueryResult { + inner: oracle_queries, + }; + + SimpleBatchSingleQueryResult { + oracle_query, + commitment_query, + } +} + #[derive(Debug, Copy, Clone, Serialize, Deserialize)] enum CodewordPointPair { Ext(E, E), @@ -170,6 +510,51 @@ impl CodewordPointPair { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +enum SimpleBatchLeavesPair +where + E::BaseField: Serialize + DeserializeOwned, +{ + Ext(Vec<(E, E)>), + Base(Vec<(E::BaseField, E::BaseField)>), +} + +impl SimpleBatchLeavesPair +where + E::BaseField: Serialize + DeserializeOwned, +{ + #[allow(unused)] + pub fn as_ext(&self) -> Vec<(E, E)> { + match self { + SimpleBatchLeavesPair::Ext(x) => x.clone(), + SimpleBatchLeavesPair::Base(x) => { + x.iter().map(|(x, y)| ((*x).into(), (*y).into())).collect() + } + } + } + + pub fn batch(&self, coeffs: &[E]) -> (E, E) { + match self { + SimpleBatchLeavesPair::Ext(x) => { + let mut result = (E::ZERO, E::ZERO); + for (i, (x, y)) in x.iter().enumerate() { + result.0 += coeffs[i] * *x; + result.1 += coeffs[i] * *y; + } + result + } + SimpleBatchLeavesPair::Base(x) => { + let mut result = (E::ZERO, E::ZERO); + for (i, (x, y)) in x.iter().enumerate() { + result.0 += coeffs[i] * *x; + result.1 += coeffs[i] * *y; + } + result + } + } + } +} + #[derive(Debug, Copy, Clone, Serialize, Deserialize)] struct CodewordSingleQueryResult where @@ -510,12 +895,9 @@ where ) -> Vec> { let ret = self .get_inner() - .into_iter() + .iter() .enumerate() - .map(|(i, query_result)| { - let path = path(i, query_result.index); - path - }) + .map(|(i, query_result)| path(i, query_result.index)) .collect_vec(); ret } @@ -537,7 +919,7 @@ where query_result .merkle_path(path) .into_iter() - .zip(query_result.get_inner_into().into_iter()) + .zip(query_result.get_inner_into()) .map( |(path, codeword_result)| CodewordSingleQueryResultWithMerklePath { query: codeword_result, @@ -554,7 +936,7 @@ where .for_each(|q| q.write_transcript(transcript)); } - fn check_merkle_paths(&self, roots: &Vec>, hasher: &Hasher) { + fn check_merkle_paths(&self, roots: &[Digest], hasher: &Hasher) { // let timer = start_timer!(|| "ListQuery::Check Merkle Path"); self.get_inner() .iter() @@ -590,16 +972,17 @@ where { pub fn from_single_query_result( single_query_result: SingleQueryResult, - oracle_trees: &Vec>, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { + assert!(commitment.codeword_tree.height() > 0); Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( single_query_result.oracle_query, |i, j| oracle_trees[i].merkle_path_without_leaf_sibling_or_root(j), ), commitment_query: CodewordSingleQueryResultWithMerklePath { - query: single_query_result.commitment_query.clone(), + query: single_query_result.commitment_query, merkle_path: commitment .codeword_tree .merkle_path_without_leaf_sibling_or_root( @@ -660,23 +1043,23 @@ where } } - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, - mut cipher: ctr::Ctr32LE, index: usize, hasher: &Hasher, ) { // let timer = start_timer!(|| "Checking codeword single query"); self.oracle_query.check_merkle_paths(roots, hasher); self.commitment_query - .check_merkle_path(&Digest(comm.root().0.try_into().unwrap()), hasher); + .check_merkle_path(&Digest(comm.root().0), hasher); let (mut curr_left, mut curr_right) = self.commitment_query.query.codepoints.as_ext(); @@ -684,18 +1067,14 @@ where let mut left_index = right_index - 1; for i in 0..num_rounds { - // let round_timer = start_timer!(|| format!("SingleQueryResult::round {}", i)); - let ri0 = reverse_bits(left_index, num_vars + log_rate - i); - - let x0 = E::from(query_point::( - 1 << (num_vars + log_rate - i), - ri0, - num_vars + log_rate - i - 1, - &mut cipher, - )); - let x1 = -x0; + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, + ); - let res = interpolate2([(x0, curr_left), (x1, curr_right)], fold_challenges[i]); + let res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); let next_index = right_index >> 1; let next_oracle_value = if i < num_rounds - 1 { @@ -720,6 +1099,135 @@ where } } +pub struct QueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + inner: Vec<(usize, SingleQueryResult)>, +} + +pub struct QueriesResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + inner: Vec<(usize, SingleQueryResultWithMerklePath)>, +} + +impl QueriesResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn from_query_result( + query_result: QueriesResult, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithData, + ) -> Self { + Self { + inner: query_result + .inner + .into_iter() + .map(|(i, q)| { + ( + i, + SingleQueryResultWithMerklePath::from_single_query_result( + q, + oracle_trees, + commitment, + ), + ) + }) + .collect(), + } + } + + pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + self.inner.iter().for_each(|(_, q)| { + q.write_transcript(transcript); + }); + } + + pub fn read_transcript_base( + transcript: &mut impl TranscriptRead, E>, + num_rounds: usize, + log_rate: usize, + poly_num_vars: usize, + indices: &[usize], + ) -> Self { + Self { + inner: indices + .iter() + .map(|index| { + ( + *index, + SingleQueryResultWithMerklePath::read_transcript_base( + transcript, + num_rounds, + log_rate, + poly_num_vars, + *index, + ), + ) + }) + .collect(), + } + } + + pub fn read_transcript_ext( + transcript: &mut impl TranscriptRead, E>, + num_rounds: usize, + log_rate: usize, + poly_num_vars: usize, + indices: &[usize], + ) -> Self { + Self { + inner: indices + .iter() + .map(|index| { + ( + *index, + SingleQueryResultWithMerklePath::read_transcript_ext( + transcript, + num_rounds, + log_rate, + poly_num_vars, + *index, + ), + ) + }) + .collect(), + } + } + + #[allow(clippy::too_many_arguments)] + pub fn check>( + &self, + vp: &>::VerifierParameters, + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_codeword: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + hasher: &Hasher, + ) { + let timer = start_timer!(|| "QueriesResult::check"); + self.inner.par_iter().for_each(|(index, query)| { + query.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + final_codeword, + roots, + comm, + *index, + hasher, + ); + }); + end_timer!(timer); + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] struct BatchedSingleQueryResult where @@ -744,8 +1252,8 @@ where { pub fn from_batched_single_query_result( batched_single_query_result: BatchedSingleQueryResult, - oracle_trees: &Vec>, - commitments: &Vec<&BasefoldCommitmentWithData>, + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithData], ) -> Self { Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( @@ -818,23 +1326,29 @@ where } } - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, + final_codeword: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], coeffs: &[E], - mut cipher: ctr::Ctr32LE, index: usize, hasher: &Hasher, ) { self.oracle_query.check_merkle_paths(roots, hasher); - self.commitments_query - .check_merkle_paths(&comms.iter().map(|comm| comm.root()).collect(), hasher); + self.commitments_query.check_merkle_paths( + comms + .iter() + .map(|comm| comm.root()) + .collect_vec() + .as_slice(), + hasher, + ); // end_timer!(commit_timer); let mut curr_left = E::ZERO; @@ -845,7 +1359,6 @@ where for i in 0..num_rounds { // let round_timer = start_timer!(|| format!("BatchedSingleQueryResult::round {}", i)); - let ri0 = reverse_bits(left_index, num_vars + log_rate - i); let matching_comms = comms .iter() .enumerate() @@ -854,21 +1367,20 @@ where .collect_vec(); matching_comms.iter().for_each(|index| { - let query = self.commitments_query.get_inner()[*index].query.clone(); + let query = self.commitments_query.get_inner()[*index].query; assert_eq!(query.index >> 1, left_index >> 1); curr_left += query.left_ext() * coeffs[*index]; curr_right += query.right_ext() * coeffs[*index]; }); - let x0: E = E::from(query_point::( - 1 << (num_vars + log_rate - i), - ri0, - num_vars + log_rate - i - 1, - &mut cipher, - )); - let x1 = -x0; + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, + ); - let mut res = interpolate2([(x0, curr_left), (x1, curr_right)], fold_challenges[i]); + let mut res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); let next_index = right_index >> 1; @@ -901,7 +1413,7 @@ where matching_comms.iter().for_each(|index| { let query: CodewordSingleQueryResult = - self.commitments_query.get_inner()[*index].query.clone(); + self.commitments_query.get_inner()[*index].query; assert_eq!(query.index >> 1, next_index >> 1); if next_index & 1 == 0 { res += query.left_ext() * coeffs[*index]; @@ -941,8 +1453,8 @@ where { pub fn from_batched_query_result( batched_query_result: BatchedQueriesResult, - oracle_trees: &Vec>, - commitments: &Vec<&BasefoldCommitmentWithData>, + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithData], ) -> Self { Self { inner: batched_query_result @@ -963,34 +1475,355 @@ where } pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { - self.inner - .iter() - .for_each(|(_, q)| q.write_transcript(transcript)); + self.inner + .iter() + .for_each(|(_, q)| q.write_transcript(transcript)); + } + + pub fn read_transcript_base( + transcript: &mut impl TranscriptRead, E>, + num_rounds: usize, + log_rate: usize, + poly_num_vars: &[usize], + indices: &[usize], + ) -> Self { + Self { + inner: indices + .iter() + .map(|index| { + ( + *index, + BatchedSingleQueryResultWithMerklePath::read_transcript_base( + transcript, + num_rounds, + log_rate, + poly_num_vars, + *index, + ), + ) + }) + .collect(), + } + } + + pub fn read_transcript_ext( + transcript: &mut impl TranscriptRead, E>, + num_rounds: usize, + log_rate: usize, + poly_num_vars: &[usize], + indices: &[usize], + ) -> Self { + Self { + inner: indices + .iter() + .map(|index| { + ( + *index, + BatchedSingleQueryResultWithMerklePath::read_transcript_ext( + transcript, + num_rounds, + log_rate, + poly_num_vars, + *index, + ), + ) + }) + .collect(), + } + } + + #[allow(clippy::too_many_arguments)] + pub fn check>( + &self, + vp: &>::VerifierParameters, + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_codeword: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], + coeffs: &[E], + hasher: &Hasher, + ) { + let timer = start_timer!(|| "BatchedQueriesResult::check"); + self.inner.par_iter().for_each(|(index, query)| { + query.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + final_codeword, + roots, + comms, + coeffs, + *index, + hasher, + ); + }); + end_timer!(timer); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleBatchCommitmentSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + leaves: SimpleBatchLeavesPair, + index: usize, +} + +impl SimpleBatchCommitmentSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + fn new_ext(left: Vec, right: Vec, index: usize) -> Self { + Self { + leaves: SimpleBatchLeavesPair::Ext(left.into_iter().zip(right).collect()), + index, + } + } + + fn new_base(left: Vec, right: Vec, index: usize) -> Self { + Self { + leaves: SimpleBatchLeavesPair::Base(left.into_iter().zip(right).collect()), + index, + } + } + + #[allow(unused)] + fn left_ext(&self) -> Vec { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => x.iter().map(|(x, _)| *x).collect(), + SimpleBatchLeavesPair::Base(x) => x.iter().map(|(x, _)| E::from(*x)).collect(), + } + } + + #[allow(unused)] + fn right_ext(&self) -> Vec { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => x.iter().map(|(_, x)| *x).collect(), + SimpleBatchLeavesPair::Base(x) => x.iter().map(|(_, x)| E::from(*x)).collect(), + } + } + + pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => { + x.iter().for_each(|(x, y)| { + transcript.write_field_element_ext(x).unwrap(); + transcript.write_field_element_ext(y).unwrap(); + }); + } + SimpleBatchLeavesPair::Base(x) => { + x.iter().for_each(|(x, y)| { + transcript.write_field_element_base(x).unwrap(); + transcript.write_field_element_base(y).unwrap(); + }); + } + }; + } + + pub fn read_transcript_ext( + transcript: &mut impl TranscriptRead, E>, + full_codeword_size_log: usize, + codeword_size_log: usize, + index: usize, + batch_size: usize, + ) -> Self { + let mut left = vec![]; + let mut right = vec![]; + (0..batch_size).for_each(|_| { + left.push(transcript.read_field_element_ext().unwrap()); + right.push(transcript.read_field_element_ext().unwrap()); + }); + Self::new_ext( + left, + right, + index >> (full_codeword_size_log - codeword_size_log), + ) + } + + pub fn read_transcript_base( + transcript: &mut impl TranscriptRead, E>, + full_codeword_size_log: usize, + codeword_size_log: usize, + index: usize, + batch_size: usize, + ) -> Self { + let mut left = vec![]; + let mut right = vec![]; + (0..batch_size).for_each(|_| { + left.push(transcript.read_field_element_base().unwrap()); + right.push(transcript.read_field_element_base().unwrap()); + }); + Self::new_base( + left, + right, + index >> (full_codeword_size_log - codeword_size_log), + ) + } +} + +#[derive(Debug, Clone)] +struct SimpleBatchCommitmentSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + query: SimpleBatchCommitmentSingleQueryResult, + merkle_path: MerklePathWithoutLeafOrRoot, +} + +impl SimpleBatchCommitmentSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + self.query.write_transcript(transcript); + self.merkle_path.write_transcript(transcript); + } + + pub fn read_transcript_base( + transcript: &mut impl TranscriptRead, E>, + full_codeword_size_log: usize, + codeword_size_log: usize, + index: usize, + batch_size: usize, + ) -> Self { + Self { + query: SimpleBatchCommitmentSingleQueryResult::read_transcript_base( + transcript, + full_codeword_size_log, + codeword_size_log, + index, + batch_size, + ), + merkle_path: MerklePathWithoutLeafOrRoot::read_transcript( + transcript, + codeword_size_log, + ), + } + } + + pub fn read_transcript_ext( + transcript: &mut impl TranscriptRead, E>, + full_codeword_size_log: usize, + codeword_size_log: usize, + index: usize, + batch_size: usize, + ) -> Self { + Self { + query: SimpleBatchCommitmentSingleQueryResult::read_transcript_ext( + transcript, + full_codeword_size_log, + codeword_size_log, + index, + batch_size, + ), + merkle_path: MerklePathWithoutLeafOrRoot::read_transcript( + transcript, + codeword_size_log, + ), + } + } + + pub fn check_merkle_path(&self, root: &Digest, hasher: &Hasher) { + // let timer = start_timer!(|| "CodewordSingleQuery::Check Merkle Path"); + match &self.query.leaves { + SimpleBatchLeavesPair::Ext(inner) => { + self.merkle_path.authenticate_batch_leaves_root_ext( + inner.iter().map(|(x, _)| *x).collect(), + inner.iter().map(|(_, x)| *x).collect(), + self.query.index, + root, + hasher, + ); + } + SimpleBatchLeavesPair::Base(inner) => { + self.merkle_path.authenticate_batch_leaves_root_base( + inner.iter().map(|(x, _)| *x).collect(), + inner.iter().map(|(_, x)| *x).collect(), + self.query.index, + root, + hasher, + ); + } + } + // end_timer!(timer); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleBatchSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + oracle_query: OracleListQueryResult, + commitment_query: SimpleBatchCommitmentSingleQueryResult, +} + +#[derive(Debug, Clone)] +struct SimpleBatchSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + oracle_query: OracleListQueryResultWithMerklePath, + commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath, +} + +impl SimpleBatchSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn from_single_query_result( + single_query_result: SimpleBatchSingleQueryResult, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithData, + ) -> Self { + Self { + oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( + single_query_result.oracle_query, + |i, j| oracle_trees[i].merkle_path_without_leaf_sibling_or_root(j), + ), + commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath { + query: single_query_result.commitment_query.clone(), + merkle_path: commitment + .codeword_tree + .merkle_path_without_leaf_sibling_or_root( + single_query_result.commitment_query.index, + ), + }, + } + } + + pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + self.oracle_query.write_transcript(transcript); + self.commitment_query.write_transcript(transcript); } pub fn read_transcript_base( transcript: &mut impl TranscriptRead, E>, num_rounds: usize, log_rate: usize, - poly_num_vars: &[usize], - indices: &[usize], + num_vars: usize, + index: usize, + batch_size: usize, ) -> Self { Self { - inner: indices - .iter() - .map(|index| { - ( - *index, - BatchedSingleQueryResultWithMerklePath::read_transcript_base( - transcript, - num_rounds, - log_rate, - poly_num_vars, - *index, - ), - ) - }) - .collect(), + oracle_query: OracleListQueryResultWithMerklePath::read_transcript( + transcript, + num_rounds, + num_vars + log_rate, + index, + ), + commitment_query: + SimpleBatchCommitmentSingleQueryResultWithMerklePath::read_transcript_base( + transcript, + num_vars + log_rate, + num_vars + log_rate, + index, + batch_size, + ), } } @@ -998,82 +1831,109 @@ where transcript: &mut impl TranscriptRead, E>, num_rounds: usize, log_rate: usize, - poly_num_vars: &[usize], - indices: &[usize], + num_vars: usize, + index: usize, + batch_size: usize, ) -> Self { Self { - inner: indices - .iter() - .map(|index| { - ( - *index, - BatchedSingleQueryResultWithMerklePath::read_transcript_ext( - transcript, - num_rounds, - log_rate, - poly_num_vars, - *index, - ), - ) - }) - .collect(), + oracle_query: OracleListQueryResultWithMerklePath::read_transcript( + transcript, + num_rounds, + num_vars + log_rate, + index, + ), + commitment_query: + SimpleBatchCommitmentSingleQueryResultWithMerklePath::read_transcript_ext( + transcript, + num_vars + log_rate, + num_vars + log_rate, + index, + batch_size, + ), } } - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, - coeffs: &[E], - cipher: ctr::Ctr32LE, + final_codeword: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + index: usize, hasher: &Hasher, ) { - let timer = start_timer!(|| "BatchedQueriesResult::check"); - self.inner.par_iter().for_each(|(index, query)| { - query.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - final_codeword, - roots, - comms, - coeffs, - cipher.clone(), - *index, - hasher, + let timer = start_timer!(|| "Checking codeword single query"); + self.oracle_query.check_merkle_paths(roots, hasher); + self.commitment_query + .check_merkle_path(&Digest(comm.root().0), hasher); + + let (mut curr_left, mut curr_right) = + self.commitment_query.query.leaves.batch(batch_coeffs); + + let mut right_index = index | 1; + let mut left_index = right_index - 1; + + for i in 0..num_rounds { + // let round_timer = start_timer!(|| format!("SingleQueryResult::round {}", i)); + + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, ); - }); + + let res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); + + let next_index = right_index >> 1; + let next_oracle_value = if i < num_rounds - 1 { + right_index = next_index | 1; + left_index = right_index - 1; + let next_oracle_query = self.oracle_query.get_inner()[i].clone(); + (curr_left, curr_right) = next_oracle_query.query.codepoints.as_ext(); + if next_index & 1 == 0 { + curr_left + } else { + curr_right + } + } else { + // Note that final_codeword has been bit-reversed, so no need to bit-reverse + // next_index here. + final_codeword[next_index] + }; + assert_eq!(res, next_oracle_value, "Failed at round {}", i); + // end_timer!(round_timer); + } end_timer!(timer); } } -pub struct QueriesResult +pub struct SimpleBatchQueriesResult where E::BaseField: Serialize + DeserializeOwned, { - inner: Vec<(usize, SingleQueryResult)>, + inner: Vec<(usize, SimpleBatchSingleQueryResult)>, } -pub struct QueriesResultWithMerklePath +pub struct SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, { - inner: Vec<(usize, SingleQueryResultWithMerklePath)>, + inner: Vec<(usize, SimpleBatchSingleQueryResultWithMerklePath)>, } -impl QueriesResultWithMerklePath +impl SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, { pub fn from_query_result( - query_result: QueriesResult, - oracle_trees: &Vec>, + query_result: SimpleBatchQueriesResult, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { Self { @@ -1083,7 +1943,7 @@ where .map(|(i, q)| { ( i, - SingleQueryResultWithMerklePath::from_single_query_result( + SimpleBatchSingleQueryResultWithMerklePath::from_single_query_result( q, oracle_trees, commitment, @@ -1106,6 +1966,7 @@ where log_rate: usize, poly_num_vars: usize, indices: &[usize], + batch_size: usize, ) -> Self { Self { inner: indices @@ -1113,12 +1974,13 @@ where .map(|index| { ( *index, - SingleQueryResultWithMerklePath::read_transcript_base( + SimpleBatchSingleQueryResultWithMerklePath::read_transcript_base( transcript, num_rounds, log_rate, poly_num_vars, *index, + batch_size, ), ) }) @@ -1132,6 +1994,7 @@ where log_rate: usize, poly_num_vars: usize, indices: &[usize], + batch_size: usize, ) -> Self { Self { inner: indices @@ -1139,12 +2002,13 @@ where .map(|index| { ( *index, - SingleQueryResultWithMerklePath::read_transcript_ext( + SimpleBatchSingleQueryResultWithMerklePath::read_transcript_ext( transcript, num_rounds, log_rate, poly_num_vars, *index, + batch_size, ), ) }) @@ -1152,29 +2016,30 @@ where } } - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, - cipher: ctr::Ctr32LE, hasher: &Hasher, ) { let timer = start_timer!(|| "QueriesResult::check"); self.inner.par_iter().for_each(|(index, query)| { - query.check( + query.check::( + vp, fold_challenges, + batch_coeffs, num_rounds, num_vars, - log_rate, final_codeword, roots, comm, - cipher.clone(), *index, hasher, ); @@ -1182,199 +2047,3 @@ where end_timer!(timer); } } - -pub fn batch_query_phase( - transcript: &mut impl TranscriptWrite, E>, - codeword_size: usize, - comms: &[&BasefoldCommitmentWithData], - oracles: &Vec>, - num_verifier_queries: usize, -) -> BatchedQueriesResult -where - E::BaseField: Serialize + DeserializeOwned, -{ - let queries = transcript.squeeze_challenges(num_verifier_queries); - - // Transform the challenge queries from field elements into integers - let queries_usize: Vec = queries - .iter() - .map(|x_index| ext_to_usize(x_index) % codeword_size) - .collect_vec(); - - BatchedQueriesResult { - inner: queries_usize - .par_iter() - .map(|x_index| { - ( - *x_index, - batch_basefold_get_query::(comms, &oracles, codeword_size, *x_index), - ) - }) - .collect(), - } -} - -pub fn verifier_query_phase( - queries: &QueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, - num_rounds: usize, - num_vars: usize, - log_rate: usize, - final_message: &Vec, - roots: &Vec>, - comm: &BasefoldCommitment, - partial_eq: &[E], - rng: ChaCha8Rng, - eval: &E, - hasher: &Hasher, -) where - E::BaseField: Serialize + DeserializeOwned, -{ - let timer = start_timer!(|| "Verifier query phase"); - - let encode_timer = start_timer!(|| "Encode final codeword"); - let mut message = final_message.clone(); - interpolate_over_boolean_hypercube(&mut message); - let mut final_codeword = encode_rs_basecode(&message, 1 << log_rate, message.len()); - assert_eq!(final_codeword.len(), 1); - let mut final_codeword = final_codeword.remove(0); - reverse_index_bits_in_place(&mut final_codeword); - end_timer!(encode_timer); - - // For computing the weights on the fly, because the verifier is incapable of storing - // the weights. - let aes_timer = start_timer!(|| "Initialize AES"); - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - let mut rng = rng.clone(); - rng.set_word_pos(0); - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - let cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - end_timer!(aes_timer); - - queries.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - &final_codeword, - roots, - comm, - cipher, - hasher, - ); - - let final_timer = start_timer!(|| "Final checks"); - assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); - - // The sum-check part of the protocol - for i in 0..fold_challenges.len() - 1 { - assert_eq!( - degree_2_eval(&sum_check_messages[i], fold_challenges[i]), - degree_2_zero_plus_one(&sum_check_messages[i + 1]) - ); - } - - // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial - // sent from the prover - assert_eq!( - degree_2_eval( - &sum_check_messages[fold_challenges.len() - 1], - fold_challenges[fold_challenges.len() - 1] - ), - inner_product(final_message, partial_eq) - ); - end_timer!(final_timer); - - end_timer!(timer); -} - -pub fn batch_verifier_query_phase( - queries: &BatchedQueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, - num_rounds: usize, - num_vars: usize, - log_rate: usize, - final_message: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, - coeffs: &[E], - partial_eq: &[E], - rng: ChaCha8Rng, - eval: &E, - hasher: &Hasher, -) where - E::BaseField: Serialize + DeserializeOwned, -{ - let timer = start_timer!(|| "Verifier batch query phase"); - let encode_timer = start_timer!(|| "Encode final codeword"); - let mut message = final_message.clone(); - interpolate_over_boolean_hypercube(&mut message); - let mut final_codeword = encode_rs_basecode(&message, 1 << log_rate, message.len()); - assert_eq!(final_codeword.len(), 1); - let mut final_codeword = final_codeword.remove(0); - reverse_index_bits_in_place(&mut final_codeword); - end_timer!(encode_timer); - - // For computing the weights on the fly, because the verifier is incapable of storing - // the weights. - let aes_timer = start_timer!(|| "Initialize AES"); - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - let mut rng = rng.clone(); - rng.set_word_pos(0); - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - let cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - end_timer!(aes_timer); - - queries.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - &final_codeword, - roots, - comms, - coeffs, - cipher, - hasher, - ); - - #[allow(unused)] - let final_timer = start_timer!(|| "Final checks"); - assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); - - // The sum-check part of the protocol - for i in 0..fold_challenges.len() - 1 { - assert_eq!( - degree_2_eval(&sum_check_messages[i], fold_challenges[i]), - degree_2_zero_plus_one(&sum_check_messages[i + 1]) - ); - } - - // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial - // sent from the prover - assert_eq!( - degree_2_eval( - &sum_check_messages[fold_challenges.len() - 1], - fold_challenges[fold_challenges.len() - 1] - ), - inner_product(final_message, partial_eq) - ); - end_timer!(final_timer); - end_timer!(timer); -} diff --git a/mpcs/src/basefold/structure.rs b/mpcs/src/basefold/structure.rs index 8dd39e75b..0a1972c38 100644 --- a/mpcs/src/basefold/structure.rs +++ b/mpcs/src/basefold/structure.rs @@ -2,44 +2,51 @@ use crate::util::{hash::Digest, merkle_tree::MerkleTree}; use core::fmt::Debug; use ff_ext::ExtensionField; +use rand::RngCore; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use multilinear_extensions::mle::FieldType; -use rand_chacha::rand_core::RngCore; +use rand_chacha::ChaCha8Rng; use std::{marker::PhantomData, slice}; +pub use super::encoding::{EncodingProverParameters, EncodingScheme, RSCode, RSCodeDefaultSpec}; +use super::{Basecode, BasecodeDefaultSpec}; + #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldParams +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldParams> where E::BaseField: Serialize + DeserializeOwned, { - pub(super) log_rate: usize, - pub(super) num_verifier_queries: usize, - pub(super) max_num_vars: usize, - pub(super) table_w_weights: Vec>, - pub(super) table: Vec>, - pub(super) rng: Rng, + pub(super) params: >::PublicParameters, } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldProverParams -where - E::BaseField: Serialize + DeserializeOwned, -{ - pub(super) log_rate: usize, - pub(super) table_w_weights: Vec>, - pub(super) table: Vec>, - pub(super) num_verifier_queries: usize, - pub(super) max_num_vars: usize, +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldProverParams> { + pub encoding_params: >::ProverParameters, +} + +impl> BasefoldProverParams { + pub fn get_max_message_size_log(&self) -> usize { + self.encoding_params.get_max_message_size_log() + } } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldVerifierParams { - pub(super) rng: Rng, - pub(super) max_num_vars: usize, - pub(super) log_rate: usize, - pub(super) num_verifier_queries: usize, +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldVerifierParams> { + pub(super) encoding_params: >::VerifierParameters, } /// A polynomial commitment together with all the data (e.g., the codeword, and Merkle tree) @@ -51,9 +58,10 @@ where E::BaseField: Serialize + DeserializeOwned, { pub(crate) codeword_tree: MerkleTree, - pub(crate) bh_evals: FieldType, + pub(crate) polynomials_bh_evals: Vec>, pub(crate) num_vars: usize, pub(crate) is_base: bool, + pub(crate) num_polys: usize, } impl BasefoldCommitmentWithData @@ -61,7 +69,12 @@ where E::BaseField: Serialize + DeserializeOwned, { pub fn to_commitment(&self) -> BasefoldCommitment { - BasefoldCommitment::new(self.codeword_tree.root(), self.num_vars, self.is_base) + BasefoldCommitment::new( + self.codeword_tree.root(), + self.num_vars, + self.is_base, + self.num_polys, + ) } pub fn get_root_ref(&self) -> &Digest { @@ -72,12 +85,16 @@ where Digest::(self.get_root_ref().0) } - pub fn get_codeword(&self) -> &FieldType { + pub fn get_codewords(&self) -> &Vec> { self.codeword_tree.leaves() } + pub fn batch_codewords(&self, coeffs: &Vec) -> Vec { + self.codeword_tree.batch_leaves(coeffs) + } + pub fn codeword_size(&self) -> usize { - self.codeword_tree.size() + self.codeword_tree.size().1 } pub fn codeword_size_log(&self) -> usize { @@ -85,14 +102,14 @@ where } pub fn poly_size(&self) -> usize { - self.bh_evals.len() + 1 << self.num_vars } - pub fn get_codeword_entry_base(&self, index: usize) -> E::BaseField { + pub fn get_codeword_entry_base(&self, index: usize) -> Vec { self.codeword_tree.get_leaf_as_base(index) } - pub fn get_codeword_entry_ext(&self, index: usize) -> E { + pub fn get_codeword_entry_ext(&self, index: usize) -> Vec { self.codeword_tree.get_leaf_as_extension(index) } @@ -101,21 +118,21 @@ where } } -impl Into> for BasefoldCommitmentWithData +impl From> for Digest where E::BaseField: Serialize + DeserializeOwned, { - fn into(self) -> Digest { - self.get_root_as() + fn from(val: BasefoldCommitmentWithData) -> Self { + val.get_root_as() } } -impl Into> for &BasefoldCommitmentWithData +impl From<&BasefoldCommitmentWithData> for BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - fn into(self) -> BasefoldCommitment { - self.to_commitment() + fn from(val: &BasefoldCommitmentWithData) -> Self { + val.to_commitment() } } @@ -128,17 +145,24 @@ where pub(super) root: Digest, pub(super) num_vars: Option, pub(super) is_base: bool, + pub(super) num_polys: Option, } impl BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - pub fn new(root: Digest, num_vars: usize, is_base: bool) -> Self { + pub fn new( + root: Digest, + num_vars: usize, + is_base: bool, + num_polys: usize, + ) -> Self { Self { root, num_vars: Some(num_vars), is_base, + num_polys: Some(num_polys), } } @@ -153,14 +177,6 @@ where pub fn is_base(&self) -> bool { self.is_base } - - pub fn as_challenge_field(&self) -> BasefoldCommitment { - BasefoldCommitment:: { - root: Digest::(self.root().0), - num_vars: self.num_vars, - is_base: self.is_base, - } - } } impl PartialEq for BasefoldCommitmentWithData @@ -168,7 +184,8 @@ where E::BaseField: Serialize + DeserializeOwned, { fn eq(&self, other: &Self) -> bool { - self.get_codeword().eq(other.get_codeword()) && self.bh_evals.eq(&other.bh_evals) + self.get_codewords().eq(other.get_codewords()) + && self.polynomials_bh_evals.eq(&other.polynomials_bh_evals) } } @@ -177,37 +194,50 @@ impl Eq for BasefoldCommitmentWithData where { } -pub trait BasefoldExtParams: Debug { - fn get_reps() -> usize; +pub trait BasefoldSpec: Debug + Clone { + type EncodingScheme: EncodingScheme; - fn get_rate() -> usize; + fn get_number_queries() -> usize { + Self::EncodingScheme::get_number_queries() + } - fn get_basecode() -> usize; + fn get_rate_log() -> usize { + Self::EncodingScheme::get_rate_log() + } + + fn get_basecode_msg_size_log() -> usize { + Self::EncodingScheme::get_basecode_msg_size_log() + } } -#[derive(Debug)] -pub struct BasefoldDefaultParams; +#[derive(Debug, Clone)] +pub struct BasefoldBasecodeParams; -impl BasefoldExtParams for BasefoldDefaultParams { - fn get_reps() -> usize { - return 260; - } +impl BasefoldSpec for BasefoldBasecodeParams +where + E::BaseField: Serialize + DeserializeOwned, +{ + type EncodingScheme = Basecode; +} - fn get_rate() -> usize { - return 3; - } +#[derive(Debug, Clone)] +pub struct BasefoldRSParams; - fn get_basecode() -> usize { - return 7; - } +impl BasefoldSpec for BasefoldRSParams +where + E::BaseField: Serialize + DeserializeOwned, +{ + type EncodingScheme = RSCode; } #[derive(Debug)] -pub struct Basefold(PhantomData<(E, V)>); +pub struct Basefold, Rng: RngCore>( + PhantomData<(E, Spec, Rng)>, +); -pub type BasefoldDefault = Basefold; +pub type BasefoldDefault = Basefold; -impl Clone for Basefold { +impl, Rng: RngCore> Clone for Basefold { fn clone(&self) -> Self { Self(PhantomData) } diff --git a/mpcs/src/basefold/sumcheck.rs b/mpcs/src/basefold/sumcheck.rs index a6a84f0eb..bb5312111 100644 --- a/mpcs/src/basefold/sumcheck.rs +++ b/mpcs/src/basefold/sumcheck.rs @@ -7,8 +7,8 @@ use rayon::prelude::{ }; pub fn sum_check_first_round_field_type( - mut eq: &mut Vec, - mut bh_values: &mut FieldType, + eq: &mut Vec, + bh_values: &mut FieldType, ) -> Vec { // The input polynomials are in the form of evaluations. Instead of viewing // every one element as the evaluation of the polynomial at a single point, @@ -16,24 +16,21 @@ pub fn sum_check_first_round_field_type( // a single point, leaving the first variable free, and obtaining a univariate // polynomial. The one_level_interp_hc transforms the evaluation forms into // the coefficient forms, for every of these partial polynomials. - one_level_interp_hc(&mut eq); - one_level_interp_hc_field_type(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc_field_type(bh_values); parallel_pi_field_type(bh_values, eq) // p_i(&bh_values, &eq) } -pub fn sum_check_first_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, -) -> Vec { +pub fn sum_check_first_round(eq: &mut Vec, bh_values: &mut Vec) -> Vec { // The input polynomials are in the form of evaluations. Instead of viewing // every one element as the evaluation of the polynomial at a single point, // we can view every two elements as partially evaluating the polynomial at // a single point, leaving the first variable free, and obtaining a univariate // polynomial. The one_level_interp_hc transforms the evaluation forms into // the coefficient forms, for every of these partial polynomials. - one_level_interp_hc(&mut eq); - one_level_interp_hc(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc(bh_values); parallel_pi(bh_values, eq) // p_i(&bh_values, &eq) } @@ -51,7 +48,7 @@ pub fn one_level_interp_hc(evals: &mut Vec) { return; } evals.par_chunks_mut(2).for_each(|chunk| { - chunk[1] = chunk[1] - chunk[0]; + chunk[1] -= chunk[0]; }); } @@ -68,15 +65,15 @@ pub fn one_level_eval_hc(evals: &mut Vec, challenge: F) { }); } -fn parallel_pi_field_type(evals: &mut FieldType, eq: &mut Vec) -> Vec { +fn parallel_pi_field_type(evals: &mut FieldType, eq: &mut [E]) -> Vec { match evals { - FieldType::Ext(evals) => parallel_pi(evals, &eq), - FieldType::Base(evals) => parallel_pi_base(evals, &eq), + FieldType::Ext(evals) => parallel_pi(evals, eq), + FieldType::Base(evals) => parallel_pi_base(evals, eq), _ => unreachable!(), } } -fn parallel_pi(evals: &Vec, eq: &Vec) -> Vec { +fn parallel_pi(evals: &[F], eq: &[F]) -> Vec { if evals.len() == 1 { return vec![evals[0], evals[0], evals[0]]; } @@ -111,7 +108,7 @@ fn parallel_pi(evals: &Vec, eq: &Vec) -> Vec { coeffs } -fn parallel_pi_base(evals: &Vec, eq: &Vec) -> Vec { +fn parallel_pi_base(evals: &[E::BaseField], eq: &[E]) -> Vec { if evals.len() == 1 { return vec![E::from(evals[0]), E::from(evals[0]), E::from(evals[0])]; } @@ -147,31 +144,27 @@ fn parallel_pi_base(evals: &Vec, eq: &Vec) - } pub fn sum_check_challenge_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, + eq: &mut Vec, + bh_values: &mut Vec, challenge: F, ) -> Vec { // Note that when the last round ends, every two elements are in // the coefficient form. Use the challenge to reduce the two elements // into a single value. This is equivalent to substituting the challenge // to the first variable of the poly. - one_level_eval_hc(&mut bh_values, challenge); - one_level_eval_hc(&mut eq, challenge); + one_level_eval_hc(bh_values, challenge); + one_level_eval_hc(eq, challenge); - one_level_interp_hc(&mut eq); - one_level_interp_hc(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc(bh_values); - parallel_pi(&bh_values, &eq) + parallel_pi(bh_values, eq) // p_i(&bh_values,&eq) } -pub fn sum_check_last_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, - challenge: F, -) { - one_level_eval_hc(&mut bh_values, challenge); - one_level_eval_hc(&mut eq, challenge); +pub fn sum_check_last_round(eq: &mut Vec, bh_values: &mut Vec, challenge: F) { + one_level_eval_hc(bh_values, challenge); + one_level_eval_hc(eq, challenge); } #[cfg(test)] @@ -184,7 +177,7 @@ mod tests { use super::*; use crate::util::test::rand_vec; - pub fn p_i(evals: &Vec, eq: &Vec) -> Vec { + pub fn p_i(evals: &[F], eq: &[F]) -> Vec { if evals.len() == 1 { return vec![evals[0], evals[0], evals[0]]; } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 216b50d52..277f98fd0 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -29,8 +29,9 @@ pub fn pcs_setup>( pub fn pcs_trim>( param: &Pcs::Param, + poly_size: usize, ) -> Result<(Pcs::ProverParam, Pcs::VerifierParam), Error> { - Pcs::trim(param) + Pcs::trim(param, poly_size) } pub fn pcs_commit>( @@ -50,16 +51,16 @@ pub fn pcs_commit_and_write>( pp: &Pcs::ProverParam, - polys: &Vec>, -) -> Result, Error> { + polys: &[DenseMultilinearExtension], +) -> Result { Pcs::batch_commit(pp, polys) } -pub fn pcs_batch_commit_and_write<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( +pub fn pcs_batch_commit_and_write>( pp: &Pcs::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, -) -> Result, Error> { +) -> Result { Pcs::batch_commit_and_write(pp, polys, transcript) } @@ -76,8 +77,8 @@ pub fn pcs_open>( pub fn pcs_batch_open>( pp: &Pcs::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Pcs::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, @@ -114,7 +115,7 @@ pub fn pcs_verify>( pub fn pcs_batch_verify<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( vp: &Pcs::VerifierParam, - comms: &Vec, + comms: &[Pcs::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, @@ -136,7 +137,10 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn setup(poly_size: usize, rng: &Self::Rng) -> Result; - fn trim(param: &Self::Param) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; + fn trim( + param: &Self::Param, + poly_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; fn commit( pp: &Self::ProverParam, @@ -151,14 +155,14 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_commit( pp: &Self::ProverParam, - polys: &Vec>, - ) -> Result, Error>; + polys: &[DenseMultilinearExtension], + ) -> Result; fn batch_commit_and_write( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, - ) -> Result, Error>; + ) -> Result; fn open( pp: &Self::ProverParam, @@ -171,13 +175,26 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, ) -> Result<(), Error>; + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + pp: &Self::ProverParam, + polys: &[DenseMultilinearExtension], + comm: &Self::CommitmentWithData, + point: &[E], + evals: &[E], + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error>; + fn read_commitment( vp: &Self::VerifierParam, transcript: &mut impl TranscriptRead, @@ -203,11 +220,19 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_verify( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, ) -> Result<(), Error>; + + fn simple_batch_verify( + vp: &Self::VerifierParam, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error>; } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -234,8 +259,8 @@ where fn ni_batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], ) -> Result, Error> { @@ -257,7 +282,7 @@ where fn ni_batch_verify<'a>( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], proof: &PCSProof, @@ -308,18 +333,19 @@ pub enum Error { mod basefold; pub use basefold::{ - Basefold, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, - BasefoldDefaultParams, BasefoldExtParams, BasefoldParams, + coset_fft, fft, fft_root_table, Basecode, BasecodeDefaultSpec, Basefold, + BasefoldBasecodeParams, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, + BasefoldParams, BasefoldRSParams, BasefoldSpec, EncodingScheme, RSCode, RSCodeDefaultSpec, }; fn validate_input( function: &str, param_num_vars: usize, - polys: &Vec>, - points: &Vec>, + polys: &[DenseMultilinearExtension], + points: &[Vec], ) -> Result<(), Error> { - let polys = polys.into_iter().collect_vec(); - let points = points.into_iter().collect_vec(); + let polys = polys.iter().collect_vec(); + let points = points.iter().collect_vec(); for poly in polys.iter() { if param_num_vars < poly.num_vars { return Err(err_too_many_variates( @@ -377,7 +403,7 @@ pub mod test_util { let rng = ChaCha8Rng::from_seed([0u8; 32]); let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; // Commit and open let proof = { @@ -402,15 +428,13 @@ pub mod test_util { // Verify let result = { let mut transcript = T::from_proof(proof.as_slice()); - let result = Pcs::verify( + Pcs::verify( &vp, &Pcs::read_commitment(&vp, &mut transcript).unwrap(), &transcript.squeeze_challenges(num_vars), &transcript.read_field_element_ext().unwrap(), &mut transcript, - ); - - result + ) }; result.unwrap(); } @@ -435,7 +459,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -459,7 +483,11 @@ pub mod test_util { } }) .collect_vec(); - let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) + .collect_vec(); let points = (0..num_points) .map(|i| transcript.squeeze_challenges(num_vars - i)) @@ -511,14 +539,77 @@ pub mod test_util { result.unwrap(); } } + + pub(super) fn run_simple_batch_commit_open_verify( + base: bool, + num_vars_start: usize, + num_vars_end: usize, + batch_size: usize, + ) where + E: ExtensionField, + Pcs: PolynomialCommitmentScheme, + T: TranscriptRead + + TranscriptWrite + + InMemoryTranscript, + { + for num_vars in num_vars_start..num_vars_end { + let rng = ChaCha8Rng::from_seed([0u8; 32]); + // Setup + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + + let proof = { + let mut transcript = T::new(); + let polys = (0..batch_size) + .map(|_| { + if base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + let point = transcript.squeeze_challenges(num_vars); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); + + transcript.write_field_elements_ext(&evals).unwrap(); + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + transcript.into_proof() + }; + // Batch verify + let result = { + let mut transcript = T::from_proof(proof.as_slice()); + let comms = &Pcs::read_commitment(&vp, &mut transcript).unwrap(); + + let point = transcript.squeeze_challenges(num_vars); + let evals = transcript.read_field_elements_ext(batch_size).unwrap(); + + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript) + }; + + result.unwrap(); + } + } } #[cfg(test)] mod test { use crate::{ - basefold::{Basefold, BasefoldExtParams}, + basefold::Basefold, util::transcript::{FieldTranscript, InMemoryTranscript, PoseidonTranscript}, - PolynomialCommitmentScheme, + BasefoldRSParams, PolynomialCommitmentScheme, }; use goldilocks::GoldilocksExt2; use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; @@ -526,22 +617,7 @@ mod test { use rand_chacha::ChaCha8Rng; #[test] fn test_transcript() { - #[derive(Debug)] - pub struct Five {} - - impl BasefoldExtParams for Five { - fn get_reps() -> usize { - return 5; - } - fn get_rate() -> usize { - return 3; - } - fn get_basecode() -> usize { - return 2; - } - } - - type Pcs = Basefold; + type Pcs = Basefold; let num_vars = 10; let rng = ChaCha8Rng::from_seed([0u8; 32]); let poly_size = 1 << num_vars; @@ -550,7 +626,9 @@ mod test { let param = >::setup(poly_size, &rng).unwrap(); - let (pp, vp) = >::trim(¶m).unwrap(); + let (pp, vp) = + >::trim(¶m, 1 << num_vars) + .unwrap(); println!("before commit"); let comm = >::commit_and_write( &pp, diff --git a/mpcs/src/sum_check.rs b/mpcs/src/sum_check.rs index e12233cf9..8ea3517af 100644 --- a/mpcs/src/sum_check.rs +++ b/mpcs/src/sum_check.rs @@ -93,8 +93,8 @@ pub fn evaluate( &|query| evals[&query], &|idx| challenges[idx], &|scalar| -scalar, - &|lhs, rhs| lhs + &rhs, - &|lhs, rhs| lhs * &rhs, + &|lhs, rhs| lhs + rhs, + &|lhs, rhs| lhs * rhs, &|value, scalar| scalar * value, ) } @@ -104,11 +104,7 @@ pub fn lagrange_eval(x: &[F], b: usize) -> F { product(x.iter().enumerate().map( |(idx, x_i)| { - if b.nth_bit(idx) { - *x_i - } else { - F::ONE - x_i - } + if b.nth_bit(idx) { *x_i } else { F::ONE - x_i } }, )) } diff --git a/mpcs/src/sum_check/classic.rs b/mpcs/src/sum_check/classic.rs index 439535e71..853d25602 100644 --- a/mpcs/src/sum_check/classic.rs +++ b/mpcs/src/sum_check/classic.rs @@ -116,10 +116,8 @@ impl<'a, E: ExtensionField> ProverState<'a, E> { .expression .used_rotation() .into_iter() - .filter_map(|rotation| { - (rotation != Rotation::cur()) - .then(|| (rotation, self.bh.rotation_map(rotation))) - }) + .filter(|&rotation| (rotation != Rotation::cur())) + .map(|rotation| (rotation, self.bh.rotation_map(rotation))) .collect::>(); for query in self.expression.used_query() { if query.rotation() != Rotation::cur() { @@ -312,7 +310,7 @@ mod tests { #[test] fn test_sum_check_protocol() { - let polys = vec![ + let polys = [ DenseMultilinearExtension::::from_evaluations_vec( 2, vec![Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)], diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 9e596fe08..470883d7f 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -205,7 +205,7 @@ impl ClassicSumCheckProver for CoefficientsProver { products.iter_mut().for_each(|(lhs, _)| { *lhs *= &rhs; }); - (constant * &rhs, products) + (constant * rhs, products) }, ); Self(constant, flattened) @@ -215,7 +215,7 @@ impl ClassicSumCheckProver for CoefficientsProver { // Initialize h(X) to zero let mut coeffs = Coefficients(FieldType::Ext(vec![E::ZERO; state.expression.degree() + 1])); // First, sum the constant over the hypercube and add to h(X) - coeffs += &(E::from(state.size() as u64) * &self.0); + coeffs += &(E::from(state.size() as u64) * self.0); // Next, for every product of polynomials, where each product is assumed to be exactly 2 // put this into h(X). if self.1.iter().all(|(_, products)| products.len() == 2) { @@ -287,11 +287,11 @@ impl CoefficientsProver { .take(n) .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { let coeff_0 = lhs_0 * rhs_0; - let coeff_2 = (lhs_1 - lhs_0) * &(rhs_1 - rhs_0); + let coeff_2 = (lhs_1 - lhs_0) * (rhs_1 - rhs_0); coeffs[0] += &coeff_0; coeffs[2] += &coeff_2; if !LAZY { - coeffs[1] += &(lhs_1 * rhs_1 - &coeff_0 - &coeff_2); + coeffs[1] += &(lhs_1 * rhs_1 - coeff_0 - coeff_2); } }); }; diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index cf056d51a..260986b2f 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -113,6 +113,42 @@ pub fn field_type_index_ext(poly: &FieldType, index: usize } } +pub fn field_type_index_mul_base( + poly: &mut FieldType, + index: usize, + scalar: &E::BaseField, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] *= scalar, + FieldType::Base(coeffs) => coeffs[index] *= scalar, + _ => unreachable!(), + } +} + +pub fn field_type_index_set_base( + poly: &mut FieldType, + index: usize, + scalar: &E::BaseField, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] = E::from(*scalar), + FieldType::Base(coeffs) => coeffs[index] = *scalar, + _ => unreachable!(), + } +} + +pub fn field_type_index_set_ext( + poly: &mut FieldType, + index: usize, + scalar: &E, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] = *scalar, + FieldType::Base(_) => panic!("Cannot set base field from extension field"), + _ => unreachable!(), + } +} + pub struct FieldTypeIterExt<'a, E: ExtensionField> { inner: &'a FieldType, index: usize, diff --git a/mpcs/src/util/arithmetic.rs b/mpcs/src/util/arithmetic.rs index f4ca552d8..609f65455 100644 --- a/mpcs/src/util/arithmetic.rs +++ b/mpcs/src/util/arithmetic.rs @@ -80,7 +80,7 @@ pub fn inner_product<'a, 'b, F: Field>( rhs: impl IntoIterator, ) -> F { lhs.into_iter() - .zip_eq(rhs.into_iter()) + .zip_eq(rhs) .map(|(lhs, rhs)| *lhs * rhs) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -92,8 +92,8 @@ pub fn inner_product_three<'a, 'b, 'c, F: Field>( c: impl IntoIterator, ) -> F { a.into_iter() - .zip_eq(b.into_iter()) - .zip_eq(c.into_iter()) + .zip_eq(b) + .zip_eq(c) .map(|((a, b), c)| *a * b * c) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -107,8 +107,9 @@ pub fn barycentric_weights(points: &[F]) -> Vec { points .iter() .enumerate() - .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i)) - .reduce(|acc, value| acc * &value) + .filter(|&(i, _point_i)| (i != j)) + .map(|(_i, point_i)| *point_j - point_i) + .reduce(|acc, value| acc * value) .unwrap_or(F::ONE) }) .collect_vec(); @@ -126,7 +127,7 @@ pub fn barycentric_interpolate(weights: &[F], points: &[F], evals: &[F let sum_inv = coeffs.iter().fold(F::ZERO, |sum, coeff| sum + coeff); (coeffs, sum_inv.invert().unwrap()) }; - inner_product(&coeffs, evals) * &sum_inv + inner_product(&coeffs, evals) * sum_inv } pub fn modulus() -> BigUint { @@ -134,11 +135,7 @@ pub fn modulus() -> BigUint { } pub fn fe_from_bool(value: bool) -> F { - if value { - F::ONE - } else { - F::ZERO - } + if value { F::ONE } else { F::ZERO } } pub fn fe_mod_from_le_bytes(bytes: impl AsRef<[u8]>) -> F { @@ -221,17 +218,17 @@ pub fn interpolate2(points: [(F, F); 2], x: F) -> F { a1 + (x - a0) * (b1 - a1) * (b0 - a0).invert().unwrap() } -pub fn degree_2_zero_plus_one(poly: &Vec) -> F { +pub fn degree_2_zero_plus_one(poly: &[F]) -> F { poly[0] + poly[0] + poly[1] + poly[2] } -pub fn degree_2_eval(poly: &Vec, point: F) -> F { +pub fn degree_2_eval(poly: &[F], point: F) -> F { poly[0] + point * poly[1] + point * point * poly[2] } -pub fn base_from_raw_bytes(bytes: &Vec) -> E::BaseField { +pub fn base_from_raw_bytes(bytes: &[u8]) -> E::BaseField { let mut res = E::BaseField::ZERO; - bytes.into_iter().for_each(|b| { + bytes.iter().for_each(|b| { res += E::BaseField::from(u64::from(*b)); }); res diff --git a/mpcs/src/util/arithmetic/hypercube.rs b/mpcs/src/util/arithmetic/hypercube.rs index 16c9868af..ccd6294e3 100644 --- a/mpcs/src/util/arithmetic/hypercube.rs +++ b/mpcs/src/util/arithmetic/hypercube.rs @@ -1,4 +1,3 @@ -use ark_std::{end_timer, start_timer}; use ff::Field; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; @@ -30,7 +29,7 @@ pub fn interpolate_over_boolean_hypercube(evals: &mut Vec) { evals.par_chunks_mut(chunk_size).for_each(|chunk| { let half_chunk = chunk_size >> 1; for j in half_chunk..chunk_size { - chunk[j] = chunk[j] - chunk[j - half_chunk]; + chunk[j] -= chunk[j - half_chunk]; } }); } diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 4842f6e09..fa8010970 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -40,6 +40,54 @@ pub fn hash_two_leaves_base( Digest(result) } +pub fn hash_two_leaves_batch_ext( + a: &[E], + b: &[E], + hasher: &Hasher, +) -> Digest { + let mut left_hasher = hasher.clone(); + a.iter().for_each(|a| left_hasher.update(a.as_bases())); + let left = Digest( + left_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + let mut right_hasher = hasher.clone(); + b.iter().for_each(|b| right_hasher.update(b.as_bases())); + let right = Digest( + right_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + hash_two_digests(&left, &right, hasher) +} + +pub fn hash_two_leaves_batch_base( + a: &Vec, + b: &Vec, + hasher: &Hasher, +) -> Digest { + let mut left_hasher = hasher.clone(); + left_hasher.update(a.as_slice()); + let left = Digest( + left_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + let mut right_hasher = hasher.clone(); + right_hasher.update(b.as_slice()); + let right = Digest( + right_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + hash_two_digests(&left, &right, hasher) +} + pub fn hash_two_digests( a: &Digest, b: &Digest, diff --git a/mpcs/src/util/merkle_tree.rs b/mpcs/src/util/merkle_tree.rs index 1c563f126..b3d1bf899 100644 --- a/mpcs/src/util/merkle_tree.rs +++ b/mpcs/src/util/merkle_tree.rs @@ -1,4 +1,5 @@ use ff_ext::ExtensionField; +use itertools::Itertools; use multilinear_extensions::mle::FieldType; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, @@ -6,7 +7,11 @@ use rayon::{ }; use crate::util::{ - hash::{hash_two_digests, hash_two_leaves_base, hash_two_leaves_ext, Digest, Hasher}, + field_type_index_base, field_type_index_ext, + hash::{ + hash_two_digests, hash_two_leaves_base, hash_two_leaves_batch_base, + hash_two_leaves_batch_ext, hash_two_leaves_ext, Digest, Hasher, + }, log2_strict, transcript::{TranscriptRead, TranscriptWrite}, Deserialize, DeserializeOwned, Serialize, @@ -21,7 +26,7 @@ where E::BaseField: Serialize + DeserializeOwned, { inner: Vec>>, - leaves: FieldType, + leaves: Vec>, } impl MerkleTree @@ -30,7 +35,14 @@ where { pub fn from_leaves(leaves: FieldType, hasher: &Hasher) -> Self { Self { - inner: merkelize::(&leaves, hasher), + inner: merkelize::(&[&leaves], hasher), + leaves: vec![leaves], + } + } + + pub fn from_batch_leaves(leaves: Vec>, hasher: &Hasher) -> Self { + Self { + inner: merkelize::(&leaves.iter().collect_vec(), hasher), leaves, } } @@ -47,26 +59,52 @@ where self.inner.len() } - pub fn leaves(&self) -> &FieldType { + pub fn leaves(&self) -> &Vec> { &self.leaves } - pub fn size(&self) -> usize { - self.leaves.len() + pub fn batch_leaves(&self, coeffs: &Vec) -> Vec { + (0..self.leaves[0].len()) + .map(|i| { + self.leaves + .iter() + .zip(coeffs) + .map(|(leaf, coeff)| field_type_index_ext(leaf, i) * *coeff) + .sum() + }) + .collect() + } + + pub fn size(&self) -> (usize, usize) { + (self.leaves.len(), self.leaves[0].len()) } - pub fn get_leaf_as_base(&self, index: usize) -> E::BaseField { - match &self.leaves { - FieldType::Base(leaves) => leaves[index], - FieldType::Ext(_) => panic!("Mismatching field type, calling get_leaf_as_base on a Merkle tree over extension fields"), + pub fn get_leaf_as_base(&self, index: usize) -> Vec { + match &self.leaves[0] { + FieldType::Base(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_base(leaves, index)) + .collect(), + FieldType::Ext(_) => panic!( + "Mismatching field type, calling get_leaf_as_base on a Merkle tree over extension fields" + ), FieldType::Unreachable => unreachable!(), } } - pub fn get_leaf_as_extension(&self, index: usize) -> E { - match &self.leaves { - FieldType::Base(leaves) => E::from(leaves[index]), - FieldType::Ext(leaves) => leaves[index], + pub fn get_leaf_as_extension(&self, index: usize) -> Vec { + match &self.leaves[0] { + FieldType::Base(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_ext(leaves, index)) + .collect(), + FieldType::Ext(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_ext(leaves, index)) + .collect(), FieldType::Unreachable => unreachable!(), } } @@ -75,7 +113,7 @@ where &self, leaf_index: usize, ) -> MerklePathWithoutLeafOrRoot { - assert!(leaf_index < self.size()); + assert!(leaf_index < self.size().1); MerklePathWithoutLeafOrRoot::::new( self.inner .iter() @@ -105,6 +143,10 @@ where Self { inner } } + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + pub fn len(&self) -> usize { self.inner.len() } @@ -164,28 +206,100 @@ where hasher, ) } + + pub fn authenticate_batch_leaves_root_ext( + &self, + left: Vec, + right: Vec, + index: usize, + root: &Digest, + hasher: &Hasher, + ) { + authenticate_merkle_path_root_batch::( + &self.inner, + FieldType::Ext(left), + FieldType::Ext(right), + index, + root, + hasher, + ) + } + + pub fn authenticate_batch_leaves_root_base( + &self, + left: Vec, + right: Vec, + index: usize, + root: &Digest, + hasher: &Hasher, + ) { + authenticate_merkle_path_root_batch::( + &self.inner, + FieldType::Base(left), + FieldType::Base(right), + index, + root, + hasher, + ) + } } fn merkelize( - values: &FieldType, + values: &[&FieldType], hasher: &Hasher, ) -> Vec>> { - let timer = start_timer!(|| format!("merkelize {} values", values.len())); - let log_v = log2_strict(values.len()); + #[cfg(feature = "sanity-check")] + for i in 0..(values.len() - 1) { + assert_eq!(values[i].len(), values[i + 1].len()); + } + let timer = start_timer!(|| format!("merkelize {} values", values[0].len() * values.len())); + let log_v = log2_strict(values[0].len()); let mut tree = Vec::with_capacity(log_v); // The first layer of hashes, half the number of leaves - let mut hashes = vec![Digest::default(); values.len() >> 1]; - hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = match values { - FieldType::Base(values) => { - hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1], hasher) - } - FieldType::Ext(values) => { - hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1], hasher) - } - FieldType::Unreachable => unreachable!(), - }; - }); + let mut hashes = vec![Digest::default(); values[0].len() >> 1]; + if values.len() == 1 { + hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { + *hash = match &values[0] { + FieldType::Base(values) => { + hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1], hasher) + } + FieldType::Ext(values) => { + hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1], hasher) + } + FieldType::Unreachable => unreachable!(), + }; + }); + } else { + hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { + *hash = match &values[0] { + FieldType::Base(_) => hash_two_leaves_batch_base::( + &values + .iter() + .map(|values| field_type_index_base(values, i << 1)) + .collect(), + &values + .iter() + .map(|values| field_type_index_base(values, (i << 1) + 1)) + .collect(), + hasher, + ), + FieldType::Ext(_) => hash_two_leaves_batch_ext::( + values + .iter() + .map(|values| field_type_index_ext(values, i << 1)) + .collect_vec() + .as_slice(), + values + .iter() + .map(|values| field_type_index_ext(values, (i << 1) + 1)) + .collect_vec() + .as_slice(), + hasher, + ), + FieldType::Unreachable => unreachable!(), + }; + }); + } tree.push(hashes); @@ -202,7 +316,7 @@ fn merkelize( } fn authenticate_merkle_path_root( - path: &Vec>, + path: &[Digest], leaves: FieldType, x_index: usize, root: &Digest, @@ -218,11 +332,43 @@ fn authenticate_merkle_path_root( // The lowest bit in the index is ignored. It can point to either leaves x_index >>= 1; - for i in 0..path.len() { + for path_i in path.iter() { + hash = if x_index & 1 == 0 { + hash_two_digests(&hash, path_i, hasher) + } else { + hash_two_digests(path_i, &hash, hasher) + }; + x_index >>= 1; + } + assert_eq!(&hash, root); +} + +fn authenticate_merkle_path_root_batch( + path: &[Digest], + left: FieldType, + right: FieldType, + x_index: usize, + root: &Digest, + hasher: &Hasher, +) { + let mut x_index = x_index; + let mut hash = match (left, right) { + (FieldType::Base(left), FieldType::Base(right)) => { + hash_two_leaves_batch_base::(&left, &right, hasher) + } + (FieldType::Ext(left), FieldType::Ext(right)) => { + hash_two_leaves_batch_ext::(&left, &right, hasher) + } + _ => unreachable!(), + }; + + // The lowest bit in the index is ignored. It can point to either leaves + x_index >>= 1; + for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, &path[i], hasher) + hash_two_digests(&hash, path_i, hasher) } else { - hash_two_digests(&path[i], &hash, hasher) + hash_two_digests(path_i, &hash, hasher) }; x_index >>= 1; } diff --git a/mpcs/src/util/parallel.rs b/mpcs/src/util/parallel.rs index 5f44d473d..950d43797 100644 --- a/mpcs/src/util/parallel.rs +++ b/mpcs/src/util/parallel.rs @@ -1,10 +1,7 @@ pub fn num_threads() -> usize { #[cfg(feature = "parallel")] let nt = rayon::current_num_threads(); - return nt; - - #[cfg(not(feature = "parallel"))] - return 1; + if cfg!(feature = "parallel") { nt } else { 1 } } pub fn parallelize_iter(iter: I, f: F) diff --git a/mpcs/src/util/transcript.rs b/mpcs/src/util/transcript.rs index 343bfffba..2edf082cf 100644 --- a/mpcs/src/util/transcript.rs +++ b/mpcs/src/util/transcript.rs @@ -21,17 +21,16 @@ pub trait FieldTranscript { fn common_field_element_ext(&mut self, fe: &E) -> Result<(), Error>; fn common_field_elements(&mut self, fes: FieldType) -> Result<(), Error> { - Ok(match fes { + match fes { FieldType::Base(fes) => fes .iter() - .map(|fe| self.common_field_element_base(fe)) - .try_collect()?, + .try_for_each(|fe| self.common_field_element_base(fe))?, FieldType::Ext(fes) => fes .iter() - .map(|fe| self.common_field_element_ext(fe)) - .try_collect()?, + .try_for_each(|fe| self.common_field_element_ext(fe))?, FieldType::Unreachable => unreachable!(), - }) + }; + Ok(()) } } @@ -146,7 +145,7 @@ impl Stream { self.inner.len() - self.pointer } - pub fn read_exact(&mut self, output: &mut Vec) -> Result<(), Error> { + pub fn read_exact(&mut self, output: &mut [T]) -> Result<(), Error> { let left = self.left(); if left < output.len() { return Err(Error::Transcript(