Skip to content

Commit

Permalink
All range tables (#236)
Browse files Browse the repository at this point in the history
_Issue #214_

A generalization of #234 for all range or scalar tables.

- Moved and refactored the implementation of u16.
- Implementation without generics because I find it cleaner and it
compiles faster.
- Definition of separate circuits `U5TableCircuit`, `U8TableCircuit`,
`U16TableCircuit` using a parameter trait.
- Fix soundness of `assert_u8`.

---------

Co-authored-by: Aurélien Nicolas <[email protected]>
  • Loading branch information
naure and Aurélien Nicolas authored Sep 18, 2024
1 parent 713461b commit 36392db
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 89 deletions.
8 changes: 4 additions & 4 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM};
use ceno_zkvm::{
scheme::verifier::ZKVMVerifier,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::RangeTableCircuit,
tables::U16TableCircuit,
};
use ff_ext::ff::Field;
use goldilocks::GoldilocksExt2;
Expand Down Expand Up @@ -92,12 +92,12 @@ fn main() {
// keygen
let mut zkvm_cs = ZKVMConstraintSystem::default();
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let range_config = zkvm_cs.register_table_circuit::<RangeTableCircuit<E>>();
let range_config = zkvm_cs.register_table_circuit::<U16TableCircuit<E>>();
let prog_config = zkvm_cs.register_table_circuit::<ProgramTableCircuit<E>>();

let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces.register_table_circuit::<RangeTableCircuit<E>>(
zkvm_fixed_traces.register_table_circuit::<U16TableCircuit<E>>(
&zkvm_cs,
range_config.clone(),
&(),
Expand Down Expand Up @@ -148,7 +148,7 @@ fn main() {
zkvm_witness.finalize_lk_multiplicities();
// assign table circuits
zkvm_witness
.assign_table_circuit::<RangeTableCircuit<E>>(&zkvm_cs, &range_config, &())
.assign_table_circuit::<U16TableCircuit<E>>(&zkvm_cs, &range_config, &())
.unwrap();
zkvm_witness
.assign_table_circuit::<ProgramTableCircuit<E>>(
Expand Down
6 changes: 5 additions & 1 deletion ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
self.assert_u16(name_fn, expr * Expression::from(1 << 8))
let items: Vec<Expression<E>> = vec![(ROMType::U8 as usize).into(), expr];
let rlc_record = self.rlc_chip_record(items);
self.lk_record(name_fn, rlc_record)?;
Ok(())
}

pub(crate) fn assert_bit<NR, N>(
Expand All @@ -228,6 +231,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
// TODO: Replace with `x * (1 - x)` or a multi-bit lookup similar to assert_u8_pair.
self.assert_u16(name_fn, expr * Expression::from(1 << 15))
}

Expand Down
13 changes: 13 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,18 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
}
}

fn load_u8_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for i in 0..=u8::MAX as usize {
let rlc_record = cb.rlc_chip_record(vec![(ROMType::U8 as usize).into(), i.into()]);
let rlc_record = eval_by_expr(&[], &challenge, &rlc_record);
t_vec.push(rlc_record.to_repr().as_ref().to_vec());
}
}

fn load_u16_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
Expand Down Expand Up @@ -347,6 +359,7 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
let mut table_vec = vec![];
// TODO load more tables here
load_u5_table(&mut table_vec, cb, challenge);
load_u8_table(&mut table_vec, cb, challenge);
load_u16_table(&mut table_vec, cb, challenge);
load_lt_table(&mut table_vec, cb, challenge);
load_and_table(&mut table_vec, cb, challenge);
Expand Down
3 changes: 2 additions & 1 deletion ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ pub struct TowerProverSpec<'a, E: ExtensionField> {
pub type WitnessId = u16;
pub type ChallengeId = u16;

#[derive(Debug)]
#[derive(Copy, Clone, Debug)]
pub enum ROMType {
U5 = 0, // 2^5 = 32
U8, // 2^8 = 256
U16, // 2^16 = 65,536
And, // a ^ b where a, b are bytes
Ltu, // a <(usign) b where a, b are bytes
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/tables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use ff_ext::ExtensionField;
use std::collections::HashMap;

mod range;
pub use range::RangeTableCircuit;
pub use range::*;

mod program;
pub use program::{InsnRecord, ProgramTableCircuit};
Expand Down
104 changes: 25 additions & 79 deletions ceno_zkvm/src/tables/range.rs
Original file line number Diff line number Diff line change
@@ -1,89 +1,35 @@
use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit};
//! Definition of the range tables and their circuits.
use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
scheme::constants::MIN_PAR_SIZE,
set_fixed_val, set_val,
structs::ROMType,
tables::TableCircuit,
uint::constants::RANGE_CHIP_BIT_WIDTH,
witness::RowMajorMatrix,
};
use ff_ext::ExtensionField;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
mod range_impl;

#[derive(Clone, Debug)]
pub struct RangeTableConfig {
u16_tbl: Fixed,
u16_mlt: WitIn,
}

pub struct RangeTableCircuit<E>(PhantomData<E>);

impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
type TableConfig = RangeTableConfig;
type FixedInput = ();
type WitnessInput = ();

fn name() -> String {
"RANGE".into()
}

fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<RangeTableConfig, ZKVMError> {
let u16_tbl = cb.create_fixed(|| "u16_tbl")?;
let u16_mlt = cb.create_witin(|| "u16_mlt")?;
mod range_circuit;
use range_circuit::{RangeTable, RangeTableCircuit};

let u16_table_values = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::U16 as u64)),
Expression::Fixed(u16_tbl.clone()),
]);
use crate::structs::ROMType;

cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?;

Ok(RangeTableConfig { u16_tbl, u16_mlt })
pub struct U5Table;
impl RangeTable for U5Table {
const ROM_TYPE: ROMType = ROMType::U5;
fn len() -> usize {
1 << 5
}
}
pub type U5TableCircuit<E> = RangeTableCircuit<E, U5Table>;

fn generate_fixed_traces(
config: &RangeTableConfig,
num_fixed: usize,
_input: &(),
) -> RowMajorMatrix<E::BaseField> {
let num_u16s = 1 << 16;
let mut fixed = RowMajorMatrix::<E::BaseField>::new(num_u16s, num_fixed);
fixed
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip((0..num_u16s).into_par_iter())
.for_each(|(row, i)| {
set_fixed_val!(row, config.u16_tbl, E::BaseField::from(i as u64));
});

fixed
pub struct U8Table;
impl RangeTable for U8Table {
const ROM_TYPE: ROMType = ROMType::U8;
fn len() -> usize {
1 << 8
}
}
pub type U8TableCircuit<E> = RangeTableCircuit<E, U8Table>;

fn assign_instances(
config: &Self::TableConfig,
num_witin: usize,
multiplicity: &[HashMap<u64, usize>],
_input: &(),
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
let multiplicity = &multiplicity[ROMType::U16 as usize];
let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH];
for (limb, mlt) in multiplicity {
u16_mlt[*limb as usize] = *mlt;
}

let mut witness = RowMajorMatrix::<E::BaseField>::new(u16_mlt.len(), num_witin);
witness
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(u16_mlt.into_par_iter())
.for_each(|(row, mlt)| {
set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64));
});

Ok(witness)
pub struct U16Table;
impl RangeTable for U16Table {
const ROM_TYPE: ROMType = ROMType::U16;
fn len() -> usize {
1 << 16
}
}
pub type U16TableCircuit<E> = RangeTableCircuit<E, U16Table>;
59 changes: 59 additions & 0 deletions ceno_zkvm/src/tables/range/range_circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//! Range tables as circuits with trait TableCircuit.
use super::range_impl::RangeTableConfig;

use std::{collections::HashMap, marker::PhantomData};

use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, structs::ROMType, tables::TableCircuit,
witness::RowMajorMatrix,
};
use ff_ext::ExtensionField;

/// Use this trait as parameter to RangeTableCircuit.
pub trait RangeTable {
const ROM_TYPE: ROMType;

fn len() -> usize;

fn content() -> Vec<u64> {
(0..Self::len() as u64).collect()
}
}

pub struct RangeTableCircuit<E, R>(PhantomData<(E, R)>);

impl<E: ExtensionField, RANGE: RangeTable> TableCircuit<E> for RangeTableCircuit<E, RANGE> {
type TableConfig = RangeTableConfig;
type FixedInput = ();
type WitnessInput = ();

fn name() -> String {
format!("RANGE_{:?}", RANGE::ROM_TYPE)
}

fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<RangeTableConfig, ZKVMError> {
cb.namespace(
|| Self::name(),
|cb| RangeTableConfig::construct_circuit(cb, RANGE::ROM_TYPE),
)
}

fn generate_fixed_traces(
config: &RangeTableConfig,
num_fixed: usize,
_input: &(),
) -> RowMajorMatrix<E::BaseField> {
config.generate_fixed_traces(num_fixed, RANGE::content())
}

fn assign_instances(
config: &Self::TableConfig,
num_witin: usize,
multiplicity: &[HashMap<u64, usize>],
_input: &(),
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize];
config.assign_instances(num_witin, multiplicity, RANGE::len())
}
}
93 changes: 93 additions & 0 deletions ceno_zkvm/src/tables/range/range_impl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//! The implementation of range tables. No generics.
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use std::{collections::HashMap, mem::MaybeUninit};

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
scheme::constants::MIN_PAR_SIZE,
set_fixed_val, set_val,
structs::ROMType,
witness::RowMajorMatrix,
};

#[derive(Clone, Debug)]
pub struct RangeTableConfig {
fixed: Fixed,
mlt: WitIn,
}

impl RangeTableConfig {
pub fn construct_circuit<E: ExtensionField>(
cb: &mut CircuitBuilder<E>,
rom_type: ROMType,
) -> Result<Self, ZKVMError> {
let fixed = cb.create_fixed(|| "fixed")?;
let mlt = cb.create_witin(|| "mlt")?;

let rlc_record = cb.rlc_chip_record(vec![
(rom_type as usize).into(),
Expression::Fixed(fixed.clone()),
]);

cb.lk_table_record(|| "record", rlc_record, mlt.expr())?;

Ok(Self { fixed, mlt })
}

pub fn generate_fixed_traces<F: SmallField>(
&self,
num_fixed: usize,
content: Vec<u64>,
) -> RowMajorMatrix<F> {
let mut fixed = RowMajorMatrix::<F>::new(content.len(), num_fixed);

// Fill the padding with zeros, if any.
fixed.par_iter_mut().skip(content.len()).for_each(|row| {
set_fixed_val!(row, self.fixed, F::ZERO);
});

fixed
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(content.into_par_iter())
.for_each(|(row, i)| {
set_fixed_val!(row, self.fixed, F::from(i));
});

fixed
}

pub fn assign_instances<F: SmallField>(
&self,
num_witin: usize,
multiplicity: &HashMap<u64, usize>,
length: usize,
) -> Result<RowMajorMatrix<F>, ZKVMError> {
let mut witness = RowMajorMatrix::<F>::new(length, num_witin);

let mut mlts = vec![0; length];
for (idx, mlt) in multiplicity {
mlts[*idx as usize] = *mlt;
}

// Fill the padding with zeros, if any.
witness.par_iter_mut().skip(length).for_each(|row| {
set_val!(row, self.mlt, F::ZERO);
});

witness
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(mlts.into_par_iter())
.for_each(|(row, mlt)| {
set_val!(row, self.mlt, F::from(mlt as u64));
});

Ok(witness)
}
}
5 changes: 2 additions & 3 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,10 @@ impl LkMultiplicity {
}

fn assert_byte(&mut self, v: u64) {
let v = v * (1 << u8::BITS);
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U16 as usize]
(*multiplicity.borrow_mut()[ROMType::U8 as usize]
.entry(v)
.or_default()) += 1;
}
Expand Down Expand Up @@ -189,6 +188,6 @@ mod tests {
}
let res = lkm.into_finalize_result();
// check multiplicity counts of assert_byte
assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count);
assert_eq!(res[ROMType::U8 as usize][&8], thread_count);
}
}

0 comments on commit 36392db

Please sign in to comment.