Skip to content

Commit

Permalink
Fix/transcript fork (#224)
Browse files Browse the repository at this point in the history
This is another approach to #223, fixing #222 / #210. Features:

- Prepare for parallel proving with independent transcripts per thread.
- Soundness with different transcripts in each thread.
- Type-safe API which captures the transcript by value.

There is no function to merge the threads back into one transcript, but
that can be added if ever needed.

---------

Co-authored-by: Aurélien Nicolas <[email protected]>
  • Loading branch information
2 people authored and hero78119 committed Sep 30, 2024
1 parent d646e95 commit 891959f
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 13 deletions.
8 changes: 4 additions & 4 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,12 @@ fn main() {

let timer = Instant::now();

let mut transcript = Transcript::new(b"riscv");
let transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.create_proof(zkvm_witness, max_threads, transcript, &real_challenges)
.expect("create_proof failed");

println!(
Expand All @@ -155,10 +155,10 @@ fn main() {
timer.elapsed().as_secs_f64()
);

let mut transcript = Transcript::new(b"riscv");
let transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges)
.verify_proof(zkvm_proof, transcript, &real_challenges)
.expect("verify proof return with error"),
);
}
Expand Down
13 changes: 11 additions & 2 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ pub struct ZKVMTableProof<E: ExtensionField> {
pub wits_in_evals: Vec<E>,
}

/// Map circuit names to
/// - an opcode or table proof,
/// - an index unique across both types.
#[derive(Default, Clone)]
pub struct ZKVMProof<E: ExtensionField> {
opcode_proofs: HashMap<String, ZKVMOpcodeProof<E>>,
table_proofs: HashMap<String, ZKVMTableProof<E>>,
opcode_proofs: HashMap<String, (usize, ZKVMOpcodeProof<E>)>,
table_proofs: HashMap<String, (usize, ZKVMTableProof<E>)>,
}

impl<E: ExtensionField> ZKVMProof<E> {
pub fn num_circuits(&self) -> usize {
self.opcode_proofs.len() + self.table_proofs.len()
}
}
15 changes: 11 additions & 4 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,18 @@ impl<E: ExtensionField> ZKVMProver<E> {
&self,
mut witnesses: ZKVMWitnesses<E>,
max_threads: usize,
transcript: &mut Transcript<E>,
transcript: Transcript<E>,
challenges: &[E; 2],
) -> Result<ZKVMProof<E>, ZKVMError> {
let mut vm_proof = ZKVMProof::default();
for (circuit_name, pk) in self.pk.circuit_pks.iter() {
let mut transcripts = transcript.fork(self.pk.circuit_pks.len());

for ((circuit_name, pk), (i, transcript)) in self
.pk
.circuit_pks
.iter() // Sorted by key.
.zip_eq(transcripts.iter_mut().enumerate())
{
let witness = witnesses
.witnesses
.remove(circuit_name)
Expand Down Expand Up @@ -94,7 +101,7 @@ impl<E: ExtensionField> ZKVMProver<E> {
);
vm_proof
.opcode_proofs
.insert(circuit_name.clone(), opcode_proof);
.insert(circuit_name.clone(), (i, opcode_proof));
} else {
let table_proof = self.create_table_proof(
pk,
Expand All @@ -116,7 +123,7 @@ impl<E: ExtensionField> ZKVMProver<E> {
);
vm_proof
.table_proofs
.insert(circuit_name.clone(), table_proof);
.insert(circuit_name.clone(), (i, table_proof));
}
}

Expand Down
12 changes: 9 additions & 3 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
pub fn verify_proof(
&self,
vm_proof: ZKVMProof<E>,
transcript: &mut Transcript<E>,
transcript: Transcript<E>,
challenges: &[E; 2],
) -> Result<bool, ZKVMError> {
let mut prod_r = E::ONE;
Expand All @@ -48,7 +48,11 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
let dummy_table_item = challenges[0];
let point_eval = PointAndEval::default();
let mut dummy_table_item_multiplicity = 0;
for (name, opcode_proof) in vm_proof.opcode_proofs {
let mut transcripts = transcript.fork(vm_proof.num_circuits());

for (name, (i, opcode_proof)) in vm_proof.opcode_proofs {
let transcript = &mut transcripts[i];

let circuit_vk = self
.vk
.circuit_vks
Expand Down Expand Up @@ -82,7 +86,9 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap();
}

for (name, table_proof) in vm_proof.table_proofs {
for (name, (i, table_proof)) in vm_proof.table_proofs {
let transcript = &mut transcripts[i];

let circuit_vk = self
.vk
.circuit_vks
Expand Down
11 changes: 11 additions & 0 deletions transcript/src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ impl<E: ExtensionField> Transcript<E> {
}

impl<E: ExtensionField> Transcript<E> {
/// Fork this transcript into n different threads.
pub fn fork(self, n: usize) -> Vec<Self> {
let mut forks = Vec::with_capacity(n);
for i in 0..n {
let mut fork = self.clone();
fork.append_field_element(&(i as u64).into());
forks.push(fork);
}
forks
}

// Append the message to the transcript.
pub fn append_message(&mut self, msg: &[u8]) {
let msg_f = E::BaseField::bytes_to_field_elements(msg);
Expand Down

0 comments on commit 891959f

Please sign in to comment.