Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BaseFold trim API to consume pp & truncate public parameters directly in trim #248

Merged
merged 10 commits into from
Nov 11, 2024
2 changes: 1 addition & 1 deletion ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn bench_add(c: &mut Criterion) {
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);

let param = Pcs::setup(1 << MAX_NUM_VARIABLES).unwrap();
let (pp, vp) = Pcs::trim(&param, 1 << MAX_NUM_VARIABLES).unwrap();
let (pp, vp) = Pcs::trim(param, 1 << MAX_NUM_VARIABLES).unwrap();

let pk = zkvm_cs
.clone()
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn main() {

// keygen
let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let mut zkvm_cs = ZKVMConstraintSystem::default();

let config = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ fn test_multiple_opcode() {
|cs| SubInstruction::construct_circuit(&mut CircuitBuilder::<E>::new(cs)),
);
let param = Pcs::setup(1 << 10).unwrap();
let (pp, _) = Pcs::trim(&param, 1 << 10).unwrap();
let (pp, _) = Pcs::trim(param, 1 << 10).unwrap();
cs.key_gen::<Pcs>(&pp, None);
}
4 changes: 2 additions & 2 deletions ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fn test_rw_lk_expression_combination() {

// pcs setup
let param = Pcs::setup(1 << 13).unwrap();
let (pp, vp) = Pcs::trim(&param, 1 << 13).unwrap();
let (pp, vp) = Pcs::trim(param, 1 << 13).unwrap();

// configure
let name = TestCircuit::<E, RW, L>::name();
Expand Down Expand Up @@ -221,7 +221,7 @@ fn test_single_add_instance_e2e() {
);

let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let mut zkvm_cs = ZKVMConstraintSystem::default();
// opcode circuits
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
Expand Down
2 changes: 1 addition & 1 deletion mpcs/benches/basecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) {
let (pp, _) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
let polys = (0..batch_size)
.map(|_| {
Expand Down
6 changes: 3 additions & 3 deletions mpcs/benches/commit_open_verify_basecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) {
Pcs::setup(poly_size).unwrap();
})
});
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};

let mut transcript = T::new(b"BaseFold");
Expand Down Expand Up @@ -118,7 +118,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
// Batch commit and open
let evals = chain![
Expand Down Expand Up @@ -258,7 +258,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base:
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
let mut transcript = T::new(b"BaseFold");
let polys = (0..batch_size)
Expand Down
6 changes: 3 additions & 3 deletions mpcs/benches/commit_open_verify_rs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) {
Pcs::setup(poly_size).unwrap();
})
});
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};

let mut transcript = T::new(b"BaseFold");
Expand Down Expand Up @@ -125,7 +125,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
// Batch commit and open
let evals = chain![
Expand Down Expand Up @@ -266,7 +266,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base:
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
let mut transcript = T::new(b"BaseFold");
let polys = (0..batch_size)
Expand Down
2 changes: 1 addition & 1 deletion mpcs/benches/rscode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) {
let (pp, _) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
let polys = (0..batch_size)
.map(|_| {
Expand Down
4 changes: 2 additions & 2 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ where
/// Derive the proving key and verification key from the public parameter.
/// This step simultaneously trims the parameter for the particular size.
fn trim(
pp: &Self::Param,
pp: Self::Param,
poly_size: usize,
) -> Result<(Self::ProverParam, Self::VerifierParam), Error> {
<Spec::EncodingScheme as EncodingScheme<E>>::trim(&pp.params, log2_strict(poly_size)).map(
<Spec::EncodingScheme as EncodingScheme<E>>::trim(pp.params, log2_strict(poly_size)).map(
|(pp, vp)| {
(
BasefoldProverParams {
Expand Down
4 changes: 2 additions & 2 deletions mpcs/src/basefold/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub trait EncodingScheme<E: ExtensionField>: std::fmt::Debug + Clone {
fn setup(max_msg_size_log: usize) -> Self::PublicParameters;

fn trim(
pp: &Self::PublicParameters,
pp: Self::PublicParameters,
max_msg_size_log: usize,
) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>;

Expand Down Expand Up @@ -177,7 +177,7 @@ pub(crate) mod test_util {
let mut poly = FieldType::Ext(poly);

let pp: Code::PublicParameters = Code::setup(num_vars);
let (pp, _) = Code::trim(&pp, num_vars).unwrap();
let (pp, _) = Code::trim(pp, num_vars).unwrap();
let mut codeword = Code::encode(&pp, &poly);
reverse_index_bits_in_place_field_type(&mut codeword);
if Code::message_is_left_and_right_folding() {
Expand Down
11 changes: 7 additions & 4 deletions mpcs/src/basefold/encoding/basecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ where
}

fn trim(
pp: &Self::PublicParameters,
mut pp: Self::PublicParameters,
max_msg_size_log: usize,
) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> {
if pp.table.len() < Spec::get_rate_log() + max_msg_size_log {
Expand All @@ -127,6 +127,9 @@ where
max_msg_size_log,
)));
}
pp.table_w_weights
yczhangsjtu marked this conversation as resolved.
Show resolved Hide resolved
.truncate(Spec::get_rate_log() + max_msg_size_log);
pp.table.truncate(Spec::get_rate_log() + max_msg_size_log);
let mut key: [u8; 16] = [0u8; 16];
let mut iv: [u8; 16] = [0u8; 16];
let mut rng = ChaCha8Rng::from_seed(pp.rng_seed);
Expand All @@ -135,8 +138,8 @@ where
rng.fill_bytes(&mut iv);
Ok((
Self::ProverParameters {
table_w_weights: pp.table_w_weights.clone(),
table: pp.table.clone(),
table_w_weights: pp.table_w_weights,
table: pp.table,
rng_seed: pp.rng_seed,
_phantom: PhantomData,
},
Expand Down Expand Up @@ -430,7 +433,7 @@ mod tests {
fn prover_verifier_consistency() {
type Code = Basecode<BasecodeDefaultSpec>;
let pp: BasecodeParameters<GoldilocksExt2> = Code::setup(10);
let (pp, vp) = Code::trim(&pp, 10).unwrap();
let (pp, vp) = Code::trim(pp, 10).unwrap();
for level in 0..(10 + <Code as EncodingScheme<GoldilocksExt2>>::get_rate_log()) {
for index in 0..(1 << level) {
assert_eq!(
Expand Down
36 changes: 18 additions & 18 deletions mpcs/src/basefold/encoding/rs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ where
}

fn trim(
pp: &Self::PublicParameters,
mut pp: Self::PublicParameters,
max_message_size_log: usize,
) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> {
if pp.fft_root_table.len() < max_message_size_log + Spec::get_rate_log() {
Expand Down Expand Up @@ -308,7 +308,6 @@ where
},
));
}

let mut gamma_powers = Vec::with_capacity(max_message_size_log);
let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log);
gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR);
Expand All @@ -319,26 +318,27 @@ where
}
let inv_of_two = E::BaseField::from(2).invert().unwrap();
gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two);
pp.fft_root_table
yczhangsjtu marked this conversation as resolved.
Show resolved Hide resolved
.truncate(max_message_size_log + Spec::get_rate_log());
let verifier_fft_root_table = pp.fft_root_table
[..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()]
.iter()
.cloned()
.chain(
pp.fft_root_table[Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..]
.iter()
.map(|v| vec![v[1]]),
)
.collect();
Ok((
Self::ProverParameters {
fft_root_table: pp.fft_root_table[..max_message_size_log + Spec::get_rate_log()]
.to_vec(),
fft_root_table: pp.fft_root_table,
gamma_powers: gamma_powers.clone(),
gamma_powers_inv_div_two: gamma_powers_inv.clone(),
full_message_size_log: max_message_size_log,
},
Self::VerifierParameters {
fft_root_table: pp.fft_root_table
[..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()]
.iter()
.cloned()
.chain(
pp.fft_root_table
[Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..]
.iter()
.map(|v| vec![v[1]]),
)
.collect(),
fft_root_table: verifier_fft_root_table,
full_message_size_log: max_message_size_log,
gamma_powers,
gamma_powers_inv_div_two: gamma_powers_inv,
Expand Down Expand Up @@ -653,7 +653,7 @@ mod tests {
fn prover_verifier_consistency() {
type Code = RSCode<RSCodeDefaultSpec>;
let pp: RSCodeParameters<GoldilocksExt2> = Code::setup(10);
let (pp, vp) = Code::trim(&pp, 10).unwrap();
let (pp, vp) = Code::trim(pp, 10).unwrap();
for level in 0..(10 + <Code as EncodingScheme<GoldilocksExt2>>::get_rate_log()) {
for index in 0..(1 << level) {
let (naive_x0, naive_x1, naive_w) =
Expand Down Expand Up @@ -690,7 +690,7 @@ mod tests {
let poly = FieldType::Ext(poly);

let pp = <Code as EncodingScheme<E>>::setup(num_vars);
let (pp, _) = Code::trim(&pp, num_vars).unwrap();
let (pp, _) = Code::trim(pp, num_vars).unwrap();
let mut codeword = Code::encode(&pp, &poly);
reverse_index_bits_in_place_field_type(&mut codeword);
let challenge = E::from(2);
Expand Down Expand Up @@ -728,7 +728,7 @@ mod tests {
let poly = FieldType::Ext(poly);

let pp = <Code as EncodingScheme<E>>::setup(num_vars);
let (pp, _) = Code::trim(&pp, num_vars).unwrap();
let (pp, _) = Code::trim(pp, num_vars).unwrap();
let mut codeword = Code::encode(&pp, &poly);
check_low_degree(&codeword, "low degree check for original codeword");
let c0 = field_type_index_ext(&codeword, 0);
Expand Down
10 changes: 5 additions & 5 deletions mpcs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub fn pcs_setup<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
}

pub fn pcs_trim<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
param: &Pcs::Param,
param: Pcs::Param,
poly_size: usize,
) -> Result<(Pcs::ProverParam, Pcs::VerifierParam), Error> {
Pcs::trim(param, poly_size)
Expand Down Expand Up @@ -119,7 +119,7 @@ pub trait PolynomialCommitmentScheme<E: ExtensionField>: Clone + Debug {
fn setup(poly_size: usize) -> Result<Self::Param, Error>;

fn trim(
param: &Self::Param,
param: Self::Param,
poly_size: usize,
) -> Result<(Self::ProverParam, Self::VerifierParam), Error>;

Expand Down Expand Up @@ -380,7 +380,7 @@ pub mod test_util {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
// Commit and open
let (comm, eval, proof, challenge) = {
Expand Down Expand Up @@ -442,7 +442,7 @@ pub mod test_util {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};
// Batch commit and open
let evals = chain![
Expand Down Expand Up @@ -556,7 +556,7 @@ pub mod test_util {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();
Pcs::trim(&param, poly_size).unwrap()
Pcs::trim(param, poly_size).unwrap()
};

let (comm, evals, proof, challenge) = {
Expand Down
Loading