Skip to content

Commit

Permalink
remove unnecessary wtns in assert less than case
Browse files Browse the repository at this point in the history
  • Loading branch information
zemse committed Sep 9, 2024
1 parent 41fc120 commit f9f1e85
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 30 deletions.
42 changes: 17 additions & 25 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{

#[derive(Debug)]
pub struct LtWtns {
pub is_lt: WitIn,
pub is_lt: Option<WitIn>,
pub diff_lo: WitIn,
pub diff_hi: WitIn,
#[cfg(feature = "riv64")]
Expand Down Expand Up @@ -277,12 +277,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Ok(())
}

// TODO support riv64 feature
/// less_than
pub(crate) fn less_than<N, NR>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
assert_less_than: Option<bool>,
) -> Result<LtWtns, ZKVMError>
where
NR: Into<String> + Display + Clone,
Expand All @@ -292,7 +293,19 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
|| "less_than",
|cb| {
let name = name_fn();
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than {
(
None,
if lt {
Expression::ONE
} else {
Expression::ZERO
},
)
} else {
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
(Some(is_lt), is_lt.expr())
};

let mut witin_u16 = |var_name: &str| -> Result<WitIn, ZKVMError> {
cb.namespace(
Expand Down Expand Up @@ -322,7 +335,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
lhs - rhs,
#[cfg(feature = "riv32")]
{
diff_lo.expr() + diff_hi.expr() * (1 << 16).into() - is_lt.expr() * range
diff_lo.expr() + diff_hi.expr() * (1 << 16).into() - is_lt_expr * range
},
#[cfg(feature = "riv64")]
{
Expand All @@ -347,27 +360,6 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
)
}

pub(crate) fn assert_less_than<N, NR>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
) -> Result<LtWtns, ZKVMError>
where
NR: Into<String> + Clone + Display,
N: FnOnce() -> NR,
{
self.namespace(
|| "assert_less_than",
|cb| {
let name = name_fn();
let lt_wtns = cb.less_than::<_, _>(|| name.clone(), lhs, rhs)?;
cb.require_one(|| name, lt_wtns.is_lt.expr())?;
Ok(lt_wtns)
},
)
}

pub(crate) fn is_equal(
&mut self,
lhs: Expression<E>,
Expand Down
3 changes: 3 additions & 0 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ enum MonomialState {
}

impl<E: ExtensionField> Expression<E> {
pub const ZERO: Expression<E> = Expression::Constant(E::BaseField::ZERO);
pub const ONE: Expression<E> = Expression::Constant(E::BaseField::ONE);

pub fn degree(&self) -> usize {
match self {
Expression::Fixed(_) => 1,
Expand Down
14 changes: 10 additions & 4 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,24 @@ fn add_sub_gadget<E: ExtensionField, const IS_ADD: bool>(
let next_ts = ts + 1.into();
circuit_builder.state_out(next_pc, next_ts)?;

let lt_wtns_rs1 = circuit_builder.assert_less_than(
let lt_wtns_rs1 = circuit_builder.less_than(
|| "prev_rs1_ts < ts",
prev_rs1_ts.expr(),
cur_ts.expr(),
Some(true),
)?;
let lt_wtns_rs2 = circuit_builder.assert_less_than(
let lt_wtns_rs2 = circuit_builder.less_than(
|| "prev_rs2_ts < ts",
prev_rs2_ts.expr(),
cur_ts.expr(),
Some(true),
)?;
let lt_wtns_rd = circuit_builder.less_than(
|| "prev_rd_ts < ts",
prev_rd_ts.expr(),
cur_ts.expr(),
Some(true),
)?;
let lt_wtns_rd =
circuit_builder.assert_less_than(|| "prev_rd_ts < ts", prev_rd_ts.expr(), cur_ts.expr())?;

Ok(InstructionConfig {
pc,
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ mod tests {
) -> Result<LtCircuit, ZKVMError> {
let a = cb.create_witin(|| "a")?;
let b = cb.create_witin(|| "b")?;
let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr())?;
let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), Some(true))?;
Ok(Self { a, b, lt_wtns })
}
}
Expand Down

0 comments on commit f9f1e85

Please sign in to comment.