From 1637f4c122f3476de0788870362616818469a615 Mon Sep 17 00:00:00 2001 From: naure Date: Tue, 17 Sep 2024 12:42:07 +0200 Subject: [PATCH 1/3] Common circuit for R-Instructions (#231) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _Issues #121, #133, #134, #123, …_ Extract generic instruction handling out of the Add/Sub circuit. --------- Co-authored-by: Aurélien Nicolas --- Cargo.lock | 2 + ceno_emul/Cargo.toml | 2 + ceno_emul/src/lib.rs | 2 +- ceno_emul/src/rv32im.rs | 30 +- ceno_zkvm/src/chip_handler.rs | 10 +- ceno_zkvm/src/chip_handler/register.rs | 10 +- ceno_zkvm/src/instructions/riscv.rs | 10 +- ceno_zkvm/src/instructions/riscv/addsub.rs | 360 ++++++------------ ceno_zkvm/src/instructions/riscv/blt.rs | 7 +- ceno_zkvm/src/instructions/riscv/constants.rs | 29 -- ceno_zkvm/src/instructions/riscv/r_insn.rs | 176 +++++++++ 11 files changed, 332 insertions(+), 306 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/r_insn.rs diff --git a/Cargo.lock b/Cargo.lock index 3c030bf42..172cfec85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -310,6 +310,8 @@ version = "0.1.0" dependencies = [ "anyhow", "elf", + "strum 0.25.0", + "strum_macros 0.25.3", "tracing", ] diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index ed1096822..81730c113 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -6,6 +6,8 @@ license.workspace = true [dependencies] anyhow = { version = "1.0", default-features = false } +strum = "0.25.0" +strum_macros = "0.25.3" tracing = { version = "0.1", default-features = false, features = [ "attributes", ] } diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index c5359e442..8a21a0ac2 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -11,7 +11,7 @@ mod vm_state; pub use vm_state::VMState; mod rv32im; -pub use rv32im::{DecodedInstruction, EmuContext, InsnCategory, InsnKind}; +pub use rv32im::{DecodedInstruction, EmuContext, InsnCodes, InsnCategory, InsnKind}; mod elf; pub use elf::Program; diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 08f21302d..c173f1f21 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -16,6 +16,7 @@ use anyhow::{anyhow, Result}; use std::sync::OnceLock; +use strum_macros::EnumIter; use super::addr::{ByteAddr, RegIdx, Word, WordAddr, WORD_SIZE}; @@ -121,7 +122,7 @@ pub enum InsnCategory { Invalid, } -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, EnumIter)] #[allow(clippy::upper_case_acronyms)] pub enum InsnKind { INVALID, @@ -174,8 +175,14 @@ pub enum InsnKind { MRET, } +impl InsnKind { + pub const fn codes(self) -> InsnCodes { + RV32IM_ISA[self as usize] + } +} + #[derive(Clone, Copy, Debug)] -struct FastDecodeEntry { +pub struct InsnCodes { pub kind: InsnKind, category: InsnCategory, pub opcode: u32, @@ -269,8 +276,8 @@ const fn insn( opcode: u32, func3: i32, func7: i32, -) -> FastDecodeEntry { - FastDecodeEntry { +) -> InsnCodes { + InsnCodes { kind, category, opcode, @@ -279,7 +286,7 @@ const fn insn( } } -type InstructionTable = [FastDecodeEntry; 48]; +type InstructionTable = [InsnCodes; 48]; type FastInstructionTable = [u8; 1 << 10]; const RV32IM_ISA: InstructionTable = [ @@ -333,6 +340,15 @@ const RV32IM_ISA: InstructionTable = [ insn(InsnKind::MRET, InsnCategory::System, 0x73, 0x0, 0x18), ]; +#[cfg(test)] +#[test] +fn test_isa_table() { + use strum::IntoEnumIterator; + for kind in InsnKind::iter() { + assert_eq!(kind.codes().kind, kind); + } +} + // RISC-V instruction are determined by 3 parts: // - Opcode: 7 bits // - Func3: 3 bits @@ -373,7 +389,7 @@ impl FastDecodeTable { ((op_high << 5) | (func72bits << 3) | func3) as usize } - fn add_insn(table: &mut FastInstructionTable, insn: &FastDecodeEntry, isa_idx: usize) { + fn add_insn(table: &mut FastInstructionTable, insn: &InsnCodes, isa_idx: usize) { let op_high = insn.opcode >> 2; if (insn.func3 as i32) < 0 { for f3 in 0..8 { @@ -392,7 +408,7 @@ impl FastDecodeTable { } } - fn lookup(&self, decoded: &DecodedInstruction) -> FastDecodeEntry { + fn lookup(&self, decoded: &DecodedInstruction) -> InsnCodes { let isa_idx = self.table[Self::map10(decoded.opcode, decoded.func3, decoded.func7)]; RV32IM_ISA[isa_idx as usize] } diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index b3ac7cd52..76739a317 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -18,23 +18,23 @@ pub trait GlobalStateRegisterMachineChipOperations { } pub trait RegisterChipOperations, N: FnOnce() -> NR> { - fn register_read>>>( + fn register_read( &mut self, name_fn: N, register_id: &WitIn, prev_ts: Expression, ts: Expression, - values: &V, + values: &impl ToExpr>>, ) -> Result<(Expression, ExprLtConfig), ZKVMError>; #[allow(clippy::too_many_arguments)] - fn register_write>>>( + fn register_write( &mut self, name_fn: N, register_id: &WitIn, prev_ts: Expression, ts: Expression, - prev_values: &V, - values: &V, + prev_values: &impl ToExpr>>, + values: &impl ToExpr>>, ) -> Result<(Expression, ExprLtConfig), ZKVMError>; } diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index bdc978710..7efe3bb24 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -13,13 +13,13 @@ use super::RegisterChipOperations; impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOperations for CircuitBuilder<'a, E> { - fn register_read>>>( + fn register_read( &mut self, name_fn: N, register_id: &WitIn, prev_ts: Expression, ts: Expression, - values: &V, + values: &impl ToExpr>>, ) -> Result<(Expression, ExprLtConfig), ZKVMError> { self.namespace(name_fn, |cb| { // READ (a, v, t) @@ -58,14 +58,14 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe }) } - fn register_write>>>( + fn register_write( &mut self, name_fn: N, register_id: &WitIn, prev_ts: Expression, ts: Expression, - prev_values: &V, - values: &V, + prev_values: &impl ToExpr>>, + values: &impl ToExpr>>, ) -> Result<(Expression, ExprLtConfig), ZKVMError> { self.namespace(name_fn, |cb| { // READ (a, v, t) diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index bef83bd52..3d66f101a 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -1,16 +1,14 @@ -use constants::OpcodeType; -use ff_ext::ExtensionField; - -use super::Instruction; +use ceno_emul::InsnKind; pub mod addsub; pub mod blt; pub mod config; pub mod constants; +mod r_insn; #[cfg(test)] mod test; -pub trait RIVInstruction: Instruction { - const OPCODE_TYPE: OpcodeType; +pub trait RIVInstruction { + const INST_KIND: InsnKind; } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 5525fdfec..4d08945ad 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -1,284 +1,144 @@ use std::marker::PhantomData; -use ceno_emul::StepRecord; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use itertools::Itertools; -use super::{ - config::ExprLtConfig, - constants::{ - OPType, OpcodeType, RegUInt, FUNCT3_ADD_SUB, FUNCT7_ADD, FUNCT7_SUB, OPCODE_OP, - PC_STEP_SIZE, - }, - RIVInstruction, -}; +use super::{constants::RegUInt, r_insn::RInstructionConfig, RIVInstruction}; use crate::{ - chip_handler::{GlobalStateRegisterMachineChipOperations, RegisterChipOperations}, - circuit_builder::CircuitBuilder, - error::ZKVMError, - expression::{ToExpr, WitIn}, - instructions::{riscv::config::ExprLtInput, Instruction}, - set_val, - tables::InsnRecord, - uint::UIntValue, + circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, uint::UIntValue, witness::LkMultiplicity, }; use core::mem::MaybeUninit; -pub struct AddInstruction(PhantomData); -pub struct SubInstruction(PhantomData); - +/// This config handles R-Instructions that represent registers values as 2 * u16. #[derive(Debug)] -pub struct InstructionConfig { - pub pc: WitIn, - pub ts: WitIn, - pub prev_rd_value: RegUInt, - pub addend_0: RegUInt, - pub addend_1: RegUInt, - pub outcome: RegUInt, - pub rs1_id: WitIn, - pub rs2_id: WitIn, - pub rd_id: WitIn, - pub prev_rs1_ts: WitIn, - pub prev_rs2_ts: WitIn, - pub prev_rd_ts: WitIn, - pub lt_rs1_cfg: ExprLtConfig, - pub lt_rs2_cfg: ExprLtConfig, - pub lt_prev_ts_cfg: ExprLtConfig, - phantom: PhantomData, -} +pub struct ArithConfig { + r_insn: RInstructionConfig, -impl RIVInstruction for AddInstruction { - const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0000000); + rs1_read: RegUInt, + rs2_read: RegUInt, + rd_written: RegUInt, } -impl RIVInstruction for SubInstruction { - const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0100000); -} - -fn add_sub_gadget( - circuit_builder: &mut CircuitBuilder, -) -> Result, ZKVMError> { - let pc = circuit_builder.create_witin(|| "pc")?; - let cur_ts = circuit_builder.create_witin(|| "cur_ts")?; - - // state in - circuit_builder.state_in(pc.expr(), cur_ts.expr())?; +pub struct ArithInstruction(PhantomData<(E, I)>); - let next_pc = pc.expr() + PC_STEP_SIZE.into(); - - // Execution result = addend0 + addend1, with carry. - let prev_rd_value = RegUInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; - - let (addend_0, addend_1, outcome) = if IS_ADD { - // outcome = addend_0 + addend_1 - let addend_0 = RegUInt::new_unchecked(|| "addend_0", circuit_builder)?; - let addend_1 = RegUInt::new_unchecked(|| "addend_1", circuit_builder)?; - ( - addend_0.clone(), - addend_1.clone(), - addend_0.add(|| "outcome", circuit_builder, &addend_1, true)?, - ) - } else { - // outcome + addend_1 = addend_0 - // outcome is the new value to be updated in register so we need to constrain its range - let outcome = RegUInt::new(|| "outcome", circuit_builder)?; - let addend_1 = RegUInt::new_unchecked(|| "addend_1", circuit_builder)?; - ( - addend_1 - .clone() - .add(|| "addend_0", circuit_builder, &outcome.clone(), true)?, - addend_1, - outcome, - ) - }; - - let rs1_id = circuit_builder.create_witin(|| "rs1_id")?; - let rs2_id = circuit_builder.create_witin(|| "rs2_id")?; - let rd_id = circuit_builder.create_witin(|| "rd_id")?; - - // Fetch the instruction. - circuit_builder.lk_fetch(&InsnRecord::new( - pc.expr(), - OPCODE_OP.into(), - rd_id.expr(), - FUNCT3_ADD_SUB.into(), - rs1_id.expr(), - rs2_id.expr(), - (if IS_ADD { FUNCT7_ADD } else { FUNCT7_SUB }).into(), - ))?; - - let prev_rs1_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; - let prev_rs2_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?; - let prev_rd_ts = circuit_builder.create_witin(|| "prev_rd_ts")?; - - let (ts, lt_rs1_cfg) = circuit_builder.register_read( - || "read_rs1", - &rs1_id, - prev_rs1_ts.expr(), - cur_ts.expr(), - &addend_0, - )?; - let (ts, lt_rs2_cfg) = - circuit_builder.register_read(|| "read_rs2", &rs2_id, prev_rs2_ts.expr(), ts, &addend_1)?; - - let (ts, lt_prev_ts_cfg) = circuit_builder.register_write( - || "write_rd", - &rd_id, - prev_rd_ts.expr(), - ts, - &prev_rd_value, - &outcome, - )?; - - let next_ts = ts + 1.into(); - circuit_builder.state_out(next_pc, next_ts)?; - - Ok(InstructionConfig { - pc, - ts: cur_ts, - prev_rd_value, - addend_0, - addend_1, - outcome, - rs1_id, - rs2_id, - rd_id, - prev_rs1_ts, - prev_rs2_ts, - prev_rd_ts, - lt_rs1_cfg, - lt_rs2_cfg, - lt_prev_ts_cfg, - phantom: PhantomData, - }) +pub struct AddOp; +impl RIVInstruction for AddOp { + const INST_KIND: InsnKind = InsnKind::ADD; } +pub type AddInstruction = ArithInstruction; -fn add_sub_assignment( - config: &InstructionConfig, - instance: &mut [MaybeUninit], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, -) -> Result<(), ZKVMError> { - lk_multiplicity.fetch(step.pc().before.0); - set_val!(instance, config.pc, step.pc().before.0 as u64); - set_val!(instance, config.ts, step.cycle()); - let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value); - let rd_prev = UIntValue::new_unchecked(step.rd().unwrap().value.before); - config - .prev_rd_value - .assign_limbs(instance, rd_prev.u16_fields()); - - config - .addend_1 - .assign_limbs(instance, addend_1.u16_fields()); - - if IS_ADD { - // addend_0 + addend_1 = outcome - let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value); - config - .addend_0 - .assign_limbs(instance, addend_0.u16_fields()); - let (_, outcome_carries) = addend_0.add(&addend_1, lk_multiplicity, true); - config.outcome.assign_carries( - instance, - outcome_carries - .into_iter() - .map(|carry| E::BaseField::from(carry as u64)) - .collect_vec(), - ); - } else { - // addend_0 = outcome + addend_1 - let outcome = UIntValue::new(step.rd().unwrap().value.after, lk_multiplicity); - config.outcome.assign_limbs(instance, outcome.u16_fields()); - let (_, addend_0_carries) = addend_1.add(&outcome, lk_multiplicity, true); - config.addend_0.assign_carries( - instance, - addend_0_carries - .into_iter() - .map(|carry| E::BaseField::from(carry as u64)) - .collect_vec(), - ); - } - set_val!(instance, config.rs1_id, step.insn().rs1() as u64); - set_val!(instance, config.rs2_id, step.insn().rs2() as u64); - set_val!(instance, config.rd_id, step.insn().rd() as u64); - ExprLtInput { - lhs: step.rs1().unwrap().previous_cycle, - rhs: step.cycle(), - } - .assign(instance, &config.lt_rs1_cfg, lk_multiplicity); - ExprLtInput { - lhs: step.rs2().unwrap().previous_cycle, - rhs: step.cycle() + 1, - } - .assign(instance, &config.lt_rs2_cfg, lk_multiplicity); - ExprLtInput { - lhs: step.rd().unwrap().previous_cycle, - rhs: step.cycle() + 2, - } - .assign(instance, &config.lt_prev_ts_cfg, lk_multiplicity); - set_val!( - instance, - config.prev_rs1_ts, - step.rs1().unwrap().previous_cycle - ); - set_val!( - instance, - config.prev_rs2_ts, - step.rs2().unwrap().previous_cycle - ); - set_val!( - instance, - config.prev_rd_ts, - step.rd().unwrap().previous_cycle - ); - Ok(()) +pub struct SubOp; +impl RIVInstruction for SubOp { + const INST_KIND: InsnKind = InsnKind::SUB; } +pub type SubInstruction = ArithInstruction; -impl Instruction for AddInstruction { - // const NAME: &'static str = "ADD"; - fn name() -> String { - "ADD".into() - } - type InstructionConfig = InstructionConfig; - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - ) -> Result, ZKVMError> { - add_sub_gadget::(circuit_builder) - } +impl Instruction for ArithInstruction { + type InstructionConfig = ArithConfig; - #[allow(clippy::option_map_unit_fn)] - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [MaybeUninit], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - add_sub_assignment::<_, true>(config, instance, lk_multiplicity, step) - } -} - -impl Instruction for SubInstruction { - // const NAME: &'static str = "ADD"; fn name() -> String { - "SUB".into() + format!("{:?}", I::INST_KIND) } - type InstructionConfig = InstructionConfig; + fn construct_circuit( circuit_builder: &mut CircuitBuilder, - ) -> Result, ZKVMError> { - add_sub_gadget::(circuit_builder) + ) -> Result { + let (rs1_read, rs2_read, rd_written) = match I::INST_KIND { + InsnKind::ADD => { + // rd_written = rs1_read + rs2_read + let rs1_read = RegUInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rs2_read = RegUInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let rd_written = rs1_read.add(|| "rd_written", circuit_builder, &rs2_read, true)?; + (rs1_read, rs2_read, rd_written) + } + + InsnKind::SUB => { + // rd_written + rs2_read = rs1_read + // rd_written is the new value to be updated in register so we need to constrain its range. + let rd_written = RegUInt::new(|| "rd_written", circuit_builder)?; + let rs2_read = RegUInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let rs1_read = rs2_read.clone().add( + || "rs1_read", + circuit_builder, + &rd_written.clone(), + true, + )?; + (rs1_read, rs2_read, rd_written) + } + + _ => unreachable!("Unsupported instruction kind"), + }; + + let r_insn = RInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + &rs1_read, + &rs2_read, + &rd_written, + )?; + + Ok(ArithConfig { + r_insn, + rs1_read, + rs2_read, + rd_written, + }) } - #[allow(clippy::option_map_unit_fn)] fn assign_instance( config: &Self::InstructionConfig, - instance: &mut [MaybeUninit], + instance: &mut [MaybeUninit<::BaseField>], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - add_sub_assignment::<_, false>(config, instance, lk_multiplicity, step) + config + .r_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rs2_read = UIntValue::new_unchecked(step.rs2().unwrap().value); + config + .rs2_read + .assign_limbs(instance, rs2_read.u16_fields()); + + match I::INST_KIND { + InsnKind::ADD => { + // rs1_read + rs2_read = rd_written + let rs1_read = UIntValue::new_unchecked(step.rs1().unwrap().value); + config + .rs1_read + .assign_limbs(instance, rs1_read.u16_fields()); + let (_, outcome_carries) = rs1_read.add(&rs2_read, lk_multiplicity, true); + config.rd_written.assign_carries( + instance, + outcome_carries + .into_iter() + .map(|carry| E::BaseField::from(carry as u64)) + .collect_vec(), + ); + } + + InsnKind::SUB => { + // rs1_read = rd_written + rs2_read + let rd_written = UIntValue::new(step.rd().unwrap().value.after, lk_multiplicity); + config + .rd_written + .assign_limbs(instance, rd_written.u16_fields()); + let (_, addend_0_carries) = rs2_read.add(&rd_written, lk_multiplicity, true); + config.rs1_read.assign_carries( + instance, + addend_0_carries + .into_iter() + .map(|carry| E::BaseField::from(carry as u64)) + .collect_vec(), + ); + } + + _ => unreachable!("Unsupported instruction kind"), + }; + + Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index 30f4cc333..342bb6090 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -1,3 +1,4 @@ +use ceno_emul::InsnKind; use goldilocks::SmallField; use std::mem::MaybeUninit; @@ -20,7 +21,7 @@ use crate::{ use super::{ config::ExprLtConfig, - constants::{OPType, OpcodeType, RegUInt, RegUInt8, PC_STEP_SIZE}, + constants::{RegUInt, RegUInt8, PC_STEP_SIZE}, RIVInstruction, }; @@ -140,8 +141,8 @@ impl BltInput { } } -impl RIVInstruction for BltInstruction { - const OPCODE_TYPE: OpcodeType = OpcodeType::BType(OPType::Branch, 0x004); +impl RIVInstruction for BltInstruction { + const INST_KIND: InsnKind = InsnKind::BLT; } /// if (rs1 < rs2) PC += sext(imm) diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index e8c12cdb0..fe082e6b5 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -1,35 +1,6 @@ -use std::fmt; - use crate::uint::UInt; pub use ceno_emul::PC_STEP_SIZE; -pub const OPCODE_OP: usize = 0x33; -pub const FUNCT3_ADD_SUB: usize = 0; -pub const FUNCT7_ADD: usize = 0; -pub const FUNCT7_SUB: usize = 0x20; - -#[allow(clippy::upper_case_acronyms)] -#[derive(Debug, Clone, Copy)] -pub enum OPType { - Op, - Opimm, - Jal, - Jalr, - Branch, -} - -#[derive(Debug, Clone, Copy)] -pub enum OpcodeType { - RType(OPType, usize, usize), // (OP, func3, func7) - BType(OPType, usize), // (OP, func3) -} - -impl fmt::Display for OpcodeType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?}", self) - } -} - pub const VALUE_BIT_WIDTH: usize = 16; #[cfg(feature = "riv32")] diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs new file mode 100644 index 000000000..61baa02b4 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -0,0 +1,176 @@ +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; + +use super::{ + config::ExprLtConfig, + constants::{RegUInt, PC_STEP_SIZE}, +}; +use crate::{ + chip_handler::{GlobalStateRegisterMachineChipOperations, RegisterChipOperations}, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + instructions::riscv::config::ExprLtInput, + set_val, + tables::InsnRecord, + uint::UIntValue, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +/// This config handles the common part of R-type instructions: +/// - PC, cycle, fetch. +/// - Registers read and write. +/// +/// It does not witness of the register values, nor the actual function (e.g. add, sub, etc). +#[derive(Debug)] +pub struct RInstructionConfig { + pc: WitIn, + ts: WitIn, + rs1_id: WitIn, + rs2_id: WitIn, + rd_id: WitIn, + prev_rd_value: RegUInt, + prev_rs1_ts: WitIn, + prev_rs2_ts: WitIn, + prev_rd_ts: WitIn, + lt_rs1_cfg: ExprLtConfig, + lt_rs2_cfg: ExprLtConfig, + lt_prev_ts_cfg: ExprLtConfig, +} + +impl RInstructionConfig { + pub fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + insn_kind: InsnKind, + rs1_read: &impl ToExpr>>, + rs2_read: &impl ToExpr>>, + rd_written: &impl ToExpr>>, + ) -> Result { + // State in. + let pc = circuit_builder.create_witin(|| "pc")?; + let cur_ts = circuit_builder.create_witin(|| "cur_ts")?; + circuit_builder.state_in(pc.expr(), cur_ts.expr())?; + + // Register indexes. + let rs1_id = circuit_builder.create_witin(|| "rs1_id")?; + let rs2_id = circuit_builder.create_witin(|| "rs2_id")?; + let rd_id = circuit_builder.create_witin(|| "rd_id")?; + + // Fetch the instruction. + circuit_builder.lk_fetch(&InsnRecord::new( + pc.expr(), + (insn_kind.codes().opcode as usize).into(), + rd_id.expr(), + (insn_kind.codes().func3 as usize).into(), + rs1_id.expr(), + rs2_id.expr(), + (insn_kind.codes().func7 as usize).into(), + ))?; + + // Register state. + let prev_rs1_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; + let prev_rs2_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?; + let prev_rd_ts = circuit_builder.create_witin(|| "prev_rd_ts")?; + let prev_rd_value = RegUInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; + + // Register read and write. + let (ts, lt_rs1_cfg) = circuit_builder.register_read( + || "read_rs1", + &rs1_id, + prev_rs1_ts.expr(), + cur_ts.expr(), + rs1_read, + )?; + let (ts, lt_rs2_cfg) = circuit_builder.register_read( + || "read_rs2", + &rs2_id, + prev_rs2_ts.expr(), + ts, + rs2_read, + )?; + let (ts, lt_prev_ts_cfg) = circuit_builder.register_write( + || "write_rd", + &rd_id, + prev_rd_ts.expr(), + ts, + &prev_rd_value, + rd_written, + )?; + + // State out. + let next_pc = pc.expr() + PC_STEP_SIZE.into(); + let next_ts = ts + 1.into(); + circuit_builder.state_out(next_pc, next_ts)?; + + Ok(RInstructionConfig { + pc, + ts: cur_ts, + rs1_id, + rs2_id, + rd_id, + prev_rd_value, + prev_rs1_ts, + prev_rs2_ts, + prev_rd_ts, + lt_rs1_cfg, + lt_rs2_cfg, + lt_prev_ts_cfg, + }) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + // State in. + set_val!(instance, self.pc, step.pc().before.0 as u64); + set_val!(instance, self.ts, step.cycle()); + + // Register indexes. + set_val!(instance, self.rs1_id, step.insn().rs1() as u64); + set_val!(instance, self.rs2_id, step.insn().rs2() as u64); + set_val!(instance, self.rd_id, step.insn().rd() as u64); + + // Fetch the instruction. + lk_multiplicity.fetch(step.pc().before.0); + + // Register state. + set_val!( + instance, + self.prev_rs1_ts, + step.rs1().unwrap().previous_cycle + ); + set_val!( + instance, + self.prev_rs2_ts, + step.rs2().unwrap().previous_cycle + ); + set_val!(instance, self.prev_rd_ts, step.rd().unwrap().previous_cycle); + self.prev_rd_value.assign_limbs( + instance, + UIntValue::new_unchecked(step.rd().unwrap().value.before).u16_fields(), + ); + + // Register read and write. + ExprLtInput { + lhs: step.rs1().unwrap().previous_cycle, + rhs: step.cycle(), + } + .assign(instance, &self.lt_rs1_cfg, lk_multiplicity); + ExprLtInput { + lhs: step.rs2().unwrap().previous_cycle, + rhs: step.cycle() + 1, + } + .assign(instance, &self.lt_rs2_cfg, lk_multiplicity); + ExprLtInput { + lhs: step.rd().unwrap().previous_cycle, + rhs: step.cycle() + 2, + } + .assign(instance, &self.lt_prev_ts_cfg, lk_multiplicity); + + Ok(()) + } +} From 713461b9123ded701ec29eeca0f0df1a10247ea7 Mon Sep 17 00:00:00 2001 From: Kimi Wu Date: Wed, 18 Sep 2024 15:26:36 +0800 Subject: [PATCH 2/3] Feat/#98 riscv mul opcode (#219) close #98 --- ceno_emul/src/addr.rs | 6 + ceno_emul/src/rv32im.rs | 22 ++ ceno_emul/src/tracer.rs | 7 +- ceno_zkvm/src/expression.rs | 9 + ceno_zkvm/src/instructions/riscv.rs | 2 + ceno_zkvm/src/instructions/riscv/addsub.rs | 4 + ceno_zkvm/src/instructions/riscv/mul.rs | 218 ++++++++++++ ceno_zkvm/src/scheme/mock_prover.rs | 13 +- ceno_zkvm/src/uint.rs | 366 +++++++++++++-------- ceno_zkvm/src/uint/arithmetic.rs | 12 +- 10 files changed, 523 insertions(+), 136 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/mul.rs diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index aaa4ef48f..7ed47bad7 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -67,6 +67,12 @@ impl From for u32 { } } +impl From for u64 { + fn from(addr: WordAddr) -> Self { + addr.baddr().0 as u64 + } +} + impl ByteAddr { pub const fn waddr(self) -> WordAddr { WordAddr(self.0 / WORD_SIZE as u32) diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index c173f1f21..db2a8699d 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -204,6 +204,28 @@ impl DecodedInstruction { } } + #[allow(dead_code)] + pub fn from_raw(kind: InsnKind, rs1: u32, rs2: u32, rd: u32) -> Self { + // limit the range of inputs + let rs2 = rs2 & 0x1f; // 5bits mask + let rs1 = rs1 & 0x1f; + let rd = rd & 0x1f; + let func7 = kind.codes().func7; + let func3 = kind.codes().func3; + let opcode = kind.codes().opcode; + let insn = func7 << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode; + Self { + insn, + top_bit: func7 | 0x80, + func7, + rs2, + rs1, + func3, + rd, + opcode, + } + } + pub fn encoded(&self) -> u32 { self.insn } diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 1391c33eb..2a73ca4ac 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -60,6 +60,7 @@ impl StepRecord { rs1_read: Word, rs2_read: Word, rd: Change, + previous_cycle: Cycle, ) -> StepRecord { let insn = DecodedInstruction::new(insn_code); StepRecord { @@ -69,17 +70,17 @@ impl StepRecord { rs1: Some(ReadOp { addr: CENO_PLATFORM.register_vma(insn.rs1() as RegIdx).into(), value: rs1_read, - previous_cycle: 0, + previous_cycle, }), rs2: Some(ReadOp { addr: CENO_PLATFORM.register_vma(insn.rs2() as RegIdx).into(), value: rs2_read, - previous_cycle: 0, + previous_cycle, }), rd: Some(WriteOp { addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(), value: rd, - previous_cycle: 0, + previous_cycle, }), memory_op: None, } diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 3aec9b842..9b8f29679 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -1,5 +1,6 @@ use std::{ cmp::max, + mem::MaybeUninit, ops::{Add, Deref, Mul, Neg, Sub}, }; @@ -426,6 +427,14 @@ impl WitIn { }, ) } + + pub fn assign( + &self, + instance: &mut [MaybeUninit], + value: E::BaseField, + ) { + instance[self.id as usize] = MaybeUninit::new(value); + } } #[macro_export] diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 3d66f101a..6674c35d7 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -4,6 +4,8 @@ pub mod addsub; pub mod blt; pub mod config; pub mod constants; +pub mod mul; + mod r_insn; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 4d08945ad..6c33200cd 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -183,6 +183,7 @@ mod test { 11, 0xfffffffe, Change::new(0, 11_u32.wrapping_add(0xfffffffe)), + 0, )], ) .unwrap(); @@ -225,6 +226,7 @@ mod test { u32::MAX - 1, u32::MAX - 1, Change::new(0, (u32::MAX - 1).wrapping_add(u32::MAX - 1)), + 0, )], ) .unwrap(); @@ -267,6 +269,7 @@ mod test { 11, 2, Change::new(0, 11_u32.wrapping_sub(2)), + 0, )], ) .unwrap(); @@ -309,6 +312,7 @@ mod test { 3, 11, Change::new(0, 3_u32.wrapping_sub(11)), + 0, )], ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs new file mode 100644 index 000000000..5c628db8f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -0,0 +1,218 @@ +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; +use itertools::Itertools; + +use super::{constants::RegUInt, r_insn::RInstructionConfig, RIVInstruction}; +use crate::{ + circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, uint::UIntValue, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; +use std::marker::PhantomData; + +#[derive(Debug)] +pub struct ArithConfig { + r_insn: RInstructionConfig, + + multiplier_1: RegUInt, + multiplier_2: RegUInt, + outcome: RegUInt, +} + +pub struct ArithInstruction(PhantomData<(E, I)>); + +pub struct MulOp; +impl RIVInstruction for MulOp { + const INST_KIND: InsnKind = InsnKind::MUL; +} +pub type MulInstruction = ArithInstruction; + +impl Instruction for ArithInstruction { + type InstructionConfig = ArithConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result { + let mut multiplier_1 = RegUInt::new_unchecked(|| "multiplier_1", circuit_builder)?; + let mut multiplier_2 = RegUInt::new_unchecked(|| "multiplier_2", circuit_builder)?; + let outcome = multiplier_1.mul(|| "outcome", circuit_builder, &mut multiplier_2, true)?; + + let r_insn = RInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + &multiplier_1, + &multiplier_2, + &outcome, + )?; + + Ok(ArithConfig { + r_insn, + multiplier_1, + multiplier_2, + outcome, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.r_insn.assign_instance(instance, lkm, step)?; + + let multiplier_1 = UIntValue::new_unchecked(step.rs1().unwrap().value); + let multiplier_2 = UIntValue::new_unchecked(step.rs2().unwrap().value); + let outcome = UIntValue::new_unchecked(step.rd().unwrap().value.after); + + config + .multiplier_1 + .assign_limbs(instance, multiplier_1.u16_fields()); + config + .multiplier_2 + .assign_limbs(instance, multiplier_2.u16_fields()); + let (_, carries) = multiplier_1.mul(&multiplier_2, lkm, true); + + config.outcome.assign_limbs(instance, outcome.u16_fields()); + config.outcome.assign_carries( + instance, + carries + .into_iter() + .map(|carry| E::BaseField::from(carry as u64)) + .collect_vec(), + ); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use ceno_emul::{Change, StepRecord}; + use goldilocks::GoldilocksExt2; + use itertools::Itertools; + use multilinear_extensions::mle::IntoMLEs; + + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + scheme::mock_prover::{MockProver, MOCK_PC_MUL, MOCK_PROGRAM}, + }; + + use super::MulInstruction; + + #[test] + fn test_opcode_mul() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb))) + .unwrap() + .unwrap(); + + // values assignment + let (raw_witin, _) = MulInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_r_instruction( + 3, + MOCK_PC_MUL, + MOCK_PROGRAM[2], + 11, + 2, + Change::new(0, 22), + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); + } + + #[test] + fn test_opcode_mul_overflow() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb))) + .unwrap() + .unwrap(); + + // values assignment + let (raw_witin, _) = MulInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_r_instruction( + 3, + MOCK_PC_MUL, + MOCK_PROGRAM[2], + u32::MAX / 2 + 1, + 2, + Change::new(0, 0), + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); + } + + #[test] + fn test_opcode_mul_overflow2() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb))) + .unwrap() + .unwrap(); + + // values assignment + let (raw_witin, _) = MulInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_r_instruction( + 3, + MOCK_PC_MUL, + MOCK_PROGRAM[2], + 4294901760, + 4294901760, + Change::new(0, 0), + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); + } +} diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 8dea82b70..2d6c77107 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -15,17 +15,26 @@ use itertools::Itertools; use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; use std::{collections::HashSet, hash::Hash, marker::PhantomData, ops::Neg, sync::OnceLock}; +pub const MOCK_RS1: u32 = 2; +pub const MOCK_RS2: u32 = 3; +pub const MOCK_RD: u32 = 4; /// The program baked in the MockProver. /// TODO: Make this a parameter? pub const MOCK_PROGRAM: &[u32] = &[ + // R-Type + // funct7 | rs2 | rs1 | funct3 | rd | opcode + // ----------------------------------------- // add x4, x2, x3 - 0x00 << 25 | 3 << 20 | 2 << 15 | 4 << 7 | 0x33, + 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33, // sub x4, x2, x3 - 0x20 << 25 | 3 << 20 | 2 << 15 | 4 << 7 | 0x33, + 0x20 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33, + // mul (0x01, 0x00, 0x33) + 0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33, ]; // Addresses of particular instructions in the mock program. pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); pub const MOCK_PC_SUB: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 4); +pub const MOCK_PC_MUL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 8); #[allow(clippy::enum_variant_names)] #[derive(Debug, PartialEq, Clone)] diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index c89aa329b..519d54bbb 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -545,134 +545,242 @@ impl + Copy> UIntValue { carries.iter().for_each(|c| lkm.assert_ux::<16>(*c as u64)); (limbs, carries) } + + pub fn mul( + &self, + rhs: &Self, + lkm: &mut LkMultiplicity, + with_overflow: bool, + ) -> (Vec, Vec) { + let a_limbs = self.as_u16_limbs(); + let b_limbs = rhs.as_u16_limbs(); + + let num_limbs = a_limbs.len(); + let mut c_limbs = vec![0u16; num_limbs]; + let mut carries = vec![0u16; num_limbs]; + a_limbs.iter().enumerate().for_each(|(i, a_limb)| { + b_limbs.iter().enumerate().for_each(|(j, b_limb)| { + let idx = i + j; + if idx < num_limbs { + let (c, overflow_mul) = a_limb.overflowing_mul(*b_limb); + let (ret, overflow_add) = c_limbs[idx].overflowing_add(c); + + c_limbs[idx] = ret; + carries[idx] += (overflow_add as u16) + (overflow_mul as u16); + } + }) + }); + // complete the computation by adding prev_carry + (1..num_limbs).for_each(|i| { + if carries[i - 1] > 0 { + let (ret, overflow) = c_limbs[i].overflowing_add(carries[i - 1]); + c_limbs[i] = ret; + carries[i] += overflow as u16; + } + }); + + if !with_overflow { + // If the outcome overflows, `with_overflow` can't be false + assert_eq!(carries[carries.len() - 1], 0, "incorrect overflow flag"); + carries.resize(carries.len() - 1, 0); + } + + // range check + c_limbs.iter().for_each(|c| lkm.assert_ux::<16>(*c as u64)); + carries.iter().for_each(|c| lkm.assert_ux::<16>(*c as u64)); + + (c_limbs, carries) + } } -// #[cfg(test)] -// mod tests { -// use crate::uint::uint::UInt; -// use gkr::structs::{Circuit, CircuitWitness}; -// use goldilocks::{Goldilocks, GoldilocksExt2}; -// use itertools::Itertools; -// use simple_frontend::structs::CircuitBuilder; - -// #[test] -// fn test_uint_from_cell_ids() { -// // 33 total bits and each cells holds just 4 bits -// // to hold all 33 bits without truncations, we'd need 9 cells -// // 9 * 4 = 36 > 33 -// type UInt33 = UInt<33, 4>; -// assert!(UInt33::try_from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).is_ok()); -// assert!(UInt33::try_from(vec![1, 2, 3]).is_err()); -// assert!(UInt33::try_from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).is_err()); -// } - -// #[test] -// fn test_uint_from_different_sized_cell_values() { -// // build circuit -// let mut circuit_builder = CircuitBuilder::::new(); -// let (_, small_values) = circuit_builder.create_witness_in(8); -// type UInt30 = UInt<30, 6>; -// let uint_instance = -// UInt30::from_different_sized_cell_values(&mut circuit_builder, &small_values, 2, true) -// .unwrap(); -// circuit_builder.configure(); -// let circuit = Circuit::new(&circuit_builder); - -// // input -// // we start with cells of bit width 2 (8 of them) -// // 11 00 10 11 01 10 01 01 (bit representation) -// // 3 0 2 3 1 2 1 1 (field representation) -// // -// // repacking into cells of bit width 6 -// // 110010 110110 010100 -// // since total bit = 30 then expect 5 cells ( 30 / 6) -// // since we have 3 cells, we need to pad with 2 more -// // hence expected output: -// // 100011 100111 000101 000000 000000(bit representation) -// // 35 39 5 0 0 - -// let witness_values = vec![3, 0, 2, 3, 1, 2, 1, 1] -// .into_iter() -// .map(|v| Goldilocks::from(v)) -// .collect_vec(); -// let circuit_witness = { -// let challenges = vec![GoldilocksExt2::from(2)]; -// let mut circuit_witness = CircuitWitness::new(&circuit, challenges); -// circuit_witness.add_instance(&circuit, vec![witness_values]); -// circuit_witness -// }; -// circuit_witness.check_correctness(&circuit); - -// let output = circuit_witness.output_layer_witness_ref().instances[0].to_vec(); -// assert_eq!( -// &output[..5], -// vec![35, 39, 5, 0, 0] -// .into_iter() -// .map(|v| Goldilocks::from(v)) -// .collect_vec() -// ); - -// // padding to power of 2 -// assert_eq!( -// &output[5..], -// vec![0, 0, 0] -// .into_iter() -// .map(|v| Goldilocks::from(v)) -// .collect_vec() -// ); -// } - -// #[test] -// fn test_counter_vector() { -// // each limb has 5 bits so all number from 0..3 should require only 1 limb -// type UInt30 = UInt<30, 5>; -// let res = UInt30::counter_vector::(3); -// assert_eq!( -// res, -// vec![ -// vec![Goldilocks::from(0)], -// vec![Goldilocks::from(1)], -// vec![Goldilocks::from(2)] -// ] -// ); - -// // each limb has a single bit, number from 0..5 should require 3 limbs each -// type UInt50 = UInt<50, 1>; -// let res = UInt50::counter_vector::(5); -// assert_eq!( -// res, -// vec![ -// // 0 -// vec![ -// Goldilocks::from(0), -// Goldilocks::from(0), -// Goldilocks::from(0) -// ], -// // 1 -// vec![ -// Goldilocks::from(1), -// Goldilocks::from(0), -// Goldilocks::from(0) -// ], -// // 2 -// vec![ -// Goldilocks::from(0), -// Goldilocks::from(1), -// Goldilocks::from(0) -// ], -// // 3 -// vec![ -// Goldilocks::from(1), -// Goldilocks::from(1), -// Goldilocks::from(0) -// ], -// // 4 -// vec![ -// Goldilocks::from(0), -// Goldilocks::from(0), -// Goldilocks::from(1) -// ], -// ] -// ); -// } -// } +#[cfg(test)] +mod tests { + use crate::witness::LkMultiplicity; + + use super::UIntValue; + + #[test] + fn test_add() { + let a = UIntValue::new_unchecked(1u32); + let b = UIntValue::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.add(&b, &mut lkm, true); + assert_eq!(c[0], 3); + assert_eq!(c[1], 0); + assert_eq!(carries[0], false); + assert_eq!(carries[1], false); + } + + #[test] + fn test_add_carry() { + let a = UIntValue::new_unchecked(u16::MAX as u32); + let b = UIntValue::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.add(&b, &mut lkm, true); + assert_eq!(c[0], 1); + assert_eq!(c[1], 1); + assert_eq!(carries[0], true); + assert_eq!(carries[1], false); + } + + #[test] + fn test_mul() { + let a = UIntValue::new_unchecked(1u32); + let b = UIntValue::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.mul(&b, &mut lkm, true); + assert_eq!(c[0], 2); + assert_eq!(c[1], 0); + assert_eq!(carries[0], 0); + assert_eq!(carries[1], 0); + } + + #[test] + fn test_mul_carry() { + let a = UIntValue::new_unchecked(u16::MAX as u32); + let b = UIntValue::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.mul(&b, &mut lkm, true); + assert_eq!(c[0], u16::MAX - 1); + assert_eq!(c[1], 1); + assert_eq!(carries[0], 1); + assert_eq!(carries[1], 0); + } + + #[test] + fn test_mul_overflow() { + let a = UIntValue::new_unchecked(u32::MAX / 2 + 1); + let b = UIntValue::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.mul(&b, &mut lkm, true); + assert_eq!(c[0], 0); + assert_eq!(c[1], 0); + assert_eq!(carries[0], 0); + assert_eq!(carries[1], 1); + } + // #[test] + // fn test_uint_from_cell_ids() { + // // 33 total bits and each cells holds just 4 bits + // // to hold all 33 bits without truncations, we'd need 9 cells + // // 9 * 4 = 36 > 33 + // type UInt33 = UInt<33, 4>; + // assert!(UInt33::try_from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).is_ok()); + // assert!(UInt33::try_from(vec![1, 2, 3]).is_err()); + // assert!(UInt33::try_from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).is_err()); + // } + + // #[test] + // fn test_uint_from_different_sized_cell_values() { + // // build circuit + // let mut circuit_builder = CircuitBuilder::::new(); + // let (_, small_values) = circuit_builder.create_witness_in(8); + // type UInt30 = UInt<30, 6>; + // let uint_instance = + // UInt30::from_different_sized_cell_values(&mut circuit_builder, &small_values, 2, true) + // .unwrap(); + // circuit_builder.configure(); + // let circuit = Circuit::new(&circuit_builder); + + // // input + // // we start with cells of bit width 2 (8 of them) + // // 11 00 10 11 01 10 01 01 (bit representation) + // // 3 0 2 3 1 2 1 1 (field representation) + // // + // // repacking into cells of bit width 6 + // // 110010 110110 010100 + // // since total bit = 30 then expect 5 cells ( 30 / 6) + // // since we have 3 cells, we need to pad with 2 more + // // hence expected output: + // // 100011 100111 000101 000000 000000(bit representation) + // // 35 39 5 0 0 + + // let witness_values = vec![3, 0, 2, 3, 1, 2, 1, 1] + // .into_iter() + // .map(|v| Goldilocks::from(v)) + // .collect_vec(); + // let circuit_witness = { + // let challenges = vec![GoldilocksExt2::from(2)]; + // let mut circuit_witness = CircuitWitness::new(&circuit, challenges); + // circuit_witness.add_instance(&circuit, vec![witness_values]); + // circuit_witness + // }; + // circuit_witness.check_correctness(&circuit); + + // let output = circuit_witness.output_layer_witness_ref().instances[0].to_vec(); + // assert_eq!( + // &output[..5], + // vec![35, 39, 5, 0, 0] + // .into_iter() + // .map(|v| Goldilocks::from(v)) + // .collect_vec() + // ); + + // // padding to power of 2 + // assert_eq!( + // &output[5..], + // vec![0, 0, 0] + // .into_iter() + // .map(|v| Goldilocks::from(v)) + // .collect_vec() + // ); + // } + + // #[test] + // fn test_counter_vector() { + // // each limb has 5 bits so all number from 0..3 should require only 1 limb + // type UInt30 = UInt<30, 5>; + // let res = UInt30::counter_vector::(3); + // assert_eq!( + // res, + // vec![ + // vec![Goldilocks::from(0)], + // vec![Goldilocks::from(1)], + // vec![Goldilocks::from(2)] + // ] + // ); + + // // each limb has a single bit, number from 0..5 should require 3 limbs each + // type UInt50 = UInt<50, 1>; + // let res = UInt50::counter_vector::(5); + // assert_eq!( + // res, + // vec![ + // // 0 + // vec![ + // Goldilocks::from(0), + // Goldilocks::from(0), + // Goldilocks::from(0) + // ], + // // 1 + // vec![ + // Goldilocks::from(1), + // Goldilocks::from(0), + // Goldilocks::from(0) + // ], + // // 2 + // vec![ + // Goldilocks::from(0), + // Goldilocks::from(1), + // Goldilocks::from(0) + // ], + // // 3 + // vec![ + // Goldilocks::from(1), + // Goldilocks::from(1), + // Goldilocks::from(0) + // ], + // // 4 + // vec![ + // Goldilocks::from(0), + // Goldilocks::from(0), + // Goldilocks::from(1) + // ], + // ] + // ); + // } +} diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 0bbb1daff..4bff1907e 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -721,6 +721,16 @@ mod tests { witness_values: Vec, overflow: bool, ) { + let pow_of_c: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; + let single_wit_size = UInt::::NUM_CELLS; + if overflow { + assert_eq!( + witness_values.len() % single_wit_size, + 0, + "witness len is incorrect" + ) + } + let mut cs = ConstraintSystem::new(|| "test_mul"); let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); @@ -731,8 +741,6 @@ mod tests { .mul(|| "uint_c", &mut cb, &mut uint_b, overflow) .unwrap(); - let pow_of_c: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; - let single_wit_size = UInt::::NUM_CELLS; let wit_end_idx = if overflow { 4 * single_wit_size } else { From 36392db8d63e214c44366b45e5fa229a0c60a22d Mon Sep 17 00:00:00 2001 From: naure Date: Wed, 18 Sep 2024 09:54:01 +0200 Subject: [PATCH 3/3] All range tables (#236) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _Issue #214_ A generalization of #234 for all range or scalar tables. - Moved and refactored the implementation of u16. - Implementation without generics because I find it cleaner and it compiles faster. - Definition of separate circuits `U5TableCircuit`, `U8TableCircuit`, `U16TableCircuit` using a parameter trait. - Fix soundness of `assert_u8`. --------- Co-authored-by: Aurélien Nicolas --- ceno_zkvm/examples/riscv_add.rs | 8 +- ceno_zkvm/src/chip_handler/general.rs | 6 +- ceno_zkvm/src/scheme/mock_prover.rs | 13 +++ ceno_zkvm/src/structs.rs | 3 +- ceno_zkvm/src/tables/mod.rs | 2 +- ceno_zkvm/src/tables/range.rs | 104 +++++--------------- ceno_zkvm/src/tables/range/range_circuit.rs | 59 +++++++++++ ceno_zkvm/src/tables/range/range_impl.rs | 93 +++++++++++++++++ ceno_zkvm/src/witness.rs | 5 +- 9 files changed, 204 insertions(+), 89 deletions(-) create mode 100644 ceno_zkvm/src/tables/range/range_circuit.rs create mode 100644 ceno_zkvm/src/tables/range/range_impl.rs diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index d42367172..4a7f3b0f3 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -12,7 +12,7 @@ use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM}; use ceno_zkvm::{ scheme::verifier::ZKVMVerifier, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, - tables::RangeTableCircuit, + tables::U16TableCircuit, }; use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; @@ -92,12 +92,12 @@ fn main() { // keygen let mut zkvm_cs = ZKVMConstraintSystem::default(); let add_config = zkvm_cs.register_opcode_circuit::>(); - let range_config = zkvm_cs.register_table_circuit::>(); + let range_config = zkvm_cs.register_table_circuit::>(); let prog_config = zkvm_cs.register_table_circuit::>(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); - zkvm_fixed_traces.register_table_circuit::>( + zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, range_config.clone(), &(), @@ -148,7 +148,7 @@ fn main() { zkvm_witness.finalize_lk_multiplicities(); // assign table circuits zkvm_witness - .assign_table_circuit::>(&zkvm_cs, &range_config, &()) + .assign_table_circuit::>(&zkvm_cs, &range_config, &()) .unwrap(); zkvm_witness .assign_table_circuit::>( diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 29c86e463..30265af78 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -216,7 +216,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - self.assert_u16(name_fn, expr * Expression::from(1 << 8)) + let items: Vec> = vec![(ROMType::U8 as usize).into(), expr]; + let rlc_record = self.rlc_chip_record(items); + self.lk_record(name_fn, rlc_record)?; + Ok(()) } pub(crate) fn assert_bit( @@ -228,6 +231,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { + // TODO: Replace with `x * (1 - x)` or a multi-bit lookup similar to assert_u8_pair. self.assert_u16(name_fn, expr * Expression::from(1 << 15)) } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 2d6c77107..68846499d 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -248,6 +248,18 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> } } + fn load_u8_table( + t_vec: &mut Vec>, + cb: &CircuitBuilder, + challenge: [E; 2], + ) { + for i in 0..=u8::MAX as usize { + let rlc_record = cb.rlc_chip_record(vec![(ROMType::U8 as usize).into(), i.into()]); + let rlc_record = eval_by_expr(&[], &challenge, &rlc_record); + t_vec.push(rlc_record.to_repr().as_ref().to_vec()); + } + } + fn load_u16_table( t_vec: &mut Vec>, cb: &CircuitBuilder, @@ -347,6 +359,7 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> let mut table_vec = vec![]; // TODO load more tables here load_u5_table(&mut table_vec, cb, challenge); + load_u8_table(&mut table_vec, cb, challenge); load_u16_table(&mut table_vec, cb, challenge); load_lt_table(&mut table_vec, cb, challenge); load_and_table(&mut table_vec, cb, challenge); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1a522df88..2e2c3e7ce 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -39,9 +39,10 @@ pub struct TowerProverSpec<'a, E: ExtensionField> { pub type WitnessId = u16; pub type ChallengeId = u16; -#[derive(Debug)] +#[derive(Copy, Clone, Debug)] pub enum ROMType { U5 = 0, // 2^5 = 32 + U8, // 2^8 = 256 U16, // 2^16 = 65,536 And, // a ^ b where a, b are bytes Ltu, // a <(usign) b where a, b are bytes diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 7c9a29ed4..de92495ff 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -3,7 +3,7 @@ use ff_ext::ExtensionField; use std::collections::HashMap; mod range; -pub use range::RangeTableCircuit; +pub use range::*; mod program; pub use program::{InsnRecord, ProgramTableCircuit}; diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index 2b195ea63..c4b73f840 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -1,89 +1,35 @@ -use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; +//! Definition of the range tables and their circuits. -use crate::{ - circuit_builder::CircuitBuilder, - error::ZKVMError, - expression::{Expression, Fixed, ToExpr, WitIn}, - scheme::constants::MIN_PAR_SIZE, - set_fixed_val, set_val, - structs::ROMType, - tables::TableCircuit, - uint::constants::RANGE_CHIP_BIT_WIDTH, - witness::RowMajorMatrix, -}; -use ff_ext::ExtensionField; -use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +mod range_impl; -#[derive(Clone, Debug)] -pub struct RangeTableConfig { - u16_tbl: Fixed, - u16_mlt: WitIn, -} - -pub struct RangeTableCircuit(PhantomData); - -impl TableCircuit for RangeTableCircuit { - type TableConfig = RangeTableConfig; - type FixedInput = (); - type WitnessInput = (); - - fn name() -> String { - "RANGE".into() - } - - fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let u16_tbl = cb.create_fixed(|| "u16_tbl")?; - let u16_mlt = cb.create_witin(|| "u16_mlt")?; +mod range_circuit; +use range_circuit::{RangeTable, RangeTableCircuit}; - let u16_table_values = cb.rlc_chip_record(vec![ - Expression::Constant(E::BaseField::from(ROMType::U16 as u64)), - Expression::Fixed(u16_tbl.clone()), - ]); +use crate::structs::ROMType; - cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?; - - Ok(RangeTableConfig { u16_tbl, u16_mlt }) +pub struct U5Table; +impl RangeTable for U5Table { + const ROM_TYPE: ROMType = ROMType::U5; + fn len() -> usize { + 1 << 5 } +} +pub type U5TableCircuit = RangeTableCircuit; - fn generate_fixed_traces( - config: &RangeTableConfig, - num_fixed: usize, - _input: &(), - ) -> RowMajorMatrix { - let num_u16s = 1 << 16; - let mut fixed = RowMajorMatrix::::new(num_u16s, num_fixed); - fixed - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip((0..num_u16s).into_par_iter()) - .for_each(|(row, i)| { - set_fixed_val!(row, config.u16_tbl, E::BaseField::from(i as u64)); - }); - - fixed +pub struct U8Table; +impl RangeTable for U8Table { + const ROM_TYPE: ROMType = ROMType::U8; + fn len() -> usize { + 1 << 8 } +} +pub type U8TableCircuit = RangeTableCircuit; - fn assign_instances( - config: &Self::TableConfig, - num_witin: usize, - multiplicity: &[HashMap], - _input: &(), - ) -> Result, ZKVMError> { - let multiplicity = &multiplicity[ROMType::U16 as usize]; - let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH]; - for (limb, mlt) in multiplicity { - u16_mlt[*limb as usize] = *mlt; - } - - let mut witness = RowMajorMatrix::::new(u16_mlt.len(), num_witin); - witness - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip(u16_mlt.into_par_iter()) - .for_each(|(row, mlt)| { - set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64)); - }); - - Ok(witness) +pub struct U16Table; +impl RangeTable for U16Table { + const ROM_TYPE: ROMType = ROMType::U16; + fn len() -> usize { + 1 << 16 } } +pub type U16TableCircuit = RangeTableCircuit; diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs new file mode 100644 index 000000000..e1132f2be --- /dev/null +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -0,0 +1,59 @@ +//! Range tables as circuits with trait TableCircuit. + +use super::range_impl::RangeTableConfig; + +use std::{collections::HashMap, marker::PhantomData}; + +use crate::{ + circuit_builder::CircuitBuilder, error::ZKVMError, structs::ROMType, tables::TableCircuit, + witness::RowMajorMatrix, +}; +use ff_ext::ExtensionField; + +/// Use this trait as parameter to RangeTableCircuit. +pub trait RangeTable { + const ROM_TYPE: ROMType; + + fn len() -> usize; + + fn content() -> Vec { + (0..Self::len() as u64).collect() + } +} + +pub struct RangeTableCircuit(PhantomData<(E, R)>); + +impl TableCircuit for RangeTableCircuit { + type TableConfig = RangeTableConfig; + type FixedInput = (); + type WitnessInput = (); + + fn name() -> String { + format!("RANGE_{:?}", RANGE::ROM_TYPE) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + cb.namespace( + || Self::name(), + |cb| RangeTableConfig::construct_circuit(cb, RANGE::ROM_TYPE), + ) + } + + fn generate_fixed_traces( + config: &RangeTableConfig, + num_fixed: usize, + _input: &(), + ) -> RowMajorMatrix { + config.generate_fixed_traces(num_fixed, RANGE::content()) + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + multiplicity: &[HashMap], + _input: &(), + ) -> Result, ZKVMError> { + let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize]; + config.assign_instances(num_witin, multiplicity, RANGE::len()) + } +} diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs new file mode 100644 index 000000000..8344c43f1 --- /dev/null +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -0,0 +1,93 @@ +//! The implementation of range tables. No generics. + +use ff_ext::ExtensionField; +use goldilocks::SmallField; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +use std::{collections::HashMap, mem::MaybeUninit}; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, Fixed, ToExpr, WitIn}, + scheme::constants::MIN_PAR_SIZE, + set_fixed_val, set_val, + structs::ROMType, + witness::RowMajorMatrix, +}; + +#[derive(Clone, Debug)] +pub struct RangeTableConfig { + fixed: Fixed, + mlt: WitIn, +} + +impl RangeTableConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + rom_type: ROMType, + ) -> Result { + let fixed = cb.create_fixed(|| "fixed")?; + let mlt = cb.create_witin(|| "mlt")?; + + let rlc_record = cb.rlc_chip_record(vec![ + (rom_type as usize).into(), + Expression::Fixed(fixed.clone()), + ]); + + cb.lk_table_record(|| "record", rlc_record, mlt.expr())?; + + Ok(Self { fixed, mlt }) + } + + pub fn generate_fixed_traces( + &self, + num_fixed: usize, + content: Vec, + ) -> RowMajorMatrix { + let mut fixed = RowMajorMatrix::::new(content.len(), num_fixed); + + // Fill the padding with zeros, if any. + fixed.par_iter_mut().skip(content.len()).for_each(|row| { + set_fixed_val!(row, self.fixed, F::ZERO); + }); + + fixed + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(content.into_par_iter()) + .for_each(|(row, i)| { + set_fixed_val!(row, self.fixed, F::from(i)); + }); + + fixed + } + + pub fn assign_instances( + &self, + num_witin: usize, + multiplicity: &HashMap, + length: usize, + ) -> Result, ZKVMError> { + let mut witness = RowMajorMatrix::::new(length, num_witin); + + let mut mlts = vec![0; length]; + for (idx, mlt) in multiplicity { + mlts[*idx as usize] = *mlt; + } + + // Fill the padding with zeros, if any. + witness.par_iter_mut().skip(length).for_each(|row| { + set_val!(row, self.mlt, F::ZERO); + }); + + witness + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(mlts.into_par_iter()) + .for_each(|(row, mlt)| { + set_val!(row, self.mlt, F::from(mlt as u64)); + }); + + Ok(witness) + } +} diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 263b6645c..b49694874 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -126,11 +126,10 @@ impl LkMultiplicity { } fn assert_byte(&mut self, v: u64) { - let v = v * (1 << u8::BITS); let multiplicity = self .multiplicity .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); - (*multiplicity.borrow_mut()[ROMType::U16 as usize] + (*multiplicity.borrow_mut()[ROMType::U8 as usize] .entry(v) .or_default()) += 1; } @@ -189,6 +188,6 @@ mod tests { } let res = lkm.into_finalize_result(); // check multiplicity counts of assert_byte - assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count); + assert_eq!(res[ROMType::U8 as usize][&8], thread_count); } }