Skip to content

Commit

Permalink
Feat: witness should be basefield elements (#192)
Browse files Browse the repository at this point in the history
The matrix for holding circuit witness should over base field instead of
extension field. This can save proving time when the prover infers each
record's MLE as `E x B` is much faster than `E x E`.
  • Loading branch information
kunxian-xia authored and hero78119 committed Sep 30, 2024
1 parent d472e27 commit 2bce265
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 59 deletions.
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ pub trait Instruction<E: ExtensionField> {
// assign single instance giving step from trace
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E>],
instance: &mut [MaybeUninit<E::BaseField>],
step: StepRecord,
) -> Result<(), ZKVMError>;

fn assign_instances(
config: &Self::InstructionConfig,
num_witin: usize,
steps: Vec<StepRecord>,
) -> Result<RowMajorMatrix<E>, ZKVMError> {
let mut raw_witin = RowMajorMatrix::<E>::new(steps.len(), num_witin);
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
let mut raw_witin = RowMajorMatrix::<E::BaseField>::new(steps.len(), num_witin);
let raw_witin_iter = raw_witin.par_iter_mut();

raw_witin_iter
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
#[allow(clippy::option_map_unit_fn)]
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E>],
instance: &mut [MaybeUninit<E::BaseField>],
_step: StepRecord,
) -> Result<(), ZKVMError> {
// TODO use field from step
Expand Down Expand Up @@ -192,7 +192,7 @@ impl<E: ExtensionField> Instruction<E> for SubInstruction {
#[allow(clippy::option_map_unit_fn)]
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E>],
instance: &mut [MaybeUninit<E::BaseField>],
_step: StepRecord,
) -> Result<(), ZKVMError> {
// TODO use field from step
Expand Down
50 changes: 24 additions & 26 deletions ceno_zkvm/src/instructions/riscv/blt.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use goldilocks::SmallField;
use std::mem::MaybeUninit;

use ff_ext::ExtensionField;
Expand All @@ -13,7 +14,7 @@ use crate::{
Instruction,
},
set_val,
utils::{i64_to_ext, limb_u8_to_u16},
utils::{i64_to_base, limb_u8_to_u16},
};

use super::{
Expand Down Expand Up @@ -53,61 +54,65 @@ pub struct BltInput {

impl BltInput {
/// TODO: refactor after formalize the interface of opcode inputs
pub fn assign<E: ExtensionField>(
pub fn assign<F: SmallField, E: ExtensionField<BaseField = F>>(
&self,
config: &InstructionConfig<E>,
instance: &mut [MaybeUninit<E>],
instance: &mut [MaybeUninit<F>],
) {
assert!(!self.lhs_limb8.is_empty() && (self.lhs_limb8.len() == self.rhs_limb8.len()));
// TODO: add boundary check for witin
let lt_input = LtInput {
lhs_limbs: &self.lhs_limb8,
rhs_limbs: &self.rhs_limb8,
};
let is_lt = lt_input.assign(instance, &config.is_lt);
let is_lt = lt_input.assign::<E::BaseField>(instance, &config.is_lt);

set_val!(instance, config.pc, { i64_to_ext::<E>(self.pc as i64) });
set_val!(instance, config.pc, { i64_to_base::<F>(self.pc as i64) });
set_val!(instance, config.next_pc, {
if is_lt {
i64_to_ext::<E>(self.pc as i64 + self.imm as i64)
i64_to_base::<F>(self.pc as i64 + self.imm as i64)
} else {
i64_to_ext::<E>(self.pc as i64 + PC_STEP_SIZE as i64)
i64_to_base::<F>(self.pc as i64 + PC_STEP_SIZE as i64)
}
});
set_val!(instance, config.ts, { i64_to_ext::<E>(self.ts as i64) });
set_val!(instance, config.imm, { i64_to_ext::<E>(self.imm as i64) });
set_val!(instance, config.ts, { i64_to_base::<F>(self.ts as i64) });
set_val!(instance, config.imm, { i64_to_base::<F>(self.imm as i64) });
set_val!(instance, config.rs1_id, {
i64_to_ext::<E>(self.rs1_id as i64)
i64_to_base::<F>(self.rs1_id as i64)
});
set_val!(instance, config.rs2_id, {
i64_to_ext::<E>(self.rs2_id as i64)
i64_to_base::<F>(self.rs2_id as i64)
});
set_val!(instance, config.prev_rs1_ts, {
i64_to_ext::<E>(self.prev_rs1_ts as i64)
i64_to_base::<F>(self.prev_rs1_ts as i64)
});
set_val!(instance, config.prev_rs2_ts, {
i64_to_ext::<E>(self.prev_rs2_ts as i64)
i64_to_base::<F>(self.prev_rs2_ts as i64)
});

config.lhs_limb8.assign(instance, {
self.lhs_limb8
.iter()
.map(|&limb| i64_to_ext(limb as i64))
.map(|&limb| i64_to_base::<F>(limb as i64))
.collect()
});
config.rhs_limb8.assign(instance, {
self.rhs_limb8
.iter()
.map(|&limb| i64_to_ext(limb as i64))
.map(|&limb| i64_to_base::<F>(limb as i64))
.collect()
});
let lhs = limb_u8_to_u16(&self.lhs_limb8);
let rhs = limb_u8_to_u16(&self.rhs_limb8);
config.lhs.assign(instance, {
lhs.iter().map(|&limb| i64_to_ext(limb as i64)).collect()
lhs.iter()
.map(|&limb| i64_to_base::<F>(limb as i64))
.collect()
});
config.rhs.assign(instance, {
rhs.iter().map(|&limb| i64_to_ext(limb as i64)).collect()
rhs.iter()
.map(|&limb| i64_to_base::<F>(limb as i64))
.collect()
});
}

Expand Down Expand Up @@ -216,7 +221,7 @@ impl<E: ExtensionField> Instruction<E> for BltInstruction {

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [std::mem::MaybeUninit<E>],
instance: &mut [std::mem::MaybeUninit<E::BaseField>],
_step: ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
// take input from _step
Expand All @@ -230,18 +235,11 @@ impl<E: ExtensionField> Instruction<E> for BltInstruction {
mod test {
use super::*;
use ceno_emul::StepRecord;
use ff_ext::ExtensionField;
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
scheme::mock_prover::MockProver,
};

use super::BltInstruction;
use crate::{circuit_builder::ConstraintSystem, scheme::mock_prover::MockProver};

#[test]
fn test_blt_circuit() -> Result<(), ZKVMError> {
Expand Down
42 changes: 21 additions & 21 deletions ceno_zkvm/src/instructions/riscv/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::mem::MaybeUninit;

use crate::{expression::WitIn, set_val, utils::i64_to_ext};
use ff_ext::ExtensionField;
use crate::{expression::WitIn, set_val, utils::i64_to_base};
use goldilocks::SmallField;
use itertools::Itertools;

#[derive(Clone)]
Expand All @@ -23,19 +23,19 @@ pub struct MsbInput<'a> {
}

impl MsbInput<'_> {
pub fn assign<E: ExtensionField>(
pub fn assign<F: SmallField>(
&self,
instance: &mut [MaybeUninit<E>],
instance: &mut [MaybeUninit<F>],
config: &MsbConfig,
) -> (u8, u8) {
let n_limbs = self.limbs.len();
assert!(n_limbs > 0);
let mut high_limb = self.limbs[n_limbs - 1];
let msb = (high_limb >> 7) & 1;
set_val!(instance, config.msb, { i64_to_ext::<E>(msb as i64) });
set_val!(instance, config.msb, { i64_to_base::<F>(msb as i64) });
high_limb &= 0b0111_1111;
set_val!(instance, config.high_limb_no_msb, {
i64_to_ext::<E>(high_limb as i64)
i64_to_base::<F>(high_limb as i64)
});
(msb, high_limb)
}
Expand All @@ -57,9 +57,9 @@ pub struct LtuInput<'a> {
}

impl LtuInput<'_> {
pub fn assign<E: ExtensionField>(
pub fn assign<F: SmallField>(
&self,
instance: &mut [MaybeUninit<E>],
instance: &mut [MaybeUninit<F>],
config: &LtuConfig,
) -> bool {
let mut idx = 0;
Expand All @@ -78,28 +78,28 @@ impl LtuInput<'_> {
}
}
set_val!(instance, config.indexes[idx], {
i64_to_ext::<E>(flag as i64)
i64_to_base::<F>(flag as i64)
});
config.acc_indexes.iter().enumerate().for_each(|(id, wit)| {
if id <= idx {
set_val!(instance, wit, { i64_to_ext::<E>(flag as i64) });
set_val!(instance, wit, { i64_to_base::<F>(flag as i64) });
} else {
set_val!(instance, wit, E::ZERO);
set_val!(instance, wit, 0);
}
});
let lhs_ne_byte = i64_to_ext::<E>(self.lhs_limbs[idx] as i64);
let rhs_ne_byte = i64_to_ext::<E>(self.rhs_limbs[idx] as i64);
let lhs_ne_byte = i64_to_base::<F>(self.lhs_limbs[idx] as i64);
let rhs_ne_byte = i64_to_base::<F>(self.rhs_limbs[idx] as i64);
set_val!(instance, config.lhs_ne_byte, lhs_ne_byte);
set_val!(instance, config.rhs_ne_byte, rhs_ne_byte);
set_val!(instance, config.byte_diff_inv, {
if flag {
(lhs_ne_byte - rhs_ne_byte).invert().unwrap()
} else {
E::ONE
F::ONE
}
});
let is_ltu = self.lhs_limbs[idx] < self.rhs_limbs[idx];
set_val!(instance, config.is_ltu, { i64_to_ext::<E>(is_ltu as i64) });
set_val!(instance, config.is_ltu, { i64_to_base::<F>(is_ltu as i64) });
is_ltu
}
}
Expand All @@ -120,9 +120,9 @@ pub struct LtInput<'a> {
}

impl LtInput<'_> {
pub fn assign<E: ExtensionField>(
pub fn assign<F: SmallField>(
&self,
instance: &mut [MaybeUninit<E>],
instance: &mut [MaybeUninit<F>],
config: &LtConfig,
) -> bool {
let n_limbs = self.lhs_limbs.len();
Expand All @@ -145,7 +145,7 @@ impl LtInput<'_> {
lhs_limbs: &lhs_limbs_no_msb,
rhs_limbs: &rhs_limbs_no_msb,
};
let is_ltu = ltu_input.assign(instance, &config.is_ltu);
let is_ltu = ltu_input.assign::<F>(instance, &config.is_ltu);

let msb_is_equal = lhs_msb == rhs_msb;
let msb_diff_inv = if msb_is_equal {
Expand All @@ -154,15 +154,15 @@ impl LtInput<'_> {
lhs_msb as i64 - rhs_msb as i64
};
set_val!(instance, config.msb_is_equal, {
i64_to_ext::<E>(msb_is_equal as i64)
i64_to_base::<F>(msb_is_equal as i64)
});
set_val!(instance, config.msb_diff_inv, {
i64_to_ext::<E>(msb_diff_inv)
i64_to_base::<F>(msb_diff_inv)
});

// is_lt = a_s\cdot (1-b_s)+eq(a_s,b_s)\cdot ltu(a_{<s},b_{<s})$
let is_lt = lhs_msb * (1 - rhs_msb) + msb_is_equal as u8 * is_ltu as u8;
set_val!(instance, config.is_lt, { i64_to_ext::<E>(is_lt as i64) });
set_val!(instance, config.is_lt, { i64_to_base::<F>(is_lt as i64) });

assert!(is_lt == 0 || is_lt == 1);
is_lt > 0
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl<const M: usize, const C: usize, E: ExtensionField> UInt<M, C, E> {
}
}

pub fn assign(&self, instance: &mut [MaybeUninit<E>], values: Vec<E>) {
pub fn assign(&self, instance: &mut [MaybeUninit<E::BaseField>], values: Vec<E::BaseField>) {
assert!(
values.len() == Self::NUM_CELLS,
"assign input length mismatch. input_len={}, NUM_CELLS={}",
Expand Down
11 changes: 5 additions & 6 deletions ceno_zkvm/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ pub fn ext_to_u64<E: ExtensionField>(x: &E) -> u64 {
bases[0].to_canonical_u64()
}

pub fn i64_to_ext<E: ExtensionField>(x: i64) -> E {
let x0 = if x >= 0 {
E::BaseField::from(x as u64)
pub fn i64_to_base<F: SmallField>(x: i64) -> F {
if x >= 0 {
F::from(x as u64)
} else {
-E::BaseField::from((-x) as u64)
};
E::from_bases(&[x0, E::BaseField::ZERO])
-F::from((-x) as u64)
}
}

/// This is helper function to convert witness of u8 limb into u16 limb
Expand Down

0 comments on commit 2bce265

Please sign in to comment.