diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 0e88f9361..893a02eec 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -13,7 +13,7 @@ use crate::{ #[derive(Debug)] pub struct LtWtns { - pub is_lt: WitIn, + pub is_lt: Option, pub diff_lo: WitIn, pub diff_hi: WitIn, #[cfg(feature = "riv64")] @@ -277,12 +277,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } - // TODO support riv64 feature + /// less_than pub(crate) fn less_than( &mut self, name_fn: N, lhs: Expression, rhs: Expression, + assert_less_than: Option, ) -> Result where NR: Into + Display + Clone, @@ -292,7 +293,19 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { || "less_than", |cb| { let name = name_fn(); - let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?; + let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than { + ( + None, + if lt { + Expression::ONE + } else { + Expression::ZERO + }, + ) + } else { + let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?; + (Some(is_lt), is_lt.expr()) + }; let mut witin_u16 = |var_name: &str| -> Result { cb.namespace( @@ -322,7 +335,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { lhs - rhs, #[cfg(feature = "riv32")] { - diff_lo.expr() + diff_hi.expr() * (1 << 16).into() - is_lt.expr() * range + diff_lo.expr() + diff_hi.expr() * (1 << 16).into() - is_lt_expr * range }, #[cfg(feature = "riv64")] { @@ -347,27 +360,6 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { ) } - pub(crate) fn assert_less_than( - &mut self, - name_fn: N, - lhs: Expression, - rhs: Expression, - ) -> Result - where - NR: Into + Clone + Display, - N: FnOnce() -> NR, - { - self.namespace( - || "assert_less_than", - |cb| { - let name = name_fn(); - let lt_wtns = cb.less_than::<_, _>(|| name.clone(), lhs, rhs)?; - cb.require_one(|| name, lt_wtns.is_lt.expr())?; - Ok(lt_wtns) - }, - ) - } - pub(crate) fn is_equal( &mut self, lhs: Expression, diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 16be4128e..07bc7153c 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -38,6 +38,9 @@ enum MonomialState { } impl Expression { + pub const ZERO: Expression = Expression::Constant(E::BaseField::ZERO); + pub const ONE: Expression = Expression::Constant(E::BaseField::ONE); + pub fn degree(&self) -> usize { match self { Expression::Fixed(_) => 1, diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 99f4fc00f..d0e288052 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -126,18 +126,24 @@ fn add_sub_gadget( let next_ts = ts + 1.into(); circuit_builder.state_out(next_pc, next_ts)?; - let lt_wtns_rs1 = circuit_builder.assert_less_than( + let lt_wtns_rs1 = circuit_builder.less_than( || "prev_rs1_ts < ts", prev_rs1_ts.expr(), cur_ts.expr(), + Some(true), )?; - let lt_wtns_rs2 = circuit_builder.assert_less_than( + let lt_wtns_rs2 = circuit_builder.less_than( || "prev_rs2_ts < ts", prev_rs2_ts.expr(), cur_ts.expr(), + Some(true), + )?; + let lt_wtns_rd = circuit_builder.less_than( + || "prev_rd_ts < ts", + prev_rd_ts.expr(), + cur_ts.expr(), + Some(true), )?; - let lt_wtns_rd = - circuit_builder.assert_less_than(|| "prev_rd_ts < ts", prev_rd_ts.expr(), cur_ts.expr())?; Ok(InstructionConfig { pc, diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 7d22260a9..d2a3676ad 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -537,7 +537,7 @@ mod tests { ) -> Result { let a = cb.create_witin(|| "a")?; let b = cb.create_witin(|| "b")?; - let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr())?; + let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), Some(true))?; Ok(Self { a, b, lt_wtns }) } }