diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index 6d9c65f9a..bb2dbacab 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -10,14 +10,12 @@ use const_env::from_env; use ceno_emul::StepRecord; use ceno_zkvm::{ - circuit_builder::{ZKVMConstraintSystem, ZKVMVerifyingKey}, + circuit_builder::ZKVMConstraintSystem, scheme::verifier::ZKVMVerifier, tables::{RangeTableCircuit, TableCircuit}, }; use ff_ext::ff::Field; -use goldilocks::{Goldilocks, GoldilocksExt2}; -use itertools::Itertools; -use multilinear_extensions::mle::IntoMLE; +use goldilocks::GoldilocksExt2; use sumcheck::util::is_power_of_2; use tracing_flame::FlameLayer; use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; @@ -98,7 +96,7 @@ fn main() { let prover = ZKVMProver::new(pk); let verifier = ZKVMVerifier::new(vk); - for instance_num_vars in 20..22 { + for instance_num_vars in 15..22 { // TODO: witness generation from step records emitted by tracer let num_instances = 1 << instance_num_vars; let mut zkvm_witness = BTreeMap::default(); @@ -111,7 +109,10 @@ fn main() { let range_witness = RangeTableCircuit::::assign_instances( &range_config, range_cs.num_witin as usize, - &[], + // TODO: use real data + vec![vec![0; num_instances * 2], vec![4; num_instances * 6]] + .concat() + .as_slice(), ) .unwrap(); @@ -128,9 +129,10 @@ fn main() { .create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges) .expect("create_proof failed"); + let mut transcript = Transcript::new(b"riscv"); assert!( verifier - .verify_proof(zkvm_proof, &mut transcript, &real_challenges,) + .verify_proof(zkvm_proof, &mut transcript, &real_challenges) .expect("verify proof return with error"), ); diff --git a/ceno_zkvm/src/error.rs b/ceno_zkvm/src/error.rs index d623364c9..ea59969f8 100644 --- a/ceno_zkvm/src/error.rs +++ b/ceno_zkvm/src/error.rs @@ -7,7 +7,7 @@ pub enum UtilError { pub enum ZKVMError { CircuitError, UtilError(UtilError), - VerifyError(&'static str), + VerifyError(String), } impl From for ZKVMError { diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 4771df151..685e6f1cb 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,6 +1,6 @@ use ff_ext::ExtensionField; use std::{ - collections::{BTreeMap, BTreeSet, HashMap}, + collections::{BTreeMap, BTreeSet}, sync::Arc, }; @@ -65,6 +65,17 @@ impl ZKVMProver { let num_instances = witness.num_instances(); if is_opcode_circuit { + tracing::debug!( + "opcode circuit {} has {} witnesses, {} reads, {} writes, {} lookups", + circuit_name, + cs.num_witin, + cs.r_expressions.len(), + cs.w_expressions.len(), + cs.lk_expressions.len(), + ); + for lk_s in cs.lk_expressions_namespace_map.iter() { + tracing::debug!("opcode circuit {}: {}", circuit_name, lk_s); + } let opcode_proof = self.create_opcode_proof( pk, witness @@ -78,6 +89,11 @@ impl ZKVMProver { transcript, challenges, )?; + tracing::info!( + "generated proof for opcode {} with num_instances={}", + circuit_name, + num_instances + ); vm_proof .opcode_proofs .insert(circuit_name.clone(), opcode_proof); @@ -95,6 +111,11 @@ impl ZKVMProver { transcript, challenges, )?; + tracing::info!( + "generated proof for table {} with num_instances={}", + circuit_name, + num_instances + ); vm_proof .table_proofs .insert(circuit_name.clone(), table_proof); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 40b4a818e..92c4905f7 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -46,7 +46,9 @@ impl ZKVMVerifier { let mut prod_r = E::ONE; let mut prod_w = E::ONE; let mut logup_sum = E::ZERO; + let dummy_table_item = challenges[0]; let point_eval = PointAndEval::default(); + let mut dummy_table_item_multiplicity = 0; for (name, opcode_proof) in vm_proof.opcode_proofs { let circuit_vk = self .vk @@ -61,6 +63,16 @@ impl ZKVMVerifier { &point_eval, challenges, )?; + tracing::info!("verified proof for opcode {}", name); + + // getting the number of dummy padding item that we used in this opcode circuit + let num_lks = circuit_vk.get_cs().lk_expressions.len(); + let num_padded_lks_per_instance = num_lks.next_power_of_two() - num_lks; + let num_padded_instance = + opcode_proof.num_instances.next_power_of_two() - opcode_proof.num_instances; + dummy_table_item_multiplicity += num_padded_lks_per_instance + * opcode_proof.num_instances + + num_lks.next_power_of_two() * num_padded_instance; prod_r *= opcode_proof.record_r_out_evals.iter().product::(); prod_w *= opcode_proof.record_w_out_evals.iter().product::(); @@ -85,18 +97,26 @@ impl ZKVMVerifier { &point_eval, challenges, )?; + tracing::info!("verified proof for table {}", name); logup_sum -= table_proof.lk_p1_out_eval * table_proof.lk_q1_out_eval.invert().unwrap(); logup_sum -= table_proof.lk_p2_out_eval * table_proof.lk_q2_out_eval.invert().unwrap(); } + logup_sum -= + E::from(dummy_table_item_multiplicity as u64) * dummy_table_item.invert().unwrap(); + // check rw_set equality across all proofs - if prod_r != prod_w { - return Ok(false); - } + // TODO: enable this when we have cpu init/finalize and mem init/finalize + // if prod_r != prod_w { + // return Err(ZKVMError::VerifyError("prod_r != prod_w".into())); + // } // check logup relation across all proofs if logup_sum != E::ZERO { - return Ok(false); + return Err(ZKVMError::VerifyError(format!( + "logup_sum({:?}) != 0", + logup_sum + ))); } Ok(true) @@ -159,7 +179,7 @@ impl ZKVMVerifier { // index 0 is LogUp witness for Fixed Lookup table if logup_p_evals[0].eval != E::ONE { return Err(ZKVMError::VerifyError( - "Lookup table witness p(x) != constant 1", + "Lookup table witness p(x) != constant 1".into(), )); } @@ -279,7 +299,7 @@ impl ZKVMVerifier { .sum::(); if computed_evals != expected_evaluation { return Err(ZKVMError::VerifyError( - "main + sel evaluation verify failed", + "main + sel evaluation verify failed".into(), )); } // verify records (degree = 1) statement, thus no sumcheck @@ -298,7 +318,9 @@ impl ZKVMVerifier { eval_by_expr(&proof.wits_in_evals, challenges, expr) != *expected_evals }) { - return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); + return Err(ZKVMError::VerifyError( + "record evaluate != expected_evals".into(), + )); } // verify zero expression (degree = 1) statement, thus no sumcheck @@ -326,7 +348,6 @@ impl ZKVMVerifier { let cs = circuit_vk.get_cs(); let lk_counts_per_instance = cs.lk_table_expressions.len(); let log2_lk_count = ceil_log2(lk_counts_per_instance); - let (chip_record_alpha, _) = (challenges[0], challenges[1]); let num_instances = proof.num_instances; let log2_num_instances = ceil_log2(num_instances); @@ -394,8 +415,7 @@ impl ZKVMVerifier { * ((0..lk_counts_per_instance) .map(|i| proof.lk_d_in_evals[i] * eq_lk[i]) .sum::() - + chip_record_alpha - * (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), + + (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), *alpha_lk_n * sel_lk * ((0..lk_counts_per_instance) @@ -405,7 +425,9 @@ impl ZKVMVerifier { .iter() .sum::(); if computed_evals != expected_evaluation { - return Err(ZKVMError::VerifyError("sel evaluation verify failed")); + return Err(ZKVMError::VerifyError( + "sel evaluation verify failed".into(), + )); } // verify records (degree = 1) statement, thus no sumcheck if cs @@ -427,7 +449,9 @@ impl ZKVMVerifier { ) != *expected_evals }) { - return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); + return Err(ZKVMError::VerifyError( + "record evaluate != expected_evals".into(), + )); } Ok(input_opening_point) @@ -524,6 +548,7 @@ impl TowerVerify { }, transcript, ); + tracing::debug!("verified tower proof at layer {}/{}", round + 1, expected_max_round-1); // check expected_evaluation let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); @@ -555,7 +580,7 @@ impl TowerVerify { }) .sum::(); if expected_evaluation != sumcheck_claim.expected_evaluation { - return Err(ZKVMError::VerifyError("mismatch tower evaluation")); + return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); } // derive single eval diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index 0dabf5791..4c1916265 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -69,6 +69,7 @@ impl TableCircuit for RangeTableCircuit { for limb in inputs { u16_mlt[*limb] += 1; } + tracing::debug!("u16_mult[4] = {}", u16_mlt[4]); let mut witness = RowMajorMatrix::::new(u16_mlt.len(), num_witin); witness