From 853f1bac0d9ef5010b34e6546a9eea264ed1d7c1 Mon Sep 17 00:00:00 2001 From: Cyte Zhang Date: Mon, 11 Nov 2024 15:55:49 +0800 Subject: [PATCH] Change BaseFold trim API to consume pp & truncate public parameters directly in `trim` (#248) Considering that BaseFold trim is almost always used only once in each program execution, consume the input public parameter instead of clone it to save memory. This API change also allows the `trim` function to directly truncate the input public parameters. --- ceno_zkvm/benches/riscv_add.rs | 2 +- ceno_zkvm/examples/riscv_opcodes.rs | 2 +- ceno_zkvm/src/instructions/riscv/test.rs | 2 +- ceno_zkvm/src/scheme/tests.rs | 4 +-- mpcs/benches/basecode.rs | 2 +- mpcs/benches/commit_open_verify_basecode.rs | 6 ++-- mpcs/benches/commit_open_verify_rs.rs | 6 ++-- mpcs/benches/rscode.rs | 2 +- mpcs/src/basefold.rs | 4 +-- mpcs/src/basefold/encoding.rs | 4 +-- mpcs/src/basefold/encoding/basecode.rs | 11 ++++--- mpcs/src/basefold/encoding/rs.rs | 36 ++++++++++----------- mpcs/src/lib.rs | 10 +++--- 13 files changed, 47 insertions(+), 44 deletions(-) diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 16d5cfe67..8b948a48a 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -45,7 +45,7 @@ fn bench_add(c: &mut Criterion) { zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); let param = Pcs::setup(1 << MAX_NUM_VARIABLES).unwrap(); - let (pp, vp) = Pcs::trim(¶m, 1 << MAX_NUM_VARIABLES).unwrap(); + let (pp, vp) = Pcs::trim(param, 1 << MAX_NUM_VARIABLES).unwrap(); let pk = zkvm_cs .clone() diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index fbb0e0a83..beba16b22 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -105,7 +105,7 @@ fn main() { // keygen let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); - let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); + let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); diff --git a/ceno_zkvm/src/instructions/riscv/test.rs b/ceno_zkvm/src/instructions/riscv/test.rs index 6513cd1f7..f4b8f8824 100644 --- a/ceno_zkvm/src/instructions/riscv/test.rs +++ b/ceno_zkvm/src/instructions/riscv/test.rs @@ -23,6 +23,6 @@ fn test_multiple_opcode() { |cs| SubInstruction::construct_circuit(&mut CircuitBuilder::::new(cs)), ); let param = Pcs::setup(1 << 10).unwrap(); - let (pp, _) = Pcs::trim(¶m, 1 << 10).unwrap(); + let (pp, _) = Pcs::trim(param, 1 << 10).unwrap(); cs.key_gen::(&pp, None); } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 07ae5cb99..04edee440 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -89,7 +89,7 @@ fn test_rw_lk_expression_combination() { // pcs setup let param = Pcs::setup(1 << 13).unwrap(); - let (pp, vp) = Pcs::trim(¶m, 1 << 13).unwrap(); + let (pp, vp) = Pcs::trim(param, 1 << 13).unwrap(); // configure let name = TestCircuit::::name(); @@ -223,7 +223,7 @@ fn test_single_add_instance_e2e() { ); let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); - let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); + let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); diff --git a/mpcs/benches/basecode.rs b/mpcs/benches/basecode.rs index 193d15f88..9ef1896f6 100644 --- a/mpcs/benches/basecode.rs +++ b/mpcs/benches/basecode.rs @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) { let (pp, _) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let polys = (0..batch_size) .map(|_| { diff --git a/mpcs/benches/commit_open_verify_basecode.rs b/mpcs/benches/commit_open_verify_basecode.rs index 1a22e9171..91baa5f73 100644 --- a/mpcs/benches/commit_open_verify_basecode.rs +++ b/mpcs/benches/commit_open_verify_basecode.rs @@ -42,7 +42,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::setup(poly_size).unwrap(); }) }); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); @@ -118,7 +118,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -258,7 +258,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); let polys = (0..batch_size) diff --git a/mpcs/benches/commit_open_verify_rs.rs b/mpcs/benches/commit_open_verify_rs.rs index 686253218..1401f5127 100644 --- a/mpcs/benches/commit_open_verify_rs.rs +++ b/mpcs/benches/commit_open_verify_rs.rs @@ -46,7 +46,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::setup(poly_size).unwrap(); }) }); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); @@ -125,7 +125,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -266,7 +266,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); let polys = (0..batch_size) diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs index ac9870d84..2d284d177 100644 --- a/mpcs/benches/rscode.rs +++ b/mpcs/benches/rscode.rs @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) { let (pp, _) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let polys = (0..batch_size) .map(|_| { diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 3f85022cc..5c225c75a 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -287,10 +287,10 @@ where /// Derive the proving key and verification key from the public parameter. /// This step simultaneously trims the parameter for the particular size. fn trim( - pp: &Self::Param, + pp: Self::Param, poly_size: usize, ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { - >::trim(&pp.params, log2_strict(poly_size)).map( + >::trim(pp.params, log2_strict(poly_size)).map( |(pp, vp)| { ( BasefoldProverParams { diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs index 706ec9f9f..410d35970 100644 --- a/mpcs/src/basefold/encoding.rs +++ b/mpcs/src/basefold/encoding.rs @@ -35,7 +35,7 @@ pub trait EncodingScheme: std::fmt::Debug + Clone { fn setup(max_msg_size_log: usize) -> Self::PublicParameters; fn trim( - pp: &Self::PublicParameters, + pp: Self::PublicParameters, max_msg_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>; @@ -177,7 +177,7 @@ pub(crate) mod test_util { let mut poly = FieldType::Ext(poly); let pp: Code::PublicParameters = Code::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); if Code::message_is_left_and_right_folding() { diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs index 44b91ba76..9fbee84f1 100644 --- a/mpcs/src/basefold/encoding/basecode.rs +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -117,7 +117,7 @@ where } fn trim( - pp: &Self::PublicParameters, + mut pp: Self::PublicParameters, max_msg_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { if pp.table.len() < Spec::get_rate_log() + max_msg_size_log { @@ -127,6 +127,9 @@ where max_msg_size_log, ))); } + pp.table_w_weights + .truncate(Spec::get_rate_log() + max_msg_size_log); + pp.table.truncate(Spec::get_rate_log() + max_msg_size_log); let mut key: [u8; 16] = [0u8; 16]; let mut iv: [u8; 16] = [0u8; 16]; let mut rng = ChaCha8Rng::from_seed(pp.rng_seed); @@ -135,8 +138,8 @@ where rng.fill_bytes(&mut iv); Ok(( Self::ProverParameters { - table_w_weights: pp.table_w_weights.clone(), - table: pp.table.clone(), + table_w_weights: pp.table_w_weights, + table: pp.table, rng_seed: pp.rng_seed, _phantom: PhantomData, }, @@ -430,7 +433,7 @@ mod tests { fn prover_verifier_consistency() { type Code = Basecode; let pp: BasecodeParameters = Code::setup(10); - let (pp, vp) = Code::trim(&pp, 10).unwrap(); + let (pp, vp) = Code::trim(pp, 10).unwrap(); for level in 0..(10 + >::get_rate_log()) { for index in 0..(1 << level) { assert_eq!( diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 8535ce23e..2bcac0826 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -280,7 +280,7 @@ where } fn trim( - pp: &Self::PublicParameters, + mut pp: Self::PublicParameters, max_message_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { if pp.fft_root_table.len() < max_message_size_log + Spec::get_rate_log() { @@ -308,7 +308,6 @@ where }, )); } - let mut gamma_powers = Vec::with_capacity(max_message_size_log); let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); @@ -319,26 +318,27 @@ where } let inv_of_two = E::BaseField::from(2).invert().unwrap(); gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); + pp.fft_root_table + .truncate(max_message_size_log + Spec::get_rate_log()); + let verifier_fft_root_table = pp.fft_root_table + [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] + .iter() + .cloned() + .chain( + pp.fft_root_table[Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] + .iter() + .map(|v| vec![v[1]]), + ) + .collect(); Ok(( Self::ProverParameters { - fft_root_table: pp.fft_root_table[..max_message_size_log + Spec::get_rate_log()] - .to_vec(), + fft_root_table: pp.fft_root_table, gamma_powers: gamma_powers.clone(), gamma_powers_inv_div_two: gamma_powers_inv.clone(), full_message_size_log: max_message_size_log, }, Self::VerifierParameters { - fft_root_table: pp.fft_root_table - [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] - .iter() - .cloned() - .chain( - pp.fft_root_table - [Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] - .iter() - .map(|v| vec![v[1]]), - ) - .collect(), + fft_root_table: verifier_fft_root_table, full_message_size_log: max_message_size_log, gamma_powers, gamma_powers_inv_div_two: gamma_powers_inv, @@ -653,7 +653,7 @@ mod tests { fn prover_verifier_consistency() { type Code = RSCode; let pp: RSCodeParameters = Code::setup(10); - let (pp, vp) = Code::trim(&pp, 10).unwrap(); + let (pp, vp) = Code::trim(pp, 10).unwrap(); for level in 0..(10 + >::get_rate_log()) { for index in 0..(1 << level) { let (naive_x0, naive_x1, naive_w) = @@ -690,7 +690,7 @@ mod tests { let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); let challenge = E::from(2); @@ -728,7 +728,7 @@ mod tests { let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); check_low_degree(&codeword, "low degree check for original codeword"); let c0 = field_type_index_ext(&codeword, 0); diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 46fbea0ff..19b3d16b6 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -24,7 +24,7 @@ pub fn pcs_setup>( } pub fn pcs_trim>( - param: &Pcs::Param, + param: Pcs::Param, poly_size: usize, ) -> Result<(Pcs::ProverParam, Pcs::VerifierParam), Error> { Pcs::trim(param, poly_size) @@ -119,7 +119,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn setup(poly_size: usize) -> Result; fn trim( - param: &Self::Param, + param: Self::Param, poly_size: usize, ) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; @@ -380,7 +380,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Commit and open let (comm, eval, proof, challenge) = { @@ -442,7 +442,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -556,7 +556,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let (comm, evals, proof, challenge) = {