Skip to content

Commit

Permalink
fix/transcript_fork
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurélien Nicolas committed Sep 14, 2024
1 parent 0f50995 commit 7642b52
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 23 deletions.
8 changes: 4 additions & 4 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,12 @@ fn main() {

let timer = Instant::now();

let mut transcript = Transcript::new(b"riscv");
let transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.create_proof(zkvm_witness, max_threads, transcript, &real_challenges)
.expect("create_proof failed");

println!(
Expand All @@ -155,10 +155,10 @@ fn main() {
timer.elapsed().as_secs_f64()
);

let mut transcript = Transcript::new(b"riscv");
let transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges)
.verify_proof(zkvm_proof, transcript, &real_challenges)
.expect("verify proof return with error"),
);
}
Expand Down
13 changes: 11 additions & 2 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ pub struct ZKVMTableProof<E: ExtensionField> {
pub wits_in_evals: Vec<E>,
}

/// Map circuit names to
/// - an opcode or table proof,
/// - an index unique across both types.
#[derive(Default, Clone)]
pub struct ZKVMProof<E: ExtensionField> {
opcode_proofs: HashMap<String, ZKVMOpcodeProof<E>>,
table_proofs: HashMap<String, ZKVMTableProof<E>>,
opcode_proofs: HashMap<String, (usize, ZKVMOpcodeProof<E>)>,
table_proofs: HashMap<String, (usize, ZKVMTableProof<E>)>,
}

impl<E: ExtensionField> ZKVMProof<E> {
pub fn num_circuits(&self) -> usize {
self.opcode_proofs.len() + self.table_proofs.len()
}
}
13 changes: 8 additions & 5 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,17 @@ impl<E: ExtensionField> ZKVMProver<E> {
&self,
mut witnesses: ZKVMWitnesses<E>,
max_threads: usize,
transcript: &mut Transcript<E>,
transcript: Transcript<E>,
challenges: &[E; 2],
) -> Result<ZKVMProof<E>, ZKVMError> {
let mut vm_proof = ZKVMProof::default();
let mut transcripts = transcript.fork(self.pk.circuit_pks.len());

for ((circuit_name, pk), transcript) in
self.pk.circuit_pks.iter().zip_eq(transcripts.iter_mut())
for ((circuit_name, pk), (i, transcript)) in self
.pk
.circuit_pks
.iter()
.zip_eq(transcripts.iter_mut().enumerate())
{
let witness = witnesses
.witnesses
Expand Down Expand Up @@ -98,7 +101,7 @@ impl<E: ExtensionField> ZKVMProver<E> {
);
vm_proof
.opcode_proofs
.insert(circuit_name.clone(), opcode_proof);
.insert(circuit_name.clone(), (i, opcode_proof));
} else {
let table_proof = self.create_table_proof(
pk,
Expand All @@ -120,7 +123,7 @@ impl<E: ExtensionField> ZKVMProver<E> {
);
vm_proof
.table_proofs
.insert(circuit_name.clone(), table_proof);
.insert(circuit_name.clone(), (i, table_proof));
}
}

Expand Down
12 changes: 9 additions & 3 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
pub fn verify_proof(
&self,
vm_proof: ZKVMProof<E>,
transcript: &mut Transcript<E>,
transcript: Transcript<E>,
challenges: &[E; 2],
) -> Result<bool, ZKVMError> {
let mut prod_r = E::ONE;
Expand All @@ -48,7 +48,11 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
let dummy_table_item = challenges[0];
let point_eval = PointAndEval::default();
let mut dummy_table_item_multiplicity = 0;
for (name, opcode_proof) in vm_proof.opcode_proofs {
let mut transcripts = transcript.fork(vm_proof.num_circuits());

for (name, (i, opcode_proof)) in vm_proof.opcode_proofs {
let transcript = &mut transcripts[i];

let circuit_vk = self
.vk
.circuit_vks
Expand Down Expand Up @@ -82,7 +86,9 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap();
}

for (name, table_proof) in vm_proof.table_proofs {
for (name, (i, table_proof)) in vm_proof.table_proofs {
let transcript = &mut transcripts[i];

let circuit_vk = self
.vk
.circuit_vks
Expand Down
10 changes: 1 addition & 9 deletions transcript/src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl<E: ExtensionField> Transcript<E> {

impl<E: ExtensionField> Transcript<E> {
/// Fork this transcript into n different threads.
pub fn fork(&mut self, n: usize) -> Vec<Self> {
pub fn fork(self, n: usize) -> Vec<Self> {
let mut forks = Vec::with_capacity(n);
for i in 0..n {
let mut t = self.clone();
Expand All @@ -38,14 +38,6 @@ impl<E: ExtensionField> Transcript<E> {
forks
}

/// Include the history of the forks into the current transcript.
/// NOT IMPLEMENTED.
pub fn merge(&mut self, forks: Vec<Self>) {
for fork in forks {
self.append_field_element_ext(&fork.read_field_element_ext());
}
}

// Append the message to the transcript.
pub fn append_message(&mut self, msg: &[u8]) {
let msg_f = E::BaseField::bytes_to_field_elements(msg);
Expand Down

0 comments on commit 7642b52

Please sign in to comment.