From e73cbd0f6f597dc7c4609637fb54fb2c77424a01 Mon Sep 17 00:00:00 2001 From: Soham Zemse <22412996+zemse@users.noreply.github.com> Date: Fri, 6 Sep 2024 17:26:09 +0530 Subject: [PATCH] support lt checks upto u32 range --- ceno_zkvm/src/chip_handler/general.rs | 36 +++++--- ceno_zkvm/src/instructions/riscv/addsub.rs | 28 ++++--- ceno_zkvm/src/scheme/mock_prover.rs | 97 ++++++++++++++++++++++ 3 files changed, 137 insertions(+), 24 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index ebb856dfd..4c3ccc208 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -266,12 +266,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } - pub(crate) fn less_than( + // TODO support riv64 feature + pub(crate) fn less_than( &mut self, name_fn: N, lhs: Expression, rhs: Expression, - ) -> Result<(WitIn, WitIn), ZKVMError> + ) -> Result<(WitIn, WitIn, WitIn), ZKVMError> where NR: Into + Display + Clone, N: FnOnce() -> NR, @@ -281,25 +282,38 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { |cb| { let name = name_fn(); let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?; - let diff = cb.create_witin(|| format!("{name} diff witin"))?; - let range = Expression::Constant(2u64.pow(C as u32).into()); + + let mut witin_u16 = |var_name: &str| -> Result { + cb.namespace( + || format!("var {var_name}"), + |cb| { + let witin = cb.create_witin(|| var_name.to_string())?; + cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?; + Ok(witin) + }, + ) + }; + + let diff_lo = witin_u16("diff_lo")?; + let diff_hi = witin_u16("diff_lo")?; + + let range = Expression::Constant((u32::MAX as u64).into()); cb.require_equal( || name.clone(), lhs - rhs, - diff.expr() - is_lt.expr() * range, + diff_lo.expr() + diff_hi.expr() * 2usize.pow(16).into() - is_lt.expr() * range, )?; - cb.assert_ux::<_, _, C>(|| name, diff.expr())?; - Ok((is_lt, diff)) + Ok((is_lt, diff_lo, diff_hi)) }, ) } - pub(crate) fn assert_less_than( + pub(crate) fn assert_less_than( &mut self, name_fn: N, lhs: Expression, rhs: Expression, - ) -> Result<(WitIn, WitIn), ZKVMError> + ) -> Result<(WitIn, WitIn, WitIn), ZKVMError> where NR: Into + Clone + Display, N: FnOnce() -> NR, @@ -308,9 +322,9 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { || "assert_less_than", |cb| { let name = name_fn(); - let (is_lt, diff) = cb.less_than::<_, _, C>(|| name.clone(), lhs, rhs)?; + let (is_lt, diff_lo, diff_hi) = cb.less_than::<_, _>(|| name.clone(), lhs, rhs)?; cb.require_one(|| name, is_lt.expr())?; - Ok((is_lt, diff)) + Ok((is_lt, diff_lo, diff_hi)) }, ) } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 2992dc738..67825b230 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -34,9 +34,9 @@ pub struct InstructionConfig { pub prev_rs1_ts: WitIn, pub prev_rs2_ts: WitIn, pub prev_rd_ts: WitIn, - pub lt_wtns_rs1: (WitIn, WitIn), - pub lt_wtns_rs2: (WitIn, WitIn), - pub lt_wtns_rd: (WitIn, WitIn), + pub lt_wtns_rs1: (WitIn, WitIn, WitIn), + pub lt_wtns_rs2: (WitIn, WitIn, WitIn), + pub lt_wtns_rd: (WitIn, WitIn, WitIn), phantom: PhantomData, } @@ -120,21 +120,18 @@ fn add_sub_gadget( let next_ts = ts + 1.into(); circuit_builder.state_out(next_pc, next_ts)?; - let lt_wtns_rs1 = circuit_builder.assert_less_than::<_, _, 16>( + let lt_wtns_rs1 = circuit_builder.assert_less_than( || "prev_rs1_ts < ts", prev_rs1_ts.expr(), cur_ts.expr(), )?; - let lt_wtns_rs2 = circuit_builder.assert_less_than::<_, _, 16>( + let lt_wtns_rs2 = circuit_builder.assert_less_than( || "prev_rs2_ts < ts", prev_rs2_ts.expr(), cur_ts.expr(), )?; - let lt_wtns_rd = circuit_builder.assert_less_than::<_, _, 16>( - || "prev_rd_ts < ts", - prev_rd_ts.expr(), - cur_ts.expr(), - )?; + let lt_wtns_rd = + circuit_builder.assert_less_than(|| "prev_rd_ts < ts", prev_rd_ts.expr(), cur_ts.expr())?; Ok(InstructionConfig { pc, @@ -199,14 +196,19 @@ impl Instruction for AddInstruction { set_val!(instance, config.prev_rs2_ts, 2); set_val!(instance, config.prev_rd_ts, 2); + let u16_max = u16::MAX as u64; + set_val!(instance, config.lt_wtns_rs1.0, 1); - set_val!(instance, config.lt_wtns_rs1.1, 2u64.pow(16) - 2 + 1); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rs1.1, u16_max - 2 + 1); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rs1.2, u16_max); set_val!(instance, config.lt_wtns_rs2.0, 1); - set_val!(instance, config.lt_wtns_rs2.1, 2u64.pow(16) - 3 + 2); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rs2.1, u16_max - 3 + 2); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rs2.2, u16_max); set_val!(instance, config.lt_wtns_rd.0, 1); - set_val!(instance, config.lt_wtns_rd.1, 2u64.pow(16) - 3 + 2); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rd.1, u16_max - 3 + 2); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rd.2, u16_max); Ok(()) } } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 5a7cbb2f8..aa976439b 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -384,6 +384,8 @@ mod tests { circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::{ToExpr, WitIn}, + instructions::Instruction, + set_val, }; use ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; @@ -518,4 +520,99 @@ mod tests { }] ); } + + #[allow(dead_code)] + #[derive(Debug)] + struct LtCircuit { + pub a: WitIn, + pub b: WitIn, + pub lt: WitIn, + pub diff_lo: WitIn, + pub diff_hi: WitIn, + } + + impl LtCircuit { + fn construct_circuit( + cb: &mut CircuitBuilder, + ) -> Result { + let a = cb.create_witin(|| "a")?; + let b = cb.create_witin(|| "b")?; + let (lt, diff_lo, diff_hi) = cb.less_than(|| "lt", a.expr(), b.expr())?; + Ok(Self { + a, + b, + lt, + diff_lo, + diff_hi, + }) + } + } + + #[test] + fn test_lt_1() { + let mut cs = ConstraintSystem::new(|| "test_lt_1"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let _ = LtCircuit::construct_circuit(&mut builder).unwrap(); + + let wits_in = vec![ + vec![Goldilocks::from(3u64), Goldilocks::from(5u64)] + .into_mle() + .into(), + vec![Goldilocks::from(5u64), Goldilocks::from(3u64)] + .into_mle() + .into(), + vec![Goldilocks::from(1u64), Goldilocks::from(0u64)] + .into_mle() + .into(), + vec![ + Goldilocks::from(u16::MAX as u64 + 3 - 5), + Goldilocks::from(5 - 3), + ] + .into_mle() + .into(), + vec![Goldilocks::from(u16::MAX as u64), Goldilocks::from(0)] + .into_mle() + .into(), + ]; + + MockProver::assert_satisfied(&mut builder, &wits_in, None); + } + + #[test] + fn test_lt_2() { + let mut cs = ConstraintSystem::new(|| "test_lt_1"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let _ = LtCircuit::construct_circuit(&mut builder).unwrap(); + + let wits_in = vec![ + vec![ + Goldilocks::from(u32::MAX as u64 - 5), + Goldilocks::from(u32::MAX as u64 - 3), + ] + .into_mle() + .into(), + vec![ + Goldilocks::from(u32::MAX as u64 - 3), + Goldilocks::from(u32::MAX as u64 - 5), + ] + .into_mle() + .into(), + vec![Goldilocks::from(1u64), Goldilocks::from(0u64)] + .into_mle() + .into(), + vec![ + Goldilocks::from(u16::MAX as u64 + 3 - 5), + Goldilocks::from(5 - 3), + ] + .into_mle() + .into(), + vec![Goldilocks::from(u16::MAX as u64), Goldilocks::from(0)] + .into_mle() + .into(), + ]; + + MockProver::assert_satisfied(&mut builder, &wits_in, None); + } }