Skip to content

Commit

Permalink
support lt checks upto u32 range
Browse files Browse the repository at this point in the history
  • Loading branch information
zemse committed Sep 6, 2024
1 parent 4724cca commit 3bd648a
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 24 deletions.
36 changes: 25 additions & 11 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Ok(())
}

pub(crate) fn less_than<N, NR, const C: usize>(
// TODO support riv64 feature
pub(crate) fn less_than<N, NR>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
) -> Result<(WitIn, WitIn), ZKVMError>
) -> Result<(WitIn, WitIn, WitIn), ZKVMError>
where
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
Expand All @@ -281,25 +282,38 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
|cb| {
let name = name_fn();
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
let diff = cb.create_witin(|| format!("{name} diff witin"))?;
let range = Expression::Constant(2u64.pow(C as u32).into());

let mut witin_u16 = |var_name: &str| -> Result<WitIn, ZKVMError> {
cb.namespace(
|| format!("var {var_name}"),
|cb| {
let witin = cb.create_witin(|| var_name.to_string())?;
cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?;
Ok(witin)
},
)
};

let diff_lo = witin_u16("diff_lo")?;
let diff_hi = witin_u16("diff_lo")?;

let range = Expression::Constant((u32::MAX as u64).into());
cb.require_equal(
|| name.clone(),
lhs - rhs,
diff.expr() - is_lt.expr() * range,
diff_lo.expr() + diff_hi.expr() * 2usize.pow(16).into() - is_lt.expr() * range,
)?;
cb.assert_ux::<_, _, C>(|| name, diff.expr())?;
Ok((is_lt, diff))
Ok((is_lt, diff_lo, diff_hi))
},
)
}

pub(crate) fn assert_less_than<N, NR, const C: usize>(
pub(crate) fn assert_less_than<N, NR>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
) -> Result<(WitIn, WitIn), ZKVMError>
) -> Result<(WitIn, WitIn, WitIn), ZKVMError>
where
NR: Into<String> + Clone + Display,
N: FnOnce() -> NR,
Expand All @@ -308,9 +322,9 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
|| "assert_less_than",
|cb| {
let name = name_fn();
let (is_lt, diff) = cb.less_than::<_, _, C>(|| name.clone(), lhs, rhs)?;
let (is_lt, diff_lo, diff_hi) = cb.less_than::<_, _>(|| name.clone(), lhs, rhs)?;
cb.require_one(|| name, is_lt.expr())?;
Ok((is_lt, diff))
Ok((is_lt, diff_lo, diff_hi))
},
)
}
Expand Down
28 changes: 15 additions & 13 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ pub struct InstructionConfig<E: ExtensionField> {
pub prev_rs1_ts: WitIn,
pub prev_rs2_ts: WitIn,
pub prev_rd_ts: WitIn,
pub lt_wtns_rs1: (WitIn, WitIn),
pub lt_wtns_rs2: (WitIn, WitIn),
pub lt_wtns_rd: (WitIn, WitIn),
pub lt_wtns_rs1: (WitIn, WitIn, WitIn),
pub lt_wtns_rs2: (WitIn, WitIn, WitIn),
pub lt_wtns_rd: (WitIn, WitIn, WitIn),
phantom: PhantomData<E>,
}

Expand Down Expand Up @@ -124,21 +124,18 @@ fn add_sub_gadget<E: ExtensionField, const IS_ADD: bool>(
let next_ts = ts + 1.into();
circuit_builder.state_out(next_pc, next_ts)?;

let lt_wtns_rs1 = circuit_builder.assert_less_than::<_, _, 16>(
let lt_wtns_rs1 = circuit_builder.assert_less_than(
|| "prev_rs1_ts < ts",
prev_rs1_ts.expr(),
cur_ts.expr(),
)?;
let lt_wtns_rs2 = circuit_builder.assert_less_than::<_, _, 16>(
let lt_wtns_rs2 = circuit_builder.assert_less_than(
|| "prev_rs2_ts < ts",
prev_rs2_ts.expr(),
cur_ts.expr(),
)?;
let lt_wtns_rd = circuit_builder.assert_less_than::<_, _, 16>(
|| "prev_rd_ts < ts",
prev_rd_ts.expr(),
cur_ts.expr(),
)?;
let lt_wtns_rd =
circuit_builder.assert_less_than(|| "prev_rd_ts < ts", prev_rd_ts.expr(), cur_ts.expr())?;

Ok(InstructionConfig {
pc,
Expand Down Expand Up @@ -205,14 +202,19 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
set_val!(instance, config.prev_rs2_ts, 2);
set_val!(instance, config.prev_rd_ts, 2);

let u16_max = u16::MAX as u64;

set_val!(instance, config.lt_wtns_rs1.0, 1);
set_val!(instance, config.lt_wtns_rs1.1, 2u64.pow(16) - 2 + 1); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rs1.1, u16_max - 2 + 1); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rs1.2, u16_max);

set_val!(instance, config.lt_wtns_rs2.0, 1);
set_val!(instance, config.lt_wtns_rs2.1, 2u64.pow(16) - 3 + 2); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rs2.1, u16_max - 3 + 2); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rs2.2, u16_max);

set_val!(instance, config.lt_wtns_rd.0, 1);
set_val!(instance, config.lt_wtns_rd.1, 2u64.pow(16) - 3 + 2); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rd.1, u16_max - 3 + 2); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rd.2, u16_max);
Ok(())
}
}
Expand Down
97 changes: 97 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ mod tests {
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
expression::{ToExpr, WitIn},
instructions::Instruction,
set_val,
};
use ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
Expand Down Expand Up @@ -519,4 +521,99 @@ mod tests {
}]
);
}

#[allow(dead_code)]
#[derive(Debug)]
struct LtCircuit {
pub a: WitIn,
pub b: WitIn,
pub lt: WitIn,
pub diff_lo: WitIn,
pub diff_hi: WitIn,
}

impl LtCircuit {
fn construct_circuit(
cb: &mut CircuitBuilder<GoldilocksExt2>,
) -> Result<LtCircuit, ZKVMError> {
let a = cb.create_witin(|| "a")?;
let b = cb.create_witin(|| "b")?;
let (lt, diff_lo, diff_hi) = cb.less_than(|| "lt", a.expr(), b.expr())?;
Ok(Self {
a,
b,
lt,
diff_lo,
diff_hi,
})
}
}

#[test]
fn test_lt_1() {
let mut cs = ConstraintSystem::new(|| "test_lt_1");
let mut builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);

let _ = LtCircuit::construct_circuit(&mut builder).unwrap();

let wits_in = vec![
vec![Goldilocks::from(3u64), Goldilocks::from(5u64)]
.into_mle()
.into(),
vec![Goldilocks::from(5u64), Goldilocks::from(3u64)]
.into_mle()
.into(),
vec![Goldilocks::from(1u64), Goldilocks::from(0u64)]
.into_mle()
.into(),
vec![
Goldilocks::from(u16::MAX as u64 + 3 - 5),
Goldilocks::from(5 - 3),
]
.into_mle()
.into(),
vec![Goldilocks::from(u16::MAX as u64), Goldilocks::from(0)]
.into_mle()
.into(),
];

MockProver::assert_satisfied(&mut builder, &wits_in, None);
}

#[test]
fn test_lt_2() {
let mut cs = ConstraintSystem::new(|| "test_lt_1");
let mut builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);

let _ = LtCircuit::construct_circuit(&mut builder).unwrap();

let wits_in = vec![
vec![
Goldilocks::from(u32::MAX as u64 - 5),
Goldilocks::from(u32::MAX as u64 - 3),
]
.into_mle()
.into(),
vec![
Goldilocks::from(u32::MAX as u64 - 3),
Goldilocks::from(u32::MAX as u64 - 5),
]
.into_mle()
.into(),
vec![Goldilocks::from(1u64), Goldilocks::from(0u64)]
.into_mle()
.into(),
vec![
Goldilocks::from(u16::MAX as u64 + 3 - 5),
Goldilocks::from(5 - 3),
]
.into_mle()
.into(),
vec![Goldilocks::from(u16::MAX as u64), Goldilocks::from(0)]
.into_mle()
.into(),
];

MockProver::assert_satisfied(&mut builder, &wits_in, None);
}
}

0 comments on commit 3bd648a

Please sign in to comment.