Skip to content

Commit

Permalink
Merge branch 'master' into mock-prover
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 authored Sep 4, 2024
2 parents c10acb6 + 4500e95 commit 246ee53
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 97 deletions.
8 changes: 0 additions & 8 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,6 @@ pub struct WitIn {
pub struct Fixed(pub usize);

impl WitIn {
pub fn assign<V, E>(&self, witin: &mut [E], to: V)
where
V: FnOnce() -> E,
{
// TODO: handle out of bound Error
witin[self.id as usize] = to();
}

pub fn from_expr<E: ExtensionField>(
circuit_builder: &mut CircuitBuilder<E>,
input: Expression<E>,
Expand Down
114 changes: 62 additions & 52 deletions ceno_zkvm/src/instructions/riscv/blt.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem::MaybeUninit;

use ff_ext::ExtensionField;

use crate::{
Expand All @@ -10,6 +12,7 @@ use crate::{
riscv::config::{LtConfig, LtInput},
Instruction,
},
set_val,
utils::{i64_to_ext, limb_u8_to_u16},
};

Expand Down Expand Up @@ -50,60 +53,60 @@ pub struct BltInput {

impl BltInput {
/// TODO: refactor after formalize the interface of opcode inputs
pub fn generate_witness<E: ExtensionField>(
pub fn assign<E: ExtensionField>(
&self,
witin: &mut [E],
config: &InstructionConfig<E>,
instance: &mut [MaybeUninit<E>],
) {
assert!(!self.lhs_limb8.is_empty() && (self.lhs_limb8.len() == self.rhs_limb8.len()));
// TODO: add boundary check for witin
let lt_input = LtInput {
lhs_limbs: &self.lhs_limb8,
rhs_limbs: &self.rhs_limb8,
};
let is_lt = lt_input.generate_witness(witin, &config.is_lt);
let is_lt = lt_input.assign(instance, &config.is_lt);

config.pc.assign(witin, || i64_to_ext(self.pc as i64));
config.next_pc.assign(witin, || {
set_val!(instance, config.pc, { i64_to_ext::<E>(self.pc as i64) });
set_val!(instance, config.next_pc, {
if is_lt {
i64_to_ext(self.pc as i64 + self.imm as i64)
i64_to_ext::<E>(self.pc as i64 + self.imm as i64)
} else {
i64_to_ext(self.pc as i64 + PC_STEP_SIZE as i64)
i64_to_ext::<E>(self.pc as i64 + PC_STEP_SIZE as i64)
}
});
config.ts.assign(witin, || i64_to_ext(self.ts as i64));
config.imm.assign(witin, || i64_to_ext(self.imm as i64));
config
.rs1_id
.assign(witin, || i64_to_ext(self.rs1_id as i64));
config
.rs2_id
.assign(witin, || i64_to_ext(self.rs2_id as i64));
config
.prev_rs1_ts
.assign(witin, || i64_to_ext(self.prev_rs1_ts as i64));
config
.prev_rs2_ts
.assign(witin, || i64_to_ext(self.prev_rs2_ts as i64));

config.lhs_limb8.assign(witin, || {
set_val!(instance, config.ts, { i64_to_ext::<E>(self.ts as i64) });
set_val!(instance, config.imm, { i64_to_ext::<E>(self.imm as i64) });
set_val!(instance, config.rs1_id, {
i64_to_ext::<E>(self.rs1_id as i64)
});
set_val!(instance, config.rs2_id, {
i64_to_ext::<E>(self.rs2_id as i64)
});
set_val!(instance, config.prev_rs1_ts, {
i64_to_ext::<E>(self.prev_rs1_ts as i64)
});
set_val!(instance, config.prev_rs2_ts, {
i64_to_ext::<E>(self.prev_rs2_ts as i64)
});

config.lhs_limb8.assign(instance, {
self.lhs_limb8
.iter()
.map(|&limb| i64_to_ext(limb as i64))
.collect()
});
config.rhs_limb8.assign(witin, || {
config.rhs_limb8.assign(instance, {
self.rhs_limb8
.iter()
.map(|&limb| i64_to_ext(limb as i64))
.collect()
});
let lhs = limb_u8_to_u16(&self.lhs_limb8);
let rhs = limb_u8_to_u16(&self.rhs_limb8);
config.lhs.assign(witin, || {
config.lhs.assign(instance, {
lhs.iter().map(|&limb| i64_to_ext(limb as i64)).collect()
});
config.rhs.assign(witin, || {
config.rhs.assign(instance, {
rhs.iter().map(|&limb| i64_to_ext(limb as i64)).collect()
});
}
Expand Down Expand Up @@ -210,31 +213,35 @@ impl<E: ExtensionField> Instruction<E> for BltInstruction {
) -> Result<InstructionConfig<E>, ZKVMError> {
blt_gadget::<E>(circuit_builder)
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [std::mem::MaybeUninit<E>],
_step: ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
// take input from _step
let input = BltInput::random();
input.assign(config, instance);
Ok(())
}
}

#[cfg(test)]
mod test {
use super::*;
use ff::Field;
use ceno_emul::StepRecord;
use ff_ext::ExtensionField;
use goldilocks::GoldilocksExt2;
use multilinear_extensions::mle::IntoMLE;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
scheme::mock_prover::MockProver,
};

use super::{BltInput, BltInstruction};

fn interleave<T: Clone>(vectors: Vec<Vec<T>>) -> Vec<Vec<T>> {
let len = vectors.first().map_or(0, Vec::len);

(0..len)
.map(|i| vectors.iter().map(|vec| vec[i].clone()).collect())
.collect()
}
use super::BltInstruction;

#[test]
fn test_blt_circuit() -> Result<(), ZKVMError> {
Expand All @@ -245,21 +252,24 @@ mod test {
let num_wits = circuit_builder.cs.num_witin as usize;
// generate mock witness
let num_instances = 1 << 4;
let wits_in = (0..num_instances)
.map(|_| {
let input = BltInput::random();
let mut witin: Vec<GoldilocksExt2> = Vec::with_capacity(num_wits);
witin.resize(num_wits, GoldilocksExt2::ZERO);
input.generate_witness(&mut witin, &config);
witin
})
.collect();
let wits_in = interleave(wits_in)
.iter()
.map(|witin| witin.clone().into_mle().into())
.collect::<Vec<_>>();

MockProver::run(&mut circuit_builder, &wits_in, None).expect_err("lookup will fail");
let raw_witin = BltInstruction::assign_instances(
&config,
num_wits,
vec![StepRecord::default(); num_instances],
)
.unwrap();

MockProver::run(
&mut circuit_builder,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
)
.expect_err("lookup will fail");
Ok(())
}

Expand Down
70 changes: 40 additions & 30 deletions ceno_zkvm/src/instructions/riscv/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{expression::WitIn, utils::i64_to_ext};
use std::mem::MaybeUninit;

use crate::{expression::WitIn, set_val, utils::i64_to_ext};
use ff_ext::ExtensionField;
use itertools::Itertools;

Expand All @@ -21,20 +23,20 @@ pub struct MsbInput<'a> {
}

impl MsbInput<'_> {
pub fn generate_witness<E: ExtensionField>(
pub fn assign<E: ExtensionField>(
&self,
witin: &mut [E],
instance: &mut [MaybeUninit<E>],
config: &MsbConfig,
) -> (u8, u8) {
let n_limbs = self.limbs.len();
assert!(n_limbs > 0);
let mut high_limb = self.limbs[n_limbs - 1];
let msb = (high_limb >> 7) & 1;
config.msb.assign(witin, || i64_to_ext(msb as i64));
set_val!(instance, config.msb, { i64_to_ext::<E>(msb as i64) });
high_limb &= 0b0111_1111;
config
.high_limb_no_msb
.assign(witin, || i64_to_ext(high_limb as i64));
set_val!(instance, config.high_limb_no_msb, {
i64_to_ext::<E>(high_limb as i64)
});
(msb, high_limb)
}
}
Expand All @@ -55,7 +57,11 @@ pub struct LtuInput<'a> {
}

impl LtuInput<'_> {
pub fn generate_witness<E: ExtensionField>(&self, witin: &mut [E], config: &LtuConfig) -> bool {
pub fn assign<E: ExtensionField>(
&self,
instance: &mut [MaybeUninit<E>],
config: &LtuConfig,
) -> bool {
let mut idx = 0;
let mut flag: bool = false;
for (i, (&lhs, &rhs)) in self
Expand All @@ -71,27 +77,29 @@ impl LtuInput<'_> {
break;
}
}
config.indexes[idx].assign(witin, || i64_to_ext(flag as i64));
set_val!(instance, config.indexes[idx], {
i64_to_ext::<E>(flag as i64)
});
config.acc_indexes.iter().enumerate().for_each(|(id, wit)| {
if id <= idx {
wit.assign(witin, || i64_to_ext(flag as i64));
set_val!(instance, wit, { i64_to_ext::<E>(flag as i64) });
} else {
wit.assign(witin, || E::ZERO);
set_val!(instance, wit, E::ZERO);
}
});
let lhs_ne_byte = i64_to_ext(self.lhs_limbs[idx] as i64);
let rhs_ne_byte = i64_to_ext(self.rhs_limbs[idx] as i64);
config.lhs_ne_byte.assign(witin, || lhs_ne_byte);
config.rhs_ne_byte.assign(witin, || rhs_ne_byte);
config.byte_diff_inv.assign(witin, || {
let lhs_ne_byte = i64_to_ext::<E>(self.lhs_limbs[idx] as i64);
let rhs_ne_byte = i64_to_ext::<E>(self.rhs_limbs[idx] as i64);
set_val!(instance, config.lhs_ne_byte, lhs_ne_byte);
set_val!(instance, config.rhs_ne_byte, rhs_ne_byte);
set_val!(instance, config.byte_diff_inv, {
if flag {
(lhs_ne_byte - rhs_ne_byte).invert().unwrap()
} else {
E::ONE
}
});
let is_ltu = self.lhs_limbs[idx] < self.rhs_limbs[idx];
config.is_ltu.assign(witin, || i64_to_ext(is_ltu as i64));
set_val!(instance, config.is_ltu, { i64_to_ext::<E>(is_ltu as i64) });
is_ltu
}
}
Expand All @@ -112,18 +120,20 @@ pub struct LtInput<'a> {
}

impl LtInput<'_> {
pub fn generate_witness<E: ExtensionField>(&self, witin: &mut [E], config: &LtConfig) -> bool {
pub fn assign<E: ExtensionField>(
&self,
instance: &mut [MaybeUninit<E>],
config: &LtConfig,
) -> bool {
let n_limbs = self.lhs_limbs.len();
let lhs_msb_input = MsbInput {
limbs: self.lhs_limbs,
};
let (lhs_msb, lhs_high_limb_no_msb) =
lhs_msb_input.generate_witness(witin, &config.lhs_msb);
let (lhs_msb, lhs_high_limb_no_msb) = lhs_msb_input.assign(instance, &config.lhs_msb);
let rhs_msb_input = MsbInput {
limbs: self.rhs_limbs,
};
let (rhs_msb, rhs_high_limb_no_msb) =
rhs_msb_input.generate_witness(witin, &config.rhs_msb);
let (rhs_msb, rhs_high_limb_no_msb) = rhs_msb_input.assign(instance, &config.rhs_msb);

let mut lhs_limbs_no_msb = self.lhs_limbs.iter().copied().collect_vec();
lhs_limbs_no_msb[n_limbs - 1] = lhs_high_limb_no_msb;
Expand All @@ -135,24 +145,24 @@ impl LtInput<'_> {
lhs_limbs: &lhs_limbs_no_msb,
rhs_limbs: &rhs_limbs_no_msb,
};
let is_ltu = ltu_input.generate_witness(witin, &config.is_ltu);
let is_ltu = ltu_input.assign(instance, &config.is_ltu);

let msb_is_equal = lhs_msb == rhs_msb;
let msb_diff_inv = if msb_is_equal {
0
} else {
lhs_msb as i64 - rhs_msb as i64
};
config
.msb_is_equal
.assign(witin, || i64_to_ext(msb_is_equal as i64));
config
.msb_diff_inv
.assign(witin, || i64_to_ext(msb_diff_inv));
set_val!(instance, config.msb_is_equal, {
i64_to_ext::<E>(msb_is_equal as i64)
});
set_val!(instance, config.msb_diff_inv, {
i64_to_ext::<E>(msb_diff_inv)
});

// is_lt = a_s\cdot (1-b_s)+eq(a_s,b_s)\cdot ltu(a_{<s},b_{<s})$
let is_lt = lhs_msb * (1 - rhs_msb) + msb_is_equal as u8 * is_ltu as u8;
config.is_lt.assign(witin, || i64_to_ext(is_lt as i64));
set_val!(instance, config.is_lt, { i64_to_ext::<E>(is_lt as i64) });

assert!(is_lt == 0 || is_lt == 1);
is_lt > 0
Expand Down
11 changes: 4 additions & 7 deletions ceno_zkvm/src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
circuit_builder::CircuitBuilder,
error::{UtilError, ZKVMError},
expression::{Expression, ToExpr, WitIn},
set_val,
utils::add_one_to_big_num,
};
use ark_std::iterable::Iterable;
Expand All @@ -14,7 +15,7 @@ use ff::Field;
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use itertools::Itertools;
use std::ops::Index;
use std::{mem::MaybeUninit, ops::Index};
pub use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use sumcheck::util::ceil_log2;
Expand Down Expand Up @@ -122,11 +123,7 @@ impl<const M: usize, const C: usize, E: ExtensionField> UInt<M, C, E> {
}
}

pub fn assign<V>(&self, witin: &mut [E], to: V)
where
V: FnOnce() -> Vec<E>,
{
let values = to();
pub fn assign(&self, instance: &mut [MaybeUninit<E>], values: Vec<E>) {
assert!(
values.len() == Self::NUM_CELLS,
"assign input length mismatch. input_len={}, NUM_CELLS={}",
Expand All @@ -135,7 +132,7 @@ impl<const M: usize, const C: usize, E: ExtensionField> UInt<M, C, E> {
);
if let UintLimb::WitIn(c) = &self.limbs {
for (idx, wire) in c.iter().enumerate() {
witin[wire.id as usize] = values[idx];
set_val!(instance, wire, values[idx]);
}
}
// TODO: handle carries
Expand Down

0 comments on commit 246ee53

Please sign in to comment.