Skip to content

Commit

Permalink
logup multiplicity in witness assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Sep 9, 2024
1 parent a225ab7 commit 96d357e
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 19 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tracing-flame = "0.2.0"
tracing = "0.1.40"

rand = "0.8"
thread_local = "1.1.8"

[dev-dependencies]
pprof = { version = "0.13", features = ["flamegraph"]}
Expand Down
17 changes: 13 additions & 4 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use ceno_emul::StepRecord;
use ff_ext::ExtensionField;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};

use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix};
use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
witness::{LkMultiplicity, RowMajorMatrix},
};

pub mod riscv;

Expand All @@ -18,22 +22,27 @@ pub trait Instruction<E: ExtensionField> {
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: StepRecord,
) -> Result<(), ZKVMError>;

fn assign_instances(
config: &Self::InstructionConfig,
num_witin: usize,
steps: Vec<StepRecord>,
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
) -> Result<(RowMajorMatrix<E::BaseField>, LkMultiplicity), ZKVMError> {
let lk_multiplicity = LkMultiplicity::default();
let mut raw_witin = RowMajorMatrix::<E::BaseField>::new(steps.len(), num_witin);
let raw_witin_iter = raw_witin.par_iter_mut();

raw_witin_iter
.zip_eq(steps.into_par_iter())
.map(|(instance, step)| Self::assign_instance(config, instance, step))
.map(|(instance, step)| {
let mut lk_multiplicity = lk_multiplicity.clone();
Self::assign_instance(config, instance, &mut lk_multiplicity, step)
})
.collect::<Result<(), ZKVMError>>()?;

Ok(raw_witin)
Ok((raw_witin, lk_multiplicity))
}
}
13 changes: 8 additions & 5 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::{
instructions::Instruction,
set_val,
uint::UIntValue,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

Expand Down Expand Up @@ -151,13 +152,14 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: StepRecord,
) -> Result<(), ZKVMError> {
// TODO use fields from step
set_val!(instance, config.pc, 1);
set_val!(instance, config.ts, 2);
let addend_0 = UIntValue::new(step.rs1().unwrap().value);
let addend_1 = UIntValue::new(step.rs2().unwrap().value);
let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value);
let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value);
config
.prev_rd_value
.assign_limbs(instance, [0, 0].iter().map(E::BaseField::from).collect());
Expand All @@ -167,7 +169,7 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
config
.addend_1
.assign_limbs(instance, addend_1.u16_fields());
let carries = addend_0.add_u16_carries(&addend_1);
let (_, carries) = addend_0.add(&addend_1, lk_multiplicity, true);
config.outcome.assign_carries(
instance,
carries
Expand Down Expand Up @@ -199,6 +201,7 @@ impl<E: ExtensionField> Instruction<E> for SubInstruction {
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
_lk_multiplicity: &mut LkMultiplicity,
_step: StepRecord,
) -> Result<(), ZKVMError> {
// TODO use field from step
Expand Down Expand Up @@ -263,7 +266,7 @@ mod test {
.unwrap()
.unwrap();

let raw_witin = AddInstruction::assign_instances(
let (raw_witin, _) = AddInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord {
Expand Down Expand Up @@ -310,7 +313,7 @@ mod test {
.unwrap()
.unwrap();

let raw_witin = AddInstruction::assign_instances(
let (raw_witin, _) = AddInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord {
Expand Down
5 changes: 3 additions & 2 deletions ceno_zkvm/src/instructions/riscv/blt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
Instruction,
},
set_val,
utils::{i64_to_base, limb_u8_to_u16},
utils::{i64_to_base, limb_u8_to_u16}, witness::LkMultiplicity,
};

use super::{
Expand Down Expand Up @@ -222,6 +222,7 @@ impl<E: ExtensionField> Instruction<E> for BltInstruction {
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [std::mem::MaybeUninit<E::BaseField>],
_lk_multiplicity: &mut LkMultiplicity,
_step: ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
// take input from _step
Expand Down Expand Up @@ -250,7 +251,7 @@ mod test {
let num_wits = circuit_builder.cs.num_witin as usize;
// generate mock witness
let num_instances = 1 << 4;
let raw_witin = BltInstruction::assign_instances(
let (raw_witin, _) = BltInstruction::assign_instances(
&config,
num_wits,
vec![StepRecord::default(); num_instances],
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(box_patterns)]
#![feature(stmt_expr_attributes)]
#![feature(variant_count)]

pub mod error;
pub mod instructions;
Expand Down
48 changes: 40 additions & 8 deletions ceno_zkvm/src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
error::{UtilError, ZKVMError},
expression::{Expression, ToExpr, WitIn},
utils::add_one_to_big_num,
witness::LkMultiplicity,
};
use ark_std::iterable::Iterable;
use constants::BYTE_BIT_WIDTH;
Expand Down Expand Up @@ -476,13 +477,29 @@ impl<T: Into<u64> + Copy> UIntValue<T> {
mem::size_of::<T>() / u16_bytes
};

pub fn new(val: T) -> Self {
#[allow(dead_code)]
pub fn new(val: T, lkm: &mut LkMultiplicity) -> Self {
let uint = UIntValue::<T> {
val,
limbs: Self::split_to_u16(val),
};
Self::assert_u16(&uint.limbs, lkm);
uint
}

pub fn new_unchecked(val: T) -> Self {
UIntValue::<T> {
val,
limbs: Self::split_to_u16(val),
}
}

fn assert_u16(v: &[u16], lkm: &mut LkMultiplicity) {
v.iter().for_each(|v| {
lkm.assert_ux::<16>(*v as u64);
})
}

fn split_to_u16(value: T) -> Vec<u16> {
let value: u64 = value.into(); // Convert to u64 for generality
(0..Self::LIMBS)
Expand All @@ -502,20 +519,35 @@ impl<T: Into<u64> + Copy> UIntValue<T> {
self.limbs.iter().map(|v| F::from(*v as u64)).collect_vec()
}

pub fn add_u16_carries(&self, rhs: &Self) -> Vec<bool> {
self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold(
pub fn add(
&self,
rhs: &Self,
lkm: &mut LkMultiplicity,
with_overflow: bool,
) -> (Vec<u16>, Vec<bool>) {
let res = self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold(
vec![],
|mut acc, (a_limb, b_limb)| {
let (a, b) = a_limb.overflowing_add(*b_limb);
if let Some(prev_carry) = acc.last() {
let (_, d) = a.overflowing_add(*prev_carry as u16);
acc.push(b || d);
if let Some((_, prev_carry)) = acc.last() {
let (e, d) = a.overflowing_add(*prev_carry as u16);
acc.push((e, b || d));
} else {
acc.push(b);
acc.push((a, b));
}
// range check
if let Some((limb, _)) = acc.last() {
lkm.assert_ux::<16>(*limb as u64);
};
acc
},
)
);
let (limbs, mut carries): (Vec<u16>, Vec<bool>) = res.into_iter().unzip();
if !with_overflow {
carries.resize(carries.len() - 1, false);
}
carries.iter().for_each(|c| lkm.assert_ux::<16>(*c as u64));
(limbs, carries)
}
}

Expand Down
92 changes: 92 additions & 0 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
use std::{
array,
cell::RefCell,
collections::HashMap,
mem::{self, MaybeUninit},
slice::ChunksMut,
sync::Arc,
};

use multilinear_extensions::util::create_uninit_vec;
use rayon::{
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
slice::ParallelSliceMut,
};
use thread_local::ThreadLocal;

use crate::structs::ROMType;

#[macro_export]
macro_rules! set_val {
Expand Down Expand Up @@ -51,3 +58,88 @@ impl<T: Sized + Sync + Clone + Send> RowMajorMatrix<T> {
.collect()
}
}

/// A lock-free thread safe struct to count logup multiplicity for each ROM type
/// Lock-free by thread-local such that each thread will only have its local copy
/// struct is cloneable, for internallly it use Arc so the clone will be low cost
#[derive(Clone, Default)]
#[allow(clippy::type_complexity)]
pub struct LkMultiplicity {
multiplicity: Arc<ThreadLocal<RefCell<[HashMap<u64, usize>; mem::variant_count::<ROMType>()]>>>,
}

#[allow(dead_code)]
impl LkMultiplicity {
#[inline(always)]
pub fn assert_ux<const C: usize>(&mut self, v: u64) {
match C {
16 => self.assert_u16(v),
8 => self.assert_byte(v),
5 => self.assert_u5(v),
_ => panic!("Unsupported bit range"),
}
}

fn assert_u5(&mut self, v: u64) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U5 as usize]
.entry(v)
.or_default()) += 1;
}

fn assert_u16(&mut self, v: u64) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U16 as usize]
.entry(v)
.or_default()) += 1;
}

fn assert_byte(&mut self, v: u64) {
let v = v * (1 << u8::BITS);
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U16 as usize]
.entry(v)
.or_default()) += 1;
}

/// merge result from multiple thread local to single result
fn into_finalize_result(self) -> [HashMap<u64, usize>; mem::variant_count::<ROMType>()] {
Arc::try_unwrap(self.multiplicity)
.unwrap()
.into_iter()
.fold(array::from_fn(|_| HashMap::new()), |mut x, y| {
// x.extend(y.get_mut().into_iter().map(|(k, v)| (k.clone(), v.clone())));
x.iter_mut()
.zip(y.borrow().iter())
.for_each(|(m1, m2)| m1.extend(m2.iter().map(|(k, v)| (*k, *v))));
x
})
}
}

#[cfg(test)]
mod tests {
use std::thread;

use crate::{structs::ROMType, witness::LkMultiplicity};

#[test]
fn test_lk_multiplicity_threads() {
let lkm = LkMultiplicity::default();
let thread_count = 20;
// each thread calling assert_byte once
for _ in 0..thread_count {
let mut lkm = lkm.clone();
thread::spawn(move || lkm.assert_byte(8u64)).join().unwrap();
}
let res = lkm.into_finalize_result();
// check multiplicity counts of assert_byte
assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count);
}
}

0 comments on commit 96d357e

Please sign in to comment.