diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 18cecc5f0..4ff1ce8f3 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -46,5 +46,25 @@ sanity-check = [] print-trace = [ "ark-std/print-trace" ] [[bench]] -name = "commit_open_verify" +name = "commit_open_verify_rs" +harness = false + +[[bench]] +name = "commit_open_verify_basecode" +harness = false + +[[bench]] +name = "basecode" +harness = false + +[[bench]] +name = "rscode" +harness = false + +[[bench]] +name = "interpolate" +harness = false + +[[bench]] +name = "fft" harness = false \ No newline at end of file diff --git a/mpcs/benches/basecode.rs b/mpcs/benches/basecode.rs new file mode 100644 index 000000000..7e8ccca99 --- /dev/null +++ b/mpcs/benches/basecode.rs @@ -0,0 +1,110 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::{ + util::{ + arithmetic::interpolate_field_type_over_boolean_hypercube, + plonky2_util::reverse_index_bits_in_place_field_type, + }, Basefold, BasefoldBasecodeParams, BasefoldSpec, EncodingScheme, PolynomialCommitmentScheme +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type Pcs = Basefold; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "encoding_basecode_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, _) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + + let mut codeword = + <>::EncodingScheme as EncodingScheme>::encode( + &pp.encoding_params, + &coeffs, + ); + + // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above + // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); + + // The sum-check protocol starts from the first variable, but the FRI part + // will eventually produce the evaluation at (alpha_k, ..., alpha_1), so apply + // the bit-reversion to reverse the variable indices of the polynomial. + // In short: store the poly and codeword in big endian + reverse_index_bits_in_place_field_type(&mut coeffs); + reverse_index_bits_in_place_field_type(&mut codeword); + + (coeffs, codeword) + }) + .collect::<(Vec>, Vec>)>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify_basecode.rs b/mpcs/benches/commit_open_verify_basecode.rs new file mode 100644 index 000000000..922d90001 --- /dev/null +++ b/mpcs/benches/commit_open_verify_basecode.rs @@ -0,0 +1,395 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::{chain, Itertools}; +use mpcs::{ + util::{ + plonky2_util::log2_ceil, + transcript::{ + FieldTranscript, FieldTranscriptRead, FieldTranscriptWrite, InMemoryTranscript, + PoseidonTranscript, + }, + }, + Basefold, BasefoldBasecodeParams, Evaluation, PolynomialCommitmentScheme, +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Pcs = Basefold; +type T = PoseidonTranscript; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "ext2" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + let (pp, vp) = { + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + + group.bench_function(BenchmarkId::new("setup", format!("{}", num_vars)), |b| { + b.iter(|| { + Pcs::setup(poly_size, &rng).unwrap(); + }) + }); + Pcs::trim(¶m, poly_size).unwrap() + }; + + let proof = { + let mut transcript = T::new(); + let poly = if is_base { + DenseMultilinearExtension::random(num_vars, &mut OsRng) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + }; + + let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); + + group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { + b.iter(|| { + Pcs::commit(&pp, &poly).unwrap(); + }) + }); + + let point = transcript.squeeze_challenges(num_vars); + let eval = poly.evaluate(point.as_slice()); + transcript.write_field_element_ext(&eval).unwrap(); + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + + group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| { + b.iter_batched( + || transcript.clone(), + |mut transcript| { + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + }, + BatchSize::SmallInput, + ); + }); + + transcript.into_proof() + }; + // Verify + let mut transcript = T::from_proof(proof.as_slice()); + Pcs::verify( + &vp, + &Pcs::read_commitment(&vp, &mut transcript).unwrap(), + &transcript.squeeze_challenges(num_vars), + &transcript.read_field_element_ext().unwrap(), + &mut transcript, + ) + .unwrap(); + group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| { + b.iter_batched( + || T::from_proof(proof.as_slice()), + |mut transcript| { + Pcs::verify( + &vp, + &Pcs::read_commitment(&vp, &mut transcript).unwrap(), + &transcript.squeeze_challenges(num_vars), + &transcript.read_field_element_ext().unwrap(), + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }); + } +} + +fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "ext2" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let num_points = batch_size >> 1; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + // Setup + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + // Batch commit and open + let evals = chain![ + (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys + (0..num_points).map(|point| (point * 2 + 1, point)), + ] + .unique() + .collect_vec(); + + let proof = { + let mut transcript = T::new(); + let polys = (0..batch_size) + .map(|i| { + if is_base { + DenseMultilinearExtension::random( + num_vars - log2_ceil((i >> 1) + 1), + &mut rng.clone(), + ) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars - log2_ceil((i >> 1) + 1), + (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) + .map(|_| E::random(&mut OsRng)) + .collect(), + ) + } + }) + .collect_vec(); + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) + .collect_vec(); + + let points = (0..num_points) + .map(|i| transcript.squeeze_challenges(num_vars - log2_ceil(i + 1))) + .take(num_points) + .collect_vec(); + + let evals = evals + .iter() + .copied() + .map(|(poly, point)| { + Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) + }) + .collect_vec(); + transcript + .write_field_elements_ext(evals.iter().map(Evaluation::value)) + .unwrap(); + Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript.clone(), + |mut transcript| { + Pcs::batch_open( + &pp, + &polys, + &comms, + &points, + &evals, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + + transcript.into_proof() + }; + // Batch verify + let mut transcript = T::from_proof(proof.as_slice()); + let comms = &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(); + + let points = (0..num_points) + .map(|i| transcript.squeeze_challenges(num_vars - log2_ceil(i + 1))) + .take(num_points) + .collect_vec(); + + let evals2 = transcript.read_field_elements_ext(evals.len()).unwrap(); + + let backup_transcript = transcript.clone(); + + Pcs::batch_verify( + &vp, + comms, + &points, + &evals + .iter() + .copied() + .zip(evals2.clone()) + .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) + .collect_vec(), + &mut transcript, + ) + .unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::batch_verify( + &vp, + comms, + &points, + &evals + .iter() + .copied() + .zip(evals2.clone()) + .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) + .collect_vec(), + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "simple_batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let proof = { + let mut transcript = T::new(); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + Pcs::batch_commit(&pp, &polys).unwrap(); + }) + }, + ); + + let point = transcript.squeeze_challenges(num_vars); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); + + transcript.write_field_elements_ext(&evals).unwrap(); + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript.clone(), + |mut transcript| { + Pcs::simple_batch_open( + &pp, + &polys, + &comm, + &point, + &evals, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + transcript.into_proof() + }; + // Batch verify + let mut transcript = T::from_proof(proof.as_slice()); + let comms = &Pcs::read_commitment(&vp, &mut transcript).unwrap(); + + let point = transcript.squeeze_challenges(num_vars); + let evals = transcript.read_field_elements_ext(batch_size).unwrap(); + + let backup_transcript = transcript.clone(); + + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_commit_open_verify_goldilocks(c, false); +} + +fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_commit_open_verify_goldilocks(c, true); +} + +fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks(c, true); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2,bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify.rs b/mpcs/benches/commit_open_verify_rs.rs similarity index 60% rename from mpcs/benches/commit_open_verify.rs rename to mpcs/benches/commit_open_verify_rs.rs index 8d731f514..e372e2d11 100644 --- a/mpcs/benches/commit_open_verify.rs +++ b/mpcs/benches/commit_open_verify_rs.rs @@ -6,31 +6,38 @@ use goldilocks::GoldilocksExt2; use itertools::{chain, Itertools}; use mpcs::{ - util::transcript::{ - FieldTranscript, FieldTranscriptRead, FieldTranscriptWrite, InMemoryTranscript, - PoseidonTranscript, + util::{ + plonky2_util::log2_ceil, + transcript::{ + FieldTranscript, FieldTranscriptRead, FieldTranscriptWrite, InMemoryTranscript, + PoseidonTranscript, + }, }, - Basefold, BasefoldDefaultParams, Evaluation, PolynomialCommitmentScheme, + Basefold, BasefoldRSParams, Evaluation, PolynomialCommitmentScheme, }; use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use rand::{rngs::OsRng, SeedableRng}; use rand_chacha::ChaCha8Rng; -type Pcs = Basefold; +type Pcs = Basefold; type T = PoseidonTranscript; type E = GoldilocksExt2; const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let mut group = c.benchmark_group(format!( - "commit_open_verify_goldilocks_{}", + "commit_open_verify_goldilocks_rs_{}", if is_base { "base" } else { "ext2" } )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field - for num_vars in 10..=20 { + for num_vars in NUM_VARS_START..=NUM_VARS_END { let (pp, vp) = { let rng = ChaCha8Rng::from_seed([0u8; 32]); let poly_size = 1 << num_vars; @@ -41,7 +48,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::setup(poly_size, &rng).unwrap(); }) }); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; let proof = { @@ -111,13 +118,13 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let mut group = c.benchmark_group(format!( - "commit_batch_open_verify_goldilocks_{}", - if is_base { "base" } else { "extension" } + "batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "ext2" } )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field - for num_vars in 10..=20 { - for batch_size_log in 1..=6 { + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { let batch_size = 1 << batch_size_log; let num_points = batch_size >> 1; let rng = ChaCha8Rng::from_seed([0u8; 32]); @@ -125,7 +132,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -140,28 +147,27 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let polys = (0..batch_size) .map(|i| { if is_base { - DenseMultilinearExtension::random(num_vars - (i >> 1), &mut rng.clone()) + DenseMultilinearExtension::random( + num_vars - log2_ceil((i >> 1) + 1), + &mut rng.clone(), + ) } else { DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + num_vars - log2_ceil((i >> 1) + 1), + (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) + .map(|_| E::random(&mut OsRng)) + .collect(), ) } }) .collect_vec(); - let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_commit", format!("{}", num_vars)), - |b| { - b.iter(|| { - Pcs::batch_commit(&pp, &polys).unwrap(); - }) - }, - ); + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) + .collect_vec(); let points = (0..num_points) - .map(|i| transcript.squeeze_challenges(num_vars - i)) + .map(|i| transcript.squeeze_challenges(num_vars - log2_ceil(i + 1))) .take(num_points) .collect_vec(); @@ -178,7 +184,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); group.bench_function( - BenchmarkId::new("batch_open", format!("{}", num_vars)), + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), |b| { b.iter_batched( || transcript.clone(), @@ -205,12 +211,14 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let comms = &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(); let points = (0..num_points) - .map(|i| transcript.squeeze_challenges(num_vars - i)) + .map(|i| transcript.squeeze_challenges(num_vars - log2_ceil(i + 1))) .take(num_points) .collect_vec(); let evals2 = transcript.read_field_elements_ext(evals.len()).unwrap(); + let backup_transcript = transcript.clone(); + Pcs::batch_verify( &vp, comms, @@ -226,10 +234,10 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { .unwrap(); group.bench_function( - BenchmarkId::new("batch_verify", format!("{}", num_vars)), + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), |b| { b.iter_batched( - || transcript.clone(), + || backup_transcript.clone(), |mut transcript| { Pcs::batch_verify( &vp, @@ -253,6 +261,107 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { } } +fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "simple_batch_commit_open_verify_goldilocks_rs_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let proof = { + let mut transcript = T::new(); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + Pcs::batch_commit(&pp, &polys).unwrap(); + }) + }, + ); + + let point = transcript.squeeze_challenges(num_vars); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); + + transcript.write_field_elements_ext(&evals).unwrap(); + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + + group.bench_function( + BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || transcript.clone(), + |mut transcript| { + Pcs::simple_batch_open( + &pp, + &polys, + &comm, + &point, + &evals, + &mut transcript, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + transcript.into_proof() + }; + // Batch verify + let mut transcript = T::from_proof(proof.as_slice()); + let comms = &Pcs::read_commitment(&vp, &mut transcript).unwrap(); + + let point = transcript.squeeze_challenges(num_vars); + let evals = transcript.read_field_elements_ext(batch_size).unwrap(); + + let backup_transcript = transcript.clone(); + + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript).unwrap(); + + group.bench_function( + BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter_batched( + || backup_transcript.clone(), + |mut transcript| { + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { bench_commit_open_verify_goldilocks(c, false); } @@ -269,10 +378,18 @@ fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { bench_batch_commit_open_verify_goldilocks(c, true); } +fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, false); +} + +fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks(c, true); +} + criterion_group! { name = bench_basefold; config = Criterion::default().warm_up_time(Duration::from_millis(3000)); - targets = bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2 + targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2,bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, } criterion_main!(bench_basefold); diff --git a/mpcs/benches/fft.rs b/mpcs/benches/fft.rs new file mode 100644 index 000000000..2588a11e2 --- /dev/null +++ b/mpcs/benches/fft.rs @@ -0,0 +1,80 @@ +use std::time::Duration; + +use criterion::*; +use ff::{Field, PrimeField}; +use goldilocks::{Goldilocks, GoldilocksExt2}; + +use itertools::Itertools; +use mpcs::{coset_fft, fft_root_table}; + +use multilinear_extensions::mle::DenseMultilinearExtension; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 6; + +fn bench_fft(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "fft_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + let root_table = fft_root_table(num_vars); + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let mut polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys.par_iter_mut().for_each(|poly| { + coset_fft::( + &mut poly.evaluations, + Goldilocks::MULTIPLICATIVE_GENERATOR, + 0, + &root_table, + ); + }); + }) + }, + ); + } + } +} + +fn bench_fft_goldilocks_2(c: &mut Criterion) { + bench_fft(c, false); +} + +fn bench_fft_base(c: &mut Criterion) { + bench_fft(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_fft_base, bench_fft_goldilocks_2 +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/interpolate.rs b/mpcs/benches/interpolate.rs new file mode 100644 index 000000000..00af7c092 --- /dev/null +++ b/mpcs/benches/interpolate.rs @@ -0,0 +1,81 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::util::arithmetic::interpolate_field_type_over_boolean_hypercube; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "interpolate_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + coeffs + }) + .collect::>>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs new file mode 100644 index 000000000..ad42c7f39 --- /dev/null +++ b/mpcs/benches/rscode.rs @@ -0,0 +1,112 @@ +use std::time::Duration; + +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; + +use itertools::Itertools; +use mpcs::{ + util::{ + arithmetic::interpolate_field_type_over_boolean_hypercube, + plonky2_util::reverse_index_bits_in_place_field_type, + }, + Basefold, BasefoldRSParams, BasefoldSpec, EncodingScheme, + PolynomialCommitmentScheme, +}; + +use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +type Pcs = Basefold; +type E = GoldilocksExt2; + +const NUM_SAMPLES: usize = 10; +const NUM_VARS_START: usize = 15; +const NUM_VARS_END: usize = 20; +const BATCH_SIZE_LOG_START: usize = 3; +const BATCH_SIZE_LOG_END: usize = 5; + +fn bench_encoding(c: &mut Criterion, is_base: bool) { + let mut group = c.benchmark_group(format!( + "encoding_rscode_{}", + if is_base { "base" } else { "extension" } + )); + group.sample_size(NUM_SAMPLES); + // Challenge is over extension field, poly over the base field + for num_vars in NUM_VARS_START..=NUM_VARS_END { + for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { + let batch_size = 1 << batch_size_log; + let rng = ChaCha8Rng::from_seed([0u8; 32]); + let (pp, _) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + let polys = (0..batch_size) + .map(|_| { + if is_base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + + + group.bench_function( + BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), + |b| { + b.iter(|| { + polys + .par_iter() + .map(|poly| { + // Switch to coefficient form + let mut coeffs = poly.evaluations.clone(); + interpolate_field_type_over_boolean_hypercube(&mut coeffs); + + let mut codeword = + <>::EncodingScheme as EncodingScheme>::encode( + &pp.encoding_params, + &coeffs, + ); + + // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above + // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); + + // The sum-check protocol starts from the first variable, but the FRI part + // will eventually produce the evaluation at (alpha_k, ..., alpha_1), so apply + // the bit-reversion to reverse the variable indices of the polynomial. + // In short: store the poly and codeword in big endian + reverse_index_bits_in_place_field_type(&mut coeffs); + reverse_index_bits_in_place_field_type(&mut codeword); + + (coeffs, codeword) + }) + .collect::<(Vec>, Vec>)>(); + }) + }, + ); + } + } +} + +fn bench_encoding_goldilocks_2(c: &mut Criterion) { + bench_encoding(c, false); +} + +fn bench_encoding_base(c: &mut Criterion) { + bench_encoding(c, true); +} + +criterion_group! { + name = bench_basefold; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_encoding_base, bench_encoding_goldilocks_2, +} + +criterion_main!(bench_basefold); diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 29cc51d1d..bb923ebd7 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -23,13 +23,20 @@ use crate::{ validate_input, Error, Evaluation, NoninteractivePCS, PolynomialCommitmentScheme, }; use ark_std::{end_timer, start_timer}; +pub use encoding::{ + Basecode, BasecodeDefaultSpec, EncodingProverParameters, EncodingScheme, RSCode, + RSCodeDefaultSpec, +}; use ff_ext::ExtensionField; use multilinear_extensions::mle::MultilinearExtension; use query_phase::{ - batch_query_phase, batch_verifier_query_phase, query_phase, verifier_query_phase, + batch_prover_query_phase, batch_verifier_query_phase, prover_query_phase, + simple_batch_prover_query_phase, simple_batch_verifier_query_phase, verifier_query_phase, BatchedQueriesResultWithMerklePath, QueriesResultWithMerklePath, + SimpleBatchQueriesResultWithMerklePath, }; use std::{borrow::BorrowMut, ops::Deref}; +pub use structure::BasefoldSpec; use itertools::Itertools; use serde::{de::DeserializeOwned, Serialize}; @@ -39,102 +46,67 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, }; -use rand_chacha::ChaCha8Rng; -use rayon::prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator}; +use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; +use rayon::{ + iter::IntoParallelIterator, + prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator}, +}; use std::borrow::Cow; + type SumCheck = ClassicSumCheck>; -mod basecode; mod structure; -use basecode::{ - encode_field_type_rs_basecode, evaluate_over_foldable_domain_generic_basecode, get_table_aes, -}; pub use structure::{ - Basefold, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, - BasefoldDefaultParams, BasefoldExtParams, BasefoldParams, BasefoldProverParams, + Basefold, BasefoldBasecodeParams, BasefoldCommitment, BasefoldCommitmentWithData, + BasefoldDefault, BasefoldParams, BasefoldProverParams, BasefoldRSParams, BasefoldVerifierParams, }; mod commit_phase; -use commit_phase::{batch_commit_phase, commit_phase}; +use commit_phase::{batch_commit_phase, commit_phase, simple_batch_commit_phase}; +mod encoding; +pub use encoding::{coset_fft, fft, fft_root_table}; mod query_phase; // This sumcheck module is different from the mpcs::sumcheck module, in that // it deals only with the special case of the form \sum eq(r_i)f_i(). mod sumcheck; -impl PolynomialCommitmentScheme for Basefold +impl, Rng: RngCore> Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, { - type Param = BasefoldParams; - type ProverParam = BasefoldProverParams; - type VerifierParam = BasefoldVerifierParams; - type CommitmentWithData = BasefoldCommitmentWithData; - type Commitment = BasefoldCommitment; - type CommitmentChunk = Digest; - type Rng = ChaCha8Rng; - - fn setup(poly_size: usize, rng: &Self::Rng) -> Result { - let log_rate = V::get_rate(); - let (table_w_weights, table) = get_table_aes::(poly_size, log_rate, &mut rng.clone()); - - Ok(BasefoldParams { - log_rate, - num_verifier_queries: V::get_reps(), - max_num_vars: log2_strict(poly_size), - table_w_weights, - table, - rng: rng.clone(), - }) - } - - fn trim(param: &Self::Param) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { - Ok(( - BasefoldProverParams { - log_rate: param.log_rate, - table_w_weights: param.table_w_weights.clone(), - table: param.table.clone(), - num_verifier_queries: param.num_verifier_queries, - max_num_vars: param.max_num_vars, - }, - BasefoldVerifierParams { - rng: param.rng.clone(), - max_num_vars: param.max_num_vars, - log_rate: param.log_rate, - num_verifier_queries: param.num_verifier_queries, - }, - )) - } - - fn commit( - pp: &Self::ProverParam, + /// Converts a polynomial to a code word, also returns the evaluations over the boolean hypercube + /// for said polynomial + fn get_poly_bh_evals_and_codeword( + pp: &BasefoldProverParams, poly: &DenseMultilinearExtension, - ) -> Result { - let timer = start_timer!(|| "Basefold::commit"); + ) -> (FieldType, FieldType) { // bh_evals is just a copy of poly.evals(). // Note that this function implicitly assumes that the size of poly.evals() is a // power of two. Otherwise, the function crashes with index out of bound. let mut bh_evals = poly.evaluations.clone(); - let num_vars = log2_strict(bh_evals.len()); - assert!(num_vars <= pp.max_num_vars && num_vars >= V::get_basecode()); + let num_vars = poly.num_vars; + assert!( + num_vars <= pp.encoding_params.get_max_message_size_log(), + "num_vars {} > pp.max_num_vars {}", + num_vars, + pp.encoding_params.get_max_message_size_log() + ); + assert!( + num_vars >= Spec::get_basecode_msg_size_log(), + "num_vars {} < Spec::get_basecode_msg_size_log() {}", + num_vars, + Spec::get_basecode_msg_size_log() + ); // Switch to coefficient form let mut coeffs = bh_evals.clone(); interpolate_field_type_over_boolean_hypercube(&mut coeffs); - // Split the input into chunks of message size, encode each message, and return the codewords - let basecode = - encode_field_type_rs_basecode(&coeffs, 1 << pp.log_rate, 1 << V::get_basecode()); - - // Apply the recursive definition of the BaseFold code to the list of base codewords, - // and produce the final codeword - let mut codeword = evaluate_over_foldable_domain_generic_basecode::( - 1 << V::get_basecode(), - coeffs.len(), - pp.log_rate, - basecode, - &pp.table, - ); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place_field_type(&mut coeffs); + } + let mut codeword = Spec::EncodingScheme::encode(&pp.encoding_params, &coeffs); // If using repetition code as basecode, it may be faster to use the following line of code to create the commitment and comment out the two lines above // let mut codeword = evaluate_over_foldable_domain(pp.log_rate, coeffs, &pp.table); @@ -146,6 +118,107 @@ where reverse_index_bits_in_place_field_type(&mut bh_evals); reverse_index_bits_in_place_field_type(&mut codeword); + (bh_evals, codeword) + } + + /// Transpose a matrix of field elements, generic over the type of field element + pub fn transpose_field_type( + matrix: &[FieldType], + ) -> Result>, Error> { + let transpose_fn = match matrix[0] { + FieldType::Ext(_) => Self::get_column_ext, + FieldType::Base(_) => Self::get_column_base, + FieldType::Unreachable => unreachable!(), + }; + + let len = matrix[0].len(); + (0..len) + .into_par_iter() + .map(|i| (transpose_fn)(matrix, i)) + .collect() + } + + fn get_column_base( + matrix: &[FieldType], + column_index: usize, + ) -> Result, Error> { + Ok(FieldType::Base( + matrix + .par_iter() + .map(|row| match row { + FieldType::Base(content) => Ok(content[column_index]), + _ => Err(Error::InvalidPcsParam( + "expected base field type".to_string(), + )), + }) + .collect::, Error>>()?, + )) + } + + fn get_column_ext(matrix: &[FieldType], column_index: usize) -> Result, Error> { + Ok(FieldType::Ext( + matrix + .par_iter() + .map(|row| match row { + FieldType::Ext(content) => Ok(content[column_index]), + _ => Err(Error::InvalidPcsParam( + "expected ext field type".to_string(), + )), + }) + .collect::, Error>>()?, + )) + } +} + +impl, Rng: RngCore + std::fmt::Debug> + PolynomialCommitmentScheme for Basefold +where + E: Serialize + DeserializeOwned, + E::BaseField: Serialize + DeserializeOwned, +{ + type Param = BasefoldParams; + type ProverParam = BasefoldProverParams; + type VerifierParam = BasefoldVerifierParams; + type CommitmentWithData = BasefoldCommitmentWithData; + type Commitment = BasefoldCommitment; + type CommitmentChunk = Digest; + type Rng = ChaCha8Rng; + + fn setup(poly_size: usize, rng: &Self::Rng) -> Result { + let mut seed = [0u8; 32]; + let mut rng = rng.clone(); + rng.fill_bytes(&mut seed); + let pp = >::setup(log2_strict(poly_size), seed); + + Ok(BasefoldParams { params: pp }) + } + + fn trim( + pp: &Self::Param, + poly_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { + >::trim(&pp.params, log2_strict(poly_size)).map( + |(pp, vp)| { + ( + BasefoldProverParams { + encoding_params: pp, + }, + BasefoldVerifierParams { + encoding_params: vp, + }, + ) + }, + ) + } + + fn commit( + pp: &Self::ProverParam, + poly: &DenseMultilinearExtension, + ) -> Result { + let timer = start_timer!(|| "Basefold::commit"); + + let (bh_evals, codeword) = Self::get_poly_bh_evals_and_codeword(pp, poly); + // Compute and store all the layers of the Merkle tree let hasher = new_hasher::(); let codeword_tree = MerkleTree::::from_leaves(codeword, &hasher); @@ -160,44 +233,89 @@ where Ok(Self::CommitmentWithData { codeword_tree, - bh_evals, - num_vars, + polynomials_bh_evals: vec![bh_evals], + num_vars: poly.num_vars, is_base, + num_polys: 1, }) } fn batch_commit_and_write( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, - ) -> Result, Error> { + ) -> Result { let timer = start_timer!(|| "Basefold::batch_commit_and_write"); - let comms = Self::batch_commit(pp, polys)?; - comms.iter().for_each(|comm| { - transcript.write_commitment(&comm.get_root_as()).unwrap(); - transcript - .write_field_element_base(&u32_to_field::(comm.num_vars as u32)) - .unwrap(); - transcript - .write_field_element_base(&u32_to_field::(comm.is_base as u32)) - .unwrap(); - }); + let comm = Self::batch_commit(pp, polys)?; + transcript.write_commitment(&comm.get_root_as()).unwrap(); + transcript + .write_field_element_base(&u32_to_field::(comm.num_vars as u32)) + .unwrap(); + transcript + .write_field_element_base(&u32_to_field::(comm.is_base as u32)) + .unwrap(); + transcript + .write_field_element_base(&u32_to_field::(comm.num_polys as u32)) + .unwrap(); end_timer!(timer); - Ok(comms) + Ok(comm) } fn batch_commit( pp: &Self::ProverParam, - polys: &Vec>, - ) -> Result, Error> { - let polys_vec: Vec<&DenseMultilinearExtension> = - polys.into_iter().map(|poly| poly).collect(); - polys_vec + polys: &[DenseMultilinearExtension], + ) -> Result { + // assumptions + // 1. there must be at least one polynomial + // 2. all polynomials must exist in the same field type + // 3. all polynomials must have the same number of variables + + if polys.is_empty() { + return Err(Error::InvalidPcsParam( + "cannot batch commit to zero polynomials".to_string(), + )); + } + + for i in 1..polys.len() { + if polys[i].num_vars != polys[0].num_vars { + return Err(Error::InvalidPcsParam( + "cannot batch commit to polynomials with different number of variables" + .to_string(), + )); + } + } + + // convert each polynomial to a code word + let (bh_evals, codewords) = polys .par_iter() - .map(|poly| Self::commit(pp, poly)) - .collect() + .map(|poly| Self::get_poly_bh_evals_and_codeword(pp, poly)) + .collect::<(Vec>, Vec>)>(); + + // transpose the codewords, to group evaluations at the same point + // let leaves = Self::transpose_field_type::(codewords.as_slice())?; + + // build merkle tree from leaves + let hasher = new_hasher::(); + let codeword_tree = MerkleTree::::from_batch_leaves(codewords, &hasher); + + let is_base = match polys[0].evaluations { + FieldType::Ext(_) => false, + FieldType::Base(_) => true, + _ => unreachable!(), + }; + + Ok(Self::CommitmentWithData { + codeword_tree, + polynomials_bh_evals: bh_evals, + num_vars: polys[0].num_vars, + is_base, + num_polys: polys.len(), + }) } + /// Open a single polynomial commitment at one point. If the given + /// commitment with data contains more than one polynomial, this function + /// will panic. fn open( pp: &Self::ProverParam, poly: &DenseMultilinearExtension, @@ -208,22 +326,22 @@ where ) -> Result<(), Error> { let hasher = new_hasher::(); let timer = start_timer!(|| "Basefold::open"); - assert!(comm.num_vars >= V::get_basecode()); - let (trees, oracles) = commit_phase( - &point, - &comm, + assert!(comm.num_vars >= Spec::get_basecode_msg_size_log()); + assert!(comm.num_polys == 1); + let (trees, oracles) = commit_phase::( + &pp.encoding_params, + point, + comm, transcript, poly.num_vars, - poly.num_vars - V::get_basecode(), - &pp.table_w_weights, - pp.log_rate, + poly.num_vars - Spec::get_basecode_msg_size_log(), &hasher, ); let query_timer = start_timer!(|| "Basefold::open::query_phase"); // Each entry in queried_els stores a list of triples (F, F, i) indicating the // position opened at each round and the two values at that round - let queries = query_phase(transcript, &comm, &oracles, pp.num_verifier_queries); + let queries = prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries()); end_timer!(query_timer); let query_timer = start_timer!(|| "Basefold::open::build_query_result"); @@ -241,10 +359,15 @@ where Ok(()) } + /// Open a batch of polynomial commitments at several points. + /// The current version only supports one polynomial per commitment. + /// Because otherwise it is complex to match the polynomials and + /// the commitments, and because currently this high flexibility is + /// not very useful in ceno. fn batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, @@ -252,9 +375,12 @@ where let hasher = new_hasher::(); let timer = start_timer!(|| "Basefold::batch_open"); let num_vars = polys.iter().map(|poly| poly.num_vars).max().unwrap(); - let comms = comms.into_iter().collect_vec(); let min_num_vars = polys.iter().map(|p| p.num_vars).min().unwrap(); - assert!(min_num_vars >= V::get_basecode()); + assert!(min_num_vars >= Spec::get_basecode_msg_size_log()); + + comms.iter().for_each(|comm| { + assert!(comm.num_polys == 1); + }); if cfg!(feature = "sanity-check") { evals.iter().for_each(|eval| { @@ -265,12 +391,7 @@ where }) } - validate_input( - "batch open", - pp.max_num_vars, - &polys.clone(), - &points.to_vec(), - )?; + validate_input("batch open", pp.get_max_message_size_log(), polys, points)?; let sumcheck_timer = start_timer!(|| "Basefold::batch_open::initial sumcheck"); // evals.len() is the batch size, i.e., how many polynomials are being opened together @@ -337,7 +458,7 @@ where .map(|((scalar, poly), point)| { inner_product( &poly_iter_ext(poly).collect_vec(), - build_eq_x_r_vec(&point).iter(), + build_eq_x_r_vec(point).iter(), ) * scalar * E::from(1 << (num_vars - poly.num_vars)) // When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube @@ -402,7 +523,7 @@ where ); *scalar * evals_from_sum_check - * &eq_xy_eval(point.as_slice(), &challenges[0..point.len()]) + * eq_xy_eval(point.as_slice(), &challenges[0..point.len()]) }) .sum::(); assert_eq!(new_target_sum, desired_sum); @@ -412,25 +533,24 @@ where let point = challenges; - let (trees, oracles) = batch_commit_phase( + let (trees, oracles) = batch_commit_phase::( + &pp.encoding_params, &point, - comms.as_slice(), + comms, transcript, num_vars, - num_vars - V::get_basecode(), - &pp.table_w_weights, - pp.log_rate, + num_vars - Spec::get_basecode_msg_size_log(), coeffs.as_slice(), &hasher, ); let query_timer = start_timer!(|| "Basefold::batch_open query phase"); - let query_result = batch_query_phase( + let query_result = batch_prover_query_phase( transcript, - 1 << (num_vars + pp.log_rate), - comms.as_slice(), + 1 << (num_vars + Spec::get_rate_log()), + comms, &oracles, - pp.num_verifier_queries, + Spec::get_number_queries(), ); end_timer!(query_timer); @@ -439,7 +559,7 @@ where BatchedQueriesResultWithMerklePath::from_batched_query_result( query_result, &trees, - &comms, + comms, ); end_timer!(query_timer); @@ -451,6 +571,84 @@ where Ok(()) } + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment and have the same + /// number of variables. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + pp: &Self::ProverParam, + polys: &[DenseMultilinearExtension], + comm: &Self::CommitmentWithData, + point: &[E], + evals: &[E], + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error> { + let hasher = new_hasher::(); + let timer = start_timer!(|| "Basefold::batch_open"); + let num_vars = polys[0].num_vars; + + polys + .iter() + .for_each(|poly| assert_eq!(poly.num_vars, num_vars)); + assert!(num_vars >= Spec::get_basecode_msg_size_log()); + assert_eq!(comm.num_polys, polys.len()); + assert_eq!(comm.num_polys, evals.len()); + + if cfg!(feature = "sanity-check") { + evals + .iter() + .zip(polys) + .for_each(|(eval, poly)| assert_eq!(&poly.evaluate(point), eval)) + } + // evals.len() is the batch size, i.e., how many polynomials are being opened together + let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; + let t = transcript.squeeze_challenges(batch_size_log); + + // Use eq(X,t) where t is random to batch the different evaluation queries. + // Note that this is a small polynomial (only batch_size) compared to the polynomials + // to open. + let eq_xt = build_eq_x_r_vec(&t)[..evals.len()].to_vec(); + let _target_sum = inner_product(evals, &eq_xt); + + // Now the verifier has obtained the new target sum, and is able to compute the random + // linear coefficients. + // The remaining tasks for the prover is to prove that + // sum_i coeffs[i] poly_evals[i] is equal to + // the new target sum, where coeffs is computed as follows + let (trees, oracles) = simple_batch_commit_phase::( + &pp.encoding_params, + point, + &eq_xt, + comm, + transcript, + num_vars, + num_vars - Spec::get_basecode_msg_size_log(), + &hasher, + ); + + let query_timer = start_timer!(|| "Basefold::open::query_phase"); + // Each entry in queried_els stores a list of triples (F, F, i) indicating the + // position opened at each round and the two values at that round + let queries = + simple_batch_prover_query_phase(transcript, comm, &oracles, Spec::get_number_queries()); + end_timer!(query_timer); + + let query_timer = start_timer!(|| "Basefold::open::build_query_result"); + + let queries_with_merkle_path = + SimpleBatchQueriesResultWithMerklePath::from_query_result(queries, &trees, comm); + end_timer!(query_timer); + + let query_timer = start_timer!(|| "Basefold::open::write_queries"); + queries_with_merkle_path.write_transcript(transcript); + end_timer!(query_timer); + + end_timer!(timer); + + Ok(()) + } + fn read_commitments( _: &Self::VerifierParam, num_polys: usize, @@ -462,14 +660,15 @@ where let num_vars = base_to_usize::(&transcript.read_field_element_base().unwrap()); let is_base = base_to_usize::(&transcript.read_field_element_base().unwrap()) != 0; - (num_vars, commitment, is_base) + let num_polys = base_to_usize::(&transcript.read_field_element_base().unwrap()); + (num_vars, commitment, is_base, num_polys) }) .collect_vec(); Ok(roots .iter() - .map(|(num_vars, commitment, is_base)| { - BasefoldCommitment::new(commitment.clone(), *num_vars, *is_base) + .map(|(num_vars, commitment, is_base, num_polys)| { + BasefoldCommitment::new(commitment.clone(), *num_vars, *is_base, *num_polys) }) .collect_vec()) } @@ -484,6 +683,7 @@ where transcript.write_commitment(&comm.get_root_as())?; transcript.write_field_element_base(&u32_to_field::(comm.num_vars as u32))?; transcript.write_field_element_base(&u32_to_field::(comm.is_base as u32))?; + transcript.write_field_element_base(&u32_to_field::(comm.num_polys as u32))?; Ok(comm) } @@ -496,15 +696,16 @@ where transcript: &mut impl TranscriptRead, ) -> Result<(), Error> { let timer = start_timer!(|| "Basefold::verify"); - assert!(comm.num_vars().unwrap() >= V::get_basecode()); + assert!(comm.num_vars().unwrap() >= Spec::get_basecode_msg_size_log()); let hasher = new_hasher::(); - let _field_size = 255; let num_vars = point.len(); - let num_rounds = num_vars - V::get_basecode(); + if let Some(comm_num_vars) = comm.num_vars() { + assert_eq!(num_vars, comm_num_vars); + } + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); - let mut fold_challenges: Vec = Vec::with_capacity(vp.max_num_vars); - let _size = 0; + let mut fold_challenges: Vec = Vec::with_capacity(num_vars); let mut roots = Vec::new(); let mut sumcheck_messages = Vec::with_capacity(num_rounds); let sumcheck_timer = start_timer!(|| "Basefold::verify::interaction"); @@ -519,19 +720,19 @@ where let read_timer = start_timer!(|| "Basefold::verify::read transcript"); let final_message = transcript - .read_field_elements_ext(1 << V::get_basecode()) + .read_field_elements_ext(1 << Spec::get_basecode_msg_size_log()) .unwrap(); let query_challenges = transcript - .squeeze_challenges(vp.num_verifier_queries) + .squeeze_challenges(Spec::get_number_queries()) .iter() - .map(|index| ext_to_usize(index) % (1 << (num_vars + vp.log_rate))) + .map(|index| ext_to_usize(index) % (1 << (num_vars + Spec::get_rate_log()))) .collect_vec(); let read_query_timer = start_timer!(|| "Basefold::verify::read query"); let query_result_with_merkle_path = if comm.is_base() { QueriesResultWithMerklePath::read_transcript_base( transcript, num_rounds, - vp.log_rate, + Spec::get_rate_log(), num_vars, query_challenges.as_slice(), ) @@ -539,7 +740,7 @@ where QueriesResultWithMerklePath::read_transcript_ext( transcript, num_rounds, - vp.log_rate, + Spec::get_rate_log(), num_vars, query_challenges.as_slice(), ) @@ -558,19 +759,18 @@ where let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); eq.par_iter_mut().for_each(|e| *e *= coeff); - verifier_query_phase( + verifier_query_phase::( + &vp.encoding_params, &query_result_with_merkle_path, &sumcheck_messages, &fold_challenges, num_rounds, num_vars, - vp.log_rate, &final_message, &roots, comm, eq.as_slice(), - vp.rng.clone(), - &eval, + eval, &hasher, ); end_timer!(timer); @@ -580,7 +780,7 @@ where fn batch_verify( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, @@ -589,10 +789,10 @@ where // let key = "RAYON_NUM_THREADS"; // env::set_var(key, "32"); let hasher = new_hasher::(); - let comms = comms.into_iter().collect_vec(); + let comms = comms.iter().collect_vec(); let num_vars = points.iter().map(|point| point.len()).max().unwrap(); - let num_rounds = num_vars - V::get_basecode(); - validate_input("batch verify", vp.max_num_vars, &vec![], &points.to_vec())?; + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); + validate_input("batch verify", num_vars, &[], points)?; let poly_num_vars = comms.iter().map(|c| c.num_vars().unwrap()).collect_vec(); evals.iter().for_each(|eval| { assert_eq!( @@ -600,7 +800,7 @@ where comms[eval.poly()].num_vars().unwrap() ); }); - assert!(poly_num_vars.iter().min().unwrap() >= &V::get_basecode()); + assert!(poly_num_vars.iter().min().unwrap() >= &Spec::get_basecode_msg_size_log()); let sumcheck_timer = start_timer!(|| "Basefold::batch_verify::initial sumcheck"); let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; @@ -646,13 +846,13 @@ where } } let final_message = transcript - .read_field_elements_ext(1 << V::get_basecode()) + .read_field_elements_ext(1 << Spec::get_basecode_msg_size_log()) .unwrap(); let query_challenges = transcript - .squeeze_challenges(vp.num_verifier_queries) + .squeeze_challenges(Spec::get_number_queries()) .iter() - .map(|index| ext_to_usize(index) % (1 << (num_vars + vp.log_rate))) + .map(|index| ext_to_usize(index) % (1 << (num_vars + Spec::get_rate_log()))) .collect_vec(); let read_query_timer = start_timer!(|| "Basefold::verify::read query"); @@ -663,7 +863,7 @@ where BatchedQueriesResultWithMerklePath::read_transcript_base( transcript, num_rounds, - vp.log_rate, + Spec::get_rate_log(), poly_num_vars.as_slice(), query_challenges.as_slice(), ) @@ -671,7 +871,7 @@ where BatchedQueriesResultWithMerklePath::read_transcript_ext( transcript, num_rounds, - vp.log_rate, + Spec::get_rate_log(), poly_num_vars.as_slice(), query_challenges.as_slice(), ) @@ -692,28 +892,130 @@ where ); eq.par_iter_mut().for_each(|e| *e *= coeff); - batch_verifier_query_phase( + batch_verifier_query_phase::( + &vp.encoding_params, &query_result_with_merkle_path, &sumcheck_messages, &fold_challenges, num_rounds, num_vars, - vp.log_rate, &final_message, &roots, &comms, &coeffs, eq.as_slice(), - vp.rng.clone(), &new_target_sum, &hasher, ); end_timer!(timer); Ok(()) } + + fn simple_batch_verify( + vp: &Self::VerifierParam, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error> { + let timer = start_timer!(|| "Basefold::simple batch verify"); + assert!(comm.num_vars().unwrap() >= Spec::get_basecode_msg_size_log()); + let batch_size = evals.len(); + if let Some(num_polys) = comm.num_polys { + assert_eq!(num_polys, batch_size); + } + let hasher = new_hasher::(); + + let num_vars = point.len(); + if let Some(comm_num_vars) = comm.num_vars { + assert_eq!(num_vars, comm_num_vars); + } + let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); + + // evals.len() is the batch size, i.e., how many polynomials are being opened together + let batch_size_log = evals.len().next_power_of_two().ilog2() as usize; + let t = transcript.squeeze_challenges(batch_size_log); + let eq_xt = build_eq_x_r_vec(&t)[..evals.len()].to_vec(); + + let mut fold_challenges: Vec = Vec::with_capacity(num_vars); + let mut roots = Vec::new(); + let mut sumcheck_messages = Vec::with_capacity(num_rounds); + let sumcheck_timer = start_timer!(|| "Basefold::simple_batch_verify::interaction"); + for i in 0..num_rounds { + sumcheck_messages.push(transcript.read_field_elements_ext(3).unwrap()); + fold_challenges.push(transcript.squeeze_challenge()); + if i < num_rounds - 1 { + roots.push(transcript.read_commitment().unwrap()); + } + } + end_timer!(sumcheck_timer); + + let read_timer = start_timer!(|| "Basefold::verify::read transcript"); + let final_message = transcript + .read_field_elements_ext(1 << Spec::get_basecode_msg_size_log()) + .unwrap(); + let query_challenges = transcript + .squeeze_challenges(Spec::get_number_queries()) + .iter() + .map(|index| ext_to_usize(index) % (1 << (num_vars + Spec::get_rate_log()))) + .collect_vec(); + let read_query_timer = start_timer!(|| "Basefold::verify::read query"); + let query_result_with_merkle_path = if comm.is_base() { + SimpleBatchQueriesResultWithMerklePath::read_transcript_base( + transcript, + num_rounds, + Spec::get_rate_log(), + num_vars, + query_challenges.as_slice(), + batch_size, + ) + } else { + SimpleBatchQueriesResultWithMerklePath::read_transcript_ext( + transcript, + num_rounds, + Spec::get_rate_log(), + num_vars, + query_challenges.as_slice(), + batch_size, + ) + }; + end_timer!(read_query_timer); + end_timer!(read_timer); + + // coeff is the eq polynomial evaluated at the last challenge.len() variables + // in reverse order. + let rev_challenges = fold_challenges.clone().into_iter().rev().collect_vec(); + let coeff = eq_xy_eval( + &point[point.len() - fold_challenges.len()..], + &rev_challenges, + ); + // Compute eq as the partially evaluated eq polynomial + let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); + eq.par_iter_mut().for_each(|e| *e *= coeff); + + simple_batch_verifier_query_phase::( + &vp.encoding_params, + &query_result_with_merkle_path, + &sumcheck_messages, + &fold_challenges, + &eq_xt, + num_rounds, + num_vars, + &final_message, + &roots, + comm, + eq.as_slice(), + evals, + &hasher, + ); + end_timer!(timer); + + Ok(()) + } } -impl NoninteractivePCS for Basefold +impl, Rng: RngCore + std::fmt::Debug> NoninteractivePCS + for Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, @@ -724,45 +1026,128 @@ where mod test { use crate::{ basefold::Basefold, - test_util::{run_batch_commit_open_verify, run_commit_open_verify}, + test_util::{ + run_batch_commit_open_verify, run_commit_open_verify, + run_simple_batch_commit_open_verify, + }, util::transcript::PoseidonTranscript, }; use goldilocks::GoldilocksExt2; + use rand_chacha::ChaCha8Rng; - use super::BasefoldDefaultParams; + use super::{structure::BasefoldBasecodeParams, BasefoldRSParams}; - type PcsGoldilocks = Basefold; + type PcsGoldilocksRSCode = Basefold; + type PcsGoldilocksBaseCode = Basefold; #[test] - fn commit_open_verify_goldilocks_base() { + fn commit_open_verify_goldilocks_basecode_base() { // Challenge is over extension field, poly over the base field - run_commit_open_verify::>( - true, 10, 11, + run_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + PoseidonTranscript, + >(true, 10, 11); + } + + #[test] + fn commit_open_verify_goldilocks_rscode_base() { + // Challenge is over extension field, poly over the base field + run_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + PoseidonTranscript, + >(true, 10, 11); + } + + #[test] + fn commit_open_verify_goldilocks_basecode_2() { + // Both challenge and poly are over extension field + run_commit_open_verify::>( + false, 10, 11, ); } #[test] - fn commit_open_verify_goldilocks_2() { + fn commit_open_verify_goldilocks_rscode_2() { // Both challenge and poly are over extension field - run_commit_open_verify::>( + run_commit_open_verify::>( false, 10, 11, ); } #[test] - fn batch_commit_open_verify_goldilocks_base() { + fn simple_batch_commit_open_verify_goldilocks_basecode_base() { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + PoseidonTranscript, + >(true, 10, 11, 4); + } + + #[test] + fn simple_batch_commit_open_verify_goldilocks_rscode_base() { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + PoseidonTranscript, + >(true, 10, 11, 4); + } + + #[test] + fn simple_batch_commit_open_verify_goldilocks_basecode_2() { + // Both challenge and poly are over extension field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + PoseidonTranscript<_>, + >(false, 10, 11, 4); + } + + #[test] + fn simple_batch_commit_open_verify_goldilocks_rscode_2() { + // Both challenge and poly are over extension field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + PoseidonTranscript<_>, + >(false, 10, 11, 4); + } + + #[test] + fn batch_commit_open_verify_goldilocks_basecode_base() { // Both challenge and poly are over base field run_batch_commit_open_verify::< GoldilocksExt2, - PcsGoldilocks, + PcsGoldilocksBaseCode, PoseidonTranscript, >(true, 10, 11); } #[test] - fn batch_commit_open_verify_goldilocks_2() { + fn batch_commit_open_verify_goldilocks_rscode_base() { + // Both challenge and poly are over base field + run_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + PoseidonTranscript, + >(true, 10, 11); + } + + #[test] + fn batch_commit_open_verify_goldilocks_basecode_2() { + // Both challenge and poly are over extension field + run_batch_commit_open_verify::>( + false, 10, 11, + ); + } + + #[test] + fn batch_commit_open_verify_goldilocks_rscode_2() { // Both challenge and poly are over extension field - run_batch_commit_open_verify::>( + run_batch_commit_open_verify::>( false, 10, 11, ); } diff --git a/mpcs/src/basefold/basecode.rs b/mpcs/src/basefold/basecode.rs deleted file mode 100644 index a10679ede..000000000 --- a/mpcs/src/basefold/basecode.rs +++ /dev/null @@ -1,293 +0,0 @@ -use crate::util::{arithmetic::base_from_raw_bytes, log2_strict, num_of_bytes}; -use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; -use ark_std::{end_timer, start_timer}; -use ctr; -use ff::{BatchInverter, Field, PrimeField}; -use ff_ext::ExtensionField; -use generic_array::GenericArray; -use multilinear_extensions::mle::FieldType; -use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; - -use itertools::Itertools; - -use crate::util::plonky2_util::reverse_index_bits_in_place; -use rand_chacha::rand_core::RngCore; -use rayon::prelude::IntoParallelRefIterator; - -use crate::util::arithmetic::{horner, steps}; - -pub fn encode_field_type_rs_basecode( - poly: &FieldType, - rate: usize, - message_size: usize, -) -> Vec> { - match poly { - FieldType::Ext(poly) => encode_rs_basecode(poly, rate, message_size) - .iter() - .map(|x| FieldType::Ext(x.clone())) - .collect(), - FieldType::Base(poly) => encode_rs_basecode(poly, rate, message_size) - .iter() - .map(|x| FieldType::Base(x.clone())) - .collect(), - _ => panic!("Unsupported field type"), - } -} - -// Split the input into chunks of message size, encode each message, and return the codewords -pub fn encode_rs_basecode( - poly: &Vec, - rate: usize, - message_size: usize, -) -> Vec> { - let timer = start_timer!(|| "Encode basecode"); - // The domain is just counting 1, 2, 3, ... , domain_size - let domain: Vec = steps(F::ONE).take(message_size * rate).collect(); - let res = poly - .par_chunks_exact(message_size) - .map(|chunk| { - let mut target = vec![F::ZERO; message_size * rate]; - // Just Reed-Solomon code, but with the naive domain - target - .iter_mut() - .enumerate() - .for_each(|(i, target)| *target = horner(&chunk[..], &domain[i])); - target - }) - .collect::>>(); - end_timer!(timer); - - res -} - -fn concatenate_field_types(coeffs: &Vec>) -> FieldType { - match coeffs[0] { - FieldType::Ext(_) => { - let res = coeffs - .iter() - .map(|x| match x { - FieldType::Ext(x) => x.iter().map(|x| *x), - _ => unreachable!(), - }) - .flatten() - .collect::>(); - FieldType::Ext(res) - } - FieldType::Base(_) => { - let res = coeffs - .iter() - .map(|x| match x { - FieldType::Base(x) => x.iter().map(|x| *x), - _ => unreachable!(), - }) - .flatten() - .collect::>(); - FieldType::Base(res) - } - _ => unreachable!(), - } -} - -// this function assumes all codewords in base_codeword has equivalent length -pub fn evaluate_over_foldable_domain_generic_basecode( - base_message_length: usize, - num_coeffs: usize, - log_rate: usize, - base_codewords: Vec>, - table: &Vec>, -) -> FieldType { - let timer = start_timer!(|| "evaluate over foldable domain"); - let k = num_coeffs; - let logk = log2_strict(k); - let base_log_k = log2_strict(base_message_length); - // concatenate together all base codewords - // let now = Instant::now(); - let mut coeffs_with_bc = concatenate_field_types(&base_codewords); - // println!("concatenate base codewords {:?}", now.elapsed()); - // iterate over array, replacing even indices with (evals[i] - evals[(i+1)]) - let mut chunk_size = base_codewords[0].len(); // block length of the base code - for i in base_log_k..logk { - // In beginning of each iteration, the current codeword size is 1<> 1); - match coeffs_with_bc { - FieldType::Ext(ref mut coeffs_with_bc) => { - coeffs_with_bc.par_chunks_mut(chunk_size).for_each(|chunk| { - let half_chunk = chunk_size >> 1; - for j in half_chunk..chunk_size { - // Suppose the current codewords are (a, b) - // The new codeword is computed by two halves: - // left = a + t * b - // right = a - t * b - let rhs = chunk[j] * E::from(level[j - half_chunk]); - chunk[j] = chunk[j - half_chunk] - rhs; - chunk[j - half_chunk] = chunk[j - half_chunk] + rhs; - } - }); - } - FieldType::Base(ref mut coeffs_with_bc) => { - coeffs_with_bc.par_chunks_mut(chunk_size).for_each(|chunk| { - let half_chunk = chunk_size >> 1; - for j in half_chunk..chunk_size { - // Suppose the current codewords are (a, b) - // The new codeword is computed by two halves: - // left = a + t * b - // right = a - t * b - let rhs = chunk[j] * level[j - half_chunk]; - chunk[j] = chunk[j - half_chunk] - rhs; - chunk[j - half_chunk] = chunk[j - half_chunk] + rhs; - } - }); - } - _ => unreachable!(), - } - } - end_timer!(timer); - coeffs_with_bc -} - -pub fn get_table_aes( - poly_size: usize, - rate: usize, - rng: &mut Rng, -) -> ( - Vec>, - Vec>, -) { - // The size (logarithmic) of the codeword for the polynomial - let lg_n: usize = rate + log2_strict(poly_size); - - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - - let mut cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - - // Allocate the buffer for storing n field elements (the entire codeword) - let bytes = num_of_bytes::(1 << lg_n); - let mut dest: Vec = vec![0u8; bytes]; - cipher.apply_keystream(&mut dest[..]); - - // Now, dest is a vector filled with random data for a field vector of size n - - // Collect the bytes into field elements - let flat_table: Vec = dest - .par_chunks_exact(num_of_bytes::(1)) - .map(|chunk| base_from_raw_bytes::(&chunk.to_vec())) - .collect::>(); - - // Now, flat_table is a field vector of size n, filled with random field elements - assert_eq!(flat_table.len(), 1 << lg_n); - - // Multiply -2 to every element to get the weights. Now weights = { -2x } - let mut weights: Vec = flat_table - .par_iter() - .map(|el| E::BaseField::ZERO - *el - *el) - .collect(); - - // Then invert all the elements. Now weights = { -1/2x } - let mut scratch_space = vec![E::BaseField::ZERO; weights.len()]; - BatchInverter::invert_with_external_scratch(&mut weights, &mut scratch_space); - - // Zip x and -1/2x together. The result is the list { (x, -1/2x) } - // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which - // involves computing 1/(b-a) where b=-x and a=x, and 1/(b-a) here is exactly -1/2x - let flat_table_w_weights = flat_table - .iter() - .zip(weights) - .map(|(el, w)| (*el, w)) - .collect_vec(); - - // Split the positions from 0 to n-1 into slices of sizes: - // 2, 2, 4, 8, ..., n/2, exactly lg_n number of them - // The weights are (x, -1/2x), the table elements are just x - - let mut unflattened_table_w_weights = vec![Vec::new(); lg_n]; - let mut unflattened_table = vec![Vec::new(); lg_n]; - - let mut level_weights = flat_table_w_weights[0..2].to_vec(); - // Apply the reverse-bits permutation to a vector of size 2, equivalent to just swapping - reverse_index_bits_in_place(&mut level_weights); - unflattened_table_w_weights[0] = level_weights; - - unflattened_table[0] = flat_table[0..2].to_vec(); - for i in 1..lg_n { - unflattened_table[i] = flat_table[(1 << i)..(1 << (i + 1))].to_vec(); - let mut level = flat_table_w_weights[(1 << i)..(1 << (i + 1))].to_vec(); - reverse_index_bits_in_place(&mut level); - unflattened_table_w_weights[i] = level; - } - - return (unflattened_table_w_weights, unflattened_table); -} - -pub fn query_point( - block_length: usize, - eval_index: usize, - level: usize, - mut cipher: &mut ctr::Ctr32LE, -) -> E::BaseField { - let level_index = eval_index % (block_length); - let mut el = - query_root_table_from_rng_aes::(level, level_index % (block_length >> 1), &mut cipher); - - if level_index >= (block_length >> 1) { - el = -E::BaseField::ONE * el; - } - - return el; -} - -pub fn query_root_table_from_rng_aes( - level: usize, - index: usize, - cipher: &mut ctr::Ctr32LE, -) -> E::BaseField { - let mut level_offset: u128 = 1; - for lg_m in 1..=level { - let half_m = 1 << (lg_m - 1); - level_offset += half_m; - } - - let pos = ((level_offset + (index as u128)) - * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) - .checked_div(8) - .unwrap(); - - cipher.seek(pos); - - let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; - let mut dest: Vec = vec![0u8; bytes]; - cipher.apply_keystream(&mut dest); - - let res = base_from_raw_bytes::(&dest); - - res -} - -#[cfg(test)] -mod tests { - use super::*; - use goldilocks::GoldilocksExt2; - use multilinear_extensions::mle::DenseMultilinearExtension; - - #[test] - fn time_rs_code() { - use rand::rngs::OsRng; - - let poly = DenseMultilinearExtension::random(20, &mut OsRng); - - encode_field_type_rs_basecode::(&poly.evaluations, 2, 64); - } -} diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index 1716edaea..da8b4ad75 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -1,5 +1,6 @@ use super::{ - basecode::encode_rs_basecode, + encoding::EncodingScheme, + structure::BasefoldSpec, sumcheck::{ sum_check_challenge_round, sum_check_first_round, sum_check_first_round_field_type, sum_check_last_round, @@ -29,29 +30,37 @@ use rayon::prelude::{ use super::structure::BasefoldCommitmentWithData; // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) -pub fn commit_phase( +pub fn commit_phase>( + pp: &>::ProverParameters, point: &[E], comm: &BasefoldCommitmentWithData, transcript: &mut impl TranscriptWrite, E>, num_vars: usize, num_rounds: usize, - table_w_weights: &Vec>, - log_rate: usize, hasher: &Hasher, ) -> (Vec>, Vec>) where E::BaseField: Serialize + DeserializeOwned, { let timer = start_timer!(|| "Commit phase"); + #[cfg(feature = "sanity-check")] assert_eq!(point.len(), num_vars); let mut oracles = Vec::with_capacity(num_vars); let mut trees = Vec::with_capacity(num_vars); - let mut running_oracle = field_type_iter_ext(comm.get_codeword()).collect_vec(); - let mut running_evals = comm.bh_evals.clone(); + let mut running_oracle = field_type_iter_ext(&comm.get_codewords()[0]).collect_vec(); + let mut running_evals = comm.polynomials_bh_evals[0].clone(); + + #[cfg(feature = "sanity-check")] + assert_eq!( + running_oracle.len(), + running_evals.len() << Spec::get_rate_log() + ); + #[cfg(feature = "sanity-check")] + assert_eq!(running_evals.len(), 1 << num_vars); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube let build_eq_timer = start_timer!(|| "Basefold::open"); - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); end_timer!(build_eq_timer); reverse_index_bits_in_place(&mut eq); @@ -59,6 +68,9 @@ where let mut last_sumcheck_message = sum_check_first_round_field_type(&mut eq, &mut running_evals); end_timer!(sumcheck_timer); + #[cfg(feature = "sanity-check")] + assert_eq!(last_sumcheck_message.len(), 3); + let mut running_evals = match running_evals { FieldType::Ext(evals) => evals, FieldType::Base(evals) => evals.iter().map(|x| E::from(*x)).collect_vec(), @@ -77,8 +89,8 @@ where let challenge = transcript.squeeze_challenge(); // Fold the current oracle for FRI - running_oracle = basefold_one_round_by_interpolation_weights::( - &table_w_weights, + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, log2_strict(running_oracle.len()) - 1, &running_oracle, challenge, @@ -111,9 +123,17 @@ where let mut coeffs = running_evals.clone(); interpolate_over_boolean_hypercube(&mut coeffs); - let basecode = encode_rs_basecode(&coeffs, 1 << log_rate, coeffs.len()); - assert_eq!(basecode.len(), 1); - let basecode = basecode[0].clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(b) => b, + _ => panic!("Should be ext field"), + }; reverse_index_bits_in_place(&mut running_oracle); assert_eq!(basecode, running_oracle); @@ -122,18 +142,18 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + (trees, oracles) } // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) -pub fn batch_commit_phase( +#[allow(clippy::too_many_arguments)] +pub fn batch_commit_phase>( + pp: &>::ProverParameters, point: &[E], - comms: &[&BasefoldCommitmentWithData], + comms: &[BasefoldCommitmentWithData], transcript: &mut impl TranscriptWrite, E>, num_vars: usize, num_rounds: usize, - table_w_weights: &Vec>, - log_rate: usize, coeffs: &[E], hasher: &Hasher, ) -> (Vec>, Vec>) @@ -144,7 +164,7 @@ where assert_eq!(point.len(), num_vars); let mut oracles = Vec::with_capacity(num_vars); let mut trees = Vec::with_capacity(num_vars); - let mut running_oracle = vec![E::ZERO; 1 << (num_vars + log_rate)]; + let mut running_oracle = vec![E::ZERO; 1 << (num_vars + Spec::get_rate_log())]; let build_oracle_timer = start_timer!(|| "Basefold build initial oracle"); // Before the interaction, collect all the polynomials whose num variables match the @@ -157,8 +177,8 @@ where .for_each(|(index, comm)| { running_oracle .iter_mut() - .zip_eq(field_type_iter_ext(comm.get_codeword())) - .for_each(|(r, a)| *r += E::from(a) * coeffs[index]); + .zip_eq(field_type_iter_ext(&comm.get_codewords()[0])) + .for_each(|(r, a)| *r += a * coeffs[index]); }); end_timer!(build_oracle_timer); @@ -177,16 +197,16 @@ where // to align the polynomials to the variable with index 0 before adding them // together. So each element is repeated by // sum_of_all_evals_for_sumcheck.len() / bh_evals.len() times - *r += E::from(field_type_index_ext( - &comm.bh_evals, - pos >> (num_vars - log2_strict(comm.bh_evals.len())), - )) * coeffs[index] + *r += field_type_index_ext( + &comm.polynomials_bh_evals[0], + pos >> (num_vars - log2_strict(comm.polynomials_bh_evals[0].len())), + ) * coeffs[index] }); }); end_timer!(build_oracle_timer); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube - let mut eq = build_eq_x_r_vec(&point); + let mut eq = build_eq_x_r_vec(point); reverse_index_bits_in_place(&mut eq); let sumcheck_timer = start_timer!(|| "Basefold first round"); @@ -208,8 +228,8 @@ where let challenge = transcript.squeeze_challenge(); // Fold the current oracle for FRI - running_oracle = basefold_one_round_by_interpolation_weights::( - &table_w_weights, + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, log2_strict(running_oracle.len()) - 1, &running_oracle, challenge, @@ -236,8 +256,8 @@ where .for_each(|(index, comm)| { running_oracle .iter_mut() - .zip_eq(field_type_iter_ext(comm.get_codeword())) - .for_each(|(r, a)| *r += E::from(a) * coeffs[index]); + .zip_eq(field_type_iter_ext(&comm.get_codewords()[0])) + .for_each(|(r, a)| *r += a * coeffs[index]); }); } else { // The difference of the last round is that we don't need to compute the message, @@ -257,10 +277,127 @@ where // on the prover side should be exactly the encoding of the folded polynomial. let mut coeffs = sum_of_all_evals_for_sumcheck.clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } + interpolate_over_boolean_hypercube(&mut coeffs); + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(x) => x, + _ => panic!("Expected ext field"), + }; + + reverse_index_bits_in_place(&mut running_oracle); + assert_eq!(basecode, running_oracle); + } + } + end_timer!(sumcheck_timer); + } + end_timer!(timer); + (trees, oracles) +} + +// outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) +#[allow(clippy::too_many_arguments)] +pub fn simple_batch_commit_phase>( + pp: &>::ProverParameters, + point: &[E], + batch_coeffs: &[E], + comm: &BasefoldCommitmentWithData, + transcript: &mut impl TranscriptWrite, E>, + num_vars: usize, + num_rounds: usize, + hasher: &Hasher, +) -> (Vec>, Vec>) +where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Simple batch commit phase"); + assert_eq!(point.len(), num_vars); + assert_eq!(comm.num_polys, batch_coeffs.len()); + let mut oracles = Vec::with_capacity(num_vars); + let mut trees = Vec::with_capacity(num_vars); + let mut running_oracle = comm.batch_codewords(&batch_coeffs.to_vec()); + let mut running_evals = (0..(1 << num_vars)) + .map(|i| { + comm.polynomials_bh_evals + .iter() + .zip(batch_coeffs) + .map(|(eval, coeff)| field_type_index_ext(eval, i) * *coeff) + .sum() + }) + .collect_vec(); + + // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube + let build_eq_timer = start_timer!(|| "Basefold::open"); + let mut eq = build_eq_x_r_vec(point); + end_timer!(build_eq_timer); + reverse_index_bits_in_place(&mut eq); + + let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round"); + let mut last_sumcheck_message = sum_check_first_round(&mut eq, &mut running_evals); + end_timer!(sumcheck_timer); + + for i in 0..num_rounds { + let sumcheck_timer = start_timer!(|| format!("Basefold round {}", i)); + // For the first round, no need to send the running root, because this root is + // committing to a vector that can be recovered from linearly combining other + // already-committed vectors. + transcript + .write_field_elements_ext(&last_sumcheck_message) + .unwrap(); + + let challenge = transcript.squeeze_challenge(); + + // Fold the current oracle for FRI + running_oracle = basefold_one_round_by_interpolation_weights::( + pp, + log2_strict(running_oracle.len()) - 1, + &running_oracle, + challenge, + ); + + if i < num_rounds - 1 { + last_sumcheck_message = + sum_check_challenge_round(&mut eq, &mut running_evals, challenge); + let running_tree = + MerkleTree::::from_leaves(FieldType::Ext(running_oracle.clone()), hasher); + let running_root = running_tree.root(); + transcript.write_commitment(&running_root).unwrap(); + + oracles.push(running_oracle.clone()); + trees.push(running_tree); + } else { + // The difference of the last round is that we don't need to compute the message, + // and we don't interpolate the small polynomials. So after the last round, + // running_evals is exactly the evaluation representation of the + // folded polynomial so far. + sum_check_last_round(&mut eq, &mut running_evals, challenge); + // For the FRI part, we send the current polynomial as the message. + // Transform it back into little endiean before sending it + reverse_index_bits_in_place(&mut running_evals); + transcript.write_field_elements_ext(&running_evals).unwrap(); + + if cfg!(feature = "sanity-check") { + // If the prover is honest, in the last round, the running oracle + // on the prover side should be exactly the encoding of the folded polynomial. + + let mut coeffs = running_evals.clone(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut coeffs); + } interpolate_over_boolean_hypercube(&mut coeffs); - let basecode = encode_rs_basecode(&coeffs, 1 << log_rate, coeffs.len()); - assert_eq!(basecode.len(), 1); - let basecode = basecode[0].clone(); + let basecode = >::encode( + pp, + &FieldType::Ext(coeffs), + ); + let basecode = match basecode { + FieldType::Ext(basecode) => basecode, + _ => panic!("Should be ext field"), + }; reverse_index_bits_in_place(&mut running_oracle); assert_eq!(basecode, running_oracle); @@ -269,28 +406,22 @@ where end_timer!(sumcheck_timer); } end_timer!(timer); - return (trees, oracles); + (trees, oracles) } -fn basefold_one_round_by_interpolation_weights( - table: &Vec>, - level_index: usize, +fn basefold_one_round_by_interpolation_weights>( + pp: &>::ProverParameters, + level: usize, values: &Vec, challenge: E, ) -> Vec { - let level = &table[level_index]; values .par_chunks_exact(2) .enumerate() .map(|(i, ys)| { - interpolate2_weights( - [ - (E::from(level[i].0), ys[0]), - (E::from(-(level[i].0)), ys[1]), - ], - E::from(level[i].1), - challenge, - ) + let (x0, x1, w) = + >::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, ys[0]), (x1, ys[1])], w, challenge) }) .collect::>() } diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs new file mode 100644 index 000000000..9fa651e2f --- /dev/null +++ b/mpcs/src/basefold/encoding.rs @@ -0,0 +1,237 @@ +use ff_ext::ExtensionField; +use multilinear_extensions::mle::FieldType; + +mod utils; + +mod basecode; +pub use basecode::{Basecode, BasecodeDefaultSpec}; + +mod rs; +use plonky2::util::log2_strict; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; +pub use rs::{coset_fft, fft, fft_root_table, RSCode, RSCodeDefaultSpec}; + +use serde::{de::DeserializeOwned, Serialize}; + +use crate::{util::arithmetic::interpolate2_weights, Error}; + +pub trait EncodingProverParameters { + fn get_max_message_size_log(&self) -> usize; +} + +pub trait EncodingScheme: std::fmt::Debug + Clone { + type PublicParameters: Clone + std::fmt::Debug + Serialize + DeserializeOwned; + type ProverParameters: Clone + + std::fmt::Debug + + Serialize + + DeserializeOwned + + EncodingProverParameters + + Sync; + type VerifierParameters: Clone + std::fmt::Debug + Serialize + DeserializeOwned + Sync; + + fn setup(max_msg_size_log: usize, rng_seed: [u8; 32]) -> Self::PublicParameters; + + fn trim( + pp: &Self::PublicParameters, + max_msg_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>; + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType; + + /// Encodes a message of small length, such that the verifier is also able + /// to execute the encoding. + fn encode_small(vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType; + + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; + + /// Whether the message needs to be bit-reversed to allow even-odd + /// folding. If the folding is already even-odd style (like RS code), + /// then set this function to return false. If the folding is originally + /// left-right, like basefold, then return true. + fn message_is_left_and_right_folding() -> bool; + + fn message_is_even_and_odd_folding() -> bool { + !Self::message_is_left_and_right_folding() + } + + /// Returns three values: x0, x1 and 1/(x1-x0). Note that although + /// 1/(x1-x0) can be computed from the other two values, we return it + /// separately because inversion is expensive. + /// These three values can be used to interpolate a linear function + /// that passes through the two points (x0, y0) and (x1, y1), for the + /// given y0 and y1, then compute the value of the linear function at + /// any give x. + /// Params: + /// - level: which particular code in this family of codes? + /// - index: position in the codeword (after folded) + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E); + + /// The same as `prover_folding_coeffs`, but for the verifier. The two + /// functions, although provide the same functionality, may use different + /// implementations. For example, prover can use precomputed values stored + /// in the parameters, but the verifier may need to recompute them. + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E); + + /// Fold the given codeword into a smaller codeword of half size, using + /// the folding coefficients computed by `prover_folding_coeffs`. + /// The given codeword is assumed to be bit-reversed on the original + /// codeword directly produced from the `encode` method. + fn fold_bitreversed_codeword( + pp: &Self::ProverParameters, + codeword: &FieldType, + challenge: E, + ) -> Vec { + let level = log2_strict(codeword.len()) - 1; + match codeword { + FieldType::Ext(codeword) => codeword + .par_chunks_exact(2) + .enumerate() + .map(|(i, ys)| { + let (x0, x1, w) = Self::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, ys[0]), (x1, ys[1])], w, challenge) + }) + .collect::>(), + FieldType::Base(codeword) => codeword + .par_chunks_exact(2) + .enumerate() + .map(|(i, ys)| { + let (x0, x1, w) = Self::prover_folding_coeffs(pp, level, i); + interpolate2_weights([(x0, E::from(ys[0])), (x1, E::from(ys[1]))], w, challenge) + }) + .collect::>(), + _ => panic!("Unsupported field type"), + } + } + + /// Fold the given message into a smaller message of half size using challenge + /// as the random linear combination coefficient. + /// Note that this is always even-odd fold, assuming the message has + /// been bit-reversed (or not) according to the setting + /// of the `message_need_bit_reversion` function. + fn fold_message(msg: &FieldType, challenge: E) -> Vec { + match msg { + FieldType::Ext(msg) => msg + .par_chunks_exact(2) + .map(|ys| ys[0] + ys[1] * challenge) + .collect::>(), + FieldType::Base(msg) => msg + .par_chunks_exact(2) + .map(|ys| E::from(ys[0]) + E::from(ys[1]) * challenge) + .collect::>(), + _ => panic!("Unsupported field type"), + } + } +} + +fn concatenate_field_types(coeffs: &[FieldType]) -> FieldType { + match coeffs[0] { + FieldType::Ext(_) => { + let res = coeffs + .iter() + .flat_map(|x| match x { + FieldType::Ext(x) => x.iter().copied(), + _ => unreachable!(), + }) + .collect::>(); + FieldType::Ext(res) + } + FieldType::Base(_) => { + let res = coeffs + .iter() + .flat_map(|x| match x { + FieldType::Base(x) => x.iter().copied(), + _ => unreachable!(), + }) + .collect::>(); + FieldType::Base(res) + } + _ => unreachable!(), + } +} + +#[cfg(test)] +pub(crate) mod test_util { + use ff_ext::ExtensionField; + use multilinear_extensions::mle::FieldType; + use rand::rngs::OsRng; + + use crate::util::plonky2_util::reverse_index_bits_in_place_field_type; + + use super::EncodingScheme; + + pub fn test_codeword_folding>() { + let num_vars = 12; + + let poly: Vec = (0..(1 << num_vars)).map(|i| E::from(i)).collect(); + let mut poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp: Code::PublicParameters = Code::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + reverse_index_bits_in_place_field_type(&mut codeword); + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut poly); + } + let challenge = E::random(&mut OsRng); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let mut folded_message = FieldType::Ext(Code::fold_message(&poly, challenge)); + if Code::message_is_left_and_right_folding() { + // Reverse the message back before encoding if it has been + // bit-reversed + reverse_index_bits_in_place_field_type(&mut folded_message); + } + let mut encoded_folded_message = Code::encode(&pp, &folded_message); + reverse_index_bits_in_place_field_type(&mut encoded_folded_message); + let encoded_folded_message = match encoded_folded_message { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + for (i, (a, b)) in folded_codeword + .iter() + .zip(encoded_folded_message.iter()) + .enumerate() + { + assert_eq!(a, b, "Failed at index {}", i); + } + + let mut folded_codeword = FieldType::Ext(folded_codeword); + for round in 0..4 { + let folded_codeword_vec = + Code::fold_bitreversed_codeword(&pp, &folded_codeword, challenge); + + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut folded_message); + } + folded_message = FieldType::Ext(Code::fold_message(&folded_message, challenge)); + if Code::message_is_left_and_right_folding() { + reverse_index_bits_in_place_field_type(&mut folded_message); + } + let mut encoded_folded_message = Code::encode(&pp, &folded_message); + reverse_index_bits_in_place_field_type(&mut encoded_folded_message); + let encoded_folded_message = match encoded_folded_message { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + for (i, (a, b)) in folded_codeword_vec + .iter() + .zip(encoded_folded_message.iter()) + .enumerate() + { + assert_eq!(a, b, "Failed at index {} in round {}", i, round); + } + folded_codeword = FieldType::Ext(folded_codeword_vec); + } + } +} diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs new file mode 100644 index 000000000..051a510e6 --- /dev/null +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -0,0 +1,451 @@ +use std::marker::PhantomData; + +use super::{concatenate_field_types, EncodingProverParameters, EncodingScheme}; +use crate::{ + util::{ + arithmetic::base_from_raw_bytes, log2_strict, num_of_bytes, plonky2_util::reverse_bits, + }, + vec_mut, Error, +}; +use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; +use ark_std::{end_timer, start_timer}; +use ff::{BatchInverter, Field, PrimeField}; +use ff_ext::ExtensionField; +use generic_array::GenericArray; +use multilinear_extensions::mle::FieldType; +use rand::SeedableRng; +use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; + +use itertools::Itertools; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::util::plonky2_util::reverse_index_bits_in_place; +use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; +use rayon::prelude::IntoParallelRefIterator; + +use crate::util::arithmetic::{horner, steps}; + +pub trait BasecodeSpec: std::fmt::Debug + Clone { + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; +} + +#[derive(Debug, Clone)] +pub struct BasecodeDefaultSpec {} + +impl BasecodeSpec for BasecodeDefaultSpec { + fn get_number_queries() -> usize { + 766 + } + + fn get_rate_log() -> usize { + 3 + } + + fn get_basecode_msg_size_log() -> usize { + 7 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasecodeParameters { + pub(crate) table: Vec>, + pub(crate) table_w_weights: Vec>, + pub(crate) rng_seed: [u8; 32], +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasecodeProverParameters { + pub(crate) table: Vec>, + pub(crate) table_w_weights: Vec>, + pub(crate) rng_seed: [u8; 32], + #[serde(skip)] + _phantom: PhantomData Spec>, +} + +impl EncodingProverParameters + for BasecodeProverParameters +{ + fn get_max_message_size_log(&self) -> usize { + self.table.len() - Spec::get_rate_log() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BasecodeVerifierParameters { + pub(crate) rng_seed: [u8; 32], + pub(crate) aes_key: [u8; 16], + pub(crate) aes_iv: [u8; 16], +} + +#[derive(Debug, Clone)] +pub struct Basecode { + _phantom_data: PhantomData, +} + +impl EncodingScheme for Basecode +where + E::BaseField: Serialize + DeserializeOwned, +{ + type PublicParameters = BasecodeParameters; + + type ProverParameters = BasecodeProverParameters; + + type VerifierParameters = BasecodeVerifierParameters; + + fn setup(max_msg_size_log: usize, rng_seed: [u8; 32]) -> Self::PublicParameters { + let rng = ChaCha8Rng::from_seed(rng_seed); + let (table_w_weights, table) = + get_table_aes::(max_msg_size_log, Spec::get_rate_log(), &mut rng.clone()); + BasecodeParameters { + table, + table_w_weights, + rng_seed, + } + } + + fn trim( + pp: &Self::PublicParameters, + max_msg_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { + if pp.table.len() < Spec::get_rate_log() + max_msg_size_log { + return Err(Error::InvalidPcsParam(format!( + "Public parameter is setup for a smaller message size (log={}) than the trimmed message size (log={})", + pp.table.len() - Spec::get_rate_log(), + max_msg_size_log, + ))); + } + let mut key: [u8; 16] = [0u8; 16]; + let mut iv: [u8; 16] = [0u8; 16]; + let mut rng = ChaCha8Rng::from_seed(pp.rng_seed); + rng.set_word_pos(0); + rng.fill_bytes(&mut key); + rng.fill_bytes(&mut iv); + Ok(( + Self::ProverParameters { + table_w_weights: pp.table_w_weights.clone(), + table: pp.table.clone(), + rng_seed: pp.rng_seed, + _phantom: PhantomData, + }, + Self::VerifierParameters { + rng_seed: pp.rng_seed, + aes_key: key, + aes_iv: iv, + }, + )) + } + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType { + // Split the input into chunks of message size, encode each message, and return the codewords + let basecode = encode_field_type_rs_basecode( + coeffs, + 1 << Spec::get_rate_log(), + 1 << Spec::get_basecode_msg_size_log(), + ); + + // Apply the recursive definition of the BaseFold code to the list of base codewords, + // and produce the final codeword + evaluate_over_foldable_domain_generic_basecode::( + 1 << Spec::get_basecode_msg_size_log(), + coeffs.len(), + Spec::get_rate_log(), + basecode, + &pp.table, + ) + } + + fn encode_small(_vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType { + let mut basecodes = + encode_field_type_rs_basecode(coeffs, 1 << Spec::get_rate_log(), coeffs.len()); + assert_eq!(basecodes.len(), 1); + basecodes.remove(0) + } + + fn get_number_queries() -> usize { + Spec::get_number_queries() + } + + fn get_rate_log() -> usize { + Spec::get_rate_log() + } + + fn get_basecode_msg_size_log() -> usize { + Spec::get_basecode_msg_size_log() + } + + fn message_is_left_and_right_folding() -> bool { + true + } + + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E) { + let level = &pp.table_w_weights[level]; + ( + E::from(level[index].0), + E::from(-level[index].0), + E::from(level[index].1), + ) + } + + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E) { + type Aes128Ctr64LE = ctr::Ctr32LE; + let mut cipher = Aes128Ctr64LE::new( + GenericArray::from_slice(&vp.aes_key[..]), + GenericArray::from_slice(&vp.aes_iv[..]), + ); + + let x0: E::BaseField = query_root_table_from_rng_aes::(level, index, &mut cipher); + let x1 = -x0; + + let w = (x1 - x0).invert().unwrap(); + + (E::from(x0), E::from(x1), E::from(w)) + } +} + +fn encode_field_type_rs_basecode( + poly: &FieldType, + rate: usize, + message_size: usize, +) -> Vec> { + match poly { + FieldType::Ext(poly) => get_basecode(poly, rate, message_size) + .iter() + .map(|x| FieldType::Ext(x.clone())) + .collect(), + FieldType::Base(poly) => get_basecode(poly, rate, message_size) + .iter() + .map(|x| FieldType::Base(x.clone())) + .collect(), + _ => panic!("Unsupported field type"), + } +} + +// Split the input into chunks of message size, encode each message, and return the codewords +// FIXME: It is expensive for now because it is using naive FFT (although it is +// over a small domain) +fn get_basecode(poly: &Vec, rate: usize, message_size: usize) -> Vec> { + let timer = start_timer!(|| "Encode basecode"); + // The domain is just counting 1, 2, 3, ... , domain_size + let domain: Vec = steps(F::ONE).take(message_size * rate).collect(); + let res = poly + .par_chunks_exact(message_size) + .map(|chunk| { + let mut target = vec![F::ZERO; message_size * rate]; + // Just Reed-Solomon code, but with the naive domain + target + .iter_mut() + .enumerate() + .for_each(|(i, target)| *target = horner(chunk, &domain[i])); + target + }) + .collect::>>(); + end_timer!(timer); + + res +} + +// this function assumes all codewords in base_codeword has equivalent length +pub fn evaluate_over_foldable_domain_generic_basecode( + base_message_length: usize, + num_coeffs: usize, + log_rate: usize, + base_codewords: Vec>, + table: &[Vec], +) -> FieldType { + let timer = start_timer!(|| "evaluate over foldable domain"); + let k = num_coeffs; + let logk = log2_strict(k); + let base_log_k = log2_strict(base_message_length); + // concatenate together all base codewords + // let now = Instant::now(); + let mut coeffs_with_bc = concatenate_field_types(&base_codewords); + // println!("concatenate base codewords {:?}", now.elapsed()); + // iterate over array, replacing even indices with (evals[i] - evals[(i+1)]) + let mut chunk_size = base_codewords[0].len(); // block length of the base code + for i in base_log_k..logk { + // In beginning of each iteration, the current codeword size is 1<> 1); + vec_mut!(coeffs_with_bc, |c| { + c.par_chunks_mut(chunk_size).for_each(|chunk| { + let half_chunk = chunk_size >> 1; + for j in half_chunk..chunk_size { + // Suppose the current codewords are (a, b) + // The new codeword is computed by two halves: + // left = a + t * b + // right = a - t * b + let rhs = chunk[j] * level[j - half_chunk]; + chunk[j] = chunk[j - half_chunk] - rhs; + chunk[j - half_chunk] += rhs; + } + }); + }); + } + end_timer!(timer); + coeffs_with_bc +} + +#[allow(clippy::type_complexity)] +pub fn get_table_aes( + poly_size_log: usize, + rate: usize, + rng: &mut Rng, +) -> ( + Vec>, + Vec>, +) { + // The size (logarithmic) of the codeword for the polynomial + let lg_n: usize = rate + poly_size_log; + + let mut key: [u8; 16] = [0u8; 16]; + let mut iv: [u8; 16] = [0u8; 16]; + rng.fill_bytes(&mut key); + rng.fill_bytes(&mut iv); + + type Aes128Ctr64LE = ctr::Ctr32LE; + + let mut cipher = Aes128Ctr64LE::new( + GenericArray::from_slice(&key[..]), + GenericArray::from_slice(&iv[..]), + ); + + // Allocate the buffer for storing n field elements (the entire codeword) + let bytes = num_of_bytes::(1 << lg_n); + let mut dest: Vec = vec![0u8; bytes]; + cipher.apply_keystream(&mut dest[..]); + + // Now, dest is a vector filled with random data for a field vector of size n + + // Collect the bytes into field elements + let flat_table: Vec = dest + .par_chunks_exact(num_of_bytes::(1)) + .map(|chunk| base_from_raw_bytes::(chunk)) + .collect::>(); + + // Now, flat_table is a field vector of size n, filled with random field elements + assert_eq!(flat_table.len(), 1 << lg_n); + + // Multiply -2 to every element to get the weights. Now weights = { -2x } + let mut weights: Vec = flat_table + .par_iter() + .map(|el| E::BaseField::ZERO - *el - *el) + .collect(); + + // Then invert all the elements. Now weights = { -1/2x } + let mut scratch_space = vec![E::BaseField::ZERO; weights.len()]; + BatchInverter::invert_with_external_scratch(&mut weights, &mut scratch_space); + + // Zip x and -1/2x together. The result is the list { (x, -1/2x) } + // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which + // involves computing 1/(b-a) where b=-x and a=x, and 1/(b-a) here is exactly -1/2x + let flat_table_w_weights = flat_table + .iter() + .zip(weights) + .map(|(el, w)| (*el, w)) + .collect_vec(); + + // Split the positions from 0 to n-1 into slices of sizes: + // 2, 2, 4, 8, ..., n/2, exactly lg_n number of them + // The weights are (x, -1/2x), the table elements are just x + + let mut unflattened_table_w_weights = vec![Vec::new(); lg_n]; + let mut unflattened_table = vec![Vec::new(); lg_n]; + + unflattened_table_w_weights[0] = flat_table_w_weights[1..2].to_vec(); + unflattened_table[0] = flat_table[1..2].to_vec(); + for i in 1..lg_n { + unflattened_table[i] = flat_table[(1 << i)..(1 << (i + 1))].to_vec(); + let mut level = flat_table_w_weights[(1 << i)..(1 << (i + 1))].to_vec(); + reverse_index_bits_in_place(&mut level); + unflattened_table_w_weights[i] = level; + } + + (unflattened_table_w_weights, unflattened_table) +} + +pub fn query_root_table_from_rng_aes( + level: usize, + index: usize, + cipher: &mut ctr::Ctr32LE, +) -> E::BaseField { + let mut level_offset: u128 = 1; + for lg_m in 1..=level { + let half_m = 1 << (lg_m - 1); + level_offset += half_m; + } + + let pos = ((level_offset + (reverse_bits(index, level) as u128)) + * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) + .checked_div(8) + .unwrap(); + + cipher.seek(pos); + + let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; + let mut dest: Vec = vec![0u8; bytes]; + cipher.apply_keystream(&mut dest); + + base_from_raw_bytes::(&dest) +} + +#[cfg(test)] +mod tests { + use crate::basefold::encoding::test_util::test_codeword_folding; + + use super::*; + use goldilocks::GoldilocksExt2; + use multilinear_extensions::mle::DenseMultilinearExtension; + + #[test] + fn time_rs_code() { + use rand::rngs::OsRng; + + let poly = DenseMultilinearExtension::random(20, &mut OsRng); + + encode_field_type_rs_basecode::(&poly.evaluations, 2, 64); + } + + #[test] + fn prover_verifier_consistency() { + type Code = Basecode; + let pp: BasecodeParameters = Code::setup(10, [0; 32]); + let (pp, vp) = Code::trim(&pp, 10).unwrap(); + for level in 0..(10 + >::get_rate_log()) { + for index in 0..(1 << level) { + assert_eq!( + Code::prover_folding_coeffs(&pp, level, index), + Code::verifier_folding_coeffs(&vp, level, index), + "failed for level = {}, index = {}", + level, + index + ); + } + } + } + + #[test] + fn test_basecode_codeword_folding() { + test_codeword_folding::>(); + } +} diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs new file mode 100644 index 000000000..79c047060 --- /dev/null +++ b/mpcs/src/basefold/encoding/rs.rs @@ -0,0 +1,877 @@ +use std::marker::PhantomData; + +use super::{EncodingProverParameters, EncodingScheme}; +use crate::{ + util::{field_type_index_mul_base, log2_strict, plonky2_util::reverse_bits}, + vec_mut, Error, +}; +use ark_std::{end_timer, start_timer}; +use ff::{Field, PrimeField}; +use ff_ext::ExtensionField; +use multilinear_extensions::mle::FieldType; + +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::util::plonky2_util::reverse_index_bits_in_place; + +use crate::util::arithmetic::horner; + +pub trait RSCodeSpec: std::fmt::Debug + Clone { + fn get_number_queries() -> usize; + + fn get_rate_log() -> usize; + + fn get_basecode_msg_size_log() -> usize; +} + +/// The FFT codes in this file are borrowed and adapted from Plonky2. +type FftRootTable = Vec>; + +pub fn fft_root_table(lg_n: usize) -> FftRootTable { + // bases[i] = g^2^i, for i = 0, ..., lg_n - 1 + // Note that the end of bases is g^{n/2} = -1 + let mut bases = Vec::with_capacity(lg_n); + let mut base = F::ROOT_OF_UNITY.pow([(1 << (F::S - lg_n as u32)) as u64]); + bases.push(base); + for _ in 1..lg_n { + base = base.square(); // base = g^2^_ + bases.push(base); + } + + // The result table looks like this: + // len=2: [1, g^{n/2}=-1] + // len=2: [1, g^{n/4}] + // len=4: [1, g^{n/8}, g^{n/4}, g^{3n/8}] + // len=8: [1, g^{n/16}, ..., g^{7n/16}] + // ... + // len=n/2: [1, g, ..., g^{n/2-1}] + // There is no need to compute the other halves of these powers, because + // those would be simply the negations of the previous halves. + let mut root_table = Vec::with_capacity(lg_n); + for lg_m in 1..=lg_n { + let half_m = 1 << (lg_m - 1); + let base = bases[lg_n - lg_m]; + let mut root_row = Vec::with_capacity(half_m.max(2)); + root_row.push(F::ONE); + for i in 1..half_m.max(2) { + root_row.push(root_row[i - 1] * base); + } + root_table.push(root_row); + } + root_table +} + +#[allow(unused)] +fn ifft( + poly: &mut FieldType, + zero_factor: usize, + root_table: &FftRootTable, +) { + let n = poly.len(); + let lg_n = log2_strict(n); + let n_inv = (E::BaseField::ONE + E::BaseField::ONE) + .invert() + .unwrap() + .pow([lg_n as u64]); + + fft(poly, zero_factor, root_table); + + // We reverse all values except the first, and divide each by n. + field_type_index_mul_base(poly, 0, &n_inv); + field_type_index_mul_base(poly, n / 2, &n_inv); + vec_mut!(|poly| for i in 1..(n / 2) { + let j = n - i; + let coeffs_i = poly[j] * n_inv; + let coeffs_j = poly[i] * n_inv; + poly[i] = coeffs_i; + poly[j] = coeffs_j; + }) +} + +/// Core FFT implementation. +fn fft_classic_inner( + values: &mut FieldType, + r: usize, + lg_n: usize, + root_table: &[Vec], +) { + // We've already done the first lg_packed_width (if they were required) iterations. + + for (lg_half_m, cur_root_table) in root_table.iter().enumerate().take(lg_n).skip(r) { + let n = 1 << lg_n; + let lg_m = lg_half_m + 1; + let m = 1 << lg_m; // Subarray size (in field elements). + let half_m = m / 2; + debug_assert!(half_m != 0); + + // omega values for this iteration, as slice of vectors + let omega_table = &cur_root_table[..]; + vec_mut!(|values| { + for k in (0..n).step_by(m) { + for j in 0..half_m { + let omega = omega_table[j]; + let t = values[k + half_m + j] * omega; + let u = values[k + j]; + values[k + j] = u + t; + values[k + half_m + j] = u - t; + } + } + }) + } +} + +/// FFT implementation based on Section 32.3 of "Introduction to +/// Algorithms" by Cormen et al. +/// +/// The parameter r signifies that the first 1/2^r of the entries of +/// input may be non-zero, but the last 1 - 1/2^r entries are +/// definitely zero. +pub fn fft( + values: &mut FieldType, + r: usize, + root_table: &[Vec], +) { + vec_mut!(|values| reverse_index_bits_in_place(values)); + + let n = values.len(); + let lg_n = log2_strict(n); + + if root_table.len() != lg_n { + panic!( + "Expected root table of length {}, but it was {}.", + lg_n, + root_table.len() + ); + } + + // After reverse_index_bits, the only non-zero elements of values + // are at indices i*2^r for i = 0..n/2^r. The loop below copies + // the value at i*2^r to the positions [i*2^r + 1, i*2^r + 2, ..., + // (i+1)*2^r - 1]; i.e. it replaces the 2^r - 1 zeros following + // element i*2^r with the value at i*2^r. This corresponds to the + // first r rounds of the FFT when there are 2^r zeros at the end + // of the original input. + if r > 0 { + // if r == 0 then this loop is a noop. + let mask = !((1 << r) - 1); + match values { + FieldType::Base(values) => { + for i in 0..n { + values[i] = values[i & mask]; + } + } + FieldType::Ext(values) => { + for i in 0..n { + values[i] = values[i & mask]; + } + } + _ => panic!("Unsupported field type"), + } + } + + fft_classic_inner::(values, r, lg_n, root_table); +} + +pub fn coset_fft( + coeffs: &mut FieldType, + shift: E::BaseField, + zero_factor: usize, + root_table: &[Vec], +) { + let mut shift_power = E::BaseField::ONE; + vec_mut!(|coeffs| { + for coeff in coeffs.iter_mut() { + *coeff *= shift_power; + shift_power *= shift; + } + }); + fft(coeffs, zero_factor, root_table); +} + +#[derive(Debug, Clone)] +pub struct RSCodeDefaultSpec {} + +impl RSCodeSpec for RSCodeDefaultSpec { + fn get_number_queries() -> usize { + 336 + } + + fn get_rate_log() -> usize { + 3 + } + + fn get_basecode_msg_size_log() -> usize { + 7 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct RSCodeParameters { + pub(crate) fft_root_table: FftRootTable, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct RSCodeProverParameters { + pub(crate) fft_root_table: FftRootTable, + pub(crate) gamma_powers: Vec, + pub(crate) gamma_powers_inv_div_two: Vec, + pub(crate) full_message_size_log: usize, +} + +impl EncodingProverParameters for RSCodeProverParameters { + fn get_max_message_size_log(&self) -> usize { + self.full_message_size_log + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RSCodeVerifierParameters +where + E::BaseField: Serialize + DeserializeOwned, +{ + /// The verifier also needs a FFT table (much smaller) + /// for small-size encoding. It contains the same roots as the + /// prover's version for the first few levels (i < basecode_msg_size_log) + /// For the other levels (i >= basecode_msg_size_log), + /// it contains only the g^(2^i). + pub(crate) fft_root_table: FftRootTable, + pub(crate) full_message_size_log: usize, + pub(crate) gamma_powers: Vec, + pub(crate) gamma_powers_inv_div_two: Vec, +} + +#[derive(Debug, Clone)] +pub struct RSCode { + _phantom_data: PhantomData, +} + +impl EncodingScheme for RSCode +where + E::BaseField: Serialize + DeserializeOwned, +{ + type PublicParameters = RSCodeParameters; + + type ProverParameters = RSCodeProverParameters; + + type VerifierParameters = RSCodeVerifierParameters; + + fn setup(max_message_size_log: usize, _rng_seed: [u8; 32]) -> Self::PublicParameters { + RSCodeParameters { + fft_root_table: fft_root_table(max_message_size_log + Spec::get_rate_log()), + } + } + + fn trim( + pp: &Self::PublicParameters, + max_message_size_log: usize, + ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { + if pp.fft_root_table.len() < max_message_size_log + Spec::get_rate_log() { + return Err(Error::InvalidPcsParam(format!( + "Public parameter is setup for a smaller message size (log={}) than the trimmed message size (log={})", + pp.fft_root_table.len() - Spec::get_rate_log(), + max_message_size_log, + ))); + } + let mut gamma_powers = Vec::with_capacity(max_message_size_log); + let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); + gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); + gamma_powers_inv.push(E::BaseField::MULTIPLICATIVE_GENERATOR.invert().unwrap()); + for i in 1..max_message_size_log + Spec::get_rate_log() { + gamma_powers.push(gamma_powers[i - 1].square()); + gamma_powers_inv.push(gamma_powers_inv[i - 1].square()); + } + let inv_of_two = E::BaseField::from(2).invert().unwrap(); + gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); + Ok(( + Self::ProverParameters { + fft_root_table: pp.fft_root_table[..max_message_size_log + Spec::get_rate_log()] + .to_vec(), + gamma_powers: gamma_powers.clone(), + gamma_powers_inv_div_two: gamma_powers_inv.clone(), + full_message_size_log: max_message_size_log, + }, + Self::VerifierParameters { + fft_root_table: pp.fft_root_table + [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] + .iter() + .cloned() + .chain( + pp.fft_root_table + [Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] + .iter() + .map(|v| vec![v[1]]), + ) + .collect(), + full_message_size_log: max_message_size_log, + gamma_powers, + gamma_powers_inv_div_two: gamma_powers_inv, + }, + )) + } + + fn encode(pp: &Self::ProverParameters, coeffs: &FieldType) -> FieldType { + // Use the full message size to determine the shift factor. + Self::encode_internal(&pp.fft_root_table, coeffs, pp.full_message_size_log) + } + + fn encode_small(vp: &Self::VerifierParameters, coeffs: &FieldType) -> FieldType { + // Use the full message size to determine the shift factor. + Self::encode_internal(&vp.fft_root_table, coeffs, vp.full_message_size_log) + } + + fn get_number_queries() -> usize { + Spec::get_number_queries() + } + + fn get_rate_log() -> usize { + Spec::get_rate_log() + } + + fn get_basecode_msg_size_log() -> usize { + Spec::get_basecode_msg_size_log() + } + + fn message_is_left_and_right_folding() -> bool { + false + } + + fn prover_folding_coeffs(pp: &Self::ProverParameters, level: usize, index: usize) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // level is the logarithmic of the codeword size after folded. + // Therefore, the domain after folded is gamma^2^(full_log_n - level) H + // where H is the multiplicative subgroup of size 2^level. + // The element at index i in this domain is + // gamma^2^(full_log_n - level) * ((2^level)-th root of unity)^i + // The x0 and x1 are exactly the two square roots, i.e., + // x0 = gamma^2^(full_log_n - level - 1) * ((2^(level+1))-th root of unity)^i + // Since root_table[i] stores the first half of the powers of + // the 2^(i+1)-th roots of unity, we can avoid recomputing them. + let x0 = if index < (1 << level) { + pp.fft_root_table[level][index] + } else { + -pp.fft_root_table[level][index - (1 << level)] + } * pp.gamma_powers[pp.full_message_size_log + Spec::get_rate_log() - level - 1]; + let x1 = -x0; + // The weight is 1/(x1-x0) = -1/(2x0) + // = -1/2 * (gamma^{-1})^2^(full_codeword_log_n - level - 1) * ((2^(level+1))-th root of unity)^{2^(level+1)-i} + let w = -pp.gamma_powers_inv_div_two + [pp.full_message_size_log + Spec::get_rate_log() - level - 1] + * if index == 0 { + E::BaseField::ONE + } else if index < (1 << level) { + -pp.fft_root_table[level][(1 << level) - index] + } else if index == 1 << level { + -E::BaseField::ONE + } else { + pp.fft_root_table[level][(1 << (level + 1)) - index] + }; + (E::from(x0), E::from(x1), E::from(w)) + } + + fn verifier_folding_coeffs( + vp: &Self::VerifierParameters, + level: usize, + index: usize, + ) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // The same as prover_folding_coeffs, exept that the powers of + // g is computed on the fly for levels exceeding the root table. + let x0 = if level < Spec::get_basecode_msg_size_log() + Spec::get_rate_log() { + if index < (1 << level) { + vp.fft_root_table[level][index] + } else { + -vp.fft_root_table[level][index - (1 << level)] + } + } else { + // In this case, the level-th row of fft root table of the verifier + // only stores the first 2^(level+1)-th roots of unity. + vp.fft_root_table[level][0].pow([index as u64]) + } * vp.gamma_powers[vp.full_message_size_log + Spec::get_rate_log() - level - 1]; + let x1 = -x0; + // The weight is 1/(x1-x0) = -1/(2x0) + // = -1/2 * (gamma^{-1})^2^(full_log_n - level - 1) * ((2^(level+1))-th root of unity)^{2^(level+1)-i} + let w = -vp.gamma_powers_inv_div_two + [vp.full_message_size_log + Spec::get_rate_log() - level - 1] + * if level < Spec::get_basecode_msg_size_log() + Spec::get_rate_log() { + if index == 0 { + E::BaseField::ONE + } else if index < (1 << level) { + -vp.fft_root_table[level][(1 << level) - index] + } else if index == 1 << level { + -E::BaseField::ONE + } else { + vp.fft_root_table[level][(1 << (level + 1)) - index] + } + } else { + // In this case, this level of fft root table of the verifier + // only stores the first 2^(level+1)-th root of unity. + vp.fft_root_table[level][0].pow([(1 << (level + 1)) - index as u64]) + }; + (E::from(x0), E::from(x1), E::from(w)) + } +} + +impl RSCode { + fn encode_internal( + fft_root_table: &FftRootTable, + coeffs: &FieldType, + full_message_size_log: usize, + ) -> FieldType + where + E::BaseField: Serialize + DeserializeOwned, + { + let lg_m = log2_strict(coeffs.len()); + let fft_root_table = &fft_root_table[..lg_m + Spec::get_rate_log()]; + assert!( + lg_m <= full_message_size_log, + "Encoded message exceeds the maximum supported message size of the table." + ); + let rate = 1 << Spec::get_rate_log(); + let mut ret = match coeffs { + FieldType::Base(coeffs) => { + let mut coeffs = coeffs.clone(); + coeffs.extend(itertools::repeat_n( + E::BaseField::ZERO, + coeffs.len() * (rate - 1), + )); + FieldType::Base(coeffs) + } + FieldType::Ext(coeffs) => { + let mut coeffs = coeffs.clone(); + coeffs.extend(itertools::repeat_n(E::ZERO, coeffs.len() * (rate - 1))); + FieldType::Ext(coeffs) + } + _ => panic!("Unsupported field type"), + }; + // Let gamma be the multiplicative generator of the base field. + // The full domain is gamma H where H is the multiplicative subgroup + // of size n * rate. + // When the input message size is not n, but n/2^k, then the domain is + // gamma^2^k H. + let k = 1 << (full_message_size_log - lg_m); + coset_fft( + &mut ret, + E::BaseField::MULTIPLICATIVE_GENERATOR.pow([k]), + Spec::get_rate_log(), + fft_root_table, + ); + ret + } + + #[allow(unused)] + fn folding_coeffs_naive( + level: usize, + index: usize, + full_message_size_log: usize, + ) -> (E, E, E) { + // The coefficients are for the bit-reversed codeword, so reverse the + // bits before providing the coefficients. + let index = reverse_bits(index, level); + // x0 is the index-th 2^(level+1)-th root of unity, multiplied by + // the shift factor at level+1, which is gamma^2^(full_codeword_log_n - level - 1). + let x0 = E::BaseField::ROOT_OF_UNITY + .pow([1 << (E::BaseField::S - (level as u32 + 1))]) + .pow([index as u64]) + * E::BaseField::MULTIPLICATIVE_GENERATOR + .pow([1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]); + let x1 = -x0; + let w = (x1 - x0).invert().unwrap(); + (E::from(x0), E::from(x1), E::from(w)) + } +} + +#[allow(unused)] +fn naive_fft(poly: &[E], rate: usize, shift: E::BaseField) -> Vec { + let timer = start_timer!(|| "Encode RSCode"); + let message_size = poly.len(); + let domain_size_bit = log2_strict(message_size * rate); + let root = E::BaseField::ROOT_OF_UNITY.pow([1 << (E::BaseField::S - domain_size_bit as u32)]); + // The domain is shift * H where H is the multiplicative subgroup of size + // message_size * rate. + let mut domain = Vec::::with_capacity(message_size * rate); + domain.push(shift); + for i in 1..message_size * rate { + domain.push(domain[i - 1] * root); + } + let mut res = vec![E::ZERO; message_size * rate]; + res.iter_mut() + .enumerate() + .for_each(|(i, target)| *target = horner(poly, &E::from(domain[i]))); + end_timer!(timer); + + res +} + +#[cfg(test)] +mod tests { + use crate::{ + basefold::encoding::test_util::test_codeword_folding, + util::{field_type_index_ext, plonky2_util::reverse_index_bits_in_place_field_type}, + }; + + use super::*; + use goldilocks::{Goldilocks, GoldilocksExt2}; + + #[test] + fn test_naive_fft() { + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let mut poly2 = FieldType::Ext(poly.clone()); + + let naive = naive_fft::(&poly, 1, Goldilocks::ONE); + + let root_table = fft_root_table(num_vars); + fft::(&mut poly2, 0, &root_table); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_naive_fft_with_shift() { + use rand::rngs::OsRng; + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)) + .map(|_| GoldilocksExt2::random(&mut OsRng)) + .collect(); + let mut poly2 = FieldType::Ext(poly.clone()); + + let naive = naive_fft::(&poly, 1, Goldilocks::MULTIPLICATIVE_GENERATOR); + + let root_table = fft_root_table(num_vars); + coset_fft::( + &mut poly2, + Goldilocks::MULTIPLICATIVE_GENERATOR, + 0, + &root_table, + ); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_naive_fft_with_rate() { + use rand::rngs::OsRng; + let num_vars = 5; + let rate_bits = 1; + + let poly: Vec = (0..(1 << num_vars)) + .map(|_| GoldilocksExt2::random(&mut OsRng)) + .collect(); + let mut poly2 = vec![GoldilocksExt2::ZERO; poly.len() * (1 << rate_bits)]; + poly2.as_mut_slice()[..poly.len()].copy_from_slice(poly.as_slice()); + let mut poly2 = FieldType::Ext(poly2.clone()); + + let naive = naive_fft::( + &poly, + 1 << rate_bits, + Goldilocks::MULTIPLICATIVE_GENERATOR, + ); + + let root_table = fft_root_table(num_vars + rate_bits); + coset_fft::( + &mut poly2, + Goldilocks::MULTIPLICATIVE_GENERATOR, + rate_bits, + &root_table, + ); + + let poly2 = match poly2 { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + assert_eq!(naive, poly2); + } + + #[test] + fn test_ifft() { + let num_vars = 5; + + let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let mut poly = FieldType::Ext(poly); + let original = poly.clone(); + + let root_table = fft_root_table(num_vars); + fft::(&mut poly, 0, &root_table); + ifft::(&mut poly, 0, &root_table); + + assert_eq!(original, poly); + } + + #[test] + fn prover_verifier_consistency() { + type Code = RSCode; + let pp: RSCodeParameters = Code::setup(10, [0; 32]); + let (pp, vp) = Code::trim(&pp, 10).unwrap(); + for level in 0..(10 + >::get_rate_log()) { + for index in 0..(1 << level) { + let (naive_x0, naive_x1, naive_w) = + Code::folding_coeffs_naive(level, index, pp.full_message_size_log); + let (p_x0, p_x1, p_w) = Code::prover_folding_coeffs(&pp, level, index); + let (v_x0, v_x1, v_w) = Code::verifier_folding_coeffs(&vp, level, index); + // assert_eq!(v_w * (v_x1 - v_x0), GoldilocksExt2::ONE); + // assert_eq!(p_w * (p_x1 - p_x0), GoldilocksExt2::ONE); + assert_eq!( + (v_x0, v_x1, v_w, p_x0, p_x1, p_w), + (naive_x0, naive_x1, naive_w, naive_x0, naive_x1, naive_w), + "failed for level = {}, index = {}", + level, + index + ); + } + } + } + + #[test] + fn test_rs_codeword_folding() { + test_codeword_folding::>(); + } + + type E = GoldilocksExt2; + type F = Goldilocks; + type Code = RSCode; + + #[test] + pub fn test_colinearity() { + let num_vars = 10; + + let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp = >::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + reverse_index_bits_in_place_field_type(&mut codeword); + let challenge = E::from(2); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let codeword = match codeword { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + + for (i, (a, b)) in folded_codeword.iter().zip(codeword.chunks(2)).enumerate() { + let (x0, x1, _) = Code::prover_folding_coeffs( + &pp, + num_vars + >::get_rate_log() - 1, + i, + ); + // Check that (x0, b[0]), (x1, b[1]) and (challenge, a) are + // on the same line, i.e., + // (b[0]-a)/(x0-challenge) = (b[1]-a)/(x1-challenge) + // which is equivalent to + // (x0-challenge)*(b[1]-a) = (x1-challenge)*(b[0]-a) + assert_eq!( + (x0 - challenge) * (b[1] - a), + (x1 - challenge) * (b[0] - a), + "failed for i = {}", + i + ); + } + } + + #[test] + pub fn test_low_degree() { + let num_vars = 10; + + let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly = FieldType::Ext(poly); + + let rng_seed = [0; 32]; + let pp = >::setup(num_vars, rng_seed); + let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let mut codeword = Code::encode(&pp, &poly); + check_low_degree(&codeword, "low degree check for original codeword"); + let c0 = field_type_index_ext(&codeword, 0); + let c_mid = field_type_index_ext(&codeword, codeword.len() >> 1); + let c1 = field_type_index_ext(&codeword, 1); + let c_mid1 = field_type_index_ext(&codeword, (codeword.len() >> 1) + 1); + + reverse_index_bits_in_place_field_type(&mut codeword); + // After the bit inversion, the first element is still the first, + // but the middle one is switched to the second. + assert_eq!(c0, field_type_index_ext(&codeword, 0)); + assert_eq!(c_mid, field_type_index_ext(&codeword, 1)); + // The second element is placed at the middle, and the next to middle + // element is still at the place. + assert_eq!(c1, field_type_index_ext(&codeword, codeword.len() >> 1)); + assert_eq!( + c_mid1, + field_type_index_ext(&codeword, (codeword.len() >> 1) + 1) + ); + + // For RS codeword, the addition of the left and right halves is also + // a valid codeword + let codeword_vec = match &codeword { + FieldType::Ext(coeffs) => coeffs.clone(), + _ => panic!("Wrong field type"), + }; + let mut left_right_sum: Vec = codeword_vec + .chunks(2) + .map(|chunk| chunk[0] + chunk[1]) + .collect(); + assert_eq!(left_right_sum[0], c0 + c_mid); + reverse_index_bits_in_place(&mut left_right_sum); + assert_eq!(left_right_sum[1], c1 + c_mid1); + check_low_degree( + &FieldType::Ext(left_right_sum.clone()), + "check low degree of left+right", + ); + + // The the difference of the left and right halves is also + // a valid codeword after twisted by omega^(-i), regardless of the + // shift of the coset. + let mut left_right_diff: Vec = codeword_vec + .chunks(2) + .map(|chunk| chunk[0] - chunk[1]) + .collect(); + assert_eq!(left_right_diff[0], c0 - c_mid); + reverse_index_bits_in_place(&mut left_right_diff); + assert_eq!(left_right_diff[1], c1 - c_mid1); + let root_of_unity_inv = F::ROOT_OF_UNITY_INV + .pow([1 << (F::S as usize - log2_strict(left_right_diff.len()) - 1)]); + for (i, coeff) in left_right_diff.iter_mut().enumerate() { + *coeff *= root_of_unity_inv.pow([i as u64]); + } + assert_eq!(left_right_diff[0], c0 - c_mid); + assert_eq!(left_right_diff[1], (c1 - c_mid1) * root_of_unity_inv); + check_low_degree( + &FieldType::Ext(left_right_diff.clone()), + "check low degree of (left-right)*omega^(-i)", + ); + + let challenge = E::from(2); + let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); + let c_fold = folded_codeword[0]; + let c_fold1 = folded_codeword[folded_codeword.len() >> 1]; + let mut folded_codeword = FieldType::Ext(folded_codeword); + reverse_index_bits_in_place_field_type(&mut folded_codeword); + assert_eq!(c_fold, field_type_index_ext(&folded_codeword, 0)); + assert_eq!(c_fold1, field_type_index_ext(&folded_codeword, 1)); + + // The top level folding coefficient should have shift factor gamma + let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 0); + assert_eq!(folding_coeffs.0, E::from(F::MULTIPLICATIVE_GENERATOR)); + assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); + assert_eq!( + (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, + E::ONE + ); + // The three points (x0, c0), (x1, c_mid), (challenge, c_fold) should + // be colinear + assert_eq!( + (c_mid - c_fold) * (folding_coeffs.0 - challenge), + (c0 - c_fold) * (folding_coeffs.1 - challenge), + ); + // So the folded value should be equal to + // (gamma^{-1} * alpha * (c0 - c_mid) + (c0 + c_mid)) / 2 + assert_eq!( + c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), + challenge * (c0 - c_mid) + (c0 + c_mid) * F::MULTIPLICATIVE_GENERATOR + ); + assert_eq!( + c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), + challenge * left_right_diff[0] + left_right_sum[0] * F::MULTIPLICATIVE_GENERATOR + ); + assert_eq!( + c_fold * F::from(2), + challenge * left_right_diff[0] * F::MULTIPLICATIVE_GENERATOR.invert().unwrap() + + left_right_sum[0] + ); + + let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 1); + let root_of_unity = + F::ROOT_OF_UNITY.pow([1 << (F::S as usize - log2_strict(codeword.len()))]); + assert_eq!(root_of_unity.pow([codeword.len() as u64]), F::ONE); + assert_eq!(root_of_unity.pow([(codeword.len() >> 1) as u64]), -F::ONE); + assert_eq!( + folding_coeffs.0, + E::from(F::MULTIPLICATIVE_GENERATOR) + * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) + ); + assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); + assert_eq!( + (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, + E::ONE + ); + + // The folded codeword is the linear combination of the left+right and the + // twisted left-right vectors. + // The coefficients are respectively 1/2 and gamma^{-1}/2 * alpha. + // In another word, the folded codeword multipled by 2 is the linear + // combination by coeffs: 1 and gamma^{-1} * alpha + let gamma_inv = F::MULTIPLICATIVE_GENERATOR.invert().unwrap(); + let b = challenge * gamma_inv; + let folded_codeword_vec = match &folded_codeword { + FieldType::Ext(coeffs) => coeffs.clone(), + _ => panic!("Wrong field type"), + }; + assert_eq!( + c_fold * F::from(2), + left_right_diff[0] * b + left_right_sum[0] + ); + for (i, (c, (diff, sum))) in folded_codeword_vec + .iter() + .zip(left_right_diff.iter().zip(left_right_sum.iter())) + .enumerate() + { + assert_eq!(*c + c, *sum + b * diff, "failed for i = {}", i); + } + + check_low_degree(&folded_codeword, "low degree check for folded"); + } + + fn check_low_degree(codeword: &FieldType, message: &str) { + let mut codeword = codeword.clone(); + let codeword_bits = log2_strict(codeword.len()); + let root_table = fft_root_table(codeword_bits); + let original = codeword.clone(); + ifft(&mut codeword, 0, &root_table); + for i in (codeword.len() >> >::get_rate_log())..codeword.len() { + assert_eq!( + field_type_index_ext(&codeword, i), + E::ZERO, + "{}: zero check failed for i = {}", + message, + i + ) + } + fft(&mut codeword, 0, &root_table); + let original = match original { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + let codeword = match codeword { + FieldType::Ext(coeffs) => coeffs, + _ => panic!("Wrong field type"), + }; + original + .iter() + .zip(codeword.iter()) + .enumerate() + .for_each(|(i, (a, b))| { + assert_eq!(a, b, "{}: failed for i = {}", message, i); + }); + } +} diff --git a/mpcs/src/basefold/encoding/utils.rs b/mpcs/src/basefold/encoding/utils.rs new file mode 100644 index 000000000..36137526a --- /dev/null +++ b/mpcs/src/basefold/encoding/utils.rs @@ -0,0 +1,35 @@ +#[macro_export] +macro_rules! vec_mut { + ($a:ident, |$tmp_a:ident| $op:expr) => { + match $a { + multilinear_extensions::mle::FieldType::Base(ref mut $tmp_a) => $op, + multilinear_extensions::mle::FieldType::Ext(ref mut $tmp_a) => $op, + _ => unreachable!(), + } + }; + (|$a:ident| $op:expr) => { + vec_mut!($a, |$a| $op) + }; +} + +#[macro_export] +macro_rules! vec_map { + ($a:ident, |$tmp_a:ident| $op:expr) => { + match &$a { + multilinear_extensions::mle::FieldType::Base(a) => { + let $tmp_a = &a[..]; + let out = $op; + multilinear_extensions::mle::FieldType::Base(out) + } + multilinear_extensions::mle::FieldType::Ext(a) => { + let $tmp_a = &a[..]; + let out = $op; + multilinear_extensions::mle::FieldType::Base(out) + } + _ => unreachable!(), + } + }; + (|$a:ident| $op:expr) => { + vec_map!($a, |$a| $op) + }; +} diff --git a/mpcs/src/basefold/query_phase.rs b/mpcs/src/basefold/query_phase.rs index dacd254a1..57d41db68 100644 --- a/mpcs/src/basefold/query_phase.rs +++ b/mpcs/src/basefold/query_phase.rs @@ -1,37 +1,35 @@ -use super::basecode::{encode_rs_basecode, query_point}; use crate::util::{ arithmetic::{ - degree_2_eval, degree_2_zero_plus_one, inner_product, interpolate2, + degree_2_eval, degree_2_zero_plus_one, inner_product, interpolate2_weights, interpolate_over_boolean_hypercube, }, - ext_to_usize, + ext_to_usize, field_type_index_base, field_type_index_ext, hash::{Digest, Hasher}, log2_strict, merkle_tree::{MerklePathWithoutLeafOrRoot, MerkleTree}, transcript::{TranscriptRead, TranscriptWrite}, }; -use aes::cipher::KeyIvInit; use ark_std::{end_timer, start_timer}; use core::fmt::Debug; -use ctr; use ff_ext::ExtensionField; -use generic_array::GenericArray; use itertools::Itertools; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use multilinear_extensions::mle::FieldType; -use crate::util::plonky2_util::{reverse_bits, reverse_index_bits_in_place}; -use rand_chacha::{rand_core::RngCore, ChaCha8Rng}; +use crate::util::plonky2_util::reverse_index_bits_in_place; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use super::structure::{BasefoldCommitment, BasefoldCommitmentWithData}; +use super::{ + encoding::EncodingScheme, + structure::{BasefoldCommitment, BasefoldCommitmentWithData, BasefoldSpec}, +}; -pub fn query_phase( +pub fn prover_query_phase( transcript: &mut impl TranscriptWrite, E>, comm: &BasefoldCommitmentWithData, - oracles: &Vec>, + oracles: &[Vec], num_verifier_queries: usize, ) -> QueriesResult where @@ -51,16 +49,298 @@ where .map(|x_index| { ( *x_index, - basefold_get_query::(comm.get_codeword(), &oracles, *x_index), + basefold_get_query::(&comm.get_codewords()[0], oracles, *x_index), ) }) .collect(), } } +pub fn batch_prover_query_phase( + transcript: &mut impl TranscriptWrite, E>, + codeword_size: usize, + comms: &[BasefoldCommitmentWithData], + oracles: &[Vec], + num_verifier_queries: usize, +) -> BatchedQueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + let queries = transcript.squeeze_challenges(num_verifier_queries); + + // Transform the challenge queries from field elements into integers + let queries_usize: Vec = queries + .iter() + .map(|x_index| ext_to_usize(x_index) % codeword_size) + .collect_vec(); + + BatchedQueriesResult { + inner: queries_usize + .par_iter() + .map(|x_index| { + ( + *x_index, + batch_basefold_get_query::(comms, oracles, codeword_size, *x_index), + ) + }) + .collect(), + } +} + +pub fn simple_batch_prover_query_phase( + transcript: &mut impl TranscriptWrite, E>, + comm: &BasefoldCommitmentWithData, + oracles: &[Vec], + num_verifier_queries: usize, +) -> SimpleBatchQueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + let queries = transcript.squeeze_challenges(num_verifier_queries); + + // Transform the challenge queries from field elements into integers + let queries_usize: Vec = queries + .iter() + .map(|x_index| ext_to_usize(x_index) % comm.codeword_size()) + .collect_vec(); + + SimpleBatchQueriesResult { + inner: queries_usize + .par_iter() + .map(|x_index| { + ( + *x_index, + simple_batch_basefold_get_query::(comm.get_codewords(), oracles, *x_index), + ) + }) + .collect(), + } +} + +#[allow(clippy::too_many_arguments)] +pub fn verifier_query_phase>( + vp: &>::VerifierParameters, + queries: &QueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &Vec>, + comm: &BasefoldCommitment, + partial_eq: &[E], + eval: &E, + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier query phase"); + + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + interpolate_over_boolean_hypercube(&mut message); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + queries.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + &final_codeword, + roots, + comm, + hasher, + ); + + let final_timer = start_timer!(|| "Final checks"); + assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + + end_timer!(timer); +} + +#[allow(clippy::too_many_arguments)] +pub fn batch_verifier_query_phase>( + vp: &>::VerifierParameters, + queries: &BatchedQueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], + coeffs: &[E], + partial_eq: &[E], + eval: &E, + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier batch query phase"); + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + interpolate_over_boolean_hypercube(&mut message); + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + // For computing the weights on the fly, because the verifier is incapable of storing + // the weights. + + queries.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + &final_codeword, + roots, + comms, + coeffs, + hasher, + ); + + #[allow(unused)] + let final_timer = start_timer!(|| "Final checks"); + assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + end_timer!(timer); +} + +#[allow(clippy::too_many_arguments)] +pub fn simple_batch_verifier_query_phase>( + vp: &>::VerifierParameters, + queries: &SimpleBatchQueriesResultWithMerklePath, + sum_check_messages: &[Vec], + fold_challenges: &[E], + batch_coeffs: &[E], + num_rounds: usize, + num_vars: usize, + final_message: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + partial_eq: &[E], + evals: &[E], + hasher: &Hasher, +) where + E::BaseField: Serialize + DeserializeOwned, +{ + let timer = start_timer!(|| "Verifier query phase"); + + let encode_timer = start_timer!(|| "Encode final codeword"); + let mut message = final_message.to_vec(); + if >::message_is_even_and_odd_folding() { + reverse_index_bits_in_place(&mut message); + } + interpolate_over_boolean_hypercube(&mut message); + let final_codeword = + >::encode_small(vp, &FieldType::Ext(message)); + let mut final_codeword = match final_codeword { + FieldType::Ext(final_codeword) => final_codeword, + _ => panic!("Final codeword must be extension field"), + }; + reverse_index_bits_in_place(&mut final_codeword); + end_timer!(encode_timer); + + // For computing the weights on the fly, because the verifier is incapable of storing + // the weights. + queries.check::( + vp, + fold_challenges, + batch_coeffs, + num_rounds, + num_vars, + &final_codeword, + roots, + comm, + hasher, + ); + + let final_timer = start_timer!(|| "Final checks"); + assert_eq!( + &inner_product(batch_coeffs, evals), + °ree_2_zero_plus_one(&sum_check_messages[0]) + ); + + // The sum-check part of the protocol + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + degree_2_eval(&sum_check_messages[i], fold_challenges[i]), + degree_2_zero_plus_one(&sum_check_messages[i + 1]) + ); + } + + // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial + // sent from the prover + assert_eq!( + degree_2_eval( + &sum_check_messages[fold_challenges.len() - 1], + fold_challenges[fold_challenges.len() - 1] + ), + inner_product(final_message, partial_eq) + ); + end_timer!(final_timer); + + end_timer!(timer); +} + fn basefold_get_query( poly_codeword: &FieldType, - oracles: &Vec>, + oracles: &[Vec], x_index: usize, ) -> SingleQueryResult where @@ -96,15 +376,15 @@ where inner: oracle_queries, }; - return SingleQueryResult { + SingleQueryResult { oracle_query, commitment_query, - }; + } } fn batch_basefold_get_query( - comms: &[&BasefoldCommitmentWithData], - oracles: &Vec>, + comms: &[BasefoldCommitmentWithData], + oracles: &[Vec], codeword_size: usize, x_index: usize, ) -> BatchedSingleQueryResult @@ -133,7 +413,7 @@ where let x_index = x_index >> (log2_strict(codeword_size) - comm.codeword_size_log()); let p1 = x_index | 1; let p0 = p1 - 1; - match comm.get_codeword() { + match &comm.get_codewords()[0] { FieldType::Ext(poly_codeword) => { CodewordSingleQueryResult::new_ext(poly_codeword[p0], poly_codeword[p1], p0) } @@ -155,6 +435,66 @@ where } } +fn simple_batch_basefold_get_query( + poly_codewords: &[FieldType], + oracles: &[Vec], + x_index: usize, +) -> SimpleBatchSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + let mut index = x_index; + let p1 = index | 1; + let p0 = p1 - 1; + + let commitment_query = match poly_codewords[0] { + FieldType::Ext(_) => SimpleBatchCommitmentSingleQueryResult::new_ext( + poly_codewords + .iter() + .map(|c| field_type_index_ext(c, p0)) + .collect(), + poly_codewords + .iter() + .map(|c| field_type_index_ext(c, p1)) + .collect(), + p0, + ), + FieldType::Base(_) => SimpleBatchCommitmentSingleQueryResult::new_base( + poly_codewords + .iter() + .map(|c| field_type_index_base(c, p0)) + .collect(), + poly_codewords + .iter() + .map(|c| field_type_index_base(c, p1)) + .collect(), + p0, + ), + _ => unreachable!(), + }; + index >>= 1; + + let mut oracle_queries = Vec::with_capacity(oracles.len() + 1); + for oracle in oracles { + let p1 = index | 1; + let p0 = p1 - 1; + + oracle_queries.push(CodewordSingleQueryResult::new_ext( + oracle[p0], oracle[p1], p0, + )); + index >>= 1; + } + + let oracle_query = OracleListQueryResult { + inner: oracle_queries, + }; + + SimpleBatchSingleQueryResult { + oracle_query, + commitment_query, + } +} + #[derive(Debug, Copy, Clone, Serialize, Deserialize)] enum CodewordPointPair { Ext(E, E), @@ -170,6 +510,51 @@ impl CodewordPointPair { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +enum SimpleBatchLeavesPair +where + E::BaseField: Serialize + DeserializeOwned, +{ + Ext(Vec<(E, E)>), + Base(Vec<(E::BaseField, E::BaseField)>), +} + +impl SimpleBatchLeavesPair +where + E::BaseField: Serialize + DeserializeOwned, +{ + #[allow(unused)] + pub fn as_ext(&self) -> Vec<(E, E)> { + match self { + SimpleBatchLeavesPair::Ext(x) => x.clone(), + SimpleBatchLeavesPair::Base(x) => { + x.iter().map(|(x, y)| ((*x).into(), (*y).into())).collect() + } + } + } + + pub fn batch(&self, coeffs: &[E]) -> (E, E) { + match self { + SimpleBatchLeavesPair::Ext(x) => { + let mut result = (E::ZERO, E::ZERO); + for (i, (x, y)) in x.iter().enumerate() { + result.0 += coeffs[i] * *x; + result.1 += coeffs[i] * *y; + } + result + } + SimpleBatchLeavesPair::Base(x) => { + let mut result = (E::ZERO, E::ZERO); + for (i, (x, y)) in x.iter().enumerate() { + result.0 += coeffs[i] * *x; + result.1 += coeffs[i] * *y; + } + result + } + } + } +} + #[derive(Debug, Copy, Clone, Serialize, Deserialize)] struct CodewordSingleQueryResult where @@ -510,12 +895,9 @@ where ) -> Vec> { let ret = self .get_inner() - .into_iter() + .iter() .enumerate() - .map(|(i, query_result)| { - let path = path(i, query_result.index); - path - }) + .map(|(i, query_result)| path(i, query_result.index)) .collect_vec(); ret } @@ -537,7 +919,7 @@ where query_result .merkle_path(path) .into_iter() - .zip(query_result.get_inner_into().into_iter()) + .zip(query_result.get_inner_into()) .map( |(path, codeword_result)| CodewordSingleQueryResultWithMerklePath { query: codeword_result, @@ -554,7 +936,7 @@ where .for_each(|q| q.write_transcript(transcript)); } - fn check_merkle_paths(&self, roots: &Vec>, hasher: &Hasher) { + fn check_merkle_paths(&self, roots: &[Digest], hasher: &Hasher) { // let timer = start_timer!(|| "ListQuery::Check Merkle Path"); self.get_inner() .iter() @@ -590,16 +972,17 @@ where { pub fn from_single_query_result( single_query_result: SingleQueryResult, - oracle_trees: &Vec>, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { + assert!(commitment.codeword_tree.height() > 0); Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( single_query_result.oracle_query, |i, j| oracle_trees[i].merkle_path_without_leaf_sibling_or_root(j), ), commitment_query: CodewordSingleQueryResultWithMerklePath { - query: single_query_result.commitment_query.clone(), + query: single_query_result.commitment_query, merkle_path: commitment .codeword_tree .merkle_path_without_leaf_sibling_or_root( @@ -660,23 +1043,23 @@ where } } - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, - mut cipher: ctr::Ctr32LE, index: usize, hasher: &Hasher, ) { // let timer = start_timer!(|| "Checking codeword single query"); self.oracle_query.check_merkle_paths(roots, hasher); self.commitment_query - .check_merkle_path(&Digest(comm.root().0.try_into().unwrap()), hasher); + .check_merkle_path(&Digest(comm.root().0), hasher); let (mut curr_left, mut curr_right) = self.commitment_query.query.codepoints.as_ext(); @@ -684,18 +1067,14 @@ where let mut left_index = right_index - 1; for i in 0..num_rounds { - // let round_timer = start_timer!(|| format!("SingleQueryResult::round {}", i)); - let ri0 = reverse_bits(left_index, num_vars + log_rate - i); - - let x0 = E::from(query_point::( - 1 << (num_vars + log_rate - i), - ri0, - num_vars + log_rate - i - 1, - &mut cipher, - )); - let x1 = -x0; + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, + ); - let res = interpolate2([(x0, curr_left), (x1, curr_right)], fold_challenges[i]); + let res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); let next_index = right_index >> 1; let next_oracle_value = if i < num_rounds - 1 { @@ -720,6 +1099,135 @@ where } } +pub struct QueriesResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + inner: Vec<(usize, SingleQueryResult)>, +} + +pub struct QueriesResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + inner: Vec<(usize, SingleQueryResultWithMerklePath)>, +} + +impl QueriesResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn from_query_result( + query_result: QueriesResult, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithData, + ) -> Self { + Self { + inner: query_result + .inner + .into_iter() + .map(|(i, q)| { + ( + i, + SingleQueryResultWithMerklePath::from_single_query_result( + q, + oracle_trees, + commitment, + ), + ) + }) + .collect(), + } + } + + pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + self.inner.iter().for_each(|(_, q)| { + q.write_transcript(transcript); + }); + } + + pub fn read_transcript_base( + transcript: &mut impl TranscriptRead, E>, + num_rounds: usize, + log_rate: usize, + poly_num_vars: usize, + indices: &[usize], + ) -> Self { + Self { + inner: indices + .iter() + .map(|index| { + ( + *index, + SingleQueryResultWithMerklePath::read_transcript_base( + transcript, + num_rounds, + log_rate, + poly_num_vars, + *index, + ), + ) + }) + .collect(), + } + } + + pub fn read_transcript_ext( + transcript: &mut impl TranscriptRead, E>, + num_rounds: usize, + log_rate: usize, + poly_num_vars: usize, + indices: &[usize], + ) -> Self { + Self { + inner: indices + .iter() + .map(|index| { + ( + *index, + SingleQueryResultWithMerklePath::read_transcript_ext( + transcript, + num_rounds, + log_rate, + poly_num_vars, + *index, + ), + ) + }) + .collect(), + } + } + + #[allow(clippy::too_many_arguments)] + pub fn check>( + &self, + vp: &>::VerifierParameters, + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_codeword: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + hasher: &Hasher, + ) { + let timer = start_timer!(|| "QueriesResult::check"); + self.inner.par_iter().for_each(|(index, query)| { + query.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + final_codeword, + roots, + comm, + *index, + hasher, + ); + }); + end_timer!(timer); + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] struct BatchedSingleQueryResult where @@ -744,8 +1252,8 @@ where { pub fn from_batched_single_query_result( batched_single_query_result: BatchedSingleQueryResult, - oracle_trees: &Vec>, - commitments: &Vec<&BasefoldCommitmentWithData>, + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithData], ) -> Self { Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( @@ -818,23 +1326,29 @@ where } } - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, + final_codeword: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], coeffs: &[E], - mut cipher: ctr::Ctr32LE, index: usize, hasher: &Hasher, ) { self.oracle_query.check_merkle_paths(roots, hasher); - self.commitments_query - .check_merkle_paths(&comms.iter().map(|comm| comm.root()).collect(), hasher); + self.commitments_query.check_merkle_paths( + comms + .iter() + .map(|comm| comm.root()) + .collect_vec() + .as_slice(), + hasher, + ); // end_timer!(commit_timer); let mut curr_left = E::ZERO; @@ -845,7 +1359,6 @@ where for i in 0..num_rounds { // let round_timer = start_timer!(|| format!("BatchedSingleQueryResult::round {}", i)); - let ri0 = reverse_bits(left_index, num_vars + log_rate - i); let matching_comms = comms .iter() .enumerate() @@ -854,21 +1367,20 @@ where .collect_vec(); matching_comms.iter().for_each(|index| { - let query = self.commitments_query.get_inner()[*index].query.clone(); + let query = self.commitments_query.get_inner()[*index].query; assert_eq!(query.index >> 1, left_index >> 1); curr_left += query.left_ext() * coeffs[*index]; curr_right += query.right_ext() * coeffs[*index]; }); - let x0: E = E::from(query_point::( - 1 << (num_vars + log_rate - i), - ri0, - num_vars + log_rate - i - 1, - &mut cipher, - )); - let x1 = -x0; + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, + ); - let mut res = interpolate2([(x0, curr_left), (x1, curr_right)], fold_challenges[i]); + let mut res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); let next_index = right_index >> 1; @@ -901,7 +1413,7 @@ where matching_comms.iter().for_each(|index| { let query: CodewordSingleQueryResult = - self.commitments_query.get_inner()[*index].query.clone(); + self.commitments_query.get_inner()[*index].query; assert_eq!(query.index >> 1, next_index >> 1); if next_index & 1 == 0 { res += query.left_ext() * coeffs[*index]; @@ -941,8 +1453,8 @@ where { pub fn from_batched_query_result( batched_query_result: BatchedQueriesResult, - oracle_trees: &Vec>, - commitments: &Vec<&BasefoldCommitmentWithData>, + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithData], ) -> Self { Self { inner: batched_query_result @@ -963,34 +1475,355 @@ where } pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { - self.inner - .iter() - .for_each(|(_, q)| q.write_transcript(transcript)); + self.inner + .iter() + .for_each(|(_, q)| q.write_transcript(transcript)); + } + + pub fn read_transcript_base( + transcript: &mut impl TranscriptRead, E>, + num_rounds: usize, + log_rate: usize, + poly_num_vars: &[usize], + indices: &[usize], + ) -> Self { + Self { + inner: indices + .iter() + .map(|index| { + ( + *index, + BatchedSingleQueryResultWithMerklePath::read_transcript_base( + transcript, + num_rounds, + log_rate, + poly_num_vars, + *index, + ), + ) + }) + .collect(), + } + } + + pub fn read_transcript_ext( + transcript: &mut impl TranscriptRead, E>, + num_rounds: usize, + log_rate: usize, + poly_num_vars: &[usize], + indices: &[usize], + ) -> Self { + Self { + inner: indices + .iter() + .map(|index| { + ( + *index, + BatchedSingleQueryResultWithMerklePath::read_transcript_ext( + transcript, + num_rounds, + log_rate, + poly_num_vars, + *index, + ), + ) + }) + .collect(), + } + } + + #[allow(clippy::too_many_arguments)] + pub fn check>( + &self, + vp: &>::VerifierParameters, + fold_challenges: &[E], + num_rounds: usize, + num_vars: usize, + final_codeword: &[E], + roots: &[Digest], + comms: &[&BasefoldCommitment], + coeffs: &[E], + hasher: &Hasher, + ) { + let timer = start_timer!(|| "BatchedQueriesResult::check"); + self.inner.par_iter().for_each(|(index, query)| { + query.check::( + vp, + fold_challenges, + num_rounds, + num_vars, + final_codeword, + roots, + comms, + coeffs, + *index, + hasher, + ); + }); + end_timer!(timer); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleBatchCommitmentSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + leaves: SimpleBatchLeavesPair, + index: usize, +} + +impl SimpleBatchCommitmentSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + fn new_ext(left: Vec, right: Vec, index: usize) -> Self { + Self { + leaves: SimpleBatchLeavesPair::Ext(left.into_iter().zip(right).collect()), + index, + } + } + + fn new_base(left: Vec, right: Vec, index: usize) -> Self { + Self { + leaves: SimpleBatchLeavesPair::Base(left.into_iter().zip(right).collect()), + index, + } + } + + #[allow(unused)] + fn left_ext(&self) -> Vec { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => x.iter().map(|(x, _)| *x).collect(), + SimpleBatchLeavesPair::Base(x) => x.iter().map(|(x, _)| E::from(*x)).collect(), + } + } + + #[allow(unused)] + fn right_ext(&self) -> Vec { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => x.iter().map(|(_, x)| *x).collect(), + SimpleBatchLeavesPair::Base(x) => x.iter().map(|(_, x)| E::from(*x)).collect(), + } + } + + pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + match &self.leaves { + SimpleBatchLeavesPair::Ext(x) => { + x.iter().for_each(|(x, y)| { + transcript.write_field_element_ext(x).unwrap(); + transcript.write_field_element_ext(y).unwrap(); + }); + } + SimpleBatchLeavesPair::Base(x) => { + x.iter().for_each(|(x, y)| { + transcript.write_field_element_base(x).unwrap(); + transcript.write_field_element_base(y).unwrap(); + }); + } + }; + } + + pub fn read_transcript_ext( + transcript: &mut impl TranscriptRead, E>, + full_codeword_size_log: usize, + codeword_size_log: usize, + index: usize, + batch_size: usize, + ) -> Self { + let mut left = vec![]; + let mut right = vec![]; + (0..batch_size).for_each(|_| { + left.push(transcript.read_field_element_ext().unwrap()); + right.push(transcript.read_field_element_ext().unwrap()); + }); + Self::new_ext( + left, + right, + index >> (full_codeword_size_log - codeword_size_log), + ) + } + + pub fn read_transcript_base( + transcript: &mut impl TranscriptRead, E>, + full_codeword_size_log: usize, + codeword_size_log: usize, + index: usize, + batch_size: usize, + ) -> Self { + let mut left = vec![]; + let mut right = vec![]; + (0..batch_size).for_each(|_| { + left.push(transcript.read_field_element_base().unwrap()); + right.push(transcript.read_field_element_base().unwrap()); + }); + Self::new_base( + left, + right, + index >> (full_codeword_size_log - codeword_size_log), + ) + } +} + +#[derive(Debug, Clone)] +struct SimpleBatchCommitmentSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + query: SimpleBatchCommitmentSingleQueryResult, + merkle_path: MerklePathWithoutLeafOrRoot, +} + +impl SimpleBatchCommitmentSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + self.query.write_transcript(transcript); + self.merkle_path.write_transcript(transcript); + } + + pub fn read_transcript_base( + transcript: &mut impl TranscriptRead, E>, + full_codeword_size_log: usize, + codeword_size_log: usize, + index: usize, + batch_size: usize, + ) -> Self { + Self { + query: SimpleBatchCommitmentSingleQueryResult::read_transcript_base( + transcript, + full_codeword_size_log, + codeword_size_log, + index, + batch_size, + ), + merkle_path: MerklePathWithoutLeafOrRoot::read_transcript( + transcript, + codeword_size_log, + ), + } + } + + pub fn read_transcript_ext( + transcript: &mut impl TranscriptRead, E>, + full_codeword_size_log: usize, + codeword_size_log: usize, + index: usize, + batch_size: usize, + ) -> Self { + Self { + query: SimpleBatchCommitmentSingleQueryResult::read_transcript_ext( + transcript, + full_codeword_size_log, + codeword_size_log, + index, + batch_size, + ), + merkle_path: MerklePathWithoutLeafOrRoot::read_transcript( + transcript, + codeword_size_log, + ), + } + } + + pub fn check_merkle_path(&self, root: &Digest, hasher: &Hasher) { + // let timer = start_timer!(|| "CodewordSingleQuery::Check Merkle Path"); + match &self.query.leaves { + SimpleBatchLeavesPair::Ext(inner) => { + self.merkle_path.authenticate_batch_leaves_root_ext( + inner.iter().map(|(x, _)| *x).collect(), + inner.iter().map(|(_, x)| *x).collect(), + self.query.index, + root, + hasher, + ); + } + SimpleBatchLeavesPair::Base(inner) => { + self.merkle_path.authenticate_batch_leaves_root_base( + inner.iter().map(|(x, _)| *x).collect(), + inner.iter().map(|(_, x)| *x).collect(), + self.query.index, + root, + hasher, + ); + } + } + // end_timer!(timer); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleBatchSingleQueryResult +where + E::BaseField: Serialize + DeserializeOwned, +{ + oracle_query: OracleListQueryResult, + commitment_query: SimpleBatchCommitmentSingleQueryResult, +} + +#[derive(Debug, Clone)] +struct SimpleBatchSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + oracle_query: OracleListQueryResultWithMerklePath, + commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath, +} + +impl SimpleBatchSingleQueryResultWithMerklePath +where + E::BaseField: Serialize + DeserializeOwned, +{ + pub fn from_single_query_result( + single_query_result: SimpleBatchSingleQueryResult, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithData, + ) -> Self { + Self { + oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( + single_query_result.oracle_query, + |i, j| oracle_trees[i].merkle_path_without_leaf_sibling_or_root(j), + ), + commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath { + query: single_query_result.commitment_query.clone(), + merkle_path: commitment + .codeword_tree + .merkle_path_without_leaf_sibling_or_root( + single_query_result.commitment_query.index, + ), + }, + } + } + + pub fn write_transcript(&self, transcript: &mut impl TranscriptWrite, E>) { + self.oracle_query.write_transcript(transcript); + self.commitment_query.write_transcript(transcript); } pub fn read_transcript_base( transcript: &mut impl TranscriptRead, E>, num_rounds: usize, log_rate: usize, - poly_num_vars: &[usize], - indices: &[usize], + num_vars: usize, + index: usize, + batch_size: usize, ) -> Self { Self { - inner: indices - .iter() - .map(|index| { - ( - *index, - BatchedSingleQueryResultWithMerklePath::read_transcript_base( - transcript, - num_rounds, - log_rate, - poly_num_vars, - *index, - ), - ) - }) - .collect(), + oracle_query: OracleListQueryResultWithMerklePath::read_transcript( + transcript, + num_rounds, + num_vars + log_rate, + index, + ), + commitment_query: + SimpleBatchCommitmentSingleQueryResultWithMerklePath::read_transcript_base( + transcript, + num_vars + log_rate, + num_vars + log_rate, + index, + batch_size, + ), } } @@ -998,82 +1831,109 @@ where transcript: &mut impl TranscriptRead, E>, num_rounds: usize, log_rate: usize, - poly_num_vars: &[usize], - indices: &[usize], + num_vars: usize, + index: usize, + batch_size: usize, ) -> Self { Self { - inner: indices - .iter() - .map(|index| { - ( - *index, - BatchedSingleQueryResultWithMerklePath::read_transcript_ext( - transcript, - num_rounds, - log_rate, - poly_num_vars, - *index, - ), - ) - }) - .collect(), + oracle_query: OracleListQueryResultWithMerklePath::read_transcript( + transcript, + num_rounds, + num_vars + log_rate, + index, + ), + commitment_query: + SimpleBatchCommitmentSingleQueryResultWithMerklePath::read_transcript_ext( + transcript, + num_vars + log_rate, + num_vars + log_rate, + index, + batch_size, + ), } } - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, - coeffs: &[E], - cipher: ctr::Ctr32LE, + final_codeword: &[E], + roots: &[Digest], + comm: &BasefoldCommitment, + index: usize, hasher: &Hasher, ) { - let timer = start_timer!(|| "BatchedQueriesResult::check"); - self.inner.par_iter().for_each(|(index, query)| { - query.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - final_codeword, - roots, - comms, - coeffs, - cipher.clone(), - *index, - hasher, + let timer = start_timer!(|| "Checking codeword single query"); + self.oracle_query.check_merkle_paths(roots, hasher); + self.commitment_query + .check_merkle_path(&Digest(comm.root().0), hasher); + + let (mut curr_left, mut curr_right) = + self.commitment_query.query.leaves.batch(batch_coeffs); + + let mut right_index = index | 1; + let mut left_index = right_index - 1; + + for i in 0..num_rounds { + // let round_timer = start_timer!(|| format!("SingleQueryResult::round {}", i)); + + let (x0, x1, w) = >::verifier_folding_coeffs( + vp, + num_vars + Spec::get_rate_log() - i - 1, + left_index >> 1, ); - }); + + let res = + interpolate2_weights([(x0, curr_left), (x1, curr_right)], w, fold_challenges[i]); + + let next_index = right_index >> 1; + let next_oracle_value = if i < num_rounds - 1 { + right_index = next_index | 1; + left_index = right_index - 1; + let next_oracle_query = self.oracle_query.get_inner()[i].clone(); + (curr_left, curr_right) = next_oracle_query.query.codepoints.as_ext(); + if next_index & 1 == 0 { + curr_left + } else { + curr_right + } + } else { + // Note that final_codeword has been bit-reversed, so no need to bit-reverse + // next_index here. + final_codeword[next_index] + }; + assert_eq!(res, next_oracle_value, "Failed at round {}", i); + // end_timer!(round_timer); + } end_timer!(timer); } } -pub struct QueriesResult +pub struct SimpleBatchQueriesResult where E::BaseField: Serialize + DeserializeOwned, { - inner: Vec<(usize, SingleQueryResult)>, + inner: Vec<(usize, SimpleBatchSingleQueryResult)>, } -pub struct QueriesResultWithMerklePath +pub struct SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, { - inner: Vec<(usize, SingleQueryResultWithMerklePath)>, + inner: Vec<(usize, SimpleBatchSingleQueryResultWithMerklePath)>, } -impl QueriesResultWithMerklePath +impl SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, { pub fn from_query_result( - query_result: QueriesResult, - oracle_trees: &Vec>, + query_result: SimpleBatchQueriesResult, + oracle_trees: &[MerkleTree], commitment: &BasefoldCommitmentWithData, ) -> Self { Self { @@ -1083,7 +1943,7 @@ where .map(|(i, q)| { ( i, - SingleQueryResultWithMerklePath::from_single_query_result( + SimpleBatchSingleQueryResultWithMerklePath::from_single_query_result( q, oracle_trees, commitment, @@ -1106,6 +1966,7 @@ where log_rate: usize, poly_num_vars: usize, indices: &[usize], + batch_size: usize, ) -> Self { Self { inner: indices @@ -1113,12 +1974,13 @@ where .map(|index| { ( *index, - SingleQueryResultWithMerklePath::read_transcript_base( + SimpleBatchSingleQueryResultWithMerklePath::read_transcript_base( transcript, num_rounds, log_rate, poly_num_vars, *index, + batch_size, ), ) }) @@ -1132,6 +1994,7 @@ where log_rate: usize, poly_num_vars: usize, indices: &[usize], + batch_size: usize, ) -> Self { Self { inner: indices @@ -1139,12 +2002,13 @@ where .map(|index| { ( *index, - SingleQueryResultWithMerklePath::read_transcript_ext( + SimpleBatchSingleQueryResultWithMerklePath::read_transcript_ext( transcript, num_rounds, log_rate, poly_num_vars, *index, + batch_size, ), ) }) @@ -1152,29 +2016,30 @@ where } } - pub fn check( + #[allow(clippy::too_many_arguments)] + pub fn check>( &self, - fold_challenges: &Vec, + vp: &>::VerifierParameters, + fold_challenges: &[E], + batch_coeffs: &[E], num_rounds: usize, num_vars: usize, - log_rate: usize, - final_codeword: &Vec, - roots: &Vec>, + final_codeword: &[E], + roots: &[Digest], comm: &BasefoldCommitment, - cipher: ctr::Ctr32LE, hasher: &Hasher, ) { let timer = start_timer!(|| "QueriesResult::check"); self.inner.par_iter().for_each(|(index, query)| { - query.check( + query.check::( + vp, fold_challenges, + batch_coeffs, num_rounds, num_vars, - log_rate, final_codeword, roots, comm, - cipher.clone(), *index, hasher, ); @@ -1182,199 +2047,3 @@ where end_timer!(timer); } } - -pub fn batch_query_phase( - transcript: &mut impl TranscriptWrite, E>, - codeword_size: usize, - comms: &[&BasefoldCommitmentWithData], - oracles: &Vec>, - num_verifier_queries: usize, -) -> BatchedQueriesResult -where - E::BaseField: Serialize + DeserializeOwned, -{ - let queries = transcript.squeeze_challenges(num_verifier_queries); - - // Transform the challenge queries from field elements into integers - let queries_usize: Vec = queries - .iter() - .map(|x_index| ext_to_usize(x_index) % codeword_size) - .collect_vec(); - - BatchedQueriesResult { - inner: queries_usize - .par_iter() - .map(|x_index| { - ( - *x_index, - batch_basefold_get_query::(comms, &oracles, codeword_size, *x_index), - ) - }) - .collect(), - } -} - -pub fn verifier_query_phase( - queries: &QueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, - num_rounds: usize, - num_vars: usize, - log_rate: usize, - final_message: &Vec, - roots: &Vec>, - comm: &BasefoldCommitment, - partial_eq: &[E], - rng: ChaCha8Rng, - eval: &E, - hasher: &Hasher, -) where - E::BaseField: Serialize + DeserializeOwned, -{ - let timer = start_timer!(|| "Verifier query phase"); - - let encode_timer = start_timer!(|| "Encode final codeword"); - let mut message = final_message.clone(); - interpolate_over_boolean_hypercube(&mut message); - let mut final_codeword = encode_rs_basecode(&message, 1 << log_rate, message.len()); - assert_eq!(final_codeword.len(), 1); - let mut final_codeword = final_codeword.remove(0); - reverse_index_bits_in_place(&mut final_codeword); - end_timer!(encode_timer); - - // For computing the weights on the fly, because the verifier is incapable of storing - // the weights. - let aes_timer = start_timer!(|| "Initialize AES"); - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - let mut rng = rng.clone(); - rng.set_word_pos(0); - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - let cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - end_timer!(aes_timer); - - queries.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - &final_codeword, - roots, - comm, - cipher, - hasher, - ); - - let final_timer = start_timer!(|| "Final checks"); - assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); - - // The sum-check part of the protocol - for i in 0..fold_challenges.len() - 1 { - assert_eq!( - degree_2_eval(&sum_check_messages[i], fold_challenges[i]), - degree_2_zero_plus_one(&sum_check_messages[i + 1]) - ); - } - - // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial - // sent from the prover - assert_eq!( - degree_2_eval( - &sum_check_messages[fold_challenges.len() - 1], - fold_challenges[fold_challenges.len() - 1] - ), - inner_product(final_message, partial_eq) - ); - end_timer!(final_timer); - - end_timer!(timer); -} - -pub fn batch_verifier_query_phase( - queries: &BatchedQueriesResultWithMerklePath, - sum_check_messages: &Vec>, - fold_challenges: &Vec, - num_rounds: usize, - num_vars: usize, - log_rate: usize, - final_message: &Vec, - roots: &Vec>, - comms: &Vec<&BasefoldCommitment>, - coeffs: &[E], - partial_eq: &[E], - rng: ChaCha8Rng, - eval: &E, - hasher: &Hasher, -) where - E::BaseField: Serialize + DeserializeOwned, -{ - let timer = start_timer!(|| "Verifier batch query phase"); - let encode_timer = start_timer!(|| "Encode final codeword"); - let mut message = final_message.clone(); - interpolate_over_boolean_hypercube(&mut message); - let mut final_codeword = encode_rs_basecode(&message, 1 << log_rate, message.len()); - assert_eq!(final_codeword.len(), 1); - let mut final_codeword = final_codeword.remove(0); - reverse_index_bits_in_place(&mut final_codeword); - end_timer!(encode_timer); - - // For computing the weights on the fly, because the verifier is incapable of storing - // the weights. - let aes_timer = start_timer!(|| "Initialize AES"); - let mut key: [u8; 16] = [0u8; 16]; - let mut iv: [u8; 16] = [0u8; 16]; - let mut rng = rng.clone(); - rng.set_word_pos(0); - rng.fill_bytes(&mut key); - rng.fill_bytes(&mut iv); - - type Aes128Ctr64LE = ctr::Ctr32LE; - let cipher = Aes128Ctr64LE::new( - GenericArray::from_slice(&key[..]), - GenericArray::from_slice(&iv[..]), - ); - end_timer!(aes_timer); - - queries.check( - fold_challenges, - num_rounds, - num_vars, - log_rate, - &final_codeword, - roots, - comms, - coeffs, - cipher, - hasher, - ); - - #[allow(unused)] - let final_timer = start_timer!(|| "Final checks"); - assert_eq!(eval, °ree_2_zero_plus_one(&sum_check_messages[0])); - - // The sum-check part of the protocol - for i in 0..fold_challenges.len() - 1 { - assert_eq!( - degree_2_eval(&sum_check_messages[i], fold_challenges[i]), - degree_2_zero_plus_one(&sum_check_messages[i + 1]) - ); - } - - // Finally, the last sumcheck poly evaluation should be the same as the sum of the polynomial - // sent from the prover - assert_eq!( - degree_2_eval( - &sum_check_messages[fold_challenges.len() - 1], - fold_challenges[fold_challenges.len() - 1] - ), - inner_product(final_message, partial_eq) - ); - end_timer!(final_timer); - end_timer!(timer); -} diff --git a/mpcs/src/basefold/structure.rs b/mpcs/src/basefold/structure.rs index 8dd39e75b..0a1972c38 100644 --- a/mpcs/src/basefold/structure.rs +++ b/mpcs/src/basefold/structure.rs @@ -2,44 +2,51 @@ use crate::util::{hash::Digest, merkle_tree::MerkleTree}; use core::fmt::Debug; use ff_ext::ExtensionField; +use rand::RngCore; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use multilinear_extensions::mle::FieldType; -use rand_chacha::rand_core::RngCore; +use rand_chacha::ChaCha8Rng; use std::{marker::PhantomData, slice}; +pub use super::encoding::{EncodingProverParameters, EncodingScheme, RSCode, RSCodeDefaultSpec}; +use super::{Basecode, BasecodeDefaultSpec}; + #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldParams +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldParams> where E::BaseField: Serialize + DeserializeOwned, { - pub(super) log_rate: usize, - pub(super) num_verifier_queries: usize, - pub(super) max_num_vars: usize, - pub(super) table_w_weights: Vec>, - pub(super) table: Vec>, - pub(super) rng: Rng, + pub(super) params: >::PublicParameters, } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldProverParams -where - E::BaseField: Serialize + DeserializeOwned, -{ - pub(super) log_rate: usize, - pub(super) table_w_weights: Vec>, - pub(super) table: Vec>, - pub(super) num_verifier_queries: usize, - pub(super) max_num_vars: usize, +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldProverParams> { + pub encoding_params: >::ProverParameters, +} + +impl> BasefoldProverParams { + pub fn get_max_message_size_log(&self) -> usize { + self.encoding_params.get_max_message_size_log() + } } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BasefoldVerifierParams { - pub(super) rng: Rng, - pub(super) max_num_vars: usize, - pub(super) log_rate: usize, - pub(super) num_verifier_queries: usize, +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct BasefoldVerifierParams> { + pub(super) encoding_params: >::VerifierParameters, } /// A polynomial commitment together with all the data (e.g., the codeword, and Merkle tree) @@ -51,9 +58,10 @@ where E::BaseField: Serialize + DeserializeOwned, { pub(crate) codeword_tree: MerkleTree, - pub(crate) bh_evals: FieldType, + pub(crate) polynomials_bh_evals: Vec>, pub(crate) num_vars: usize, pub(crate) is_base: bool, + pub(crate) num_polys: usize, } impl BasefoldCommitmentWithData @@ -61,7 +69,12 @@ where E::BaseField: Serialize + DeserializeOwned, { pub fn to_commitment(&self) -> BasefoldCommitment { - BasefoldCommitment::new(self.codeword_tree.root(), self.num_vars, self.is_base) + BasefoldCommitment::new( + self.codeword_tree.root(), + self.num_vars, + self.is_base, + self.num_polys, + ) } pub fn get_root_ref(&self) -> &Digest { @@ -72,12 +85,16 @@ where Digest::(self.get_root_ref().0) } - pub fn get_codeword(&self) -> &FieldType { + pub fn get_codewords(&self) -> &Vec> { self.codeword_tree.leaves() } + pub fn batch_codewords(&self, coeffs: &Vec) -> Vec { + self.codeword_tree.batch_leaves(coeffs) + } + pub fn codeword_size(&self) -> usize { - self.codeword_tree.size() + self.codeword_tree.size().1 } pub fn codeword_size_log(&self) -> usize { @@ -85,14 +102,14 @@ where } pub fn poly_size(&self) -> usize { - self.bh_evals.len() + 1 << self.num_vars } - pub fn get_codeword_entry_base(&self, index: usize) -> E::BaseField { + pub fn get_codeword_entry_base(&self, index: usize) -> Vec { self.codeword_tree.get_leaf_as_base(index) } - pub fn get_codeword_entry_ext(&self, index: usize) -> E { + pub fn get_codeword_entry_ext(&self, index: usize) -> Vec { self.codeword_tree.get_leaf_as_extension(index) } @@ -101,21 +118,21 @@ where } } -impl Into> for BasefoldCommitmentWithData +impl From> for Digest where E::BaseField: Serialize + DeserializeOwned, { - fn into(self) -> Digest { - self.get_root_as() + fn from(val: BasefoldCommitmentWithData) -> Self { + val.get_root_as() } } -impl Into> for &BasefoldCommitmentWithData +impl From<&BasefoldCommitmentWithData> for BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - fn into(self) -> BasefoldCommitment { - self.to_commitment() + fn from(val: &BasefoldCommitmentWithData) -> Self { + val.to_commitment() } } @@ -128,17 +145,24 @@ where pub(super) root: Digest, pub(super) num_vars: Option, pub(super) is_base: bool, + pub(super) num_polys: Option, } impl BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - pub fn new(root: Digest, num_vars: usize, is_base: bool) -> Self { + pub fn new( + root: Digest, + num_vars: usize, + is_base: bool, + num_polys: usize, + ) -> Self { Self { root, num_vars: Some(num_vars), is_base, + num_polys: Some(num_polys), } } @@ -153,14 +177,6 @@ where pub fn is_base(&self) -> bool { self.is_base } - - pub fn as_challenge_field(&self) -> BasefoldCommitment { - BasefoldCommitment:: { - root: Digest::(self.root().0), - num_vars: self.num_vars, - is_base: self.is_base, - } - } } impl PartialEq for BasefoldCommitmentWithData @@ -168,7 +184,8 @@ where E::BaseField: Serialize + DeserializeOwned, { fn eq(&self, other: &Self) -> bool { - self.get_codeword().eq(other.get_codeword()) && self.bh_evals.eq(&other.bh_evals) + self.get_codewords().eq(other.get_codewords()) + && self.polynomials_bh_evals.eq(&other.polynomials_bh_evals) } } @@ -177,37 +194,50 @@ impl Eq for BasefoldCommitmentWithData where { } -pub trait BasefoldExtParams: Debug { - fn get_reps() -> usize; +pub trait BasefoldSpec: Debug + Clone { + type EncodingScheme: EncodingScheme; - fn get_rate() -> usize; + fn get_number_queries() -> usize { + Self::EncodingScheme::get_number_queries() + } - fn get_basecode() -> usize; + fn get_rate_log() -> usize { + Self::EncodingScheme::get_rate_log() + } + + fn get_basecode_msg_size_log() -> usize { + Self::EncodingScheme::get_basecode_msg_size_log() + } } -#[derive(Debug)] -pub struct BasefoldDefaultParams; +#[derive(Debug, Clone)] +pub struct BasefoldBasecodeParams; -impl BasefoldExtParams for BasefoldDefaultParams { - fn get_reps() -> usize { - return 260; - } +impl BasefoldSpec for BasefoldBasecodeParams +where + E::BaseField: Serialize + DeserializeOwned, +{ + type EncodingScheme = Basecode; +} - fn get_rate() -> usize { - return 3; - } +#[derive(Debug, Clone)] +pub struct BasefoldRSParams; - fn get_basecode() -> usize { - return 7; - } +impl BasefoldSpec for BasefoldRSParams +where + E::BaseField: Serialize + DeserializeOwned, +{ + type EncodingScheme = RSCode; } #[derive(Debug)] -pub struct Basefold(PhantomData<(E, V)>); +pub struct Basefold, Rng: RngCore>( + PhantomData<(E, Spec, Rng)>, +); -pub type BasefoldDefault = Basefold; +pub type BasefoldDefault = Basefold; -impl Clone for Basefold { +impl, Rng: RngCore> Clone for Basefold { fn clone(&self) -> Self { Self(PhantomData) } diff --git a/mpcs/src/basefold/sumcheck.rs b/mpcs/src/basefold/sumcheck.rs index a6a84f0eb..bb5312111 100644 --- a/mpcs/src/basefold/sumcheck.rs +++ b/mpcs/src/basefold/sumcheck.rs @@ -7,8 +7,8 @@ use rayon::prelude::{ }; pub fn sum_check_first_round_field_type( - mut eq: &mut Vec, - mut bh_values: &mut FieldType, + eq: &mut Vec, + bh_values: &mut FieldType, ) -> Vec { // The input polynomials are in the form of evaluations. Instead of viewing // every one element as the evaluation of the polynomial at a single point, @@ -16,24 +16,21 @@ pub fn sum_check_first_round_field_type( // a single point, leaving the first variable free, and obtaining a univariate // polynomial. The one_level_interp_hc transforms the evaluation forms into // the coefficient forms, for every of these partial polynomials. - one_level_interp_hc(&mut eq); - one_level_interp_hc_field_type(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc_field_type(bh_values); parallel_pi_field_type(bh_values, eq) // p_i(&bh_values, &eq) } -pub fn sum_check_first_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, -) -> Vec { +pub fn sum_check_first_round(eq: &mut Vec, bh_values: &mut Vec) -> Vec { // The input polynomials are in the form of evaluations. Instead of viewing // every one element as the evaluation of the polynomial at a single point, // we can view every two elements as partially evaluating the polynomial at // a single point, leaving the first variable free, and obtaining a univariate // polynomial. The one_level_interp_hc transforms the evaluation forms into // the coefficient forms, for every of these partial polynomials. - one_level_interp_hc(&mut eq); - one_level_interp_hc(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc(bh_values); parallel_pi(bh_values, eq) // p_i(&bh_values, &eq) } @@ -51,7 +48,7 @@ pub fn one_level_interp_hc(evals: &mut Vec) { return; } evals.par_chunks_mut(2).for_each(|chunk| { - chunk[1] = chunk[1] - chunk[0]; + chunk[1] -= chunk[0]; }); } @@ -68,15 +65,15 @@ pub fn one_level_eval_hc(evals: &mut Vec, challenge: F) { }); } -fn parallel_pi_field_type(evals: &mut FieldType, eq: &mut Vec) -> Vec { +fn parallel_pi_field_type(evals: &mut FieldType, eq: &mut [E]) -> Vec { match evals { - FieldType::Ext(evals) => parallel_pi(evals, &eq), - FieldType::Base(evals) => parallel_pi_base(evals, &eq), + FieldType::Ext(evals) => parallel_pi(evals, eq), + FieldType::Base(evals) => parallel_pi_base(evals, eq), _ => unreachable!(), } } -fn parallel_pi(evals: &Vec, eq: &Vec) -> Vec { +fn parallel_pi(evals: &[F], eq: &[F]) -> Vec { if evals.len() == 1 { return vec![evals[0], evals[0], evals[0]]; } @@ -111,7 +108,7 @@ fn parallel_pi(evals: &Vec, eq: &Vec) -> Vec { coeffs } -fn parallel_pi_base(evals: &Vec, eq: &Vec) -> Vec { +fn parallel_pi_base(evals: &[E::BaseField], eq: &[E]) -> Vec { if evals.len() == 1 { return vec![E::from(evals[0]), E::from(evals[0]), E::from(evals[0])]; } @@ -147,31 +144,27 @@ fn parallel_pi_base(evals: &Vec, eq: &Vec) - } pub fn sum_check_challenge_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, + eq: &mut Vec, + bh_values: &mut Vec, challenge: F, ) -> Vec { // Note that when the last round ends, every two elements are in // the coefficient form. Use the challenge to reduce the two elements // into a single value. This is equivalent to substituting the challenge // to the first variable of the poly. - one_level_eval_hc(&mut bh_values, challenge); - one_level_eval_hc(&mut eq, challenge); + one_level_eval_hc(bh_values, challenge); + one_level_eval_hc(eq, challenge); - one_level_interp_hc(&mut eq); - one_level_interp_hc(&mut bh_values); + one_level_interp_hc(eq); + one_level_interp_hc(bh_values); - parallel_pi(&bh_values, &eq) + parallel_pi(bh_values, eq) // p_i(&bh_values,&eq) } -pub fn sum_check_last_round( - mut eq: &mut Vec, - mut bh_values: &mut Vec, - challenge: F, -) { - one_level_eval_hc(&mut bh_values, challenge); - one_level_eval_hc(&mut eq, challenge); +pub fn sum_check_last_round(eq: &mut Vec, bh_values: &mut Vec, challenge: F) { + one_level_eval_hc(bh_values, challenge); + one_level_eval_hc(eq, challenge); } #[cfg(test)] @@ -184,7 +177,7 @@ mod tests { use super::*; use crate::util::test::rand_vec; - pub fn p_i(evals: &Vec, eq: &Vec) -> Vec { + pub fn p_i(evals: &[F], eq: &[F]) -> Vec { if evals.len() == 1 { return vec![evals[0], evals[0], evals[0]]; } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 216b50d52..277f98fd0 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -29,8 +29,9 @@ pub fn pcs_setup>( pub fn pcs_trim>( param: &Pcs::Param, + poly_size: usize, ) -> Result<(Pcs::ProverParam, Pcs::VerifierParam), Error> { - Pcs::trim(param) + Pcs::trim(param, poly_size) } pub fn pcs_commit>( @@ -50,16 +51,16 @@ pub fn pcs_commit_and_write>( pp: &Pcs::ProverParam, - polys: &Vec>, -) -> Result, Error> { + polys: &[DenseMultilinearExtension], +) -> Result { Pcs::batch_commit(pp, polys) } -pub fn pcs_batch_commit_and_write<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( +pub fn pcs_batch_commit_and_write>( pp: &Pcs::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, -) -> Result, Error> { +) -> Result { Pcs::batch_commit_and_write(pp, polys, transcript) } @@ -76,8 +77,8 @@ pub fn pcs_open>( pub fn pcs_batch_open>( pp: &Pcs::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Pcs::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, @@ -114,7 +115,7 @@ pub fn pcs_verify>( pub fn pcs_batch_verify<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( vp: &Pcs::VerifierParam, - comms: &Vec, + comms: &[Pcs::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, @@ -136,7 +137,10 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn setup(poly_size: usize, rng: &Self::Rng) -> Result; - fn trim(param: &Self::Param) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; + fn trim( + param: &Self::Param, + poly_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; fn commit( pp: &Self::ProverParam, @@ -151,14 +155,14 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_commit( pp: &Self::ProverParam, - polys: &Vec>, - ) -> Result, Error>; + polys: &[DenseMultilinearExtension], + ) -> Result; fn batch_commit_and_write( pp: &Self::ProverParam, - polys: &Vec>, + polys: &[DenseMultilinearExtension], transcript: &mut impl TranscriptWrite, - ) -> Result, Error>; + ) -> Result; fn open( pp: &Self::ProverParam, @@ -171,13 +175,26 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptWrite, ) -> Result<(), Error>; + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + pp: &Self::ProverParam, + polys: &[DenseMultilinearExtension], + comm: &Self::CommitmentWithData, + point: &[E], + evals: &[E], + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error>; + fn read_commitment( vp: &Self::VerifierParam, transcript: &mut impl TranscriptRead, @@ -203,11 +220,19 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_verify( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], transcript: &mut impl TranscriptRead, ) -> Result<(), Error>; + + fn simple_batch_verify( + vp: &Self::VerifierParam, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error>; } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -234,8 +259,8 @@ where fn ni_batch_open( pp: &Self::ProverParam, - polys: &Vec>, - comms: &Vec, + polys: &[DenseMultilinearExtension], + comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], ) -> Result, Error> { @@ -257,7 +282,7 @@ where fn ni_batch_verify<'a>( vp: &Self::VerifierParam, - comms: &Vec, + comms: &[Self::Commitment], points: &[Vec], evals: &[Evaluation], proof: &PCSProof, @@ -308,18 +333,19 @@ pub enum Error { mod basefold; pub use basefold::{ - Basefold, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, - BasefoldDefaultParams, BasefoldExtParams, BasefoldParams, + coset_fft, fft, fft_root_table, Basecode, BasecodeDefaultSpec, Basefold, + BasefoldBasecodeParams, BasefoldCommitment, BasefoldCommitmentWithData, BasefoldDefault, + BasefoldParams, BasefoldRSParams, BasefoldSpec, EncodingScheme, RSCode, RSCodeDefaultSpec, }; fn validate_input( function: &str, param_num_vars: usize, - polys: &Vec>, - points: &Vec>, + polys: &[DenseMultilinearExtension], + points: &[Vec], ) -> Result<(), Error> { - let polys = polys.into_iter().collect_vec(); - let points = points.into_iter().collect_vec(); + let polys = polys.iter().collect_vec(); + let points = points.iter().collect_vec(); for poly in polys.iter() { if param_num_vars < poly.num_vars { return Err(err_too_many_variates( @@ -377,7 +403,7 @@ pub mod test_util { let rng = ChaCha8Rng::from_seed([0u8; 32]); let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; // Commit and open let proof = { @@ -402,15 +428,13 @@ pub mod test_util { // Verify let result = { let mut transcript = T::from_proof(proof.as_slice()); - let result = Pcs::verify( + Pcs::verify( &vp, &Pcs::read_commitment(&vp, &mut transcript).unwrap(), &transcript.squeeze_challenges(num_vars), &transcript.read_field_element_ext().unwrap(), &mut transcript, - ); - - result + ) }; result.unwrap(); } @@ -435,7 +459,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size, &rng).unwrap(); - Pcs::trim(¶m).unwrap() + Pcs::trim(¶m, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -459,7 +483,11 @@ pub mod test_util { } }) .collect_vec(); - let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + let comms = polys + .iter() + .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) + .collect_vec(); let points = (0..num_points) .map(|i| transcript.squeeze_challenges(num_vars - i)) @@ -511,14 +539,77 @@ pub mod test_util { result.unwrap(); } } + + pub(super) fn run_simple_batch_commit_open_verify( + base: bool, + num_vars_start: usize, + num_vars_end: usize, + batch_size: usize, + ) where + E: ExtensionField, + Pcs: PolynomialCommitmentScheme, + T: TranscriptRead + + TranscriptWrite + + InMemoryTranscript, + { + for num_vars in num_vars_start..num_vars_end { + let rng = ChaCha8Rng::from_seed([0u8; 32]); + // Setup + let (pp, vp) = { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size, &rng).unwrap(); + Pcs::trim(¶m, poly_size).unwrap() + }; + + let proof = { + let mut transcript = T::new(); + let polys = (0..batch_size) + .map(|_| { + if base { + DenseMultilinearExtension::random(num_vars, &mut rng.clone()) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), + ) + } + }) + .collect_vec(); + let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + + let point = transcript.squeeze_challenges(num_vars); + + let evals = (0..batch_size) + .map(|i| polys[i].evaluate(&point)) + .collect_vec(); + + transcript.write_field_elements_ext(&evals).unwrap(); + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); + transcript.into_proof() + }; + // Batch verify + let result = { + let mut transcript = T::from_proof(proof.as_slice()); + let comms = &Pcs::read_commitment(&vp, &mut transcript).unwrap(); + + let point = transcript.squeeze_challenges(num_vars); + let evals = transcript.read_field_elements_ext(batch_size).unwrap(); + + Pcs::simple_batch_verify(&vp, comms, &point, &evals, &mut transcript) + }; + + result.unwrap(); + } + } } #[cfg(test)] mod test { use crate::{ - basefold::{Basefold, BasefoldExtParams}, + basefold::Basefold, util::transcript::{FieldTranscript, InMemoryTranscript, PoseidonTranscript}, - PolynomialCommitmentScheme, + BasefoldRSParams, PolynomialCommitmentScheme, }; use goldilocks::GoldilocksExt2; use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; @@ -526,22 +617,7 @@ mod test { use rand_chacha::ChaCha8Rng; #[test] fn test_transcript() { - #[derive(Debug)] - pub struct Five {} - - impl BasefoldExtParams for Five { - fn get_reps() -> usize { - return 5; - } - fn get_rate() -> usize { - return 3; - } - fn get_basecode() -> usize { - return 2; - } - } - - type Pcs = Basefold; + type Pcs = Basefold; let num_vars = 10; let rng = ChaCha8Rng::from_seed([0u8; 32]); let poly_size = 1 << num_vars; @@ -550,7 +626,9 @@ mod test { let param = >::setup(poly_size, &rng).unwrap(); - let (pp, vp) = >::trim(¶m).unwrap(); + let (pp, vp) = + >::trim(¶m, 1 << num_vars) + .unwrap(); println!("before commit"); let comm = >::commit_and_write( &pp, diff --git a/mpcs/src/sum_check.rs b/mpcs/src/sum_check.rs index e12233cf9..8ea3517af 100644 --- a/mpcs/src/sum_check.rs +++ b/mpcs/src/sum_check.rs @@ -93,8 +93,8 @@ pub fn evaluate( &|query| evals[&query], &|idx| challenges[idx], &|scalar| -scalar, - &|lhs, rhs| lhs + &rhs, - &|lhs, rhs| lhs * &rhs, + &|lhs, rhs| lhs + rhs, + &|lhs, rhs| lhs * rhs, &|value, scalar| scalar * value, ) } @@ -104,11 +104,7 @@ pub fn lagrange_eval(x: &[F], b: usize) -> F { product(x.iter().enumerate().map( |(idx, x_i)| { - if b.nth_bit(idx) { - *x_i - } else { - F::ONE - x_i - } + if b.nth_bit(idx) { *x_i } else { F::ONE - x_i } }, )) } diff --git a/mpcs/src/sum_check/classic.rs b/mpcs/src/sum_check/classic.rs index 439535e71..853d25602 100644 --- a/mpcs/src/sum_check/classic.rs +++ b/mpcs/src/sum_check/classic.rs @@ -116,10 +116,8 @@ impl<'a, E: ExtensionField> ProverState<'a, E> { .expression .used_rotation() .into_iter() - .filter_map(|rotation| { - (rotation != Rotation::cur()) - .then(|| (rotation, self.bh.rotation_map(rotation))) - }) + .filter(|&rotation| (rotation != Rotation::cur())) + .map(|rotation| (rotation, self.bh.rotation_map(rotation))) .collect::>(); for query in self.expression.used_query() { if query.rotation() != Rotation::cur() { @@ -312,7 +310,7 @@ mod tests { #[test] fn test_sum_check_protocol() { - let polys = vec![ + let polys = [ DenseMultilinearExtension::::from_evaluations_vec( 2, vec![Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)], diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 9e596fe08..470883d7f 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -205,7 +205,7 @@ impl ClassicSumCheckProver for CoefficientsProver { products.iter_mut().for_each(|(lhs, _)| { *lhs *= &rhs; }); - (constant * &rhs, products) + (constant * rhs, products) }, ); Self(constant, flattened) @@ -215,7 +215,7 @@ impl ClassicSumCheckProver for CoefficientsProver { // Initialize h(X) to zero let mut coeffs = Coefficients(FieldType::Ext(vec![E::ZERO; state.expression.degree() + 1])); // First, sum the constant over the hypercube and add to h(X) - coeffs += &(E::from(state.size() as u64) * &self.0); + coeffs += &(E::from(state.size() as u64) * self.0); // Next, for every product of polynomials, where each product is assumed to be exactly 2 // put this into h(X). if self.1.iter().all(|(_, products)| products.len() == 2) { @@ -287,11 +287,11 @@ impl CoefficientsProver { .take(n) .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { let coeff_0 = lhs_0 * rhs_0; - let coeff_2 = (lhs_1 - lhs_0) * &(rhs_1 - rhs_0); + let coeff_2 = (lhs_1 - lhs_0) * (rhs_1 - rhs_0); coeffs[0] += &coeff_0; coeffs[2] += &coeff_2; if !LAZY { - coeffs[1] += &(lhs_1 * rhs_1 - &coeff_0 - &coeff_2); + coeffs[1] += &(lhs_1 * rhs_1 - coeff_0 - coeff_2); } }); }; diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index cf056d51a..260986b2f 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -113,6 +113,42 @@ pub fn field_type_index_ext(poly: &FieldType, index: usize } } +pub fn field_type_index_mul_base( + poly: &mut FieldType, + index: usize, + scalar: &E::BaseField, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] *= scalar, + FieldType::Base(coeffs) => coeffs[index] *= scalar, + _ => unreachable!(), + } +} + +pub fn field_type_index_set_base( + poly: &mut FieldType, + index: usize, + scalar: &E::BaseField, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] = E::from(*scalar), + FieldType::Base(coeffs) => coeffs[index] = *scalar, + _ => unreachable!(), + } +} + +pub fn field_type_index_set_ext( + poly: &mut FieldType, + index: usize, + scalar: &E, +) { + match poly { + FieldType::Ext(coeffs) => coeffs[index] = *scalar, + FieldType::Base(_) => panic!("Cannot set base field from extension field"), + _ => unreachable!(), + } +} + pub struct FieldTypeIterExt<'a, E: ExtensionField> { inner: &'a FieldType, index: usize, diff --git a/mpcs/src/util/arithmetic.rs b/mpcs/src/util/arithmetic.rs index f4ca552d8..609f65455 100644 --- a/mpcs/src/util/arithmetic.rs +++ b/mpcs/src/util/arithmetic.rs @@ -80,7 +80,7 @@ pub fn inner_product<'a, 'b, F: Field>( rhs: impl IntoIterator, ) -> F { lhs.into_iter() - .zip_eq(rhs.into_iter()) + .zip_eq(rhs) .map(|(lhs, rhs)| *lhs * rhs) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -92,8 +92,8 @@ pub fn inner_product_three<'a, 'b, 'c, F: Field>( c: impl IntoIterator, ) -> F { a.into_iter() - .zip_eq(b.into_iter()) - .zip_eq(c.into_iter()) + .zip_eq(b) + .zip_eq(c) .map(|((a, b), c)| *a * b * c) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -107,8 +107,9 @@ pub fn barycentric_weights(points: &[F]) -> Vec { points .iter() .enumerate() - .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i)) - .reduce(|acc, value| acc * &value) + .filter(|&(i, _point_i)| (i != j)) + .map(|(_i, point_i)| *point_j - point_i) + .reduce(|acc, value| acc * value) .unwrap_or(F::ONE) }) .collect_vec(); @@ -126,7 +127,7 @@ pub fn barycentric_interpolate(weights: &[F], points: &[F], evals: &[F let sum_inv = coeffs.iter().fold(F::ZERO, |sum, coeff| sum + coeff); (coeffs, sum_inv.invert().unwrap()) }; - inner_product(&coeffs, evals) * &sum_inv + inner_product(&coeffs, evals) * sum_inv } pub fn modulus() -> BigUint { @@ -134,11 +135,7 @@ pub fn modulus() -> BigUint { } pub fn fe_from_bool(value: bool) -> F { - if value { - F::ONE - } else { - F::ZERO - } + if value { F::ONE } else { F::ZERO } } pub fn fe_mod_from_le_bytes(bytes: impl AsRef<[u8]>) -> F { @@ -221,17 +218,17 @@ pub fn interpolate2(points: [(F, F); 2], x: F) -> F { a1 + (x - a0) * (b1 - a1) * (b0 - a0).invert().unwrap() } -pub fn degree_2_zero_plus_one(poly: &Vec) -> F { +pub fn degree_2_zero_plus_one(poly: &[F]) -> F { poly[0] + poly[0] + poly[1] + poly[2] } -pub fn degree_2_eval(poly: &Vec, point: F) -> F { +pub fn degree_2_eval(poly: &[F], point: F) -> F { poly[0] + point * poly[1] + point * point * poly[2] } -pub fn base_from_raw_bytes(bytes: &Vec) -> E::BaseField { +pub fn base_from_raw_bytes(bytes: &[u8]) -> E::BaseField { let mut res = E::BaseField::ZERO; - bytes.into_iter().for_each(|b| { + bytes.iter().for_each(|b| { res += E::BaseField::from(u64::from(*b)); }); res diff --git a/mpcs/src/util/arithmetic/hypercube.rs b/mpcs/src/util/arithmetic/hypercube.rs index 16c9868af..ccd6294e3 100644 --- a/mpcs/src/util/arithmetic/hypercube.rs +++ b/mpcs/src/util/arithmetic/hypercube.rs @@ -1,4 +1,3 @@ -use ark_std::{end_timer, start_timer}; use ff::Field; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; @@ -30,7 +29,7 @@ pub fn interpolate_over_boolean_hypercube(evals: &mut Vec) { evals.par_chunks_mut(chunk_size).for_each(|chunk| { let half_chunk = chunk_size >> 1; for j in half_chunk..chunk_size { - chunk[j] = chunk[j] - chunk[j - half_chunk]; + chunk[j] -= chunk[j - half_chunk]; } }); } diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 4842f6e09..fa8010970 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -40,6 +40,54 @@ pub fn hash_two_leaves_base( Digest(result) } +pub fn hash_two_leaves_batch_ext( + a: &[E], + b: &[E], + hasher: &Hasher, +) -> Digest { + let mut left_hasher = hasher.clone(); + a.iter().for_each(|a| left_hasher.update(a.as_bases())); + let left = Digest( + left_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + let mut right_hasher = hasher.clone(); + b.iter().for_each(|b| right_hasher.update(b.as_bases())); + let right = Digest( + right_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + hash_two_digests(&left, &right, hasher) +} + +pub fn hash_two_leaves_batch_base( + a: &Vec, + b: &Vec, + hasher: &Hasher, +) -> Digest { + let mut left_hasher = hasher.clone(); + left_hasher.update(a.as_slice()); + let left = Digest( + left_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + let mut right_hasher = hasher.clone(); + right_hasher.update(b.as_slice()); + let right = Digest( + right_hasher.squeeze_vec()[0..DIGEST_WIDTH] + .try_into() + .unwrap(), + ); + + hash_two_digests(&left, &right, hasher) +} + pub fn hash_two_digests( a: &Digest, b: &Digest, diff --git a/mpcs/src/util/merkle_tree.rs b/mpcs/src/util/merkle_tree.rs index 1c563f126..b3d1bf899 100644 --- a/mpcs/src/util/merkle_tree.rs +++ b/mpcs/src/util/merkle_tree.rs @@ -1,4 +1,5 @@ use ff_ext::ExtensionField; +use itertools::Itertools; use multilinear_extensions::mle::FieldType; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, @@ -6,7 +7,11 @@ use rayon::{ }; use crate::util::{ - hash::{hash_two_digests, hash_two_leaves_base, hash_two_leaves_ext, Digest, Hasher}, + field_type_index_base, field_type_index_ext, + hash::{ + hash_two_digests, hash_two_leaves_base, hash_two_leaves_batch_base, + hash_two_leaves_batch_ext, hash_two_leaves_ext, Digest, Hasher, + }, log2_strict, transcript::{TranscriptRead, TranscriptWrite}, Deserialize, DeserializeOwned, Serialize, @@ -21,7 +26,7 @@ where E::BaseField: Serialize + DeserializeOwned, { inner: Vec>>, - leaves: FieldType, + leaves: Vec>, } impl MerkleTree @@ -30,7 +35,14 @@ where { pub fn from_leaves(leaves: FieldType, hasher: &Hasher) -> Self { Self { - inner: merkelize::(&leaves, hasher), + inner: merkelize::(&[&leaves], hasher), + leaves: vec![leaves], + } + } + + pub fn from_batch_leaves(leaves: Vec>, hasher: &Hasher) -> Self { + Self { + inner: merkelize::(&leaves.iter().collect_vec(), hasher), leaves, } } @@ -47,26 +59,52 @@ where self.inner.len() } - pub fn leaves(&self) -> &FieldType { + pub fn leaves(&self) -> &Vec> { &self.leaves } - pub fn size(&self) -> usize { - self.leaves.len() + pub fn batch_leaves(&self, coeffs: &Vec) -> Vec { + (0..self.leaves[0].len()) + .map(|i| { + self.leaves + .iter() + .zip(coeffs) + .map(|(leaf, coeff)| field_type_index_ext(leaf, i) * *coeff) + .sum() + }) + .collect() + } + + pub fn size(&self) -> (usize, usize) { + (self.leaves.len(), self.leaves[0].len()) } - pub fn get_leaf_as_base(&self, index: usize) -> E::BaseField { - match &self.leaves { - FieldType::Base(leaves) => leaves[index], - FieldType::Ext(_) => panic!("Mismatching field type, calling get_leaf_as_base on a Merkle tree over extension fields"), + pub fn get_leaf_as_base(&self, index: usize) -> Vec { + match &self.leaves[0] { + FieldType::Base(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_base(leaves, index)) + .collect(), + FieldType::Ext(_) => panic!( + "Mismatching field type, calling get_leaf_as_base on a Merkle tree over extension fields" + ), FieldType::Unreachable => unreachable!(), } } - pub fn get_leaf_as_extension(&self, index: usize) -> E { - match &self.leaves { - FieldType::Base(leaves) => E::from(leaves[index]), - FieldType::Ext(leaves) => leaves[index], + pub fn get_leaf_as_extension(&self, index: usize) -> Vec { + match &self.leaves[0] { + FieldType::Base(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_ext(leaves, index)) + .collect(), + FieldType::Ext(_) => self + .leaves + .iter() + .map(|leaves| field_type_index_ext(leaves, index)) + .collect(), FieldType::Unreachable => unreachable!(), } } @@ -75,7 +113,7 @@ where &self, leaf_index: usize, ) -> MerklePathWithoutLeafOrRoot { - assert!(leaf_index < self.size()); + assert!(leaf_index < self.size().1); MerklePathWithoutLeafOrRoot::::new( self.inner .iter() @@ -105,6 +143,10 @@ where Self { inner } } + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + pub fn len(&self) -> usize { self.inner.len() } @@ -164,28 +206,100 @@ where hasher, ) } + + pub fn authenticate_batch_leaves_root_ext( + &self, + left: Vec, + right: Vec, + index: usize, + root: &Digest, + hasher: &Hasher, + ) { + authenticate_merkle_path_root_batch::( + &self.inner, + FieldType::Ext(left), + FieldType::Ext(right), + index, + root, + hasher, + ) + } + + pub fn authenticate_batch_leaves_root_base( + &self, + left: Vec, + right: Vec, + index: usize, + root: &Digest, + hasher: &Hasher, + ) { + authenticate_merkle_path_root_batch::( + &self.inner, + FieldType::Base(left), + FieldType::Base(right), + index, + root, + hasher, + ) + } } fn merkelize( - values: &FieldType, + values: &[&FieldType], hasher: &Hasher, ) -> Vec>> { - let timer = start_timer!(|| format!("merkelize {} values", values.len())); - let log_v = log2_strict(values.len()); + #[cfg(feature = "sanity-check")] + for i in 0..(values.len() - 1) { + assert_eq!(values[i].len(), values[i + 1].len()); + } + let timer = start_timer!(|| format!("merkelize {} values", values[0].len() * values.len())); + let log_v = log2_strict(values[0].len()); let mut tree = Vec::with_capacity(log_v); // The first layer of hashes, half the number of leaves - let mut hashes = vec![Digest::default(); values.len() >> 1]; - hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = match values { - FieldType::Base(values) => { - hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1], hasher) - } - FieldType::Ext(values) => { - hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1], hasher) - } - FieldType::Unreachable => unreachable!(), - }; - }); + let mut hashes = vec![Digest::default(); values[0].len() >> 1]; + if values.len() == 1 { + hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { + *hash = match &values[0] { + FieldType::Base(values) => { + hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1], hasher) + } + FieldType::Ext(values) => { + hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1], hasher) + } + FieldType::Unreachable => unreachable!(), + }; + }); + } else { + hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { + *hash = match &values[0] { + FieldType::Base(_) => hash_two_leaves_batch_base::( + &values + .iter() + .map(|values| field_type_index_base(values, i << 1)) + .collect(), + &values + .iter() + .map(|values| field_type_index_base(values, (i << 1) + 1)) + .collect(), + hasher, + ), + FieldType::Ext(_) => hash_two_leaves_batch_ext::( + values + .iter() + .map(|values| field_type_index_ext(values, i << 1)) + .collect_vec() + .as_slice(), + values + .iter() + .map(|values| field_type_index_ext(values, (i << 1) + 1)) + .collect_vec() + .as_slice(), + hasher, + ), + FieldType::Unreachable => unreachable!(), + }; + }); + } tree.push(hashes); @@ -202,7 +316,7 @@ fn merkelize( } fn authenticate_merkle_path_root( - path: &Vec>, + path: &[Digest], leaves: FieldType, x_index: usize, root: &Digest, @@ -218,11 +332,43 @@ fn authenticate_merkle_path_root( // The lowest bit in the index is ignored. It can point to either leaves x_index >>= 1; - for i in 0..path.len() { + for path_i in path.iter() { + hash = if x_index & 1 == 0 { + hash_two_digests(&hash, path_i, hasher) + } else { + hash_two_digests(path_i, &hash, hasher) + }; + x_index >>= 1; + } + assert_eq!(&hash, root); +} + +fn authenticate_merkle_path_root_batch( + path: &[Digest], + left: FieldType, + right: FieldType, + x_index: usize, + root: &Digest, + hasher: &Hasher, +) { + let mut x_index = x_index; + let mut hash = match (left, right) { + (FieldType::Base(left), FieldType::Base(right)) => { + hash_two_leaves_batch_base::(&left, &right, hasher) + } + (FieldType::Ext(left), FieldType::Ext(right)) => { + hash_two_leaves_batch_ext::(&left, &right, hasher) + } + _ => unreachable!(), + }; + + // The lowest bit in the index is ignored. It can point to either leaves + x_index >>= 1; + for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, &path[i], hasher) + hash_two_digests(&hash, path_i, hasher) } else { - hash_two_digests(&path[i], &hash, hasher) + hash_two_digests(path_i, &hash, hasher) }; x_index >>= 1; } diff --git a/mpcs/src/util/parallel.rs b/mpcs/src/util/parallel.rs index 5f44d473d..950d43797 100644 --- a/mpcs/src/util/parallel.rs +++ b/mpcs/src/util/parallel.rs @@ -1,10 +1,7 @@ pub fn num_threads() -> usize { #[cfg(feature = "parallel")] let nt = rayon::current_num_threads(); - return nt; - - #[cfg(not(feature = "parallel"))] - return 1; + if cfg!(feature = "parallel") { nt } else { 1 } } pub fn parallelize_iter(iter: I, f: F) diff --git a/mpcs/src/util/transcript.rs b/mpcs/src/util/transcript.rs index 343bfffba..2edf082cf 100644 --- a/mpcs/src/util/transcript.rs +++ b/mpcs/src/util/transcript.rs @@ -21,17 +21,16 @@ pub trait FieldTranscript { fn common_field_element_ext(&mut self, fe: &E) -> Result<(), Error>; fn common_field_elements(&mut self, fes: FieldType) -> Result<(), Error> { - Ok(match fes { + match fes { FieldType::Base(fes) => fes .iter() - .map(|fe| self.common_field_element_base(fe)) - .try_collect()?, + .try_for_each(|fe| self.common_field_element_base(fe))?, FieldType::Ext(fes) => fes .iter() - .map(|fe| self.common_field_element_ext(fe)) - .try_collect()?, + .try_for_each(|fe| self.common_field_element_ext(fe))?, FieldType::Unreachable => unreachable!(), - }) + }; + Ok(()) } } @@ -146,7 +145,7 @@ impl Stream { self.inner.len() - self.pointer } - pub fn read_exact(&mut self, output: &mut Vec) -> Result<(), Error> { + pub fn read_exact(&mut self, output: &mut [T]) -> Result<(), Error> { let left = self.left(); if left < output.len() { return Err(Error::Transcript(