From a8174ae1221852901398db185dec35de931a3e96 Mon Sep 17 00:00:00 2001 From: naure Date: Sat, 14 Sep 2024 15:42:16 +0200 Subject: [PATCH] program-table (#210) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _Issue #110._ --------- Co-authored-by: Aurélien Nicolas --- ceno_emul/src/rv32im.rs | 4 +- ceno_zkvm/.gitignore | 1 + ceno_zkvm/examples/riscv_add.rs | 27 ++- ceno_zkvm/src/chip_handler/general.rs | 12 ++ ceno_zkvm/src/instructions/riscv/addsub.rs | 70 ++++--- ceno_zkvm/src/instructions/riscv/constants.rs | 5 + ceno_zkvm/src/scheme/mock_prover.rs | 49 ++++- ceno_zkvm/src/structs.rs | 14 +- ceno_zkvm/src/tables/mod.rs | 8 +- ceno_zkvm/src/tables/program.rs | 177 ++++++++++++++++++ ceno_zkvm/src/tables/range.rs | 7 +- ceno_zkvm/src/witness.rs | 17 +- 12 files changed, 339 insertions(+), 52 deletions(-) create mode 100644 ceno_zkvm/.gitignore create mode 100644 ceno_zkvm/src/tables/program.rs diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 0f30285fe..08f21302d 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -209,7 +209,7 @@ impl DecodedInstruction { self.rd } - pub fn func3(&self) -> u32 { + pub fn funct3(&self) -> u32 { self.func3 } @@ -221,7 +221,7 @@ impl DecodedInstruction { self.rs2 } - pub fn func7(&self) -> u32 { + pub fn funct7(&self) -> u32 { self.func7 } diff --git a/ceno_zkvm/.gitignore b/ceno_zkvm/.gitignore new file mode 100644 index 000000000..8db038a03 --- /dev/null +++ b/ceno_zkvm/.gitignore @@ -0,0 +1 @@ +tracing.folded diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index af9af772c..d42367172 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -1,7 +1,10 @@ use std::time::Instant; use ark_std::test_rng; -use ceno_zkvm::{instructions::riscv::addsub::AddInstruction, scheme::prover::ZKVMProver}; +use ceno_zkvm::{ + instructions::riscv::addsub::AddInstruction, scheme::prover::ZKVMProver, + tables::ProgramTableCircuit, +}; use clap::Parser; use const_env::from_env; @@ -90,11 +93,20 @@ fn main() { let mut zkvm_cs = ZKVMConstraintSystem::default(); let add_config = zkvm_cs.register_opcode_circuit::>(); let range_config = zkvm_cs.register_table_circuit::>(); + let prog_config = zkvm_cs.register_table_circuit::>(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); - zkvm_fixed_traces - .register_table_circuit::>(&zkvm_cs, range_config.clone()); + zkvm_fixed_traces.register_table_circuit::>( + &zkvm_cs, + range_config.clone(), + &(), + ); + zkvm_fixed_traces.register_table_circuit::>( + &zkvm_cs, + prog_config.clone(), + &PROGRAM_ADD_LOOP, + ); let pk = zkvm_cs .clone() @@ -136,7 +148,14 @@ fn main() { zkvm_witness.finalize_lk_multiplicities(); // assign table circuits zkvm_witness - .assign_table_circuit::>(&zkvm_cs, &range_config) + .assign_table_circuit::>(&zkvm_cs, &range_config, &()) + .unwrap(); + zkvm_witness + .assign_table_circuit::>( + &zkvm_cs, + &prog_config, + &PROGRAM_ADD_LOOP.len(), + ) .unwrap(); let timer = Instant::now(); diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index c1ec9b2a9..29c86e463 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -10,6 +10,7 @@ use crate::{ expression::{Expression, Fixed, ToExpr, WitIn}, instructions::riscv::config::ExprLtConfig, structs::ROMType, + tables::InsnRecord, }; use super::utils::rlc_chip_record; @@ -60,6 +61,17 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.lk_table_record(name_fn, rlc_record, multiplicity) } + /// Fetch an instruction at a given PC from the Program table. + pub fn lk_fetch(&mut self, record: &InsnRecord>) -> Result<(), ZKVMError> { + let rlc_record = { + let mut fields = vec![E::BaseField::from(ROMType::Instruction as u64).expr()]; + fields.extend_from_slice(record.as_slice()); + self.rlc_chip_record(fields) + }; + + self.cs.lk_record(|| "fetch", rlc_record) + } + pub fn read_record( &mut self, name_fn: N, diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 2bccb8576..993cc5b6a 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -6,7 +6,10 @@ use itertools::Itertools; use super::{ config::ExprLtConfig, - constants::{OPType, OpcodeType, RegUInt, PC_STEP_SIZE}, + constants::{ + OPType, OpcodeType, RegUInt, FUNCT3_ADD_SUB, FUNCT7_ADD, FUNCT7_SUB, OPCODE_OP, + PC_STEP_SIZE, + }, RIVInstruction, }; use crate::{ @@ -16,6 +19,7 @@ use crate::{ expression::{ToExpr, WitIn}, instructions::{riscv::config::ExprLtInput, Instruction}, set_val, + tables::InsnRecord, uint::UIntValue, witness::LkMultiplicity, }; @@ -93,6 +97,17 @@ fn add_sub_gadget( let rs2_id = circuit_builder.create_witin(|| "rs2_id")?; let rd_id = circuit_builder.create_witin(|| "rd_id")?; + // Fetch the instruction. + circuit_builder.lk_fetch(&InsnRecord::new( + pc.expr(), + OPCODE_OP.into(), + rd_id.expr(), + FUNCT3_ADD_SUB.into(), + rs1_id.expr(), + rs2_id.expr(), + (if IS_ADD { FUNCT7_ADD } else { FUNCT7_SUB }).into(), + ))?; + let prev_rs1_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; let prev_rs2_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?; let prev_rd_ts = circuit_builder.create_witin(|| "prev_rd_ts")?; @@ -145,6 +160,7 @@ fn add_sub_assignment( lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + lk_multiplicity.fetch(step.pc().before.0); set_val!(instance, config.pc, step.pc().before.0 as u64); set_val!(instance, config.ts, step.cycle()); let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value); @@ -268,7 +284,7 @@ impl Instruction for SubInstruction { #[cfg(test)] mod test { - use ceno_emul::{Change, ReadOp, StepRecord, WriteOp}; + use ceno_emul::{Change, ReadOp, StepRecord, WriteOp, CENO_PLATFORM}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; @@ -276,7 +292,7 @@ mod test { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{riscv::constants::PC_STEP_SIZE, Instruction}, - scheme::mock_prover::MockProver, + scheme::mock_prover::{MockProver, MOCK_PC_ADD, MOCK_PC_SUB, MOCK_PROGRAM}, }; use super::{AddInstruction, SubInstruction}; @@ -302,22 +318,20 @@ mod test { cb.cs.num_witin as usize, vec![StepRecord { cycle: 3, - pc: Change { - before: 2u32.into(), - after: (2u32 + PC_STEP_SIZE as u32).into(), - }, + pc: Change::new(MOCK_PC_ADD, MOCK_PC_ADD + PC_STEP_SIZE), + insn_code: MOCK_PROGRAM[0], rs1: Some(ReadOp { - addr: 2.into(), + addr: CENO_PLATFORM.register_vma(2).into(), value: 11u32, previous_cycle: 0, }), rs2: Some(ReadOp { - addr: 3.into(), + addr: CENO_PLATFORM.register_vma(3).into(), value: 0xfffffffeu32, previous_cycle: 0, }), rd: Some(WriteOp { - addr: 4.into(), + addr: CENO_PLATFORM.register_vma(4).into(), value: Change { before: 0u32, after: 11u32.wrapping_add(0xfffffffeu32), @@ -362,22 +376,20 @@ mod test { cb.cs.num_witin as usize, vec![StepRecord { cycle: 3, - pc: Change { - before: 2u32.into(), - after: (2u32 + PC_STEP_SIZE as u32).into(), - }, + pc: Change::new(MOCK_PC_ADD, MOCK_PC_ADD + PC_STEP_SIZE), + insn_code: MOCK_PROGRAM[0], rs1: Some(ReadOp { - addr: 2.into(), + addr: CENO_PLATFORM.register_vma(2).into(), value: u32::MAX - 1, previous_cycle: 0, }), rs2: Some(ReadOp { - addr: 3.into(), + addr: CENO_PLATFORM.register_vma(3).into(), value: u32::MAX - 1, previous_cycle: 0, }), rd: Some(WriteOp { - addr: 4.into(), + addr: CENO_PLATFORM.register_vma(4).into(), value: Change { before: 0u32, after: (u32::MAX - 1).wrapping_add(u32::MAX - 1), @@ -422,22 +434,20 @@ mod test { cb.cs.num_witin as usize, vec![StepRecord { cycle: 3, - pc: Change { - before: 2u32.into(), - after: (2u32 + PC_STEP_SIZE as u32).into(), - }, + pc: Change::new(MOCK_PC_SUB, MOCK_PC_SUB + PC_STEP_SIZE), + insn_code: MOCK_PROGRAM[1], rs1: Some(ReadOp { - addr: 2.into(), + addr: CENO_PLATFORM.register_vma(2).into(), value: 11u32, previous_cycle: 0, }), rs2: Some(ReadOp { - addr: 3.into(), + addr: CENO_PLATFORM.register_vma(3).into(), value: 2u32, previous_cycle: 0, }), rd: Some(WriteOp { - addr: 4.into(), + addr: CENO_PLATFORM.register_vma(4).into(), value: Change { before: 0u32, after: 11u32.wrapping_sub(2u32), @@ -482,22 +492,20 @@ mod test { cb.cs.num_witin as usize, vec![StepRecord { cycle: 3, - pc: Change { - before: 2u32.into(), - after: (2u32 + PC_STEP_SIZE as u32).into(), - }, + pc: Change::new(MOCK_PC_SUB, MOCK_PC_SUB + PC_STEP_SIZE), + insn_code: MOCK_PROGRAM[1], rs1: Some(ReadOp { - addr: 2.into(), + addr: CENO_PLATFORM.register_vma(2).into(), value: 3u32, previous_cycle: 0, }), rs2: Some(ReadOp { - addr: 3.into(), + addr: CENO_PLATFORM.register_vma(3).into(), value: 11u32, previous_cycle: 0, }), rd: Some(WriteOp { - addr: 4.into(), + addr: CENO_PLATFORM.register_vma(4).into(), value: Change { before: 0u32, after: 3u32.wrapping_sub(11u32), diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 37222be8b..813f2b340 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -4,6 +4,11 @@ use crate::uint::UInt; pub(crate) const PC_STEP_SIZE: usize = 4; +pub const OPCODE_OP: usize = 0x33; +pub const FUNCT3_ADD_SUB: usize = 0; +pub const FUNCT7_ADD: usize = 0; +pub const FUNCT7_SUB: usize = 0x20; + #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, Copy)] pub enum OPType { diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 55386ce8f..8dea82b70 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1,10 +1,13 @@ use super::utils::{eval_by_expr, wit_infer_by_expr}; use crate::{ - circuit_builder::CircuitBuilder, + circuit_builder::{CircuitBuilder, ConstraintSystem}, expression::Expression, + scheme::utils::eval_by_expr_with_fixed, structs::{ROMType, WitnessId}, + tables::{ProgramTableCircuit, TableCircuit}, }; use ark_std::test_rng; +use ceno_emul::{ByteAddr, CENO_PLATFORM}; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::SmallField; @@ -12,6 +15,18 @@ use itertools::Itertools; use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; use std::{collections::HashSet, hash::Hash, marker::PhantomData, ops::Neg, sync::OnceLock}; +/// The program baked in the MockProver. +/// TODO: Make this a parameter? +pub const MOCK_PROGRAM: &[u32] = &[ + // add x4, x2, x3 + 0x00 << 25 | 3 << 20 | 2 << 15 | 4 << 7 | 0x33, + // sub x4, x2, x3 + 0x20 << 25 | 3 << 20 | 2 << 15 | 4 << 7 | 0x33, +]; +// Addresses of particular instructions in the mock program. +pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); +pub const MOCK_PC_SUB: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 4); + #[allow(clippy::enum_variant_names)] #[derive(Debug, PartialEq, Clone)] pub(crate) enum MockProverError { @@ -297,6 +312,29 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> } } + fn load_program_table( + t_vec: &mut Vec>, + _cb: &CircuitBuilder, + challenge: [E; 2], + ) { + let mut cs = ConstraintSystem::::new(|| "mock_program"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = ProgramTableCircuit::construct_circuit(&mut cb).unwrap(); + let fixed = + ProgramTableCircuit::::generate_fixed_traces(&config, cs.num_fixed, MOCK_PROGRAM); + for table_expr in &cs.lk_table_expressions { + for row in fixed.iter_rows() { + // TODO: Find a better way to obtain the row content. + let row = row + .iter() + .map(|v| unsafe { v.clone().assume_init() }.into()) + .collect::>(); + let rlc_record = eval_by_expr_with_fixed(&row, &[], &challenge, &table_expr.values); + t_vec.push(rlc_record.to_repr().as_ref().to_vec()); + } + } + } + let mut table_vec = vec![]; // TODO load more tables here load_u5_table(&mut table_vec, cb, challenge); @@ -304,6 +342,7 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> load_lt_table(&mut table_vec, cb, challenge); load_and_table(&mut table_vec, cb, challenge); load_ltu_table(&mut table_vec, cb, challenge); + load_program_table(&mut table_vec, cb, challenge); HashSet::from_iter(table_vec) } @@ -336,7 +375,7 @@ fn load_once_tables( impl<'a, E: ExtensionField + Hash> MockProver { pub fn run_with_challenge( - cb: &mut CircuitBuilder, + cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], challenge: [E; 2], ) -> Result<(), Vec>> { @@ -344,14 +383,14 @@ impl<'a, E: ExtensionField + Hash> MockProver { } pub fn run( - cb: &mut CircuitBuilder, + cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], ) -> Result<(), Vec>> { Self::run_maybe_challenge(cb, wits_in, None) } fn run_maybe_challenge( - cb: &mut CircuitBuilder, + cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], challenge: Option<[E; 2]>, ) -> Result<(), Vec>> { @@ -451,7 +490,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { } pub fn assert_satisfied( - cb: &mut CircuitBuilder, + cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], challenge: Option<[E; 2]>, ) { diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index e2bbe0b56..1a522df88 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -41,10 +41,11 @@ pub type ChallengeId = u16; #[derive(Debug)] pub enum ROMType { - U5 = 0, // 2^5 = 32 - U16, // 2^16 = 65,536 - And, // a ^ b where a, b are bytes - Ltu, // a <(usign) b where a, b are bytes + U5 = 0, // 2^5 = 32 + U16, // 2^16 = 65,536 + And, // a ^ b where a, b are bytes + Ltu, // a <(usign) b where a, b are bytes + Instruction, // Decoded instruction from the fixed program. } #[derive(Clone, Debug, Copy)] @@ -155,13 +156,14 @@ impl ZKVMFixedTraces { &mut self, cs: &ZKVMConstraintSystem, config: TC::TableConfig, + input: &TC::FixedInput, ) { let cs = cs.get_cs(&TC::name()).expect("cs not found"); assert!( self.circuit_fixed_traces .insert( TC::name(), - Some(TC::generate_fixed_traces(&config, cs.num_fixed,)), + Some(TC::generate_fixed_traces(&config, cs.num_fixed, input)), ) .is_none() ); @@ -227,6 +229,7 @@ impl ZKVMWitnesses { &mut self, cs: &ZKVMConstraintSystem, config: &TC::TableConfig, + input: &TC::WitnessInput, ) -> Result<(), ZKVMError> { assert!(self.combined_lk_mlt.is_some()); @@ -235,6 +238,7 @@ impl ZKVMWitnesses { config, cs.num_witin as usize, self.combined_lk_mlt.as_ref().unwrap(), + input, )?; assert!(self.witnesses.insert(TC::name(), witness).is_none()); diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index a66094952..7c9a29ed4 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -5,9 +5,13 @@ use std::collections::HashMap; mod range; pub use range::RangeTableCircuit; +mod program; +pub use program::{InsnRecord, ProgramTableCircuit}; + pub trait TableCircuit { type TableConfig: Send + Sync; - type Input: Send + Sync; + type FixedInput: Send + Sync + ?Sized; + type WitnessInput: Send + Sync + ?Sized; fn name() -> String; @@ -18,11 +22,13 @@ pub trait TableCircuit { fn generate_fixed_traces( config: &Self::TableConfig, num_fixed: usize, + input: &Self::FixedInput, ) -> RowMajorMatrix; fn assign_instances( config: &Self::TableConfig, num_witin: usize, multiplicity: &[HashMap], + input: &Self::WitnessInput, ) -> Result, ZKVMError>; } diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs new file mode 100644 index 000000000..6acb8991c --- /dev/null +++ b/ceno_zkvm/src/tables/program.rs @@ -0,0 +1,177 @@ +use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, Fixed, ToExpr, WitIn}, + scheme::constants::MIN_PAR_SIZE, + set_fixed_val, set_val, + structs::ROMType, + tables::TableCircuit, + witness::RowMajorMatrix, +}; +use ceno_emul::{DecodedInstruction, Word, CENO_PLATFORM, WORD_SIZE}; +use ff_ext::ExtensionField; +use itertools::Itertools; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; + +#[derive(Clone, Debug)] +pub struct InsnRecord([T; 7]); + +impl InsnRecord { + pub fn new(pc: T, opcode: T, rd: T, funct3: T, rs1: T, rs2: T, imm_or_funct7: T) -> Self { + InsnRecord([pc, opcode, rd, funct3, rs1, rs2, imm_or_funct7]) + } + + pub fn as_slice(&self) -> &[T] { + &self.0 + } + + pub fn pc(&self) -> &T { + &self.0[0] + } + + pub fn opcode(&self) -> &T { + &self.0[1] + } + + pub fn rd(&self) -> &T { + &self.0[2] + } + + pub fn funct3(&self) -> &T { + &self.0[3] + } + + pub fn rs1(&self) -> &T { + &self.0[4] + } + + pub fn rs2(&self) -> &T { + &self.0[5] + } + + /// The complete immediate value, for instruction types I/S/B/U/J. + /// Otherwise, the field funct7 of R-Type instructions. + pub fn imm_or_funct7(&self) -> &T { + &self.0[6] + } +} + +impl InsnRecord { + fn from_decoded(pc: u32, insn: &DecodedInstruction) -> Self { + InsnRecord::new( + pc, + insn.opcode(), + insn.rd(), + insn.funct3(), + insn.rs1(), + insn.rs2(), + insn.funct7(), // TODO: get immediate for all types. + ) + } +} + +#[derive(Clone, Debug)] +pub struct ProgramTableConfig { + /// The fixed table of instruction records. + record: InsnRecord, + + /// Multiplicity of the record - how many times an instruction is visited. + mlt: WitIn, +} + +pub struct ProgramTableCircuit(PhantomData); + +impl TableCircuit for ProgramTableCircuit { + type TableConfig = ProgramTableConfig; + type FixedInput = [u32]; + type WitnessInput = usize; + + fn name() -> String { + "PROGRAM".into() + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let record = InsnRecord([ + cb.create_fixed(|| "pc")?, + cb.create_fixed(|| "opcode")?, + cb.create_fixed(|| "rd")?, + cb.create_fixed(|| "funct3")?, + cb.create_fixed(|| "rs1")?, + cb.create_fixed(|| "rs2")?, + cb.create_fixed(|| "imm_or_funct7")?, + ]); + + let mlt = cb.create_witin(|| "mlt")?; + + let record_exprs = { + let mut fields = vec![E::BaseField::from(ROMType::Instruction as u64).expr()]; + fields.extend( + record + .as_slice() + .iter() + .map(|f| Expression::Fixed(f.clone())), + ); + cb.rlc_chip_record(fields) + }; + + cb.lk_table_record(|| "prog table", record_exprs, mlt.expr())?; + + Ok(ProgramTableConfig { record, mlt }) + } + + fn generate_fixed_traces( + config: &ProgramTableConfig, + num_fixed: usize, + program: &[Word], + ) -> RowMajorMatrix { + // TODO: get bytecode of the program. + let num_instructions = program.len(); + let pc_start = CENO_PLATFORM.pc_start(); + + let mut fixed = RowMajorMatrix::::new(num_instructions, num_fixed); + + fixed + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip((0..num_instructions).into_par_iter()) + .for_each(|(row, i)| { + let pc = pc_start + (i * WORD_SIZE) as u32; + let insn = DecodedInstruction::new(program[i]); + let values = InsnRecord::from_decoded(pc, &insn); + + for (col, val) in config.record.as_slice().iter().zip_eq(values.as_slice()) { + set_fixed_val!(row, *col, E::BaseField::from(*val as u64)); + } + }); + + fixed + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + multiplicity: &[HashMap], + num_instructions: &usize, + ) -> Result, ZKVMError> { + let multiplicity = &multiplicity[ROMType::Instruction as usize]; + + let mut prog_mlt = vec![0_usize; *num_instructions]; + for (pc, mlt) in multiplicity { + let i = (*pc as usize - CENO_PLATFORM.pc_start() as usize) / WORD_SIZE; + prog_mlt[i] = *mlt; + } + + let mut witness = RowMajorMatrix::::new(prog_mlt.len(), num_witin); + witness + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(prog_mlt.into_par_iter()) + .for_each(|(row, mlt)| { + set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); + }); + + Ok(witness) + } +} diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index 293077f46..2b195ea63 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -24,7 +24,8 @@ pub struct RangeTableCircuit(PhantomData); impl TableCircuit for RangeTableCircuit { type TableConfig = RangeTableConfig; - type Input = u64; + type FixedInput = (); + type WitnessInput = (); fn name() -> String { "RANGE".into() @@ -47,6 +48,7 @@ impl TableCircuit for RangeTableCircuit { fn generate_fixed_traces( config: &RangeTableConfig, num_fixed: usize, + _input: &(), ) -> RowMajorMatrix { let num_u16s = 1 << 16; let mut fixed = RowMajorMatrix::::new(num_u16s, num_fixed); @@ -55,7 +57,7 @@ impl TableCircuit for RangeTableCircuit { .with_min_len(MIN_PAR_SIZE) .zip((0..num_u16s).into_par_iter()) .for_each(|(row, i)| { - set_fixed_val!(row, config.u16_tbl.0, E::BaseField::from(i as u64)); + set_fixed_val!(row, config.u16_tbl, E::BaseField::from(i as u64)); }); fixed @@ -65,6 +67,7 @@ impl TableCircuit for RangeTableCircuit { config: &Self::TableConfig, num_witin: usize, multiplicity: &[HashMap], + _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[ROMType::U16 as usize]; let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH]; diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index a77d9c4c3..263b6645c 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -3,7 +3,7 @@ use std::{ cell::RefCell, collections::HashMap, mem::{self, MaybeUninit}, - slice::ChunksMut, + slice::{Chunks, ChunksMut}, sync::Arc, }; @@ -26,7 +26,7 @@ macro_rules! set_val { #[macro_export] macro_rules! set_fixed_val { ($ins:ident, $field:expr, $val:expr) => { - $ins[$field as usize] = MaybeUninit::new($val); + $ins[$field.0] = MaybeUninit::new($val); }; } @@ -52,6 +52,10 @@ impl RowMajorMatrix { self.values.len() / self.num_col - self.num_padding_rows } + pub fn iter_rows(&self) -> Chunks> { + self.values.chunks(self.num_col) + } + pub fn iter_mut(&mut self) -> ChunksMut> { self.values.chunks_mut(self.num_col) } @@ -142,6 +146,15 @@ impl LkMultiplicity { .or_default()) += 1; } + pub fn fetch(&mut self, pc: u32) { + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::Instruction as usize] + .entry(pc as u64) + .or_default()) += 1; + } + /// merge result from multiple thread local to single result pub fn into_finalize_result(self) -> [HashMap; mem::variant_count::()] { Arc::try_unwrap(self.multiplicity)