diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index 773fbcf01..af9af772c 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -141,12 +141,12 @@ fn main() { let timer = Instant::now(); - let mut transcript = Transcript::new(b"riscv"); + let transcript = Transcript::new(b"riscv"); let mut rng = test_rng(); let real_challenges = [E::random(&mut rng), E::random(&mut rng)]; let zkvm_proof = prover - .create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges) + .create_proof(zkvm_witness, max_threads, transcript, &real_challenges) .expect("create_proof failed"); println!( @@ -155,10 +155,10 @@ fn main() { timer.elapsed().as_secs_f64() ); - let mut transcript = Transcript::new(b"riscv"); + let transcript = Transcript::new(b"riscv"); assert!( verifier - .verify_proof(zkvm_proof, &mut transcript, &real_challenges) + .verify_proof(zkvm_proof, transcript, &real_challenges) .expect("verify proof return with error"), ); } diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index cc9b7d353..67ec7a134 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -60,8 +60,17 @@ pub struct ZKVMTableProof { pub wits_in_evals: Vec, } +/// Map circuit names to +/// - an opcode or table proof, +/// - an index unique across both types. #[derive(Default, Clone)] pub struct ZKVMProof { - opcode_proofs: HashMap>, - table_proofs: HashMap>, + opcode_proofs: HashMap)>, + table_proofs: HashMap)>, +} + +impl ZKVMProof { + pub fn num_circuits(&self) -> usize { + self.opcode_proofs.len() + self.table_proofs.len() + } } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index fdf5a9aa2..31ff3b4a5 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -47,11 +47,18 @@ impl ZKVMProver { &self, mut witnesses: ZKVMWitnesses, max_threads: usize, - transcript: &mut Transcript, + transcript: Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { let mut vm_proof = ZKVMProof::default(); - for (circuit_name, pk) in self.pk.circuit_pks.iter() { + let mut transcripts = transcript.fork(self.pk.circuit_pks.len()); + + for ((circuit_name, pk), (i, transcript)) in self + .pk + .circuit_pks + .iter() // Sorted by key. + .zip_eq(transcripts.iter_mut().enumerate()) + { let witness = witnesses .witnesses .remove(circuit_name) @@ -94,7 +101,7 @@ impl ZKVMProver { ); vm_proof .opcode_proofs - .insert(circuit_name.clone(), opcode_proof); + .insert(circuit_name.clone(), (i, opcode_proof)); } else { let table_proof = self.create_table_proof( pk, @@ -116,7 +123,7 @@ impl ZKVMProver { ); vm_proof .table_proofs - .insert(circuit_name.clone(), table_proof); + .insert(circuit_name.clone(), (i, table_proof)); } } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index fb2f2e97a..28fc0dc56 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -39,7 +39,7 @@ impl ZKVMVerifier { pub fn verify_proof( &self, vm_proof: ZKVMProof, - transcript: &mut Transcript, + transcript: Transcript, challenges: &[E; 2], ) -> Result { let mut prod_r = E::ONE; @@ -48,7 +48,11 @@ impl ZKVMVerifier { let dummy_table_item = challenges[0]; let point_eval = PointAndEval::default(); let mut dummy_table_item_multiplicity = 0; - for (name, opcode_proof) in vm_proof.opcode_proofs { + let mut transcripts = transcript.fork(vm_proof.num_circuits()); + + for (name, (i, opcode_proof)) in vm_proof.opcode_proofs { + let transcript = &mut transcripts[i]; + let circuit_vk = self .vk .circuit_vks @@ -82,7 +86,9 @@ impl ZKVMVerifier { opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); } - for (name, table_proof) in vm_proof.table_proofs { + for (name, (i, table_proof)) in vm_proof.table_proofs { + let transcript = &mut transcripts[i]; + let circuit_vk = self .vk .circuit_vks diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 5a71602c2..ec78a5488 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -27,6 +27,17 @@ impl Transcript { } impl Transcript { + /// Fork this transcript into n different threads. + pub fn fork(self, n: usize) -> Vec { + let mut forks = Vec::with_capacity(n); + for i in 0..n { + let mut fork = self.clone(); + fork.append_field_element(&(i as u64).into()); + forks.push(fork); + } + forks + } + // Append the message to the transcript. pub fn append_message(&mut self, msg: &[u8]) { let msg_f = E::BaseField::bytes_to_field_elements(msg);