diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index f6a94bba7..e10ed4fed 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -3,6 +3,7 @@ use ff_ext::ExtensionField; use crate::{ error::ZKVMError, expression::{Expression, ToExpr, WitIn}, + instructions::riscv::config::ExprLtConfig, }; pub mod general; @@ -23,8 +24,9 @@ pub trait RegisterChipOperations, N: FnOnce( prev_ts: Expression, ts: Expression, values: &V, - ) -> Result, ZKVMError>; + ) -> Result<(Expression, ExprLtConfig), ZKVMError>; + #[allow(clippy::too_many_arguments)] fn register_write>>>( &mut self, name_fn: N, @@ -33,5 +35,5 @@ pub trait RegisterChipOperations, N: FnOnce( ts: Expression, prev_values: &V, values: &V, - ) -> Result, ZKVMError>; + ) -> Result<(Expression, ExprLtConfig), ZKVMError>; } diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 0b4a672a8..53e44b978 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -1,3 +1,5 @@ +use std::fmt::Display; + use ff_ext::ExtensionField; use ff::Field; @@ -6,6 +8,7 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, + instructions::riscv::config::ExprLtConfig, structs::ROMType, }; @@ -264,6 +267,72 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } + /// less_than + pub(crate) fn less_than( + &mut self, + name_fn: N, + lhs: Expression, + rhs: Expression, + assert_less_than: Option, + ) -> Result + where + NR: Into + Display + Clone, + N: FnOnce() -> NR, + { + #[cfg(feature = "riv64")] + panic!("less_than is not supported for riv64 yet"); + + #[cfg(feature = "riv32")] + self.namespace( + || "less_than", + |cb| { + let name = name_fn(); + let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than { + ( + None, + if lt { + Expression::ONE + } else { + Expression::ZERO + }, + ) + } else { + let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?; + (Some(is_lt), is_lt.expr()) + }; + + let mut witin_u16 = |var_name: String| -> Result { + cb.namespace( + || format!("var {var_name}"), + |cb| { + let witin = cb.create_witin(|| var_name.to_string())?; + cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?; + Ok(witin) + }, + ) + }; + + let diff = (0..2) + .map(|i| witin_u16(format!("diff_{i}"))) + .collect::, _>>()?; + + let diff_expr = diff + .iter() + .enumerate() + .map(|(i, diff)| (i, diff.expr())) + .fold(Expression::ZERO, |sum, (i, a)| { + sum + if i > 0 { a * (1 << (16 * i)).into() } else { a } + }); + + let range = Expression::Constant((1 << 32).into()); + + cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?; + + Ok(ExprLtConfig { is_lt, diff }) + }, + ) + } + pub(crate) fn is_equal( &mut self, lhs: Expression, diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index 712bfda1c..bdc978710 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -4,6 +4,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, + instructions::riscv::config::ExprLtConfig, structs::RAMType, }; @@ -19,7 +20,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe prev_ts: Expression, ts: Expression, values: &V, - ) -> Result, ZKVMError> { + ) -> Result<(Expression, ExprLtConfig), ZKVMError> { self.namespace(name_fn, |cb| { // READ (a, v, t) let read_record = cb.rlc_chip_record( @@ -29,7 +30,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe ))], vec![register_id.expr()], values.expr(), - vec![prev_ts], + vec![prev_ts.clone()], ] .concat(), ); @@ -49,12 +50,11 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe cb.write_record(|| "write_record", write_record)?; // assert prev_ts < current_ts - // TODO implement lt gadget - // let is_lt = prev_ts.lt(self, ts)?; - // self.require_one(is_lt)?; + let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?; + let next_ts = ts + 1.into(); - Ok(next_ts) + Ok((next_ts, lt_cfg)) }) } @@ -66,7 +66,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe ts: Expression, prev_values: &V, values: &V, - ) -> Result, ZKVMError> { + ) -> Result<(Expression, ExprLtConfig), ZKVMError> { self.namespace(name_fn, |cb| { // READ (a, v, t) let read_record = cb.rlc_chip_record( @@ -76,7 +76,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe ))], vec![register_id.expr()], prev_values.expr(), - vec![prev_ts], + vec![prev_ts.clone()], ] .concat(), ); @@ -95,13 +95,11 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe cb.read_record(|| "read_record", read_record)?; cb.write_record(|| "write_record", write_record)?; - // assert prev_ts < current_ts - // TODO implement lt gadget - // let is_lt = prev_ts.lt(self, ts)?; - // self.require_one(is_lt)?; + let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?; + let next_ts = ts + 1.into(); - Ok(next_ts) + Ok((next_ts, lt_cfg)) }) } } diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 16be4128e..07bc7153c 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -38,6 +38,9 @@ enum MonomialState { } impl Expression { + pub const ZERO: Expression = Expression::Constant(E::BaseField::ZERO); + pub const ONE: Expression = Expression::Constant(E::BaseField::ONE); + pub fn degree(&self) -> usize { match self { Expression::Fixed(_) => 1, diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index f563ba68b..670a247d9 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -5,6 +5,7 @@ use ff_ext::ExtensionField; use itertools::Itertools; use super::{ + config::ExprLtConfig, constants::{OPType, OpcodeType, RegUInt, PC_STEP_SIZE}, RIVInstruction, }; @@ -13,7 +14,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{ToExpr, WitIn}, - instructions::Instruction, + instructions::{riscv::config::ExprLtInput, Instruction}, set_val, uint::UIntValue, witness::LkMultiplicity, @@ -37,6 +38,9 @@ pub struct InstructionConfig { pub prev_rs1_ts: WitIn, pub prev_rs2_ts: WitIn, pub prev_rd_ts: WitIn, + pub lt_rs1_cfg: ExprLtConfig, + pub lt_rs2_cfg: ExprLtConfig, + pub lt_prev_ts_cfg: ExprLtConfig, phantom: PhantomData, } @@ -99,17 +103,17 @@ fn add_sub_gadget( let prev_rs2_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?; let prev_rd_ts = circuit_builder.create_witin(|| "prev_rd_ts")?; - let ts = circuit_builder.register_read( + let (ts, lt_rs1_cfg) = circuit_builder.register_read( || "read_rs1", &rs1_id, prev_rs1_ts.expr(), cur_ts.expr(), &addend_0, )?; - let ts = + let (ts, lt_rs2_cfg) = circuit_builder.register_read(|| "read_rs2", &rs2_id, prev_rs2_ts.expr(), ts, &addend_1)?; - let ts = circuit_builder.register_write( + let (ts, lt_prev_ts_cfg) = circuit_builder.register_write( || "write_rd", &rd_id, prev_rd_ts.expr(), @@ -134,6 +138,9 @@ fn add_sub_gadget( prev_rs1_ts, prev_rs2_ts, prev_rd_ts, + lt_rs1_cfg, + lt_rs2_cfg, + lt_prev_ts_cfg, phantom: PhantomData, }) } @@ -159,7 +166,7 @@ impl Instruction for AddInstruction { ) -> Result<(), ZKVMError> { // TODO use fields from step set_val!(instance, config.pc, 1); - set_val!(instance, config.ts, 2); + set_val!(instance, config.ts, 3); let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value); let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value); let rd_prev = UIntValue::new_unchecked(step.rd().unwrap().value.before); @@ -187,6 +194,23 @@ impl Instruction for AddInstruction { set_val!(instance, config.prev_rs1_ts, 2); set_val!(instance, config.prev_rs2_ts, 2); set_val!(instance, config.prev_rd_ts, 2); + + ExprLtInput { + lhs: 2, // rs1 + rhs: 3, // cur_ts + } + .assign(instance, &config.lt_rs1_cfg); + ExprLtInput { + lhs: 2, // rs2 + rhs: 4, // cur_ts + } + .assign(instance, &config.lt_rs2_cfg); + ExprLtInput { + lhs: 2, // rd + rhs: 5, // cur_ts + } + .assign(instance, &config.lt_prev_ts_cfg); + Ok(()) } } @@ -362,7 +386,7 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), - None, + Some([100.into(), 100000.into()]), ); } } diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index c6bc8ab53..112c756d0 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -10,7 +10,7 @@ use crate::{ error::ZKVMError, expression::{ToExpr, WitIn}, instructions::{ - riscv::config::{LtConfig, LtInput}, + riscv::config::{UIntLtConfig, UIntLtInput}, Instruction, }, set_val, @@ -19,6 +19,7 @@ use crate::{ }; use super::{ + config::ExprLtConfig, constants::{OPType, OpcodeType, RegUInt, RegUInt8, PC_STEP_SIZE}, RIVInstruction, }; @@ -38,7 +39,9 @@ pub struct InstructionConfig { pub rs2_id: WitIn, pub prev_rs1_ts: WitIn, pub prev_rs2_ts: WitIn, - pub is_lt: LtConfig, + pub is_lt: UIntLtConfig, + pub lt_rs1_cfg: ExprLtConfig, + pub lt_rs2_cfg: ExprLtConfig, } pub struct BltInput { @@ -62,7 +65,7 @@ impl BltInput { ) { assert!(!self.lhs_limb8.is_empty() && (self.lhs_limb8.len() == self.rhs_limb8.len())); // TODO: add boundary check for witin - let lt_input = LtInput { + let lt_input = UIntLtInput { lhs_limbs: &self.lhs_limb8, rhs_limbs: &self.rhs_limb8, }; @@ -175,14 +178,14 @@ fn blt_gadget( let lhs = RegUInt::from_u8_limbs(circuit_builder, &lhs_limb8); let rhs = RegUInt::from_u8_limbs(circuit_builder, &rhs_limb8); - let ts = circuit_builder.register_read( + let (ts, lt_rs1_cfg) = circuit_builder.register_read( || "read ts for lhs", &rs1_id, prev_rs1_ts.expr(), cur_ts.expr(), &lhs, )?; - let ts = circuit_builder.register_read( + let (ts, lt_rs2_cfg) = circuit_builder.register_read( || "read ts for rhs", &rs2_id, prev_rs2_ts.expr(), @@ -208,6 +211,8 @@ fn blt_gadget( prev_rs1_ts, prev_rs2_ts, is_lt, + lt_rs1_cfg, + lt_rs2_cfg, }) } @@ -270,7 +275,7 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), - None, + Some([1.into(), 1000.into()]), ) .expect_err("lookup will fail"); Ok(()) diff --git a/ceno_zkvm/src/instructions/riscv/config.rs b/ceno_zkvm/src/instructions/riscv/config.rs index e5816b2d3..77972e7a9 100644 --- a/ceno_zkvm/src/instructions/riscv/config.rs +++ b/ceno_zkvm/src/instructions/riscv/config.rs @@ -42,7 +42,7 @@ impl MsbInput<'_> { } #[derive(Clone)] -pub struct LtuConfig { +pub struct UIntLtuConfig { pub indexes: Vec, pub acc_indexes: Vec, pub byte_diff_inv: WitIn, @@ -51,16 +51,16 @@ pub struct LtuConfig { pub is_ltu: WitIn, } -pub struct LtuInput<'a> { +pub struct UIntLtuInput<'a> { pub lhs_limbs: &'a [u8], pub rhs_limbs: &'a [u8], } -impl LtuInput<'_> { +impl UIntLtuInput<'_> { pub fn assign( &self, instance: &mut [MaybeUninit], - config: &LtuConfig, + config: &UIntLtuConfig, ) -> bool { let mut idx = 0; let mut flag: bool = false; @@ -105,25 +105,25 @@ impl LtuInput<'_> { } #[derive(Clone)] -pub struct LtConfig { +pub struct UIntLtConfig { pub lhs_msb: MsbConfig, pub rhs_msb: MsbConfig, pub msb_is_equal: WitIn, pub msb_diff_inv: WitIn, - pub is_ltu: LtuConfig, + pub is_ltu: UIntLtuConfig, pub is_lt: WitIn, } -pub struct LtInput<'a> { +pub struct UIntLtInput<'a> { pub lhs_limbs: &'a [u8], pub rhs_limbs: &'a [u8], } -impl LtInput<'_> { +impl UIntLtInput<'_> { pub fn assign( &self, instance: &mut [MaybeUninit], - config: &LtConfig, + config: &UIntLtConfig, ) -> bool { let n_limbs = self.lhs_limbs.len(); let lhs_msb_input = MsbInput { @@ -141,7 +141,7 @@ impl LtInput<'_> { let mut rhs_limbs_no_msb = self.rhs_limbs.iter().copied().collect_vec(); rhs_limbs_no_msb[n_limbs - 1] = rhs_high_limb_no_msb; - let ltu_input = LtuInput { + let ltu_input = UIntLtuInput { lhs_limbs: &lhs_limbs_no_msb, rhs_limbs: &rhs_limbs_no_msb, }; @@ -168,3 +168,30 @@ impl LtInput<'_> { is_lt > 0 } } + +#[derive(Debug)] +pub struct ExprLtConfig { + pub is_lt: Option, + pub diff: Vec, +} + +pub struct ExprLtInput { + pub lhs: u64, + pub rhs: u64, +} + +impl ExprLtInput { + pub fn assign(&self, instance: &mut [MaybeUninit], config: &ExprLtConfig) { + if let Some(is_lt) = config.is_lt { + set_val!(instance, is_lt, { if self.lhs < self.rhs { 1 } else { 0 } }); + } + + let diff = self.lhs as i64 - self.rhs as i64; + config.diff.iter().enumerate().for_each(|(i, wit)| { + // extract the 16 bit limb from diff and assign to instance + set_val!(instance, wit, { + i64_to_base::((diff >> (i * 16)) & 0xffff) + }); + }); + } +} diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index f116afb01..4f8ca3f01 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -8,7 +8,7 @@ use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; -use std::{marker::PhantomData, ops::Neg}; +use std::{collections::HashSet, hash::Hash, marker::PhantomData, ops::Neg}; #[allow(clippy::enum_variant_names)] #[derive(Debug, PartialEq, Clone)] @@ -109,7 +109,7 @@ impl MockProverError { format!("WitIn({})", wit_in) } Expression::Challenge(id, _, _, _) => format!("Challenge({})", id), - Expression::Constant(constant) => fmt_base_field::(constant).to_string(), + Expression::Constant(constant) => fmt_base_field::(constant, true).to_string(), Expression::Fixed(fixed) => format!("{:?}", fixed), Expression::Sum(left, right) => { let s = format!( @@ -146,38 +146,50 @@ impl MockProverError { field .as_bases() .iter() - .map(fmt_base_field::) + .map(|b| fmt_base_field::(b, false)) .collect::>() .join(",") ) } - fn fmt_base_field(base_field: &E::BaseField) -> String { + fn fmt_base_field(base_field: &E::BaseField, add_prn: bool) -> String { let value = base_field.to_canonical_u64(); if value > E::BaseField::MODULUS_U64 - u16::MAX as u64 { // beautiful format for negative number > -65536 - format!("(-{})", E::BaseField::MODULUS_U64 - value) + fmt_prn(format!("-{}", E::BaseField::MODULUS_U64 - value), add_prn) } else if value < u16::MAX as u64 { format!("{value}") } else { // hex - format!("{value:#x}") + if value > E::BaseField::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) { + fmt_prn( + format!("-{:#x}", E::BaseField::MODULUS_U64 - value), + add_prn, + ) + } else { + format!("{value:#x}") + } } } + fn fmt_prn(s: String, add_prn: bool) -> String { + if add_prn { format!("({})", s) } else { s } + } + fn fmt_wtns( wtns: &[WitnessId], wits_in: &[ArcMultilinearExtension], inst_id: usize, ) -> String { wtns.iter() + .sorted() .map(|wt_id| { let wit = &wits_in[*wt_id as usize]; let value_fmt = if let Some(e) = wit.get_ext_field_vec_optn() { fmt_field(&e[inst_id]) } else if let Some(bf) = wit.get_base_field_vec_optn() { - fmt_base_field::(&bf[inst_id]) + fmt_base_field::(&bf[inst_id], true) } else { "Unknown".to_string() }; @@ -192,7 +204,10 @@ pub(crate) struct MockProver { _phantom: PhantomData, } -impl<'a, E: ExtensionField> MockProver { +impl<'a, E: ExtensionField> MockProver +where + E: Hash, +{ #[allow(dead_code)] pub fn run( cb: &mut CircuitBuilder, @@ -266,6 +281,8 @@ impl<'a, E: ExtensionField> MockProver { let mut table_vec = vec![]; load_u5_table(&mut table_vec, cb, challenge); load_u16_table(&mut table_vec, cb, challenge); + load_lt_table(&mut table_vec, cb, challenge); + let table: HashSet = table_vec.into_iter().collect(); // Lookup expressions for (expr, name) in cb @@ -279,7 +296,7 @@ impl<'a, E: ExtensionField> MockProver { // Check each lookup expr exists in t vec for (inst_id, element) in expr_evaluated.iter().enumerate() { - if !table_vec.contains(element) { + if !table.contains(element) { errors.push(MockProverError::LookupError { expression: expr.clone(), evaluated: *element, @@ -340,6 +357,7 @@ pub fn load_u16_table( cb: &CircuitBuilder, challenge: [E; 2], ) { + t_vec.reserve(1 << 16); for i in 0..(1 << 16) { let rlc_record = cb.rlc_chip_record(vec![ Expression::Constant(E::BaseField::from(ROMType::U16 as u64)), @@ -350,18 +368,48 @@ pub fn load_u16_table( } } +pub fn load_lt_table( + t_vec: &mut Vec, + cb: &CircuitBuilder, + challenge: [E; 2], +) { + t_vec.reserve(1 << 16); + for lhs in 0..(1 << 8) { + for rhs in 0..(1 << 8) { + let is_lt = if lhs < rhs { 1 } else { 0 }; + let lhs_rhs = lhs * 256 + rhs; + let rlc_record = cb.rlc_chip_record(vec![ + Expression::Constant(E::BaseField::from(ROMType::Ltu as u64)), + lhs_rhs.into(), + is_lt.into(), + ]); + let rlc_record = eval_by_expr(&[], &challenge, &rlc_record); + t_vec.push(rlc_record); + } + } +} + #[allow(unused_imports)] #[cfg(test)] mod tests { + use std::mem::MaybeUninit; + use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::{ToExpr, WitIn}, + instructions::{ + riscv::config::{ExprLtConfig, ExprLtInput}, + Instruction, + }, + set_val, + witness::RowMajorMatrix, }; use ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; - use multilinear_extensions::mle::IntoMLE; + use multilinear_extensions::mle::{IntoMLE, IntoMLEs}; + use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; #[derive(Debug)] #[allow(dead_code)] @@ -492,4 +540,239 @@ mod tests { }] ); } + + #[allow(dead_code)] + #[derive(Debug)] + struct AssertLtCircuit { + pub a: WitIn, + pub b: WitIn, + pub lt_wtns: ExprLtConfig, + } + + struct AssertLtCircuitInput { + pub a: u64, + pub b: u64, + } + + impl AssertLtCircuit { + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let a = cb.create_witin(|| "a")?; + let b = cb.create_witin(|| "b")?; + let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), Some(true))?; + Ok(Self { a, b, lt_wtns }) + } + + fn assign_instance( + &self, + instance: &mut [MaybeUninit], + input: AssertLtCircuitInput, + ) -> Result<(), ZKVMError> { + set_val!(instance, self.a, input.a); + set_val!(instance, self.b, input.b); + ExprLtInput { + lhs: input.a, + rhs: input.b, + } + .assign(instance, &self.lt_wtns); + + Ok(()) + } + + fn assign_instances( + &self, + num_witin: usize, + instances: Vec, + ) -> Result, ZKVMError> { + let mut raw_witin = RowMajorMatrix::::new(instances.len(), num_witin); + let raw_witin_iter = raw_witin.par_iter_mut(); + + raw_witin_iter + .zip_eq(instances.into_par_iter()) + .map(|(instance, input)| self.assign_instance::(instance, input)) + .collect::>()?; + + Ok(raw_witin) + } + } + + #[test] + fn test_assert_lt_1() { + let mut cs = ConstraintSystem::new(|| "test_assert_lt_1"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let circuit = AssertLtCircuit::construct_circuit(&mut builder).unwrap(); + + let raw_witin = circuit + .assign_instances::( + builder.cs.num_witin as usize, + vec![ + AssertLtCircuitInput { a: 3, b: 5 }, + AssertLtCircuitInput { a: 7, b: 11 }, + ], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut builder, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + Some([1.into(), 1000.into()]), + ); + } + + #[test] + fn test_assert_lt_u32() { + let mut cs = ConstraintSystem::new(|| "test_assert_lt_u32"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let circuit = AssertLtCircuit::construct_circuit(&mut builder).unwrap(); + let raw_witin = circuit + .assign_instances::( + builder.cs.num_witin as usize, + vec![ + AssertLtCircuitInput { + a: u32::MAX as u64 - 5, + b: u32::MAX as u64 - 3, + }, + AssertLtCircuitInput { + a: u32::MAX as u64 - 3, + b: u32::MAX as u64 - 2, + }, + ], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut builder, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + Some([1.into(), 1000.into()]), + ); + } + + #[allow(dead_code)] + #[derive(Debug)] + struct LtCircuit { + pub a: WitIn, + pub b: WitIn, + pub lt_wtns: ExprLtConfig, + } + + struct LtCircuitInput { + pub a: u64, + pub b: u64, + } + + impl LtCircuit { + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let a = cb.create_witin(|| "a")?; + let b = cb.create_witin(|| "b")?; + let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), None)?; + Ok(Self { a, b, lt_wtns }) + } + + fn assign_instance( + &self, + instance: &mut [MaybeUninit], + input: LtCircuitInput, + ) -> Result<(), ZKVMError> { + set_val!(instance, self.a, input.a); + set_val!(instance, self.b, input.b); + ExprLtInput { + lhs: input.a, + rhs: input.b, + } + .assign(instance, &self.lt_wtns); + + Ok(()) + } + + fn assign_instances( + &self, + num_witin: usize, + instances: Vec, + ) -> Result, ZKVMError> { + let mut raw_witin = RowMajorMatrix::::new(instances.len(), num_witin); + let raw_witin_iter = raw_witin.par_iter_mut(); + + raw_witin_iter + .zip_eq(instances.into_par_iter()) + .map(|(instance, input)| self.assign_instance::(instance, input)) + .collect::>()?; + + Ok(raw_witin) + } + } + + #[test] + fn test_lt_1() { + let mut cs = ConstraintSystem::new(|| "test_lt_1"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let circuit = LtCircuit::construct_circuit(&mut builder).unwrap(); + + let raw_witin = circuit + .assign_instances::( + builder.cs.num_witin as usize, + vec![ + LtCircuitInput { a: 3, b: 5 }, + LtCircuitInput { a: 7, b: 11 }, + ], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut builder, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + Some([1.into(), 1000.into()]), + ); + } + + #[test] + fn test_lt_u32() { + let mut cs = ConstraintSystem::new(|| "test_lt_u32"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let circuit = LtCircuit::construct_circuit(&mut builder).unwrap(); + + let raw_witin = circuit + .assign_instances::( + builder.cs.num_witin as usize, + vec![ + LtCircuitInput { + a: u32::MAX as u64 - 5, + b: u32::MAX as u64 - 3, + }, + LtCircuitInput { + a: u32::MAX as u64 - 3, + b: u32::MAX as u64 - 5, + }, + ], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut builder, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + Some([1.into(), 1000.into()]), + ); + } } diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index f362ae12d..3a443a3d3 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -8,7 +8,7 @@ use crate::{ create_witin_from_expr, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, - instructions::riscv::config::{IsEqualConfig, LtConfig, LtuConfig, MsbConfig}, + instructions::riscv::config::{IsEqualConfig, MsbConfig, UIntLtConfig, UIntLtuConfig}, }; impl UInt { @@ -265,7 +265,7 @@ impl UInt { &self, circuit_builder: &mut CircuitBuilder, rhs: &UInt, - ) -> Result { + ) -> Result { let n_bytes = Self::NUM_CELLS; let indexes: Vec = (0..n_bytes) .map(|_| circuit_builder.create_witin(|| "index")) @@ -343,7 +343,7 @@ impl UInt { // circuit_builder.assert_bit(is_ltu.expr())?; // lookup ensure it is bit // now we know the first non-equal byte pairs is (lhs_ne_byte, rhs_ne_byte) circuit_builder.lookup_ltu_limb8(is_ltu.expr(), lhs_ne_byte.expr(), rhs_ne_byte.expr())?; - Ok(LtuConfig { + Ok(UIntLtuConfig { byte_diff_inv, indexes, acc_indexes: si, @@ -357,7 +357,7 @@ impl UInt { &self, circuit_builder: &mut CircuitBuilder, rhs: &UInt, - ) -> Result { + ) -> Result { let is_lt = circuit_builder.create_witin(|| "is_lt")?; circuit_builder.assert_bit(|| "assert_bit", is_lt.expr())?; @@ -383,7 +383,7 @@ impl UInt { + msb_is_equal.expr() * is_ltu.is_ltu.expr() - is_lt.expr(), )?; - Ok(LtConfig { + Ok(UIntLtConfig { lhs_msb, rhs_msb, msb_is_equal,