diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 16d5cfe67..8b948a48a 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -45,7 +45,7 @@ fn bench_add(c: &mut Criterion) { zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); let param = Pcs::setup(1 << MAX_NUM_VARIABLES).unwrap(); - let (pp, vp) = Pcs::trim(¶m, 1 << MAX_NUM_VARIABLES).unwrap(); + let (pp, vp) = Pcs::trim(param, 1 << MAX_NUM_VARIABLES).unwrap(); let pk = zkvm_cs .clone() diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index fbb0e0a83..beba16b22 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -105,7 +105,7 @@ fn main() { // keygen let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); - let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); + let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); diff --git a/ceno_zkvm/src/instructions/riscv/test.rs b/ceno_zkvm/src/instructions/riscv/test.rs index 6513cd1f7..f4b8f8824 100644 --- a/ceno_zkvm/src/instructions/riscv/test.rs +++ b/ceno_zkvm/src/instructions/riscv/test.rs @@ -23,6 +23,6 @@ fn test_multiple_opcode() { |cs| SubInstruction::construct_circuit(&mut CircuitBuilder::::new(cs)), ); let param = Pcs::setup(1 << 10).unwrap(); - let (pp, _) = Pcs::trim(¶m, 1 << 10).unwrap(); + let (pp, _) = Pcs::trim(param, 1 << 10).unwrap(); cs.key_gen::(&pp, None); } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 07ae5cb99..04edee440 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -89,7 +89,7 @@ fn test_rw_lk_expression_combination() { // pcs setup let param = Pcs::setup(1 << 13).unwrap(); - let (pp, vp) = Pcs::trim(¶m, 1 << 13).unwrap(); + let (pp, vp) = Pcs::trim(param, 1 << 13).unwrap(); // configure let name = TestCircuit::::name(); @@ -223,7 +223,7 @@ fn test_single_add_instance_e2e() { ); let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); - let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); + let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); diff --git a/mpcs/benches/basecode.rs b/mpcs/benches/basecode.rs index 193d15f88..9ef1896f6 100644 --- a/mpcs/benches/basecode.rs +++ b/mpcs/benches/basecode.rs @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) { let (pp, _) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let polys = (0..batch_size) .map(|_| { diff --git a/mpcs/benches/commit_open_verify_basecode.rs b/mpcs/benches/commit_open_verify_basecode.rs index 1a22e9171..91baa5f73 100644 --- a/mpcs/benches/commit_open_verify_basecode.rs +++ b/mpcs/benches/commit_open_verify_basecode.rs @@ -42,7 +42,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::setup(poly_size).unwrap(); }) }); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); @@ -118,7 +118,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -258,7 +258,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); let polys = (0..batch_size) diff --git a/mpcs/benches/commit_open_verify_rs.rs b/mpcs/benches/commit_open_verify_rs.rs index 686253218..1401f5127 100644 --- a/mpcs/benches/commit_open_verify_rs.rs +++ b/mpcs/benches/commit_open_verify_rs.rs @@ -46,7 +46,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::setup(poly_size).unwrap(); }) }); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); @@ -125,7 +125,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -266,7 +266,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); let polys = (0..batch_size) diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs index ac9870d84..2d284d177 100644 --- a/mpcs/benches/rscode.rs +++ b/mpcs/benches/rscode.rs @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) { let (pp, _) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let polys = (0..batch_size) .map(|_| { diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 3f85022cc..5c225c75a 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -287,10 +287,10 @@ where /// Derive the proving key and verification key from the public parameter. /// This step simultaneously trims the parameter for the particular size. fn trim( - pp: &Self::Param, + pp: Self::Param, poly_size: usize, ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { - >::trim(&pp.params, log2_strict(poly_size)).map( + >::trim(pp.params, log2_strict(poly_size)).map( |(pp, vp)| { ( BasefoldProverParams { diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs index 706ec9f9f..410d35970 100644 --- a/mpcs/src/basefold/encoding.rs +++ b/mpcs/src/basefold/encoding.rs @@ -35,7 +35,7 @@ pub trait EncodingScheme: std::fmt::Debug + Clone { fn setup(max_msg_size_log: usize) -> Self::PublicParameters; fn trim( - pp: &Self::PublicParameters, + pp: Self::PublicParameters, max_msg_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>; @@ -177,7 +177,7 @@ pub(crate) mod test_util { let mut poly = FieldType::Ext(poly); let pp: Code::PublicParameters = Code::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); if Code::message_is_left_and_right_folding() { diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs index 44b91ba76..9fbee84f1 100644 --- a/mpcs/src/basefold/encoding/basecode.rs +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -117,7 +117,7 @@ where } fn trim( - pp: &Self::PublicParameters, + mut pp: Self::PublicParameters, max_msg_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { if pp.table.len() < Spec::get_rate_log() + max_msg_size_log { @@ -127,6 +127,9 @@ where max_msg_size_log, ))); } + pp.table_w_weights + .truncate(Spec::get_rate_log() + max_msg_size_log); + pp.table.truncate(Spec::get_rate_log() + max_msg_size_log); let mut key: [u8; 16] = [0u8; 16]; let mut iv: [u8; 16] = [0u8; 16]; let mut rng = ChaCha8Rng::from_seed(pp.rng_seed); @@ -135,8 +138,8 @@ where rng.fill_bytes(&mut iv); Ok(( Self::ProverParameters { - table_w_weights: pp.table_w_weights.clone(), - table: pp.table.clone(), + table_w_weights: pp.table_w_weights, + table: pp.table, rng_seed: pp.rng_seed, _phantom: PhantomData, }, @@ -430,7 +433,7 @@ mod tests { fn prover_verifier_consistency() { type Code = Basecode; let pp: BasecodeParameters = Code::setup(10); - let (pp, vp) = Code::trim(&pp, 10).unwrap(); + let (pp, vp) = Code::trim(pp, 10).unwrap(); for level in 0..(10 + >::get_rate_log()) { for index in 0..(1 << level) { assert_eq!( diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 8535ce23e..2bcac0826 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -280,7 +280,7 @@ where } fn trim( - pp: &Self::PublicParameters, + mut pp: Self::PublicParameters, max_message_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { if pp.fft_root_table.len() < max_message_size_log + Spec::get_rate_log() { @@ -308,7 +308,6 @@ where }, )); } - let mut gamma_powers = Vec::with_capacity(max_message_size_log); let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); @@ -319,26 +318,27 @@ where } let inv_of_two = E::BaseField::from(2).invert().unwrap(); gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); + pp.fft_root_table + .truncate(max_message_size_log + Spec::get_rate_log()); + let verifier_fft_root_table = pp.fft_root_table + [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] + .iter() + .cloned() + .chain( + pp.fft_root_table[Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] + .iter() + .map(|v| vec![v[1]]), + ) + .collect(); Ok(( Self::ProverParameters { - fft_root_table: pp.fft_root_table[..max_message_size_log + Spec::get_rate_log()] - .to_vec(), + fft_root_table: pp.fft_root_table, gamma_powers: gamma_powers.clone(), gamma_powers_inv_div_two: gamma_powers_inv.clone(), full_message_size_log: max_message_size_log, }, Self::VerifierParameters { - fft_root_table: pp.fft_root_table - [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] - .iter() - .cloned() - .chain( - pp.fft_root_table - [Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] - .iter() - .map(|v| vec![v[1]]), - ) - .collect(), + fft_root_table: verifier_fft_root_table, full_message_size_log: max_message_size_log, gamma_powers, gamma_powers_inv_div_two: gamma_powers_inv, @@ -653,7 +653,7 @@ mod tests { fn prover_verifier_consistency() { type Code = RSCode; let pp: RSCodeParameters = Code::setup(10); - let (pp, vp) = Code::trim(&pp, 10).unwrap(); + let (pp, vp) = Code::trim(pp, 10).unwrap(); for level in 0..(10 + >::get_rate_log()) { for index in 0..(1 << level) { let (naive_x0, naive_x1, naive_w) = @@ -690,7 +690,7 @@ mod tests { let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); let challenge = E::from(2); @@ -728,7 +728,7 @@ mod tests { let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); check_low_degree(&codeword, "low degree check for original codeword"); let c0 = field_type_index_ext(&codeword, 0); diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 46fbea0ff..19b3d16b6 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -24,7 +24,7 @@ pub fn pcs_setup>( } pub fn pcs_trim>( - param: &Pcs::Param, + param: Pcs::Param, poly_size: usize, ) -> Result<(Pcs::ProverParam, Pcs::VerifierParam), Error> { Pcs::trim(param, poly_size) @@ -119,7 +119,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn setup(poly_size: usize) -> Result; fn trim( - param: &Self::Param, + param: Self::Param, poly_size: usize, ) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; @@ -380,7 +380,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Commit and open let (comm, eval, proof, challenge) = { @@ -442,7 +442,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -556,7 +556,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let (comm, evals, proof, challenge) = {