Skip to content

Commit

Permalink
Change BaseFold trim API to consume pp & truncate public parameters d…
Browse files Browse the repository at this point in the history
…irectly in `trim` (#248)

Considering that BaseFold trim is almost always used only once in each
program execution, consume the input public parameter instead of clone
it to save memory.

This API change also allows the `trim` function to directly truncate the
input public parameters.
  • Loading branch information
yczhangsjtu authored Nov 11, 2024
1 parent 8e00028 commit 853f1ba
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 44 deletions.
2 changes: 1 addition & 1 deletion ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn bench_add(c: &mut Criterion) {
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);

let param = Pcs::setup(1 << MAX_NUM_VARIABLES).unwrap();
let (pp, vp) = Pcs::trim(&param, 1 << MAX_NUM_VARIABLES).unwrap();
let (pp, vp) = Pcs::trim(param, 1 << MAX_NUM_VARIABLES).unwrap();

let pk = zkvm_cs
.clone()
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn main() {

// keygen
let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let mut zkvm_cs = ZKVMConstraintSystem::default();

let config = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ fn test_multiple_opcode() {
|cs| SubInstruction::construct_circuit(&mut CircuitBuilder::<E>::new(cs)),
);
let param = Pcs::setup(1 << 10).unwrap();
let (pp, _) = Pcs::trim(&param, 1 << 10).unwrap();
let (pp, _) = Pcs::trim(param, 1 << 10).unwrap();
cs.key_gen::<Pcs>(&pp, None);
}
4 changes: 2 additions & 2 deletions ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn test_rw_lk_expression_combination() {

// pcs setup
let param = Pcs::setup(1 << 13).unwrap();
let (pp, vp) = Pcs::trim(&param, 1 << 13).unwrap();
let (pp, vp) = Pcs::trim(param, 1 << 13).unwrap();

// configure
let name = TestCircuit::<E, RW, L>::name();
Expand Down Expand Up @@ -223,7 +223,7 @@ fn test_single_add_instance_e2e() {
);

let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let mut zkvm_cs = ZKVMConstraintSystem::default();
// opcode circuits
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
Expand Down
2 changes: 1 addition & 1 deletion mpcs/benches/basecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) {
let (pp, _) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
let polys = (0..batch_size)
.map(|_| {
Expand Down
6 changes: 3 additions & 3 deletions mpcs/benches/commit_open_verify_basecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) {
Pcs::setup(poly_size).unwrap();
})
});
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};

let mut transcript = T::new(b"BaseFold");
Expand Down Expand Up @@ -118,7 +118,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
// Batch commit and open
let evals = chain![
Expand Down Expand Up @@ -258,7 +258,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base:
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
let mut transcript = T::new(b"BaseFold");
let polys = (0..batch_size)
Expand Down
6 changes: 3 additions & 3 deletions mpcs/benches/commit_open_verify_rs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) {
Pcs::setup(poly_size).unwrap();
})
});
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};

let mut transcript = T::new(b"BaseFold");
Expand Down Expand Up @@ -125,7 +125,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
// Batch commit and open
let evals = chain![
Expand Down Expand Up @@ -266,7 +266,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base:
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
let mut transcript = T::new(b"BaseFold");
let polys = (0..batch_size)
Expand Down
2 changes: 1 addition & 1 deletion mpcs/benches/rscode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) {
let (pp, _) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
let polys = (0..batch_size)
.map(|_| {
Expand Down
4 changes: 2 additions & 2 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ where
/// Derive the proving key and verification key from the public parameter.
/// This step simultaneously trims the parameter for the particular size.
fn trim(
pp: &Self::Param,
pp: Self::Param,
poly_size: usize,
) -> Result<(Self::ProverParam, Self::VerifierParam), Error> {
<Spec::EncodingScheme as EncodingScheme<E>>::trim(&pp.params, log2_strict(poly_size)).map(
<Spec::EncodingScheme as EncodingScheme<E>>::trim(pp.params, log2_strict(poly_size)).map(
|(pp, vp)| {
(
BasefoldProverParams {
Expand Down
4 changes: 2 additions & 2 deletions mpcs/src/basefold/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub trait EncodingScheme<E: ExtensionField>: std::fmt::Debug + Clone {
fn setup(max_msg_size_log: usize) -> Self::PublicParameters;

fn trim(
pp: &Self::PublicParameters,
pp: Self::PublicParameters,
max_msg_size_log: usize,
) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>;

Expand Down Expand Up @@ -177,7 +177,7 @@ pub(crate) mod test_util {
let mut poly = FieldType::Ext(poly);

let pp: Code::PublicParameters = Code::setup(num_vars);
let (pp, _) = Code::trim(&pp, num_vars).unwrap();
let (pp, _) = Code::trim(pp, num_vars).unwrap();
let mut codeword = Code::encode(&pp, &poly);
reverse_index_bits_in_place_field_type(&mut codeword);
if Code::message_is_left_and_right_folding() {
Expand Down
11 changes: 7 additions & 4 deletions mpcs/src/basefold/encoding/basecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ where
}

fn trim(
pp: &Self::PublicParameters,
mut pp: Self::PublicParameters,
max_msg_size_log: usize,
) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> {
if pp.table.len() < Spec::get_rate_log() + max_msg_size_log {
Expand All @@ -127,6 +127,9 @@ where
max_msg_size_log,
)));
}
pp.table_w_weights
.truncate(Spec::get_rate_log() + max_msg_size_log);
pp.table.truncate(Spec::get_rate_log() + max_msg_size_log);
let mut key: [u8; 16] = [0u8; 16];
let mut iv: [u8; 16] = [0u8; 16];
let mut rng = ChaCha8Rng::from_seed(pp.rng_seed);
Expand All @@ -135,8 +138,8 @@ where
rng.fill_bytes(&mut iv);
Ok((
Self::ProverParameters {
table_w_weights: pp.table_w_weights.clone(),
table: pp.table.clone(),
table_w_weights: pp.table_w_weights,
table: pp.table,
rng_seed: pp.rng_seed,
_phantom: PhantomData,
},
Expand Down Expand Up @@ -430,7 +433,7 @@ mod tests {
fn prover_verifier_consistency() {
type Code = Basecode<BasecodeDefaultSpec>;
let pp: BasecodeParameters<GoldilocksExt2> = Code::setup(10);
let (pp, vp) = Code::trim(&pp, 10).unwrap();
let (pp, vp) = Code::trim(pp, 10).unwrap();
for level in 0..(10 + <Code as EncodingScheme<GoldilocksExt2>>::get_rate_log()) {
for index in 0..(1 << level) {
assert_eq!(
Expand Down
36 changes: 18 additions & 18 deletions mpcs/src/basefold/encoding/rs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ where
}

fn trim(
pp: &Self::PublicParameters,
mut pp: Self::PublicParameters,
max_message_size_log: usize,
) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> {
if pp.fft_root_table.len() < max_message_size_log + Spec::get_rate_log() {
Expand Down Expand Up @@ -308,7 +308,6 @@ where
},
));
}

let mut gamma_powers = Vec::with_capacity(max_message_size_log);
let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log);
gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR);
Expand All @@ -319,26 +318,27 @@ where
}
let inv_of_two = E::BaseField::from(2).invert().unwrap();
gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two);
pp.fft_root_table
.truncate(max_message_size_log + Spec::get_rate_log());
let verifier_fft_root_table = pp.fft_root_table
[..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()]
.iter()
.cloned()
.chain(
pp.fft_root_table[Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..]
.iter()
.map(|v| vec![v[1]]),
)
.collect();
Ok((
Self::ProverParameters {
fft_root_table: pp.fft_root_table[..max_message_size_log + Spec::get_rate_log()]
.to_vec(),
fft_root_table: pp.fft_root_table,
gamma_powers: gamma_powers.clone(),
gamma_powers_inv_div_two: gamma_powers_inv.clone(),
full_message_size_log: max_message_size_log,
},
Self::VerifierParameters {
fft_root_table: pp.fft_root_table
[..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()]
.iter()
.cloned()
.chain(
pp.fft_root_table
[Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..]
.iter()
.map(|v| vec![v[1]]),
)
.collect(),
fft_root_table: verifier_fft_root_table,
full_message_size_log: max_message_size_log,
gamma_powers,
gamma_powers_inv_div_two: gamma_powers_inv,
Expand Down Expand Up @@ -653,7 +653,7 @@ mod tests {
fn prover_verifier_consistency() {
type Code = RSCode<RSCodeDefaultSpec>;
let pp: RSCodeParameters<GoldilocksExt2> = Code::setup(10);
let (pp, vp) = Code::trim(&pp, 10).unwrap();
let (pp, vp) = Code::trim(pp, 10).unwrap();
for level in 0..(10 + <Code as EncodingScheme<GoldilocksExt2>>::get_rate_log()) {
for index in 0..(1 << level) {
let (naive_x0, naive_x1, naive_w) =
Expand Down Expand Up @@ -690,7 +690,7 @@ mod tests {
let poly = FieldType::Ext(poly);

let pp = <Code as EncodingScheme<E>>::setup(num_vars);
let (pp, _) = Code::trim(&pp, num_vars).unwrap();
let (pp, _) = Code::trim(pp, num_vars).unwrap();
let mut codeword = Code::encode(&pp, &poly);
reverse_index_bits_in_place_field_type(&mut codeword);
let challenge = E::from(2);
Expand Down Expand Up @@ -728,7 +728,7 @@ mod tests {
let poly = FieldType::Ext(poly);

let pp = <Code as EncodingScheme<E>>::setup(num_vars);
let (pp, _) = Code::trim(&pp, num_vars).unwrap();
let (pp, _) = Code::trim(pp, num_vars).unwrap();
let mut codeword = Code::encode(&pp, &poly);
check_low_degree(&codeword, "low degree check for original codeword");
let c0 = field_type_index_ext(&codeword, 0);
Expand Down
10 changes: 5 additions & 5 deletions mpcs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub fn pcs_setup<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
}

pub fn pcs_trim<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
param: &Pcs::Param,
param: Pcs::Param,
poly_size: usize,
) -> Result<(Pcs::ProverParam, Pcs::VerifierParam), Error> {
Pcs::trim(param, poly_size)
Expand Down Expand Up @@ -119,7 +119,7 @@ pub trait PolynomialCommitmentScheme<E: ExtensionField>: Clone + Debug {
fn setup(poly_size: usize) -> Result<Self::Param, Error>;

fn trim(
param: &Self::Param,
param: Self::Param,
poly_size: usize,
) -> Result<(Self::ProverParam, Self::VerifierParam), Error>;

Expand Down Expand Up @@ -380,7 +380,7 @@ pub mod test_util {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
// Commit and open
let (comm, eval, proof, challenge) = {
Expand Down Expand Up @@ -442,7 +442,7 @@ pub mod test_util {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
// Batch commit and open
let evals = chain![
Expand Down Expand Up @@ -556,7 +556,7 @@ pub mod test_util {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};

let (comm, evals, proof, challenge) = {
Expand Down

0 comments on commit 853f1ba

Please sign in to comment.