diff --git a/Cargo.toml b/Cargo.toml index 552212b3..53c98f1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "snark-verifier", "snark-verifier-sdk", ] +resolver = "2" [profile.dev] opt-level = 3 diff --git a/rust-toolchain b/rust-toolchain index 51ab4759..ee2d639b 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2022-10-28 \ No newline at end of file +nightly-2023-08-12 \ No newline at end of file diff --git a/snark-verifier-sdk/Cargo.toml b/snark-verifier-sdk/Cargo.toml index e85f614f..adf30bec 100644 --- a/snark-verifier-sdk/Cargo.toml +++ b/snark-verifier-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "snark-verifier-sdk" -version = "0.1.1" +version = "0.1.2" edition = "2021" [dependencies] @@ -53,7 +53,7 @@ parallel = ["snark-verifier/parallel"] halo2-pse = ["snark-verifier/halo2-pse", "dep:serde_with"] halo2-axiom = ["snark-verifier/halo2-axiom"] -zkevm = ["dep:zkevm-circuits", "dep:bus-mapping", "dep:mock", "dep:eth-types"] +# zkevm = ["dep:zkevm-circuits", "dep:bus-mapping", "dep:mock", "dep:eth-types"] [[bench]] name = "standard_plonk" diff --git a/snark-verifier-sdk/benches/read_pk.rs b/snark-verifier-sdk/benches/read_pk.rs index 02d25ec6..55154a2e 100644 --- a/snark-verifier-sdk/benches/read_pk.rs +++ b/snark-verifier-sdk/benches/read_pk.rs @@ -1,7 +1,7 @@ use ark_std::{end_timer, start_timer}; use criterion::Criterion; use criterion::{criterion_group, criterion_main}; -use halo2_base::gates::builder::BASE_CONFIG_PARAMS; +use halo2_base::gates::builder::CircuitBuilderStage; use halo2_base::halo2_proofs; use halo2_base::utils::fs::gen_srs; use halo2_proofs::halo2curves as halo2_curves; @@ -182,30 +182,37 @@ fn bench(c: &mut Criterion) { let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); let agg_config = AggregationConfigParams::from_path(path); - BASE_CONFIG_PARAMS.with(|params| *params.borrow_mut() = agg_config.into()); let params = gen_srs(agg_config.degree); - let agg_circuit = AggregationCircuit::keygen::(¶ms, snarks); + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Keygen, + agg_config, + None, + ¶ms, + snarks, + ); + std::fs::remove_file("examples/agg.pk").ok(); let start0 = start_timer!(|| "gen vk & pk"); gen_pk(¶ms, &agg_circuit, Some(Path::new("examples/agg.pk"))); end_timer!(start0); let mut group = c.benchmark_group("read-pk"); group.sample_size(10); - group.bench_with_input("1mb", &(1024 * 1024), |b, &c| { - b.iter(|| read_pk_with_capacity::(c, "examples/agg.pk")) + group.bench_with_input("buffer 1mb capacity", &(1024 * 1024), |b, &c| { + b.iter(|| read_pk_with_capacity::(c, "examples/agg.pk", agg_config)) }); - group.bench_with_input("10mb", &(10 * 1024 * 1024), |b, &c| { - b.iter(|| read_pk_with_capacity::(c, "examples/agg.pk")) + group.bench_with_input("buffer 10mb capacity", &(10 * 1024 * 1024), |b, &c| { + b.iter(|| read_pk_with_capacity::(c, "examples/agg.pk", agg_config)) }); - group.bench_with_input("100mb", &(100 * 1024 * 1024), |b, &c| { - b.iter(|| read_pk_with_capacity::(c, "examples/agg.pk")) + group.bench_with_input("buffer 100mb capacity", &(100 * 1024 * 1024), |b, &c| { + b.iter(|| read_pk_with_capacity::(c, "examples/agg.pk", agg_config)) }); - group.bench_with_input("1gb", &(1024 * 1024 * 1024), |b, &c| { - b.iter(|| read_pk_with_capacity::(c, "examples/agg.pk")) + group.bench_with_input("buffer 1gb capacity", &(1024 * 1024 * 1024), |b, &c| { + b.iter(|| read_pk_with_capacity::(c, "examples/agg.pk", agg_config)) }); group.finish(); + std::fs::remove_file("examples/agg.pk").unwrap(); } criterion_group! { diff --git a/snark-verifier-sdk/benches/standard_plonk.rs b/snark-verifier-sdk/benches/standard_plonk.rs index f696f822..eecb7140 100644 --- a/snark-verifier-sdk/benches/standard_plonk.rs +++ b/snark-verifier-sdk/benches/standard_plonk.rs @@ -1,7 +1,7 @@ use ark_std::{end_timer, start_timer}; use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; -use halo2_base::gates::builder::{CircuitBuilderStage, BASE_CONFIG_PARAMS}; +use halo2_base::gates::builder::CircuitBuilderStage; use halo2_base::halo2_proofs; use halo2_base::utils::fs::gen_srs; use halo2_proofs::halo2curves as halo2_curves; @@ -185,11 +185,15 @@ fn bench(c: &mut Criterion) { let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); let agg_config = AggregationConfigParams::from_path(path); - BASE_CONFIG_PARAMS.with(|params| *params.borrow_mut() = agg_config.into()); let params = gen_srs(agg_config.degree); - let lookup_bits = params.k() as usize - 1; - let agg_circuit = AggregationCircuit::keygen::(¶ms, snarks.clone()); + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Keygen, + agg_config, + None, + ¶ms, + snarks.clone(), + ); let start0 = start_timer!(|| "gen vk & pk"); let pk = gen_pk(¶ms, &agg_circuit, Some(Path::new("agg.pk"))); @@ -205,8 +209,8 @@ fn bench(c: &mut Criterion) { b.iter(|| { let agg_circuit = AggregationCircuit::new::( CircuitBuilderStage::Prover, + agg_config, Some(break_points.clone()), - lookup_bits, params, snarks.clone(), ); @@ -222,8 +226,8 @@ fn bench(c: &mut Criterion) { // do one more time to verify let agg_circuit = AggregationCircuit::new::( CircuitBuilderStage::Prover, + agg_config, Some(break_points), - lookup_bits, ¶ms, snarks.clone(), ); diff --git a/snark-verifier-sdk/src/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs index 9f9ac64a..2d093f7b 100644 --- a/snark-verifier-sdk/src/halo2/aggregation.rs +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -4,8 +4,7 @@ use halo2_base::{ gates::{ builder::{ BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - PublicBaseConfig, RangeCircuitBuilder, RangeWithInstanceCircuitBuilder, - BASE_CONFIG_PARAMS, + PublicBaseConfig, RangeWithInstanceCircuitBuilder, }, flex_gate::GateStrategy, RangeChip, @@ -14,10 +13,7 @@ use halo2_base::{ circuit::{Layouter, SimpleFloorPlanner}, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::{self, Circuit, ConstraintSystem, Selector}, - poly::{ - commitment::{Params, ParamsProver}, - kzg::commitment::ParamsKZG, - }, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }, utils::ScalarField, AssignedValue, @@ -139,7 +135,7 @@ where /// Same as `FlexGateConfigParams` except we assume a single Phase and default 'Vertical' strategy. /// Also adds `lookup_bits` field. -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize)] pub struct AggregationConfigParams { pub degree: u32, pub num_advice: usize, @@ -168,6 +164,40 @@ impl From for BaseConfigParams { } } +impl TryFrom<&BaseConfigParams> for AggregationConfigParams { + type Error = &'static str; + + fn try_from(params: &BaseConfigParams) -> Result { + if params.num_advice_per_phase.iter().skip(1).any(|&n| n != 0) { + return Err("AggregationConfigParams only supports 1 phase"); + } + if params.num_lookup_advice_per_phase.iter().skip(1).any(|&n| n != 0) { + return Err("AggregationConfigParams only supports 1 phase"); + } + if params.lookup_bits.is_none() { + return Err("AggregationConfigParams requires lookup_bits"); + } + Ok(Self { + degree: params.k as u32, + num_advice: params.num_advice_per_phase[0], + num_lookup_advice: params.num_lookup_advice_per_phase[0], + num_fixed: params.num_fixed, + lookup_bits: params.lookup_bits.unwrap(), + }) + } +} + +/// Holds virtual contexts for the cells used to verify a collection of snarks +#[derive(Clone, Debug)] +pub struct AggregationCtxBuilder { + /// Virtual region with virtual contexts (columns) + pub builder: GateThreadBuilder, + /// The limbs of the pair of elliptic curve points that need to be verified in a final pairing check. + pub accumulator: Vec>, + // the public instances from previous snarks that were aggregated + pub previous_instances: Vec>>, +} + #[derive(Clone, Debug)] pub struct AggregationCircuit { pub inner: RangeWithInstanceCircuitBuilder, @@ -175,7 +205,7 @@ pub struct AggregationCircuit { // the user can optionally append these to `inner.assigned_instances` to expose them pub previous_instances: Vec>>, // accumulation scheme proof, private input - pub as_proof: Vec, // not sure this needs to be stored, keeping for now + // pub as_proof: Vec, } // trait just so we can have a generic that is either SHPLONK or GWC @@ -201,17 +231,15 @@ pub trait Halo2KzgAccumulationScheme<'a> = PolynomialCommitmentScheme< VerifyingKey = KzgAsVerifyingKey, > + AccumulationSchemeProver>; -impl AggregationCircuit { - /// Given snarks, this creates a circuit and runs the `GateThreadBuilder` to verify all the snarks. - /// By default, the returned circuit has public instances equal to the limbs of the pair of elliptic curve points, referred to as the `accumulator`, that need to be verified in a final pairing check. +impl AggregationCtxBuilder { + /// Given snarks, this runs the `GateThreadBuilder` to verify all the snarks. /// - /// The user can optionally modify the circuit after calling this function to add more instances to `assigned_instances` to expose. + /// Also returns the limbs of the pair of elliptic curve points, referred to as the `accumulator`, that need to be verified in a final pairing check. /// /// Warning: will fail silently if `snarks` were created using a different multi-open scheme than `AS` /// where `AS` can be either [`crate::SHPLONK`] or [`crate::GWC`] (for original PLONK multi-open scheme) pub fn new( - stage: CircuitBuilderStage, - break_points: Option, + witness_gen_only: bool, lookup_bits: usize, params: &ParamsKZG, snarks: impl IntoIterator, @@ -254,11 +282,7 @@ impl AggregationCircuit { }; // create thread builder and run aggregation witness gen - let builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; + let builder = GateThreadBuilder::new(witness_gen_only); // create halo2loader let range = RangeChip::::default(lookup_bits); let fp_chip = FpChip::::new(&range, BITS, LIMBS); @@ -269,7 +293,7 @@ impl AggregationCircuit { aggregate::(&svk, &loader, &snarks, as_proof.as_slice()); let lhs = accumulator.lhs.assigned(); let rhs = accumulator.rhs.assigned(); - let assigned_instances = lhs + let accumulator = lhs .x() .limbs() .iter() @@ -284,73 +308,65 @@ impl AggregationCircuit { let KzgAccumulator { lhs, rhs } = _accumulator; let instances = [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, Fr, LIMBS, BITS>).concat(); - for (lhs, rhs) in instances.iter().zip(assigned_instances.iter()) { + for (lhs, rhs) in instances.iter().zip(accumulator.iter()) { assert_eq!(lhs, rhs.value()); } } let builder = loader.take_ctx(); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, break_points.unwrap()) - } - }; - let inner = RangeWithInstanceCircuitBuilder::new(circuit, assigned_instances); - Self { inner, previous_instances, as_proof } + Self { builder, accumulator, previous_instances } } +} - pub fn public( +impl AggregationCircuit { + /// Given snarks, this creates a circuit and runs the `GateThreadBuilder` to verify all the snarks. + /// By default, the returned circuit has public instances equal to the limbs of the pair of elliptic curve points, referred to as the `accumulator`, that need to be verified in a final pairing check. + /// + /// The user can optionally modify the circuit after calling this function to add more instances to `assigned_instances` to expose. + /// + /// Warning: will fail silently if `snarks` were created using a different multi-open scheme than `AS` + /// where `AS` can be either [`crate::SHPLONK`] or [`crate::GWC`] (for original PLONK multi-open scheme) + pub fn new( stage: CircuitBuilderStage, + agg_config: AggregationConfigParams, break_points: Option, - lookup_bits: usize, params: &ParamsKZG, snarks: impl IntoIterator, - has_prev_accumulator: bool, ) -> Self where AS: for<'a> Halo2KzgAccumulationScheme<'a>, { - let mut private = Self::new::(stage, break_points, lookup_bits, params, snarks); - private.expose_previous_instances(has_prev_accumulator); - private - } - - // this function is for convenience - /// `params` should be the universal trusted setup to be used for the aggregation circuit, not the one used to generate the previous snarks, although we assume both use the same generator g[0] - pub fn keygen(params: &ParamsKZG, snarks: impl IntoIterator) -> Self - where - AS: for<'a> Halo2KzgAccumulationScheme<'a>, - { - let lookup_bits = BASE_CONFIG_PARAMS - .with(|conf| conf.borrow().lookup_bits) - .unwrap_or(params.k() as usize - 1); - let circuit = - Self::new::(CircuitBuilderStage::Keygen, None, lookup_bits, params, snarks); - circuit.config(params.k(), Some(10)); - circuit + let AggregationCtxBuilder { builder, accumulator, previous_instances } = + AggregationCtxBuilder::new::( + stage == CircuitBuilderStage::Prover, + agg_config.lookup_bits, + params, + snarks, + ); + let inner = RangeWithInstanceCircuitBuilder::from_stage( + stage, + builder, + agg_config.into(), + break_points, + accumulator, + ); + Self { inner, previous_instances } } - // this function is for convenience - pub fn prover( + pub fn public( + stage: CircuitBuilderStage, + agg_config: AggregationConfigParams, + break_points: Option, params: &ParamsKZG, snarks: impl IntoIterator, - break_points: MultiPhaseThreadBreakPoints, + has_prev_accumulator: bool, ) -> Self where AS: for<'a> Halo2KzgAccumulationScheme<'a>, { - let lookup_bits = BASE_CONFIG_PARAMS - .with(|conf| conf.borrow().lookup_bits) - .unwrap_or(params.k() as usize - 1); - Self::new::( - CircuitBuilderStage::Prover, - Some(break_points), - lookup_bits, - params, - snarks, - ) + let mut private = Self::new::(stage, agg_config, break_points, params, snarks); + private.expose_previous_instances(has_prev_accumulator); + private } /// Re-expose the previous public instances of aggregated snarks again. @@ -363,12 +379,12 @@ impl AggregationCircuit { } } - pub fn as_proof(&self) -> &[u8] { - &self.as_proof[..] - } - - pub fn config(&self, k: u32, minimum_rows: Option) -> BaseConfigParams { - self.inner.config(k, minimum_rows) + /// Auto-configure the circuit and change the circuit's internal configuration parameters. + pub fn config(&mut self, k: u32, minimum_rows: Option) -> BaseConfigParams { + let mut new_config = self.inner.circuit.0.builder.borrow().config(k as usize, minimum_rows); + new_config.lookup_bits = self.inner.circuit.0.config_params.lookup_bits; + self.inner.circuit.0.config_params = new_config.clone(); + new_config } pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { @@ -401,13 +417,25 @@ impl CircuitExt for RangeWithInstanceCircuitBuilder { impl Circuit for AggregationCircuit { type Config = PublicBaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = AggregationConfigParams; + + fn params(&self) -> Self::Params { + (&self.inner.circuit.0.config_params).try_into().unwrap() + } fn without_witnesses(&self) -> Self { unimplemented!() } - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - RangeWithInstanceCircuitBuilder::configure(meta) + fn configure_with_params( + meta: &mut ConstraintSystem, + params: Self::Params, + ) -> Self::Config { + RangeWithInstanceCircuitBuilder::configure_with_params(meta, params.into()) + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() } fn synthesize( diff --git a/snark-verifier-sdk/src/lib.rs b/snark-verifier-sdk/src/lib.rs index 07a15bc7..40401a4c 100644 --- a/snark-verifier-sdk/src/lib.rs +++ b/snark-verifier-sdk/src/lib.rs @@ -78,13 +78,14 @@ pub trait CircuitExt: Circuit { } } -pub fn read_pk>(path: &Path) -> io::Result> { - read_pk_with_capacity::(BUFFER_SIZE, path) +pub fn read_pk>(path: &Path, params: C::Params) -> io::Result> { + read_pk_with_capacity::(BUFFER_SIZE, path, params) } pub fn read_pk_with_capacity>( capacity: usize, path: impl AsRef, + params: C::Params, ) -> io::Result> { let f = File::open(path.as_ref())?; #[cfg(feature = "display")] @@ -97,7 +98,8 @@ pub fn read_pk_with_capacity>( // let initial_buffer_size = f.metadata().map(|m| m.len() as usize + 1).unwrap_or(0); // let mut bufreader = Vec::with_capacity(initial_buffer_size); // f.read_to_end(&mut bufreader)?; - let pk = ProvingKey::read::<_, C>(&mut bufreader, SerdeFormat::RawBytesUnchecked).unwrap(); + let pk = + ProvingKey::read::<_, C>(&mut bufreader, SerdeFormat::RawBytesUnchecked, params).unwrap(); #[cfg(feature = "display")] end_timer!(read_time); @@ -112,7 +114,7 @@ pub fn gen_pk>( path: Option<&Path>, ) -> ProvingKey { if let Some(path) = path { - if let Ok(pk) = read_pk::(path) { + if let Ok(pk) = read_pk::(path, circuit.params()) { return pk; } } diff --git a/snark-verifier/Cargo.toml b/snark-verifier/Cargo.toml index 313fd038..5606df37 100644 --- a/snark-verifier/Cargo.toml +++ b/snark-verifier/Cargo.toml @@ -1,25 +1,28 @@ [package] name = "snark-verifier" -version = "0.1.1" +version = "0.1.2" edition = "2021" [dependencies] -itertools = "0.10.5" -lazy_static = "1.4.0" -num-bigint = "0.4.3" +itertools = "0.11" +lazy_static = "1.4" +num-bigint = "0.4" num-integer = "0.1.45" num-traits = "0.2.15" hex = "0.4" rand = "0.8" serde = { version = "1.0", features = ["derive"] } +pairing = { version = "0.23" } # Use halo2-base as non-optional dependency because it re-exports halo2_proofs, halo2curves, and poseidon, using different repos based on feature flag "halo2-axiom" or "halo2-pse" halo2-base = { git = "https://github.com/axiom-crypto/halo2-lib.git", branch = "develop", default-features = false } # This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). -poseidon-circuit = { git = "https://github.com/scroll-tech/poseidon-circuit.git", rev = "50015b7" } +# We forked it to upgrade to ff v0.13 and removed the circuit module +poseidon-rs = { git = "https://github.com/axiom-crypto/poseidon-circuit.git", rev = "1aee4a1" } +# poseidon-circuit = { git = "https://github.com/scroll-tech/poseidon-circuit.git", rev = "50015b7" } # parallel -rayon = { version = "1.7.0", optional = true } +rayon = { version = "1.7", optional = true } # loader_evm sha3 = { version = "0.10.8", optional = true } diff --git a/snark-verifier/examples/evm-verifier-with-accumulator.rs b/snark-verifier/examples/evm-verifier-with-accumulator.rs index 4ffa6459..41493efa 100644 --- a/snark-verifier/examples/evm-verifier-with-accumulator.rs +++ b/snark-verifier/examples/evm-verifier-with-accumulator.rs @@ -1,6 +1,6 @@ use aggregation::{AggregationCircuit, AggregationConfigParams}; use halo2_base::{ - gates::builder::{set_lookup_bits, CircuitBuilderStage}, + gates::builder::{BaseConfigParams, CircuitBuilderStage}, halo2_proofs, utils::fs::gen_srs, }; @@ -300,7 +300,7 @@ mod aggregation { As::verify(&Default::default(), &accumulators, &proof).unwrap() } - #[derive(serde::Serialize, serde::Deserialize)] + #[derive(serde::Serialize, serde::Deserialize, Default)] pub struct AggregationConfigParams { pub degree: u32, pub num_advice: usize, @@ -318,8 +318,8 @@ mod aggregation { impl AggregationCircuit { pub fn new( stage: CircuitBuilderStage, + config_params: BaseConfigParams, break_points: Option, - lookup_bits: usize, params_g0: G1Affine, snarks: impl IntoIterator, ) -> Self { @@ -355,13 +355,9 @@ mod aggregation { }; // create thread builder and run aggregation witness gen - let builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; + let builder = GateThreadBuilder::from_stage(stage); // create halo2loader - let range = RangeChip::::default(lookup_bits); + let range = RangeChip::::default(config_params.lookup_bits.unwrap()); let fp_chip = FpChip::::new(&range, BITS, LIMBS); let ecc_chip = BaseFieldEccChip::new(&fp_chip); let loader = Halo2Loader::new(ecc_chip, builder); @@ -391,24 +387,25 @@ mod aggregation { } let builder = loader.take_ctx(); - let inner = match stage { - CircuitBuilderStage::Mock => { - RangeWithInstanceCircuitBuilder::mock(builder, assigned_instances) - } - CircuitBuilderStage::Keygen => { - RangeWithInstanceCircuitBuilder::keygen(builder, assigned_instances) - } - CircuitBuilderStage::Prover => RangeWithInstanceCircuitBuilder::prover( - builder, - assigned_instances, - break_points.unwrap(), - ), - }; + let inner = RangeWithInstanceCircuitBuilder::from_stage( + stage, + builder, + config_params, + break_points, + assigned_instances, + ); Self { inner, as_proof } } - pub fn config(&self, k: u32, minimum_rows: Option) -> BaseConfigParams { - self.inner.circuit.0.builder.borrow().config(k as usize, minimum_rows) + pub fn config( + &self, + k: u32, + minimum_rows: Option, + lookup_bits: usize, + ) -> BaseConfigParams { + let mut params = self.inner.circuit.0.builder.borrow().config(k as usize, minimum_rows); + params.lookup_bits = Some(lookup_bits); + params } pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { @@ -432,13 +429,25 @@ mod aggregation { impl Circuit for AggregationCircuit { type Config = PublicBaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = BaseConfigParams; + + fn params(&self) -> Self::Params { + self.inner.circuit.params() + } + + fn configure_with_params( + meta: &mut plonk::ConstraintSystem, + params: Self::Params, + ) -> Self::Config { + RangeWithInstanceCircuitBuilder::configure_with_params(meta, params) + } fn without_witnesses(&self) -> Self { unimplemented!() } - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - RangeWithInstanceCircuitBuilder::configure(meta) + fn configure(_: &mut plonk::ConstraintSystem) -> Self::Config { + unimplemented!() } fn synthesize( @@ -574,15 +583,23 @@ fn main() { File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - set_lookup_bits(agg_config.lookup_bits); - let agg_circuit = AggregationCircuit::new( + let mut config_params = BaseConfigParams { + k: agg_config.degree as usize, + strategy: Default::default(), + num_advice_per_phase: vec![agg_config.num_advice], + num_lookup_advice_per_phase: vec![agg_config.num_lookup_advice], + num_fixed: agg_config.num_fixed, + lookup_bits: Some(agg_config.lookup_bits), + }; + let mut agg_circuit = AggregationCircuit::new( CircuitBuilderStage::Mock, + config_params, None, - agg_config.lookup_bits, params_app.get_g()[0], snarks.clone(), ); - agg_circuit.config(agg_config.degree, Some(6)); + config_params = agg_circuit.config(agg_config.degree, Some(6), agg_config.lookup_bits); + agg_circuit.inner.circuit.0.config_params = config_params.clone(); #[cfg(debug_assertions)] { MockProver::run(agg_config.degree, &agg_circuit, agg_circuit.instances()) @@ -605,8 +622,8 @@ fn main() { let agg_circuit = AggregationCircuit::new( CircuitBuilderStage::Prover, + config_params, Some(break_points), - agg_config.lookup_bits, params_app.get_g()[0], snarks, ); diff --git a/snark-verifier/examples/recursion.rs b/snark-verifier/examples/recursion.rs index 7415e1ab..b469a51a 100644 --- a/snark-verifier/examples/recursion.rs +++ b/snark-verifier/examples/recursion.rs @@ -3,7 +3,7 @@ use ark_std::{end_timer, start_timer}; use common::*; use halo2_base::gates::builder::BaseConfigParams; -use halo2_base::gates::{builder::BASE_CONFIG_PARAMS, flex_gate::GateStrategy}; +use halo2_base::gates::flex_gate::GateStrategy; use halo2_base::halo2_proofs; use halo2_base::utils::fs::gen_srs; use halo2_proofs::{ @@ -12,11 +12,10 @@ use halo2_proofs::{ halo2curves::{ bn256::{Bn256, Fr, G1Affine}, group::ff::Field, - FieldExt, }, plonk::{ - self, create_proof, keygen_pk, keygen_vk, Circuit, ConstraintSystem, Error, ProvingKey, - Selector, VerifyingKey, + create_proof, keygen_pk, keygen_vk, Circuit, ConstraintSystem, Error, ProvingKey, Selector, + VerifyingKey, }, poly::{ commitment::ParamsProver, @@ -192,19 +191,38 @@ mod common { pub fn gen_dummy_snark>( params: &ParamsKZG, vk: Option<&VerifyingKey>, - ) -> Snark { - struct CsProxy(PhantomData<(F, C)>); + config_params: ConcreteCircuit::Params, + ) -> Snark + where + ConcreteCircuit::Params: Clone, + { + struct CsProxy>(C::Params, PhantomData<(F, C)>); - impl> Circuit for CsProxy { + impl> Circuit for CsProxy + where + C::Params: Clone, + { type Config = C::Config; type FloorPlanner = C::FloorPlanner; + type Params = C::Params; + + fn params(&self) -> Self::Params { + self.0.clone() + } fn without_witnesses(&self) -> Self { - CsProxy(PhantomData) + CsProxy(self.0.clone(), PhantomData) + } + + fn configure_with_params( + meta: &mut ConstraintSystem, + params: Self::Params, + ) -> Self::Config { + C::configure_with_params(meta, params) } - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - C::configure(meta) + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() } fn synthesize( @@ -227,9 +245,9 @@ mod common { } } - let dummy_vk = vk - .is_none() - .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); + let dummy_vk = vk.is_none().then(|| { + keygen_vk(params, &CsProxy::(config_params, PhantomData)).unwrap() + }); let protocol = compile( params, vk.or(dummy_vk.as_ref()).unwrap(), @@ -326,10 +344,7 @@ mod application { mod recursion { use halo2_base::{ gates::{ - builder::{ - GateThreadBuilder, PublicBaseConfig, RangeWithInstanceCircuitBuilder, - BASE_CONFIG_PARAMS, - }, + builder::{GateThreadBuilder, PublicBaseConfig, RangeWithInstanceCircuitBuilder}, GateInstructions, RangeChip, RangeInstructions, }, AssignedValue, @@ -456,6 +471,7 @@ mod recursion { round: usize, instances: Vec, as_proof: Vec, + lookup_bits: usize, inner: RangeWithInstanceCircuitBuilder, } @@ -473,6 +489,7 @@ mod recursion { initial_state: Fr, state: Fr, round: usize, + config_params: BaseConfigParams, ) -> Self { let svk = params.get_g()[0].into(); let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); @@ -528,17 +545,25 @@ mod recursion { .collect(); let builder = GateThreadBuilder::mock(); - let inner = RangeWithInstanceCircuitBuilder::mock(builder, vec![]); - let mut circuit = - Self { svk, default_accumulator, app, previous, round, instances, as_proof, inner }; + let lookup_bits = config_params.lookup_bits.unwrap(); + let inner = RangeWithInstanceCircuitBuilder::mock(builder, config_params, vec![]); + let mut circuit = Self { + svk, + default_accumulator, + app, + previous, + round, + instances, + as_proof, + inner, + lookup_bits, + }; circuit.build(); circuit } fn build(&mut self) { - let lookup_bits = - BASE_CONFIG_PARAMS.with(|params| params.borrow().lookup_bits.unwrap()); - let range = RangeChip::::default(lookup_bits); + let range = RangeChip::::default(self.lookup_bits); let main_gate = range.gate(); let mut builder = GateThreadBuilder::mock(); let ctx = &mut builder; @@ -626,8 +651,12 @@ mod recursion { ); } - fn initial_snark(params: &ParamsKZG, vk: Option<&VerifyingKey>) -> Snark { - let mut snark = gen_dummy_snark::(params, vk); + fn initial_snark( + params: &ParamsKZG, + vk: Option<&VerifyingKey>, + config_params: BaseConfigParams, + ) -> Snark { + let mut snark = gen_dummy_snark::(params, vk, config_params); let g = params.get_g(); snark.instances = vec![[g[1].x, g[1].y, g[0].x, g[0].y] .into_iter() @@ -658,13 +687,25 @@ mod recursion { impl Circuit for RecursionCircuit { type Config = PublicBaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = BaseConfigParams; + + fn params(&self) -> Self::Params { + self.inner.circuit.params() + } fn without_witnesses(&self) -> Self { unimplemented!() } - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - RangeWithInstanceCircuitBuilder::configure(meta) + fn configure_with_params( + meta: &mut ConstraintSystem, + params: Self::Params, + ) -> Self::Config { + RangeWithInstanceCircuitBuilder::configure_with_params(meta, params) + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() } fn synthesize( @@ -699,14 +740,20 @@ mod recursion { recursion_params: &ParamsKZG, app_params: &ParamsKZG, app_vk: &VerifyingKey, - ) -> ProvingKey { + recursion_config: BaseConfigParams, + app_config: ConcreteCircuit::Params, + ) -> ProvingKey + where + ConcreteCircuit::Params: Clone, + { let recursion = RecursionCircuit::new( recursion_params, - gen_dummy_snark::(app_params, Some(app_vk)), - RecursionCircuit::initial_snark(recursion_params, None), + gen_dummy_snark::(app_params, Some(app_vk), app_config), + RecursionCircuit::initial_snark(recursion_params, None, recursion_config.clone()), Fr::zero(), Fr::zero(), 0, + recursion_config, ); // we cannot auto-configure the circuit because dummy_snark must know the configuration beforehand // uncomment the following line only in development to test and print out the optimal configuration ahead of time @@ -721,11 +768,15 @@ mod recursion { recursion_pk: &ProvingKey, initial_state: Fr, inputs: Vec, + config_params: BaseConfigParams, ) -> (Fr, Snark) { let mut state = initial_state; let mut app = ConcreteCircuit::new(state); - let mut previous = - RecursionCircuit::initial_snark(recursion_params, Some(recursion_pk.get_vk())); + let mut previous = RecursionCircuit::initial_snark( + recursion_params, + Some(recursion_pk.get_vk()), + config_params.clone(), + ); for (round, input) in inputs.into_iter().enumerate() { state = app.state_transition(input); println!("Generate app snark"); @@ -737,6 +788,7 @@ mod recursion { initial_state, state, round, + config_params.clone(), ); println!("Generate recursion snark"); previous = gen_snark(recursion_params, recursion_pk, recursion); @@ -752,16 +804,14 @@ fn main() { serde_json::from_reader(fs::File::open("configs/example_recursion.json").unwrap()).unwrap(); let k = recursion_config.degree; let recursion_params = gen_srs(k); - BASE_CONFIG_PARAMS.with(|params| { - *params.borrow_mut() = BaseConfigParams { - strategy: GateStrategy::Vertical, - k: k as usize, - num_advice_per_phase: vec![recursion_config.num_advice], - num_lookup_advice_per_phase: vec![recursion_config.num_lookup_advice], - num_fixed: recursion_config.num_fixed, - lookup_bits: Some(recursion_config.lookup_bits), - } - }); + let config_params = BaseConfigParams { + strategy: GateStrategy::Vertical, + k: k as usize, + num_advice_per_phase: vec![recursion_config.num_advice], + num_lookup_advice_per_phase: vec![recursion_config.num_lookup_advice], + num_fixed: recursion_config.num_fixed, + lookup_bits: Some(recursion_config.lookup_bits), + }; let app_pk = gen_pk(&app_params, &application::Square::default()); @@ -770,6 +820,8 @@ fn main() { &recursion_params, &app_params, app_pk.get_vk(), + config_params.clone(), + (), ); end_timer!(pk_time); @@ -782,9 +834,10 @@ fn main() { &recursion_pk, Fr::from(2u64), vec![(); num_round], + config_params.clone(), ); end_timer!(pf_time); - assert_eq!(final_state, Fr::from(2u64).pow(&[1 << num_round, 0, 0, 0])); + assert_eq!(final_state, Fr::from(2u64).pow([1 << num_round])); { let dk = diff --git a/snark-verifier/src/loader.rs b/snark-verifier/src/loader.rs index 77a8f54b..a3637f08 100644 --- a/snark-verifier/src/loader.rs +++ b/snark-verifier/src/loader.rs @@ -122,12 +122,12 @@ pub trait ScalarLoader { /// Load `zero` as constant. fn load_zero(&self) -> Self::LoadedScalar { - self.load_const(&F::zero()) + self.load_const(&F::ZERO) } /// Load `one` as constant. fn load_one(&self) -> Self::LoadedScalar { - self.load_const(&F::one()) + self.load_const(&F::ONE) } /// Assert lhs and rhs field elements are equal. @@ -145,13 +145,13 @@ pub trait ScalarLoader { let loader = values.first().unwrap().1.loader(); iter::empty() - .chain(if constant == F::zero() { + .chain(if constant == F::ZERO { None } else { Some(Cow::Owned(loader.load_const(&constant))) }) .chain(values.iter().map(|&(coeff, value)| { - if coeff == F::one() { + if coeff == F::ONE { Cow::Borrowed(value) } else { Cow::Owned(loader.load_const(&coeff) * value) @@ -174,9 +174,9 @@ pub trait ScalarLoader { let loader = values.first().unwrap().1.loader(); iter::empty() - .chain(if constant == F::zero() { None } else { Some(loader.load_const(&constant)) }) + .chain(if constant == F::ZERO { None } else { Some(loader.load_const(&constant)) }) .chain(values.iter().map(|&(coeff, lhs, rhs)| { - if coeff == F::one() { + if coeff == F::ONE { lhs.clone() * rhs } else { loader.load_const(&coeff) * lhs * rhs @@ -188,20 +188,20 @@ pub trait ScalarLoader { /// Sum field elements with coefficients. fn sum_with_coeff(&self, values: &[(F, &Self::LoadedScalar)]) -> Self::LoadedScalar { - self.sum_with_coeff_and_const(values, F::zero()) + self.sum_with_coeff_and_const(values, F::ZERO) } /// Sum field elements and constant. fn sum_with_const(&self, values: &[&Self::LoadedScalar], constant: F) -> Self::LoadedScalar { self.sum_with_coeff_and_const( - &values.iter().map(|&value| (F::one(), value)).collect_vec(), + &values.iter().map(|&value| (F::ONE, value)).collect_vec(), constant, ) } /// Sum field elements. fn sum(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { - self.sum_with_const(values, F::zero()) + self.sum_with_const(values, F::ZERO) } /// Sum product of field elements with coefficients. @@ -209,7 +209,7 @@ pub trait ScalarLoader { &self, values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], ) -> Self::LoadedScalar { - self.sum_products_with_coeff_and_const(values, F::zero()) + self.sum_products_with_coeff_and_const(values, F::ZERO) } /// Sum product of field elements and constant. @@ -219,7 +219,7 @@ pub trait ScalarLoader { constant: F, ) -> Self::LoadedScalar { self.sum_products_with_coeff_and_const( - &values.iter().map(|&(lhs, rhs)| (F::one(), lhs, rhs)).collect_vec(), + &values.iter().map(|&(lhs, rhs)| (F::ONE, lhs, rhs)).collect_vec(), constant, ) } @@ -229,7 +229,7 @@ pub trait ScalarLoader { &self, values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], ) -> Self::LoadedScalar { - self.sum_products_with_const(values, F::zero()) + self.sum_products_with_const(values, F::ZERO) } /// Product of field elements. diff --git a/snark-verifier/src/loader/evm/loader.rs b/snark-verifier/src/loader/evm/loader.rs index 98ca5ca4..bfb37c8c 100644 --- a/snark-verifier/src/loader/evm/loader.rs +++ b/snark-verifier/src/loader/evm/loader.rs @@ -684,8 +684,8 @@ impl> ScalarLoader for Rc { } let push_addend = |(coeff, value): &(F, &Scalar)| { - assert_ne!(*coeff, F::zero()); - match (*coeff == F::one(), &value.value) { + assert_ne!(*coeff, F::ZERO); + match (*coeff == F::ONE, &value.value) { (true, _) => self.push(value), (false, Value::Constant(value)) => self.push( &self.scalar(Value::Constant(fe_to_u256(*coeff * u256_to_fe::(*value)))), @@ -699,7 +699,7 @@ impl> ScalarLoader for Rc { }; let mut values = values.iter(); - let initial_value = if constant == F::zero() { + let initial_value = if constant == F::ZERO { push_addend(values.next().unwrap()) } else { self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))) @@ -733,8 +733,8 @@ impl> ScalarLoader for Rc { } let push_addend = |(coeff, lhs, rhs): &(F, &Scalar, &Scalar)| { - assert_ne!(*coeff, F::zero()); - match (*coeff == F::one(), &lhs.value, &rhs.value) { + assert_ne!(*coeff, F::ZERO); + match (*coeff == F::ONE, &lhs.value, &rhs.value) { (_, Value::Constant(lhs), Value::Constant(rhs)) => { self.push(&self.scalar(Value::Constant(fe_to_u256( *coeff * u256_to_fe::(*lhs) * u256_to_fe::(*rhs), @@ -764,7 +764,7 @@ impl> ScalarLoader for Rc { }; let mut values = values.iter(); - let initial_value = if constant == F::zero() { + let initial_value = if constant == F::ZERO { push_addend(values.next().unwrap()) } else { self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))) diff --git a/snark-verifier/src/loader/evm/test/tui.rs b/snark-verifier/src/loader/evm/test/tui.rs index 328082c7..9bd68bb9 100644 --- a/snark-verifier/src/loader/evm/test/tui.rs +++ b/snark-verifier/src/loader/evm/test/tui.rs @@ -45,12 +45,7 @@ impl Tui { let backend = CrosstermBackend::new(stdout); let mut terminal = Terminal::new(backend).unwrap(); terminal.hide_cursor().unwrap(); - Tui { - debug_arena, - terminal, - key_buffer: String::new(), - current_step, - } + Tui { debug_arena, terminal, key_buffer: String::new(), current_step } } pub fn start(mut self) { @@ -91,11 +86,8 @@ impl Tui { let mut draw_memory: DrawMemory = DrawMemory::default(); let debug_call = &self.debug_arena; - let mut opcode_list: Vec = debug_call[0] - .1 - .iter() - .map(|step| step.pretty_opcode()) - .collect(); + let mut opcode_list: Vec = + debug_call[0].1.iter().map(|step| step.pretty_opcode()).collect(); let mut last_index = 0; let mut stack_labels = false; @@ -385,12 +377,8 @@ impl Tui { if let [op_pane, stack_pane, memory_pane] = Layout::default() .direction(Direction::Vertical) .constraints( - [ - Constraint::Ratio(1, 3), - Constraint::Ratio(1, 3), - Constraint::Ratio(1, 3), - ] - .as_ref(), + [Constraint::Ratio(1, 3), Constraint::Ratio(1, 3), Constraint::Ratio(1, 3)] + .as_ref(), ) .split(app)[..] { @@ -412,14 +400,7 @@ impl Tui { stack_labels, draw_memory, ); - Tui::draw_memory( - f, - debug_steps, - current_step, - memory_pane, - mem_utf, - draw_memory, - ); + Tui::draw_memory(f, debug_steps, current_step, memory_pane, mem_utf, draw_memory); } else { panic!("unable to create vertical panes") } @@ -536,15 +517,11 @@ impl Tui { let prev_start = draw_memory.current_startline; let abs_min_start = 0; let abs_max_start = (opcode_list.len() as i32 - 1) - (height / 2); - let mut min_start = max( - current_step as i32 - height + extra_top_lines, - abs_min_start, - ) as usize; + let mut min_start = + max(current_step as i32 - height + extra_top_lines, abs_min_start) as usize; - let mut max_start = max( - min(current_step as i32 - extra_top_lines, abs_max_start), - abs_min_start, - ) as usize; + let mut max_start = + max(min(current_step as i32 - extra_top_lines, abs_max_start), abs_min_start) as usize; if min_start > max_start { std::mem::swap(&mut min_start, &mut max_start); @@ -559,18 +536,11 @@ impl Tui { } draw_memory.current_startline = display_start; - let max_pc_len = debug_steps - .iter() - .fold(0, |max_val, val| val.pc.max(max_val)) - .to_string() - .len(); + let max_pc_len = + debug_steps.iter().fold(0, |max_val, val| val.pc.max(max_val)).to_string().len(); let mut add_new_line = |line_number| { - let bg_color = if line_number == current_step { - Color::DarkGray - } else { - Color::Reset - }; + let bg_color = if line_number == current_step { Color::DarkGray } else { Color::Reset }; let line_number_format = if line_number == current_step { let step: &DebugStep = &debug_steps[line_number]; @@ -598,9 +568,8 @@ impl Tui { add_new_line(number); } add_new_line(opcode_list.len()); - let paragraph = Paragraph::new(text_output) - .block(block_source_code) - .wrap(Wrap { trim: true }); + let paragraph = + Paragraph::new(text_output).block(block_source_code).wrap(Wrap { trim: true }); f.render_widget(paragraph, area); } @@ -610,12 +579,11 @@ impl Tui { current_step: usize, area: Rect, stack_labels: bool, - draw_memory: &mut DrawMemory, + draw_memory: &DrawMemory, ) { let stack = &debug_steps[current_step].stack; - let stack_space = Block::default() - .title(format!("Stack: {}", stack.len())) - .borders(Borders::ALL); + let stack_space = + Block::default().title(format!("Stack: {}", stack.len())).borders(Borders::ALL); let min_len = usize::max(format!("{}", stack.len()).len(), 2); let indices_affected = stack_indices_affected(debug_steps[current_step].instruction.0); @@ -626,12 +594,10 @@ impl Tui { .enumerate() .skip(draw_memory.current_stack_startline) .map(|(i, stack_item)| { - let affected = indices_affected - .iter() - .find(|(affected_index, _name)| *affected_index == i); + let affected = + indices_affected.iter().find(|(affected_index, _name)| *affected_index == i); let mut words: Vec = (0..32) - .into_iter() .rev() .map(|i| stack_item.byte(i)) .map(|byte| { @@ -667,9 +633,7 @@ impl Tui { }) .collect(); - let paragraph = Paragraph::new(text) - .block(stack_space) - .wrap(Wrap { trim: true }); + let paragraph = Paragraph::new(text).block(stack_space).wrap(Wrap { trim: true }); f.render_widget(paragraph, area); } @@ -679,14 +643,11 @@ impl Tui { current_step: usize, area: Rect, mem_utf8: bool, - draw_mem: &mut DrawMemory, + draw_mem: &DrawMemory, ) { let memory = &debug_steps[current_step].memory; let stack_space = Block::default() - .title(format!( - "Memory (max expansion: {} bytes)", - memory.effective_len() - )) + .title(format!("Memory (max expansion: {} bytes)", memory.effective_len())) .borders(Borders::ALL); let memory = memory.data(); let max_i = memory.len() / 32; @@ -773,9 +734,7 @@ impl Tui { Spans::from(spans) }) .collect(); - let paragraph = Paragraph::new(text) - .block(stack_space) - .wrap(Wrap { trim: true }); + let paragraph = Paragraph::new(text).block(stack_space).wrap(Wrap { trim: true }); f.render_widget(paragraph, area); } } @@ -884,13 +843,7 @@ fn stack_indices_affected(op: u8) -> Vec<(usize, &'static str)> { 0xa0 => vec![(0, "offset"), (1, "length")], 0xa1 => vec![(0, "offset"), (1, "length"), (2, "topic")], 0xa2 => vec![(0, "offset"), (1, "length"), (2, "topic1"), (3, "topic2")], - 0xa3 => vec![ - (0, "offset"), - (1, "length"), - (2, "topic1"), - (3, "topic2"), - (4, "topic3"), - ], + 0xa3 => vec![(0, "offset"), (1, "length"), (2, "topic1"), (3, "topic2"), (4, "topic3")], 0xa4 => vec![ (0, "offset"), (1, "length"), diff --git a/snark-verifier/src/loader/evm/util.rs b/snark-verifier/src/loader/evm/util.rs index a84df4c3..5df077f6 100644 --- a/snark-verifier/src/loader/evm/util.rs +++ b/snark-verifier/src/loader/evm/util.rs @@ -74,7 +74,7 @@ pub fn modulus() -> U256 where F: PrimeField, { - U256::from_little_endian((-F::one()).to_repr().as_ref()) + 1 + U256::from_little_endian((-F::ONE).to_repr().as_ref()) + 1 } /// Encode instances and proof into calldata. diff --git a/snark-verifier/src/loader/evm/util/executor.rs b/snark-verifier/src/loader/evm/util/executor.rs index a7697a0e..17062028 100644 --- a/snark-verifier/src/loader/evm/util/executor.rs +++ b/snark-verifier/src/loader/evm/util/executor.rs @@ -47,13 +47,8 @@ fn get_create2_address_from_hash( salt: [u8; 32], init_code_hash: impl Into, ) -> Address { - let bytes = [ - &[0xff], - from.into().as_bytes(), - salt.as_slice(), - init_code_hash.into().as_ref(), - ] - .concat(); + let bytes = + [&[0xff], from.into().as_bytes(), salt.as_slice(), init_code_hash.into().as_ref()].concat(); let hash = keccak256(bytes); @@ -87,11 +82,7 @@ struct LogCollector { impl Inspector for LogCollector { fn log(&mut self, _: &mut EVMData<'_, DB>, address: &Address, topics: &[H256], data: &Bytes) { - self.logs.push(Log { - address: *address, - topics: topics.to_vec(), - data: data.clone(), - }); + self.logs.push(Log { address: *address, topics: topics.to_vec(), data: data.clone() }); } fn call( @@ -114,6 +105,7 @@ pub enum CallKind { Create2, } +#[allow(clippy::derivable_impls)] impl Default for CallKind { fn default() -> Self { CallKind::Call @@ -284,29 +276,15 @@ impl Debugger { fn enter(&mut self, depth: usize, address: Address, kind: CallKind) { self.context = address; - self.head = self.arena.push_node(DebugNode { - depth, - address, - kind, - ..Default::default() - }); + self.head = self.arena.push_node(DebugNode { depth, address, kind, ..Default::default() }); } fn exit(&mut self) { if let Some(parent_id) = self.arena.arena[self.head].parent { - let DebugNode { - depth, - address, - kind, - .. - } = self.arena.arena[parent_id]; + let DebugNode { depth, address, kind, .. } = self.arena.arena[parent_id]; self.context = address; - self.head = self.arena.push_node(DebugNode { - depth, - address, - kind, - ..Default::default() - }); + self.head = + self.arena.push_node(DebugNode { depth, address, kind, ..Default::default() }); } } } @@ -324,11 +302,7 @@ impl Inspector for Debugger { let opcode_infos = spec_opcode_gas(data.env.cfg.spec_id); let opcode_info = &opcode_infos[op as usize]; - let push_size = if opcode_info.is_push() { - (op - opcode::PUSH1 + 1) as usize - } else { - 0 - }; + let push_size = if opcode_info.is_push() { (op - opcode::PUSH1 + 1) as usize } else { 0 }; let push_bytes = match push_size { 0 => None, n => { @@ -394,12 +368,7 @@ impl Inspector for Debugger { CallKind::Create, ); - ( - Return::Continue, - None, - Gas::new(call.gas_limit), - Bytes::new(), - ) + (Return::Continue, None, Gas::new(call.gas_limit), Bytes::new()) } fn create_end( @@ -619,12 +588,7 @@ impl Inspector for InspectorStack { } ); - ( - Return::Continue, - None, - Gas::new(call.gas_limit), - Bytes::new(), - ) + (Return::Continue, None, Gas::new(call.gas_limit), Bytes::new()) } fn create_end( @@ -741,11 +705,7 @@ pub struct Executor { impl Executor { fn new(debugger: bool, gas_limit: U256) -> Self { - Executor { - db: InMemoryDB::default(), - debugger, - gas_limit, - } + Executor { db: InMemoryDB::default(), debugger, gas_limit } } pub fn db_mut(&mut self) -> &mut InMemoryDB { @@ -757,16 +717,8 @@ impl Executor { let result = self.call_raw_with_env(env); self.commit(&result); - let RawCallResult { - exit_reason, - out, - gas_used, - gas_refunded, - logs, - debug, - env, - .. - } = result; + let RawCallResult { exit_reason, out, gas_used, gas_refunded, logs, debug, env, .. } = + result; let address = match (exit_reason, out) { (return_ok!(), TransactOut::Create(_, Some(address))) => Some(address), @@ -801,13 +753,7 @@ impl Executor { let result = evm_inner::<_, true>(&mut env, &mut self.db.clone(), &mut inspector).transact(); let (exec_result, state_changeset) = result; - let ExecutionResult { - exit_reason, - gas_refunded, - gas_used, - out, - .. - } = exec_result; + let ExecutionResult { exit_reason, gas_refunded, gas_used, out, .. } = exec_result; let result = match out { TransactOut::Call(ref data) => data.to_owned(), @@ -831,16 +777,13 @@ impl Executor { fn commit(&mut self, result: &RawCallResult) { if let Some(state_changeset) = result.state_changeset.as_ref() { - self.db - .commit(state_changeset.clone().into_iter().collect()); + self.db.commit(state_changeset.clone().into_iter().collect()); } } fn inspector(&self) -> InspectorStack { - let mut stack = InspectorStack { - logs: Some(LogCollector::default()), - ..Default::default() - }; + let mut stack = + InspectorStack { logs: Some(LogCollector::default()), ..Default::default() }; if self.debugger { let gas_inspector = Rc::new(RefCell::new(GasInspector::default())); stack.gas = Some(gas_inspector.clone()); @@ -857,10 +800,7 @@ impl Executor { value: U256, ) -> Env { Env { - block: BlockEnv { - gas_limit: self.gas_limit, - ..BlockEnv::default() - }, + block: BlockEnv { gas_limit: self.gas_limit, ..BlockEnv::default() }, tx: TxEnv { caller, transact_to, diff --git a/snark-verifier/src/loader/halo2/loader.rs b/snark-verifier/src/loader/halo2/loader.rs index 31be9841..105972c0 100644 --- a/snark-verifier/src/loader/halo2/loader.rs +++ b/snark-verifier/src/loader/halo2/loader.rs @@ -136,15 +136,15 @@ impl> Halo2Loader { | (Value::Constant(constant), Value::Assigned(assigned)) => { Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), assigned)], + &[(C::Scalar::ONE, assigned)], *constant, )) } (Value::Assigned(lhs), Value::Assigned(rhs)) => { Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), lhs), (C::Scalar::one(), rhs)], - C::Scalar::zero(), + &[(C::Scalar::ONE, lhs), (C::Scalar::ONE, rhs)], + C::Scalar::ZERO, )) } }; @@ -161,14 +161,14 @@ impl> Halo2Loader { (Value::Constant(constant), Value::Assigned(assigned)) => { Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(-C::Scalar::one(), assigned)], + &[(-C::Scalar::ONE, assigned)], *constant, )) } (Value::Assigned(assigned), Value::Constant(constant)) => { Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), assigned)], + &[(C::Scalar::ONE, assigned)], -*constant, )) } @@ -191,14 +191,14 @@ impl> Halo2Loader { Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), &[(*constant, assigned)], - C::Scalar::zero(), + C::Scalar::ZERO, )) } (Value::Assigned(lhs), Value::Assigned(rhs)) => { Value::Assigned(self.scalar_chip().sum_products_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), lhs, rhs)], - C::Scalar::zero(), + &[(C::Scalar::ONE, lhs, rhs)], + C::Scalar::ZERO, )) } }; @@ -557,7 +557,7 @@ impl> EcPointLoader for Rc + if scalar.eq(&C::Scalar::ONE) => { variable_base_non_scaled.push(base); } diff --git a/snark-verifier/src/loader/halo2/shim.rs b/snark-verifier/src/loader/halo2/shim.rs index 790c9e22..9d010d2b 100644 --- a/snark-verifier/src/loader/halo2/shim.rs +++ b/snark-verifier/src/loader/halo2/shim.rs @@ -1,8 +1,8 @@ -use crate::util::arithmetic::{CurveAffine, FieldExt}; +use crate::util::arithmetic::{CurveAffine, PrimeField}; use std::{fmt::Debug, ops::Deref}; /// Instructions to handle field element operations. -pub trait IntegerInstructions: Clone + Debug { +pub trait IntegerInstructions: Clone + Debug { /// Context (either enhanced `region` or some kind of builder). type Context: Debug; /// Assigned cell. @@ -24,8 +24,8 @@ pub trait IntegerInstructions: Clone + Debug { fn sum_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F::Scalar, impl Deref)], - constant: F::Scalar, + values: &[(F, impl Deref)], + constant: F, ) -> Self::AssignedInteger; /// Sum product of integers with coefficients and constant. @@ -33,11 +33,11 @@ pub trait IntegerInstructions: Clone + Debug { &self, ctx: &mut Self::Context, values: &[( - F::Scalar, + F, impl Deref, impl Deref, )], - constant: F::Scalar, + constant: F, ) -> Self::AssignedInteger; /// Returns `lhs - rhs`. @@ -132,25 +132,26 @@ mod halo2_lib { use crate::halo2_proofs::halo2curves::CurveAffineExt; use crate::{ loader::halo2::{EccInstructions, IntegerInstructions}, - util::arithmetic::CurveAffine, + util::arithmetic::{CurveAffine, PrimeField}, }; use halo2_base::{ self, gates::{builder::GateThreadBuilder, GateChip, GateInstructions, RangeInstructions}, + utils::BigPrimeField, AssignedValue, QuantumCell::{Constant, Existing}, }; use halo2_ecc::bigint::ProperCrtUint; use halo2_ecc::{ ecc::{BaseFieldEccChip, EcPoint}, - fields::{FieldChip, PrimeField}, + fields::FieldChip, }; use std::ops::Deref; type AssignedInteger = ProperCrtUint<::ScalarExt>; type AssignedEcPoint = EcPoint<::ScalarExt, AssignedInteger>; - impl IntegerInstructions for GateChip { + impl IntegerInstructions for GateChip { type Context = GateThreadBuilder; type AssignedCell = AssignedValue; type AssignedInteger = AssignedValue; @@ -166,14 +167,14 @@ mod halo2_lib { fn sum_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F::Scalar, impl Deref)], + values: &[(F, impl Deref)], constant: F, ) -> Self::AssignedInteger { let mut a = Vec::with_capacity(values.len() + 1); let mut b = Vec::with_capacity(values.len() + 1); - if constant != F::zero() { + if constant != F::ZERO { a.push(Constant(constant)); - b.push(Constant(F::one())); + b.push(Constant(F::ONE)); } a.extend(values.iter().map(|(_, a)| Existing(*a.deref()))); b.extend(values.iter().map(|(c, _)| Constant(*c))); @@ -184,7 +185,7 @@ mod halo2_lib { &self, ctx: &mut Self::Context, values: &[( - F::Scalar, + F, impl Deref, impl Deref, )], @@ -220,8 +221,8 @@ mod halo2_lib { ) -> Self::AssignedInteger { // make sure scalar != 0 let is_zero = self.is_zero(ctx.main(0), *a); - self.assert_is_const(ctx.main(0), &is_zero, &F::zero()); - GateInstructions::div_unsafe(self, ctx.main(0), Constant(F::one()), Existing(*a)) + self.assert_is_const(ctx.main(0), &is_zero, &F::ZERO); + GateInstructions::div_unsafe(self, ctx.main(0), Constant(F::ONE), Existing(*a)) } fn assert_equal( @@ -236,8 +237,8 @@ mod halo2_lib { impl<'chip, C: CurveAffineExt> EccInstructions for BaseFieldEccChip<'chip, C> where - C::ScalarExt: PrimeField, - C::Base: PrimeField, + C::ScalarExt: BigPrimeField, + C::Base: BigPrimeField, { type Context = GateThreadBuilder; type ScalarChip = GateChip; diff --git a/snark-verifier/src/pcs/ipa.rs b/snark-verifier/src/pcs/ipa.rs index 6358e15d..288745d7 100644 --- a/snark-verifier/src/pcs/ipa.rs +++ b/snark-verifier/src/pcs/ipa.rs @@ -379,7 +379,7 @@ fn h_eval>(xi: &[T], z: &T) -> T { fn h_coeffs(xi: &[F], scalar: F) -> Vec { assert!(!xi.is_empty()); - let mut coeffs = vec![F::zero(); 1 << xi.len()]; + let mut coeffs = vec![F::ZERO; 1 << xi.len()]; coeffs[0] = scalar; for (len, xi) in xi.iter().rev().enumerate().map(|(i, xi)| (1 << i, xi)) { diff --git a/snark-verifier/src/pcs/ipa/accumulation.rs b/snark-verifier/src/pcs/ipa/accumulation.rs index 56d61aa7..51434541 100644 --- a/snark-verifier/src/pcs/ipa/accumulation.rs +++ b/snark-verifier/src/pcs/ipa/accumulation.rs @@ -186,13 +186,13 @@ where let (u, h) = instances .iter() - .map(|IpaAccumulator { u, xi }| (*u, h_coeffs(xi, C::Scalar::one()))) + .map(|IpaAccumulator { u, xi }| (*u, h_coeffs(xi, C::Scalar::ONE))) .chain(a_b_u.map(|(a, b, u)| { ( u, iter::empty() .chain([b, a]) - .chain(iter::repeat_with(C::Scalar::zero).take(pk.domain.n - 2)) + .chain(iter::repeat(C::Scalar::ZERO).take(pk.domain.n - 2)) .collect(), ) })) diff --git a/snark-verifier/src/pcs/ipa/decider.rs b/snark-verifier/src/pcs/ipa/decider.rs index 5235a857..6fd7026b 100644 --- a/snark-verifier/src/pcs/ipa/decider.rs +++ b/snark-verifier/src/pcs/ipa/decider.rs @@ -48,7 +48,7 @@ mod native { dk: &Self::DecidingKey, IpaAccumulator { u, xi }: IpaAccumulator, ) -> Result<(), Error> { - let h = h_coeffs(&xi, C::Scalar::one()); + let h = h_coeffs(&xi, C::Scalar::ONE); (u == multi_scalar_multiplication(&h, &dk.g).to_affine()) .then_some(()) .ok_or_else(|| Error::AssertionFailure("U == commit(G, h)".to_string())) diff --git a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs index cae77a5f..538aa4fb 100644 --- a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs +++ b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs @@ -5,7 +5,7 @@ use crate::{ PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{CurveAffine, FieldExt, Fraction}, + arithmetic::{CurveAffine, Fraction, PrimeField}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -159,7 +159,7 @@ where fn query_sets(queries: &[Query]) -> Vec> where - F: FieldExt, + F: PrimeField + Ord, T: Clone, { let poly_shifts = @@ -203,7 +203,7 @@ where fn query_set_coeffs(sets: &[QuerySet], x: &T, x_3: &T) -> Vec> where - F: FieldExt, + F: PrimeField + Ord, T: LoadedScalar, { let loader = x.loader(); @@ -236,7 +236,7 @@ struct QuerySet<'a, F, T> { impl<'a, F, T> QuerySet<'a, F, T> where - F: FieldExt, + F: PrimeField, T: LoadedScalar, { fn msm>( @@ -288,7 +288,7 @@ struct QuerySetCoeff { impl QuerySetCoeff where - F: FieldExt, + F: PrimeField + Ord, T: LoadedScalar, { fn new(shifts: &[F], powers_of_x: &[T], x_3: &T, x_3_minus_x_shift_i: &BTreeMap) -> Self { @@ -303,7 +303,7 @@ where .filter(|&(i, _)| i != j) .map(|(_, shift_i)| (*shift_j - shift_i)) .reduce(|acc, value| acc * value) - .unwrap_or_else(|| F::one()) + .unwrap_or(F::ONE) }) .collect_vec(); diff --git a/snark-verifier/src/pcs/kzg.rs b/snark-verifier/src/pcs/kzg.rs index 8f416ee3..387a108c 100644 --- a/snark-verifier/src/pcs/kzg.rs +++ b/snark-verifier/src/pcs/kzg.rs @@ -18,7 +18,7 @@ pub use accumulator::LimbsEncodingInstructions; /// KZG succinct verifying key. #[derive(Clone, Copy, Debug)] -pub struct KzgSuccinctVerifyingKey { +pub struct KzgSuccinctVerifyingKey { /// Generator. pub g: C, } diff --git a/snark-verifier/src/pcs/kzg/accumulation.rs b/snark-verifier/src/pcs/kzg/accumulation.rs index 1f901568..d71e366e 100644 --- a/snark-verifier/src/pcs/kzg/accumulation.rs +++ b/snark-verifier/src/pcs/kzg/accumulation.rs @@ -19,6 +19,7 @@ pub struct KzgAs(PhantomData<(M, MOS)>); impl AccumulationScheme for KzgAs where M: MultiMillerLoop, + M::G1Affine: CurveAffine, L: Loader, MOS: Clone + Debug, { @@ -139,6 +140,7 @@ where impl AccumulationSchemeProver for KzgAs where M: MultiMillerLoop, + M::G1Affine: CurveAffine, MOS: Clone + Debug, { type ProvingKey = KzgAsProvingKey; @@ -163,7 +165,7 @@ where let blind = pk .zk() .then(|| { - let s = M::Scalar::random(rng); + let s = M::Fr::random(rng); let (g, s_g) = pk.0.unwrap(); let lhs = (s_g * s).to_affine(); let rhs = (g * s).to_affine(); diff --git a/snark-verifier/src/pcs/kzg/accumulator.rs b/snark-verifier/src/pcs/kzg/accumulator.rs index 82d1454b..423ae7d5 100644 --- a/snark-verifier/src/pcs/kzg/accumulator.rs +++ b/snark-verifier/src/pcs/kzg/accumulator.rs @@ -59,7 +59,6 @@ mod native { let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = limbs .chunks(LIMBS) - .into_iter() .map(|limbs| { fe_from_limbs::<_, _, LIMBS, BITS>( limbs.iter().map(|limb| **limb).collect_vec().try_into().unwrap(), @@ -109,7 +108,6 @@ mod evm { let [lhs_x, lhs_y, rhs_x, rhs_y]: [[_; LIMBS]; 4] = limbs .chunks(LIMBS) - .into_iter() .map(|limbs| limbs.to_vec().try_into().unwrap()) .collect_vec() .try_into() @@ -204,14 +202,15 @@ mod halo2 { mod halo2_lib { use super::*; use halo2_base::halo2_proofs::halo2curves::CurveAffineExt; - use halo2_ecc::{ecc::BaseFieldEccChip, fields::PrimeField}; + use halo2_base::utils::BigPrimeField; + use halo2_ecc::ecc::BaseFieldEccChip; impl<'chip, C, const LIMBS: usize, const BITS: usize> LimbsEncodingInstructions for BaseFieldEccChip<'chip, C> where C: CurveAffineExt, - C::ScalarExt: PrimeField, - C::Base: PrimeField, + C::ScalarExt: BigPrimeField, + C::Base: BigPrimeField, { fn assign_ec_point_from_limbs( &self, diff --git a/snark-verifier/src/pcs/kzg/decider.rs b/snark-verifier/src/pcs/kzg/decider.rs index 59f1afbf..04f2caaf 100644 --- a/snark-verifier/src/pcs/kzg/decider.rs +++ b/snark-verifier/src/pcs/kzg/decider.rs @@ -1,4 +1,7 @@ -use crate::{pcs::kzg::KzgSuccinctVerifyingKey, util::arithmetic::MultiMillerLoop}; +use crate::{ + pcs::kzg::KzgSuccinctVerifyingKey, + util::arithmetic::{CurveAffine, MultiMillerLoop}, +}; use std::marker::PhantomData; /// KZG deciding key. @@ -23,7 +26,10 @@ impl KzgDecidingKey { } } -impl From<(M::G1Affine, M::G2Affine, M::G2Affine)> for KzgDecidingKey { +impl From<(M::G1Affine, M::G2Affine, M::G2Affine)> for KzgDecidingKey +where + M::G1Affine: CurveAffine, +{ fn from((g1, g2, s_g2): (M::G1Affine, M::G2Affine, M::G2Affine)) -> KzgDecidingKey { KzgDecidingKey::new(g1, g2, s_g2) } @@ -43,7 +49,7 @@ mod native { AccumulationDecider, }, util::{ - arithmetic::{Group, MillerLoopResult, MultiMillerLoop}, + arithmetic::{CurveAffine, Group, MillerLoopResult, MultiMillerLoop}, Itertools, }, Error, @@ -53,6 +59,7 @@ mod native { impl AccumulationDecider for KzgAs where M: MultiMillerLoop, + M::G1Affine: CurveAffine, MOS: Clone + Debug, { type DecidingKey = KzgDecidingKey; @@ -103,7 +110,9 @@ mod evm { impl AccumulationDecider> for KzgAs where M: MultiMillerLoop, - M::Scalar: PrimeField, + M::G1Affine: CurveAffine, + M::G2Affine: CurveAffine, + M::Fr: PrimeField, MOS: Clone + Debug, { type DecidingKey = KzgDecidingKey; @@ -152,7 +161,7 @@ mod evm { loader.code_mut().runtime_append(code); let challenge = loader.scalar(Value::Memory(challenge_ptr)); - let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); + let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); let [lhs, rhs] = [lhs, rhs].map(|msms| { msms.iter() .zip(powers_of_challenge.iter()) diff --git a/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs index 3a448056..d1398ebb 100644 --- a/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs @@ -6,7 +6,7 @@ use crate::{ PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{CurveAffine, FieldExt, Fraction, MultiMillerLoop}, + arithmetic::{CurveAffine, Fraction, MultiMillerLoop, PrimeField}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -27,6 +27,8 @@ pub struct Bdfg21; impl PolynomialCommitmentScheme for KzgAs where M: MultiMillerLoop, + M::G1Affine: CurveAffine, + M::Fr: Ord, L: Loader, { type VerifyingKey = KzgSuccinctVerifyingKey; @@ -35,7 +37,7 @@ where fn read_proof( _: &KzgSuccinctVerifyingKey, - _: &[Query], + _: &[Query], transcript: &mut T, ) -> Result, Error> where @@ -48,22 +50,21 @@ where svk: &KzgSuccinctVerifyingKey, commitments: &[Msm], z: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Bdfg21Proof, ) -> Result { let sets = query_sets(queries); let f = { let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); - let powers_of_mu = proof - .mu - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let powers_of_mu = + proof.mu.powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); let msms = sets .iter() .zip(coeffs.iter()) .map(|(set, coeff)| set.msm(coeff, commitments, &powers_of_mu)); - msms.zip(proof.gamma.powers(sets.len()).into_iter()) + msms.zip(proof.gamma.powers(sets.len())) .map(|(msm, power_of_gamma)| msm * &power_of_gamma) .sum::>() - Msm::base(&proof.w) * &coeffs[0].z_s @@ -72,10 +73,7 @@ where let rhs = Msm::base(&proof.w_prime); let lhs = f + rhs.clone() * &proof.z_prime; - Ok(KzgAccumulator::new( - lhs.evaluate(Some(svk.g)), - rhs.evaluate(Some(svk.g)), - )) + Ok(KzgAccumulator::new(lhs.evaluate(Some(svk.g)), rhs.evaluate(Some(svk.g)))) } } @@ -104,24 +102,14 @@ where let w = transcript.read_ec_point()?; let z_prime = transcript.squeeze_challenge(); let w_prime = transcript.read_ec_point()?; - Ok(Bdfg21Proof { - mu, - gamma, - w, - z_prime, - w_prime, - }) + Ok(Bdfg21Proof { mu, gamma, w, z_prime, w_prime }) } } -fn query_sets(queries: &[Query]) -> Vec> { - let poly_shifts = queries.iter().fold( - Vec::<(usize, Vec, Vec<&T>)>::new(), - |mut poly_shifts, query| { - if let Some(pos) = poly_shifts - .iter() - .position(|(poly, _, _)| *poly == query.poly) - { +fn query_sets(queries: &[Query]) -> Vec> { + let poly_shifts = + queries.iter().fold(Vec::<(usize, Vec, Vec<&T>)>::new(), |mut poly_shifts, query| { + if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) { let (_, shifts, evals) = &mut poly_shifts[pos]; if !shifts.contains(&query.shift) { shifts.push(query.shift); @@ -131,67 +119,47 @@ fn query_sets(queries: &[Query]) -> Vec>::new(), - |mut sets, (poly, shifts, evals)| { - if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) - }) { - let set = &mut sets[pos]; - if !set.polys.contains(&poly) { - set.polys.push(poly); - set.evals.push( - set.shifts - .iter() - .map(|lhs| { - let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); - evals[idx] - }) - .collect(), - ); - } - } else { - let set = QuerySet { - shifts, - polys: vec![poly], - evals: vec![evals], - }; - sets.push(set); + }); + + poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx] + }) + .collect(), + ); } - sets - }, - ) + } else { + let set = QuerySet { shifts, polys: vec![poly], evals: vec![evals] }; + sets.push(set); + } + sets + }) } -fn query_set_coeffs<'a, F: FieldExt, T: LoadedScalar>( - sets: &[QuerySet<'a, F, T>], +fn query_set_coeffs>( + sets: &[QuerySet], z: &T, z_prime: &T, ) -> Vec> { let loader = z.loader(); - let superset = sets - .iter() - .flat_map(|set| set.shifts.clone()) - .sorted() - .dedup(); + let superset = sets.iter().flat_map(|set| set.shifts.clone()).sorted().dedup(); - let size = sets - .iter() - .map(|set| set.shifts.len()) - .chain(Some(2)) - .max() - .unwrap(); + let size = sets.iter().map(|set| set.shifts.len()).chain(Some(2)).max().unwrap(); let powers_of_z = z.powers(size); - let z_prime_minus_z_shift_i = BTreeMap::from_iter(superset.map(|shift| { - ( - shift, - z_prime.clone() - z.clone() * loader.load_const(&shift), - ) - })); + let z_prime_minus_z_shift_i = BTreeMap::from_iter( + superset.map(|shift| (shift, z_prime.clone() - z.clone() * loader.load_const(&shift))), + ); let mut z_s_1 = None; let mut coeffs = sets @@ -225,7 +193,7 @@ struct QuerySet<'a, F, T> { evals: Vec>, } -impl<'a, F: FieldExt, T: LoadedScalar> QuerySet<'a, F, T> { +impl<'a, F: PrimeField, T: LoadedScalar> QuerySet<'a, F, T> { fn msm>( &self, coeff: &QuerySetCoeff, @@ -270,7 +238,7 @@ struct QuerySetCoeff { impl QuerySetCoeff where - F: FieldExt, + F: PrimeField + Ord, T: LoadedScalar, { fn new( @@ -292,7 +260,7 @@ where .filter(|&(i, _)| i != j) .map(|(_, shift_i)| (*shift_j - shift_i)) .reduce(|acc, value| acc * value) - .unwrap_or_else(|| F::one()) + .unwrap_or(F::ONE) }) .collect_vec(); @@ -312,10 +280,7 @@ where .collect_vec(); let z_s = loader.product( - &shifts - .iter() - .map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()) - .collect_vec(), + &shifts.iter().map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()).collect_vec(), ); let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); @@ -344,13 +309,8 @@ where .iter_mut() .chain(self.commitment_coeff.as_mut()) .for_each(Fraction::evaluate); - let barycentric_weights_sum = loader.sum( - &self - .eval_coeffs - .iter() - .map(Fraction::evaluated) - .collect_vec(), - ); + let barycentric_weights_sum = + loader.sum(&self.eval_coeffs.iter().map(Fraction::evaluated).collect_vec()); self.r_eval_coeff = Some(match self.commitment_coeff.as_ref() { Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), None => Fraction::one_over(barycentric_weights_sum), @@ -370,13 +330,9 @@ impl CostEstimation for KzgAs where M: MultiMillerLoop, { - type Input = Vec>; + type Input = Vec>; - fn estimate_cost(_: &Vec>) -> Cost { - Cost { - num_commitment: 2, - num_msm: 2, - ..Default::default() - } + fn estimate_cost(_: &Vec>) -> Cost { + Cost { num_commitment: 2, num_msm: 2, ..Default::default() } } } diff --git a/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs index e5741163..da5d51a6 100644 --- a/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs @@ -23,6 +23,7 @@ pub struct Gwc19; impl PolynomialCommitmentScheme for KzgAs where M: MultiMillerLoop, + M::G1Affine: CurveAffine, L: Loader, { type VerifyingKey = KzgSuccinctVerifyingKey; @@ -31,7 +32,7 @@ where fn read_proof( _: &Self::VerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result where @@ -44,15 +45,13 @@ where svk: &Self::VerifyingKey, commitments: &[Msm], z: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Self::Proof, ) -> Result { let sets = query_sets(queries); let powers_of_u = &proof.u.powers(sets.len()); let f = { - let powers_of_v = proof - .v - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let powers_of_v = proof.v.powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); sets.iter() .map(|set| set.msm(commitments, &powers_of_v)) .zip(powers_of_u.iter()) @@ -67,11 +66,7 @@ where .zip(powers_of_u.iter()) .map(|(w, power_of_u)| Msm::base(w) * power_of_u) .collect_vec(); - let lhs = f + rhs - .iter() - .zip(z_omegas) - .map(|(uw, z_omega)| uw.clone() * &z_omega) - .sum(); + let lhs = f + rhs.iter().zip(z_omegas).map(|(uw, z_omega)| uw.clone() * &z_omega).sum(); Ok(KzgAccumulator::new( lhs.evaluate(Some(svk.g)), @@ -161,14 +156,10 @@ impl CostEstimation for KzgAs where M: MultiMillerLoop, { - type Input = Vec>; + type Input = Vec>; - fn estimate_cost(queries: &Vec>) -> Cost { + fn estimate_cost(queries: &Vec>) -> Cost { let num_w = query_sets(queries).len(); - Cost { - num_commitment: num_w, - num_msm: num_w, - ..Default::default() - } + Cost { num_commitment: num_w, num_msm: num_w, ..Default::default() } } } diff --git a/snark-verifier/src/system/halo2.rs b/snark-verifier/src/system/halo2.rs index 98f4488c..2dc5751d 100644 --- a/snark-verifier/src/system/halo2.rs +++ b/snark-verifier/src/system/halo2.rs @@ -7,7 +7,7 @@ use crate::halo2_proofs::{ }; use crate::{ util::{ - arithmetic::{root_of_unity, CurveAffine, Domain, FieldExt, Rotation}, + arithmetic::{root_of_unity, CurveAffine, Domain, PrimeField, Rotation}, Itertools, }, verifier::plonk::protocol::{ @@ -161,7 +161,7 @@ impl From for Rotation { } } -struct Polynomials<'a, F: FieldExt> { +struct Polynomials<'a, F: PrimeField> { cs: &'a ConstraintSystem, zk: bool, query_instance: bool, @@ -179,7 +179,7 @@ struct Polynomials<'a, F: FieldExt> { num_lookup_z: usize, } -impl<'a, F: FieldExt> Polynomials<'a, F> { +impl<'a, F: PrimeField> Polynomials<'a, F> { fn new( cs: &'a ConstraintSystem, zk: bool, @@ -474,7 +474,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn l_active(&self) -> Expression { - Expression::Constant(F::one()) - self.l_last() - self.l_blind() + Expression::Constant(F::ONE) - self.l_last() - self.l_blind() } fn system_challenge_offset(&self) -> usize { @@ -499,7 +499,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn permutation_constraints(&'a self, t: usize) -> impl IntoIterator> + 'a { - let one = &Expression::Constant(F::one()); + let one = &Expression::Constant(F::ONE); let l_0 = &Expression::::CommonPolynomial(CommonPolynomial::Lagrange(0)); let l_last = &self.l_last(); let l_active = &self.l_active(); @@ -591,7 +591,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn lookup_constraints(&'a self, t: usize) -> impl IntoIterator> + 'a { - let one = &Expression::Constant(F::one()); + let one = &Expression::Constant(F::ONE); let l_0 = &Expression::::CommonPolynomial(CommonPolynomial::Lagrange(0)); let l_last = &self.l_last(); let l_active = &self.l_active(); @@ -698,7 +698,7 @@ impl EncodedChallenge for MockChallenge { } #[derive(Default)] -struct MockTranscript(F); +struct MockTranscript(F); impl Transcript for MockTranscript { fn squeeze_challenge(&mut self) -> MockChallenge { diff --git a/snark-verifier/src/system/halo2/transcript.rs b/snark-verifier/src/system/halo2/transcript.rs index 9cfd6b89..10da3a22 100644 --- a/snark-verifier/src/system/halo2/transcript.rs +++ b/snark-verifier/src/system/halo2/transcript.rs @@ -1,7 +1,6 @@ //! Transcripts implemented with both `halo2_proofs::transcript` and //! `crate::util::transcript`. use crate::halo2_proofs; -use halo2_proofs::transcript::{Blake2bRead, Blake2bWrite, Challenge255}; use crate::{ loader::native::{self, NativeLoader}, util::{ @@ -10,6 +9,8 @@ use crate::{ }, Error, }; +use halo2_proofs::transcript::{Blake2bRead, Blake2bWrite, Challenge255}; +use pairing::group::ff::FromUniformBytes; use std::io::{Read, Write}; #[cfg(feature = "loader_evm")] @@ -18,7 +19,10 @@ pub mod evm; #[cfg(feature = "loader_halo2")] pub mod halo2; -impl Transcript for Blake2bRead> { +impl Transcript for Blake2bRead> +where + C::Scalar: FromUniformBytes<64>, +{ fn loader(&self) -> &NativeLoader { &native::LOADER } @@ -38,8 +42,9 @@ impl Transcript for Blake2bRead TranscriptRead - for Blake2bRead> +impl TranscriptRead for Blake2bRead> +where + C::Scalar: FromUniformBytes<64>, { fn read_scalar(&mut self) -> Result { halo2_proofs::transcript::TranscriptRead::read_scalar(self) @@ -52,7 +57,10 @@ impl TranscriptRead } } -impl Transcript for Blake2bWrite> { +impl Transcript for Blake2bWrite> +where + C::Scalar: FromUniformBytes<64>, +{ fn loader(&self) -> &NativeLoader { &native::LOADER } @@ -72,7 +80,10 @@ impl Transcript for Blake2bWrite TranscriptWrite for Blake2bWrite, C, Challenge255> { +impl TranscriptWrite for Blake2bWrite, C, Challenge255> +where + C::Scalar: FromUniformBytes<64>, +{ fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { halo2_proofs::transcript::TranscriptWrite::write_scalar(self, scalar) .map_err(|err| Error::Transcript(err.kind(), err.to_string())) diff --git a/snark-verifier/src/system/halo2/transcript/halo2.rs b/snark-verifier/src/system/halo2/transcript/halo2.rs index 86b1929c..8a0ce6d4 100644 --- a/snark-verifier/src/system/halo2/transcript/halo2.rs +++ b/snark-verifier/src/system/halo2/transcript/halo2.rs @@ -1,6 +1,7 @@ //! Transcript for verifier in [`halo2_proofs`] circuit. use crate::halo2_proofs; +use crate::util::arithmetic::FieldExt; use crate::{ loader::{ halo2::{EcPoint, EccInstructions, Halo2Loader, Scalar}, @@ -64,7 +65,10 @@ where { /// Initialize [`PoseidonTranscript`] given readable or writeable stream for /// verifying or proving with [`NativeLoader`]. - pub fn new(loader: &Rc>, stream: R) -> Self { + pub fn new(loader: &Rc>, stream: R) -> Self + where + C::Scalar: FieldExt, + { let buf = Poseidon::new::(loader); Self { loader: loader.clone(), stream, buf } } @@ -165,7 +169,10 @@ impl(stream: S) -> Self { + pub fn new(stream: S) -> Self + where + C::Scalar: FieldExt, + { Self { loader: NativeLoader, stream, @@ -375,6 +382,7 @@ impl where C: CurveAffine, + C::Scalar: FieldExt, R: Read, { fn init(reader: R) -> Self { @@ -409,6 +417,7 @@ impl where C: CurveAffine, + C::Scalar: FieldExt, W: Write, { fn init(writer: W) -> Self { @@ -423,12 +432,13 @@ where mod halo2_lib { use crate::halo2_curves::CurveAffineExt; use crate::system::halo2::transcript::halo2::NativeEncoding; - use halo2_ecc::{ecc::BaseFieldEccChip, fields::PrimeField}; + use halo2_base::utils::BigPrimeField; + use halo2_ecc::ecc::BaseFieldEccChip; impl<'chip, C: CurveAffineExt> NativeEncoding for BaseFieldEccChip<'chip, C> where - C::Scalar: PrimeField, - C::Base: PrimeField, + C::Scalar: BigPrimeField, + C::Base: BigPrimeField, { fn encode( &self, diff --git a/snark-verifier/src/util/arithmetic.rs b/snark-verifier/src/util/arithmetic.rs index 97962e32..070277a5 100644 --- a/snark-verifier/src/util/arithmetic.rs +++ b/snark-verifier/src/util/arithmetic.rs @@ -4,15 +4,15 @@ use crate::halo2_curves; use crate::util::Itertools; pub use halo2_curves::{ group::{ - ff::{BatchInvert, Field, PrimeField}, + ff::{BatchInvert, Field, FromUniformBytes, PrimeField}, prime::PrimeCurveAffine, Curve, Group, GroupEncoding, }, - pairing::MillerLoopResult, - Coordinates, CurveAffine, CurveExt, FieldExt, + Coordinates, CurveAffine, CurveExt, }; use num_bigint::BigUint; use num_traits::One; +pub use pairing::MillerLoopResult; use serde::{Deserialize, Serialize}; use std::{ cmp::Ordering, @@ -22,9 +22,14 @@ use std::{ }; /// [`halo2_curves::pairing::MultiMillerLoop`] with [`std::fmt::Debug`]. -pub trait MultiMillerLoop: halo2_curves::pairing::MultiMillerLoop + Debug {} +pub trait MultiMillerLoop: pairing::MultiMillerLoop + Debug {} -impl MultiMillerLoop for M {} +impl MultiMillerLoop for M {} + +/// Trait for fields that can implement Poseidon hash +pub trait FieldExt: PrimeField + FromUniformBytes<64> + Ord {} + +impl + Ord> FieldExt for F {} /// Operations that could be done with field elements. pub trait FieldOps: @@ -54,7 +59,7 @@ pub fn batch_invert_and_mul(values: &mut [F], coeff: &F) { } let products = values .iter() - .scan(F::one(), |acc, value| { + .scan(F::ONE, |acc, value| { *acc *= value; Some(*acc) }) @@ -65,7 +70,7 @@ pub fn batch_invert_and_mul(values: &mut [F], coeff: &F) { * coeff; for (value, product) in - values.iter_mut().rev().zip(products.into_iter().rev().skip(1).chain(Some(F::one()))) + values.iter_mut().rev().zip(products.into_iter().rev().skip(1).chain(Some(F::ONE))) { let mut inv = all_product_inv * product; mem::swap(value, &mut inv); @@ -75,7 +80,7 @@ pub fn batch_invert_and_mul(values: &mut [F], coeff: &F) { /// Batch invert [`PrimeField`] elements. pub fn batch_invert(values: &mut [F]) { - batch_invert_and_mul(values, &F::one()) + batch_invert_and_mul(values, &F::ONE) } /// Root of unity of 2^k-sized multiplicative subgroup of [`PrimeField`] by @@ -88,7 +93,7 @@ pub fn batch_invert(values: &mut [F]) { pub fn root_of_unity(k: usize) -> F { assert!(k <= F::S as usize); - iter::successors(Some(F::root_of_unity()), |acc| Some(acc.square())) + iter::successors(Some(F::ROOT_OF_UNITY), |acc| Some(acc.square())) .take(F::S as usize - k + 1) .last() .unwrap() @@ -230,7 +235,7 @@ impl Fraction { /// Modulus of a [`PrimeField`] pub fn modulus() -> BigUint { - fe_to_big(-F::one()) + 1usize + fe_to_big(-F::ONE) + 1usize } /// Convert a [`BigUint`] into a [`PrimeField`] . @@ -286,7 +291,7 @@ pub fn fe_to_limbs(scalar: F) -> impl Iterator { - iter::successors(Some(F::one()), move |power| Some(scalar * power)) + iter::successors(Some(F::ONE), move |power| Some(scalar * power)) } /// Compute inner product of 2 slice of [`Field`]. diff --git a/snark-verifier/src/util/hash/poseidon.rs b/snark-verifier/src/util/hash/poseidon.rs index 1ff06ab9..e826ac52 100644 --- a/snark-verifier/src/util/hash/poseidon.rs +++ b/snark-verifier/src/util/hash/poseidon.rs @@ -1,9 +1,12 @@ #![allow(clippy::needless_range_loop)] // for clarity of matrix operations use crate::{ loader::{LoadedScalar, ScalarLoader}, - util::{arithmetic::FieldExt, Itertools}, + util::{ + arithmetic::{FieldExt, PrimeField}, + Itertools, + }, }; -use poseidon_circuit::poseidon::primitives::Spec as PoseidonSpec; // trait +use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; // trait use std::{iter, marker::PhantomData, mem}; #[cfg(test)] @@ -12,7 +15,7 @@ mod tests; // struct so we can use PoseidonSpec trait to generate round constants and MDS matrix #[derive(Debug)] pub struct Poseidon128Pow5Gen< - F: FieldExt, + F: PrimeField, const T: usize, const RATE: usize, const R_F: usize, @@ -23,7 +26,7 @@ pub struct Poseidon128Pow5Gen< } impl< - F: FieldExt, + F: PrimeField, const T: usize, const RATE: usize, const R_F: usize, @@ -57,7 +60,7 @@ impl< /// `OptimizedPoseidonSpec` holds construction parameters as well as constants that are used in /// permutation step. #[derive(Debug, Clone)] -pub struct OptimizedPoseidonSpec { +pub struct OptimizedPoseidonSpec { pub(crate) r_f: usize, pub(crate) mds_matrices: MDSMatrices, pub(crate) constants: OptimizedConstants, @@ -67,7 +70,7 @@ pub struct OptimizedPoseidonSpec /// full rounds has T sized constants there is a single constant for each /// partial round #[derive(Debug, Clone)] -pub struct OptimizedConstants { +pub struct OptimizedConstants { pub(crate) start: Vec<[F; T]>, pub(crate) partial: Vec, pub(crate) end: Vec<[F; T]>, @@ -80,7 +83,7 @@ pub(crate) type Mds = [[F; T]; T]; /// also called `pre_sparse_mds` and sparse matrices that enables us to reduce /// number of multiplications in apply MDS step #[derive(Debug, Clone)] -pub struct MDSMatrices { +pub struct MDSMatrices { pub(crate) mds: MDSMatrix, pub(crate) pre_sparse_mds: MDSMatrix, pub(crate) sparse_matrices: Vec>, @@ -89,18 +92,18 @@ pub struct MDSMatrices { /// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear /// layer of partial rounds instead of the original MDS #[derive(Debug, Clone)] -pub struct SparseMDSMatrix { +pub struct SparseMDSMatrix { pub(crate) row: [F; T], pub(crate) col_hat: [F; RATE], } /// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon #[derive(Clone, Debug)] -pub struct MDSMatrix(pub(crate) Mds); +pub struct MDSMatrix(pub(crate) Mds); -impl MDSMatrix { +impl MDSMatrix { pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] { - let mut res = [F::zero(); T]; + let mut res = [F::ZERO; T]; for i in 0..T { for j in 0..T { res[i] += self.0[i][j] * v[j]; @@ -110,16 +113,16 @@ impl MDSMatrix { } fn identity() -> Mds { - let mut mds = [[F::zero(); T]; T]; + let mut mds = [[F::ZERO; T]; T]; for i in 0..T { - mds[i][i] = F::one(); + mds[i][i] = F::ONE; } mds } /// Multiplies two MDS matrices. Used in sparse matrix calculations fn mul(&self, other: &Self) -> Self { - let mut res = [[F::zero(); T]; T]; + let mut res = [[F::ZERO; T]; T]; for i in 0..T { for j in 0..T { for k in 0..T { @@ -131,7 +134,7 @@ impl MDSMatrix { } fn transpose(&self) -> Self { - let mut res = [[F::zero(); T]; T]; + let mut res = [[F::ZERO; T]; T]; for i in 0..T { for j in 0..T { res[i][j] = self.0[j][i]; @@ -141,11 +144,11 @@ impl MDSMatrix { } fn determinant(m: [[F; N]; N]) -> F { - let mut res = F::one(); + let mut res = F::ONE; let mut m = m; for i in 0..N { let mut pivot = i; - while m[pivot][i] == F::zero() { + while m[pivot][i] == F::ZERO { pivot += 1; assert!(pivot < N, "matrix is not invertible"); } @@ -196,7 +199,7 @@ impl MDSMatrix { let w = self.0.iter().skip(1).map(|row| row[0]).collect::>(); // m_hat is the `(t-1 * t-1)` right bottom sub-matrix of m := self.0 - let mut m_hat = [[F::zero(); RATE]; RATE]; + let mut m_hat = [[F::ZERO; RATE]; RATE]; for i in 0..RATE { for j in 0..RATE { m_hat[i][j] = self.0[i + 1][j + 1]; @@ -204,7 +207,7 @@ impl MDSMatrix { } // w_hat = m_hat^{-1} * w, where m_hat^{-1} is matrix inverse and * is matrix mult // we avoid computing m_hat^{-1} explicitly by using Cramer's rule: https://en.wikipedia.org/wiki/Cramer%27s_rule - let mut w_hat = [F::zero(); RATE]; + let mut w_hat = [F::ZERO; RATE]; let det = Self::determinant(m_hat); let det_inv = Option::::from(det.invert()).expect("matrix is not invertible"); for j in 0..RATE { @@ -225,9 +228,12 @@ impl MDSMatrix { } } -impl OptimizedPoseidonSpec { +impl OptimizedPoseidonSpec { /// Generate new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated - pub fn new() -> Self { + pub fn new() -> Self + where + F: FieldExt, + { let (round_constants, mds, mds_inv) = Poseidon128Pow5Gen::::constants(); let mds = MDSMatrix(mds); @@ -254,7 +260,7 @@ impl OptimizedPoseidonSpec = vec![[F::zero(); T]; r_f_half]; + let mut constants_start: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half]; constants_start[0] = constants[0]; for (optimized, constants) in constants_start.iter_mut().skip(1).zip(constants.iter().skip(1)) @@ -264,7 +270,7 @@ impl OptimizedPoseidonSpec OptimizedPoseidonSpec = vec![[F::zero(); T]; r_f_half - 1]; + let mut constants_end: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half - 1]; for (optimized, constants) in constants_end.iter_mut().zip(constants.iter().skip(r_f_half + r_p + 1)) { @@ -320,20 +325,20 @@ impl OptimizedPoseidonSpec { +struct State { inner: [L; T], _marker: PhantomData, } // the transcript hash implementation is the one suggested in the original paper https://eprint.iacr.org/2019/458.pdf // another reference implementation is https://github.com/privacy-scaling-explorations/halo2wrong/tree/master/transcript/src -impl, const T: usize, const RATE: usize> State { +impl, const T: usize, const RATE: usize> State { fn new(inner: [L; T]) -> Self { Self { inner, _marker: PhantomData } } fn default(loader: &L::Loader) -> Self { - let mut default_state = [F::zero(); T]; + let mut default_state = [F::ZERO; T]; // from Section 4.2 of https://eprint.iacr.org/2019/458.pdf // • Variable-Input-Length Hashing. The capacity value is 2^64 + (o−1) where o the output length. // for our transcript use cases, o = 1 @@ -376,8 +381,8 @@ impl, const T: usize, const RATE: usize> State, const T: usize, const RATE: usize> State, const T: usize, const RATE: usize> State { +pub struct Poseidon { spec: OptimizedPoseidonSpec, default_state: State, state: State, buf: Vec, } -impl, const T: usize, const RATE: usize> Poseidon { +impl, const T: usize, const RATE: usize> Poseidon { /// Initialize a poseidon hasher. /// Generates a new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated pub fn new( loader: &L::Loader, - ) -> Self { + ) -> Self + where + F: FieldExt, + { let default_state = State::default(loader); Self { spec: OptimizedPoseidonSpec::new::(), @@ -495,7 +503,7 @@ impl, const T: usize, const RATE: usize> Poseido self.state.sbox_full(constants); self.state.apply_mds(&mds); } - self.state.sbox_full(&[F::zero(); T]); + self.state.sbox_full(&[F::ZERO; T]); self.state.apply_mds(&mds); } } diff --git a/snark-verifier/src/util/hash/poseidon/tests.rs b/snark-verifier/src/util/hash/poseidon/tests.rs index cf4712bc..76793c91 100644 --- a/snark-verifier/src/util/hash/poseidon/tests.rs +++ b/snark-verifier/src/util/hash/poseidon/tests.rs @@ -47,7 +47,7 @@ fn test_poseidon_against_test_vectors() { hasher.state = State::new(state.try_into().unwrap()); hasher.permutation(&[(); RATE].map(|_| Fr::zero())); // avoid padding let state_0 = hasher.state.inner; - let expected = vec![ + let expected = [ "7853200120776062878684798364095072458815029376092732009249414926327459813530", "7142104613055408817911962100316808866448378443474503659992478482890339429929", "6549537674122432311777789598043107870002137484850126429160507761192163713804", @@ -71,7 +71,7 @@ fn test_poseidon_against_test_vectors() { hasher.state = State::new(state.try_into().unwrap()); hasher.permutation(&[(); RATE].map(|_| Fr::zero())); let state_0 = hasher.state.inner; - let expected = vec![ + let expected = [ "18821383157269793795438455681495246036402687001665670618754263018637548127333", "7817711165059374331357136443537800893307845083525445872661165200086166013245", "16733335996448830230979566039396561240864200624113062088822991822580465420551", diff --git a/snark-verifier/src/util/msm.rs b/snark-verifier/src/util/msm.rs index 8d18cdf8..fef53e59 100644 --- a/snark-verifier/src/util/msm.rs +++ b/snark-verifier/src/util/msm.rs @@ -71,7 +71,7 @@ where let gen = gen.map(|gen| self.bases.first().unwrap().loader().ec_point_load_const(&gen)); let pairs = iter::empty() .chain(self.constant.as_ref().map(|constant| (constant, gen.as_ref().unwrap()))) - .chain(self.scalars.iter().zip(self.bases.into_iter())) + .chain(self.scalars.iter().zip(self.bases)) .collect_vec(); L::multi_scalar_multiplication(&pairs) } diff --git a/snark-verifier/src/util/poly.rs b/snark-verifier/src/util/poly.rs index 17a065f9..9d688a4e 100644 --- a/snark-verifier/src/util/poly.rs +++ b/snark-verifier/src/util/poly.rs @@ -55,7 +55,7 @@ impl Polynomial { /// Returns evaluation at given `x`. pub fn evaluate(&self, x: F) -> F { let evaluate_serial = - |coeffs: &[F]| coeffs.iter().rev().fold(F::zero(), |acc, coeff| acc * x + coeff); + |coeffs: &[F]| coeffs.iter().rev().fold(F::ZERO, |acc, coeff| acc * x + coeff); #[cfg(feature = "parallel")] { @@ -68,7 +68,7 @@ impl Polynomial { } let chunk_size = Integer::div_ceil(&self.len(), &num_threads); - let mut results = vec![F::zero(); num_threads]; + let mut results = vec![F::ZERO; num_threads]; parallelize_iter( results.iter_mut().zip(self.0.chunks(chunk_size)).zip(powers(x.pow_vartime(&[ chunk_size as u64, @@ -78,7 +78,7 @@ impl Polynomial { ]))), |((result, coeffs), scalar)| *result = evaluate_serial(coeffs) * scalar, ); - results.iter().fold(F::zero(), |acc, result| acc + result) + results.iter().fold(F::ZERO, |acc, result| acc + result) } #[cfg(not(feature = "parallel"))] evaluate_serial(&self.0) @@ -133,10 +133,10 @@ impl Mul for Polynomial { type Output = Polynomial; fn mul(mut self, rhs: F) -> Polynomial { - if rhs == F::zero() { - return Polynomial::new(vec![F::zero(); self.len()]); + if rhs == F::ZERO { + return Polynomial::new(vec![F::ZERO; self.len()]); } - if rhs == F::one() { + if rhs == F::ONE { return self; } parallelize(&mut self.0, |(lhs, _)| { diff --git a/snark-verifier/src/verifier/plonk/proof.rs b/snark-verifier/src/verifier/plonk/proof.rs index 7adba7ac..a42a56ae 100644 --- a/snark-verifier/src/verifier/plonk/proof.rs +++ b/snark-verifier/src/verifier/plonk/proof.rs @@ -158,7 +158,7 @@ where .queries .iter() .map(|query| { - let shift = protocol.domain.rotate_scalar(C::Scalar::one(), query.rotation); + let shift = protocol.domain.rotate_scalar(C::Scalar::ONE, query.rotation); pcs::Query::new(query.poly, shift) }) .collect() diff --git a/snark-verifier/src/verifier/plonk/protocol.rs b/snark-verifier/src/verifier/plonk/protocol.rs index a3a84346..97f6f336 100644 --- a/snark-verifier/src/verifier/plonk/protocol.rs +++ b/snark-verifier/src/verifier/plonk/protocol.rs @@ -219,7 +219,7 @@ where let numer = zn_minus_one.clone() * &n_inv; let omegas = langranges .iter() - .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::one(), Rotation(i)))) + .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::ONE, Rotation(i)))) .collect_vec(); let lagrange_evals = omegas .iter() @@ -478,7 +478,7 @@ impl Sum for Expression { impl One for Expression { fn one() -> Self { - Expression::Constant(F::one()) + Expression::Constant(F::ONE) } }