From 52baa91fa05f40e76751d242b2dd5263d97ac467 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 23 Aug 2024 19:41:03 +0800 Subject: [PATCH] add verify_table_proof --- ceno_zkvm/src/scheme/prover.rs | 3 +- ceno_zkvm/src/scheme/verifier.rs | 187 ++++++++++++++++++++++++++++++- ceno_zkvm/src/tables/range.rs | 28 +++-- 3 files changed, 202 insertions(+), 16 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 168ebd9a7..cf5ab747b 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -513,7 +513,7 @@ impl ZKVMProver { vec![TowerProverSpec { witness: lk_wit_layers, }], - NUM_FANIN, + NUM_FANIN_LOGUP, transcript, ); assert_eq!(rt_tower.len(), log2_num_instances + log2_lk_count); @@ -589,7 +589,6 @@ impl ZKVMProver { // \sum_t (sel(rt, t) * (\sum_i alpha_lk_n * eq(rs, i) * lk_n_record[i])) virtual_polys.add_mle_list(vec![&sel_lk, &lk_n_wit[i]], eq_lk[i] * alpha_lk_n); } - // \sum_t alpha_lk * sel(rt, t) * 0 * (\sum_i (eq(rs, i)) - 1) let mut distrinct_zerocheck_terms_set = BTreeSet::new(); // degree > 1 zero expression sumcheck diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index b6d2730fd..266b23688 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -20,7 +20,9 @@ use crate::{ utils::{get_challenge_pows, sel_eval}, }; -use super::{constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMProof}; +use super::{ + constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMProof, ZKVMTableProof, +}; pub struct ZKVMVerifier { circuit: Circuit, @@ -259,6 +261,187 @@ impl ZKVMVerifier { Ok(input_opening_point) } + + pub fn verify_table_proof( + &self, + proof: &ZKVMTableProof, + transcript: &mut Transcript, + num_logup_fanin: usize, + _out_evals: &PointAndEval, + challenges: &[E; 2], // derive challenge from PCS + ) -> Result, ZKVMError> { + let lk_counts_per_instance = self.circuit.lk_expressions.len(); + let log2_lk_count = ceil_log2(lk_counts_per_instance); + let (chip_record_alpha, _) = (challenges[0], challenges[1]); + + let num_instances = proof.num_instances; + let log2_num_instances = ceil_log2(num_instances); + + // verify and reduce product tower sumcheck + let tower_proofs = &proof.tower_proof; + + let expected_max_round = log2_num_instances + log2_lk_count; + let (rt_tower, record_evals, logup_p_evals, logup_q_evals) = TowerVerify::verify( + vec![], + vec![vec![ + proof.lk_p1_out_eval, + proof.lk_p2_out_eval, + proof.lk_q1_out_eval, + proof.lk_q2_out_eval, + ]], + tower_proofs, + expected_max_round, + num_logup_fanin, + transcript, + )?; + assert!(logup_q_evals.len() == 1, "[lk_q_record]"); + assert!(logup_p_evals.len() == 1, "[lk_p_record]"); + + // verify zero statement (degree > 1) + sel sumcheck + let rt_lk: Vec = rt_tower[..log2_num_instances + log2_lk_count].to_vec(); + + let alpha_pow = get_challenge_pows( + 2 + self.circuit.assert_zero_sumcheck_expressions.len(), + transcript, + ); + let mut alpha_pow_iter = alpha_pow.iter(); + let (alpha_lk_d, alpha_lk_n) = ( + alpha_pow_iter.next().unwrap(), + alpha_pow_iter.next().unwrap(), + ); + // alpha_lk * (out_lk_q - chip_record_alpha) + alpha_lk_n * out_lk_p + // + 0 // 0 come from zero check + let claim_sum = + *alpha_lk_d * (logup_q_evals[0] - chip_record_alpha) + *alpha_lk_n * logup_p_evals[0]; + let main_sel_subclaim = IOPVerifierState::verify( + claim_sum, + &IOPProof { + point: vec![], // final claimed point will be derived from sumcheck protocol + proofs: proof.main_sel_sumcheck_proofs.clone(), + }, + &VPAuxInfo { + max_degree: SEL_DEGREE.max(self.circuit.max_non_lc_degree), + num_variables: log2_num_instances, + phantom: PhantomData, + }, + transcript, + ); + let (input_opening_point, expected_evaluation) = ( + main_sel_subclaim + .point + .iter() + .map(|c| c.elements) + .collect_vec(), + main_sel_subclaim.expected_evaluation, + ); + let eq_lk = build_eq_x_r_vec_sequential(&rt_lk[..log2_lk_count]); + + let (sel_lk, sel_non_lc_zero_sumcheck) = { + ( + eq_eval(&rt_lk[log2_lk_count..], &input_opening_point) + * sel_eval(num_instances, &rt_lk[log2_lk_count..]), + // only initialize when circuit got non empty assert_zero_sumcheck_expressions + { + let rt_non_lc_sumcheck = rt_tower[..log2_num_instances].to_vec(); + if !self.circuit.assert_zero_sumcheck_expressions.is_empty() { + Some( + eq_eval(&rt_non_lc_sumcheck, &input_opening_point) + * sel_eval(num_instances, &rt_non_lc_sumcheck), + ) + } else { + None + } + }, + ) + }; + + let computed_evals = [ + // lookup denominator + *alpha_lk_d + * sel_lk + * ((0..lk_counts_per_instance) + .map(|i| proof.lk_d_in_evals[i] * eq_lk[i]) + .sum::() + + chip_record_alpha + * (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), + *alpha_lk_n + * sel_lk + * ((0..lk_counts_per_instance) + .map(|i| proof.lk_n_in_evals[i] * eq_lk[i]) + .sum::()), + // degree > 1 zero exp sumcheck + { + // sel(rt_non_lc_sumcheck, main_sel_eval_point) * \sum_j (alpha{j} * expr(main_sel_eval_point)) + sel_non_lc_zero_sumcheck.unwrap_or(E::ZERO) + * self + .circuit + .assert_zero_sumcheck_expressions + .iter() + .zip_eq(alpha_pow_iter) + .map(|(expr, alpha)| { + // evaluate zero expression by all wits_in_evals because they share the unique input_opening_point opening + *alpha + * eval_by_expr( + &proof.fixed_in_evals, + &proof.wits_in_evals, + challenges, + expr, + ) + }) + .sum::() + }, + ] + .iter() + .sum::(); + if computed_evals != expected_evaluation { + return Err(ZKVMError::VerifyError( + "main + sel evaluation verify failed", + )); + } + // verify records (degree = 1) statement, thus no sumcheck + if self + .circuit + .lk_expressions + .iter() + .map(|lk| &lk.values) + .chain( + self.circuit + .lk_expressions + .iter() + .map(|lk| &lk.multiplicity), + ) + .zip_eq( + proof.lk_d_in_evals[..lk_counts_per_instance] + .iter() + .chain(proof.lk_n_in_evals[..lk_counts_per_instance].iter()), + ) + .any(|(expr, expected_evals)| { + eval_by_expr( + &proof.fixed_in_evals, + &proof.wits_in_evals, + challenges, + expr, + ) != *expected_evals + }) + { + return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); + } + + // verify zero expression (degree = 1) statement, thus no sumcheck + if self.circuit.assert_zero_expressions.iter().any(|expr| { + eval_by_expr( + &proof.fixed_in_evals, + &proof.wits_in_evals, + challenges, + expr, + ) != E::ZERO + }) { + // TODO add me back + // return Err(ZKVMError::VerifyError("zero expression != 0")); + } + + Ok(input_opening_point) + } } pub struct TowerVerify; @@ -338,7 +521,7 @@ impl TowerVerify { let sumcheck_claim = IOPVerifierState::verify( *out_claim, &IOPProof { - point: vec![], // final claimed point will be derive from sumcheck protocol + point: vec![], // final claimed point will be derived from sumcheck protocol proofs: tower_proofs.proofs[round].clone(), }, &VPAuxInfo { diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index 55a56d1ce..d336dd2ec 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -74,7 +74,11 @@ impl RangeTableConfig { mod tests { use crate::{ circuit_builder::CircuitBuilder, - scheme::{constants::NUM_FANIN, prover::ZKVMProver, verifier::ZKVMVerifier}, + scheme::{ + constants::{NUM_FANIN, NUM_FANIN_LOGUP}, + prover::ZKVMProver, + verifier::ZKVMVerifier, + }, structs::PointAndEval, tables::range::RangeTableConfig, }; @@ -88,7 +92,6 @@ mod tests { let config = RangeTableConfig::construct_circuit(&mut cb).unwrap(); let circuit = cb.finalize_circuit(); - println!("circuit: {:?}", circuit); let traces = config.generate_traces((0..1 << 8).into_iter().collect_vec().as_slice()); let prover = ZKVMProver::new(circuit.clone()); @@ -116,15 +119,16 @@ mod tests { ) .expect("create proof"); - // let verifier = ZKVMVerifier::new(circuit); - // verifier - // .verify( - // &proof, - // &mut transcript, - // NUM_FANIN, - // &PointAndEval::default(), - // &challenges, - // ) - // .expect("verify proof failed"); + let mut transcript = Transcript::new(b"range"); + let verifier = ZKVMVerifier::new(circuit); + verifier + .verify_table_proof( + &proof, + &mut transcript, + NUM_FANIN_LOGUP, + &PointAndEval::default(), + &challenges, + ) + .expect("verify proof failed"); } }