Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add range table circuit #154

Merged
merged 11 commits into from
Sep 3, 2024
5 changes: 2 additions & 3 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::time::{Duration, Instant};

use ark_std::test_rng;
use ceno_zkvm::{
circuit_builder::{CircuitBuilder, ConstraintSystem, ProvingKey},
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
};
Expand Down Expand Up @@ -65,8 +65,7 @@ fn bench_add(c: &mut Criterion) {
let mut cs = ConstraintSystem::new(|| "risv_add");
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);
let _ = AddInstruction::construct_circuit(&mut circuit_builder);
let vk = cs.key_gen();
let pk = ProvingKey::create_pk(vk);
let pk = cs.key_gen(None);
let num_witin = pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
Expand Down
5 changes: 2 additions & 3 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::time::Instant;

use ark_std::test_rng;
use ceno_zkvm::{
circuit_builder::{CircuitBuilder, ConstraintSystem, ProvingKey},
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
};
Expand Down Expand Up @@ -44,8 +44,7 @@ fn main() {
let mut cs = ConstraintSystem::new(|| "risv_add");
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);
let _ = AddInstruction::construct_circuit(&mut circuit_builder);
let vk = cs.key_gen();
let pk = ProvingKey::create_pk(vk);
let pk = cs.key_gen(None);
let num_witin = pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
Expand Down
23 changes: 22 additions & 1 deletion ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ff::Field;
use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
expression::{Expression, WitIn},
expression::{Expression, Fixed, WitIn},
structs::ROMType,
};

Expand All @@ -22,6 +22,14 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.cs.create_witin(name_fn)
}

pub fn create_fixed<NR, N>(&mut self, name_fn: N) -> Result<Fixed, ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.cs.create_fixed(name_fn)
}

pub fn lk_record<NR, N>(
&mut self,
name_fn: N,
Expand All @@ -34,6 +42,19 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.cs.lk_record(name_fn, rlc_record)
}

pub fn lk_table_record<NR, N>(
&mut self,
name_fn: N,
rlc_record: Expression<E>,
multiplicity: Expression<E>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.cs.lk_table_record(name_fn, rlc_record, multiplicity)
}

pub fn read_record<NR, N>(
&mut self,
name_fn: N,
Expand Down
72 changes: 66 additions & 6 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::marker::PhantomData;

use ff_ext::ExtensionField;
use multilinear_extensions::mle::DenseMultilinearExtension;

use crate::{
error::ZKVMError,
expression::{Expression, WitIn},
expression::{Expression, Fixed, WitIn},
structs::WitnessId,
};

Expand Down Expand Up @@ -60,13 +61,22 @@ impl NameSpace {
}
}

#[derive(Clone, Debug)]
pub struct LogupTableExpression<E: ExtensionField> {
pub multiplicity: Expression<E>,
pub values: Expression<E>,
}

#[derive(Clone, Debug)]
pub struct ConstraintSystem<E: ExtensionField> {
pub(crate) ns: NameSpace,

pub num_witin: WitnessId,
pub witin_namespace_map: Vec<String>,

pub num_fixed: usize,
pub fixed_namespace_map: Vec<String>,

pub r_expressions: Vec<Expression<E>>,
pub r_expressions_namespace_map: Vec<String>,

Expand All @@ -76,6 +86,8 @@ pub struct ConstraintSystem<E: ExtensionField> {
/// lookup expression
pub lk_expressions: Vec<Expression<E>>,
pub lk_expressions_namespace_map: Vec<String>,
pub lk_table_expressions: Vec<LogupTableExpression<E>>,
pub lk_table_expressions_namespace_map: Vec<String>,

/// main constraints zero expression
pub assert_zero_expressions: Vec<Expression<E>>,
Expand All @@ -100,13 +112,17 @@ impl<E: ExtensionField> ConstraintSystem<E> {
Self {
num_witin: 0,
witin_namespace_map: vec![],
num_fixed: 0,
fixed_namespace_map: vec![],
ns: NameSpace::new(root_name_fn),
r_expressions: vec![],
r_expressions_namespace_map: vec![],
w_expressions: vec![],
w_expressions_namespace_map: vec![],
lk_expressions: vec![],
lk_expressions_namespace_map: vec![],
lk_table_expressions: vec![],
lk_table_expressions_namespace_map: vec![],
assert_zero_expressions: vec![],
assert_zero_expressions_namespace_map: vec![],
assert_zero_sumcheck_expressions: vec![],
Expand All @@ -118,8 +134,12 @@ impl<E: ExtensionField> ConstraintSystem<E> {
phantom: std::marker::PhantomData,
}
}
pub fn key_gen(self) -> VerifyingKey<E> {
VerifyingKey { cs: self }

pub fn key_gen(self, fixed_traces: Option<Vec<DenseMultilinearExtension<E>>>) -> ProvingKey<E> {
ProvingKey {
fixed_traces,
vk: VerifyingKey { cs: self },
}
}

pub fn create_witin<NR: Into<String>, N: FnOnce() -> NR>(
Expand All @@ -140,6 +160,19 @@ impl<E: ExtensionField> ConstraintSystem<E> {
Ok(wit_in)
}

pub fn create_fixed<NR: Into<String>, N: FnOnce() -> NR>(
&mut self,
n: N,
) -> Result<Fixed, ZKVMError> {
let f = Fixed(self.num_fixed);
self.num_fixed += 1;

let path = self.ns.compute_path(n().into());
self.fixed_namespace_map.push(path);

Ok(f)
}

pub fn lk_record<NR: Into<String>, N: FnOnce() -> NR>(
&mut self,
name_fn: N,
Expand All @@ -157,6 +190,32 @@ impl<E: ExtensionField> ConstraintSystem<E> {
Ok(())
}

pub fn lk_table_record<NR, N>(
&mut self,
name_fn: N,
rlc_record: Expression<E>,
multiplicity: Expression<E>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
assert_eq!(
rlc_record.degree(),
1,
"rlc record degree {} != 1",
rlc_record.degree()
);
self.lk_table_expressions.push(LogupTableExpression {
values: rlc_record,
multiplicity,
});
let path = self.ns.compute_path(name_fn().into());
self.lk_table_expressions_namespace_map.push(path);

Ok(())
}

pub fn read_record<NR: Into<String>, N: FnOnce() -> NR>(
&mut self,
name_fn: N,
Expand Down Expand Up @@ -237,13 +296,14 @@ pub struct CircuitBuilder<'a, E: ExtensionField> {

#[derive(Clone, Debug)]
pub struct ProvingKey<E: ExtensionField> {
pub fixed_traces: Option<Vec<DenseMultilinearExtension<E>>>,
pub vk: VerifyingKey<E>,
}

impl<E: ExtensionField> ProvingKey<E> {
pub fn create_pk(vk: VerifyingKey<E>) -> Self {
Self { vk }
}
// pub fn create_pk(vk: VerifyingKey<E>) -> Self {
// Self { vk }
// }
pub fn get_cs(&self) -> &ConstraintSystem<E> {
self.vk.get_cs()
}
Expand Down
40 changes: 25 additions & 15 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use crate::structs::{ChallengeId, WitnessId};
pub enum Expression<E: ExtensionField> {
/// WitIn(Id)
WitIn(WitnessId),
/// Fixed
Fixed(Fixed),
/// Constant poly
Constant(E::BaseField),
/// This is the sum of two expression
Expand All @@ -25,7 +27,7 @@ pub enum Expression<E: ExtensionField> {
}

/// this is used as finite state machine state
/// for differentiate a expression is in monomial form or not
/// for differentiate an expression is in monomial form or not
enum MonomialState {
SumTerm,
ProductTerm,
Expand All @@ -34,6 +36,7 @@ enum MonomialState {
impl<E: ExtensionField> Expression<E> {
pub fn degree(&self) -> usize {
match self {
Expression::Fixed(_) => 1,
Expression::WitIn(_) => 1,
Expression::Constant(_) => 0,
Expression::Sum(a_expr, b_expr) => max(a_expr.degree(), b_expr.degree()),
Expand All @@ -46,6 +49,7 @@ impl<E: ExtensionField> Expression<E> {
#[allow(clippy::too_many_arguments)]
pub fn evaluate<T>(
&self,
fixed_in: &impl Fn(&Fixed) -> T,
wit_in: &impl Fn(WitnessId) -> T, // witin id
constant: &impl Fn(E::BaseField) -> T,
challenge: &impl Fn(ChallengeId, usize, E, E) -> T,
Expand All @@ -54,22 +58,23 @@ impl<E: ExtensionField> Expression<E> {
scaled: &impl Fn(T, T, T) -> T,
) -> T {
match self {
Expression::Fixed(f) => fixed_in(f),
Expression::WitIn(witness_id) => wit_in(*witness_id),
Expression::Constant(scalar) => constant(*scalar),
Expression::Sum(a, b) => {
let a = a.evaluate(wit_in, constant, challenge, sum, product, scaled);
let b = b.evaluate(wit_in, constant, challenge, sum, product, scaled);
let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled);
let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled);
sum(a, b)
}
Expression::Product(a, b) => {
let a = a.evaluate(wit_in, constant, challenge, sum, product, scaled);
let b = b.evaluate(wit_in, constant, challenge, sum, product, scaled);
let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled);
let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled);
product(a, b)
}
Expression::ScaledSum(x, a, b) => {
let x = x.evaluate(wit_in, constant, challenge, sum, product, scaled);
let a = a.evaluate(wit_in, constant, challenge, sum, product, scaled);
let b = b.evaluate(wit_in, constant, challenge, sum, product, scaled);
let x = x.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled);
let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled);
let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled);
scaled(x, a, b)
}
Expression::Challenge(challenge_id, pow, scalar, offset) => {
Expand All @@ -91,6 +96,7 @@ impl<E: ExtensionField> Expression<E> {

fn is_zero_expr(expr: &Expression<E>) -> bool {
match expr {
Expression::Fixed(_) => false,
Expression::WitIn(_) => false,
Expression::Constant(c) => *c == E::BaseField::ZERO,
Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b),
Expand All @@ -102,10 +108,13 @@ impl<E: ExtensionField> Expression<E> {

fn is_monomial_form_inner(s: MonomialState, expr: &Expression<E>) -> bool {
match (expr, s) {
(Expression::WitIn(_), MonomialState::SumTerm) => true,
(Expression::WitIn(_), MonomialState::ProductTerm) => true,
(Expression::Constant(_), MonomialState::SumTerm) => true,
(Expression::Constant(_), MonomialState::ProductTerm) => true,
(
Expression::Fixed(_)
| Expression::WitIn(_)
| Expression::Challenge(..)
| Expression::Constant(_),
_,
) => true,
(Expression::Sum(a, b), MonomialState::SumTerm) => {
Self::is_monomial_form_inner(MonomialState::SumTerm, a)
&& Self::is_monomial_form_inner(MonomialState::SumTerm, b)
Expand All @@ -121,8 +130,6 @@ impl<E: ExtensionField> Expression<E> {
}
(Expression::ScaledSum(_, _, _), MonomialState::SumTerm) => true,
(Expression::ScaledSum(_, _, b), MonomialState::ProductTerm) => Self::is_zero_expr(b),
(Expression::Challenge(_, _, _, _), MonomialState::SumTerm) => true,
(Expression::Challenge(_, _, _, _), MonomialState::ProductTerm) => true,
}
}
}
Expand All @@ -131,7 +138,7 @@ impl<E: ExtensionField> Neg for Expression<E> {
type Output = Expression<E>;
fn neg(self) -> Self::Output {
match self {
Expression::WitIn(_) => Expression::ScaledSum(
Expression::Fixed(_) | Expression::WitIn(_) => Expression::ScaledSum(
Box::new(self),
Box::new(Expression::Constant(E::BaseField::ONE.neg())),
Box::new(Expression::Constant(E::BaseField::ZERO)),
Expand Down Expand Up @@ -386,6 +393,9 @@ pub struct WitIn {
pub id: WitnessId,
}

#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
pub struct Fixed(pub usize);

pub trait ToExpr<E: ExtensionField> {
type Output;
fn expr(&self) -> Self::Output;
Expand Down
5 changes: 2 additions & 3 deletions ceno_zkvm/src/instructions/riscv/test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use goldilocks::GoldilocksExt2;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem, ProvingKey},
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
};

Expand All @@ -26,6 +26,5 @@ fn test_multiple_opcode() {
Ok(config)
},
);
let vk = cs.key_gen();
let _pk = ProvingKey::create_pk(vk);
cs.key_gen(None);
}
1 change: 1 addition & 0 deletions ceno_zkvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pub mod error;
pub mod instructions;
pub mod scheme;
pub mod tables;
// #[cfg(test)]
pub use utils::u64vec;
mod chip_handler;
Expand Down
20 changes: 20 additions & 0 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,23 @@ pub struct ZKVMProof<E: ExtensionField> {

pub wits_in_evals: Vec<E>,
}

#[derive(Clone)]
pub struct ZKVMTableProof<E: ExtensionField> {
pub num_instances: usize,
// logup sum at layer 1
pub lk_p1_out_eval: E,
pub lk_p2_out_eval: E,
pub lk_q1_out_eval: E,
pub lk_q2_out_eval: E,

pub tower_proof: TowerProofs<E>,

// select layer sumcheck proof
pub sel_sumcheck_proofs: Vec<IOPProverMessage<E>>,
pub lk_d_in_evals: Vec<E>,
pub lk_n_in_evals: Vec<E>,

pub fixed_in_evals: Vec<E>,
pub wits_in_evals: Vec<E>,
}
Loading
Loading