diff --git a/Makefile.toml b/Makefile.toml index 8d6ab5edf..a62827fcc 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -11,10 +11,6 @@ CUR_TARGET = { script = [''' '''] } RAYON_NUM_THREADS = "${CORE}" -[tasks.build] -# Override the default `--all-features`, that's broken, because some of our features are mutually exclusive. -args = ["build"] - [tasks.tests] args = [ "test", diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 4bb427eb9..f3217a3bc 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -52,12 +52,9 @@ pprof2.workspace = true glob = "0.3" [features] -default = ["riv32", "forbid_overflow"] +default = ["forbid_overflow"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] -non_pow2_rayon_thread = [] -riv32 = [] -riv64 = [] [[bench]] harness = false diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index 8d16f342d..991fad3d7 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -3,7 +3,7 @@ use ff_ext::ExtensionField; use crate::{ error::ZKVMError, expression::{Expression, ToExpr}, - gadgets::AssertLTConfig, + gadgets::AssertLtConfig, instructions::riscv::constants::UINT_LIMBS, }; @@ -34,7 +34,7 @@ pub trait RegisterChipOperations, N: FnOnce( prev_ts: Expression, ts: Expression, value: RegisterExpr, - ) -> Result<(Expression, AssertLTConfig), ZKVMError>; + ) -> Result<(Expression, AssertLtConfig), ZKVMError>; #[allow(clippy::too_many_arguments)] fn register_write( @@ -45,7 +45,7 @@ pub trait RegisterChipOperations, N: FnOnce( ts: Expression, prev_values: RegisterExpr, value: RegisterExpr, - ) -> Result<(Expression, AssertLTConfig), ZKVMError>; + ) -> Result<(Expression, AssertLtConfig), ZKVMError>; } /// The common representation of a memory address. @@ -62,7 +62,7 @@ pub trait MemoryChipOperations, N: FnOnce() prev_ts: Expression, ts: Expression, value: MemoryExpr, - ) -> Result<(Expression, AssertLTConfig), ZKVMError>; + ) -> Result<(Expression, AssertLtConfig), ZKVMError>; #[allow(clippy::too_many_arguments)] fn memory_write( @@ -73,5 +73,5 @@ pub trait MemoryChipOperations, N: FnOnce() ts: Expression, prev_values: MemoryExpr, value: MemoryExpr, - ) -> Result<(Expression, AssertLTConfig), ZKVMError>; + ) -> Result<(Expression, AssertLtConfig), ZKVMError>; } diff --git a/ceno_zkvm/src/chip_handler/memory.rs b/ceno_zkvm/src/chip_handler/memory.rs index 3b6b2ec53..2969eeb7a 100644 --- a/ceno_zkvm/src/chip_handler/memory.rs +++ b/ceno_zkvm/src/chip_handler/memory.rs @@ -3,7 +3,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::Expression, - gadgets::AssertLTConfig, + gadgets::AssertLtConfig, instructions::riscv::constants::UINT_LIMBS, structs::RAMType, }; @@ -19,7 +19,7 @@ impl, N: FnOnce() -> NR> MemoryChipOperation prev_ts: Expression, ts: Expression, value: MemoryExpr, - ) -> Result<(Expression, AssertLTConfig), ZKVMError> { + ) -> Result<(Expression, AssertLtConfig), ZKVMError> { self.namespace(name_fn, |cb| { // READ (a, v, t) let read_record = [ @@ -39,7 +39,7 @@ impl, N: FnOnce() -> NR> MemoryChipOperation cb.write_record(|| "write_record", RAMType::Memory, write_record)?; // assert prev_ts < current_ts - let lt_cfg = AssertLTConfig::construct_circuit( + let lt_cfg = AssertLtConfig::construct_circuit( cb, || "prev_ts < ts", prev_ts, @@ -61,7 +61,7 @@ impl, N: FnOnce() -> NR> MemoryChipOperation ts: Expression, prev_values: MemoryExpr, value: MemoryExpr, - ) -> Result<(Expression, AssertLTConfig), ZKVMError> { + ) -> Result<(Expression, AssertLtConfig), ZKVMError> { self.namespace(name_fn, |cb| { // READ (a, v, t) let read_record = [ @@ -80,7 +80,7 @@ impl, N: FnOnce() -> NR> MemoryChipOperation cb.read_record(|| "read_record", RAMType::Memory, read_record)?; cb.write_record(|| "write_record", RAMType::Memory, write_record)?; - let lt_cfg = AssertLTConfig::construct_circuit( + let lt_cfg = AssertLtConfig::construct_circuit( cb, || "prev_ts < ts", prev_ts, diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index d2f1ebf14..c85060cc6 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -4,7 +4,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr}, - gadgets::AssertLTConfig, + gadgets::AssertLtConfig, instructions::riscv::constants::UINT_LIMBS, structs::RAMType, }; @@ -21,7 +21,7 @@ impl, N: FnOnce() -> NR> RegisterChipOperati prev_ts: Expression, ts: Expression, value: RegisterExpr, - ) -> Result<(Expression, AssertLTConfig), ZKVMError> { + ) -> Result<(Expression, AssertLtConfig), ZKVMError> { self.namespace(name_fn, |cb| { // READ (a, v, t) let read_record = [ @@ -43,7 +43,7 @@ impl, N: FnOnce() -> NR> RegisterChipOperati cb.write_record(|| "write_record", RAMType::Register, write_record)?; // assert prev_ts < current_ts - let lt_cfg = AssertLTConfig::construct_circuit( + let lt_cfg = AssertLtConfig::construct_circuit( cb, || "prev_ts < ts", prev_ts, @@ -65,7 +65,7 @@ impl, N: FnOnce() -> NR> RegisterChipOperati ts: Expression, prev_values: RegisterExpr, value: RegisterExpr, - ) -> Result<(Expression, AssertLTConfig), ZKVMError> { + ) -> Result<(Expression, AssertLtConfig), ZKVMError> { assert!(register_id.expr().degree() <= 1); self.namespace(name_fn, |cb| { // READ (a, v, t) @@ -87,7 +87,7 @@ impl, N: FnOnce() -> NR> RegisterChipOperati cb.read_record(|| "read_record", RAMType::Register, read_record)?; cb.write_record(|| "write_record", RAMType::Register, write_record)?; - let lt_cfg = AssertLTConfig::construct_circuit( + let lt_cfg = AssertLtConfig::construct_circuit( cb, || "prev_ts < ts", prev_ts, diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index bad41b1b4..c2b523014 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -590,17 +590,17 @@ macro_rules! mixed_binop_instances { mixed_binop_instances!( Add, add, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) + (u8, u16, u32, u64, usize, i8, i16, i32, i64, isize) ); mixed_binop_instances!( Sub, sub, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) + (u8, u16, u32, u64, usize, i8, i16, i32, i64, isize) ); mixed_binop_instances!( Mul, mul, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) + (u8, u16, u32, u64, usize, i8, i16, i32, i64, isize) ); impl Mul for Expression { @@ -835,14 +835,6 @@ macro_rules! impl_from_unsigned { } impl_from_unsigned!(u8, u16, u32, u64, usize, RAMType, InsnKind); -// Implement From trait for u128 separately since it requires explicit reduction -impl> From for Expression { - fn from(value: u128) -> Self { - let reduced = value.rem_euclid(F::MODULUS_U64 as u128) as u64; - Expression::Constant(F::from(reduced)) - } -} - // Implement From trait for signed types macro_rules! impl_from_signed { ($($t:ty),*) => { @@ -856,7 +848,7 @@ macro_rules! impl_from_signed { )* }; } -impl_from_signed!(i8, i16, i32, i64, i128, isize); +impl_from_signed!(i8, i16, i32, i64, isize); impl Display for Expression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/ceno_zkvm/src/gadgets/div.rs b/ceno_zkvm/src/gadgets/div.rs index 95b7cc49f..d3808349b 100644 --- a/ceno_zkvm/src/gadgets/div.rs +++ b/ceno_zkvm/src/gadgets/div.rs @@ -10,13 +10,13 @@ use crate::{ witness::LkMultiplicity, }; -use super::AssertLTConfig; +use super::AssertLtConfig; /// divide gadget #[derive(Debug, Clone)] pub struct DivConfig { pub dividend: UInt, - pub r_lt: AssertLTConfig, + pub r_lt: AssertLtConfig, pub intermediate_mul: UInt, } @@ -35,7 +35,7 @@ impl DivConfig { let (dividend, intermediate_mul) = divisor.mul_add(|| "divisor * outcome + r", cb, quotient, remainder, true)?; - let r_lt = AssertLTConfig::construct_circuit( + let r_lt = AssertLtConfig::construct_circuit( cb, || "remainder < divisor", remainder.value(), diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index 967efc868..c9dedaa03 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -19,9 +19,9 @@ use crate::{ use super::SignedExtendConfig; #[derive(Debug, Clone)] -pub struct AssertLTConfig(InnerLtConfig); +pub struct AssertLtConfig(InnerLtConfig); -impl AssertLTConfig { +impl AssertLtConfig { pub fn construct_circuit< E: ExtensionField, NR: Into + Display + Clone, diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 60846581e..0f21b3ce2 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -5,7 +5,7 @@ mod signed_ext; pub use div::DivConfig; pub use is_lt::{ - AssertLTConfig, AssertSignedLtConfig, InnerLtConfig, IsLtConfig, SignedLtConfig, cal_lt_diff, + AssertLtConfig, AssertSignedLtConfig, InnerLtConfig, IsLtConfig, SignedLtConfig, cal_lt_diff, }; pub use is_zero::{IsEqualConfig, IsZeroConfig}; pub use signed_ext::SignedExtendConfig; diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index fb3858ddc..b604f7c92 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -14,10 +14,8 @@ pub const PUBLIC_IO_IDX: usize = 6; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; -#[cfg(feature = "riv32")] pub const BIT_WIDTH: usize = 32usize; -#[cfg(feature = "riv64")] -pub const BIT_WIDTH: usize = 64usize; + pub type UInt = UIntLimbs; pub type UIntMul = UIntLimbs<{ 2 * BIT_WIDTH }, LIMB_BITS, E>; /// use UInt for x bits limb size diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index 95dba096d..d3d0d90a7 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -3,7 +3,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{ToExpr, WitIn}, - gadgets::AssertLTConfig, + gadgets::AssertLtConfig, instructions::{ Instruction, riscv::{ @@ -21,7 +21,7 @@ use std::marker::PhantomData; pub struct HaltConfig { ecall_cfg: EcallInstructionConfig, prev_x10_ts: WitIn, - lt_x10_cfg: AssertLTConfig, + lt_x10_cfg: AssertLtConfig, } pub struct HaltInstruction(PhantomData); diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index 123667ed1..078223617 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -5,7 +5,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, - gadgets::AssertLTConfig, + gadgets::AssertLtConfig, set_val, tables::InsnRecord, witness::LkMultiplicity, @@ -17,7 +17,7 @@ pub struct EcallInstructionConfig { pub pc: WitIn, pub ts: WitIn, prev_x5_ts: WitIn, - lt_x5_cfg: AssertLTConfig, + lt_x5_cfg: AssertLtConfig, } impl EcallInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index e94785937..468eb3623 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -12,7 +12,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, - gadgets::AssertLTConfig, + gadgets::AssertLtConfig, set_val, uint::Value, witness::LkMultiplicity, @@ -76,7 +76,7 @@ impl StateInOut { pub struct ReadRS1 { pub id: WitIn, pub prev_ts: WitIn, - pub lt_cfg: AssertLTConfig, + pub lt_cfg: AssertLtConfig, _field_type: PhantomData, } @@ -130,7 +130,7 @@ impl ReadRS1 { pub struct ReadRS2 { pub id: WitIn, pub prev_ts: WitIn, - pub lt_cfg: AssertLTConfig, + pub lt_cfg: AssertLtConfig, _field_type: PhantomData, } @@ -185,7 +185,7 @@ pub struct WriteRD { pub id: WitIn, pub prev_ts: WitIn, pub prev_value: UInt, - pub lt_cfg: AssertLTConfig, + pub lt_cfg: AssertLtConfig, } impl WriteRD { @@ -245,7 +245,7 @@ impl WriteRD { #[derive(Debug)] pub struct ReadMEM { pub prev_ts: WitIn, - pub lt_cfg: AssertLTConfig, + pub lt_cfg: AssertLtConfig, _field_type: PhantomData, } @@ -300,7 +300,7 @@ impl ReadMEM { #[derive(Debug)] pub struct WriteMEM { pub prev_ts: WitIn, - pub lt_cfg: AssertLTConfig, + pub lt_cfg: AssertLtConfig, } impl WriteMEM { diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs index 536ab15ce..2214c437b 100644 --- a/ceno_zkvm/src/instructions/riscv/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -431,9 +431,7 @@ impl Signed { lkm, *val.as_u16_limbs().last().unwrap() as u64, )?; - let signed_val = val.as_u32() as i32; - - Ok(signed_val) + Ok(i32::from(val)) } pub fn expr(&self) -> Expression { diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 792472c4d..e375912af 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -7,7 +7,7 @@ use crate::{ Value, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, - gadgets::{AssertLTConfig, SignedExtendConfig}, + gadgets::{AssertLtConfig, SignedExtendConfig}, instructions::Instruction, set_val, }; @@ -26,7 +26,7 @@ pub struct ShiftConfig { pow2_rs2_low5: WitIn, outflow: WitIn, - assert_lt_config: AssertLTConfig, + assert_lt_config: AssertLtConfig, // SRA signed_extend_config: Option>, @@ -90,7 +90,7 @@ impl Instruction for ShiftLogicalInstru let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?; let outflow = circuit_builder.create_witin(|| "outflow"); - let assert_lt_config = AssertLTConfig::construct_circuit( + let assert_lt_config = AssertLtConfig::construct_circuit( circuit_builder, || "outflow < pow2_rs2_low5", outflow.expr(), diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 771121f89..50ebe4ff0 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -4,7 +4,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, - gadgets::{AssertLTConfig, SignedExtendConfig}, + gadgets::{AssertLtConfig, SignedExtendConfig}, instructions::{ Instruction, riscv::{constants::UInt, i_insn::IInstructionConfig}, @@ -24,7 +24,7 @@ pub struct ShiftImmConfig { rs1_read: UInt, rd_written: UInt, outflow: WitIn, - assert_lt_config: AssertLTConfig, + assert_lt_config: AssertLtConfig, // SRAI is_lt_config: Option>, @@ -83,7 +83,7 @@ impl Instruction for ShiftImmInstructio let rd_written = UInt::new(|| "rd_written", circuit_builder)?; let outflow = circuit_builder.create_witin(|| "outflow"); - let assert_lt_config = AssertLTConfig::construct_circuit( + let assert_lt_config = AssertLtConfig::construct_circuit( circuit_builder, || "outflow < imm", outflow.expr(), diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index e59aa6411..21b6d4c05 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1208,7 +1208,7 @@ mod tests { ROMType::U5, error::ZKVMError, expression::{ToExpr, WitIn}, - gadgets::{AssertLTConfig, IsLtConfig}, + gadgets::{AssertLtConfig, IsLtConfig}, instructions::InstancePaddingStrategy, set_val, witness::{LkMultiplicity, RowMajorMatrix}, @@ -1353,7 +1353,7 @@ mod tests { struct AssertLtCircuit { pub a: WitIn, pub b: WitIn, - pub lt_wtns: AssertLTConfig, + pub lt_wtns: AssertLtConfig, } struct AssertLtCircuitInput { @@ -1365,7 +1365,7 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a"); let b = cb.create_witin(|| "b"); - let lt_wtns = AssertLTConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = AssertLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; Ok(Self { a, b, lt_wtns }) } diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 0af75208d..b11490395 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -8,7 +8,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::{UtilError, ZKVMError}, expression::{Expression, ToExpr, WitIn}, - gadgets::{AssertLTConfig, SignedExtendConfig}, + gadgets::{AssertLtConfig, SignedExtendConfig}, instructions::riscv::constants::UInt, utils::add_one_to_big_num, witness::LkMultiplicity, @@ -83,7 +83,7 @@ pub struct UIntLimbs { // We don't need `overflow` witness since the last element of `carries` represents it. pub carries: Option>, // for carry range check using lt tricks - pub carries_auxiliary_lt_config: Option>, + pub carries_auxiliary_lt_config: Option>, } impl UIntLimbs { @@ -131,7 +131,7 @@ impl UIntLimbs { pub fn from_witins_unchecked( limbs: Vec, carries: Option>, - carries_auxiliary_lt_config: Option>, + carries_auxiliary_lt_config: Option>, ) -> Self { assert!(limbs.len() == Self::NUM_LIMBS); if let Some(carries) = &carries { @@ -606,6 +606,30 @@ pub struct Value<'a, T: Into + From + Copy + Default> { pub limbs: Cow<'a, [u16]>, } +impl<'a, T: Into + From + Copy + Default> From<&'a Value<'a, T>> for &'a [u16] { + fn from(v: &'a Value<'a, T>) -> Self { + v.as_u16_limbs() + } +} + +impl<'a, T: Into + From + Copy + Default> From<&Value<'a, T>> for u64 { + fn from(v: &Value<'a, T>) -> Self { + v.as_u64() + } +} + +impl<'a, T: Into + From + Copy + Default> From<&Value<'a, T>> for u32 { + fn from(v: &Value<'a, T>) -> Self { + v.as_u32() + } +} + +impl<'a, T: Into + From + Copy + Default> From<&Value<'a, T>> for i32 { + fn from(v: &Value<'a, T>) -> Self { + v.as_i32() + } +} + // TODO generalize to support non 16 bit limbs // TODO optimize api with fixed size array impl<'a, T: Into + From + Copy + Default> Value<'a, T> { @@ -616,10 +640,7 @@ impl<'a, T: Into + From + Copy + Default> Value<'a, T> { const LIMBS: usize = (Self::M + 15) / 16; pub fn new(val: T, lkm: &mut LkMultiplicity) -> Self { - let uint = Value:: { - val, - limbs: Cow::Owned(Self::split_to_u16(val)), - }; + let uint = Self::new_unchecked(val); Self::assert_u16(&uint.limbs, lkm); uint } @@ -684,6 +705,11 @@ impl<'a, T: Into + From + Copy + Default> Value<'a, T> { self.as_u64() as u32 } + /// Convert the limbs to an i32 value + pub fn as_i32(&self) -> i32 { + self.as_u32() as i32 + } + pub fn u16_fields(&self) -> Vec { self.limbs.iter().map(|v| F::from(*v as u64)).collect_vec() } diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index dfe33b076..5e48b0319 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -7,7 +7,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, - gadgets::AssertLTConfig, + gadgets::AssertLtConfig, instructions::riscv::config::IsEqualConfig, }; @@ -137,7 +137,7 @@ impl UIntLimbs { .iter() .enumerate() .map(|(i, carry)| { - AssertLTConfig::construct_circuit( + AssertLtConfig::construct_circuit( circuit_builder, || format!("carry_{i}_in_less_than"), carry.expr(), @@ -145,7 +145,7 @@ impl UIntLimbs { Self::MAX_DEGREE_2_MUL_CARRY_U16_LIMB, ) }) - .collect::, ZKVMError>>()?; + .collect::, ZKVMError>>()?; // creating a witness constrained as expression to reduce overall degree let mut swap_witin = |name: &str, diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 092187cba..5d747be06 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -29,6 +29,3 @@ criterion.workspace = true [[bench]] harness = false name = "devirgo_sumcheck" - -[features] -non_pow2_rayon_thread = [] diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 09d916e21..4578b7ec2 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -1,7 +1,5 @@ #![deny(clippy::cargo)] #![feature(decl_macro)] -#[cfg(feature = "non_pow2_rayon_thread")] -pub mod local_thread_pool; pub mod macros; mod prover; mod prover_v2; diff --git a/sumcheck/src/local_thread_pool.rs b/sumcheck/src/local_thread_pool.rs deleted file mode 100644 index 8df2f620e..000000000 --- a/sumcheck/src/local_thread_pool.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::sync::{Arc, Once}; - -use rayon::ThreadPool; - -pub(crate) static mut LOCAL_THREAD_POOL: Option> = None; -static LOCAL_THREAD_POOL_SET: Once = Once::new(); - -pub fn create_local_pool_once(size: usize, in_place: bool) { - unsafe { - let size = if in_place { size - 1 } else { size }; - let pool_size = LOCAL_THREAD_POOL - .as_ref() - .map(|a| a.current_num_threads()) - .unwrap_or(0); - if pool_size > 0 && pool_size != size { - panic!( - "calling prove_batch_polys with different polys size. prev size {} vs now size {}", - pool_size, size - ); - } - LOCAL_THREAD_POOL_SET.call_once(|| { - let _ = Some(&*LOCAL_THREAD_POOL.get_or_insert_with(|| { - Arc::new( - rayon::ThreadPoolBuilder::new() - .num_threads(size) - .build() - .unwrap(), - ) - })); - }); - } -} diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 7cee2169b..0dc5de870 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -12,9 +12,6 @@ use rayon::{ }; use transcript::{Challenge, Transcript, TranscriptSyncronized}; -#[cfg(feature = "non_pow2_rayon_thread")] -use crate::local_thread_pool::{LOCAL_THREAD_POOL, create_local_pool_once}; - use crate::{ macros::{entered_span, exit_span}, structs::{IOPProof, IOPProverMessage, IOPProverState}, @@ -119,25 +116,11 @@ impl IOPProverState { if rayon::current_num_threads() >= max_thread_id { rayon::spawn(spawn_task); } else { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "rayon global thread pool size {} mismatch with desired poly size {}, add --features non_pow2_rayon_thread", - rayon::current_num_threads(), - polys.len() - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - unsafe { - create_local_pool_once(max_thread_id, true); - - if let Some(pool) = LOCAL_THREAD_POOL.as_ref() { - pool.spawn(spawn_task) - } else { - panic!("empty local pool") - } - } + panic!( + "rayon global thread pool size {} mismatch with desired poly size {}.", + rayon::current_num_threads(), + polys.len() + ); } } diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index 3dae3a101..b4021b77d 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -18,9 +18,6 @@ use rayon::{ }; use transcript::{Challenge, Transcript, TranscriptSyncronized}; -#[cfg(feature = "non_pow2_rayon_thread")] -use crate::local_thread_pool::{LOCAL_THREAD_POOL, create_local_pool_once}; - use crate::{ macros::{entered_span, exit_span}, structs::{IOPProof, IOPProverMessage, IOPProverStateV2}, @@ -235,26 +232,11 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { { rayon::in_place_scope(scoped_fn) } else { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "rayon global thread pool size {} mismatch with desired poly size {}, add - --features non_pow2_rayon_thread", - rayon::current_num_threads(), - polys.len() - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - unsafe { - create_local_pool_once(max_thread_id, true); - - if let Some(pool) = LOCAL_THREAD_POOL.as_ref() { - pool.scope(scoped_fn) - } else { - panic!("empty local pool") - } - } + panic!( + "rayon global thread pool size {} mismatch with desired poly size {}.", + rayon::current_num_threads(), + polys.len() + ); }; if log2_max_thread_id == 0 {