diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 9385c0294..63b9c3f84 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -1,3 +1,5 @@ +use std::fmt::Display; + use ff_ext::ExtensionField; use ff::Field; @@ -264,6 +266,49 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } + pub(crate) fn less_than( + &mut self, + name_fn: N, + lhs: Expression, + rhs: Expression, + ) -> Result + where + NR: Into + Display, + N: FnOnce() -> NR, + { + self.namespace( + || "less_than", + |cb| { + let name = name_fn(); + let is_lt = cb.create_witin(|| format!("{name} witin"))?; + // TODO add name_fn to lookup_ltu_limb8, not done yet to avoid merge conflicts + cb.lookup_ltu_limb8(is_lt.expr(), lhs, rhs)?; + Ok(is_lt) + }, + ) + } + + pub(crate) fn assert_less_than( + &mut self, + name_fn: N, + lhs: Expression, + rhs: Expression, + ) -> Result + where + NR: Into + Clone + Display, + N: FnOnce() -> NR, + { + self.namespace( + || "assert_less_than", + |cb| { + let name = name_fn(); + let is_lt = cb.less_than(|| name.clone(), lhs, rhs)?; + cb.require_one(|| name, is_lt.expr())?; + Ok(is_lt) + }, + ) + } + pub(crate) fn is_equal( &mut self, lhs: Expression, diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 20609ff0d..d039e987d 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -34,6 +34,9 @@ pub struct InstructionConfig { pub prev_rs1_ts: WitIn, pub prev_rs2_ts: WitIn, pub prev_rd_ts: WitIn, + pub lt_wtns_rs1: WitIn, + pub lt_wtns_rs2: WitIn, + pub lt_wtns_rd: WitIn, phantom: PhantomData, } @@ -117,6 +120,19 @@ fn add_sub_gadget( let next_ts = ts + 1.into(); circuit_builder.state_out(next_pc, next_ts)?; + let lt_wtns_rs1 = circuit_builder.assert_less_than( + || "prev_rs1_ts < ts", + prev_rs1_ts.expr(), + cur_ts.expr(), + )?; + let lt_wtns_rs2 = circuit_builder.assert_less_than( + || "prev_rs2_ts < ts", + prev_rs2_ts.expr(), + cur_ts.expr(), + )?; + let lt_wtns_rd = + circuit_builder.assert_less_than(|| "prev_rd_ts < ts", prev_rd_ts.expr(), cur_ts.expr())?; + Ok(InstructionConfig { pc, ts: cur_ts, @@ -130,6 +146,9 @@ fn add_sub_gadget( prev_rs1_ts, prev_rs2_ts, prev_rd_ts, + lt_wtns_rs1, + lt_wtns_rs2, + lt_wtns_rd, phantom: PhantomData, }) } @@ -151,7 +170,7 @@ impl Instruction for AddInstruction { ) -> Result<(), ZKVMError> { // TODO use field from step set_val!(instance, config.pc, 1); - set_val!(instance, config.ts, 2); + set_val!(instance, config.ts, 3); config.prev_rd_value.wits_in().map(|prev_rd_value| { set_val!(instance, prev_rd_value[0], 4); set_val!(instance, prev_rd_value[1], 4); @@ -176,6 +195,9 @@ impl Instruction for AddInstruction { set_val!(instance, config.prev_rs1_ts, 2); set_val!(instance, config.prev_rs2_ts, 2); set_val!(instance, config.prev_rd_ts, 2); + set_val!(instance, config.lt_wtns_rs1, 1); + set_val!(instance, config.lt_wtns_rs2, 1); + set_val!(instance, config.lt_wtns_rd, 1); Ok(()) } } @@ -272,7 +294,7 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), - None, + Some([100.into(), 100000.into()]), ); } } diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index 72743b46e..2e5a71d2d 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -210,6 +210,14 @@ impl Instruction for BltInstruction { ) -> Result, ZKVMError> { blt_gadget::(circuit_builder) } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [std::mem::MaybeUninit], + step: ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + todo!() + } } #[cfg(test)] diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index f703bf785..9c58043d4 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -264,6 +264,8 @@ impl<'a, E: ExtensionField> MockProver { // TODO load more tables here let mut table_vec = vec![]; load_u5_table(&mut table_vec, cb, challenge); + load_u16_table(&mut table_vec, cb, challenge); + load_lt_table(&mut table_vec, cb, challenge); // Lookup expressions for (expr, name) in cb @@ -333,6 +335,43 @@ pub fn load_u5_table( } } +pub fn load_u16_table( + t_vec: &mut Vec, + cb: &CircuitBuilder, + challenge: [E; 2], +) { + t_vec.reserve(65536); + for i in 0..65536 { + let rlc_record = cb.rlc_chip_record(vec![ + Expression::Constant(E::BaseField::from(ROMType::U16 as u64)), + i.into(), + ]); + let rlc_record = eval_by_expr(&[], &challenge, &rlc_record); + t_vec.push(rlc_record); + } +} + +pub fn load_lt_table>( + t_vec: &mut Vec, + cb: &CircuitBuilder, + challenge: [E; 2], +) { + t_vec.reserve(65536); + for lhs in 0..256 { + for rhs in 0..256 { + let is_lt = if lhs < rhs { 1 } else { 0 }; + let lhs_rhs = lhs * 256 + rhs; + let rlc_record = cb.rlc_chip_record(vec![ + Expression::Constant(E::BaseField::from(ROMType::Ltu as u64)), + lhs_rhs.into(), + is_lt.into(), + ]); + let rlc_record = eval_by_expr(&[], &challenge, &rlc_record); + t_vec.push(rlc_record); + } + } +} + #[allow(unused_imports)] #[cfg(test)] mod tests {