From 7642b52296ef621619e2f5c7510ca5c9bea5494d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Sat, 14 Sep 2024 11:22:44 +0200 Subject: [PATCH] fix/transcript_fork --- ceno_zkvm/examples/riscv_add.rs | 8 ++++---- ceno_zkvm/src/scheme.rs | 13 +++++++++++-- ceno_zkvm/src/scheme/prover.rs | 13 ++++++++----- ceno_zkvm/src/scheme/verifier.rs | 12 +++++++++--- transcript/src/basic.rs | 10 +--------- 5 files changed, 33 insertions(+), 23 deletions(-) diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index 773fbcf01..af9af772c 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -141,12 +141,12 @@ fn main() { let timer = Instant::now(); - let mut transcript = Transcript::new(b"riscv"); + let transcript = Transcript::new(b"riscv"); let mut rng = test_rng(); let real_challenges = [E::random(&mut rng), E::random(&mut rng)]; let zkvm_proof = prover - .create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges) + .create_proof(zkvm_witness, max_threads, transcript, &real_challenges) .expect("create_proof failed"); println!( @@ -155,10 +155,10 @@ fn main() { timer.elapsed().as_secs_f64() ); - let mut transcript = Transcript::new(b"riscv"); + let transcript = Transcript::new(b"riscv"); assert!( verifier - .verify_proof(zkvm_proof, &mut transcript, &real_challenges) + .verify_proof(zkvm_proof, transcript, &real_challenges) .expect("verify proof return with error"), ); } diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index cc9b7d353..67ec7a134 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -60,8 +60,17 @@ pub struct ZKVMTableProof { pub wits_in_evals: Vec, } +/// Map circuit names to +/// - an opcode or table proof, +/// - an index unique across both types. #[derive(Default, Clone)] pub struct ZKVMProof { - opcode_proofs: HashMap>, - table_proofs: HashMap>, + opcode_proofs: HashMap)>, + table_proofs: HashMap)>, +} + +impl ZKVMProof { + pub fn num_circuits(&self) -> usize { + self.opcode_proofs.len() + self.table_proofs.len() + } } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 7379e6b61..c3592142e 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -47,14 +47,17 @@ impl ZKVMProver { &self, mut witnesses: ZKVMWitnesses, max_threads: usize, - transcript: &mut Transcript, + transcript: Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { let mut vm_proof = ZKVMProof::default(); let mut transcripts = transcript.fork(self.pk.circuit_pks.len()); - for ((circuit_name, pk), transcript) in - self.pk.circuit_pks.iter().zip_eq(transcripts.iter_mut()) + for ((circuit_name, pk), (i, transcript)) in self + .pk + .circuit_pks + .iter() + .zip_eq(transcripts.iter_mut().enumerate()) { let witness = witnesses .witnesses @@ -98,7 +101,7 @@ impl ZKVMProver { ); vm_proof .opcode_proofs - .insert(circuit_name.clone(), opcode_proof); + .insert(circuit_name.clone(), (i, opcode_proof)); } else { let table_proof = self.create_table_proof( pk, @@ -120,7 +123,7 @@ impl ZKVMProver { ); vm_proof .table_proofs - .insert(circuit_name.clone(), table_proof); + .insert(circuit_name.clone(), (i, table_proof)); } } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index fb2f2e97a..28fc0dc56 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -39,7 +39,7 @@ impl ZKVMVerifier { pub fn verify_proof( &self, vm_proof: ZKVMProof, - transcript: &mut Transcript, + transcript: Transcript, challenges: &[E; 2], ) -> Result { let mut prod_r = E::ONE; @@ -48,7 +48,11 @@ impl ZKVMVerifier { let dummy_table_item = challenges[0]; let point_eval = PointAndEval::default(); let mut dummy_table_item_multiplicity = 0; - for (name, opcode_proof) in vm_proof.opcode_proofs { + let mut transcripts = transcript.fork(vm_proof.num_circuits()); + + for (name, (i, opcode_proof)) in vm_proof.opcode_proofs { + let transcript = &mut transcripts[i]; + let circuit_vk = self .vk .circuit_vks @@ -82,7 +86,9 @@ impl ZKVMVerifier { opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); } - for (name, table_proof) in vm_proof.table_proofs { + for (name, (i, table_proof)) in vm_proof.table_proofs { + let transcript = &mut transcripts[i]; + let circuit_vk = self .vk .circuit_vks diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 918351705..857b5c201 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -28,7 +28,7 @@ impl Transcript { impl Transcript { /// Fork this transcript into n different threads. - pub fn fork(&mut self, n: usize) -> Vec { + pub fn fork(self, n: usize) -> Vec { let mut forks = Vec::with_capacity(n); for i in 0..n { let mut t = self.clone(); @@ -38,14 +38,6 @@ impl Transcript { forks } - /// Include the history of the forks into the current transcript. - /// NOT IMPLEMENTED. - pub fn merge(&mut self, forks: Vec) { - for fork in forks { - self.append_field_element_ext(&fork.read_field_element_ext()); - } - } - // Append the message to the transcript. pub fn append_message(&mut self, msg: &[u8]) { let msg_f = E::BaseField::bytes_to_field_elements(msg);