diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 61f8c4822..d40e92941 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -2,7 +2,7 @@ use std::time::{Duration, Instant}; use ark_std::test_rng; use ceno_zkvm::{ - circuit_builder::{CircuitBuilder, ConstraintSystem, ProvingKey}, + circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{riscv::addsub::AddInstruction, Instruction}, scheme::prover::ZKVMProver, }; @@ -65,8 +65,7 @@ fn bench_add(c: &mut Criterion) { let mut cs = ConstraintSystem::new(|| "risv_add"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let _ = AddInstruction::construct_circuit(&mut circuit_builder); - let vk = cs.key_gen(); - let pk = ProvingKey::create_pk(vk); + let pk = cs.key_gen(None); let num_witin = pk.get_cs().num_witin; let prover = ZKVMProver::new(pk); diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index 6a31a5abb..b7c547782 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -2,7 +2,7 @@ use std::time::Instant; use ark_std::test_rng; use ceno_zkvm::{ - circuit_builder::{CircuitBuilder, ConstraintSystem, ProvingKey}, + circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{riscv::addsub::AddInstruction, Instruction}, scheme::prover::ZKVMProver, }; @@ -44,8 +44,7 @@ fn main() { let mut cs = ConstraintSystem::new(|| "risv_add"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let _ = AddInstruction::construct_circuit(&mut circuit_builder); - let vk = cs.key_gen(); - let pk = ProvingKey::create_pk(vk); + let pk = cs.key_gen(None); let num_witin = pk.get_cs().num_witin; let prover = ZKVMProver::new(pk); diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index dc4b26e4f..5b9f0b84d 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -5,7 +5,7 @@ use ff::Field; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, - expression::{Expression, WitIn}, + expression::{Expression, Fixed, WitIn}, structs::ROMType, }; @@ -22,6 +22,14 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.create_witin(name_fn) } + pub fn create_fixed(&mut self, name_fn: N) -> Result + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.create_fixed(name_fn) + } + pub fn lk_record( &mut self, name_fn: N, @@ -34,6 +42,19 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.lk_record(name_fn, rlc_record) } + pub fn lk_table_record( + &mut self, + name_fn: N, + rlc_record: Expression, + multiplicity: Expression, + ) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.lk_table_record(name_fn, rlc_record, multiplicity) + } + pub fn read_record( &mut self, name_fn: N, diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 8a075101f..794ababb4 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,10 +1,11 @@ use std::marker::PhantomData; use ff_ext::ExtensionField; +use multilinear_extensions::mle::DenseMultilinearExtension; use crate::{ error::ZKVMError, - expression::{Expression, WitIn}, + expression::{Expression, Fixed, WitIn}, structs::WitnessId, }; @@ -60,6 +61,12 @@ impl NameSpace { } } +#[derive(Clone, Debug)] +pub struct LogupTableExpression { + pub multiplicity: Expression, + pub values: Expression, +} + #[derive(Clone, Debug)] pub struct ConstraintSystem { pub(crate) ns: NameSpace, @@ -67,6 +74,9 @@ pub struct ConstraintSystem { pub num_witin: WitnessId, pub witin_namespace_map: Vec, + pub num_fixed: usize, + pub fixed_namespace_map: Vec, + pub r_expressions: Vec>, pub r_expressions_namespace_map: Vec, @@ -76,6 +86,8 @@ pub struct ConstraintSystem { /// lookup expression pub lk_expressions: Vec>, pub lk_expressions_namespace_map: Vec, + pub lk_table_expressions: Vec>, + pub lk_table_expressions_namespace_map: Vec, /// main constraints zero expression pub assert_zero_expressions: Vec>, @@ -100,6 +112,8 @@ impl ConstraintSystem { Self { num_witin: 0, witin_namespace_map: vec![], + num_fixed: 0, + fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), r_expressions: vec![], r_expressions_namespace_map: vec![], @@ -107,6 +121,8 @@ impl ConstraintSystem { w_expressions_namespace_map: vec![], lk_expressions: vec![], lk_expressions_namespace_map: vec![], + lk_table_expressions: vec![], + lk_table_expressions_namespace_map: vec![], assert_zero_expressions: vec![], assert_zero_expressions_namespace_map: vec![], assert_zero_sumcheck_expressions: vec![], @@ -118,8 +134,12 @@ impl ConstraintSystem { phantom: std::marker::PhantomData, } } - pub fn key_gen(self) -> VerifyingKey { - VerifyingKey { cs: self } + + pub fn key_gen(self, fixed_traces: Option>>) -> ProvingKey { + ProvingKey { + fixed_traces, + vk: VerifyingKey { cs: self }, + } } pub fn create_witin, N: FnOnce() -> NR>( @@ -140,6 +160,19 @@ impl ConstraintSystem { Ok(wit_in) } + pub fn create_fixed, N: FnOnce() -> NR>( + &mut self, + n: N, + ) -> Result { + let f = Fixed(self.num_fixed); + self.num_fixed += 1; + + let path = self.ns.compute_path(n().into()); + self.fixed_namespace_map.push(path); + + Ok(f) + } + pub fn lk_record, N: FnOnce() -> NR>( &mut self, name_fn: N, @@ -157,6 +190,32 @@ impl ConstraintSystem { Ok(()) } + pub fn lk_table_record( + &mut self, + name_fn: N, + rlc_record: Expression, + multiplicity: Expression, + ) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + assert_eq!( + rlc_record.degree(), + 1, + "rlc record degree {} != 1", + rlc_record.degree() + ); + self.lk_table_expressions.push(LogupTableExpression { + values: rlc_record, + multiplicity, + }); + let path = self.ns.compute_path(name_fn().into()); + self.lk_table_expressions_namespace_map.push(path); + + Ok(()) + } + pub fn read_record, N: FnOnce() -> NR>( &mut self, name_fn: N, @@ -237,13 +296,14 @@ pub struct CircuitBuilder<'a, E: ExtensionField> { #[derive(Clone, Debug)] pub struct ProvingKey { + pub fixed_traces: Option>>, pub vk: VerifyingKey, } impl ProvingKey { - pub fn create_pk(vk: VerifyingKey) -> Self { - Self { vk } - } + // pub fn create_pk(vk: VerifyingKey) -> Self { + // Self { vk } + // } pub fn get_cs(&self) -> &ConstraintSystem { self.vk.get_cs() } diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index aa46d467c..7e0c2bdab 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -13,6 +13,8 @@ use crate::structs::{ChallengeId, WitnessId}; pub enum Expression { /// WitIn(Id) WitIn(WitnessId), + /// Fixed + Fixed(Fixed), /// Constant poly Constant(E::BaseField), /// This is the sum of two expression @@ -25,7 +27,7 @@ pub enum Expression { } /// this is used as finite state machine state -/// for differentiate a expression is in monomial form or not +/// for differentiate an expression is in monomial form or not enum MonomialState { SumTerm, ProductTerm, @@ -34,6 +36,7 @@ enum MonomialState { impl Expression { pub fn degree(&self) -> usize { match self { + Expression::Fixed(_) => 1, Expression::WitIn(_) => 1, Expression::Constant(_) => 0, Expression::Sum(a_expr, b_expr) => max(a_expr.degree(), b_expr.degree()), @@ -46,6 +49,7 @@ impl Expression { #[allow(clippy::too_many_arguments)] pub fn evaluate( &self, + fixed_in: &impl Fn(&Fixed) -> T, wit_in: &impl Fn(WitnessId) -> T, // witin id constant: &impl Fn(E::BaseField) -> T, challenge: &impl Fn(ChallengeId, usize, E, E) -> T, @@ -54,22 +58,23 @@ impl Expression { scaled: &impl Fn(T, T, T) -> T, ) -> T { match self { + Expression::Fixed(f) => fixed_in(f), Expression::WitIn(witness_id) => wit_in(*witness_id), Expression::Constant(scalar) => constant(*scalar), Expression::Sum(a, b) => { - let a = a.evaluate(wit_in, constant, challenge, sum, product, scaled); - let b = b.evaluate(wit_in, constant, challenge, sum, product, scaled); + let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); + let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); sum(a, b) } Expression::Product(a, b) => { - let a = a.evaluate(wit_in, constant, challenge, sum, product, scaled); - let b = b.evaluate(wit_in, constant, challenge, sum, product, scaled); + let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); + let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); product(a, b) } Expression::ScaledSum(x, a, b) => { - let x = x.evaluate(wit_in, constant, challenge, sum, product, scaled); - let a = a.evaluate(wit_in, constant, challenge, sum, product, scaled); - let b = b.evaluate(wit_in, constant, challenge, sum, product, scaled); + let x = x.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); + let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); + let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); scaled(x, a, b) } Expression::Challenge(challenge_id, pow, scalar, offset) => { @@ -91,6 +96,7 @@ impl Expression { fn is_zero_expr(expr: &Expression) -> bool { match expr { + Expression::Fixed(_) => false, Expression::WitIn(_) => false, Expression::Constant(c) => *c == E::BaseField::ZERO, Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b), @@ -102,10 +108,13 @@ impl Expression { fn is_monomial_form_inner(s: MonomialState, expr: &Expression) -> bool { match (expr, s) { - (Expression::WitIn(_), MonomialState::SumTerm) => true, - (Expression::WitIn(_), MonomialState::ProductTerm) => true, - (Expression::Constant(_), MonomialState::SumTerm) => true, - (Expression::Constant(_), MonomialState::ProductTerm) => true, + ( + Expression::Fixed(_) + | Expression::WitIn(_) + | Expression::Challenge(..) + | Expression::Constant(_), + _, + ) => true, (Expression::Sum(a, b), MonomialState::SumTerm) => { Self::is_monomial_form_inner(MonomialState::SumTerm, a) && Self::is_monomial_form_inner(MonomialState::SumTerm, b) @@ -121,8 +130,6 @@ impl Expression { } (Expression::ScaledSum(_, _, _), MonomialState::SumTerm) => true, (Expression::ScaledSum(_, _, b), MonomialState::ProductTerm) => Self::is_zero_expr(b), - (Expression::Challenge(_, _, _, _), MonomialState::SumTerm) => true, - (Expression::Challenge(_, _, _, _), MonomialState::ProductTerm) => true, } } } @@ -131,7 +138,7 @@ impl Neg for Expression { type Output = Expression; fn neg(self) -> Self::Output { match self { - Expression::WitIn(_) => Expression::ScaledSum( + Expression::Fixed(_) | Expression::WitIn(_) => Expression::ScaledSum( Box::new(self), Box::new(Expression::Constant(E::BaseField::ONE.neg())), Box::new(Expression::Constant(E::BaseField::ZERO)), @@ -386,6 +393,9 @@ pub struct WitIn { pub id: WitnessId, } +#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] +pub struct Fixed(pub usize); + pub trait ToExpr { type Output; fn expr(&self) -> Self::Output; diff --git a/ceno_zkvm/src/instructions/riscv/test.rs b/ceno_zkvm/src/instructions/riscv/test.rs index 07224bcce..3b20098f6 100644 --- a/ceno_zkvm/src/instructions/riscv/test.rs +++ b/ceno_zkvm/src/instructions/riscv/test.rs @@ -1,7 +1,7 @@ use goldilocks::GoldilocksExt2; use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem, ProvingKey}, + circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::Instruction, }; @@ -26,6 +26,5 @@ fn test_multiple_opcode() { Ok(config) }, ); - let vk = cs.key_gen(); - let _pk = ProvingKey::create_pk(vk); + cs.key_gen(None); } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 54275bf88..1399c2b42 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -3,6 +3,7 @@ pub mod error; pub mod instructions; pub mod scheme; +pub mod tables; // #[cfg(test)] pub use utils::u64vec; mod chip_handler; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 223183da8..ea55b1ec2 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -37,3 +37,23 @@ pub struct ZKVMProof { pub wits_in_evals: Vec, } + +#[derive(Clone)] +pub struct ZKVMTableProof { + pub num_instances: usize, + // logup sum at layer 1 + pub lk_p1_out_eval: E, + pub lk_p2_out_eval: E, + pub lk_q1_out_eval: E, + pub lk_q2_out_eval: E, + + pub tower_proof: TowerProofs, + + // select layer sumcheck proof + pub sel_sumcheck_proofs: Vec>, + pub lk_d_in_evals: Vec, + pub lk_n_in_evals: Vec, + + pub fixed_in_evals: Vec, + pub wits_in_evals: Vec, +} diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 93a86e660..07e047b38 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -3,3 +3,4 @@ pub(crate) const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/loo pub(crate) const SEL_DEGREE: usize = 2; pub const NUM_FANIN: usize = 2; +pub const NUM_FANIN_LOGUP: usize = 2; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 5698a2c71..464e8e60d 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -105,10 +105,10 @@ impl<'a, E: ExtensionField> MockProver { let left = left.neg().neg(); // TODO get_ext_field_vec doesn't work without this let right = right.neg(); - let left_evaluated = wit_infer_by_expr(wits_in, &challenge, &left); + let left_evaluated = wit_infer_by_expr(&[], wits_in, &challenge, &left); let left_evaluated = left_evaluated.get_ext_field_vec(); - let right_evaluated = wit_infer_by_expr(wits_in, &challenge, &right); + let right_evaluated = wit_infer_by_expr(&[], wits_in, &challenge, &right); let right_evaluated = right_evaluated.get_ext_field_vec(); for (left_element, right_element) in left_evaluated.iter().zip_eq(right_evaluated) { @@ -124,7 +124,7 @@ impl<'a, E: ExtensionField> MockProver { } } else { let expr = expr.clone().neg().neg(); // TODO get_ext_field_vec doesn't work without this - let expr_evaluated = wit_infer_by_expr(wits_in, &challenge, &expr); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, &challenge, &expr); let expr_evaluated = expr_evaluated.get_ext_field_vec(); for element in expr_evaluated { @@ -150,7 +150,7 @@ impl<'a, E: ExtensionField> MockProver { .iter() .zip_eq(cb.cs.lk_expressions_namespace_map.iter()) { - let expr_evaluated = wit_infer_by_expr(wits_in, &challenge, expr); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, &challenge, expr); let expr_evaluated = expr_evaluated.get_ext_field_vec(); // Check each lookup expr exists in t vec diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 6be540d42..e080e0c42 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,10 +1,12 @@ -use std::collections::BTreeSet; +use std::{collections::BTreeSet, sync::Arc}; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ - mle::IntoMLE, util::ceil_log2, virtual_poly::build_eq_x_r_vec, + mle::{IntoMLE, MultilinearExtension}, + util::ceil_log2, + virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, }; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; @@ -18,7 +20,7 @@ use crate::{ circuit_builder::ProvingKey, error::ZKVMError, scheme::{ - constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN}, + constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, wit_infer_by_expr, @@ -29,7 +31,7 @@ use crate::{ virtual_polys::VirtualPolynomials, }; -use super::ZKVMProof; +use super::{ZKVMProof, ZKVMTableProof}; pub struct ZKVMProver { pk: ProvingKey, @@ -72,7 +74,7 @@ impl ZKVMProver { .chain(cs.lk_expressions.par_iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&witnesses, challenges, expr) + wit_infer_by_expr(&[], &witnesses, challenges, expr) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); @@ -135,7 +137,7 @@ impl ZKVMProver { exit_span!(span); let span = entered_span!("wit_inference::tower_witness_lk_layers"); - let lk_wit_layers = infer_tower_logup_witness(lk_records_last_layer); + let lk_wit_layers = infer_tower_logup_witness(None, lk_records_last_layer); exit_span!(span); if cfg!(test) { @@ -283,7 +285,8 @@ impl ZKVMProver { if num_instances < sel_non_lc_zero_sumcheck.len() { sel_non_lc_zero_sumcheck.splice( num_instances..sel_non_lc_zero_sumcheck.len(), - std::iter::repeat(E::ZERO), + std::iter::repeat(E::ZERO) + .take(sel_non_lc_zero_sumcheck.len() - num_instances), ); } let sel_non_lc_zero_sumcheck: ArcMultilinearExtension = @@ -324,7 +327,7 @@ impl ZKVMProver { *alpha_write * eq_w[w_counts_per_instance..].iter().sum::() - *alpha_write, ); - // lk + // lk denominator // rt := rt || rs for i in 0..lk_counts_per_instance { // \sum_t (sel(rt, t) * (\sum_i alpha_lk* eq(rs, i) * record_w[i])) @@ -427,6 +430,200 @@ impl ZKVMProver { wits_in_evals, }) } + + pub fn create_table_proof( + &self, + witnesses: Vec>, + num_instances: usize, + max_threads: usize, + transcript: &mut Transcript, + challenges: &[E; 2], + ) -> Result, ZKVMError> { + let cs = self.pk.get_cs(); + let fixed = self + .pk + .fixed_traces + .as_ref() + .expect("pk.fixed_traces must not be none for table circuit") + .iter() + .map(|f| -> ArcMultilinearExtension { Arc::new(f.get_ranged_mle(1, 0)) }) + .collect::>>(); + let log2_num_instances = ceil_log2(num_instances); + let next_pow2_instances = 1 << log2_num_instances; + + // sanity check + assert_eq!(witnesses.len(), cs.num_witin as usize); + assert_eq!(fixed.len(), cs.num_fixed); + assert!(witnesses.iter().all(|v| { + v.num_vars() == log2_num_instances && v.evaluations().len() == next_pow2_instances + })); + assert!(!cs.lk_table_expressions.is_empty()); + + // main constraint: lookup denominator and numerator record witness inference + let span = entered_span!("wit_inference::record"); + let records_wit: Vec> = cs + .lk_table_expressions + .par_iter() + .map(|lk| &lk.values) + .chain( + cs.lk_table_expressions + .par_iter() + .map(|lk| &lk.multiplicity), + ) + .map(|expr| { + assert_eq!(expr.degree(), 1); + wit_infer_by_expr(&fixed, &witnesses, challenges, expr) + }) + .collect(); + let (lk_d_wit, lk_n_wit) = records_wit.split_at(cs.lk_table_expressions.len()); + exit_span!(span); + + let lk_counts_per_instance = cs.lk_table_expressions.len(); + let log2_lk_count = ceil_log2(lk_counts_per_instance); + + // infer all tower witness after last layer + let span = entered_span!("wit_inference::tower_witness_lk_last_layer"); + // TODO optimize last layer to avoid alloc new vector to save memory + let lk_denominator_last_layer = + interleaving_mles_to_mles(lk_d_wit, log2_num_instances, NUM_FANIN_LOGUP, E::ONE); + let lk_numerator_last_layer = + interleaving_mles_to_mles(lk_n_wit, log2_num_instances, NUM_FANIN_LOGUP, E::ZERO); + assert_eq!(lk_denominator_last_layer.len(), NUM_FANIN_LOGUP); + assert_eq!(lk_numerator_last_layer.len(), NUM_FANIN_LOGUP); + exit_span!(span); + + let span = entered_span!("wit_inference::tower_witness_lk_layers"); + let lk_wit_layers = + infer_tower_logup_witness(Some(lk_numerator_last_layer), lk_denominator_last_layer); + exit_span!(span); + + if cfg!(test) { + // sanity check + assert_eq!(lk_wit_layers.len(), log2_num_instances + log2_lk_count); + assert!(lk_wit_layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + let (p1, p2, q1, q2) = (&w[0], &w[1], &w[2], &w[3]); + p1.evaluations().len() == expected_size + && p2.evaluations().len() == expected_size + && q1.evaluations().len() == expected_size + && q2.evaluations().len() == expected_size + })); + } + + // product constraint tower sumcheck + let span = entered_span!("sumcheck::tower"); + // final evals for verifier + let lk_p1_out_eval = lk_wit_layers[0][0].get_ext_field_vec()[0]; + let lk_p2_out_eval = lk_wit_layers[0][1].get_ext_field_vec()[0]; + let lk_q1_out_eval = lk_wit_layers[0][2].get_ext_field_vec()[0]; + let lk_q2_out_eval = lk_wit_layers[0][3].get_ext_field_vec()[0]; + let (rt_tower, tower_proof) = TowerProver::create_proof( + max_threads, + vec![], + vec![TowerProverSpec { + witness: lk_wit_layers, + }], + NUM_FANIN_LOGUP, + transcript, + ); + assert_eq!(rt_tower.len(), log2_num_instances + log2_lk_count); + exit_span!(span); + + // selector layer sumcheck + let span = entered_span!("sumcheck::main_sel"); + let rt_lk: Vec = tower_proof.logup_specs_points[0] + .last() + .expect("error getting rt_lk") + .to_vec(); + + let num_threads = proper_num_threads(log2_num_instances, max_threads); + // 2 for denominator and numerator + let alpha_pow = get_challenge_pows(2, transcript); + let mut alpha_pow_iter = alpha_pow.iter(); + let (alpha_lk_d, alpha_lk_n) = ( + alpha_pow_iter.next().unwrap(), + alpha_pow_iter.next().unwrap(), + ); + // create selector: all ONE, but padding ZERO to ceil_log2 + let sel_lk: ArcMultilinearExtension = { + let mut sel_lk = build_eq_x_r_vec(&rt_lk[log2_lk_count..]); + if num_instances < sel_lk.len() { + sel_lk.splice( + num_instances..sel_lk.len(), + std::iter::repeat(E::ZERO).take(sel_lk.len() - num_instances), + ); + } + sel_lk.into_mle().into() + }; + + let mut virtual_polys = VirtualPolynomials::::new(num_threads, log2_num_instances); + + let eq_lk = build_eq_x_r_vec(&rt_lk[..log2_lk_count]); + // lk denominator + // rt := rt || rs + for i in 0..lk_counts_per_instance { + // \sum_t (sel(rt, t) * (\sum_i alpha_lk_d * eq(rs, i) * lk_d_record[i])) + virtual_polys.add_mle_list(vec![&sel_lk, &lk_d_wit[i]], eq_lk[i] * alpha_lk_d); + } + // \sum_t alpha_lk_d * sel(rt, t) * (\sum_i (eq(rs, i)) - 1) + virtual_polys.add_mle_list( + vec![&sel_lk], + *alpha_lk_d * (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE), + ); + + // lk numerator + for i in 0..lk_counts_per_instance { + // \sum_t (sel(rt, t) * (\sum_i alpha_lk_n * eq(rs, i) * lk_n_record[i])) + virtual_polys.add_mle_list(vec![&sel_lk, &lk_n_wit[i]], eq_lk[i] * alpha_lk_n); + } + + let (sel_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( + num_threads, + virtual_polys.get_batched_polys(), + transcript, + ); + let sel_evals = state.get_mle_final_evaluations(); + assert_eq!( + sel_evals.len(), + lk_counts_per_instance * 2 + 1 // 1 for sel_lk + ); + let mut sel_evals_iter = sel_evals.into_iter(); + sel_evals_iter.next(); // skip sel_lk + let lk_d_in_evals = (0..lk_counts_per_instance) + .map(|_| sel_evals_iter.next().unwrap()) + .collect_vec(); + let lk_n_in_evals = (0..lk_counts_per_instance) + .map(|_| sel_evals_iter.next().unwrap()) + .collect_vec(); + assert!(sel_evals_iter.count() == 0); + let input_open_point = sel_sumcheck_proofs.point.clone(); + assert!(input_open_point.len() == log2_num_instances); + exit_span!(span); + + let span = entered_span!("fixed::evals + witin::evals"); + let mut evals = witnesses + .par_iter() + .chain(fixed.par_iter()) + .map(|poly| poly.evaluate(&input_open_point)) + .collect::>(); + let fixed_in_evals = evals.split_off(witnesses.len()); + let wits_in_evals = evals; + exit_span!(span); + + Ok(ZKVMTableProof { + num_instances, + lk_p1_out_eval, + lk_p2_out_eval, + lk_q1_out_eval, + lk_q2_out_eval, + tower_proof, + sel_sumcheck_proofs: sel_sumcheck_proofs.proofs, + lk_d_in_evals, + lk_n_in_evals, + fixed_in_evals, + wits_in_evals, + }) + } } /// TowerProofs @@ -477,7 +674,6 @@ impl TowerProver { assert_eq!(num_fanin, 2); let mut proofs = TowerProofs::new(prod_specs.len(), logup_specs.len()); - assert!(!prod_specs.is_empty()); let log_num_fanin = ceil_log2(num_fanin); // -1 for sliding windows size 2: (cur_layer, next_layer) w.r.t total size let max_round_index = prod_specs diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 14a98cbcd..3a2008a36 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -8,7 +8,7 @@ use multilinear_extensions::mle::IntoMLE; use transcript::Transcript; use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem, ProvingKey}, + circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::{Expression, ToExpr}, structs::PointAndEval, @@ -53,8 +53,8 @@ fn test_rw_lk_expression_combination() { let mut cs = ConstraintSystem::new(|| "test"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let _ = TestCircuit::construct_circuit::(&mut circuit_builder); - let vk = cs.key_gen(); - let pk = ProvingKey::create_pk(vk.clone()); + let pk = cs.key_gen(None); + let vk = pk.vk.clone(); // generate mock witness let num_instances = 1 << 2; diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 2fe4e9457..3bf2f5396 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -64,9 +64,18 @@ pub(crate) fn interleaving_mles_to_mles<'a, E: ExtensionField>( assert_eq!(instance.len(), per_instance_size); instance[i] = *value; }), - _ => { - unreachable!("must be extension field") - } + FieldType::Base(mle) => mle + .get(start..(start + per_fanin_len)) + .unwrap_or(&[]) + .par_iter() + .zip(evaluations.par_chunks_mut(per_instance_size)) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(value, instance)| { + assert_eq!(instance.len(), per_instance_size); + instance[i] = + <::BaseField as Into>::into(*value); + }), + _ => unreachable!(), }); evaluations.into_mle().into() }) @@ -75,83 +84,80 @@ pub(crate) fn interleaving_mles_to_mles<'a, E: ExtensionField>( /// infer logup witness from last layer /// return is the ([p1,p2], [q1,q2]) for each layer -pub(crate) fn infer_tower_logup_witness( - q_mles: Vec>, -) -> Vec>> { +pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( + p_mles: Option>>, + q_mles: Vec>, +) -> Vec>> { if cfg!(test) { assert_eq!(q_mles.len(), 2); assert!(q_mles.iter().map(|q| q.evaluations().len()).all_equal()); } let num_vars = ceil_log2(q_mles[0].evaluations().len()); - let mut wit_layers = (0..num_vars).fold( - vec![(Option::>>::None, q_mles)], - |mut acc, _| { - let (p, q): &( - Option>>, - Vec>, - ) = acc.last().unwrap(); - let (q1, q2) = (&q[0], &q[1]); - let cur_len = q1.evaluations().len() / 2; - let (next_p, next_q): ( - Vec>, - Vec>, - ) = (0..2) - .map(|index| { - let mut p_evals = vec![E::ZERO; cur_len]; - let mut q_evals = vec![E::ZERO; cur_len]; - let start_index = cur_len * index; - if let Some(p) = p { - let (p1, p2) = (&p[0], &p[1]); - match ( - p1.evaluations(), - p2.evaluations(), - q1.evaluations(), - q2.evaluations(), - ) { - ( - FieldType::Ext(p1), - FieldType::Ext(p2), - FieldType::Ext(q1), - FieldType::Ext(q2), - ) => q1[start_index..][..cur_len] - .par_iter() - .zip(q2[start_index..][..cur_len].par_iter()) - .zip(p1[start_index..][..cur_len].par_iter()) - .zip(p2[start_index..][..cur_len].par_iter()) - .zip(p_evals.par_iter_mut()) - .zip(q_evals.par_iter_mut()) - .with_min_len(MIN_PAR_SIZE) - .for_each(|(((((q1, q2), p1), p2), p_eval), q_eval)| { - *p_eval = *p2 * q1 + *p1 * q2; - *q_eval = *q1 * q2; - }), - _ => unreachable!(), - }; - } else { - match (q1.evaluations(), q2.evaluations()) { - (FieldType::Ext(q1), FieldType::Ext(q2)) => q1[start_index..] - [..cur_len] - .par_iter() - .zip(q2[start_index..][..cur_len].par_iter()) - .zip(p_evals.par_iter_mut()) - .zip(q_evals.par_iter_mut()) - .with_min_len(MIN_PAR_SIZE) - .for_each(|(((q1, q2), p_res), q_res)| { - // 1 / q1 + 1 / q2 = (q1+q2) / q1*q2 - // p is numerator and q is denominator - *p_res = *q1 + q2; - *q_res = *q1 * q2; - }), - _ => unreachable!(), - }; - } - (p_evals.into_mle().into(), q_evals.into_mle().into()) - }) - .unzip(); // vec[vec[p1, p2], vec[q1, q2]] - acc.push((Some(next_p), next_q)); - acc - }, - ); + let mut wit_layers = (0..num_vars).fold(vec![(p_mles, q_mles)], |mut acc, _| { + let (p, q): &( + Option>>, + Vec>, + ) = acc.last().unwrap(); + let (q1, q2) = (&q[0], &q[1]); + let cur_len = q1.evaluations().len() / 2; + let (next_p, next_q): ( + Vec>, + Vec>, + ) = (0..2) + .map(|index| { + let mut p_evals = vec![E::ZERO; cur_len]; + let mut q_evals = vec![E::ZERO; cur_len]; + let start_index = cur_len * index; + if let Some(p) = p { + let (p1, p2) = (&p[0], &p[1]); + match ( + p1.evaluations(), + p2.evaluations(), + q1.evaluations(), + q2.evaluations(), + ) { + ( + FieldType::Ext(p1), + FieldType::Ext(p2), + FieldType::Ext(q1), + FieldType::Ext(q2), + ) => q1[start_index..][..cur_len] + .par_iter() + .zip(q2[start_index..][..cur_len].par_iter()) + .zip(p1[start_index..][..cur_len].par_iter()) + .zip(p2[start_index..][..cur_len].par_iter()) + .zip(p_evals.par_iter_mut()) + .zip(q_evals.par_iter_mut()) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(((((q1, q2), p1), p2), p_eval), q_eval)| { + *p_eval = *p2 * q1 + *p1 * q2; + *q_eval = *q1 * q2; + }), + _ => unreachable!(), + }; + } else { + match (q1.evaluations(), q2.evaluations()) { + (FieldType::Ext(q1), FieldType::Ext(q2)) => q1[start_index..][..cur_len] + .par_iter() + .zip(q2[start_index..][..cur_len].par_iter()) + .zip(p_evals.par_iter_mut()) + .zip(q_evals.par_iter_mut()) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(((q1, q2), p_res), q_res)| { + // 1 / q1 + 1 / q2 = (q1+q2) / q1*q2 + // p is numerator and q is denominator + *p_res = *q1 + q2; + *q_res = *q1 * q2; + }), + _ => unreachable!(), + }; + } + (p_evals.into_mle().into(), q_evals.into_mle().into()) + }) + .unzip(); // vec[vec[p1, p2], vec[q1, q2]] + acc.push((Some(next_p), next_q)); + acc + }); wit_layers.reverse(); wit_layers .into_iter() @@ -211,11 +217,13 @@ pub(crate) fn infer_tower_product_witness( } pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( + fixed: &[ArcMultilinearExtension<'a, E>], witnesses: &[ArcMultilinearExtension<'a, E>], challenges: &[E; N], expr: &Expression, ) -> ArcMultilinearExtension<'a, E> { expr.evaluate::>( + &|f| fixed[f.0].clone(), &|witness_id| witnesses[witness_id as usize].clone(), &|scalar| { let scalar: ArcMultilinearExtension = Arc::new( @@ -334,8 +342,18 @@ pub(crate) fn eval_by_expr( witnesses: &[E], challenges: &[E], expr: &Expression, +) -> E { + eval_by_expr_with_fixed(&[], witnesses, challenges, expr) +} + +pub(crate) fn eval_by_expr_with_fixed( + fixed: &[E], + witnesses: &[E], + challenges: &[E], + expr: &Expression, ) -> E { expr.evaluate::( + &|f| fixed[f.0], &|witness_id| witnesses[witness_id as usize], &|scalar| scalar.into(), &|challenge_id, pow, scalar, offset| { @@ -479,7 +497,7 @@ mod tests { .into_mle() .into(), ]; - let mut res = infer_tower_logup_witness(q); + let mut res = infer_tower_logup_witness(None, q); assert_eq!(num_vars + 1, res.len()); // input layer let layer = res.pop().unwrap(); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index da157fc22..43f87ded0 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -15,12 +15,17 @@ use transcript::Transcript; use crate::{ circuit_builder::VerifyingKey, error::ZKVMError, - scheme::constants::{NUM_FANIN, SEL_DEGREE}, + scheme::{ + constants::{NUM_FANIN, SEL_DEGREE}, + utils::eval_by_expr_with_fixed, + }, structs::{Point, PointAndEval, TowerProofs}, utils::{get_challenge_pows, sel_eval}, }; -use super::{constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMProof}; +use super::{ + constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMProof, ZKVMTableProof, +}; pub struct ZKVMVerifier { vk: VerifyingKey, @@ -244,6 +249,124 @@ impl ZKVMVerifier { Ok(input_opening_point) } + + pub fn verify_table_proof( + &self, + proof: &ZKVMTableProof, + transcript: &mut Transcript, + num_logup_fanin: usize, + _out_evals: &PointAndEval, + challenges: &[E; 2], // TODO: derive challenge from PCS + ) -> Result, ZKVMError> { + let cs = self.vk.get_cs(); + let lk_counts_per_instance = cs.lk_table_expressions.len(); + let log2_lk_count = ceil_log2(lk_counts_per_instance); + let (chip_record_alpha, _) = (challenges[0], challenges[1]); + + let num_instances = proof.num_instances; + let log2_num_instances = ceil_log2(num_instances); + + // verify and reduce product tower sumcheck + let tower_proofs = &proof.tower_proof; + + let expected_max_round = log2_num_instances + log2_lk_count; + let (_, _, logup_p_evals, logup_q_evals) = TowerVerify::verify( + vec![], + vec![vec![ + proof.lk_p1_out_eval, + proof.lk_p2_out_eval, + proof.lk_q1_out_eval, + proof.lk_q2_out_eval, + ]], + tower_proofs, + vec![expected_max_round], + num_logup_fanin, + transcript, + )?; + assert!(logup_q_evals.len() == 1, "[lk_q_record]"); + assert!(logup_p_evals.len() == 1, "[lk_p_record]"); + assert_eq!(logup_p_evals[0].point, logup_q_evals[0].point); + + // verify selector layer sumcheck + let rt_lk: Vec = logup_p_evals[0].point.to_vec(); + + // 2 for denominator and numerator + let alpha_pow = get_challenge_pows(2, transcript); + let mut alpha_pow_iter = alpha_pow.iter(); + let (alpha_lk_d, alpha_lk_n) = ( + alpha_pow_iter.next().unwrap(), + alpha_pow_iter.next().unwrap(), + ); + // alpha_lk * (out_lk_q - one) + alpha_lk_n * out_lk_p + let claim_sum = + *alpha_lk_d * (logup_q_evals[0].eval - E::ONE) + *alpha_lk_n * logup_p_evals[0].eval; + let sel_subclaim = IOPVerifierState::verify( + claim_sum, + &IOPProof { + point: vec![], // final claimed point will be derived from sumcheck protocol + proofs: proof.sel_sumcheck_proofs.clone(), + }, + &VPAuxInfo { + max_degree: SEL_DEGREE.max(cs.max_non_lc_degree), + num_variables: log2_num_instances, + phantom: PhantomData, + }, + transcript, + ); + let (input_opening_point, expected_evaluation) = ( + sel_subclaim.point.iter().map(|c| c.elements).collect_vec(), + sel_subclaim.expected_evaluation, + ); + let eq_lk = build_eq_x_r_vec_sequential(&rt_lk[..log2_lk_count]); + + let sel_lk = eq_eval(&rt_lk[log2_lk_count..], &input_opening_point) + * sel_eval(num_instances, &rt_lk[log2_lk_count..]); + + let computed_evals = [ + // lookup denominator + *alpha_lk_d + * sel_lk + * ((0..lk_counts_per_instance) + .map(|i| proof.lk_d_in_evals[i] * eq_lk[i]) + .sum::() + + chip_record_alpha + * (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), + *alpha_lk_n + * sel_lk + * ((0..lk_counts_per_instance) + .map(|i| proof.lk_n_in_evals[i] * eq_lk[i]) + .sum::()), + ] + .iter() + .sum::(); + if computed_evals != expected_evaluation { + return Err(ZKVMError::VerifyError("sel evaluation verify failed")); + } + // verify records (degree = 1) statement, thus no sumcheck + if cs + .lk_table_expressions + .iter() + .map(|lk| &lk.values) + .chain(cs.lk_table_expressions.iter().map(|lk| &lk.multiplicity)) + .zip_eq( + proof.lk_d_in_evals[..lk_counts_per_instance] + .iter() + .chain(proof.lk_n_in_evals[..lk_counts_per_instance].iter()), + ) + .any(|(expr, expected_evals)| { + eval_by_expr_with_fixed( + &proof.fixed_in_evals, + &proof.wits_in_evals, + challenges, + expr, + ) != *expected_evals + }) + { + return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); + } + + Ok(input_opening_point) + } } pub struct TowerVerify; @@ -326,7 +449,7 @@ impl TowerVerify { let sumcheck_claim = IOPVerifierState::verify( *out_claim, &IOPProof { - point: vec![], // final claimed point will be derive from sumcheck protocol + point: vec![], // final claimed point will be derived from sumcheck protocol proofs: tower_proofs.proofs[round].clone(), }, &VPAuxInfo { diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs new file mode 100644 index 000000000..b2277ba15 --- /dev/null +++ b/ceno_zkvm/src/tables/mod.rs @@ -0,0 +1 @@ +mod range; diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs new file mode 100644 index 000000000..a081ee549 --- /dev/null +++ b/ceno_zkvm/src/tables/range.rs @@ -0,0 +1,131 @@ +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, Fixed, ToExpr, WitIn}, + structs::{ROMType, WitnessId}, +}; +use ff_ext::ExtensionField; +use itertools::Itertools; +use multilinear_extensions::mle::DenseMultilinearExtension; +use std::{collections::BTreeMap, marker::PhantomData}; + +#[derive(Clone, Debug)] +pub struct RangeTableConfig { + u16_tbl: Fixed, + u16_mlt: WitIn, + _marker: PhantomData, +} + +#[derive(Default)] +pub struct RangeTableTrace { + pub fixed: BTreeMap>, + pub wits: BTreeMap>, +} + +impl RangeTableConfig { + #[allow(unused)] + fn construct_circuit(cb: &mut CircuitBuilder) -> Result, ZKVMError> { + let u16_tbl = cb.create_fixed(|| "u16_tbl")?; + let u16_mlt = cb.create_witin(|| "u16_mlt")?; + + let u16_table_values = cb.rlc_chip_record(vec![ + Expression::Constant(E::BaseField::from(ROMType::U16 as u64)), + Expression::Fixed(u16_tbl.clone()), + ]); + + cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?; + + Ok(RangeTableConfig { + u16_tbl, + u16_mlt, + _marker: Default::default(), + }) + } + + #[allow(unused)] + fn generate_traces(self, inputs: &[u16]) -> RangeTableTrace { + let mut u16_mlt = vec![0; 1 << 16]; + for limb in inputs { + u16_mlt[*limb as usize] += 1; + } + + let u16_tbl = DenseMultilinearExtension::from_evaluations_vec( + 16, + (0..(1 << 16)).map(E::BaseField::from).collect_vec(), + ); + let u16_mlt = DenseMultilinearExtension::from_evaluations_vec( + 16, + u16_mlt.into_iter().map(E::BaseField::from).collect_vec(), + ); + + let config = self.clone(); + let mut traces = RangeTableTrace::default(); + traces.fixed.insert(config.u16_tbl, u16_tbl); + traces.wits.insert(config.u16_mlt.id, u16_mlt); + + traces + } +} + +#[cfg(test)] +mod tests { + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + scheme::{constants::NUM_FANIN_LOGUP, prover::ZKVMProver, verifier::ZKVMVerifier}, + structs::PointAndEval, + tables::range::RangeTableConfig, + }; + use goldilocks::GoldilocksExt2; + use itertools::Itertools; + use transcript::Transcript; + + #[test] + fn test_range_circuit() { + let mut cs = ConstraintSystem::new(|| "riscv"); + let config = cs + .namespace( + || "range", + |cs| { + let mut cb = CircuitBuilder::::new(cs); + RangeTableConfig::construct_circuit(&mut cb) + }, + ) + .expect("construct range table circuit"); + + let traces = config.generate_traces((0..1 << 8).collect_vec().as_slice()); + + let pk = cs.key_gen(Some(traces.fixed.clone().into_values().collect_vec())); + let vk = pk.vk.clone(); + let prover = ZKVMProver::new(pk); + + let mut transcript = Transcript::new(b"range"); + let challenges = [1.into(), 2.into()]; + + let proof = prover + .create_table_proof( + traces + .wits + .into_values() + .map(|mle| mle.into()) + .collect_vec(), + // TODO: fix the verification error for num_instances is not power-of-two case + 1 << 16, + 1, + &mut transcript, + &challenges, + ) + .expect("create proof"); + + let mut transcript = Transcript::new(b"range"); + let verifier = ZKVMVerifier::new(vk); + verifier + .verify_table_proof( + &proof, + &mut transcript, + NUM_FANIN_LOGUP, + &PointAndEval::default(), + &challenges, + ) + .expect("verify proof failed"); + } +} diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 20bbbe389..375f15b8c 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -94,6 +94,7 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { ) -> BTreeSet { assert!(expr.is_monomial_form()); let monomial_terms = expr.evaluate( + &|_| unreachable!(), &|witness_id| vec![(E::ONE, { vec![witness_id] })], &|scalar| vec![(E::from(scalar), { vec![] })], &|challenge_id, pow, scalar, offset| {