diff --git a/Cargo.toml b/Cargo.toml index aa22a7c2..ffa19656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,11 +9,9 @@ readme = "README.md" repository = "https://github.com/Microsoft/Nova" license-file = "LICENSE" keywords = ["zkSNARKs", "cryptography", "proofs"] -rust-version="1.79.0" +rust-version = "1.79.0" [dependencies] -bellpepper-core = { version="0.4.0", default-features = false } -bellpepper = { version="0.4.0", default-features = false } ff = { version = "0.13.0", features = ["derive"] } digest = "0.10" sha3 = "0.10" @@ -23,7 +21,6 @@ rand_chacha = "0.3" subtle = "2.5" pasta_curves = { version = "0.5", features = ["repr-c", "serde"] } halo2curves = { version = "0.6.0", features = ["bits", "derive_serde"] } -neptune = { version = "13.0.0", default-features = false } generic-array = "1.0.0" num-bigint = { version = "0.4", features = ["serde", "rand"] } num-traits = "0.2" @@ -78,6 +75,4 @@ harness = false default = ["halo2curves/asm"] # Compiles in portable mode, w/o ISA extensions => binary can be executed on all systems. portable = ["pasta-msm/portable"] -cuda = ["neptune/cuda", "neptune/pasta", "neptune/arity24"] -opencl = ["neptune/opencl", "neptune/pasta", "neptune/arity24"] flamegraph = ["pprof/flamegraph", "pprof/criterion"] diff --git a/benches/compressed-snark.rs b/benches/compressed-snark.rs index 1a866d7d..031b195e 100644 --- a/benches/compressed-snark.rs +++ b/benches/compressed-snark.rs @@ -1,9 +1,9 @@ #![allow(non_snake_case)] -use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::marker::PhantomData; use criterion::{measurement::WallTime, *}; use ff::PrimeField; +use nova_snark::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use nova_snark::{ provider::{Bn256EngineKZG, GrumpkinEngine}, traits::{ diff --git a/benches/compute-digest.rs b/benches/compute-digest.rs index 64cc4851..e6ba042d 100644 --- a/benches/compute-digest.rs +++ b/benches/compute-digest.rs @@ -1,8 +1,8 @@ use std::{marker::PhantomData, time::Duration}; -use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use ff::PrimeField; +use nova_snark::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use nova_snark::{ provider::{Bn256EngineKZG, GrumpkinEngine}, traits::{ diff --git a/benches/ppsnark.rs b/benches/ppsnark.rs index 56a24ac0..bcb6fa97 100644 --- a/benches/ppsnark.rs +++ b/benches/ppsnark.rs @@ -1,9 +1,9 @@ #![allow(non_snake_case)] -use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::marker::PhantomData; use criterion::*; use ff::PrimeField; +use nova_snark::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use nova_snark::{ provider::Bn256EngineKZG, spartan::direct::DirectSNARK, diff --git a/benches/recursive-snark.rs b/benches/recursive-snark.rs index 72beaa16..18d052d8 100644 --- a/benches/recursive-snark.rs +++ b/benches/recursive-snark.rs @@ -1,9 +1,9 @@ #![allow(non_snake_case)] -use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::marker::PhantomData; use criterion::*; use ff::PrimeField; +use nova_snark::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use nova_snark::{ provider::{Bn256EngineKZG, GrumpkinEngine}, traits::{ diff --git a/benches/sha256.rs b/benches/sha256.rs index dfd03fc3..4f4792a3 100644 --- a/benches/sha256.rs +++ b/benches/sha256.rs @@ -3,16 +3,16 @@ //! This code invokes a hand-written SHA-256 gadget from bellman/bellperson. //! It also uses code from bellman/bellperson to compare circuit-generated digest with sha2 crate's output #![allow(non_snake_case)] -use bellpepper::gadgets::{sha256::sha256, Assignment}; -use bellpepper_core::{ - boolean::{AllocatedBit, Boolean}, - num::{AllocatedNum, Num}, - ConstraintSystem, SynthesisError, -}; use core::marker::PhantomData; use core::time::Duration; use criterion::*; use ff::{PrimeField, PrimeFieldBits}; +use nova_snark::frontend::gadgets::{sha256::sha256, Assignment}; +use nova_snark::frontend::{ + boolean::{AllocatedBit, Boolean}, + num::{AllocatedNum, Num}, + ConstraintSystem, SynthesisError, +}; use nova_snark::{ provider::{Bn256EngineKZG, GrumpkinEngine}, traits::{ diff --git a/examples/and.rs b/examples/and.rs index 33b50bb7..e3fd36de 100644 --- a/examples/and.rs +++ b/examples/and.rs @@ -1,13 +1,13 @@ //! This example executes a batch of 64-bit AND operations. //! It performs the AND operation by first decomposing the operands into bits and then performing the operation bit-by-bit. //! We execute a configurable number of AND operations per step of Nova's recursion. -use bellpepper_core::{ - boolean::AllocatedBit, num::AllocatedNum, ConstraintSystem, LinearCombination, SynthesisError, -}; use core::marker::PhantomData; use ff::Field; use ff::{PrimeField, PrimeFieldBits}; use flate2::{write::ZlibEncoder, Compression}; +use nova_snark::frontend::{ + boolean::AllocatedBit, num::AllocatedNum, ConstraintSystem, LinearCombination, SynthesisError, +}; use nova_snark::{ provider::{Bn256EngineKZG, GrumpkinEngine}, traits::{ diff --git a/examples/hashchain.rs b/examples/hashchain.rs index d0554568..6a291d1f 100644 --- a/examples/hashchain.rs +++ b/examples/hashchain.rs @@ -1,17 +1,12 @@ //! This example proves the knowledge of preimage to a hash chain tail, with a configurable number of elements per hash chain node. //! The output of each step tracks the current tail of the hash chain -use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use ff::Field; use flate2::{write::ZlibEncoder, Compression}; use generic_array::typenum::U24; -use neptune::{ - circuit2::Elt, - sponge::{ - api::{IOPattern, SpongeAPI, SpongeOp}, - circuit::SpongeCircuit, - vanilla::{Mode::Simplex, Sponge, SpongeTrait}, - }, - Strength, +use nova_snark::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; +use nova_snark::provider::poseidon::Elt; +use nova_snark::provider::poseidon::{ + IOPattern, Simplex, Sponge, SpongeAPI, SpongeCircuit, SpongeOp, SpongeTrait, Strength, }; use nova_snark::{ provider::{Bn256EngineKZG, GrumpkinEngine}, @@ -91,9 +86,9 @@ impl StepCircuit for HashChainCircuit { let acc = &mut ns; sponge.start(parameter, None, acc); - neptune::sponge::api::SpongeAPI::absorb(&mut sponge, num_absorbs, &elt, acc); + SpongeAPI::absorb(&mut sponge, num_absorbs, &elt, acc); - let output = neptune::sponge::api::SpongeAPI::squeeze(&mut sponge, 1, acc); + let output = SpongeAPI::squeeze(&mut sponge, 1, acc); sponge.finish(acc).unwrap(); Elt::ensure_allocated(&output[0], &mut ns.namespace(|| "ensure allocated"), true)? }; diff --git a/examples/minroot.rs b/examples/minroot.rs index 20c7275f..ff53f38b 100644 --- a/examples/minroot.rs +++ b/examples/minroot.rs @@ -1,9 +1,9 @@ //! Demonstrates how to use Nova to produce a recursive proof of the correct execution of //! iterations of the `MinRoot` function, thereby realizing a Nova-based verifiable delay function (VDF). //! We execute a configurable number of iterations of the `MinRoot` function per step of Nova's recursion. -use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use ff::Field; use flate2::{write::ZlibEncoder, Compression}; +use nova_snark::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use nova_snark::{ provider::{Bn256EngineKZG, GrumpkinEngine}, traits::{ diff --git a/src/circuit.rs b/src/circuit.rs index 337cdd16..f6b146a2 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -4,6 +4,12 @@ //! of the running instances. Each of these hashes is H(params = H(shape, ck), i, z0, zi, U). //! Each circuit folds the last invocation of the other into the running instance +use crate::frontend::gadgets::Assignment; +use crate::frontend::{ + boolean::{AllocatedBit, Boolean}, + num::AllocatedNum, + ConstraintSystem, SynthesisError, +}; use crate::{ constants::{NUM_FE_WITHOUT_IO_FOR_CRHF, NUM_HASH_BITS}, gadgets::{ @@ -19,12 +25,6 @@ use crate::{ }, Commitment, }; -use bellpepper::gadgets::Assignment; -use bellpepper_core::{ - boolean::{AllocatedBit, Boolean}, - num::AllocatedNum, - ConstraintSystem, SynthesisError, -}; use ff::Field; use serde::{Deserialize, Serialize}; @@ -383,12 +383,12 @@ impl<'a, E: Engine, SC: StepCircuit> NovaAugmentedCircuit<'a, E, SC> { mod tests { use super::*; use crate::{ - bellpepper::{ + constants::{BN_LIMB_WIDTH, BN_N_LIMBS}, + frontend::{ r1cs::{NovaShape, NovaWitness}, solver::SatisfyingAssignment, test_shape_cs::TestShapeCS, }, - constants::{BN_LIMB_WIDTH, BN_N_LIMBS}, gadgets::utils::scalar_as_base, provider::{ poseidon::PoseidonConstantsCircuit, diff --git a/src/errors.rs b/src/errors.rs index a779b4bb..0454a5aa 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -69,8 +69,8 @@ pub enum NovaError { InternalError, } -impl From for NovaError { - fn from(err: bellpepper_core::SynthesisError) -> Self { +impl From for NovaError { + fn from(err: crate::frontend::SynthesisError) -> Self { Self::SynthesisError { reason: err.to_string(), } diff --git a/src/frontend/constraint_system.rs b/src/frontend/constraint_system.rs new file mode 100644 index 00000000..dceede6b --- /dev/null +++ b/src/frontend/constraint_system.rs @@ -0,0 +1,473 @@ +use std::io; +use std::marker::PhantomData; + +use ff::PrimeField; + +use super::lc::{Index, LinearCombination, Variable}; + +/// Computations are expressed in terms of arithmetic circuits, in particular +/// rank-1 quadratic constraint systems. The `Circuit` trait represents a +/// circuit that can be synthesized. The `synthesize` method is called during +/// CRS generation and during proving. +pub trait Circuit { + /// Synthesize the circuit into a rank-1 quadratic constraint system. + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError>; +} + +/// This is an error that could occur during circuit synthesis contexts, +/// such as CRS generation, proving or verification. +#[allow(clippy::upper_case_acronyms)] +#[derive(thiserror::Error, Debug)] +pub enum SynthesisError { + /// During synthesis, we lacked knowledge of a variable assignment. + #[error("an assignment for a variable could not be computed")] + AssignmentMissing, + /// During synthesis, we divided by zero. + #[error("division by zero")] + DivisionByZero, + /// During synthesis, we constructed an unsatisfiable constraint system. + #[error("unsatisfiable constraint system")] + Unsatisfiable, + /// During synthesis, our polynomials ended up being too high of degree + #[error("polynomial degree is too large")] + PolynomialDegreeTooLarge, + /// During proof generation, we encountered an identity in the CRS + #[error("encountered an identity element in the CRS")] + UnexpectedIdentity, + /// During proof generation, we encountered an I/O error with the CRS + #[error("encountered an I/O error: {0}")] + IoError(#[from] io::Error), + /// During verification, our verifying key was malformed. + #[error("malformed verifying key")] + MalformedVerifyingKey, + /// During CRS generation, we observed an unconstrained auxiliary variable + #[error("auxiliary variable was unconstrained")] + UnconstrainedVariable, + /// During GPU multiexp/fft, some GPU related error happened + #[error("attempted to aggregate malformed proofs: {0}")] + MalformedProofs(String), + /// Non power of two proofs given for aggregation + #[error("non power of two proofs given for aggregation")] + NonPowerOfTwo, + /// Incompatible vector length + #[error("incompatible vector length: {0}")] + IncompatibleLengthVector(String), +} + +/// Represents a constraint system which can have new variables +/// allocated and constrains between them formed. +pub trait ConstraintSystem: Sized + Send { + /// Represents the type of the "root" of this constraint system + /// so that nested namespaces can minimize indirection. + type Root: ConstraintSystem; + + /// Create a new empty constraint system + fn new() -> Self { + unimplemented!( + "ConstraintSystem::new must be implemented for extensible types implementing ConstraintSystem" + ); + } + + /// Return the "one" input variable + fn one() -> Variable { + Variable::new_unchecked(Index::Input(0)) + } + + /// Allocate a private variable in the constraint system. The provided function is used to + /// determine the assignment of the variable. The given `annotation` function is invoked + /// in testing contexts in order to derive a unique name for this variable in the current + /// namespace. + fn alloc(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into; + + /// Allocate a pre-committed variable in the constraint system. The provided function is used to + /// determine the assignment of the variable. The given `annotation` function is invoked + /// in testing contexts in order to derive a unique name for this variable in the current + /// namespace. + fn alloc_precommitted( + &mut self, + annotation: A, + f: F, + ) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into; + + /// Allocate a public variable in the constraint system. The provided function is used to + /// determine the assignment of the variable. + fn alloc_input(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into; + + /// Enforce that `A` * `B` = `C`. The `annotation` function is invoked in testing contexts + /// in order to derive a unique name for the constraint in the current namespace. + fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination; + + /// Create a new (sub)namespace and enter into it. Not intended + /// for downstream use; use `namespace` instead. + fn push_namespace(&mut self, name_fn: N) + where + NR: Into, + N: FnOnce() -> NR; + + /// Exit out of the existing namespace. Not intended for + /// downstream use; use `namespace` instead. + fn pop_namespace(&mut self); + + /// Gets the "root" constraint system, bypassing the namespacing. + /// Not intended for downstream use; use `namespace` instead. + fn get_root(&mut self) -> &mut Self::Root; + + /// Begin a namespace for this constraint system. + fn namespace(&mut self, name_fn: N) -> Namespace<'_, Scalar, Self::Root> + where + NR: Into, + N: FnOnce() -> NR, + { + self.get_root().push_namespace(name_fn); + + Namespace(self.get_root(), Default::default()) + } + + /// Most implementations of ConstraintSystem are not 'extensible': they won't implement a specialized + /// version of `extend` and should therefore also keep the default implementation of `is_extensible` + /// so callers which optionally make use of `extend` can know to avoid relying on it when unimplemented. + fn is_extensible() -> bool { + false + } + + /// Extend concatenates thew `other` constraint systems to the receiver, modifying the receiver, whose + /// inputs, allocated variables, and constraints will precede those of the `other` constraint system. + /// The primary use case for this is parallel synthesis of circuits which can be decomposed into + /// entirely independent sub-circuits. Each can be synthesized in its own thread, then the + /// original `ConstraintSystem` can be extended with each, in the same order they would have + /// been synthesized sequentially. + fn extend(&mut self, _other: &Self) { + unimplemented!( + "ConstraintSystem::extend must be implemented for types implementing ConstraintSystem" + ); + } + + /// Determines if the current `ConstraintSystem` instance is a witness generator. + /// ConstraintSystems that are witness generators need not assemble the actual constraints. Rather, they exist only + /// to efficiently create a witness. + /// + /// # Returns + /// + /// * `false` - By default, a `ConstraintSystem` is not a witness generator. + fn is_witness_generator(&self) -> bool { + false + } + + /// Extend the inputs of the `ConstraintSystem`. + /// + /// # Panics + /// + /// Panics if called on a `ConstraintSystem` that is not a witness generator. + fn extend_inputs(&mut self, _new_inputs: &[Scalar]) { + assert!(self.is_witness_generator()); + unimplemented!("ConstraintSystem::extend_inputs must be implemented for witness generators implementing ConstraintSystem") + } + + /// Extend the auxiliary inputs of the `ConstraintSystem`. + /// + /// # Panics + /// + /// Panics if called on a `ConstraintSystem` that is not a witness generator. + fn extend_aux(&mut self, _new_aux: &[Scalar]) { + assert!(self.is_witness_generator()); + unimplemented!("ConstraintSystem::extend_aux must be implemented for witness generators implementing ConstraintSystem") + } + + /// Allocate empty space for the auxiliary inputs and the main inputs of the `ConstraintSystem`. + /// + /// # Panics + /// + /// Panics if called on a `ConstraintSystem` that is not a witness generator. + fn allocate_empty(&mut self, _aux_n: usize, _inputs_n: usize) -> (&mut [Scalar], &mut [Scalar]) { + // This method should only ever be called on witness generators. + assert!(self.is_witness_generator()); + unimplemented!("ConstraintSystem::allocate_empty must be implemented for witness generators implementing ConstraintSystem") + } + + /// Allocate empty space for the main inputs of the `ConstraintSystem`. + /// + /// # Panics + /// + /// Panics if called on a `ConstraintSystem` that is not a witness generator. + fn allocate_empty_inputs(&mut self, _n: usize) -> &mut [Scalar] { + // This method should only ever be called on witness generators. + assert!(self.is_witness_generator()); + unimplemented!("ConstraintSystem::allocate_empty_inputs must be implemented for witness generators implementing ConstraintSystem") + } + + /// Allocate empty space for the auxiliary inputs of the `ConstraintSystem`. + /// + /// # Panics + /// + /// Panics if called on a `ConstraintSystem` that is not a witness generator. + fn allocate_empty_aux(&mut self, _n: usize) -> &mut [Scalar] { + // This method should only ever be called on witness generators. + assert!(self.is_witness_generator()); + unimplemented!("ConstraintSystem::allocate_empty_aux must be implemented for witness generators implementing ConstraintSystem") + } + + /// Returns the constraint system's inputs as a slice of `Scalar`s. + /// + /// # Panics + /// + /// Panics if called on a `ConstraintSystem` that is not a witness generator. + fn inputs_slice(&self) -> &[Scalar] { + assert!(self.is_witness_generator()); + unimplemented!("ConstraintSystem::inputs_slice must be implemented for witness generators implementing ConstraintSystem") + } + + /// Returns the constraint system's aux witness as a slice of `Scalar`s. + /// + /// # Panics + /// + /// Panics if called on a `ConstraintSystem` that is not a witness generator. + fn aux_slice(&self) -> &[Scalar] { + assert!(self.is_witness_generator()); + unimplemented!("ConstraintSystem::aux_slice must be implemented for witness generators implementing ConstraintSystem") + } +} + +/// This is a "namespaced" constraint system which borrows a constraint system (pushing +/// a namespace context) and, when dropped, pops out of the namespace context. +#[derive(Debug)] +pub struct Namespace<'a, Scalar: PrimeField, CS: ConstraintSystem>( + &'a mut CS, + PhantomData, +); + +impl<'cs, Scalar: PrimeField, CS: ConstraintSystem> ConstraintSystem + for Namespace<'cs, Scalar, CS> +{ + type Root = CS::Root; + + fn one() -> Variable { + CS::one() + } + + fn alloc(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.0.alloc(annotation, f) + } + + fn alloc_precommitted( + &mut self, + annotation: A, + f: F, + ) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.0.alloc_precommitted(annotation, f) + } + + fn alloc_input(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.0.alloc_input(annotation, f) + } + + fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + self.0.enforce(annotation, a, b, c) + } + + // Downstream users who use `namespace` will never interact with these + // functions and they will never be invoked because the namespace is + // never a root constraint system. + + fn push_namespace(&mut self, _: N) + where + NR: Into, + N: FnOnce() -> NR, + { + panic!("only the root's push_namespace should be called"); + } + + fn pop_namespace(&mut self) { + panic!("only the root's pop_namespace should be called"); + } + + fn get_root(&mut self) -> &mut Self::Root { + self.0.get_root() + } + + fn is_witness_generator(&self) -> bool { + self.0.is_witness_generator() + } + + fn extend_inputs(&mut self, new_inputs: &[Scalar]) { + self.0.extend_inputs(new_inputs) + } + + fn extend_aux(&mut self, new_aux: &[Scalar]) { + self.0.extend_aux(new_aux) + } + + fn allocate_empty(&mut self, aux_n: usize, inputs_n: usize) -> (&mut [Scalar], &mut [Scalar]) { + self.0.allocate_empty(aux_n, inputs_n) + } + + fn inputs_slice(&self) -> &[Scalar] { + self.0.inputs_slice() + } + fn aux_slice(&self) -> &[Scalar] { + self.0.aux_slice() + } +} + +impl<'a, Scalar: PrimeField, CS: ConstraintSystem> Drop for Namespace<'a, Scalar, CS> { + fn drop(&mut self) { + self.get_root().pop_namespace() + } +} + +/// Convenience implementation of ConstraintSystem for mutable references to +/// constraint systems. +impl<'cs, Scalar: PrimeField, CS: ConstraintSystem> ConstraintSystem + for &'cs mut CS +{ + type Root = CS::Root; + + fn one() -> Variable { + CS::one() + } + + fn alloc(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + (**self).alloc(annotation, f) + } + + fn alloc_precommitted( + &mut self, + annotation: A, + f: F, + ) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + (**self).alloc_precommitted(annotation, f) + } + + fn alloc_input(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + (**self).alloc_input(annotation, f) + } + + fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + (**self).enforce(annotation, a, b, c) + } + + fn push_namespace(&mut self, name_fn: N) + where + NR: Into, + N: FnOnce() -> NR, + { + (**self).push_namespace(name_fn) + } + + fn pop_namespace(&mut self) { + (**self).pop_namespace() + } + + fn get_root(&mut self) -> &mut Self::Root { + (**self).get_root() + } + + fn namespace(&mut self, name_fn: N) -> Namespace<'_, Scalar, Self::Root> + where + NR: Into, + N: FnOnce() -> NR, + { + (**self).namespace(name_fn) + } + + fn is_extensible() -> bool { + CS::is_extensible() + } + + fn extend(&mut self, other: &Self) { + (**self).extend(other) + } + + fn is_witness_generator(&self) -> bool { + (**self).is_witness_generator() + } + + fn extend_inputs(&mut self, new_inputs: &[Scalar]) { + (**self).extend_inputs(new_inputs) + } + + fn extend_aux(&mut self, new_aux: &[Scalar]) { + (**self).extend_aux(new_aux) + } + + fn allocate_empty(&mut self, aux_n: usize, inputs_n: usize) -> (&mut [Scalar], &mut [Scalar]) { + (**self).allocate_empty(aux_n, inputs_n) + } + + fn allocate_empty_inputs(&mut self, n: usize) -> &mut [Scalar] { + (**self).allocate_empty_inputs(n) + } + + fn allocate_empty_aux(&mut self, n: usize) -> &mut [Scalar] { + (**self).allocate_empty_aux(n) + } + + fn inputs_slice(&self) -> &[Scalar] { + (**self).inputs_slice() + } + + fn aux_slice(&self) -> &[Scalar] { + (**self).aux_slice() + } +} diff --git a/src/frontend/gadgets/boolean.rs b/src/frontend/gadgets/boolean.rs new file mode 100644 index 00000000..f142b186 --- /dev/null +++ b/src/frontend/gadgets/boolean.rs @@ -0,0 +1,774 @@ +//! Gadgets for allocating bits in the circuit and performing boolean logic. + +use ff::{PrimeField, PrimeFieldBits}; + +use crate::frontend::{ConstraintSystem, LinearCombination, SynthesisError, Variable}; + +/// Represents a variable in the constraint system which is guaranteed +/// to be either zero or one. +#[derive(Debug, Clone)] +pub struct AllocatedBit { + variable: Variable, + value: Option, +} + +impl AllocatedBit { + /// Get inner value of [`AllocatedBit`] + pub fn get_value(&self) -> Option { + self.value + } + + /// Get inner [`Variable`] of [`AllocatedBit`] + pub fn get_variable(&self) -> Variable { + self.variable + } + + /// Allocate a variable in the constraint system which can only be a + /// boolean value. Further, constrain that the boolean is false + /// unless the condition is false. + pub fn alloc_conditionally( + mut cs: CS, + value: Option, + must_be_false: &AllocatedBit, + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let var = cs.alloc( + || "boolean", + || { + if value.ok_or(SynthesisError::AssignmentMissing)? { + Ok(Scalar::ONE) + } else { + Ok(Scalar::ZERO) + } + }, + )?; + + // Constrain: (1 - must_be_false - a) * a = 0 + // if must_be_false is true, the equation + // reduces to -a * a = 0, which implies a = 0. + // if must_be_false is false, the equation + // reduces to (1 - a) * a = 0, which is a + // traditional boolean constraint. + cs.enforce( + || "boolean constraint", + |lc| lc + CS::one() - must_be_false.variable - var, + |lc| lc + var, + |lc| lc, + ); + + Ok(AllocatedBit { + variable: var, + value, + }) + } + + /// Allocate a variable in the constraint system which can only be a + /// boolean value. + pub fn alloc(mut cs: CS, value: Option) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let var = cs.alloc( + || "boolean", + || { + if value.ok_or(SynthesisError::AssignmentMissing)? { + Ok(Scalar::ONE) + } else { + Ok(Scalar::ZERO) + } + }, + )?; + + // Constrain: (1 - a) * a = 0 + // This constrains a to be either 0 or 1. + cs.enforce( + || "boolean constraint", + |lc| lc + CS::one() - var, + |lc| lc + var, + |lc| lc, + ); + + Ok(AllocatedBit { + variable: var, + value, + }) + } + + /// Performs an XOR operation over the two operands, returning + /// an `AllocatedBit`. + pub fn xor(mut cs: CS, a: &Self, b: &Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let mut result_value = None; + + let result_var = cs.alloc( + || "xor result", + || { + if a.value.ok_or(SynthesisError::AssignmentMissing)? + ^ b.value.ok_or(SynthesisError::AssignmentMissing)? + { + result_value = Some(true); + + Ok(Scalar::ONE) + } else { + result_value = Some(false); + + Ok(Scalar::ZERO) + } + }, + )?; + + // Constrain (a + a) * (b) = (a + b - c) + // Given that a and b are boolean constrained, if they + // are equal, the only solution for c is 0, and if they + // are different, the only solution for c is 1. + // + // ¬(a ∧ b) ∧ ¬(¬a ∧ ¬b) = c + // (1 - (a * b)) * (1 - ((1 - a) * (1 - b))) = c + // (1 - ab) * (1 - (1 - a - b + ab)) = c + // (1 - ab) * (a + b - ab) = c + // a + b - ab - (a^2)b - (b^2)a + (a^2)(b^2) = c + // a + b - ab - ab - ab + ab = c + // a + b - 2ab = c + // -2a * b = c - a - b + // 2a * b = a + b - c + // (a + a) * b = a + b - c + cs.enforce( + || "xor constraint", + |lc| lc + a.variable + a.variable, + |lc| lc + b.variable, + |lc| lc + a.variable + b.variable - result_var, + ); + + Ok(AllocatedBit { + variable: result_var, + value: result_value, + }) + } + + /// Performs an AND operation over the two operands, returning + /// an `AllocatedBit`. + pub fn and(mut cs: CS, a: &Self, b: &Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let mut result_value = None; + + let result_var = cs.alloc( + || "and result", + || { + if a.value.ok_or(SynthesisError::AssignmentMissing)? + & b.value.ok_or(SynthesisError::AssignmentMissing)? + { + result_value = Some(true); + + Ok(Scalar::ONE) + } else { + result_value = Some(false); + + Ok(Scalar::ZERO) + } + }, + )?; + + // Constrain (a) * (b) = (c), ensuring c is 1 iff + // a AND b are both 1. + cs.enforce( + || "and constraint", + |lc| lc + a.variable, + |lc| lc + b.variable, + |lc| lc + result_var, + ); + + Ok(AllocatedBit { + variable: result_var, + value: result_value, + }) + } + + /// Calculates `a AND (NOT b)`. + pub fn and_not(mut cs: CS, a: &Self, b: &Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let mut result_value = None; + + let result_var = cs.alloc( + || "and not result", + || { + if a.value.ok_or(SynthesisError::AssignmentMissing)? + & !b.value.ok_or(SynthesisError::AssignmentMissing)? + { + result_value = Some(true); + + Ok(Scalar::ONE) + } else { + result_value = Some(false); + + Ok(Scalar::ZERO) + } + }, + )?; + + // Constrain (a) * (1 - b) = (c), ensuring c is 1 iff + // a is true and b is false, and otherwise c is 0. + cs.enforce( + || "and not constraint", + |lc| lc + a.variable, + |lc| lc + CS::one() - b.variable, + |lc| lc + result_var, + ); + + Ok(AllocatedBit { + variable: result_var, + value: result_value, + }) + } + + /// Calculates `(NOT a) AND (NOT b)`. + pub fn nor(mut cs: CS, a: &Self, b: &Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let mut result_value = None; + + let result_var = cs.alloc( + || "nor result", + || { + if !a.value.ok_or(SynthesisError::AssignmentMissing)? + & !b.value.ok_or(SynthesisError::AssignmentMissing)? + { + result_value = Some(true); + + Ok(Scalar::ONE) + } else { + result_value = Some(false); + + Ok(Scalar::ZERO) + } + }, + )?; + + // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff + // a and b are both false, and otherwise c is 0. + cs.enforce( + || "nor constraint", + |lc| lc + CS::one() - a.variable, + |lc| lc + CS::one() - b.variable, + |lc| lc + result_var, + ); + + Ok(AllocatedBit { + variable: result_var, + value: result_value, + }) + } +} + +/// Convert a `u64` into a vector of `Boolean`s representing its bits. +pub fn u64_into_boolean_vec_le>( + mut cs: CS, + value: Option, +) -> Result, SynthesisError> { + let values = match value { + Some(ref value) => { + let mut tmp = Vec::with_capacity(64); + + for i in 0..64 { + tmp.push(Some(*value >> i & 1 == 1)); + } + + tmp + } + None => vec![None; 64], + }; + + let bits = values + .into_iter() + .enumerate() + .map(|(i, b)| { + Ok(Boolean::from(AllocatedBit::alloc( + cs.namespace(|| format!("bit {}", i)), + b, + )?)) + }) + .collect::, SynthesisError>>()?; + + Ok(bits) +} + +/// Convert a field element into a vector of `Boolean`s representing its bits. +pub fn field_into_boolean_vec_le( + cs: CS, + value: Option, +) -> Result, SynthesisError> +where + Scalar: PrimeField, + Scalar: PrimeFieldBits, + CS: ConstraintSystem, +{ + let v = field_into_allocated_bits_le::(cs, value)?; + + Ok(v.into_iter().map(Boolean::from).collect()) +} + +/// Convert a field element into a vector of [`AllocatedBit`]'s representing its bits. +pub fn field_into_allocated_bits_le( + mut cs: CS, + value: Option, +) -> Result, SynthesisError> +where + Scalar: PrimeField, + Scalar: PrimeFieldBits, + CS: ConstraintSystem, +{ + // Deconstruct in big-endian bit order + let values = match value { + Some(ref value) => { + let field_char = Scalar::char_le_bits(); + let mut field_char = field_char.into_iter().rev(); + + let mut tmp = Vec::with_capacity(Scalar::NUM_BITS as usize); + + let mut found_one = false; + for b in value.to_le_bits().into_iter().rev() { + // Skip leading bits + found_one |= field_char.next().unwrap(); + if !found_one { + continue; + } + + tmp.push(Some(b)); + } + + assert_eq!(tmp.len(), Scalar::NUM_BITS as usize); + + tmp + } + None => vec![None; Scalar::NUM_BITS as usize], + }; + + // Allocate in little-endian order + let bits = values + .into_iter() + .rev() + .enumerate() + .map(|(i, b)| AllocatedBit::alloc(cs.namespace(|| format!("bit {}", i)), b)) + .collect::, SynthesisError>>()?; + + Ok(bits) +} + +/// This is a boolean value which may be either a constant or +/// an interpretation of an `AllocatedBit`. +#[derive(Clone, Debug)] +pub enum Boolean { + /// Existential view of the boolean variable + Is(AllocatedBit), + /// Negated view of the boolean variable + Not(AllocatedBit), + /// Constant (not an allocated variable) + Constant(bool), +} + +impl Boolean { + /// Check if the boolean is a constant + pub fn is_constant(&self) -> bool { + matches!(*self, Boolean::Constant(_)) + } + + /// Constrain two booleans to be equal. + pub fn enforce_equal(mut cs: CS, a: &Self, b: &Self) -> Result<(), SynthesisError> + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + match (a, b) { + (&Boolean::Constant(a), &Boolean::Constant(b)) => { + if a == b { + Ok(()) + } else { + Err(SynthesisError::Unsatisfiable) + } + } + (&Boolean::Constant(true), a) | (a, &Boolean::Constant(true)) => { + cs.enforce( + || "enforce equal to one", + |lc| lc, + |lc| lc, + |lc| lc + CS::one() - &a.lc(CS::one(), Scalar::ONE), + ); + + Ok(()) + } + (&Boolean::Constant(false), a) | (a, &Boolean::Constant(false)) => { + cs.enforce( + || "enforce equal to zero", + |lc| lc, + |lc| lc, + |_| a.lc(CS::one(), Scalar::ONE), + ); + + Ok(()) + } + (a, b) => { + cs.enforce( + || "enforce equal", + |lc| lc, + |lc| lc, + |_| a.lc(CS::one(), Scalar::ONE) - &b.lc(CS::one(), Scalar::ONE), + ); + + Ok(()) + } + } + } + + /// Get the inner value of the boolean. + pub fn get_value(&self) -> Option { + match *self { + Boolean::Constant(c) => Some(c), + Boolean::Is(ref v) => v.get_value(), + Boolean::Not(ref v) => v.get_value().map(|b| !b), + } + } + + /// Return a linear combination representing the boolean. + pub fn lc(&self, one: Variable, coeff: Scalar) -> LinearCombination { + match *self { + Boolean::Constant(c) => { + if c { + LinearCombination::::zero() + (coeff, one) + } else { + LinearCombination::::zero() + } + } + Boolean::Is(ref v) => LinearCombination::::zero() + (coeff, v.get_variable()), + Boolean::Not(ref v) => { + LinearCombination::::zero() + (coeff, one) - (coeff, v.get_variable()) + } + } + } + + /// Construct a boolean from a known constant + pub fn constant(b: bool) -> Self { + Boolean::Constant(b) + } + + /// Return a negated interpretation of this boolean. + pub fn not(&self) -> Self { + match *self { + Boolean::Constant(c) => Boolean::Constant(!c), + Boolean::Is(ref v) => Boolean::Not(v.clone()), + Boolean::Not(ref v) => Boolean::Is(v.clone()), + } + } + + /// Perform XOR over two boolean operands + pub fn xor<'a, Scalar, CS>(cs: CS, a: &'a Self, b: &'a Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + match (a, b) { + (&Boolean::Constant(false), x) | (x, &Boolean::Constant(false)) => Ok(x.clone()), + (&Boolean::Constant(true), x) | (x, &Boolean::Constant(true)) => Ok(x.not()), + // a XOR (NOT b) = NOT(a XOR b) + (is @ &Boolean::Is(_), not @ &Boolean::Not(_)) + | (not @ &Boolean::Not(_), is @ &Boolean::Is(_)) => { + Ok(Boolean::xor(cs, is, ¬.not())?.not()) + } + // a XOR b = (NOT a) XOR (NOT b) + (&Boolean::Is(ref a), &Boolean::Is(ref b)) | (&Boolean::Not(ref a), &Boolean::Not(ref b)) => { + Ok(Boolean::Is(AllocatedBit::xor(cs, a, b)?)) + } + } + } + + /// Perform AND over two boolean operands + pub fn and<'a, Scalar, CS>(cs: CS, a: &'a Self, b: &'a Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + match (a, b) { + // false AND x is always false + (&Boolean::Constant(false), _) | (_, &Boolean::Constant(false)) => { + Ok(Boolean::Constant(false)) + } + // true AND x is always x + (&Boolean::Constant(true), x) | (x, &Boolean::Constant(true)) => Ok(x.clone()), + // a AND (NOT b) + (&Boolean::Is(ref is), &Boolean::Not(ref not)) + | (&Boolean::Not(ref not), &Boolean::Is(ref is)) => { + Ok(Boolean::Is(AllocatedBit::and_not(cs, is, not)?)) + } + // (NOT a) AND (NOT b) = a NOR b + (Boolean::Not(a), Boolean::Not(b)) => Ok(Boolean::Is(AllocatedBit::nor(cs, a, b)?)), + // a AND b + (Boolean::Is(a), Boolean::Is(b)) => Ok(Boolean::Is(AllocatedBit::and(cs, a, b)?)), + } + } + + /// Perform OR over two boolean operands + pub fn or<'a, Scalar, CS>( + mut cs: CS, + a: &'a Boolean, + b: &'a Boolean, + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + Ok(Boolean::not(&Boolean::and( + cs.namespace(|| "not and (not a) (not b)"), + &Boolean::not(a), + &Boolean::not(b), + )?)) + } + + /// Computes (a and b) xor ((not a) and c) + pub fn sha256_ch<'a, Scalar, CS>( + mut cs: CS, + a: &'a Self, + b: &'a Self, + c: &'a Self, + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let ch_value = match (a.get_value(), b.get_value(), c.get_value()) { + (Some(a), Some(b), Some(c)) => { + // (a and b) xor ((not a) and c) + Some((a & b) ^ ((!a) & c)) + } + _ => None, + }; + + match (a, b, c) { + (&Boolean::Constant(_), &Boolean::Constant(_), &Boolean::Constant(_)) => { + // They're all constants, so we can just compute the value. + + return Ok(Boolean::Constant(ch_value.expect("they're all constants"))); + } + (&Boolean::Constant(false), _, c) => { + // If a is false + // (a and b) xor ((not a) and c) + // equals + // (false) xor (c) + // equals + // c + return Ok(c.clone()); + } + (a, &Boolean::Constant(false), c) => { + // If b is false + // (a and b) xor ((not a) and c) + // equals + // ((not a) and c) + return Boolean::and(cs, &a.not(), c); + } + (a, b, &Boolean::Constant(false)) => { + // If c is false + // (a and b) xor ((not a) and c) + // equals + // (a and b) + return Boolean::and(cs, a, b); + } + (a, b, &Boolean::Constant(true)) => { + // If c is true + // (a and b) xor ((not a) and c) + // equals + // (a and b) xor (not a) + // equals + // not (a and (not b)) + return Ok(Boolean::and(cs, a, &b.not())?.not()); + } + (a, &Boolean::Constant(true), c) => { + // If b is true + // (a and b) xor ((not a) and c) + // equals + // a xor ((not a) and c) + // equals + // not ((not a) and (not c)) + return Ok(Boolean::and(cs, &a.not(), &c.not())?.not()); + } + (&Boolean::Constant(true), _, _) => { + // If a is true + // (a and b) xor ((not a) and c) + // equals + // b xor ((not a) and c) + // So we just continue! + } + ( + &Boolean::Is(_) | &Boolean::Not(_), + &Boolean::Is(_) | &Boolean::Not(_), + &Boolean::Is(_) | &Boolean::Not(_), + ) => {} + } + + let ch = cs.alloc( + || "ch", + || { + ch_value.ok_or(SynthesisError::AssignmentMissing).map(|v| { + if v { + Scalar::ONE + } else { + Scalar::ZERO + } + }) + }, + )?; + + // a(b - c) = ch - c + cs.enforce( + || "ch computation", + |_| b.lc(CS::one(), Scalar::ONE) - &c.lc(CS::one(), Scalar::ONE), + |_| a.lc(CS::one(), Scalar::ONE), + |lc| lc + ch - &c.lc(CS::one(), Scalar::ONE), + ); + + Ok( + AllocatedBit { + value: ch_value, + variable: ch, + } + .into(), + ) + } + + /// Computes (a and b) xor (a and c) xor (b and c) + pub fn sha256_maj<'a, Scalar, CS>( + mut cs: CS, + a: &'a Self, + b: &'a Self, + c: &'a Self, + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let maj_value = match (a.get_value(), b.get_value(), c.get_value()) { + (Some(a), Some(b), Some(c)) => { + // (a and b) xor (a and c) xor (b and c) + Some((a & b) ^ (a & c) ^ (b & c)) + } + _ => None, + }; + + match (a, b, c) { + (&Boolean::Constant(_), &Boolean::Constant(_), &Boolean::Constant(_)) => { + // They're all constants, so we can just compute the value. + + return Ok(Boolean::Constant(maj_value.expect("they're all constants"))); + } + (&Boolean::Constant(false), b, c) => { + // If a is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (b and c) + return Boolean::and(cs, b, c); + } + (a, &Boolean::Constant(false), c) => { + // If b is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and c) + return Boolean::and(cs, a, c); + } + (a, b, &Boolean::Constant(false)) => { + // If c is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and b) + return Boolean::and(cs, a, b); + } + (a, b, &Boolean::Constant(true)) => { + // If c is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and b) xor (a) xor (b) + // equals + // not ((not a) and (not b)) + return Ok(Boolean::and(cs, &a.not(), &b.not())?.not()); + } + (a, &Boolean::Constant(true), c) => { + // If b is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a) xor (a and c) xor (c) + return Ok(Boolean::and(cs, &a.not(), &c.not())?.not()); + } + (&Boolean::Constant(true), b, c) => { + // If a is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (b) xor (c) xor (b and c) + return Ok(Boolean::and(cs, &b.not(), &c.not())?.not()); + } + ( + &Boolean::Is(_) | &Boolean::Not(_), + &Boolean::Is(_) | &Boolean::Not(_), + &Boolean::Is(_) | &Boolean::Not(_), + ) => {} + } + + let maj = cs.alloc( + || "maj", + || { + maj_value.ok_or(SynthesisError::AssignmentMissing).map(|v| { + if v { + Scalar::ONE + } else { + Scalar::ZERO + } + }) + }, + )?; + + // ¬(¬a ∧ ¬b) ∧ ¬(¬a ∧ ¬c) ∧ ¬(¬b ∧ ¬c) + // (1 - ((1 - a) * (1 - b))) * (1 - ((1 - a) * (1 - c))) * (1 - ((1 - b) * (1 - c))) + // (a + b - ab) * (a + c - ac) * (b + c - bc) + // -2abc + ab + ac + bc + // a (-2bc + b + c) + bc + // + // (b) * (c) = (bc) + // (2bc - b - c) * (a) = bc - maj + + let bc = Self::and(cs.namespace(|| "b and c"), b, c)?; + + cs.enforce( + || "maj computation", + |_| { + bc.lc(CS::one(), Scalar::ONE) + &bc.lc(CS::one(), Scalar::ONE) + - &b.lc(CS::one(), Scalar::ONE) + - &c.lc(CS::one(), Scalar::ONE) + }, + |_| a.lc(CS::one(), Scalar::ONE), + |_| bc.lc(CS::one(), Scalar::ONE) - maj, + ); + + Ok( + AllocatedBit { + value: maj_value, + variable: maj, + } + .into(), + ) + } +} + +impl From for Boolean { + fn from(b: AllocatedBit) -> Boolean { + Boolean::Is(b) + } +} diff --git a/src/frontend/gadgets/mod.rs b/src/frontend/gadgets/mod.rs new file mode 100644 index 00000000..bb751231 --- /dev/null +++ b/src/frontend/gadgets/mod.rs @@ -0,0 +1,23 @@ +//! Self-contained sub-circuit implementations for various primitives. + +use super::SynthesisError; +pub mod boolean; +mod multieq; +pub mod num; +pub mod sha256; +mod uint32; + +/// A trait for representing an assignment to a variable. +pub trait Assignment { + /// Get the value of the assigned variable. + fn get(&self) -> Result<&T, SynthesisError>; +} + +impl Assignment for Option { + fn get(&self) -> Result<&T, SynthesisError> { + match *self { + Some(ref v) => Ok(v), + None => Err(SynthesisError::AssignmentMissing), + } + } +} diff --git a/src/frontend/gadgets/multieq.rs b/src/frontend/gadgets/multieq.rs new file mode 100644 index 00000000..8f249f51 --- /dev/null +++ b/src/frontend/gadgets/multieq.rs @@ -0,0 +1,135 @@ +use ff::PrimeField; + +use crate::frontend::{ConstraintSystem, LinearCombination, SynthesisError, Variable}; + +#[derive(Debug)] +pub struct MultiEq> { + cs: CS, + ops: usize, + bits_used: usize, + lhs: LinearCombination, + rhs: LinearCombination, +} + +impl> MultiEq { + pub fn new(cs: CS) -> Self { + MultiEq { + cs, + ops: 0, + bits_used: 0, + lhs: LinearCombination::zero(), + rhs: LinearCombination::zero(), + } + } + + fn accumulate(&mut self) { + let ops = self.ops; + let lhs = self.lhs.clone(); + let rhs = self.rhs.clone(); + self.cs.enforce( + || format!("multieq {}", ops), + |_| lhs, + |lc| lc + CS::one(), + |_| rhs, + ); + self.lhs = LinearCombination::zero(); + self.rhs = LinearCombination::zero(); + self.bits_used = 0; + self.ops += 1; + } + + pub fn enforce_equal( + &mut self, + num_bits: usize, + lhs: &LinearCombination, + rhs: &LinearCombination, + ) { + // Check if we will exceed the capacity + if (Scalar::CAPACITY as usize) <= (self.bits_used + num_bits) { + self.accumulate(); + } + + assert!((Scalar::CAPACITY as usize) > (self.bits_used + num_bits)); + + let coeff = Scalar::from(2u64).pow_vartime([self.bits_used as u64]); + self.lhs = self.lhs.clone() + (coeff, lhs); + self.rhs = self.rhs.clone() + (coeff, rhs); + self.bits_used += num_bits; + } +} + +impl> Drop for MultiEq { + fn drop(&mut self) { + if self.bits_used > 0 { + self.accumulate(); + } + } +} + +impl> ConstraintSystem + for MultiEq +{ + type Root = Self; + + fn one() -> Variable { + CS::one() + } + + fn alloc(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.cs.alloc(annotation, f) + } + + fn alloc_precommitted( + &mut self, + annotation: A, + f: F, + ) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.cs.alloc_precommitted(annotation, f) + } + + fn alloc_input(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.cs.alloc_input(annotation, f) + } + + fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + self.cs.enforce(annotation, a, b, c) + } + + fn push_namespace(&mut self, name_fn: N) + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.get_root().push_namespace(name_fn) + } + + fn pop_namespace(&mut self) { + self.cs.get_root().pop_namespace() + } + + fn get_root(&mut self) -> &mut Self::Root { + self + } +} diff --git a/src/frontend/gadgets/num.rs b/src/frontend/gadgets/num.rs new file mode 100644 index 00000000..af324a56 --- /dev/null +++ b/src/frontend/gadgets/num.rs @@ -0,0 +1,552 @@ +//! Gadgets representing numbers in the scalar field of the underlying curve. + +use ff::{PrimeField, PrimeFieldBits}; +use serde::{Deserialize, Serialize}; + +use crate::frontend::{ConstraintSystem, LinearCombination, SynthesisError, Variable}; + +use crate::frontend::gadgets::boolean::{self, AllocatedBit, Boolean}; + +/// Represents an allocated number in the circuit. +#[derive(Debug, Serialize, Deserialize)] +pub struct AllocatedNum { + value: Option, + variable: Variable, +} + +impl Clone for AllocatedNum { + fn clone(&self) -> Self { + AllocatedNum { + value: self.value, + variable: self.variable, + } + } +} + +impl AllocatedNum { + /// Allocate a `Variable(Aux)` in a `ConstraintSystem`. + pub fn alloc(mut cs: CS, value: F) -> Result + where + CS: ConstraintSystem, + F: FnOnce() -> Result, + { + let mut new_value = None; + let var = cs.alloc( + || "num", + || { + let tmp = value()?; + + new_value = Some(tmp); + + Ok(tmp) + }, + )?; + + Ok(AllocatedNum { + value: new_value, + variable: var, + }) + } + + /// Allocate a `Variable(Aux)` in a `ConstraintSystem`. Requires an + /// infallible getter for the value. + pub fn alloc_infallible(cs: CS, value: F) -> Self + where + CS: ConstraintSystem, + F: FnOnce() -> Scalar, + { + Self::alloc(cs, || Ok(value())).unwrap() + } + + /// Allocate a `Variable(Input)` in a `ConstraintSystem`. + pub fn alloc_input(mut cs: CS, value: F) -> Result + where + CS: ConstraintSystem, + F: FnOnce() -> Result, + { + let mut new_value = None; + let var = cs.alloc_input( + || "input num", + || { + let tmp = value()?; + + new_value = Some(tmp); + + Ok(tmp) + }, + )?; + + Ok(AllocatedNum { + value: new_value, + variable: var, + }) + } + + /// Allocate a `Variable` of either `Aux` or `Input` in a + /// `ConstraintSystem`. The `Variable` is a an `Input` if `is_input` is + /// true. This allows uniform creation of circuits containing components + /// which may or may not be public inputs. + pub fn alloc_maybe_input(cs: CS, is_input: bool, value: F) -> Result + where + CS: ConstraintSystem, + F: FnOnce() -> Result, + { + if is_input { + Self::alloc_input(cs, value) + } else { + Self::alloc(cs, value) + } + } + + /// Make [`AllocatedNum`] a public input. + pub fn inputize(&self, mut cs: CS) -> Result<(), SynthesisError> + where + CS: ConstraintSystem, + { + let input = cs.alloc_input( + || "input variable", + || self.value.ok_or(SynthesisError::AssignmentMissing), + )?; + + cs.enforce( + || "enforce input is correct", + |lc| lc + input, + |lc| lc + CS::one(), + |lc| lc + self.variable, + ); + + Ok(()) + } + + /// Deconstructs this allocated number into its + /// boolean representation in little-endian bit + /// order, requiring that the representation + /// strictly exists "in the field" (i.e., a + /// congruency is not allowed.) + pub fn to_bits_le_strict(&self, mut cs: CS) -> Result, SynthesisError> + where + CS: ConstraintSystem, + Scalar: PrimeFieldBits, + { + pub fn kary_and( + mut cs: CS, + v: &[AllocatedBit], + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + assert!(!v.is_empty()); + + // Let's keep this simple for now and just AND them all + // manually + let mut cur = None; + + for (i, v) in v.iter().enumerate() { + if cur.is_none() { + cur = Some(v.clone()); + } else { + cur = Some(AllocatedBit::and( + cs.namespace(|| format!("and {}", i)), + cur.as_ref().unwrap(), + v, + )?); + } + } + + Ok(cur.expect("v.len() > 0")) + } + + // We want to ensure that the bit representation of a is + // less than or equal to r - 1. + let a = self.value.map(|e| e.to_le_bits()); + let b = (-Scalar::ONE).to_le_bits(); + + // Get the bits of `a` in big-endian order. + let mut a = a.as_ref().map(|e| e.into_iter().rev()); + + let mut result = vec![]; + + // Runs of ones in r + let mut last_run = None; + let mut current_run = vec![]; + + let mut found_one = false; + let mut i = 0; + for b in b.into_iter().rev() { + let a_bit: Option = a.as_mut().map(|e| *e.next().unwrap()); + + // Skip over unset bits at the beginning + found_one |= b; + if !found_one { + // a_bit should also be false + if let Some(a_bit) = a_bit { + assert!(!a_bit); + } + continue; + } + + if b { + // This is part of a run of ones. Let's just + // allocate the boolean with the expected value. + let a_bit = AllocatedBit::alloc(cs.namespace(|| format!("bit {}", i)), a_bit)?; + // ... and add it to the current run of ones. + current_run.push(a_bit.clone()); + result.push(a_bit); + } else { + if !current_run.is_empty() { + // This is the start of a run of zeros, but we need + // to k-ary AND against `last_run` first. + + if last_run.is_some() { + current_run.push(last_run.clone().unwrap()); + } + last_run = Some(kary_and( + cs.namespace(|| format!("run ending at {}", i)), + ¤t_run, + )?); + current_run.truncate(0); + } + + // If `last_run` is true, `a` must be false, or it would + // not be in the field. + // + // If `last_run` is false, `a` can be true or false. + + let a_bit = AllocatedBit::alloc_conditionally( + cs.namespace(|| format!("bit {}", i)), + a_bit, + last_run.as_ref().expect("char always starts with a one"), + )?; + result.push(a_bit); + } + + i += 1; + } + + // char is prime, so we'll always end on + // a run of zeros. + assert_eq!(current_run.len(), 0); + + // Now, we have `result` in big-endian order. + // However, now we have to unpack self! + + let mut lc = LinearCombination::zero(); + let mut coeff = Scalar::ONE; + + for bit in result.iter().rev() { + lc = lc + (coeff, bit.get_variable()); + + coeff = coeff.double(); + } + + lc = lc - self.variable; + + cs.enforce(|| "unpacking constraint", |lc| lc, |lc| lc, |_| lc); + + // Convert into booleans, and reverse for little-endian bit order + Ok(result.into_iter().map(Boolean::from).rev().collect()) + } + + /// Convert the allocated number into its little-endian representation. + /// Note that this does not strongly enforce that the commitment is + /// "in the field." + pub fn to_bits_le(&self, mut cs: CS) -> Result, SynthesisError> + where + CS: ConstraintSystem, + Scalar: PrimeFieldBits, + { + let bits = boolean::field_into_allocated_bits_le(&mut cs, self.value)?; + + let mut lc = LinearCombination::zero(); + let mut coeff = Scalar::ONE; + + for bit in bits.iter() { + lc = lc + (coeff, bit.get_variable()); + + coeff = coeff.double(); + } + + lc = lc - self.variable; + + cs.enforce(|| "unpacking constraint", |lc| lc, |lc| lc, |_| lc); + + Ok(bits.into_iter().map(Boolean::from).collect()) + } + + /// Adds two allocated numbers together, returning a new allocated number. + pub fn add(&self, mut cs: CS, other: &Self) -> Result + where + CS: ConstraintSystem, + { + let mut value = None; + + let var = cs.alloc( + || "sum num", + || { + let mut tmp = self.value.ok_or(SynthesisError::AssignmentMissing)?; + tmp.add_assign(other.value.ok_or(SynthesisError::AssignmentMissing)?); + + value = Some(tmp); + + Ok(tmp) + }, + )?; + + // Constrain: (a + b) * 1 = a + b + cs.enforce( + || "addition constraint", + |lc| lc + self.variable + other.variable, + |lc| lc + CS::one(), + |lc| lc + var, + ); + + Ok(AllocatedNum { + value, + variable: var, + }) + } + + /// Multiplies two allocated numbers together, returning a new allocated number. + pub fn mul(&self, mut cs: CS, other: &Self) -> Result + where + CS: ConstraintSystem, + { + let mut value = None; + + let var = cs.alloc( + || "product num", + || { + let mut tmp = self.value.ok_or(SynthesisError::AssignmentMissing)?; + tmp.mul_assign(other.value.ok_or(SynthesisError::AssignmentMissing)?); + + value = Some(tmp); + + Ok(tmp) + }, + )?; + + // Constrain: a * b = ab + cs.enforce( + || "multiplication constraint", + |lc| lc + self.variable, + |lc| lc + other.variable, + |lc| lc + var, + ); + + Ok(AllocatedNum { + value, + variable: var, + }) + } + + /// Squares an allocated number, returning a new allocated number. + pub fn square(&self, mut cs: CS) -> Result + where + CS: ConstraintSystem, + { + let mut value = None; + + let var = cs.alloc( + || "squared num", + || { + let mut tmp = self.value.ok_or(SynthesisError::AssignmentMissing)?; + tmp = tmp.square(); + + value = Some(tmp); + + Ok(tmp) + }, + )?; + + // Constrain: a * a = aa + cs.enforce( + || "squaring constraint", + |lc| lc + self.variable, + |lc| lc + self.variable, + |lc| lc + var, + ); + + Ok(AllocatedNum { + value, + variable: var, + }) + } + + /// Asserts that the allocated number is not zero. + pub fn assert_nonzero(&self, mut cs: CS) -> Result<(), SynthesisError> + where + CS: ConstraintSystem, + { + let inv = cs.alloc( + || "ephemeral inverse", + || { + let tmp = self.value.ok_or(SynthesisError::AssignmentMissing)?; + + if tmp.is_zero().into() { + Err(SynthesisError::DivisionByZero) + } else { + Ok(tmp.invert().unwrap()) + } + }, + )?; + + // Constrain a * inv = 1, which is only valid + // iff a has a multiplicative inverse, untrue + // for zero. + cs.enforce( + || "nonzero assertion constraint", + |lc| lc + self.variable, + |lc| lc + inv, + |lc| lc + CS::one(), + ); + + Ok(()) + } + + /// Takes two allocated numbers (a, b) and returns + /// (b, a) if the condition is true, and (a, b) + /// otherwise. + pub fn conditionally_reverse( + mut cs: CS, + a: &Self, + b: &Self, + condition: &Boolean, + ) -> Result<(Self, Self), SynthesisError> + where + CS: ConstraintSystem, + { + let c = Self::alloc(cs.namespace(|| "conditional reversal result 1"), || { + if condition + .get_value() + .ok_or(SynthesisError::AssignmentMissing)? + { + Ok(b.value.ok_or(SynthesisError::AssignmentMissing)?) + } else { + Ok(a.value.ok_or(SynthesisError::AssignmentMissing)?) + } + })?; + + cs.enforce( + || "first conditional reversal", + |lc| lc + a.variable - b.variable, + |_| condition.lc(CS::one(), Scalar::ONE), + |lc| lc + a.variable - c.variable, + ); + + let d = Self::alloc(cs.namespace(|| "conditional reversal result 2"), || { + if condition + .get_value() + .ok_or(SynthesisError::AssignmentMissing)? + { + Ok(a.value.ok_or(SynthesisError::AssignmentMissing)?) + } else { + Ok(b.value.ok_or(SynthesisError::AssignmentMissing)?) + } + })?; + + cs.enforce( + || "second conditional reversal", + |lc| lc + b.variable - a.variable, + |_| condition.lc(CS::one(), Scalar::ONE), + |lc| lc + b.variable - d.variable, + ); + + Ok((c, d)) + } + + /// Get scalar value of the [`AllocatedNum`]. + pub fn get_value(&self) -> Option { + self.value + } + + /// Get the inner [`Variable`] of the [`AllocatedNum`]. + pub fn get_variable(&self) -> Variable { + self.variable + } +} + +/// Represents a number in the circuit using a linear combination. +#[derive(Debug, Clone)] +pub struct Num { + value: Option, + lc: LinearCombination, +} + +impl From> for Num { + fn from(num: AllocatedNum) -> Num { + Num { + value: num.value, + lc: LinearCombination::::from_variable(num.variable), + } + } +} + +impl Num { + /// Create a zero [`Num`]. + pub fn zero() -> Self { + Num { + value: Some(Scalar::ZERO), + lc: LinearCombination::zero(), + } + } + + /// Get [`Scalar`] value of the [`Num`]. + pub fn get_value(&self) -> Option { + self.value + } + + /// Get the inner [`LinearCombination`] of the [`Num`]. + pub fn lc(&self, coeff: Scalar) -> LinearCombination { + LinearCombination::zero() + (coeff, &self.lc) + } + + /// Add a boolean to the Num with a given coefficient. + pub fn add_bool_with_coeff(self, one: Variable, bit: &Boolean, coeff: Scalar) -> Self { + let newval = match (self.value, bit.get_value()) { + (Some(mut curval), Some(bval)) => { + if bval { + curval.add_assign(&coeff); + } + + Some(curval) + } + _ => None, + }; + + Num { + value: newval, + lc: self.lc + &bit.lc(one, coeff), + } + } + + /// Add self to another Num, returning a new Num. + #[allow(clippy::should_implement_trait)] + pub fn add(self, other: &Self) -> Self { + let lc = self.lc + &other.lc; + let value = match (self.value, other.value) { + (Some(v1), Some(v2)) => { + let mut tmp = v1; + tmp.add_assign(&v2); + Some(tmp) + } + (Some(v), None) | (None, Some(v)) => Some(v), + (None, None) => None, + }; + + Num { value, lc } + } + + /// Scale the [`Num`] by a scalar. + pub fn scale(mut self, scalar: Scalar) -> Self { + for (_variable, fr) in self.lc.iter_mut() { + fr.mul_assign(&scalar); + } + + if let Some(ref mut v) = self.value { + v.mul_assign(&scalar); + } + + self + } +} diff --git a/src/frontend/gadgets/sha256.rs b/src/frontend/gadgets/sha256.rs new file mode 100644 index 00000000..3b668852 --- /dev/null +++ b/src/frontend/gadgets/sha256.rs @@ -0,0 +1,275 @@ +//! Circuits for the [SHA-256] hash function and its internal compression +//! function. +//! +//! [SHA-256]: https://tools.ietf.org/html/rfc6234 + +#![allow(clippy::many_single_char_names)] + +use ff::PrimeField; + +use super::boolean::Boolean; +use super::multieq::MultiEq; +use super::uint32::UInt32; +use crate::frontend::{ConstraintSystem, SynthesisError}; + +#[allow(clippy::unreadable_literal)] +const ROUND_CONSTANTS: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +#[allow(clippy::unreadable_literal)] +const IV: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +/// Compute the SHA-256 hash of the given input. +pub fn sha256_block_no_padding( + mut cs: CS, + input: &[Boolean], +) -> Result, SynthesisError> +where + Scalar: PrimeField, + CS: ConstraintSystem, +{ + assert_eq!(input.len(), 512); + + Ok( + sha256_compression_function(&mut cs, input, &get_sha256_iv())? + .into_iter() + .flat_map(|e| e.into_bits_be()) + .collect(), + ) +} + +/// Compute the SHA-256 hash of the given input. +pub fn sha256(mut cs: CS, input: &[Boolean]) -> Result, SynthesisError> +where + Scalar: PrimeField, + CS: ConstraintSystem, +{ + assert!(input.len() % 8 == 0); + + let mut padded = input.to_vec(); + let plen = padded.len() as u64; + // append a single '1' bit + padded.push(Boolean::constant(true)); + // append K '0' bits, where K is the minimum number >= 0 such that L + 1 + K + 64 is a multiple of 512 + while (padded.len() + 64) % 512 != 0 { + padded.push(Boolean::constant(false)); + } + // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits + for b in (0..64).rev().map(|i| (plen >> i) & 1 == 1) { + padded.push(Boolean::constant(b)); + } + assert!(padded.len() % 512 == 0); + + let mut cur = get_sha256_iv(); + for (i, block) in padded.chunks(512).enumerate() { + cur = sha256_compression_function(cs.namespace(|| format!("block {}", i)), block, &cur)?; + } + + Ok(cur.into_iter().flat_map(|e| e.into_bits_be()).collect()) +} + +fn get_sha256_iv() -> Vec { + IV.iter().map(|&v| UInt32::constant(v)).collect() +} + +/// Sha256 compression function +pub fn sha256_compression_function( + cs: CS, + input: &[Boolean], + current_hash_value: &[UInt32], +) -> Result, SynthesisError> +where + Scalar: PrimeField, + CS: ConstraintSystem, +{ + assert_eq!(input.len(), 512); + assert_eq!(current_hash_value.len(), 8); + + let mut w = input + .chunks(32) + .map(UInt32::from_bits_be) + .collect::>(); + + // We can save some constraints by combining some of + // the constraints in different u32 additions + let mut cs = MultiEq::new(cs); + + for i in 16..64 { + let cs = &mut cs.namespace(|| format!("w extension {}", i)); + + // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) + let mut s0 = w[i - 15].rotr(7); + s0 = s0.xor(cs.namespace(|| "first xor for s0"), &w[i - 15].rotr(18))?; + s0 = s0.xor(cs.namespace(|| "second xor for s0"), &w[i - 15].shr(3))?; + + // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) + let mut s1 = w[i - 2].rotr(17); + s1 = s1.xor(cs.namespace(|| "first xor for s1"), &w[i - 2].rotr(19))?; + s1 = s1.xor(cs.namespace(|| "second xor for s1"), &w[i - 2].shr(10))?; + + let tmp = UInt32::addmany( + cs.namespace(|| "computation of w[i]"), + &[w[i - 16].clone(), s0, w[i - 7].clone(), s1], + )?; + + // w[i] := w[i-16] + s0 + w[i-7] + s1 + w.push(tmp); + } + + assert_eq!(w.len(), 64); + + enum Maybe { + Deferred(Vec), + Concrete(UInt32), + } + + impl Maybe { + fn compute(self, cs: M, others: &[UInt32]) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + M: ConstraintSystem>, + { + Ok(match self { + Maybe::Concrete(ref v) => return Ok(v.clone()), + Maybe::Deferred(mut v) => { + v.extend(others.iter().cloned()); + UInt32::addmany(cs, &v)? + } + }) + } + } + + let mut a = Maybe::Concrete(current_hash_value[0].clone()); + let mut b = current_hash_value[1].clone(); + let mut c = current_hash_value[2].clone(); + let mut d = current_hash_value[3].clone(); + let mut e = Maybe::Concrete(current_hash_value[4].clone()); + let mut f = current_hash_value[5].clone(); + let mut g = current_hash_value[6].clone(); + let mut h = current_hash_value[7].clone(); + + for i in 0..64 { + let cs = &mut cs.namespace(|| format!("compression round {}", i)); + + // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25) + let new_e = e.compute(cs.namespace(|| "deferred e computation"), &[])?; + let mut s1 = new_e.rotr(6); + s1 = s1.xor(cs.namespace(|| "first xor for s1"), &new_e.rotr(11))?; + s1 = s1.xor(cs.namespace(|| "second xor for s1"), &new_e.rotr(25))?; + + // ch := (e and f) xor ((not e) and g) + let ch = UInt32::sha256_ch(cs.namespace(|| "ch"), &new_e, &f, &g)?; + + // temp1 := h + S1 + ch + k[i] + w[i] + let temp1 = [ + h.clone(), + s1, + ch, + UInt32::constant(ROUND_CONSTANTS[i]), + w[i].clone(), + ]; + + // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22) + let new_a = a.compute(cs.namespace(|| "deferred a computation"), &[])?; + let mut s0 = new_a.rotr(2); + s0 = s0.xor(cs.namespace(|| "first xor for s0"), &new_a.rotr(13))?; + s0 = s0.xor(cs.namespace(|| "second xor for s0"), &new_a.rotr(22))?; + + // maj := (a and b) xor (a and c) xor (b and c) + let maj = UInt32::sha256_maj(cs.namespace(|| "maj"), &new_a, &b, &c)?; + + // temp2 := S0 + maj + let temp2 = [s0, maj]; + + /* + h := g + g := f + f := e + e := d + temp1 + d := c + c := b + b := a + a := temp1 + temp2 + */ + + h = g; + g = f; + f = new_e; + e = Maybe::Deferred(temp1.iter().cloned().chain(Some(d)).collect::>()); + d = c; + c = b; + b = new_a; + a = Maybe::Deferred( + temp1 + .iter() + .cloned() + .chain(temp2.iter().cloned()) + .collect::>(), + ); + } + + /* + Add the compressed chunk to the current hash value: + h0 := h0 + a + h1 := h1 + b + h2 := h2 + c + h3 := h3 + d + h4 := h4 + e + h5 := h5 + f + h6 := h6 + g + h7 := h7 + h + */ + + let h0 = a.compute( + cs.namespace(|| "deferred h0 computation"), + &[current_hash_value[0].clone()], + )?; + + let h1 = UInt32::addmany( + cs.namespace(|| "new h1"), + &[current_hash_value[1].clone(), b], + )?; + + let h2 = UInt32::addmany( + cs.namespace(|| "new h2"), + &[current_hash_value[2].clone(), c], + )?; + + let h3 = UInt32::addmany( + cs.namespace(|| "new h3"), + &[current_hash_value[3].clone(), d], + )?; + + let h4 = e.compute( + cs.namespace(|| "deferred h4 computation"), + &[current_hash_value[4].clone()], + )?; + + let h5 = UInt32::addmany( + cs.namespace(|| "new h5"), + &[current_hash_value[5].clone(), f], + )?; + + let h6 = UInt32::addmany( + cs.namespace(|| "new h6"), + &[current_hash_value[6].clone(), g], + )?; + + let h7 = UInt32::addmany( + cs.namespace(|| "new h7"), + &[current_hash_value[7].clone(), h], + )?; + + Ok(vec![h0, h1, h2, h3, h4, h5, h6, h7]) +} diff --git a/src/frontend/gadgets/uint32.rs b/src/frontend/gadgets/uint32.rs new file mode 100644 index 00000000..c1833f3e --- /dev/null +++ b/src/frontend/gadgets/uint32.rs @@ -0,0 +1,402 @@ +//! Circuit representation of a [`u32`], with helpers for the [`sha256`] +//! gadgets. + +use ff::PrimeField; + +use crate::frontend::{ConstraintSystem, LinearCombination, SynthesisError}; + +use super::boolean::{AllocatedBit, Boolean}; +use super::multieq::MultiEq; + +/// Represents an interpretation of 32 `Boolean` objects as an +/// unsigned integer. +#[derive(Clone, Debug)] +pub struct UInt32 { + // Least significant bit first + bits: Vec, + value: Option, +} + +impl UInt32 { + /// Construct a constant `UInt32` from a `u32` + pub fn constant(value: u32) -> Self { + let mut bits = Vec::with_capacity(32); + + let mut tmp = value; + for _ in 0..32 { + if tmp & 1 == 1 { + bits.push(Boolean::constant(true)) + } else { + bits.push(Boolean::constant(false)) + } + + tmp >>= 1; + } + + UInt32 { + bits, + value: Some(value), + } + } + + /// Allocate a `UInt32` in the constraint system + pub fn alloc(mut cs: CS, value: Option) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let values = match value { + Some(mut val) => { + let mut v = Vec::with_capacity(32); + + for _ in 0..32 { + v.push(Some(val & 1 == 1)); + val >>= 1; + } + + v + } + None => vec![None; 32], + }; + + let bits = values + .into_iter() + .enumerate() + .map(|(i, v)| { + Ok(Boolean::from(AllocatedBit::alloc( + cs.namespace(|| format!("allocated bit {}", i)), + v, + )?)) + }) + .collect::, SynthesisError>>()?; + + Ok(UInt32 { bits, value }) + } + + pub fn into_bits_be(self) -> Vec { + let mut ret = self.bits; + ret.reverse(); + ret + } + + pub fn from_bits_be(bits: &[Boolean]) -> Self { + assert_eq!(bits.len(), 32); + + let mut value = Some(0u32); + for b in bits { + if let Some(v) = value.as_mut() { + *v <<= 1; + } + + match b.get_value() { + Some(true) => { + if let Some(v) = value.as_mut() { + *v |= 1; + } + } + Some(false) => {} + None => { + value = None; + } + } + } + + UInt32 { + value, + bits: bits.iter().rev().cloned().collect(), + } + } + + /// Turns this `UInt32` into its little-endian byte order representation. + pub fn into_bits(self) -> Vec { + self.bits + } + + /// Converts a little-endian byte order representation of bits into a + /// `UInt32`. + pub fn from_bits(bits: &[Boolean]) -> Self { + assert_eq!(bits.len(), 32); + + let new_bits = bits.to_vec(); + + let mut value = Some(0u32); + for b in new_bits.iter().rev() { + if let Some(v) = value.as_mut() { + *v <<= 1; + } + + match *b { + Boolean::Constant(b) => { + if b { + if let Some(v) = value.as_mut() { + *v |= 1; + } + } + } + Boolean::Is(ref b) => match b.get_value() { + Some(true) => { + if let Some(v) = value.as_mut() { + *v |= 1; + } + } + Some(false) => {} + None => value = None, + }, + Boolean::Not(ref b) => match b.get_value() { + Some(false) => { + if let Some(v) = value.as_mut() { + *v |= 1; + } + } + Some(true) => {} + None => value = None, + }, + } + } + + UInt32 { + value, + bits: new_bits, + } + } + + pub fn rotr(&self, by: usize) -> Self { + let by = by % 32; + + let new_bits = self + .bits + .iter() + .skip(by) + .chain(self.bits.iter()) + .take(32) + .cloned() + .collect(); + + UInt32 { + bits: new_bits, + value: self.value.map(|v| v.rotate_right(by as u32)), + } + } + + pub fn shr(&self, by: usize) -> Self { + let by = by % 32; + + let fill = Boolean::constant(false); + + let new_bits = self + .bits + .iter() // The bits are least significant first + .skip(by) // Skip the bits that will be lost during the shift + .chain(Some(&fill).into_iter().cycle()) // Rest will be zeros + .take(32) // Only 32 bits needed! + .cloned() + .collect(); + + UInt32 { + bits: new_bits, + value: self.value.map(|v| v >> by as u32), + } + } + + fn triop( + mut cs: CS, + a: &Self, + b: &Self, + c: &Self, + tri_fn: F, + circuit_fn: U, + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + F: Fn(u32, u32, u32) -> u32, + U: Fn(&mut CS, usize, &Boolean, &Boolean, &Boolean) -> Result, + { + let new_value = match (a.value, b.value, c.value) { + (Some(a), Some(b), Some(c)) => Some(tri_fn(a, b, c)), + _ => None, + }; + + let bits = a + .bits + .iter() + .zip(b.bits.iter()) + .zip(c.bits.iter()) + .enumerate() + .map(|(i, ((a, b), c))| circuit_fn(&mut cs, i, a, b, c)) + .collect::>()?; + + Ok(UInt32 { + bits, + value: new_value, + }) + } + + /// Compute the `maj` value (a and b) xor (a and c) xor (b and c) + /// during SHA256. + pub fn sha256_maj( + cs: CS, + a: &Self, + b: &Self, + c: &Self, + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + Self::triop( + cs, + a, + b, + c, + |a, b, c| (a & b) ^ (a & c) ^ (b & c), + |cs, i, a, b, c| Boolean::sha256_maj(cs.namespace(|| format!("maj {}", i)), a, b, c), + ) + } + + /// Compute the `ch` value `(a and b) xor ((not a) and c)` + /// during SHA256. + pub fn sha256_ch(cs: CS, a: &Self, b: &Self, c: &Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + Self::triop( + cs, + a, + b, + c, + |a, b, c| (a & b) ^ ((!a) & c), + |cs, i, a, b, c| Boolean::sha256_ch(cs.namespace(|| format!("ch {}", i)), a, b, c), + ) + } + + /// XOR this `UInt32` with another `UInt32` + pub fn xor(&self, mut cs: CS, other: &Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let new_value = match (self.value, other.value) { + (Some(a), Some(b)) => Some(a ^ b), + _ => None, + }; + + let bits = self + .bits + .iter() + .zip(other.bits.iter()) + .enumerate() + .map(|(i, (a, b))| Boolean::xor(cs.namespace(|| format!("xor of bit {}", i)), a, b)) + .collect::>()?; + + Ok(UInt32 { + bits, + value: new_value, + }) + } + + /// Perform modular addition of several `UInt32` objects. + #[allow(clippy::unnecessary_unwrap)] + pub fn addmany(mut cs: M, operands: &[Self]) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + M: ConstraintSystem>, + { + // Make some arbitrary bounds for ourselves to avoid overflows + // in the scalar field + assert!(Scalar::NUM_BITS >= 64); + assert!(operands.len() >= 2); // Weird trivial cases that should never happen + assert!(operands.len() <= 10); + + // Compute the maximum value of the sum so we allocate enough bits for + // the result + let mut max_value = (operands.len() as u64) * (u64::from(u32::max_value())); + + // Keep track of the resulting value + let mut result_value = Some(0u64); + + // This is a linear combination that we will enforce to equal the + // output + let mut lc = LinearCombination::zero(); + + let mut all_constants = true; + + // Iterate over the operands + for op in operands { + // Accumulate the value + match op.value { + Some(val) => { + if let Some(v) = result_value.as_mut() { + *v += u64::from(val); + } + } + None => { + // If any of our operands have unknown value, we won't + // know the value of the result + result_value = None; + } + } + + // Iterate over each bit of the operand and add the operand to + // the linear combination + let mut coeff = Scalar::ONE; + for bit in &op.bits { + lc = lc + &bit.lc(CS::one(), coeff); + + all_constants &= bit.is_constant(); + + coeff = coeff.double(); + } + } + + // The value of the actual result is modulo 2^32 + let modular_value = result_value.map(|v| v as u32); + + if all_constants && modular_value.is_some() { + // We can just return a constant, rather than + // unpacking the result into allocated bits. + + return Ok(UInt32::constant(modular_value.unwrap())); + } + + // Storage area for the resulting bits + let mut result_bits = vec![]; + + // Linear combination representing the output, + // for comparison with the sum of the operands + let mut result_lc = LinearCombination::zero(); + + // Allocate each bit of the result + let mut coeff = Scalar::ONE; + let mut i = 0; + while max_value != 0 { + // Allocate the bit + let b = AllocatedBit::alloc( + cs.namespace(|| format!("result bit {}", i)), + result_value.map(|v| (v >> i) & 1 == 1), + )?; + + // Add this bit to the result combination + result_lc = result_lc + (coeff, b.get_variable()); + + result_bits.push(b.into()); + + max_value >>= 1; + i += 1; + coeff = coeff.double(); + } + + // Enforce equality between the sum and result + cs.get_root().enforce_equal(i, &lc, &result_lc); + + // Discard carry bits that we don't care about + result_bits.truncate(32); + + Ok(UInt32 { + bits: result_bits, + value: modular_value, + }) + } +} diff --git a/src/frontend/lc.rs b/src/frontend/lc.rs new file mode 100644 index 00000000..3b333295 --- /dev/null +++ b/src/frontend/lc.rs @@ -0,0 +1,436 @@ +use std::ops::{Add, Sub}; + +use ff::PrimeField; +use serde::{Deserialize, Serialize}; + +/// Represents a variable in our constraint system. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct Variable(pub Index); + +impl Variable { + /// This constructs a variable with an arbitrary index. + /// Circuit implementations are not recommended to use this. + pub fn new_unchecked(idx: Index) -> Variable { + Variable(idx) + } + + /// This returns the index underlying the variable. + /// Circuit implementations are not recommended to use this. + pub fn get_unchecked(&self) -> Index { + self.0 + } +} + +/// Represents the index of either an input variable or +/// auxiliary variable. +#[derive(Copy, Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)] +pub enum Index { + /// Public input variable + Input(usize), + /// Private auxiliary variable + Aux(usize), + /// Precommitted variable + Precommitted(usize), +} + +/// This represents a linear combination of some variables, with coefficients +/// in the scalar field of a pairing-friendly elliptic curve group. +#[derive(Clone, Debug, PartialEq)] +pub struct LinearCombination { + inputs: Indexer, + aux: Indexer, + precommitted: Indexer, +} + +#[derive(Clone, Debug, PartialEq)] +struct Indexer { + /// Stores a list of `T` indexed by the number in the first slot of the tuple. + values: Vec<(usize, T)>, + /// `(index, key)` of the last insertion operation. Used to optimize consecutive operations + last_inserted: Option<(usize, usize)>, +} + +impl Default for Indexer { + fn default() -> Self { + Indexer { + values: Vec::new(), + last_inserted: None, + } + } +} + +impl Indexer { + pub fn from_value(index: usize, value: T) -> Self { + Indexer { + values: vec![(index, value)], + last_inserted: Some((0, index)), + } + } + + pub fn iter(&self) -> impl Iterator + '_ { + #[allow(clippy::map_identity)] + self.values.iter().map(|(key, value)| (key, value)) + } + + pub fn iter_mut(&mut self) -> impl Iterator + '_ { + self.values.iter_mut().map(|(key, value)| (&*key, value)) + } + + pub fn insert_or_update(&mut self, key: usize, insert: F, update: G) + where + F: FnOnce() -> T, + G: FnOnce(&mut T), + { + if let Some((last_index, last_key)) = self.last_inserted { + // Optimization to avoid doing binary search on inserts & updates that are linear, meaning + // they are adding a consecutive values. + if last_key == key { + // update the same key again + update(&mut self.values[last_index].1); + return; + } else if last_key + 1 == key { + // optimization for follow on updates + let i = last_index + 1; + if i >= self.values.len() { + // insert at the end + self.values.push((key, insert())); + self.last_inserted = Some((i, key)); + } else if self.values[i].0 == key { + // update + update(&mut self.values[i].1); + } else { + // insert + self.values.insert(i, (key, insert())); + self.last_inserted = Some((i, key)); + } + return; + } + } + match self.values.binary_search_by_key(&key, |(k, _)| *k) { + Ok(i) => { + update(&mut self.values[i].1); + } + Err(i) => { + self.values.insert(i, (key, insert())); + self.last_inserted = Some((i, key)); + } + } + } + + pub fn len(&self) -> usize { + self.values.len() + } + + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } +} + +impl Default for LinearCombination { + fn default() -> Self { + Self::zero() + } +} + +impl LinearCombination { + /// This returns a zero [`LinearCombination`]. + pub fn zero() -> LinearCombination { + LinearCombination { + inputs: Default::default(), + aux: Default::default(), + precommitted: Default::default(), + } + } + + /// Create a [`LinearCombination`] from a variable and coefficient. + pub fn from_coeff(var: Variable, coeff: Scalar) -> Self { + match var { + Variable(Index::Input(i)) => Self { + inputs: Indexer::from_value(i, coeff), + aux: Default::default(), + precommitted: Default::default(), + }, + Variable(Index::Aux(i)) => Self { + inputs: Default::default(), + aux: Indexer::from_value(i, coeff), + precommitted: Default::default(), + }, + Variable(Index::Precommitted(i)) => Self { + inputs: Default::default(), + aux: Default::default(), + precommitted: Indexer::from_value(i, coeff), + }, + } + } + + /// Create a [`LinearCombination`] from a variable. + pub fn from_variable(var: Variable) -> Self { + Self::from_coeff(var, Scalar::ONE) + } + + /// Iter for the [`LinearCombination`]. + pub fn iter(&self) -> impl Iterator + '_ { + self + .inputs + .iter() + .map(|(k, v)| (Variable(Index::Input(*k)), v)) + .chain(self.aux.iter().map(|(k, v)| (Variable(Index::Aux(*k)), v))) + } + + /// Iter inputs for the [`LinearCombination`] + #[inline] + pub fn iter_inputs(&self) -> impl Iterator + '_ { + self.inputs.iter() + } + + /// Iter aux for the [`LinearCombination`] + #[inline] + pub fn iter_aux(&self) -> impl Iterator + '_ { + self.aux.iter() + } + + /// Iter precommitted for the [`LinearCombination`] + #[inline] + pub fn iter_precommitted(&self) -> impl Iterator + '_ { + self.precommitted.iter() + } + + /// Iter mut for the [`LinearCombination`] + pub fn iter_mut(&mut self) -> impl Iterator + '_ { + self + .inputs + .iter_mut() + .map(|(k, v)| (Variable(Index::Input(*k)), v)) + .chain( + self + .aux + .iter_mut() + .map(|(k, v)| (Variable(Index::Aux(*k)), v)), + ) + } + + #[inline] + fn add_assign_unsimplified_input(&mut self, new_var: usize, coeff: Scalar) { + self + .inputs + .insert_or_update(new_var, || coeff, |val| *val += coeff); + } + + #[inline] + fn add_assign_unsimplified_aux(&mut self, new_var: usize, coeff: Scalar) { + self + .aux + .insert_or_update(new_var, || coeff, |val| *val += coeff); + } + + #[inline] + fn add_assign_unsimplified_precommitted(&mut self, new_var: usize, coeff: Scalar) { + self + .precommitted + .insert_or_update(new_var, || coeff, |val| *val += coeff); + } + + /// Add unsimplified + pub fn add_unsimplified(mut self, (coeff, var): (Scalar, Variable)) -> LinearCombination { + match var.0 { + Index::Input(new_var) => { + self.add_assign_unsimplified_input(new_var, coeff); + } + Index::Aux(new_var) => { + self.add_assign_unsimplified_aux(new_var, coeff); + } + Index::Precommitted(new_var) => { + self.add_assign_unsimplified_precommitted(new_var, coeff); + } + } + + self + } + + #[inline] + fn sub_assign_unsimplified_input(&mut self, new_var: usize, coeff: Scalar) { + self.add_assign_unsimplified_input(new_var, -coeff); + } + + #[inline] + fn sub_assign_unsimplified_aux(&mut self, new_var: usize, coeff: Scalar) { + self.add_assign_unsimplified_aux(new_var, -coeff); + } + + #[inline] + fn sub_assign_unsimplified_precommitted(&mut self, new_var: usize, coeff: Scalar) { + self.add_assign_unsimplified_precommitted(new_var, -coeff); + } + + /// Sub unsimplified + pub fn sub_unsimplified(mut self, (coeff, var): (Scalar, Variable)) -> LinearCombination { + match var.0 { + Index::Input(new_var) => { + self.sub_assign_unsimplified_input(new_var, coeff); + } + Index::Aux(new_var) => { + self.sub_assign_unsimplified_aux(new_var, coeff); + } + Index::Precommitted(new_var) => { + self.sub_assign_unsimplified_precommitted(new_var, -coeff); + } + } + + self + } + + /// len of the [`LinearCombination`] + pub fn len(&self) -> usize { + self.inputs.len() + self.aux.len() + self.precommitted.len() + } + + /// Check if the [`LinearCombination`] is empty + pub fn is_empty(&self) -> bool { + self.inputs.is_empty() && self.aux.is_empty() && self.precommitted.is_empty() + } + + /// Evaluate the [`LinearCombination`] with the given input and aux assignments. + pub fn eval( + &self, + input_assignment: &[Scalar], + aux_assignment: &[Scalar], + precommitted_assignment: &[Scalar], + ) -> Scalar { + let mut acc = Scalar::ZERO; + + let one = Scalar::ONE; + + for (index, coeff) in self.iter_inputs() { + let mut tmp = input_assignment[*index]; + if coeff != &one { + tmp *= coeff; + } + acc += tmp; + } + + for (index, coeff) in self.iter_aux() { + let mut tmp = aux_assignment[*index]; + if coeff != &one { + tmp *= coeff; + } + acc += tmp; + } + + for (index, coeff) in self.iter_precommitted() { + let mut tmp = precommitted_assignment[*index]; + if coeff != &one { + tmp *= coeff; + } + acc += tmp; + } + + acc + } +} + +impl Add<(Scalar, Variable)> for LinearCombination { + type Output = LinearCombination; + + fn add(self, (coeff, var): (Scalar, Variable)) -> LinearCombination { + self.add_unsimplified((coeff, var)) + } +} + +impl Sub<(Scalar, Variable)> for LinearCombination { + type Output = LinearCombination; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, (coeff, var): (Scalar, Variable)) -> LinearCombination { + self.sub_unsimplified((coeff, var)) + } +} + +impl Add for LinearCombination { + type Output = LinearCombination; + + fn add(self, other: Variable) -> LinearCombination { + self + (Scalar::ONE, other) + } +} + +impl Sub for LinearCombination { + type Output = LinearCombination; + + fn sub(self, other: Variable) -> LinearCombination { + self - (Scalar::ONE, other) + } +} + +impl<'a, Scalar: PrimeField> Add<&'a LinearCombination> for LinearCombination { + type Output = LinearCombination; + + fn add(mut self, other: &'a LinearCombination) -> LinearCombination { + for (var, val) in other.inputs.iter() { + self.add_assign_unsimplified_input(*var, *val); + } + + for (var, val) in other.aux.iter() { + self.add_assign_unsimplified_aux(*var, *val); + } + + self + } +} + +impl<'a, Scalar: PrimeField> Sub<&'a LinearCombination> for LinearCombination { + type Output = LinearCombination; + + fn sub(mut self, other: &'a LinearCombination) -> LinearCombination { + for (var, val) in other.inputs.iter() { + self.sub_assign_unsimplified_input(*var, *val); + } + + for (var, val) in other.aux.iter() { + self.sub_assign_unsimplified_aux(*var, *val); + } + + self + } +} + +impl<'a, Scalar: PrimeField> Add<(Scalar, &'a LinearCombination)> + for LinearCombination +{ + type Output = LinearCombination; + + fn add( + mut self, + (coeff, other): (Scalar, &'a LinearCombination), + ) -> LinearCombination { + for (var, val) in other.inputs.iter() { + self.add_assign_unsimplified_input(*var, *val * coeff); + } + + for (var, val) in other.aux.iter() { + self.add_assign_unsimplified_aux(*var, *val * coeff); + } + + self + } +} + +impl<'a, Scalar: PrimeField> Sub<(Scalar, &'a LinearCombination)> + for LinearCombination +{ + type Output = LinearCombination; + + fn sub( + mut self, + (coeff, other): (Scalar, &'a LinearCombination), + ) -> LinearCombination { + for (var, val) in other.inputs.iter() { + self.sub_assign_unsimplified_input(*var, *val * coeff); + } + + for (var, val) in other.aux.iter() { + self.sub_assign_unsimplified_aux(*var, *val * coeff); + } + + self + } +} diff --git a/src/bellpepper/mod.rs b/src/frontend/mod.rs similarity index 82% rename from src/bellpepper/mod.rs rename to src/frontend/mod.rs index 7b0aaf07..a165b160 100644 --- a/src/bellpepper/mod.rs +++ b/src/frontend/mod.rs @@ -1,7 +1,13 @@ -//! Support for generating R1CS from [Bellpepper]. -//! -//! [Bellpepper]: https://github.com/lurk-lab/bellpepper +//! Support for generating R1CS +mod constraint_system; +pub use constraint_system::{Circuit, ConstraintSystem, Namespace, SynthesisError}; +pub mod gadgets; +pub use gadgets::{boolean, num}; +mod lc; +pub use lc::{Index, LinearCombination, Variable}; +pub mod util_cs; +pub use util_cs::test_cs; pub mod r1cs; pub mod shape_cs; pub mod solver; @@ -9,8 +15,9 @@ pub mod test_shape_cs; #[cfg(test)] mod tests { + use crate::frontend::{num::AllocatedNum, ConstraintSystem}; use crate::{ - bellpepper::{ + frontend::{ r1cs::{NovaShape, NovaWitness}, shape_cs::ShapeCS, solver::SatisfyingAssignment, @@ -18,7 +25,6 @@ mod tests { provider::{Bn256EngineKZG, PallasEngine, Secp256k1Engine}, traits::{snark::default_ck_hint, Engine}, }; - use bellpepper_core::{num::AllocatedNum, ConstraintSystem}; use ff::PrimeField; fn synthesize_alloc_bit>(cs: &mut CS) { diff --git a/src/bellpepper/r1cs.rs b/src/frontend/r1cs.rs similarity index 98% rename from src/bellpepper/r1cs.rs rename to src/frontend/r1cs.rs index 3a7fd535..4900fdc2 100644 --- a/src/bellpepper/r1cs.rs +++ b/src/frontend/r1cs.rs @@ -3,13 +3,13 @@ #![allow(non_snake_case)] use super::{shape_cs::ShapeCS, solver::SatisfyingAssignment, test_shape_cs::TestShapeCS}; +use crate::frontend::{Index, LinearCombination}; use crate::{ errors::NovaError, r1cs::{CommitmentKeyHint, R1CSInstance, R1CSShape, R1CSWitness, SparseMatrix, R1CS}, traits::Engine, CommitmentKey, }; -use bellpepper_core::{Index, LinearCombination}; use ff::PrimeField; /// `NovaWitness` provide a method for acquiring an `R1CSInstance` and `R1CSWitness` from implementers. @@ -125,6 +125,7 @@ fn add_constraint( M.data.push(*coeff); M.indices.push(idx); } + _ => todo!(), } } }; diff --git a/src/bellpepper/shape_cs.rs b/src/frontend/shape_cs.rs similarity index 84% rename from src/bellpepper/shape_cs.rs rename to src/frontend/shape_cs.rs index 57543401..c97b7016 100644 --- a/src/bellpepper/shape_cs.rs +++ b/src/frontend/shape_cs.rs @@ -1,7 +1,7 @@ //! Support for generating R1CS shape using bellpepper. +use crate::frontend::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; use crate::traits::Engine; -use bellpepper_core::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; use ff::PrimeField; /// `ShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit. @@ -17,6 +17,7 @@ where )>, inputs: usize, aux: usize, + precommitted: usize, } impl ShapeCS { @@ -47,6 +48,7 @@ impl Default for ShapeCS { constraints: vec![], inputs: 1, aux: 0, + precommitted: 0, } } } @@ -65,6 +67,23 @@ impl ConstraintSystem for ShapeCS { Ok(Variable::new_unchecked(Index::Aux(self.aux - 1))) } + fn alloc_precommitted( + &mut self, + _annotation: A, + _f: F, + ) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.precommitted += 1; + + Ok(Variable::new_unchecked(Index::Precommitted( + self.precommitted - 1, + ))) + } + fn alloc_input(&mut self, _annotation: A, _f: F) -> Result where F: FnOnce() -> Result, diff --git a/src/bellpepper/solver.rs b/src/frontend/solver.rs similarity index 82% rename from src/bellpepper/solver.rs rename to src/frontend/solver.rs index 4fc211a3..5f9abaf9 100644 --- a/src/bellpepper/solver.rs +++ b/src/frontend/solver.rs @@ -2,7 +2,7 @@ use crate::traits::Engine; -use bellpepper::util_cs::witness_cs::WitnessCS; +use crate::frontend::util_cs::witness_cs::WitnessCS; /// A `ConstraintSystem` which calculates witness values for a concrete instance of an R1CS circuit. pub type SatisfyingAssignment = WitnessCS<::Scalar>; diff --git a/src/bellpepper/test_shape_cs.rs b/src/frontend/test_shape_cs.rs similarity index 87% rename from src/bellpepper/test_shape_cs.rs rename to src/frontend/test_shape_cs.rs index f4a6ae81..6a1224b9 100644 --- a/src/bellpepper/test_shape_cs.rs +++ b/src/frontend/test_shape_cs.rs @@ -6,8 +6,8 @@ use std::{ collections::{BTreeMap, HashMap}, }; +use crate::frontend::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; use crate::traits::Engine; -use bellpepper_core::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; use core::fmt::Write; use ff::{Field, PrimeField}; @@ -39,11 +39,15 @@ impl PartialOrd for OrderedVariable { impl Ord for OrderedVariable { fn cmp(&self, other: &Self) -> Ordering { match (self.0.get_unchecked(), other.0.get_unchecked()) { - (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => { - a.cmp(b) - } + (Index::Input(ref a), Index::Input(ref b)) + | (Index::Aux(ref a), Index::Aux(ref b)) + | (Index::Precommitted(ref a), Index::Precommitted(ref b)) => a.cmp(b), (Index::Input(_), Index::Aux(_)) => Ordering::Less, (Index::Aux(_), Index::Input(_)) => Ordering::Greater, + (Index::Precommitted(_), Index::Aux(_)) => Ordering::Less, + (Index::Aux(_), Index::Precommitted(_)) => Ordering::Greater, + (Index::Input(_), Index::Precommitted(_)) => Ordering::Less, + (Index::Precommitted(_), Index::Input(_)) => Ordering::Greater, } } } @@ -61,6 +65,7 @@ pub struct TestShapeCS { )>, inputs: Vec, aux: Vec, + precommitted: Vec, } fn proc_lc( @@ -177,6 +182,9 @@ where Index::Aux(i) => { write!(s, "`A{}`", &self.aux[i]).unwrap(); } + Index::Precommitted(i) => { + write!(s, "`P{}`", &self.precommitted[i]).unwrap(); + } } } if is_first { @@ -224,6 +232,7 @@ impl Default for TestShapeCS { constraints: vec![], inputs: vec![String::from("ONE")], aux: vec![], + precommitted: vec![], } } } @@ -246,6 +255,24 @@ where Ok(Variable::new_unchecked(Index::Aux(self.aux.len() - 1))) } + fn alloc_precommitted( + &mut self, + annotation: A, + _f: F, + ) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + let path = compute_path(&self.current_namespace, &annotation().into()); + self.precommitted.push(path); + + Ok(Variable::new_unchecked(Index::Precommitted( + self.precommitted.len() - 1, + ))) + } + fn alloc_input(&mut self, annotation: A, _f: F) -> Result where F: FnOnce() -> Result, diff --git a/src/frontend/util_cs/mod.rs b/src/frontend/util_cs/mod.rs new file mode 100644 index 00000000..ed4884c7 --- /dev/null +++ b/src/frontend/util_cs/mod.rs @@ -0,0 +1,102 @@ +//! The `util_cs` module provides a set of utilities for working with constraint system +use crate::frontend::LinearCombination; +use ff::PrimeField; + +pub mod test_cs; +pub mod witness_cs; + +/// Alias for a constraint in a constraint system +pub type Constraint = ( + LinearCombination, + LinearCombination, + LinearCombination, + String, +); + +/// The `Comparable` trait allows comparison of two constraint systems which +/// implement the trait. The only non-trivial method, `delta`, has a default +/// implementation which supplies the desired behavior. +/// +/// Use `delta` to compare constraint systems. If they are not identical, the +/// returned `Delta` enum contains fine-grained information about how they +/// differ. This can be especially useful when debugging the situation in which +/// a constraint system is satisfied, but the corresponding Groth16 proof does +/// not verify. +/// +/// If `ignore_counts` is true, count mismatches will be ignored, and any constraint +/// mismatch will be returned. This is useful in pinpointing the source of a mismatch. +/// +/// Example usage: +/// +/// ```ignore +/// let delta = cs.delta(&cs_blank, false); +/// assert!(delta == Delta::Equal); +/// ``` +pub trait Comparable { + /// Returns the number of inputs in the constraint system + fn num_inputs(&self) -> usize; + /// Returns the number of constraints in the constraint system + fn num_constraints(&self) -> usize; + /// Returns the inputs in the constraint system + fn inputs(&self) -> Vec; + /// Returns the auxiliary variables in the constraint system + fn aux(&self) -> Vec; + /// Returns the constraints in the constraint system + fn constraints(&self) -> &[Constraint]; + + /// Compare two constraint systems and return a `Delta` enum + fn delta>(&self, other: &C, ignore_counts: bool) -> Delta + where + Scalar: PrimeField, + { + let input_count_matches = self.num_inputs() == other.num_inputs(); + let constraint_count_matches = self.num_constraints() == other.num_constraints(); + + let inputs_match = self.inputs() == other.inputs(); + let constraints_match = self.constraints() == other.constraints(); + + let equal = + input_count_matches && constraint_count_matches && inputs_match && constraints_match; + + if !ignore_counts && !input_count_matches { + Delta::InputCountMismatch(self.num_inputs(), other.num_inputs()) + } else if !ignore_counts && !constraint_count_matches { + Delta::ConstraintCountMismatch(self.num_constraints(), other.num_constraints()) + } else if !constraints_match { + let c = self.constraints(); + let o = other.constraints(); + + let mismatch = c + .iter() + .zip(o) + .enumerate() + .filter(|(_, (a, b))| a != b) + .map(|(i, (a, b))| (i, a, b)) + .next(); + + let m = mismatch.unwrap(); + + Delta::ConstraintMismatch(m.0, m.1.clone(), m.2.clone()) + } else if equal { + Delta::Equal + } else { + Delta::Different + } + } +} + +/// The `Delta` enum is used to compare two constraint systems which implement +#[allow(clippy::large_enum_variant)] +#[derive(Clone, Debug, PartialEq)] +pub enum Delta { + /// The two constraint systems are equal + Equal, + /// The two constraint systems are different + Different, + /// The two constraint systems have a mismatch in the number of inputs + InputCountMismatch(usize, usize), + /// The two constraint systems have a mismatch in the number of constraints + ConstraintCountMismatch(usize, usize), + /// The two constraint systems have a mismatch in a constraint + ConstraintMismatch(usize, Constraint, Constraint), +} diff --git a/src/frontend/util_cs/test_cs.rs b/src/frontend/util_cs/test_cs.rs new file mode 100644 index 00000000..467367e4 --- /dev/null +++ b/src/frontend/util_cs/test_cs.rs @@ -0,0 +1,445 @@ +//! Test constraint system for use in tests. + +#![allow(dead_code)] +use std::cmp::Ordering; +use std::collections::BTreeMap; +use std::collections::HashMap; + +use super::Comparable; +use crate::frontend::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; + +use ff::PrimeField; + +#[derive(Debug)] +enum NamedObject { + Constraint(usize), + Var(Variable), + Namespace, +} + +/// Constraint system for testing purposes. +#[derive(Debug)] +pub struct TestConstraintSystem { + named_objects: HashMap, + current_namespace: Vec, + #[allow(clippy::type_complexity)] + constraints: Vec<( + LinearCombination, + LinearCombination, + LinearCombination, + String, + )>, + inputs: Vec<(Scalar, String)>, + aux: Vec<(Scalar, String)>, + precommitted: Vec<(Scalar, String)>, +} + +#[derive(Clone, Copy)] +struct OrderedVariable(Variable); + +impl Eq for OrderedVariable {} +impl PartialEq for OrderedVariable { + fn eq(&self, other: &OrderedVariable) -> bool { + match (self.0.get_unchecked(), other.0.get_unchecked()) { + (Index::Input(ref a), Index::Input(ref b)) => a == b, + (Index::Aux(ref a), Index::Aux(ref b)) => a == b, + _ => false, + } + } +} +impl PartialOrd for OrderedVariable { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for OrderedVariable { + fn cmp(&self, other: &Self) -> Ordering { + match (self.0.get_unchecked(), other.0.get_unchecked()) { + (Index::Input(ref a), Index::Input(ref b)) => a.cmp(b), + (Index::Aux(ref a), Index::Aux(ref b)) => a.cmp(b), + (Index::Precommitted(ref a), Index::Precommitted(ref b)) => a.cmp(b), + (Index::Input(_), Index::Aux(_)) => Ordering::Less, + (Index::Aux(_), Index::Input(_)) => Ordering::Greater, + (Index::Precommitted(_), Index::Aux(_)) => Ordering::Less, + (Index::Aux(_), Index::Precommitted(_)) => Ordering::Greater, + (Index::Input(_), Index::Precommitted(_)) => Ordering::Less, + (Index::Precommitted(_), Index::Input(_)) => Ordering::Greater, + } + } +} + +fn proc_lc( + terms: &LinearCombination, +) -> BTreeMap { + let mut map = BTreeMap::new(); + for (var, &coeff) in terms.iter() { + map + .entry(OrderedVariable(var)) + .or_insert_with(|| Scalar::ZERO) + .add_assign(&coeff); + } + + // Remove terms that have a zero coefficient to normalize + let mut to_remove = vec![]; + for (var, coeff) in map.iter() { + if coeff.is_zero().into() { + to_remove.push(*var) + } + } + + for var in to_remove { + map.remove(&var); + } + + map +} + +fn _eval_lc2( + terms: &LinearCombination, + inputs: &[Scalar], + aux: &[Scalar], + precommitted: &[Scalar], +) -> Scalar { + let mut acc = Scalar::ZERO; + + for (var, coeff) in terms.iter() { + let mut tmp = match var.get_unchecked() { + Index::Input(index) => inputs[index], + Index::Aux(index) => aux[index], + Index::Precommitted(index) => precommitted[index], + }; + + tmp.mul_assign(coeff); + acc.add_assign(&tmp); + } + + acc +} + +fn eval_lc( + terms: &LinearCombination, + inputs: &[(Scalar, String)], + aux: &[(Scalar, String)], + precommitted: &[(Scalar, String)], +) -> Scalar { + let mut acc = Scalar::ZERO; + + for (var, coeff) in terms.iter() { + let mut tmp = match var.get_unchecked() { + Index::Input(index) => inputs[index].0, + Index::Aux(index) => aux[index].0, + Index::Precommitted(index) => precommitted[index].0, + }; + + tmp.mul_assign(coeff); + acc.add_assign(&tmp); + } + + acc +} + +impl Default for TestConstraintSystem { + fn default() -> Self { + let mut map = HashMap::new(); + map.insert( + "ONE".into(), + NamedObject::Var(TestConstraintSystem::::one()), + ); + + TestConstraintSystem { + named_objects: map, + current_namespace: vec![], + constraints: vec![], + inputs: vec![(Scalar::ONE, "ONE".into())], + aux: vec![], + precommitted: vec![], + } + } +} + +impl TestConstraintSystem { + /// Create a new test constraint system. + pub fn new() -> Self { + Default::default() + } + + /// Get scalar inputs + pub fn scalar_inputs(&self) -> Vec { + self + .inputs + .iter() + .map(|(scalar, _string)| *scalar) + .collect() + } + + /// Get scalar aux + pub fn scalar_aux(&self) -> Vec { + self.aux.iter().map(|(scalar, _string)| *scalar).collect() + } + + /// Pretty print + pub fn pretty_print_list(&self) -> Vec { + let mut result = Vec::new(); + + for input in &self.inputs { + result.push(format!("INPUT {}", input.1)); + } + for aux in &self.aux { + result.push(format!("AUX {}", aux.1)); + } + + for (_a, _b, _c, name) in &self.constraints { + result.push(name.to_string()); + } + + result + } + + /// Pretty print + pub fn pretty_print(&self) -> String { + let res = self.pretty_print_list(); + + res.join("\n") + } + + /// Get path which is unsatisfied + pub fn which_is_unsatisfied(&self) -> Option<&str> { + for (a, b, c, path) in &self.constraints { + let mut a = eval_lc::(a, &self.inputs, &self.aux, &self.precommitted); + let b = eval_lc::(b, &self.inputs, &self.aux, &self.precommitted); + let c = eval_lc::(c, &self.inputs, &self.aux, &self.precommitted); + + a.mul_assign(&b); + + if a != c { + return Some(path); + } + } + + None + } + + /// Check if the constraint system is satisfied. + pub fn is_satisfied(&self) -> bool { + match self.which_is_unsatisfied() { + Some(b) => { + println!("fail: {:?}", b); + false + } + None => true, + } + // self.which_is_unsatisfied().is_none() + } + + /// Return the number of constraints in the constraint system. + pub fn num_constraints(&self) -> usize { + self.constraints.len() + } + + /// Create a new variable in the constraint system. + pub fn set(&mut self, path: &str, to: Scalar) { + match self.named_objects.get(path) { + Some(NamedObject::Var(v)) => match v.get_unchecked() { + Index::Input(index) => self.inputs[index].0 = to, + Index::Aux(index) => self.aux[index].0 = to, + Index::Precommitted(index) => self.precommitted[index].0 = to, + }, + Some(e) => panic!( + "tried to set path `{}` to value, but `{:?}` already exists there.", + path, e + ), + _ => panic!("no variable exists at path: {}", path), + } + } + + /// Verify expected vec == self.inputs + pub fn verify(&self, expected: &[Scalar]) -> bool { + assert_eq!(expected.len() + 1, self.inputs.len()); + for (a, b) in self.inputs.iter().skip(1).zip(expected.iter()) { + if &a.0 != b { + return false; + } + } + + true + } + + /// Return number of inputs in the constraint system. + pub fn num_inputs(&self) -> usize { + self.inputs.len() + } + + /// Get an input variable. + pub fn get_input(&mut self, index: usize, path: &str) -> Scalar { + let (assignment, name) = self.inputs[index].clone(); + + assert_eq!(path, name); + + assignment + } + + /// Get inputs + pub fn get_inputs(&self) -> &[(Scalar, String)] { + &self.inputs[..] + } + + /// Get Scalar from path + pub fn get(&mut self, path: &str) -> Scalar { + match self.named_objects.get(path) { + Some(NamedObject::Var(v)) => match v.get_unchecked() { + Index::Input(index) => self.inputs[index].0, + Index::Aux(index) => self.aux[index].0, + Index::Precommitted(index) => self.precommitted[index].0, + }, + Some(e) => panic!( + "tried to get value of path `{}`, but `{:?}` exists there (not a variable)", + path, e + ), + _ => panic!("no variable exists at path: {}", path), + } + } + + fn set_named_obj(&mut self, path: String, to: NamedObject) { + assert!( + !self.named_objects.contains_key(&path), + "tried to create object at existing path: {}", + path + ); + + self.named_objects.insert(path, to); + } +} + +impl Comparable for TestConstraintSystem { + fn num_inputs(&self) -> usize { + self.num_inputs() + } + fn num_constraints(&self) -> usize { + self.num_constraints() + } + + fn aux(&self) -> Vec { + self + .aux + .iter() + .map(|(_scalar, string)| string.to_string()) + .collect() + } + + fn inputs(&self) -> Vec { + self + .inputs + .iter() + .map(|(_scalar, string)| string.to_string()) + .collect() + } + + fn constraints(&self) -> &[crate::frontend::util_cs::Constraint] { + &self.constraints + } +} + +fn compute_path(ns: &[String], this: &str) -> String { + assert!( + !this.chars().any(|a| a == '/'), + "'/' is not allowed in names" + ); + + if ns.is_empty() { + return this.to_string(); + } + + let name = ns.join("/"); + format!("{}/{}", name, this) +} + +impl ConstraintSystem for TestConstraintSystem { + type Root = Self; + + fn alloc(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + let index = self.aux.len(); + let path = compute_path(&self.current_namespace, &annotation().into()); + self.aux.push((f()?, path.clone())); + let var = Variable::new_unchecked(Index::Aux(index)); + self.set_named_obj(path, NamedObject::Var(var)); + + Ok(var) + } + + fn alloc_precommitted( + &mut self, + annotation: A, + f: F, + ) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + let index = self.precommitted.len(); + let path = compute_path(&self.current_namespace, &annotation().into()); + self.precommitted.push((f()?, path.clone())); + let var = Variable::new_unchecked(Index::Precommitted(index)); + self.set_named_obj(path, NamedObject::Var(var)); + + Ok(var) + } + + fn alloc_input(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + let index = self.inputs.len(); + let path = compute_path(&self.current_namespace, &annotation().into()); + self.inputs.push((f()?, path.clone())); + let var = Variable::new_unchecked(Index::Input(index)); + self.set_named_obj(path, NamedObject::Var(var)); + + Ok(var) + } + + fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + let path = compute_path(&self.current_namespace, &annotation().into()); + let index = self.constraints.len(); + self.set_named_obj(path.clone(), NamedObject::Constraint(index)); + + let a = a(LinearCombination::zero()); + let b = b(LinearCombination::zero()); + let c = c(LinearCombination::zero()); + + self.constraints.push((a, b, c, path)); + } + + fn push_namespace(&mut self, name_fn: N) + where + NR: Into, + N: FnOnce() -> NR, + { + let name = name_fn().into(); + let path = compute_path(&self.current_namespace, &name); + self.set_named_obj(path, NamedObject::Namespace); + self.current_namespace.push(name); + } + + fn pop_namespace(&mut self) { + assert!(self.current_namespace.pop().is_some()); + } + + fn get_root(&mut self) -> &mut Self::Root { + self + } +} diff --git a/src/frontend/util_cs/witness_cs.rs b/src/frontend/util_cs/witness_cs.rs new file mode 100644 index 00000000..895196a9 --- /dev/null +++ b/src/frontend/util_cs/witness_cs.rs @@ -0,0 +1,236 @@ +//! Support for efficiently generating R1CS witness using bellperson. + +use ff::PrimeField; + +use crate::frontend::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; + +/// A [`ConstraintSystem`] trait +pub trait SizedWitness { + /// Returns the number of constraints in the constraint system + fn num_constraints(&self) -> usize; + /// Returns the number of inputs in the constraint system + fn num_inputs(&self) -> usize; + /// Returns the number of auxiliary variables in the constraint system + fn num_aux(&self) -> usize; + + /// Generate a witness for the constraint system + fn generate_witness_into(&mut self, aux: &mut [Scalar], inputs: &mut [Scalar]) -> Scalar; + /// Generate a witness for the constraint system + fn generate_witness(&mut self) -> (Vec, Vec, Scalar) { + let aux_count = self.num_aux(); + let inputs_count = self.num_inputs(); + + let mut aux = Vec::with_capacity(aux_count); + let mut inputs = Vec::with_capacity(inputs_count); + + aux.resize(aux_count, Scalar::ZERO); + inputs.resize(inputs_count, Scalar::ZERO); + + let result = self.generate_witness_into(&mut aux, &mut inputs); + + (aux, inputs, result) + } + + /// Generate a witness for the constraint system + fn generate_witness_into_cs>(&mut self, cs: &mut CS) -> Scalar { + assert!(cs.is_witness_generator()); + + let aux_count = self.num_aux(); + let inputs_count = self.num_inputs(); + + let (aux, inputs) = cs.allocate_empty(aux_count, inputs_count); + + assert_eq!(aux.len(), aux_count); + assert_eq!(inputs.len(), inputs_count); + + self.generate_witness_into(aux, inputs) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +/// A `ConstraintSystem` which calculates witness values for a concrete instance of an R1CS circuit. +pub struct WitnessCS +where + Scalar: PrimeField, +{ + // Assignments of variables + pub(crate) input_assignment: Vec, + pub(crate) aux_assignment: Vec, + pub(crate) precommitted_assignment: Vec, +} + +impl WitnessCS +where + Scalar: PrimeField, +{ + /// Get input assignment + pub fn input_assignment(&self) -> &[Scalar] { + &self.input_assignment + } + + /// Get aux assignment + pub fn aux_assignment(&self) -> &[Scalar] { + &self.aux_assignment + } + + /// Create a new [`WitnessCS`] with the specified capacity. + pub fn with_capacity(input_size: usize, aux_size: usize, precommitted_size: usize) -> Self { + let mut input_assignment = Vec::with_capacity(input_size); + input_assignment.push(Scalar::ONE); + let aux_assignment = Vec::with_capacity(aux_size); + let precommitted_assignment = Vec::with_capacity(precommitted_size); + Self { + input_assignment, + aux_assignment, + precommitted_assignment, + } + } + + /// Create a new [`WitnessCS`] from the specified assignments. + pub fn from_assignments( + input_assignment: Vec, + aux_assignment: Vec, + precommitted_assignment: Vec, + ) -> Self { + Self { + input_assignment, + aux_assignment, + precommitted_assignment, + } + } + + /// Convert the [`WitnessCS`] into a tuple of input and aux assignments. + pub fn to_assignments(self) -> (Vec, Vec) { + (self.input_assignment, self.aux_assignment) + } +} + +impl ConstraintSystem for WitnessCS +where + Scalar: PrimeField, +{ + type Root = Self; + + fn new() -> Self { + let input_assignment = vec![Scalar::ONE]; + + Self { + input_assignment, + aux_assignment: vec![], + precommitted_assignment: vec![], + } + } + + fn alloc(&mut self, _: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.aux_assignment.push(f()?); + + Ok(Variable(Index::Aux(self.aux_assignment.len() - 1))) + } + + fn alloc_precommitted(&mut self, _: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.precommitted_assignment.push(f()?); + + Ok(Variable(Index::Precommitted( + self.precommitted_assignment.len() - 1, + ))) + } + + fn alloc_input(&mut self, _: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.input_assignment.push(f()?); + + Ok(Variable(Index::Input(self.input_assignment.len() - 1))) + } + + fn enforce(&mut self, _: A, _a: LA, _b: LB, _c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + // Do nothing: we don't care about linear-combination evaluations in this context. + } + + fn push_namespace(&mut self, _: N) + where + NR: Into, + N: FnOnce() -> NR, + { + // Do nothing; we don't care about namespaces in this context. + } + + fn pop_namespace(&mut self) { + // Do nothing; we don't care about namespaces in this context. + } + + fn get_root(&mut self) -> &mut Self::Root { + self + } + + //////////////////////////////////////////////////////////////////////////////// + // Extensible + fn is_extensible() -> bool { + true + } + + fn extend(&mut self, other: &Self) { + self.input_assignment + // Skip first input, which must have been a temporarily allocated one variable. + .extend(&other.input_assignment[1..]); + self.aux_assignment.extend(&other.aux_assignment); + } + + //////////////////////////////////////////////////////////////////////////////// + // Witness generator + fn is_witness_generator(&self) -> bool { + true + } + + fn extend_inputs(&mut self, new_inputs: &[Scalar]) { + self.input_assignment.extend(new_inputs); + } + + fn extend_aux(&mut self, new_aux: &[Scalar]) { + self.aux_assignment.extend(new_aux); + } + + fn allocate_empty(&mut self, aux_n: usize, inputs_n: usize) -> (&mut [Scalar], &mut [Scalar]) { + let allocated_aux = { + let i = self.aux_assignment.len(); + self.aux_assignment.resize(aux_n + i, Scalar::ZERO); + &mut self.aux_assignment[i..] + }; + + let allocated_inputs = { + let i = self.input_assignment.len(); + self.input_assignment.resize(inputs_n + i, Scalar::ZERO); + &mut self.input_assignment[i..] + }; + + (allocated_aux, allocated_inputs) + } + + fn inputs_slice(&self) -> &[Scalar] { + &self.input_assignment + } + + fn aux_slice(&self) -> &[Scalar] { + &self.aux_assignment + } +} diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index ebb8f2de..23cefe5c 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -1,5 +1,11 @@ //! This module implements various elliptic curve gadgets #![allow(non_snake_case)] +use crate::frontend::gadgets::Assignment; +use crate::frontend::{ + boolean::{AllocatedBit, Boolean}, + num::AllocatedNum, + ConstraintSystem, SynthesisError, +}; use crate::{ gadgets::utils::{ alloc_num_equals, alloc_one, alloc_zero, conditionally_select, conditionally_select2, @@ -8,12 +14,6 @@ use crate::{ }, traits::{Engine, Group}, }; -use bellpepper::gadgets::Assignment; -use bellpepper_core::{ - boolean::{AllocatedBit, Boolean}, - num::AllocatedNum, - ConstraintSystem, SynthesisError, -}; use ff::{Field, PrimeField}; /// `AllocatedPoint` provides an elliptic curve abstraction inside a circuit. @@ -783,7 +783,7 @@ where mod tests { use super::*; use crate::{ - bellpepper::{ + frontend::{ r1cs::{NovaShape, NovaWitness}, {solver::SatisfyingAssignment, test_shape_cs::TestShapeCS}, }, diff --git a/src/gadgets/nonnative/bignat.rs b/src/gadgets/nonnative/bignat.rs index 73363feb..a12a9a11 100644 --- a/src/gadgets/nonnative/bignat.rs +++ b/src/gadgets/nonnative/bignat.rs @@ -4,7 +4,7 @@ use super::{ }, OptionExt, }; -use bellpepper_core::{ConstraintSystem, LinearCombination, SynthesisError}; +use crate::frontend::{ConstraintSystem, LinearCombination, SynthesisError}; use ff::PrimeField; use num_bigint::BigInt; use num_traits::cast::ToPrimitive; @@ -782,7 +782,7 @@ impl Polynomial { #[cfg(test)] mod tests { use super::*; - use bellpepper_core::{test_cs::TestConstraintSystem, Circuit}; + use crate::frontend::{test_cs::TestConstraintSystem, Circuit}; use pasta_curves::pallas::Scalar; use proptest::prelude::*; diff --git a/src/gadgets/nonnative/mod.rs b/src/gadgets/nonnative/mod.rs index 4d611cbb..80c8e12a 100644 --- a/src/gadgets/nonnative/mod.rs +++ b/src/gadgets/nonnative/mod.rs @@ -1,7 +1,7 @@ //! This module implements various gadgets necessary for doing non-native arithmetic //! Code in this module is adapted from [bellman-bignat](https://github.com/alex-ozdemir/bellman-bignat), which is licenced under MIT -use bellpepper_core::SynthesisError; +use crate::frontend::SynthesisError; use ff::PrimeField; trait OptionExt { diff --git a/src/gadgets/nonnative/util.rs b/src/gadgets/nonnative/util.rs index 6e2ebf5e..0f601714 100644 --- a/src/gadgets/nonnative/util.rs +++ b/src/gadgets/nonnative/util.rs @@ -1,5 +1,5 @@ use super::{BitAccess, OptionExt}; -use bellpepper_core::{ +use crate::frontend::{ num::AllocatedNum, {ConstraintSystem, LinearCombination, SynthesisError, Variable}, }; diff --git a/src/gadgets/r1cs.rs b/src/gadgets/r1cs.rs index 5ad31cfa..62386ea2 100644 --- a/src/gadgets/r1cs.rs +++ b/src/gadgets/r1cs.rs @@ -3,6 +3,8 @@ use super::nonnative::{ bignat::BigNat, util::{f_to_nat, Num}, }; +use crate::frontend::gadgets::{boolean::Boolean, num::AllocatedNum, Assignment}; +use crate::frontend::{ConstraintSystem, SynthesisError}; use crate::{ constants::{NUM_CHALLENGE_BITS, NUM_FE_FOR_RO}, gadgets::{ @@ -15,8 +17,6 @@ use crate::{ r1cs::{R1CSInstance, RelaxedR1CSInstance}, traits::{commitment::CommitmentTrait, Engine, Group, ROCircuitTrait, ROConstantsCircuit}, }; -use bellpepper::gadgets::{boolean::Boolean, num::AllocatedNum, Assignment}; -use bellpepper_core::{ConstraintSystem, SynthesisError}; use ff::Field; /// An Allocated R1CS Instance diff --git a/src/gadgets/utils.rs b/src/gadgets/utils.rs index 5ec90f04..6f212638 100644 --- a/src/gadgets/utils.rs +++ b/src/gadgets/utils.rs @@ -1,12 +1,12 @@ //! This module implements various low-level gadgets use super::nonnative::bignat::{nat_to_limbs, BigNat}; -use crate::traits::Engine; -use bellpepper::gadgets::Assignment; -use bellpepper_core::{ +use crate::frontend::gadgets::Assignment; +use crate::frontend::{ boolean::{AllocatedBit, Boolean}, num::AllocatedNum, ConstraintSystem, LinearCombination, SynthesisError, }; +use crate::traits::Engine; use ff::{Field, PrimeField, PrimeFieldBits}; use num_bigint::BigInt; diff --git a/src/lib.rs b/src/lib.rs index 2a0201b7..3e73eeae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,6 @@ #![forbid(unsafe_code)] // private modules -mod bellpepper; mod circuit; mod constants; mod digest; @@ -20,6 +19,7 @@ mod r1cs; // public modules pub mod errors; +pub mod frontend; pub mod gadgets; pub mod provider; pub mod spartan; @@ -27,13 +27,13 @@ pub mod traits; use once_cell::sync::OnceCell; -use crate::bellpepper::{ +use crate::digest::{DigestComputer, SimpleDigestible}; +use crate::frontend::{ r1cs::{NovaShape, NovaWitness}, shape_cs::ShapeCS, solver::SatisfyingAssignment, }; -use crate::digest::{DigestComputer, SimpleDigestible}; -use bellpepper_core::{ConstraintSystem, SynthesisError}; +use crate::frontend::{ConstraintSystem, SynthesisError}; use circuit::{NovaAugmentedCircuit, NovaAugmentedCircuitInputs, NovaAugmentedCircuitParams}; use constants::{BN_LIMB_WIDTH, BN_N_LIMBS, NUM_FE_WITHOUT_IO_FOR_CRHF, NUM_HASH_BITS}; use core::marker::PhantomData; @@ -996,10 +996,10 @@ mod tests { }, traits::{circuit::TrivialCircuit, evaluation::EvaluationEngineTrait, snark::default_ck_hint}, }; - use ::bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::{fmt::Write, marker::PhantomData}; use expect_test::{expect, Expect}; use ff::PrimeField; + use frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; type EE = provider::ipa_pc::EvaluationEngine; type EEPrime = provider::hyperkzg::EvaluationEngine; diff --git a/src/nifs.rs b/src/nifs.rs index 9fba6e84..fe578584 100644 --- a/src/nifs.rs +++ b/src/nifs.rs @@ -209,8 +209,9 @@ impl NIFSRelaxed { #[cfg(test)] mod tests { use super::*; + use crate::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use crate::{ - bellpepper::{ + frontend::{ r1cs::{NovaShape, NovaWitness}, solver::SatisfyingAssignment, test_shape_cs::TestShapeCS, @@ -219,7 +220,6 @@ mod tests { r1cs::{SparseMatrix, R1CS}, traits::{commitment::CommitmentEngineTrait, snark::default_ck_hint, Engine}, }; - use ::bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use ff::{Field, PrimeField}; use rand::rngs::OsRng; diff --git a/src/provider/mod.rs b/src/provider/mod.rs index dd94c657..fea6dd39 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -8,7 +8,7 @@ pub mod ipa_pc; pub(crate) mod bn256_grumpkin; pub(crate) mod pasta; pub(crate) mod pedersen; -pub(crate) mod poseidon; +pub mod poseidon; pub(crate) mod secp_secq; pub(crate) mod traits; diff --git a/src/provider/poseidon/circuit2.rs b/src/provider/poseidon/circuit2.rs new file mode 100644 index 00000000..679f41f4 --- /dev/null +++ b/src/provider/poseidon/circuit2.rs @@ -0,0 +1,650 @@ +//! The `circuit2` module implements the optimal Poseidon hash circuit. +use super::circuit2_witness::poseidon_hash_allocated_witness; +use super::hash_type::HashType; +use super::matrix::Matrix; +use super::mds::SparseMatrix; +use super::poseidon_inner::{Arity, PoseidonConstants}; +use crate::frontend::boolean::Boolean; +use crate::frontend::num::{self, AllocatedNum}; +use crate::frontend::{ConstraintSystem, LinearCombination, SynthesisError}; +use ff::PrimeField; +use std::marker::PhantomData; + +/// Similar to `num::Num`, we use `Elt` to accumulate both values and linear combinations, then eventually +/// extract into a `num::AllocatedNum`, enforcing that the linear combination corresponds to the result. +#[derive(Clone)] +pub enum Elt { + /// [`AllocatedNum`] variant + Allocated(AllocatedNum), + /// [`num::Num`] variant + Num(num::Num), +} + +impl From> for Elt { + fn from(allocated: AllocatedNum) -> Self { + Self::Allocated(allocated) + } +} + +impl Elt { + /// Check if the Elt is allocated. + pub const fn is_allocated(&self) -> bool { + matches!(self, Self::Allocated(_)) + } + + /// Check if the Elt is a Num. + pub const fn is_num(&self) -> bool { + matches!(self, Self::Num(_)) + } + + /// Create an Elt from a [`Scalar`]. + pub fn num_from_fr>(fr: Scalar) -> Self { + let num = num::Num::::zero(); + Self::Num(num.add_bool_with_coeff(CS::one(), &Boolean::Constant(true), fr)) + } + + /// Ensure Elt is allocated. + pub fn ensure_allocated>( + &self, + cs: &mut CS, + enforce: bool, + ) -> Result, SynthesisError> { + match self { + Self::Allocated(v) => Ok(v.clone()), + Self::Num(num) => { + let v = AllocatedNum::alloc(cs.namespace(|| "allocate for Elt::Num"), || { + num.get_value().ok_or(SynthesisError::AssignmentMissing) + })?; + + if enforce { + cs.enforce( + || "enforce num allocation preserves lc".to_string(), + |_| num.lc(Scalar::ONE), + |lc| lc + CS::one(), + |lc| lc + v.get_variable(), + ); + } + Ok(v) + } + } + } + + /// Get the value of the Elt. + pub fn val(&self) -> Option { + match self { + Self::Allocated(v) => v.get_value(), + Self::Num(num) => num.get_value(), + } + } + + /// Get the [`LinearCombination`] of the Elt. + pub fn lc(&self) -> LinearCombination { + match self { + Self::Num(num) => num.lc(Scalar::ONE), + Self::Allocated(v) => LinearCombination::::zero() + v.get_variable(), + } + } + + /// Add two Elts and return Elt::Num tracking the calculation. + #[allow(clippy::should_implement_trait)] + pub fn add(self, other: Elt) -> Result, SynthesisError> { + match (self, other) { + (Elt::Num(a), Elt::Num(b)) => Ok(Elt::Num(a.add(&b))), + (a, b) => Ok(Elt::Num(a.num().add(&b.num()))), + } + } + + /// Add two Elts and return Elt::Num tracking the calculation. + pub fn add_ref(self, other: &Elt) -> Result, SynthesisError> { + match (self, other) { + (Elt::Num(a), Elt::Num(b)) => Ok(Elt::Num(a.add(b))), + (a, b) => Ok(Elt::Num(a.num().add(&b.num()))), + } + } + + /// Scale + pub fn scale>( + self, + scalar: Scalar, + ) -> Result, SynthesisError> { + match self { + Elt::Num(num) => Ok(Elt::Num(num.scale(scalar))), + Elt::Allocated(a) => Elt::Num(a.into()).scale::(scalar), + } + } + + /// Square + pub fn square>( + &self, + mut cs: CS, + ) -> Result, SynthesisError> { + match self { + Elt::Num(num) => { + let allocated = AllocatedNum::alloc(&mut cs.namespace(|| "squared num"), || { + num + .get_value() + .ok_or(SynthesisError::AssignmentMissing) + .map(|tmp| tmp * tmp) + })?; + cs.enforce( + || "squaring constraint", + |_| num.lc(Scalar::ONE), + |_| num.lc(Scalar::ONE), + |lc| lc + allocated.get_variable(), + ); + Ok(allocated) + } + Elt::Allocated(a) => a.square(cs), + } + } + + /// Return inner Num. + pub fn num(&self) -> num::Num { + match self { + Elt::Num(num) => num.clone(), + Elt::Allocated(a) => a.clone().into(), + } + } +} + +/// Circuit for Poseidon hash. +pub struct PoseidonCircuit2<'a, Scalar, A> +where + Scalar: PrimeField, + A: Arity, +{ + constants_offset: usize, + width: usize, + pub(crate) elements: Vec>, + pub(crate) pos: usize, + current_round: usize, + constants: &'a PoseidonConstants, + _w: PhantomData, +} + +/// PoseidonCircuit2 implementation. +impl<'a, Scalar, A> PoseidonCircuit2<'a, Scalar, A> +where + Scalar: PrimeField, + A: Arity, +{ + /// Create a new Poseidon hasher for `preimage`. + pub fn new(elements: Vec>, constants: &'a PoseidonConstants) -> Self { + let width = constants.width(); + + PoseidonCircuit2 { + constants_offset: 0, + width, + elements, + pos: 1, + current_round: 0, + constants, + _w: PhantomData::, + } + } + + pub fn new_empty>( + constants: &'a PoseidonConstants, + ) -> Self { + let elements = Self::initial_elements::(); + Self::new(elements, constants) + } + + pub fn hash>( + &mut self, + cs: &mut CS, + ) -> Result, SynthesisError> { + self.full_round(cs.namespace(|| "first round"), true, false)?; + + for i in 1..self.constants.full_rounds / 2 { + self.full_round( + cs.namespace(|| format!("initial full round {i}")), + false, + false, + )?; + } + + for i in 0..self.constants.partial_rounds { + self.partial_round(cs.namespace(|| format!("partial round {i}")))?; + } + + for i in 0..(self.constants.full_rounds / 2) - 1 { + self.full_round( + cs.namespace(|| format!("final full round {i}")), + false, + false, + )?; + } + self.full_round(cs.namespace(|| "terminal full round"), false, true)?; + + let elt = self.elements[1].clone(); + self.reset_offsets(); + + Ok(elt) + } + + pub fn apply_padding>(&mut self) { + if let HashType::ConstantLength(l) = self.constants.hash_type { + let final_pos = 1 + (l % self.constants.arity()); + + assert_eq!( + self.pos, final_pos, + "preimage length does not match constant length required for hash" + ); + }; + match self.constants.hash_type { + HashType::ConstantLength(_) | HashType::Encryption => { + for elt in self.elements[self.pos..].iter_mut() { + *elt = Elt::num_from_fr::(Scalar::ZERO); + } + self.pos = self.elements.len(); + } + HashType::VariableLength => todo!(), + _ => (), // incl HashType::Sponge + } + } + + pub fn hash_to_allocated>( + &mut self, + mut cs: CS, + ) -> Result, SynthesisError> { + let elt = self.hash(&mut cs).unwrap(); + elt.ensure_allocated(&mut cs, true) + } + + fn hash_to_num>( + &mut self, + mut cs: CS, + ) -> Result, SynthesisError> { + self.hash(&mut cs).map(|elt| elt.num()) + } + + fn full_round>( + &mut self, + mut cs: CS, + first_round: bool, + last_round: bool, + ) -> Result<(), SynthesisError> { + let mut constants_offset = self.constants_offset; + + let pre_round_keys = if first_round { + (0..self.width) + .map(|i| self.constants.compressed_round_constants[constants_offset + i]) + .collect::>() + } else { + Vec::new() + }; + constants_offset += pre_round_keys.len(); + + let post_round_keys = if first_round || !last_round { + (0..self.width) + .map(|i| self.constants.compressed_round_constants[constants_offset + i]) + .collect::>() + } else { + Vec::new() + }; + constants_offset += post_round_keys.len(); + + // Apply the quintic S-Box to all elements + for i in 0..self.elements.len() { + let pre_round_key = if first_round { + let rk = pre_round_keys[i]; + Some(rk) + } else { + None + }; + + let post_round_key = if first_round || !last_round { + let rk = post_round_keys[i]; + Some(rk) + } else { + None + }; + + if first_round { + { + self.elements[i] = quintic_s_box_pre_add( + cs.namespace(|| format!("quintic s-box {i}")), + &self.elements[i], + pre_round_key, + post_round_key, + )?; + } + } else { + self.elements[i] = quintic_s_box( + cs.namespace(|| format!("quintic s-box {i}")), + &self.elements[i], + post_round_key, + )?; + } + } + self.constants_offset = constants_offset; + + // Multiply the elements by the constant MDS matrix + self.product_mds::()?; + Ok(()) + } + + fn partial_round>( + &mut self, + mut cs: CS, + ) -> Result<(), SynthesisError> { + let round_key = self.constants.compressed_round_constants[self.constants_offset]; + self.constants_offset += 1; + // Apply the quintic S-Box to the first element. + self.elements[0] = quintic_s_box( + cs.namespace(|| "solitary quintic s-box"), + &self.elements[0], + Some(round_key), + )?; + + // Multiply the elements by the constant MDS matrix + self.product_mds::()?; + Ok(()) + } + + fn product_mds_m>(&mut self) -> Result<(), SynthesisError> { + self.product_mds_with_matrix::(&self.constants.mds_matrices.m) + } + + /// Set the provided elements with the result of the product between the elements and the appropriate + /// MDS matrix. + #[allow(clippy::collapsible_else_if)] + fn product_mds>(&mut self) -> Result<(), SynthesisError> { + let full_half = self.constants.half_full_rounds; + let sparse_offset = full_half - 1; + if self.current_round == sparse_offset { + self.product_mds_with_matrix::(&self.constants.pre_sparse_matrix)?; + } else { + if (self.current_round > sparse_offset) + && (self.current_round < full_half + self.constants.partial_rounds) + { + let index = self.current_round - sparse_offset - 1; + let sparse_matrix = &self.constants.sparse_matrixes[index]; + + self.product_mds_with_sparse_matrix::(sparse_matrix)?; + } else { + self.product_mds_m::()?; + } + }; + + self.current_round += 1; + Ok(()) + } + + #[allow(clippy::ptr_arg)] + fn product_mds_with_matrix>( + &mut self, + matrix: &Matrix, + ) -> Result<(), SynthesisError> { + let mut result: Vec> = Vec::with_capacity(self.constants.width()); + + for j in 0..self.constants.width() { + let column = (0..self.constants.width()) + .map(|i| matrix[i][j]) + .collect::>(); + + let product = scalar_product::(self.elements.as_slice(), &column)?; + + result.push(product); + } + + self.elements = result; + + Ok(()) + } + + // Sparse matrix in this context means one of the form, M''. + fn product_mds_with_sparse_matrix>( + &mut self, + matrix: &SparseMatrix, + ) -> Result<(), SynthesisError> { + let mut result: Vec> = Vec::with_capacity(self.constants.width()); + + result.push(scalar_product::( + self.elements.as_slice(), + &matrix.w_hat, + )?); + + for j in 1..self.width { + result.push(self.elements[j].clone().add( + self.elements[0] + .clone() // First row is dense. + .scale::(matrix.v_rest[j - 1])?, // Except for first row/column, diagonals are one. + )?); + } + + self.elements = result; + + Ok(()) + } + + fn initial_elements>() -> Vec> { + std::iter::repeat(Elt::num_from_fr::(Scalar::ZERO)) + .take(A::to_usize() + 1) + .collect() + } + pub fn reset>(&mut self) { + self.reset_offsets(); + self.elements = Self::initial_elements::(); + } + + pub fn reset_offsets(&mut self) { + self.constants_offset = 0; + self.current_round = 0; + self.pos = 1; + } +} + +/// Create circuit for Poseidon hash, returning an allocated `Num` at the cost of one constraint. +pub fn poseidon_hash_allocated( + cs: CS, + preimage: Vec>, + constants: &PoseidonConstants, +) -> Result, SynthesisError> +where + CS: ConstraintSystem, + Scalar: PrimeField, + A: Arity, +{ + if cs.is_witness_generator() { + let mut cs = cs; + poseidon_hash_allocated_witness(&mut cs, &preimage, constants) + } else { + let arity = A::to_usize(); + let tag_element = Elt::num_from_fr::(constants.domain_tag); + let mut elements = Vec::with_capacity(arity + 1); + elements.push(tag_element); + elements.extend(preimage.into_iter().map(Elt::Allocated)); + + if let HashType::ConstantLength(length) = constants.hash_type { + assert!(length <= arity, "illegal length: constants are malformed"); + // Add zero-padding. + for _ in 0..(arity - length) { + let elt = Elt::Num(num::Num::zero()); + elements.push(elt); + } + } + let mut p = PoseidonCircuit2::new(elements, constants); + + p.hash_to_allocated(cs) + } +} + +/// Create circuit for Poseidon hash, minimizing constraints by returning an unallocated `Num`. +pub fn poseidon_hash_num( + cs: CS, + preimage: Vec>, + constants: &PoseidonConstants, +) -> Result, SynthesisError> +where + CS: ConstraintSystem, + Scalar: PrimeField, + A: Arity, +{ + let arity = A::to_usize(); + let tag_element = Elt::num_from_fr::(constants.domain_tag); + let mut elements = Vec::with_capacity(arity + 1); + elements.push(tag_element); + elements.extend(preimage.into_iter().map(Elt::Allocated)); + + if let HashType::ConstantLength(length) = constants.hash_type { + assert!(length <= arity, "illegal length: constants are malformed"); + // Add zero-padding. + for _ in 0..(arity - length) { + let elt = Elt::Num(num::Num::zero()); + elements.push(elt); + } + } + + let mut p = PoseidonCircuit2::new(elements, constants); + + p.hash_to_num(cs) +} + +/// Compute l^5 and enforce constraint. If round_key is supplied, add it to result. +fn quintic_s_box, Scalar: PrimeField>( + mut cs: CS, + l: &Elt, + post_round_key: Option, +) -> Result, SynthesisError> { + // If round_key was supplied, add it after all exponentiation. + let l2 = l.square(cs.namespace(|| "l^2"))?; + let l4 = l2.square(cs.namespace(|| "l^4"))?; + let l5 = mul_sum( + cs.namespace(|| "(l4 * l) + rk)"), + &l4, + l, + None, + post_round_key, + true, + ); + + Ok(Elt::Allocated(l5?)) +} + +/// Compute l^5 and enforce constraint. If round_key is supplied, add it to l first. +fn quintic_s_box_pre_add, Scalar: PrimeField>( + mut cs: CS, + l: &Elt, + pre_round_key: Option, + post_round_key: Option, +) -> Result, SynthesisError> { + if let (Some(pre_round_key), Some(post_round_key)) = (pre_round_key, post_round_key) { + // If round_key was supplied, add it to l before squaring. + let l2 = square_sum(cs.namespace(|| "(l+rk)^2"), pre_round_key, l, true)?; + let l4 = l2.square(cs.namespace(|| "l^4"))?; + let l5 = mul_sum( + cs.namespace(|| "l4 * (l + rk)"), + &l4, + l, + Some(pre_round_key), + Some(post_round_key), + true, + ); + + Ok(Elt::Allocated(l5?)) + } else { + panic!("pre_round_key and post_round_key must both be provided."); + } +} + +/// Calculates square of sum and enforces that constraint. +pub fn square_sum, Scalar: PrimeField>( + mut cs: CS, + to_add: Scalar, + elt: &Elt, + enforce: bool, +) -> Result, SynthesisError> { + let res = AllocatedNum::alloc(cs.namespace(|| "squared sum"), || { + let mut tmp = elt.val().ok_or(SynthesisError::AssignmentMissing)?; + tmp.add_assign(&to_add); + tmp = tmp.square(); + Ok(tmp) + })?; + + if enforce { + cs.enforce( + || "squared sum constraint", + |_| elt.lc() + (to_add, CS::one()), + |_| elt.lc() + (to_add, CS::one()), + |lc| lc + res.get_variable(), + ); + } + Ok(res) +} + +/// Calculates (a * (pre_add + b)) + post_add — and enforces that constraint. +#[allow(clippy::collapsible_else_if)] +pub fn mul_sum, Scalar: PrimeField>( + mut cs: CS, + a: &AllocatedNum, + b: &Elt, + pre_add: Option, + post_add: Option, + enforce: bool, +) -> Result, SynthesisError> { + let res = AllocatedNum::alloc(cs.namespace(|| "mul_sum"), || { + let mut tmp = b.val().ok_or(SynthesisError::AssignmentMissing)?; + if let Some(x) = pre_add { + tmp.add_assign(&x); + } + tmp.mul_assign(&a.get_value().ok_or(SynthesisError::AssignmentMissing)?); + if let Some(x) = post_add { + tmp.add_assign(&x); + } + + Ok(tmp) + })?; + + if enforce { + if let Some(x) = post_add { + let neg = -x; + + if let Some(pre) = pre_add { + cs.enforce( + || "mul sum constraint pre-post-add", + |_| b.lc() + (pre, CS::one()), + |lc| lc + a.get_variable(), + |lc| lc + res.get_variable() + (neg, CS::one()), + ); + } else { + cs.enforce( + || "mul sum constraint post-add", + |_| b.lc(), + |lc| lc + a.get_variable(), + |lc| lc + res.get_variable() + (neg, CS::one()), + ); + } + } else { + if let Some(pre) = pre_add { + cs.enforce( + || "mul sum constraint pre-add", + |_| b.lc() + (pre, CS::one()), + |lc| lc + a.get_variable(), + |lc| lc + res.get_variable(), + ); + } else { + cs.enforce( + || "mul sum constraint", + |_| b.lc(), + |lc| lc + a.get_variable(), + |lc| lc + res.get_variable(), + ); + } + } + } + Ok(res) +} + +fn scalar_product>( + elts: &[Elt], + scalars: &[Scalar], +) -> Result, SynthesisError> { + elts + .iter() + .zip(scalars) + .try_fold(Elt::Num(num::Num::zero()), |acc, (elt, &scalar)| { + acc.add(elt.clone().scale::(scalar)?) + }) +} diff --git a/src/provider/poseidon/circuit2_witness.rs b/src/provider/poseidon/circuit2_witness.rs new file mode 100644 index 00000000..859a2ce3 --- /dev/null +++ b/src/provider/poseidon/circuit2_witness.rs @@ -0,0 +1,278 @@ +/// The `circuit2_witness` module implements witness-generation for the optimal Poseidon hash circuit. +use super::poseidon_inner::{Arity, Poseidon, PoseidonConstants}; +use crate::frontend::util_cs::witness_cs::SizedWitness; + +use crate::frontend::num::AllocatedNum; + +use crate::frontend::{ConstraintSystem, SynthesisError}; +use ff::PrimeField; +use generic_array::sequence::GenericSequence; +use generic_array::typenum::Unsigned; +use generic_array::GenericArray; + +/// Create circuit for Poseidon hash, returning an `AllocatedNum` at the cost of one constraint. +pub fn poseidon_hash_allocated_witness( + cs: &mut CS, + preimage: &[AllocatedNum], + constants: &PoseidonConstants, +) -> Result, SynthesisError> +where + CS: ConstraintSystem, + Scalar: PrimeField, + A: Arity, +{ + assert!(cs.is_witness_generator()); + let result = poseidon_hash_witness_into_cs(cs, preimage, constants); + + AllocatedNum::alloc(&mut cs.namespace(|| "result"), || Ok(result)) +} + +pub fn poseidon_hash_witness_into_cs( + cs: &mut CS, + preimage: &[AllocatedNum], + constants: &PoseidonConstants, +) -> Scalar +where + CS: ConstraintSystem, + Scalar: PrimeField, + A: Arity, +{ + let scalar_preimage = preimage + .iter() + .map(|elt| elt.get_value().unwrap()) + .collect::>(); + let mut p = Poseidon::new_with_preimage(&scalar_preimage, constants); + + p.generate_witness_into_cs(cs) +} + +impl<'a, Scalar, A> SizedWitness for Poseidon<'a, Scalar, A> +where + Scalar: PrimeField, + A: Arity, +{ + fn num_constraints(&self) -> usize { + let s_box_cost = 3; + let width = A::ConstantsSize::to_usize(); + (width * s_box_cost * self.constants.full_rounds) + (s_box_cost * self.constants.partial_rounds) + } + + fn num_inputs(&self) -> usize { + 0 + } + + fn num_aux(&self) -> usize { + self.num_constraints() + } + fn generate_witness_into(&mut self, aux: &mut [Scalar], _inputs: &mut [Scalar]) -> Scalar { + let width = A::ConstantsSize::to_usize(); + let constants = self.constants; + let elements = &mut self.elements; + + let mut elements_buffer = GenericArray::::generate(|_| Scalar::ZERO); + + let c = &constants.compressed_round_constants; + + let mut offset = 0; + let mut aux_index = 0; + macro_rules! push_aux { + ($val:expr) => { + aux[aux_index] = $val; + aux_index += 1; + }; + } + + assert_eq!(width, elements.len()); + + // First Round (Full) + { + // s-box + for elt in elements.iter_mut() { + let x = c[offset]; + let y = c[offset + width]; + let mut tmp = *elt; + + tmp.add_assign(x); + tmp = tmp.square(); + push_aux!(tmp); // l2 + + tmp = tmp.square(); + push_aux!(tmp); // l4 + + tmp = tmp * (*elt + x) + y; + push_aux!(tmp); // l5 + + *elt = tmp; + offset += 1; + } + offset += width; // post-round keys + + // mds + { + let m = &constants.mds_matrices.m; + + for j in 0..width { + let scalar_product = m + .iter() + .enumerate() + .fold(Scalar::ZERO, |acc, (n, row)| acc + (row[j] * elements[n])); + + elements_buffer[j] = scalar_product; + } + elements.copy_from_slice(&elements_buffer); + } + } + + // Remaining initial full rounds. + { + for i in 1..constants.half_full_rounds { + // Use pre-sparse matrix on last initial full round. + let m = if i == constants.half_full_rounds - 1 { + &constants.pre_sparse_matrix + } else { + &constants.mds_matrices.m + }; + { + // s-box + for elt in elements.iter_mut() { + let y = c[offset]; + let mut tmp = *elt; + + tmp = tmp.square(); + push_aux!(tmp); // l2 + + tmp = tmp.square(); + push_aux!(tmp); // l4 + + tmp = tmp * *elt + y; + push_aux!(tmp); // l5 + + *elt = tmp; + offset += 1; + } + } + + // mds + { + for j in 0..width { + let scalar_product = m + .iter() + .enumerate() + .fold(Scalar::ZERO, |acc, (n, row)| acc + (row[j] * elements[n])); + + elements_buffer[j] = scalar_product; + } + elements.copy_from_slice(&elements_buffer); + } + } + } + + // Partial rounds + { + for i in 0..constants.partial_rounds { + // s-box + + // FIXME: a little silly to use a loop here. + for elt in elements[0..1].iter_mut() { + let y = c[offset]; + let mut tmp = *elt; + + tmp = tmp.square(); + push_aux!(tmp); // l2 + + tmp = tmp.square(); + push_aux!(tmp); // l4 + + tmp = tmp * *elt + y; + push_aux!(tmp); // l5 + + *elt = tmp; + offset += 1; + } + let m = &constants.sparse_matrixes[i]; + + // sparse mds + { + elements_buffer[0] = elements + .iter() + .zip(&m.w_hat) + .fold(Scalar::ZERO, |acc, (&x, &y)| acc + (x * y)); + + for j in 1..width { + elements_buffer[j] = elements[j] + elements[0] * m.v_rest[j - 1]; + } + + elements.copy_from_slice(&elements_buffer); + } + } + } + // Final full rounds. + { + let m = &constants.mds_matrices.m; + for _ in 1..constants.half_full_rounds { + { + // s-box + for elt in elements.iter_mut() { + let y = c[offset]; + let mut tmp = *elt; + + tmp = tmp.square(); + push_aux!(tmp); // l2 + + tmp = tmp.square(); + push_aux!(tmp); // l4 + + tmp = tmp * *elt + y; + push_aux!(tmp); // l5 + + *elt = tmp; + offset += 1; + } + } + + // mds + { + for j in 0..width { + let scalar_product = m + .iter() + .enumerate() + .fold(Scalar::ZERO, |acc, (n, row)| acc + (row[j] * elements[n])); + + elements_buffer[j] = scalar_product; + } + elements.copy_from_slice(&elements_buffer); + } + } + + // Terminal full round + { + // s-box + for elt in elements.iter_mut() { + let mut tmp = *elt; + + tmp = tmp.square(); + push_aux!(tmp); // l2 + + tmp = tmp.square(); + push_aux!(tmp); // l4 + + tmp *= *elt; + push_aux!(tmp); // l5 + + *elt = tmp; + } + + // mds + { + for j in 0..width { + elements_buffer[j] = + (0..width).fold(Scalar::ZERO, |acc, i| acc + elements[i] * m[i][j]); + } + elements.copy_from_slice(&elements_buffer); + } + } + } + + elements[1] + } +} diff --git a/src/provider/poseidon/hash_type.rs b/src/provider/poseidon/hash_type.rs new file mode 100644 index 00000000..ff2cd664 --- /dev/null +++ b/src/provider/poseidon/hash_type.rs @@ -0,0 +1,109 @@ +/// `HashType` provides support for domain separation tags. +/// For 128-bit security, we need to reserve one (~256-bit) field element per Poseidon permutation. +/// This element cannot be used for hash preimage data — but can be assigned a constant value designating +/// the hash function built on top of the underlying permutation. +/// +/// `neptune` implements a variation of the domain separation tag scheme suggested in the updated Poseidon paper. This +/// allows for a variety of modes. This ensures that digest values produced using one hash function cannot be reused +/// where another is required. +/// +/// Because `neptune` also supports a first-class notion of `Strength`, we include a mechanism for composing +/// `Strength` with `HashType` so that hashes with `Strength` other than `Standard` (currently only `Strengthened`) +/// may still express the full range of hash function types. +use ff::PrimeField; +use serde::{Deserialize, Serialize}; + +use super::poseidon_inner::Arity; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))] +pub enum HashType> { + MerkleTree, + MerkleTreeSparse(u64), + VariableLength, + ConstantLength(usize), + Encryption, + Custom(CType), + Sponge, +} + +impl> HashType { + /// Implements domain separation defined in original [Poseidon paper](https://eprint.iacr.org/2019/458.pdf). + /// Calculates field element used as a zero element in underlying [`crate::poseidon::Poseidon`] buffer that holds preimage. + pub fn domain_tag(&self) -> F { + match self { + // 2^arity - 1 + HashType::MerkleTree => A::tag(), + // bitmask + HashType::MerkleTreeSparse(bitmask) => F::from(*bitmask), + // 2^64 + HashType::VariableLength => pow2::(64), + // length * 2^64 + // length of 0 denotes a duplex sponge + HashType::ConstantLength(length) => x_pow2::(*length as u64, 64), + // 2^32 or (2^32 + 2^32 = 2^33) with strength tag + HashType::Encryption => pow2::(32), + // identifier * 2^40 + // identifier must be in range [1..=256] + // If identifier == 0 then the strengthened version collides with Encryption with standard strength. + // NOTE: in order to leave room for future `Strength` tags, + // we make identifier a multiple of 2^40 rather than 2^32. + HashType::Custom(ref ctype) => ctype.domain_tag(), + HashType::Sponge => F::ZERO, + } + } + + /// Some HashTypes require more testing so are not yet supported, since they are not yet needed. + /// As and when needed, support can be added, along with tests to ensure the initial implementation + /// is sound. + pub const fn is_supported(&self) -> bool { + match self { + HashType::MerkleTreeSparse(_) | HashType::VariableLength => false, + HashType::MerkleTree + | HashType::ConstantLength(_) + | HashType::Encryption + | HashType::Custom(_) + | HashType::Sponge => true, + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum CType> { + Arbitrary(u64), + // See: https://github.com/bincode-org/bincode/issues/424 + // This is a bit of a hack, but since `serde(skip)` tags the last variant arm, + // the generated code ends up being correct. But, in the future, do not + // carelessly add new variants to this enum. + #[serde(skip)] + _Phantom((F, A)), +} + +impl> CType { + const fn identifier(&self) -> u64 { + match self { + CType::Arbitrary(id) => *id, + CType::_Phantom(_) => panic!("_Phantom is not a real custom tag type."), + } + } + + fn domain_tag(&self) -> F { + let id = self.identifier(); + assert!(id > 0, "custom domain tag id out of range"); + assert!(id <= 256, "custom domain tag id out of range"); + + x_pow2::(id, 40) + } +} + +/// pow2(n) = 2^n +fn pow2(n: u64) -> F { + F::from(2).pow_vartime([n]) +} + +/// x_pow2(x, n) = x * 2^n +fn x_pow2(coeff: u64, n: u64) -> F { + let mut tmp = pow2::(n); + tmp.mul_assign(F::from(coeff)); + tmp +} diff --git a/src/provider/poseidon/matrix.rs b/src/provider/poseidon/matrix.rs new file mode 100644 index 00000000..b260ba06 --- /dev/null +++ b/src/provider/poseidon/matrix.rs @@ -0,0 +1,321 @@ +// Allow `&Matrix` in function signatures. +#![allow(clippy::ptr_arg)] + +use ff::PrimeField; + +/// Matrix functions here are, at least for now, quick and dirty — intended only to support precomputation of poseidon optimization. +/// +/// Matrix represented as a Vec of rows, so that m[i][j] represents the jth column of the ith row in Matrix, m. +pub(crate) type Matrix = Vec>; + +pub(crate) fn rows(matrix: &Matrix) -> usize { + matrix.len() +} + +/// Panics if `matrix` is not actually a matrix. So only use any of these functions on well-formed data. +/// Only use during constant calculation on matrices known to have been constructed correctly. +fn columns(matrix: &Matrix) -> usize { + if matrix.is_empty() { + 0 + } else { + let column_length = matrix[0].len(); + for row in matrix { + assert!(row.len() == column_length, "not a matrix"); + } + column_length + } +} + +// This wastefully discards the actual inverse, if it exists, so in general callers should +// just call `invert` if that result will be needed. +pub(crate) fn is_invertible(matrix: &Matrix) -> bool { + is_square(matrix) && invert(matrix).is_some() +} + +fn scalar_vec_mul(scalar: F, vec: &[F]) -> Vec { + vec + .iter() + .map(|val| { + let mut prod = scalar; + prod.mul_assign(val); + prod + }) + .collect::>() +} + +pub(crate) fn mat_mul(a: &Matrix, b: &Matrix) -> Option> { + if rows(a) != columns(b) { + return None; + }; + + let b_t = transpose(b); + + let res = a + .iter() + .map(|input_row| { + b_t + .iter() + .map(|transposed_column| vec_mul(input_row, transposed_column)) + .collect() + }) + .collect(); + + Some(res) +} + +fn vec_mul(a: &[F], b: &[F]) -> F { + a.iter().zip(b).fold(F::ZERO, |mut acc, (v1, v2)| { + let mut tmp = *v1; + tmp.mul_assign(v2); + acc.add_assign(&tmp); + acc + }) +} + +pub(crate) fn vec_add(a: &[F], b: &[F]) -> Vec { + a.iter() + .zip(b.iter()) + .map(|(a, b)| { + let mut res = *a; + res.add_assign(b); + res + }) + .collect::>() +} + +pub(crate) fn vec_sub(a: &[F], b: &[F]) -> Vec { + a.iter() + .zip(b.iter()) + .map(|(a, b)| { + let mut res = *a; + res.sub_assign(b); + res + }) + .collect::>() +} + +/// Left-multiply a vector by a square matrix of same size: MV where V is considered a column vector. +pub(crate) fn left_apply_matrix(m: &Matrix, v: &[F]) -> Vec { + assert!(is_square(m), "Only square matrix can be applied to vector."); + assert_eq!( + rows(m), + v.len(), + "Matrix can only be applied to vector of same size." + ); + + let mut result = vec![F::ZERO; v.len()]; + + for (result, row) in result.iter_mut().zip(m.iter()) { + for (mat_val, vec_val) in row.iter().zip(v) { + let mut tmp = *mat_val; + tmp.mul_assign(vec_val); + result.add_assign(&tmp); + } + } + result +} + +#[allow(clippy::needless_range_loop)] +pub(crate) fn transpose(matrix: &Matrix) -> Matrix { + let size = rows(matrix); + let mut new = Vec::with_capacity(size); + for j in 0..size { + let mut row = Vec::with_capacity(size); + for i in 0..size { + row.push(matrix[i][j]) + } + new.push(row); + } + new +} + +#[allow(clippy::needless_range_loop)] +pub(crate) fn make_identity(size: usize) -> Matrix { + let mut result = vec![vec![F::ZERO; size]; size]; + for i in 0..size { + result[i][i] = F::ONE; + } + result +} + +pub(crate) fn kronecker_delta(i: usize, j: usize) -> F { + if i == j { + F::ONE + } else { + F::ZERO + } +} + +pub(crate) fn is_identity(matrix: &Matrix) -> bool { + for i in 0..rows(matrix) { + for j in 0..columns(matrix) { + if matrix[i][j] != kronecker_delta(i, j) { + return false; + } + } + } + true +} + +pub(crate) fn is_square(matrix: &Matrix) -> bool { + rows(matrix) == columns(matrix) +} + +pub(crate) fn minor(matrix: &Matrix, i: usize, j: usize) -> Matrix { + assert!(is_square(matrix)); + let size = rows(matrix); + assert!(size > 0); + let new = matrix + .iter() + .enumerate() + .filter_map(|(ii, row)| { + if ii == i { + None + } else { + let mut new_row = row.clone(); + new_row.remove(j); + Some(new_row) + } + }) + .collect(); + assert!(is_square(&new)); + new +} + +// Assumes matrix is partially reduced to upper triangular. `column` is the column to eliminate from all rows. +// Returns `None` if either: +// - no non-zero pivot can be found for `column` +// - `column` is not the first +fn eliminate( + matrix: &Matrix, + column: usize, + shadow: &mut Matrix, +) -> Option> { + let zero = F::ZERO; + let pivot_index = (0..rows(matrix)) + .find(|&i| matrix[i][column] != zero && (0..column).all(|j| matrix[i][j] == zero))?; + + let pivot = &matrix[pivot_index]; + let pivot_val = pivot[column]; + + // This should never fail since we have a non-zero `pivot_val` if we got here. + let inv_pivot = Option::from(pivot_val.invert())?; + let mut result = Vec::with_capacity(matrix.len()); + result.push(pivot.clone()); + + for (i, row) in matrix.iter().enumerate() { + if i == pivot_index { + continue; + }; + let val = row[column]; + if val == zero { + // Value is already eliminated. + result.push(row.to_vec()); + } else { + let mut factor = val; + factor.mul_assign(&inv_pivot); + + let scaled_pivot = scalar_vec_mul(factor, pivot); + let eliminated = vec_sub(row, &scaled_pivot); + result.push(eliminated); + + let shadow_pivot = &shadow[pivot_index]; + let scaled_shadow_pivot = scalar_vec_mul(factor, shadow_pivot); + let shadow_row = &shadow[i]; + shadow[i] = vec_sub(shadow_row, &scaled_shadow_pivot); + } + } + + let pivot_row = shadow.remove(pivot_index); + shadow.insert(0, pivot_row); + + Some(result) +} + +// `matrix` must be square. +fn upper_triangular( + matrix: &Matrix, + shadow: &mut Matrix, +) -> Option> { + assert!(is_square(matrix)); + let mut result = Vec::with_capacity(matrix.len()); + let mut shadow_result = Vec::with_capacity(matrix.len()); + + let mut curr = matrix.clone(); + let mut column = 0; + while curr.len() > 1 { + let initial_rows = curr.len(); + + curr = eliminate(&curr, column, shadow)?; + result.push(curr[0].clone()); + shadow_result.push(shadow[0].clone()); + column += 1; + + curr = curr[1..].to_vec(); + *shadow = shadow[1..].to_vec(); + assert_eq!(curr.len(), initial_rows - 1); + } + result.push(curr[0].clone()); + shadow_result.push(shadow[0].clone()); + + *shadow = shadow_result; + + Some(result) +} + +// `matrix` must be upper triangular. +fn reduce_to_identity( + matrix: &Matrix, + shadow: &mut Matrix, +) -> Option> { + let size = rows(matrix); + let mut result: Matrix = Vec::new(); + let mut shadow_result: Matrix = Vec::new(); + + for i in 0..size { + let idx = size - i - 1; + let row = &matrix[idx]; + let shadow_row = &shadow[idx]; + + let val = row[idx]; + let inv = { + let inv = val.invert(); + // If `val` is zero, then there is no inverse, and we cannot compute a result. + if inv.is_none().into() { + return None; + } + inv.unwrap() + }; + + let mut normalized = scalar_vec_mul(inv, row); + let mut shadow_normalized = scalar_vec_mul(inv, shadow_row); + + for j in 0..i { + let idx = size - j - 1; + let val = normalized[idx]; + let subtracted = scalar_vec_mul(val, &result[j]); + let result_subtracted = scalar_vec_mul(val, &shadow_result[j]); + + normalized = vec_sub(&normalized, &subtracted); + shadow_normalized = vec_sub(&shadow_normalized, &result_subtracted); + } + + result.push(normalized); + shadow_result.push(shadow_normalized); + } + + result.reverse(); + shadow_result.reverse(); + + *shadow = shadow_result; + Some(result) +} + +// +pub(crate) fn invert(matrix: &Matrix) -> Option> { + let mut shadow = make_identity(columns(matrix)); + let ut = upper_triangular(matrix, &mut shadow); + + ut.and_then(|x| reduce_to_identity(&x, &mut shadow)) + .and(Some(shadow)) +} diff --git a/src/provider/poseidon/mds.rs b/src/provider/poseidon/mds.rs new file mode 100644 index 00000000..84274ca9 --- /dev/null +++ b/src/provider/poseidon/mds.rs @@ -0,0 +1,196 @@ +// Allow `&Matrix` in function signatures. +#![allow(clippy::ptr_arg)] + +use ff::PrimeField; +use serde::{Deserialize, Serialize}; + +use super::matrix; +use super::matrix::{ + invert, is_identity, is_invertible, is_square, left_apply_matrix, mat_mul, minor, transpose, + Matrix, +}; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct MdsMatrices { + pub m: Matrix, + pub m_inv: Matrix, + pub m_hat: Matrix, + pub m_hat_inv: Matrix, + pub m_prime: Matrix, + pub m_double_prime: Matrix, +} + +pub(crate) fn derive_mds_matrices(m: Matrix) -> MdsMatrices { + let m_inv = invert(&m).unwrap(); // m is MDS so invertible. + let m_hat = minor(&m, 0, 0); + let m_hat_inv = invert(&m_hat).unwrap(); // If this returns None, then `mds_matrix` was not correctly generated. + let m_prime = make_prime(&m); + let m_double_prime = make_double_prime(&m, &m_hat_inv); + + MdsMatrices { + m, + m_inv, + m_hat, + m_hat_inv, + m_prime, + m_double_prime, + } +} + +/// A `SparseMatrix` is specifically one of the form of M''. +/// This means its first row and column are each dense, and the interior matrix +/// (minor to the element in both the row and column) is the identity. +/// We will pluralize this compact structure `sparse_matrixes` to distinguish from `sparse_matrices` from which they are created. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SparseMatrix { + /// `w_hat` is the first column of the M'' matrix. It will be directly multiplied (scalar product) with a row of state elements. + pub w_hat: Vec, + /// `v_rest` contains all but the first (already included in `w_hat`). + pub v_rest: Vec, +} + +impl SparseMatrix { + pub fn new_from_ref(m_double_prime: &Matrix) -> Self { + assert!(Self::is_sparse_matrix(m_double_prime)); + let size = matrix::rows(m_double_prime); + + let w_hat = (0..size).map(|i| m_double_prime[i][0]).collect::>(); + let v_rest = m_double_prime[0][1..].to_vec(); + + Self { w_hat, v_rest } + } + + pub fn is_sparse_matrix(m: &Matrix) -> bool { + is_square(m) && is_identity(&minor(m, 0, 0)) + } + + pub fn size(&self) -> usize { + self.w_hat.len() + } + + pub fn to_matrix(&self) -> Matrix { + let mut m = matrix::make_identity(self.size()); + for (j, elt) in self.w_hat.iter().enumerate() { + m[j][0] = *elt; + } + for (i, elt) in self.v_rest.iter().enumerate() { + m[0][i + 1] = *elt; + } + m + } +} + +// - Having effectively moved the round-key additions into the S-boxes, refactor MDS matrices used for partial-round mix layer to use sparse matrices. +// - This requires using a different (sparse) matrix at each partial round, rather than the same dense matrix at each. +// - The MDS matrix, M, for each such round, starting from the last, is factored into two components, such that M' x M'' = M. +// - M'' is sparse and replaces M for the round. +// - The previous layer's M is then replaced by M x M' = M*. +// - M* is likewise factored into M*' and M*'', and the process continues. +pub(crate) fn factor_to_sparse_matrixes( + base_matrix: &Matrix, + n: usize, +) -> (Matrix, Vec>) { + let (pre_sparse, sparse_matrices) = factor_to_sparse_matrices(base_matrix, n); + let sparse_matrixes = sparse_matrices + .iter() + .map(|m| SparseMatrix::::new_from_ref(m)) + .collect::>(); + + (pre_sparse, sparse_matrixes) +} + +pub(crate) fn factor_to_sparse_matrices( + base_matrix: &Matrix, + n: usize, +) -> (Matrix, Vec>) { + let (pre_sparse, mut all) = + (0..n).fold((base_matrix.clone(), Vec::new()), |(curr, mut acc), _| { + let derived = derive_mds_matrices(curr); + acc.push(derived.m_double_prime); + let new = mat_mul(base_matrix, &derived.m_prime).unwrap(); + (new, acc) + }); + all.reverse(); + (pre_sparse, all) +} + +pub(crate) fn generate_mds(t: usize) -> Matrix { + // Source: https://github.com/dusk-network/dusk-poseidon-merkle/commit/776c37734ea2e71bb608ce4bc58fdb5f208112a7#diff-2eee9b20fb23edcc0bf84b14167cbfdc + // Generate x and y values deterministically for the cauchy matrix + // where x[i] != y[i] to allow the values to be inverted + // and there are no duplicates in the x vector or y vector, so that the determinant is always non-zero + // [a b] + // [c d] + // det(M) = (ad - bc) ; if a == b and c == d => det(M) =0 + // For an MDS matrix, every possible mxm submatrix, must have det(M) != 0 + let xs: Vec = (0..t as u64).map(F::from).collect(); + let ys: Vec = (t as u64..2 * t as u64).map(F::from).collect(); + + let matrix = xs + .iter() + .map(|xs_item| { + ys.iter() + .map(|ys_item| { + // Generate the entry at (i,j) + let mut tmp = *xs_item; + tmp.add_assign(ys_item); + tmp.invert().unwrap() + }) + .collect() + }) + .collect(); + + // To ensure correctness, we would check all sub-matrices for invertibility. Meanwhile, this is a simple sanity check. + assert!(is_invertible(&matrix)); + + // `poseidon::product_mds_with_matrix` relies on the constructed MDS matrix being symmetric, so ensure it is. + assert_eq!(matrix, transpose(&matrix)); + matrix +} + +fn make_prime(m: &Matrix) -> Matrix { + m.iter() + .enumerate() + .map(|(i, row)| match i { + 0 => { + let mut new_row = vec![F::ZERO; row.len()]; + new_row[0] = F::ONE; + new_row + } + _ => { + let mut new_row = vec![F::ZERO; row.len()]; + new_row[1..].copy_from_slice(&row[1..]); + new_row + } + }) + .collect() +} + +fn make_double_prime(m: &Matrix, m_hat_inv: &Matrix) -> Matrix { + let (v, w) = make_v_w(m); + let w_hat = left_apply_matrix(m_hat_inv, &w); + + m.iter() + .enumerate() + .map(|(i, row)| match i { + 0 => { + let mut new_row = Vec::with_capacity(row.len()); + new_row.push(row[0]); + new_row.extend(&v); + new_row + } + _ => { + let mut new_row = vec![F::ZERO; row.len()]; + new_row[0] = w_hat[i - 1]; + new_row[i] = F::ONE; + new_row + } + }) + .collect() +} + +fn make_v_w(m: &Matrix) -> (Vec, Vec) { + let v = m[0][1..].to_vec(); + let w = m.iter().skip(1).map(|column| column[0]).collect(); + (v, w) +} diff --git a/src/provider/poseidon.rs b/src/provider/poseidon/mod.rs similarity index 73% rename from src/provider/poseidon.rs rename to src/provider/poseidon/mod.rs index b9e433a8..d7fd7c36 100644 --- a/src/provider/poseidon.rs +++ b/src/provider/poseidon/mod.rs @@ -1,24 +1,109 @@ //! Poseidon Constants and Poseidon-based RO used in Nova -use crate::traits::{ROCircuitTrait, ROTrait}; -use bellpepper_core::{ + +// TODO: remove this +#![allow(unused)] + +use ff::PrimeField; +use round_constants::generate_constants; +use round_numbers::{round_numbers_base, round_numbers_strengthened}; +use serde::{Deserialize, Serialize}; +mod circuit2; +mod circuit2_witness; +mod hash_type; +mod matrix; +mod mds; +mod poseidon_alt; +mod poseidon_inner; +mod preprocessing; +mod round_constants; +mod round_numbers; +mod serde_impl; +mod sponge; + +pub use circuit2::Elt; + +pub use sponge::{ + api::{IOPattern, SpongeAPI, SpongeOp}, + circuit::SpongeCircuit, + vanilla::{Mode::Simplex, Sponge, SpongeTrait}, +}; + +use crate::frontend::{ boolean::{AllocatedBit, Boolean}, num::AllocatedNum, ConstraintSystem, SynthesisError, }; +use crate::provider::poseidon::poseidon_inner::PoseidonConstants; +use crate::traits::{ROCircuitTrait, ROTrait}; use core::marker::PhantomData; -use ff::{PrimeField, PrimeFieldBits}; +use ff::PrimeFieldBits; use generic_array::typenum::U24; -use neptune::{ - circuit2::Elt, - poseidon::PoseidonConstants, - sponge::{ - api::{IOPattern, SpongeAPI, SpongeOp}, - circuit::SpongeCircuit, - vanilla::{Mode::Simplex, Sponge, SpongeTrait}, - }, - Strength, -}; -use serde::{Deserialize, Serialize}; + +/// The strength of the Poseidon hash function +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum Strength { + /// Standard strength + Standard, + /// Strengthened strength + Strengthened, +} + +const DEFAULT_STRENGTH: Strength = Strength::Standard; + +fn round_numbers(arity: usize, strength: &Strength) -> (usize, usize) { + match strength { + Strength::Standard => round_numbers_base(arity), + Strength::Strengthened => round_numbers_strengthened(arity), + } +} + +const SBOX: u8 = 1; // x^5 +const FIELD: u8 = 1; // Gf(p) + +fn round_constants(arity: usize, strength: &Strength) -> Vec { + let t = arity + 1; + + let (full_rounds, partial_rounds) = round_numbers(arity, strength); + + let r_f = full_rounds as u16; + let r_p = partial_rounds as u16; + + let fr_num_bits = F::NUM_BITS; + let field_size = { + assert!(fr_num_bits <= u32::from(u16::MAX)); + // It's safe to convert to u16 for compatibility with other types. + fr_num_bits as u16 + }; + + generate_constants::(FIELD, SBOX, field_size, t as u16, r_f, r_p) +} + +/// Apply the quintic S-Box (s^5) to a given item +pub(crate) fn quintic_s_box(l: &mut F, pre_add: Option<&F>, post_add: Option<&F>) { + if let Some(x) = pre_add { + l.add_assign(x); + } + let mut tmp = *l; + tmp = tmp.square(); // l^2 + tmp = tmp.square(); // l^4 + l.mul_assign(&tmp); // l^5 + if let Some(x) = post_add { + l.add_assign(x); + } +} + +#[derive(Debug, Clone)] +/// Possible error states for the hashing. +pub enum PoseidonError { + /// The allowed number of leaves cannot be greater than the arity of the tree. + FullBuffer, + /// Attempt to reference an index element that is out of bounds + IndexOutOfBounds, + /// GPU error + GpuError(String), + /// Other error + Other(String), +} /// All Poseidon Constants that are used in Nova #[derive(Clone, PartialEq, Serialize, Deserialize)] @@ -157,7 +242,7 @@ where assert_eq!(self.num_absorbs, self.state.len()); sponge.start(parameter, None, acc); - neptune::sponge::api::SpongeAPI::absorb( + SpongeAPI::absorb( &mut sponge, self.num_absorbs as u32, &(0..self.state.len()) @@ -166,7 +251,7 @@ where acc, ); - let output = neptune::sponge::api::SpongeAPI::squeeze(&mut sponge, 1, acc); + let output = SpongeAPI::squeeze(&mut sponge, 1, acc); sponge.finish(acc).unwrap(); output }; @@ -195,7 +280,7 @@ mod tests { Bn256EngineKZG, GrumpkinEngine, PallasEngine, Secp256k1Engine, Secq256k1Engine, VestaEngine, }; use crate::{ - bellpepper::solver::SatisfyingAssignment, constants::NUM_CHALLENGE_BITS, + constants::NUM_CHALLENGE_BITS, frontend::solver::SatisfyingAssignment, gadgets::utils::le_bits_to_num, traits::Engine, }; use ff::Field; diff --git a/src/provider/poseidon/poseidon_alt.rs b/src/provider/poseidon/poseidon_alt.rs new file mode 100644 index 00000000..263558cf --- /dev/null +++ b/src/provider/poseidon/poseidon_alt.rs @@ -0,0 +1,247 @@ +//! This module contains the 'correct' and 'dynamic' versions of Poseidon hashing. +//! These are tested (in `poseidon::test`) to be equivalent to the 'static optimized' version +//! used for actual hashing by the neptune library. +use super::poseidon_inner::{Arity, Poseidon}; +use super::{matrix, quintic_s_box}; +use ff::PrimeField; + +//////////////////////////////////////////////////////////////////////////////// +/// Correct +/// +/// This code path implements a naive and evidently correct poseidon hash. +/// +/// The returned element is the second poseidon element, the first is the arity tag. +pub(crate) fn hash_correct(p: &mut Poseidon<'_, F, A>) -> F +where + F: PrimeField, + A: Arity, +{ + // This counter is incremented when a round constants is read. Therefore, the round constants never repeat. + // The first full round should use the initial constants. + full_round(p); + + for _ in 1..p.constants.half_full_rounds { + full_round(p); + } + + partial_round(p); + + for _ in 1..p.constants.partial_rounds { + partial_round(p); + } + + for _ in 0..p.constants.half_full_rounds { + full_round(p); + } + + p.elements[1] +} + +pub(crate) fn full_round(p: &mut Poseidon<'_, F, A>) +where + F: PrimeField, + A: Arity, +{ + // Apply the quintic S-Box to all elements, after adding the round key. + // Round keys are added in the S-box to match circuits (where the addition is free) + // and in preparation for the shift to adding round keys after (rather than before) applying the S-box. + + let pre_round_keys = p + .constants + .round_constants + .as_ref() + .unwrap() + .iter() + .skip(p.constants_offset) + .map(Some); + + p.elements + .iter_mut() + .zip(pre_round_keys) + .for_each(|(l, pre)| { + quintic_s_box(l, pre, None); + }); + + p.constants_offset += p.elements.len(); + + // M(B) + // Multiply the elements by the constant MDS matrix + p.product_mds(); +} + +/// The partial round is the same as the full round, with the difference that we apply the S-Box only to the first bitflags poseidon leaf. +pub(crate) fn partial_round(p: &mut Poseidon<'_, F, A>) +where + F: PrimeField, + A: Arity, +{ + // Every element of the hash buffer is incremented by the round constants + add_round_constants(p); + + // Apply the quintic S-Box to the first element + quintic_s_box(&mut p.elements[0], None, None); + + // Multiply the elements by the constant MDS matrix + p.product_mds(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// Dynamic +/// +/// This code path implements a code path which dynamically calculates compressed round constants one-deep. +/// It serves as a bridge between the 'correct' and fully, statically optimized implementations. +/// Comments reference notation also expanded in matrix.rs and help clarify the relationship between +/// our optimizations and those described in the paper. +pub(crate) fn hash_optimized_dynamic(p: &mut Poseidon<'_, F, A>) -> F +where + F: PrimeField, + A: Arity, +{ + // The first full round should use the initial constants. + full_round_dynamic(p, true, true); + + for _ in 1..(p.constants.half_full_rounds) { + full_round_dynamic(p, false, true); + } + + partial_round_dynamic(p); + + for _ in 1..p.constants.partial_rounds { + partial_round(p); + } + + for _ in 0..p.constants.half_full_rounds { + full_round_dynamic(p, true, false); + } + + p.elements[1] +} + +pub(crate) fn full_round_dynamic( + p: &mut Poseidon<'_, F, A>, + add_current_round_keys: bool, + absorb_next_round_keys: bool, +) where + F: PrimeField, + A: Arity, +{ + // NOTE: decrease in performance is expected when using this pathway. + // We seek to preserve correctness while transforming the algorithm to an eventually more performant one. + + // Round keys are added in the S-box to match circuits (where the addition is free). + // If requested, add round keys synthesized from following round after (rather than before) applying the S-box. + let pre_round_keys = p + .constants + .round_constants + .as_ref() + .unwrap() + .iter() + .skip(p.constants_offset) + .map(|x| { + if add_current_round_keys { + Some(x) + } else { + None + } + }); + + if absorb_next_round_keys { + // Using the notation from `test_inverse` in matrix.rs: + // S + let post_vec = p + .constants + .round_constants + .as_ref() + .unwrap() + .iter() + .skip( + p.constants_offset + + if add_current_round_keys { + p.elements.len() + } else { + 0 + }, + ) + .take(p.elements.len()) + .copied() + .collect::>(); + + // Compute the constants which should be added *before* the next `product_mds`. + // in order to have the same effect as adding the given constants *after* the next `product_mds`. + + // M^-1(S) + let inverted_vec = matrix::left_apply_matrix(&p.constants.mds_matrices.m_inv, &post_vec); + + // M(M^-1(S)) + let original = matrix::left_apply_matrix(&p.constants.mds_matrices.m, &inverted_vec); + + // S = M(M^-1(S)) + assert_eq!(&post_vec, &original, "Oh no, the inversion trick failed."); + + let post_round_keys = inverted_vec.iter(); + + // S-Box Output = B. + // With post-add, result is B + M^-1(S). + p.elements + .iter_mut() + .zip(pre_round_keys.zip(post_round_keys)) + .for_each(|(l, (pre, post))| { + quintic_s_box(l, pre, Some(post)); + }); + } else { + p.elements + .iter_mut() + .zip(pre_round_keys) + .for_each(|(l, pre)| { + quintic_s_box(l, pre, None); + }); + } + let mut consumed = 0; + if add_current_round_keys { + consumed += p.elements.len() + }; + if absorb_next_round_keys { + consumed += p.elements.len() + }; + p.constants_offset += consumed; + + // If absorb_next_round_keys + // M(B + M^-1(S) + // else + // M(B) + // Multiply the elements by the constant MDS matrix + p.product_mds(); +} + +pub(crate) fn partial_round_dynamic(p: &mut Poseidon<'_, F, A>) +where + F: PrimeField, + A: Arity, +{ + // Apply the quintic S-Box to the first element + quintic_s_box(&mut p.elements[0], None, None); + + // Multiply the elements by the constant MDS matrix + p.product_mds(); +} + +/// For every leaf, add the round constants with index defined by the constants offset, and increment the +/// offset. +fn add_round_constants(p: &mut Poseidon<'_, F, A>) +where + F: PrimeField, + A: Arity, +{ + for (element, round_constant) in p.elements.iter_mut().zip( + p.constants + .round_constants + .as_ref() + .unwrap() + .iter() + .skip(p.constants_offset), + ) { + element.add_assign(round_constant); + } + + p.constants_offset += p.elements.len(); +} diff --git a/src/provider/poseidon/poseidon_inner.rs b/src/provider/poseidon/poseidon_inner.rs new file mode 100644 index 00000000..c133f072 --- /dev/null +++ b/src/provider/poseidon/poseidon_inner.rs @@ -0,0 +1,612 @@ +use std::marker::PhantomData; + +use ff::PrimeField; +use generic_array::{sequence::GenericSequence, typenum, ArrayLength, GenericArray}; +use typenum::*; + +use super::{ + matrix::transpose, + mds::{derive_mds_matrices, factor_to_sparse_matrixes, generate_mds}, + preprocessing::compress_round_constants, +}; + +use super::{ + hash_type::HashType, + matrix::{left_apply_matrix, Matrix}, + mds::{MdsMatrices, SparseMatrix}, + poseidon_alt::{hash_correct, hash_optimized_dynamic}, + quintic_s_box, round_constants, round_numbers, PoseidonError, Strength, DEFAULT_STRENGTH, +}; + +/// Available arities for the Poseidon hasher. +pub trait Arity: ArrayLength { + /// Must be Arity + 1. + type ConstantsSize: ArrayLength; + + fn tag() -> T; +} + +macro_rules! impl_arity { + ($($a:ty),*) => { + $( + impl Arity for $a { + type ConstantsSize = Add1<$a>; + + fn tag() -> F { + F::from((1 << <$a as Unsigned>::to_usize()) - 1) + } + } + )* + }; +} + +// Dummy implementation to allow for an "optional" argument. +impl Arity for U0 { + type ConstantsSize = U0; + + fn tag() -> F { + unreachable!("dummy implementation for U0, should not be called") + } +} + +impl_arity!( + U1, U2, U3, U4, U5, U6, U7, U8, U9, U10, U11, U12, U13, U14, U15, U16, U17, U18, U19, U20, U21, + U22, U23, U24, U25, U26, U27, U28, U29, U30, U31, U32, U33, U34, U35, U36 +); + +/// Holds preimage, some utility offsets and counters along with the reference +/// to [`PoseidonConstants`] required for hashing. [`Poseidon`] is parameterized +/// by [`ff::PrimeField`] and [`Arity`], which should be similar to [`PoseidonConstants`]. +/// +/// [`Poseidon`] accepts input `elements` set with length equal or less than [`Arity`]. +/// +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Poseidon<'a, F, A = U2> +where + F: PrimeField, + A: Arity, +{ + pub(crate) constants_offset: usize, + pub(crate) current_round: usize, // Used in static optimization only for now. + /// the elements to permute + pub elements: GenericArray, + /// index of the next element of state to be absorbed + pub(crate) pos: usize, + pub(crate) constants: &'a PoseidonConstants, + _f: PhantomData, +} + +/// Holds constant values required for further [`Poseidon`] hashing. It contains MDS matrices, +/// round constants and numbers, parameters that specify security level ([`Strength`]) and +/// domain separation ([`HashType`]). Additional constants related to optimizations are also included. +/// +/// For correct operation, [`PoseidonConstants`] instance should be parameterized with the same [`ff::PrimeField`] +/// and [`Arity`] as [`Poseidon`] instance that consumes it. +/// +/// See original [Poseidon paper](https://eprint.iacr.org/2019/458.pdf) for more details. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PoseidonConstants> { + pub(crate) mds_matrices: MdsMatrices, + pub(crate) round_constants: Option>, // TODO: figure out how to automatically allocate `None` + pub(crate) compressed_round_constants: Vec, + pub(crate) pre_sparse_matrix: Matrix, + pub(crate) sparse_matrixes: Vec>, + pub(crate) strength: Strength, + /// The domain tag is the first element of a Poseidon permutation. + /// This extra element is necessary for 128-bit security. + pub(crate) domain_tag: F, + pub(crate) full_rounds: usize, + pub(crate) half_full_rounds: usize, + pub(crate) partial_rounds: usize, + pub(crate) hash_type: HashType, + pub(crate) _a: PhantomData, +} + +impl Default for PoseidonConstants +where + F: PrimeField, + A: Arity, +{ + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum HashMode { + // The initial and correct version of the algorithm. We should preserve the ability to hash this way for reference + // and to preserve confidence in our tests along thew way. + Correct, + // This mode is meant to be mostly synchronized with `Correct` but may reduce or simplify the work performed by the + // algorithm, if not the code implementing. Its purpose is for use during refactoring/development. + OptimizedDynamic, + // Consumes statically pre-processed constants for simplest operation. + OptimizedStatic, +} +use HashMode::{Correct, OptimizedDynamic, OptimizedStatic}; + +pub const DEFAULT_HASH_MODE: HashMode = OptimizedStatic; + +impl PoseidonConstants +where + F: PrimeField, + A: Arity, +{ + /// Generates new instance of [`PoseidonConstants`] suitable for both optimized / non-optimized hashing + /// with following default parameters: + /// - 128 bit of security; + /// - Merkle Tree (where all leafs are presented) domain separation ([`HashType`]). + pub fn new() -> Self { + Self::new_with_strength(DEFAULT_STRENGTH) + } + + /// Generates new instance of [`PoseidonConstants`] suitable for both optimized / non-optimized hashing + /// of constant-size preimages with following parameters: + /// - 128 bit of security; + /// - Constant-Input-Length Hashing domain separation ([`HashType`]). + /// + /// Instantiated [`PoseidonConstants`] still calculates internal constants based on [`Arity`], but calculation of + /// [`HashType::domain_tag`] is based on input `length`. + pub fn new_constant_length(length: usize) -> Self { + Self::new_with_strength_and_type(DEFAULT_STRENGTH, HashType::ConstantLength(length)) + } + + /// Creates new instance of [`PoseidonConstants`] from already defined one with recomputed domain tag. + /// + /// It is assumed that input `length` is equal or less than [`Arity`]. + pub fn with_length(&self, length: usize) -> Self { + let arity = A::to_usize(); + assert!(length <= arity); + + let hash_type = match self.hash_type { + HashType::ConstantLength(_) => HashType::ConstantLength(length), + _ => panic!("cannot set constant length of hash without type ConstantLength."), + }; + + let domain_tag = hash_type.domain_tag(); + + Self { + hash_type, + domain_tag, + ..self.clone() + } + } + + /// Generates new instance of [`PoseidonConstants`] suitable for both optimized / non-optimized hashing + /// with Merkle Tree (where all leafs are presented) domain separation ([`HashType`]) custom security level ([`Strength`]). + pub fn new_with_strength(strength: Strength) -> Self { + Self::new_with_strength_and_type(strength, HashType::MerkleTree) + } + + /// Generates new instance of [`PoseidonConstants`] suitable for both optimized / non-optimized hashing + /// with custom domain separation ([`HashType`]) and custom security level ([`Strength`]). + pub fn new_with_strength_and_type(strength: Strength, hash_type: HashType) -> Self { + assert!(hash_type.is_supported()); + let arity = A::to_usize(); + let width = arity + 1; + let mds = generate_mds(width); + let (full_rounds, partial_rounds) = round_numbers(arity, &strength); + let round_constants = round_constants(arity, &strength); + + // Now call new_from_parameters with all the necessary parameters. + Self::new_from_parameters( + width, + mds, + round_constants, + full_rounds, + partial_rounds, + hash_type, + strength, + ) + } + + /// Generates new instance of [`PoseidonConstants`] with matrix, constants and number of rounds. + /// The matrix does not have to be symmetric. + pub fn new_from_parameters( + width: usize, + m: Matrix, + round_constants: Vec, + full_rounds: usize, + partial_rounds: usize, + hash_type: HashType, + strength: Strength, + ) -> Self { + let mds_matrices = derive_mds_matrices(m); + let half_full_rounds = full_rounds / 2; + let compressed_round_constants = compress_round_constants( + width, + full_rounds, + partial_rounds, + &round_constants, + &mds_matrices, + partial_rounds, + ); + + let (pre_sparse_matrix, sparse_matrixes) = + factor_to_sparse_matrixes(&transpose(&mds_matrices.m), partial_rounds); + + // Ensure we have enough constants for the sbox rounds + assert!( + width * (full_rounds + partial_rounds) <= round_constants.len(), + "Not enough round constants" + ); + + assert_eq!( + full_rounds * width + partial_rounds, + compressed_round_constants.len() + ); + + Self { + mds_matrices, + round_constants: Some(round_constants), + compressed_round_constants, + pre_sparse_matrix, + sparse_matrixes, + strength, + domain_tag: hash_type.domain_tag(), + full_rounds, + half_full_rounds, + partial_rounds, + hash_type, + _a: PhantomData::, + } + } + + /// Returns the [`Arity`] value represented as `usize`. + #[inline] + pub fn arity(&self) -> usize { + A::to_usize() + } + + /// Returns `width` value represented as `usize`. It equals to [`Arity`] + 1. + #[inline] + pub fn width(&self) -> usize { + A::ConstantsSize::to_usize() + } +} + +impl<'a, F, A> Poseidon<'a, F, A> +where + F: PrimeField, + A: Arity, +{ + /// Creates [`Poseidon`] instance using provided [`PoseidonConstants`] as input. Underlying set of + /// elements are initialized and `domain_tag` from [`PoseidonConstants`] is used as zero element in the set. + /// Therefore, hashing is eventually performed over [`Arity`] + 1 elements in fact, while [`Arity`] elements + /// are occupied by preimage data. + pub fn new(constants: &'a PoseidonConstants) -> Self { + let elements = GenericArray::generate(|i| { + if i == 0 { + constants.domain_tag + } else { + F::ZERO + } + }); + Poseidon { + constants_offset: 0, + current_round: 0, + elements, + pos: 1, + constants, + _f: PhantomData::, + } + } + + /// Creates [`Poseidon`] instance using provided preimage and [`PoseidonConstants`] as input. + /// Doesn't support [`PoseidonConstants`] with [`HashType::VariableLength`]. It is assumed that + /// size of input preimage set can't be greater than [`Arity`]. + pub fn new_with_preimage(preimage: &[F], constants: &'a PoseidonConstants) -> Self { + let elements = match constants.hash_type { + HashType::ConstantLength(constant_len) => { + assert_eq!(constant_len, preimage.len(), "Invalid preimage size"); + + GenericArray::generate(|i| { + if i == 0 { + constants.domain_tag + } else if i > preimage.len() { + F::ZERO + } else { + preimage[i - 1] + } + }) + } + HashType::MerkleTreeSparse(_) => { + panic!("Merkle Tree (with some empty leaves) hashes are not yet supported.") + } + HashType::VariableLength => panic!("variable-length hashes are not yet supported."), + _ => { + assert_eq!(preimage.len(), A::to_usize(), "Invalid preimage size"); + + GenericArray::generate(|i| { + if i == 0 { + constants.domain_tag + } else { + preimage[i - 1] + } + }) + } + }; + let width = preimage.len() + 1; + + Poseidon { + constants_offset: 0, + current_round: 0, + elements, + pos: width, + constants, + _f: PhantomData::, + } + } + + /// Replaces the elements with the provided optional items. + /// + /// # Panics + /// + /// Panics if the provided slice is not equal to the arity. + pub fn set_preimage(&mut self, preimage: &[F]) { + self.reset(); + self.elements[1..].copy_from_slice(preimage); + self.pos = self.elements.len(); + } + + /// Restore the initial state + pub fn reset(&mut self) { + self.reset_offsets(); + self.elements[1..].iter_mut().for_each(|l| *l = F::ZERO); + self.elements[0] = self.constants.domain_tag; + } + + pub(crate) fn reset_offsets(&mut self) { + self.constants_offset = 0; + self.current_round = 0; + self.pos = 1; + } + + /// Adds one more field element of preimage to the underlying [`Poseidon`] buffer for further hashing. + /// The returned `usize` represents the element position (within arity) for the input operation. + /// Returns [`PoseidonError::FullBuffer`] if no more elements can be added for hashing. + pub fn input(&mut self, element: F) -> Result { + // Cannot input more elements than the defined arity + // To hash constant-length input greater than arity, use sponge explicitly. + if self.pos >= self.constants.width() { + return Err(PoseidonError::FullBuffer); + } + + // Set current element, and increase the pointer + self.elements[self.pos] = element; + self.pos += 1; + + Ok(self.pos - 1) + } + + /// Performs hashing using underlying [`Poseidon`] buffer of the preimage' field elements + /// using provided [`HashMode`]. Always outputs digest expressed as a single field element + /// of concrete type specified upon [`PoseidonConstants`] and [`Poseidon`] instantiations. + pub fn hash_in_mode(&mut self, mode: HashMode) -> F { + let res = match mode { + Correct => hash_correct(self), + OptimizedDynamic => hash_optimized_dynamic(self), + OptimizedStatic => self.hash_optimized_static(), + }; + self.reset_offsets(); + res + } + + /// Performs hashing using underlying [`Poseidon`] buffer of the preimage' field elements + /// in default (optimized) mode. Always outputs digest expressed as a single field element + /// of concrete type specified upon [`PoseidonConstants`] and [`Poseidon`] instantiations. + pub fn hash(&mut self) -> F { + self.hash_in_mode(DEFAULT_HASH_MODE) + } + + pub(crate) fn apply_padding(&mut self) { + if let HashType::ConstantLength(l) = self.constants.hash_type { + let final_pos = 1 + (l % self.constants.arity()); + + assert_eq!( + self.pos, final_pos, + "preimage length does not match constant length required for hash" + ); + }; + match self.constants.hash_type { + HashType::ConstantLength(_) | HashType::Encryption => { + for elt in self.elements[self.pos..].iter_mut() { + *elt = F::ZERO; + } + self.pos = self.elements.len(); + } + HashType::VariableLength => todo!(), + _ => (), // incl. HashType::Sponge + } + } + + /// Returns 1-th element from underlying [`Poseidon`] buffer. This function is important, since + /// according to [`Poseidon`] design, after performing hashing, output digest will be stored at + /// 1-st place of underlying buffer. + #[inline] + pub fn extract_output(&self) -> F { + self.elements[1] + } + + /// Performs hashing using underlying [`Poseidon`] buffer of the preimage' field elements + /// using [`HashMode::OptimizedStatic`] mode. Always outputs digest expressed as a single field element + /// of concrete type specified upon [`PoseidonConstants`] and [`Poseidon`] instantiations. + pub fn hash_optimized_static(&mut self) -> F { + // The first full round should use the initial constants. + self.add_round_constants(); + + for _ in 0..self.constants.half_full_rounds { + self.full_round(false); + } + + for _ in 0..self.constants.partial_rounds { + self.partial_round(); + } + + // All but last full round. + for _ in 1..self.constants.half_full_rounds { + self.full_round(false); + } + self.full_round(true); + + assert_eq!( + self.constants_offset, + self.constants.compressed_round_constants.len(), + "Constants consumed ({}) must equal preprocessed constants provided ({}).", + self.constants_offset, + self.constants.compressed_round_constants.len() + ); + + self.extract_output() + } + + fn full_round(&mut self, last_round: bool) { + let to_take = self.elements.len(); + let post_round_keys = self + .constants + .compressed_round_constants + .iter() + .skip(self.constants_offset) + .take(to_take); + + if !last_round { + let needed = self.constants_offset + to_take; + assert!( + needed <= self.constants.compressed_round_constants.len(), + "Not enough preprocessed round constants ({}), need {}.", + self.constants.compressed_round_constants.len(), + needed + ); + } + self + .elements + .iter_mut() + .zip(post_round_keys) + .for_each(|(l, post)| { + // Be explicit that no round key is added after last round of S-boxes. + let post_key = if last_round { + panic!( + "Trying to skip last full round, but there is a key here! ({:?})", + post + ); + } else { + Some(post) + }; + quintic_s_box(l, None, post_key); + }); + // We need this because post_round_keys will have been empty, so it didn't happen in the for_each. :( + if last_round { + self + .elements + .iter_mut() + .for_each(|l| quintic_s_box(l, None, None)); + } else { + self.constants_offset += self.elements.len(); + } + self.round_product_mds(); + } + + /// The partial round is the same as the full round, with the difference that we apply the S-Box only to the first (arity tag) poseidon leaf. + fn partial_round(&mut self) { + let post_round_key = self.constants.compressed_round_constants[self.constants_offset]; + + // Apply the quintic S-Box to the first element + quintic_s_box(&mut self.elements[0], None, Some(&post_round_key)); + self.constants_offset += 1; + + self.round_product_mds(); + } + + fn add_round_constants(&mut self) { + for (element, round_constant) in self.elements.iter_mut().zip( + self + .constants + .compressed_round_constants + .iter() + .skip(self.constants_offset), + ) { + element.add_assign(round_constant); + } + self.constants_offset += self.elements.len(); + } + + /// Set the provided elements with the result of the product between the elements and the appropriate + /// MDS matrix. + #[allow(clippy::collapsible_else_if)] + fn round_product_mds(&mut self) { + let full_half = self.constants.half_full_rounds; + let sparse_offset = full_half - 1; + if self.current_round == sparse_offset { + self.product_mds_with_matrix(&self.constants.pre_sparse_matrix); + } else { + if (self.current_round > sparse_offset) + && (self.current_round < full_half + self.constants.partial_rounds) + { + let index = self.current_round - sparse_offset - 1; + let sparse_matrix = &self.constants.sparse_matrixes[index]; + + self.product_mds_with_sparse_matrix(sparse_matrix); + } else { + self.product_mds(); + } + }; + + self.current_round += 1; + } + + /// Set the provided elements with the result of the product between the elements and the constant + /// MDS matrix. + pub(crate) fn product_mds(&mut self) { + self.product_mds_with_matrix_left(&self.constants.mds_matrices.m); + } + + /// NOTE: This calculates a vector-matrix product (`elements * matrix`) rather than the + /// expected matrix-vector `(matrix * elements)`. This is a performance optimization which + /// exploits the fact that our MDS matrices are symmetric by construction. + #[allow(clippy::ptr_arg)] + pub(crate) fn product_mds_with_matrix(&mut self, matrix: &Matrix) { + let mut result = GenericArray::::generate(|_| F::ZERO); + + for (j, val) in result.iter_mut().enumerate() { + for (i, row) in matrix.iter().enumerate() { + let mut tmp = row[j]; + tmp.mul_assign(&self.elements[i]); + val.add_assign(&tmp); + } + } + + let _ = std::mem::replace(&mut self.elements, result); + } + + pub(crate) fn product_mds_with_matrix_left(&mut self, matrix: &Matrix) { + let result = left_apply_matrix(matrix, &self.elements); + let _ = std::mem::replace( + &mut self.elements, + GenericArray::::generate(|i| result[i]), + ); + } + + // Sparse matrix in this context means one of the form, M''. + fn product_mds_with_sparse_matrix(&mut self, sparse_matrix: &SparseMatrix) { + let mut result = GenericArray::::generate(|_| F::ZERO); + + // First column is dense. + for (i, val) in sparse_matrix.w_hat.iter().enumerate() { + let mut tmp = *val; + tmp.mul_assign(&self.elements[i]); + result[0].add_assign(&tmp); + } + + for (j, val) in result.iter_mut().enumerate().skip(1) { + // Except for first row/column, diagonals are one. + val.add_assign(&self.elements[j]); + + // First row is dense. + let mut tmp = sparse_matrix.v_rest[j - 1]; + tmp.mul_assign(&self.elements[0]); + val.add_assign(&tmp); + } + + let _ = std::mem::replace(&mut self.elements, result); + } +} diff --git a/src/provider/poseidon/preprocessing.rs b/src/provider/poseidon/preprocessing.rs new file mode 100644 index 00000000..e2ff3a5a --- /dev/null +++ b/src/provider/poseidon/preprocessing.rs @@ -0,0 +1,173 @@ +use super::matrix::{left_apply_matrix, vec_add}; +use super::mds::MdsMatrices; +use super::quintic_s_box; +use ff::PrimeField; + +// - Compress constants by pushing them back through linear layers and through the identity components of partial layers. +// - As a result, constants need only be added after each S-box. +#[allow(clippy::ptr_arg)] +pub(crate) fn compress_round_constants( + width: usize, + full_rounds: usize, + partial_rounds: usize, + round_constants: &Vec, + mds_matrices: &MdsMatrices, + partial_preprocessed: usize, +) -> Vec { + let mds_matrix = &mds_matrices.m; + let inverse_matrix = &mds_matrices.m_inv; + + let mut res = Vec::new(); + + let round_keys = |r: usize| &round_constants[r * width..(r + 1) * width]; + + let half_full_rounds = full_rounds / 2; // Not half-full rounds; half full-rounds. + + // First round constants are unchanged. + res.extend(round_keys(0)); + + let unpreprocessed = partial_rounds - partial_preprocessed; + + // Post S-box adds for the first set of full rounds should be 'inverted' from next round. + // The final round is skipped when fully preprocessing because that value must be obtained from the result of preprocessing the partial rounds. + let end = if unpreprocessed > 0 { + half_full_rounds + } else { + half_full_rounds - 1 + }; + for i in 0..end { + let next_round = round_keys(i + 1); // First round was added before any S-boxes. + let inverted = left_apply_matrix(inverse_matrix, next_round); + res.extend(inverted); + } + + // The plan: + // - Work backwards from last row in this group + // - Invert the row. + // - Save first constant (corresponding to the one S-box performed). + // - Add inverted result to previous row. + // - Repeat until all partial round key rows have been consumed. + // - Extend the preprocessed result by the final resultant row. + // - Move the accumulated list of single round keys to the preprocessed result. + // - (Last produced should be first applied, so either pop until empty, or reverse and extend, etc. + + // `partial_keys` will accumulate the single post-S-box constant for each partial-round, in reverse order. + let mut partial_keys: Vec = Vec::new(); + + let final_round = half_full_rounds + partial_rounds; + let final_round_key = round_keys(final_round).to_vec(); + + // `round_acc` holds the accumulated result of inverting and adding subsequent round constants (in reverse). + let round_acc = (0..partial_preprocessed) + .map(|i| round_keys(final_round - i - 1)) + .fold(final_round_key, |acc, previous_round_keys| { + let mut inverted = left_apply_matrix(inverse_matrix, &acc); + + partial_keys.push(inverted[0]); + inverted[0] = F::ZERO; + + vec_add(previous_round_keys, &inverted) + }); + + // Everything in here is dev-driven testing. + // Dev test case only checks one deep. + if partial_preprocessed == 1 { + // Check assumptions about how the fold calculating round_acc manifested. + + // The last round containing unpreprocessed constants which should be compressed. + let terminal_constants_round = half_full_rounds + partial_rounds; + + // Constants from the last round (of two) which should be compressed. + // T + let terminal_round_keys = round_keys(terminal_constants_round); + + // Constants from the first round (of two) which should be compressed. + // I + let initial_round_keys = round_keys(terminal_constants_round - 1); + + // M^-1(T) + let mut inv = left_apply_matrix(inverse_matrix, terminal_round_keys); + + // M^-1(T)[0] + let pk = inv[0]; + + // M^-1(T) - pk (kinda) + inv[0] = F::ZERO; + + // (M^-1(T) - pk) - I + let result_key = vec_add(initial_round_keys, &inv); + + assert_eq!(&result_key, &round_acc, "Acc assumption failed."); + assert_eq!(pk, partial_keys[0], "Partial-key assumption failed."); + assert_eq!( + 1, + partial_keys.len(), + "Partial-keys length assumption failed." + ); + + //////////////////////////////////////////////////////////////////////////////// + // Shared between branches, arbitrary initial state representing the output of a previous round's S-Box layer. + // X + let initial_state = vec![F::ONE; width]; + + //////////////////////////////////////////////////////////////////////////////// + // Compute one step with the given (unpreprocessed) constants. + + // ARK + // I + X + let mut q_state = vec_add(initial_round_keys, &initial_state); + + // S-Box (partial layer) + // S((I + X)[0]) = S(I[0] + X[0]) + quintic_s_box(&mut q_state[0], None, None); + + // Mix with mds_matrix + let mixed = left_apply_matrix(mds_matrix, &q_state); + + // Ark + let plain_result = vec_add(terminal_round_keys, &mixed); + + //////////////////////////////////////////////////////////////////////////////// + // Compute the same step using the preprocessed constants. + // M'(initial_state) + (inverted_id - initial_state) = inverted_id + //let initial_state1 = apply_matrix::(&m_prime, &initial_state); + let mut p_state = vec_add(&result_key, &initial_state); + + // In order for the S-box result to be correct, it must have the same input as in the plain path. + // That means its input (the first component of the state) must have been constructed by + // adding the same single round constant in that position. + // NOTE: this assertion uncovered a bug which was causing failure. + assert_eq!( + &result_key[0], &initial_round_keys[0], + "S-box inputs did not match." + ); + + quintic_s_box(&mut p_state[0], None, Some(&pk)); + + let preprocessed_result = left_apply_matrix(mds_matrix, &p_state); + + assert_eq!( + plain_result, preprocessed_result, + "Single preprocessing step couldn't be verified." + ); + } + + for i in 1..unpreprocessed { + res.extend(round_keys(half_full_rounds + i)); + } + res.extend(left_apply_matrix(inverse_matrix, &round_acc)); + + while let Some(x) = partial_keys.pop() { + res.push(x) + } + + // Post S-box adds for the first set of full rounds should be 'inverted' from next round. + for i in 1..(half_full_rounds) { + let start = half_full_rounds + partial_rounds; + let next_round = round_keys(i + start); + let inverted = left_apply_matrix(inverse_matrix, next_round); + res.extend(inverted); + } + + res +} diff --git a/src/provider/poseidon/round_constants.rs b/src/provider/poseidon/round_constants.rs new file mode 100644 index 00000000..20d196f0 --- /dev/null +++ b/src/provider/poseidon/round_constants.rs @@ -0,0 +1,195 @@ +use ff::PrimeField; + +/// From the paper (): +/// The round constants are generated using the Grain LFSR [23] in a self-shrinking +/// mode: +/// 1. Initialize the state with 80 bits b0, b1, . . . , b79, where +/// (a) b0, b1 describe the field, +/// (b) bi for 2 ≤ i ≤ 5 describe the S-Box, +/// (c) bi for 6 ≤ i ≤ 17 are the binary representation of n, +/// (d) bi for 18 ≤ i ≤ 29 are the binary representation of t, +/// (e) bi for 30 ≤ i ≤ 39 are the binary representation of RF , +/// (f) bi for 40 ≤ i ≤ 49 are the binary representation of RP , and +/// (g) bi for 50 ≤ i ≤ 79 are set to 1. +/// 2. Update the bits using bi+80 = bi+62 ⊕ bi+51 ⊕ bi+38 ⊕ bi+23 ⊕ bi+13 ⊕ bi +/// +/// 3. Discard the first 160 bits. +/// 4. Evaluate bits in pairs: If the first bit is a 1, output the second bit. If it is a +/// 0, discard the second bit. +/// Using this method, the generation of round constants depends on the specific +/// instance, and thus different round constants are used even if some of the chosen +/// parameters (e.g., n and t) are the same. +/// If a randomly sampled integer is not in Fp, we discard this value and take the +/// next one. Note that cryptographically strong randomness is not needed for the +/// round constants, and other methods can also be used. +/// +/// Following https://extgit.iaik.tugraz.at/krypto/hadeshash/blob/master/code/scripts/create_rcs_grain.sage +/// The script was updated and can currently be found at: +/// https://extgit.iaik.tugraz.at/krypto/hadeshash/blob/master/code/generate_parameters_grain.sage +pub(crate) fn generate_constants( + field: u8, + sbox: u8, + field_size: u16, + t: u16, + r_f: u16, + r_p: u16, +) -> Vec { + let n_bytes = F::Repr::default().as_ref().len(); + if n_bytes != 32 { + unimplemented!("neptune currently supports 32-byte fields exclusively"); + } + assert_eq!((f32::from(field_size) / 8.0).ceil() as usize, n_bytes); + + let num_constants = (r_f + r_p) * t; + let mut init_sequence: Vec = Vec::new(); + append_bits(&mut init_sequence, 2, field); // Bits 0-1 + append_bits(&mut init_sequence, 4, sbox); // Bits 2-5 + append_bits(&mut init_sequence, 12, field_size); // Bits 6-17 + append_bits(&mut init_sequence, 12, t); // Bits 18-29 + append_bits(&mut init_sequence, 10, r_f); // Bits 30-39 + append_bits(&mut init_sequence, 10, r_p); // Bits 40-49 + append_bits(&mut init_sequence, 30, 0b111111111111111111111111111111u128); // Bits 50-79 + + let mut grain = Grain::new(init_sequence, field_size); + let mut round_constants: Vec = Vec::new(); + match field { + 1 => { + for _ in 0..num_constants { + loop { + // Generate 32 bytes and interpret them as a big-endian integer. Bytes are + // big-endian to agree with the integers generated by grain_random_bits in the + // reference implementation: + // + // def grain_random_bits(num_bits): + // random_bits = [grain_gen.next() for i in range(0, num_bits)] + // random_int = int("".join(str(i) for i in random_bits), 2) + // return random_int + let mut repr = F::Repr::default(); + grain.get_next_bytes(repr.as_mut()); + repr.as_mut().reverse(); + if let Some(f) = F::from_repr_vartime(repr) { + round_constants.push(f); + break; + } + } + } + } + _ => { + panic!("Only prime fields are supported."); + } + } + round_constants +} + +fn append_bits>(vec: &mut Vec, n: usize, from: T) { + let val = from.into(); + for i in (0..n).rev() { + vec.push((val >> i) & 1 != 0); + } +} + +struct Grain { + state: Vec, + field_size: u16, +} + +impl Grain { + fn new(init_sequence: Vec, field_size: u16) -> Self { + assert_eq!(80, init_sequence.len()); + let mut g = Grain { + state: init_sequence, + field_size, + }; + for _ in 0..160 { + g.generate_new_bit(); + } + assert_eq!(80, g.state.len()); + g + } + + fn generate_new_bit(&mut self) -> bool { + let new_bit = + self.bit(62) ^ self.bit(51) ^ self.bit(38) ^ self.bit(23) ^ self.bit(13) ^ self.bit(0); + self.state.remove(0); + self.state.push(new_bit); + new_bit + } + + fn bit(&self, index: usize) -> bool { + self.state[index] + } + + fn next_byte(&mut self, bit_count: usize) -> u8 { + // Accumulate bits from most to least significant, so the most significant bit is the one generated first by the bit stream. + let mut acc: u8 = 0; + self.take(bit_count).for_each(|bit| { + acc <<= 1; + if bit { + acc += 1; + } + }); + + acc + } + + fn get_next_bytes(&mut self, result: &mut [u8]) { + let remainder_bits = self.field_size as usize % 8; + + // Prime fields will always have remainder bits, + // but other field types could be supported in the future. + if remainder_bits > 0 { + // If there is an unfull byte, it should be the first. + // Subsequent bytes are packed into result in the order generated. + result[0] = self.next_byte(remainder_bits); + } else { + result[0] = self.next_byte(8); + } + + // First byte is already set + for item in result.iter_mut().skip(1) { + *item = self.next_byte(8) + } + } +} + +impl Iterator for Grain { + type Item = bool; + + fn next(&mut self) -> Option { + let mut new_bit = self.generate_new_bit(); + while !new_bit { + let _new_bit = self.generate_new_bit(); + new_bit = self.generate_new_bit(); + } + new_bit = self.generate_new_bit(); + Some(new_bit) + } +} + +#[allow(dead_code)] +#[inline] +const fn bool_to_u8(bit: bool, offset: usize) -> u8 { + if bit { + 1u8 << offset + } else { + 0u8 + } +} + +/// Converts a slice of bools into their byte representation, in little endian. +#[allow(dead_code)] +pub(crate) fn bits_to_bytes(bits: &[bool]) -> Vec { + bits + .chunks(8) + .map(|bits| { + bool_to_u8(bits[7], 7) + | bool_to_u8(bits[6], 6) + | bool_to_u8(bits[5], 5) + | bool_to_u8(bits[4], 4) + | bool_to_u8(bits[3], 3) + | bool_to_u8(bits[2], 2) + | bool_to_u8(bits[1], 1) + | bool_to_u8(bits[0], 0) + }) + .collect() +} diff --git a/src/provider/poseidon/round_numbers.rs b/src/provider/poseidon/round_numbers.rs new file mode 100644 index 00000000..a4434ca3 --- /dev/null +++ b/src/provider/poseidon/round_numbers.rs @@ -0,0 +1,91 @@ +//! A port of `calc_round_numbers.py` +//! https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/9d80ec0473ad7cde5a12f3aac46439ad0da68c0a/code/scripts/calc_round_numbers.py +//! from Python2 to Rust for a (roughly) 256-bit prime field (e.g. BLS12-381's scalar field) and +//! 128-bit security level. + +// The number of bits of the Poseidon prime field modulus. Denoted `n` in the Poseidon paper +// (where `n = ceil(log2(p))`). Note that BLS12-381's scalar field modulus is 255 bits, however we +// use 256 bits for simplicity when operating on bytes as the single bit difference does not affect +// the round number security properties. +const PRIME_BITLEN: usize = 256; + +// Security level (in bits), denoted `M` in the Poseidon paper. +const M: usize = 128; + +// The number of S-boxes (also called the "cost") given by equation (14) in the Poseidon paper: +// `cost = t * R_F + R_P`. +#[inline] +const fn n_sboxes(t: usize, rf: usize, rp: usize) -> usize { + t * rf + rp +} + +// Returns the round numbers for a given arity `(R_F, R_P)`. +pub(crate) fn round_numbers_base(arity: usize) -> (usize, usize) { + let t = arity + 1; + calc_round_numbers(t, true) +} + +// In case of newly-discovered attacks, we may need stronger security. +// This option exists so we can preemptively create circuits in order to switch +// to them quickly if needed. +// +// "A realistic alternative is to increase the number of partial rounds by 25%. +// Then it is unlikely that a new attack breaks through this number, +// but even if this happens then the complexity is almost surely above 2^64, and you will be safe." +// - D Khovratovich +pub(crate) fn round_numbers_strengthened(arity: usize) -> (usize, usize) { + let (full_round, partial_rounds) = round_numbers_base(arity); + + // Increase by 25%, rounding up. + let strengthened_partial_rounds = f64::ceil(partial_rounds as f64 * 1.25) as usize; + + (full_round, strengthened_partial_rounds) +} + +// Returns the round numbers for a given width `t`. Here, the `security_margin` parameter does not +// indicate that we are calculating `R_F` and `R_P` for the "strengthened" round numbers, done in +// the function `round_numbers_strengthened()`. +pub(crate) fn calc_round_numbers(t: usize, security_margin: bool) -> (usize, usize) { + let mut rf = 0; + let mut rp = 0; + let mut n_sboxes_min = usize::MAX; + + for mut rf_test in (2..=1000).step_by(2) { + for mut rp_test in 4..200 { + if round_numbers_are_secure(t, rf_test, rp_test) { + if security_margin { + rf_test += 2; + rp_test = (1.075 * rp_test as f32).ceil() as usize; + } + let n_sboxes = n_sboxes(t, rf_test, rp_test); + if n_sboxes < n_sboxes_min || (n_sboxes == n_sboxes_min && rf_test < rf) { + rf = rf_test; + rp = rp_test; + n_sboxes_min = n_sboxes; + } + } + } + } + + (rf, rp) +} + +// Returns `true` if the provided round numbers satisfy the security inequalities specified in the +// Poseidon paper. +fn round_numbers_are_secure(t: usize, rf: usize, rp: usize) -> bool { + let (rp, t, n, m) = (rp as f32, t as f32, PRIME_BITLEN as f32, M as f32); + let rf_stat = if m <= (n - 3.0) * (t + 1.0) { + 6.0 + } else { + 10.0 + }; + let rf_interp = 0.43 * m + t.log2() - rp; + let rf_grob_1 = 0.21 * n - rp; + let rf_grob_2 = (0.14 * n - 1.0 - rp) / (t - 1.0); + let rf_max = [rf_stat, rf_interp, rf_grob_1, rf_grob_2] + .iter() + .map(|rf| rf.ceil() as usize) + .max() + .unwrap(); + rf >= rf_max +} diff --git a/src/provider/poseidon/serde_impl.rs b/src/provider/poseidon/serde_impl.rs new file mode 100644 index 00000000..736aa096 --- /dev/null +++ b/src/provider/poseidon/serde_impl.rs @@ -0,0 +1,238 @@ +use ff::PrimeField; +use serde::{ + de::{self, Deserializer, MapAccess, SeqAccess, Visitor}, + ser::{SerializeStruct, Serializer}, + Deserialize, Serialize, +}; +use std::fmt; +use std::marker::PhantomData; + +use super::hash_type::HashType; +use super::poseidon_inner::Arity; +use super::poseidon_inner::PoseidonConstants; + +impl Serialize for PoseidonConstants +where + F: PrimeField + Serialize, + A: Arity, +{ + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("PoseidonConstants", 8)?; + state.serialize_field("mds", &self.mds_matrices)?; + state.serialize_field("crc", &self.compressed_round_constants)?; + state.serialize_field("psm", &self.pre_sparse_matrix)?; + state.serialize_field("sm", &self.sparse_matrixes)?; + state.serialize_field("s", &self.strength)?; + state.serialize_field("rf", &self.full_rounds)?; + state.serialize_field("rp", &self.partial_rounds)?; + state.serialize_field("ht", &self.hash_type)?; + state.end() + } +} + +impl<'de, F, A> Deserialize<'de> for PoseidonConstants +where + F: PrimeField + Deserialize<'de>, + A: Arity, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "lowercase")] + enum Field { + Mds, + Crc, + Psm, + Sm, + S, + Rf, + Rp, + Ht, + } + + struct PoseidonConstantsVisitor + where + F: PrimeField, + A: Arity, + { + _f: PhantomData, + _a: PhantomData, + } + + impl<'de, F, A> Visitor<'de> for PoseidonConstantsVisitor + where + F: PrimeField + Deserialize<'de>, + A: Arity, + { + type Value = PoseidonConstants; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("struct PoseidonConstants") + } + + fn visit_seq(self, mut seq: V) -> Result, V::Error> + where + V: SeqAccess<'de>, + { + let mds_matrices = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let compressed_round_constants = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + let pre_sparse_matrix = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + let sparse_matrixes = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(3, &self))?; + let strength = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(4, &self))?; + let full_rounds = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(5, &self))?; + let partial_rounds = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(6, &self))?; + let hash_type: HashType = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(7, &self))?; + + Ok(PoseidonConstants { + mds_matrices, + round_constants: None, + compressed_round_constants, + pre_sparse_matrix, + sparse_matrixes, + strength, + domain_tag: hash_type.domain_tag(), + full_rounds, + half_full_rounds: full_rounds / 2, + partial_rounds, + hash_type, + _a: PhantomData::, + }) + } + + fn visit_map(self, mut map: V) -> Result, V::Error> + where + V: MapAccess<'de>, + { + let mut mds_matrices = None; + let mut compressed_round_constants = None; + let mut pre_sparse_matrix = None; + let mut sparse_matrixes = None; + let mut strength = None; + let mut full_rounds = None; + let mut partial_rounds = None; + let mut hash_type = None; + + while let Some(key) = map.next_key()? { + match key { + Field::Mds => { + if mds_matrices.is_some() { + return Err(de::Error::duplicate_field("mds_matrices")); + } + mds_matrices = Some(map.next_value()?); + } + Field::Crc => { + if compressed_round_constants.is_some() { + return Err(de::Error::duplicate_field("compressed_round_constants")); + } + compressed_round_constants = Some(map.next_value()?); + } + Field::Psm => { + if pre_sparse_matrix.is_some() { + return Err(de::Error::duplicate_field("pre_sparse_matrix")); + } + pre_sparse_matrix = Some(map.next_value()?); + } + Field::Sm => { + if sparse_matrixes.is_some() { + return Err(de::Error::duplicate_field("sparse_matrixes")); + } + sparse_matrixes = Some(map.next_value()?); + } + Field::S => { + if strength.is_some() { + return Err(de::Error::duplicate_field("strength")); + } + strength = Some(map.next_value()?); + } + Field::Rf => { + if full_rounds.is_some() { + return Err(de::Error::duplicate_field("full_rounds")); + } + full_rounds = Some(map.next_value()?); + } + Field::Rp => { + if partial_rounds.is_some() { + return Err(de::Error::duplicate_field("partial_rounds")); + } + partial_rounds = Some(map.next_value()?); + } + Field::Ht => { + if hash_type.is_some() { + return Err(de::Error::duplicate_field("hash_type")); + } + hash_type = Some(map.next_value()?); + } + } + } + + let mds_matrices = mds_matrices.ok_or_else(|| de::Error::missing_field("mds_matrices"))?; + let compressed_round_constants = compressed_round_constants + .ok_or_else(|| de::Error::missing_field("compressed_round_constants"))?; + let pre_sparse_matrix = + pre_sparse_matrix.ok_or_else(|| de::Error::missing_field("pre_sparse_matrix"))?; + let sparse_matrixes = + sparse_matrixes.ok_or_else(|| de::Error::missing_field("sparse_matrixes"))?; + let strength = strength.ok_or_else(|| de::Error::missing_field("strength"))?; + let full_rounds = full_rounds.ok_or_else(|| de::Error::missing_field("full_rounds"))?; + let partial_rounds = + partial_rounds.ok_or_else(|| de::Error::missing_field("partial_rounds"))?; + let hash_type: HashType = + hash_type.ok_or_else(|| de::Error::missing_field("hash_type"))?; + Ok(PoseidonConstants { + mds_matrices, + round_constants: None, + compressed_round_constants, + pre_sparse_matrix, + sparse_matrixes, + strength, + domain_tag: hash_type.domain_tag(), + full_rounds, + half_full_rounds: full_rounds / 2, + partial_rounds, + hash_type, + _a: PhantomData::, + }) + } + } + + const FIELDS: &[&str] = &[ + "mds_matrices", + "compressed_round_constants", + "pre_sparse_matrix", + "sparse_matrixes", + "strength", + "full_rounds", + "partial_rounds", + "hash_type", + ]; + deserializer.deserialize_struct( + "PoseidonConstants", + FIELDS, + PoseidonConstantsVisitor { + _f: PhantomData, + _a: PhantomData, + }, + ) + } +} diff --git a/src/provider/poseidon/sponge/api.rs b/src/provider/poseidon/sponge/api.rs new file mode 100644 index 00000000..009bd0f9 --- /dev/null +++ b/src/provider/poseidon/sponge/api.rs @@ -0,0 +1,330 @@ +/// This module implements a variant of the 'Secure Sponge API for Field Elements': https://hackmd.io/bHgsH6mMStCVibM_wYvb2w +/// +/// The API is defined by the `SpongeAPI` trait, which is implemented in terms of the `InnerSpongeAPI` trait. +/// `Neptune` provides implementations of `InnerSpongeAPI` for both `sponge::Sponge` and `sponge_circuit::SpongeCircuit`. +use crate::provider::poseidon::poseidon_inner::Arity; +use ff::PrimeField; + +#[derive(Debug)] +pub enum Error { + ParameterUsageMismatch, +} + +/// Sponge operations +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SpongeOp { + /// Absorb + Absorb(u32), + /// Squeeze + Squeeze(u32), +} + +/// A sequence of sponge operations +#[derive(Clone, Debug)] +pub struct IOPattern(pub Vec); + +impl IOPattern { + /// Compute the value of the pattern given a domain separator + pub fn value(&self, domain_separator: u32) -> u128 { + let mut hasher = Hasher::new(); + + for op in self.0.iter() { + hasher.update_op(*op); + } + hasher.finalize(domain_separator) + } + + /// Get the operation at a given index + pub fn op_at(&self, i: usize) -> Option<&SpongeOp> { + self.0.get(i) + } +} + +// A large 128-bit prime, per https://primes.utm.edu/lists/2small/100bit.html. +const HASHER_BASE: u128 = (0 - 159) as u128; + +#[derive(Clone, Copy, Debug)] +pub(crate) struct Hasher { + x: u128, + x_i: u128, + state: u128, + current_op: SpongeOp, +} + +impl Default for Hasher { + fn default() -> Self { + Self { + x: HASHER_BASE, + x_i: 1, + state: 0, + current_op: SpongeOp::Absorb(0), + } + } +} + +impl Hasher { + pub(crate) fn new() -> Self { + Default::default() + } + + /// Update hasher's current op to coalesce absorb/squeeze runs. + pub(crate) fn update_op(&mut self, op: SpongeOp) { + if self.current_op.matches(op) { + self.current_op = self.current_op.combine(op) + } else { + self.finish_op(); + self.current_op = op; + } + } + + fn finish_op(&mut self) { + if self.current_op.count() == 0 { + return; + }; + let op_value = self.current_op.value(); + + self.update(op_value); + } + + pub(crate) fn update(&mut self, a: u32) { + self.x_i = self.x_i.overflowing_mul(self.x).0; + self.state = self + .state + .overflowing_add(self.x_i.overflowing_mul(u128::from(a)).0) + .0; + } + + pub(crate) fn finalize(&mut self, domain_separator: u32) -> u128 { + self.finish_op(); + self.update(domain_separator); + self.state + } +} + +impl SpongeOp { + /// Reset the SpongeOp + pub const fn reset(&self) -> Self { + match self { + Self::Absorb(_) => Self::Squeeze(0), + Self::Squeeze(_) => Self::Absorb(0), + } + } + + /// Return the count of the SpongeOp + pub const fn count(&self) -> u32 { + match self { + Self::Absorb(n) | Self::Squeeze(n) => *n, + } + } + + /// Return true if the SpongeOp is absorb + pub const fn is_absorb(&self) -> bool { + matches!(self, Self::Absorb(_)) + } + + /// Return true if the SpongeOp is squeeze + pub const fn is_squeeze(&self) -> bool { + matches!(self, Self::Squeeze(_)) + } + + /// Combine two SpongeOps + pub fn combine(&self, other: Self) -> Self { + assert!(self.matches(other)); + + match self { + Self::Absorb(n) => Self::Absorb(n + other.count()), + Self::Squeeze(n) => Self::Squeeze(n + other.count()), + } + } + + /// Check if two SpongeOps match + pub const fn matches(&self, other: Self) -> bool { + self.is_absorb() == other.is_absorb() + } + + /// Return the value of the SpongeOp + pub fn value(&self) -> u32 { + match self { + Self::Absorb(n) => { + assert_eq!(0, n >> 31); + n + (1 << 31) + } + Self::Squeeze(n) => { + assert_eq!(0, n >> 31); + *n + } + } + } +} + +/// Sponge API trait +pub trait SpongeAPI> { + /// Accumulator type + type Acc; + /// Value type + type Value; + + /// Optional `domain_separator` defaults to 0 + fn start(&mut self, p: IOPattern, domain_separator: Option, _: &mut Self::Acc); + /// Perform an absorb operation + fn absorb(&mut self, length: u32, elements: &[Self::Value], acc: &mut Self::Acc); + /// Perform a squeeze operation + fn squeeze(&mut self, length: u32, acc: &mut Self::Acc) -> Vec; + /// Finish the sponge operation + fn finish(&mut self, _: &mut Self::Acc) -> Result<(), Error>; +} + +pub trait InnerSpongeAPI> { + type Acc; + type Value; + + fn initialize_capacity(&mut self, tag: u128, acc: &mut Self::Acc); + fn read_rate_element(&mut self, offset: usize) -> Self::Value; + fn add_rate_element(&mut self, offset: usize, x: &Self::Value); + fn permute(&mut self, acc: &mut Self::Acc); + + // Supplemental methods needed for a generic implementation. + fn rate(&self) -> usize; + fn absorb_pos(&self) -> usize; + fn squeeze_pos(&self) -> usize; + fn set_absorb_pos(&mut self, pos: usize); + fn set_squeeze_pos(&mut self, pos: usize); + + fn add(a: Self::Value, b: &Self::Value) -> Self::Value; + + fn initialize_state(&mut self, p_value: u128, acc: &mut Self::Acc) { + self.initialize_capacity(p_value, acc); + + for i in 0..self.rate() { + self.add_rate_element(i, &Self::zero()); + } + } + + fn pattern(&self) -> &IOPattern; + fn set_pattern(&mut self, pattern: IOPattern); + + fn increment_io_count(&mut self) -> usize; + + fn zero() -> Self::Value; +} + +impl, S: InnerSpongeAPI> SpongeAPI for S { + type Acc = >::Acc; + type Value = >::Value; + + fn start(&mut self, p: IOPattern, domain_separator: Option, acc: &mut Self::Acc) { + let p_value = p.value(domain_separator.unwrap_or(0)); + + self.set_pattern(p); + self.initialize_state(p_value, acc); + + self.set_absorb_pos(0); + self.set_squeeze_pos(0); + } + + fn absorb(&mut self, length: u32, elements: &[Self::Value], acc: &mut Self::Acc) { + assert_eq!(length as usize, elements.len()); + let rate = self.rate(); + + for element in elements.iter() { + if self.absorb_pos() == rate { + self.permute(acc); + self.set_absorb_pos(0); + } + let old = self.read_rate_element(self.absorb_pos()); + self.add_rate_element(self.absorb_pos(), &S::add(old, element)); + self.set_absorb_pos(self.absorb_pos() + 1); + } + let op = SpongeOp::Absorb(length); + let old_count = self.increment_io_count(); + assert_eq!(Some(&op), self.pattern().op_at(old_count)); + + self.set_squeeze_pos(rate); + } + + fn squeeze(&mut self, length: u32, acc: &mut Self::Acc) -> Vec { + let rate = self.rate(); + + let mut out = Vec::with_capacity(length as usize); + + for _ in 0..length { + if self.squeeze_pos() == rate { + self.permute(acc); + self.set_squeeze_pos(0); + self.set_absorb_pos(0); + } + out.push(self.read_rate_element(self.squeeze_pos())); + self.set_squeeze_pos(self.squeeze_pos() + 1); + } + let op = SpongeOp::Squeeze(length); + let old_count = self.increment_io_count(); + assert_eq!(Some(&op), self.pattern().op_at(old_count)); + + out + } + + fn finish(&mut self, acc: &mut Self::Acc) -> Result<(), Error> { + // Clear state. + self.initialize_state(0, acc); + let final_io_count = self.increment_io_count(); + + if final_io_count == self.pattern().0.len() { + Ok(()) + } else { + Err(Error::ParameterUsageMismatch) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_tag_values() { + let test = |expected_value: u128, pattern: IOPattern, domain_separator: u32| { + assert_eq!(expected_value, pattern.value(domain_separator)); + }; + + test(0, IOPattern(vec![]), 0); + test( + 340282366920938463463374607431768191899, + IOPattern(vec![]), + 123, + ); + test( + 340282366920938463463374607090318361668, + IOPattern(vec![SpongeOp::Absorb(2), SpongeOp::Squeeze(2)]), + 0, + ); + test( + 340282366920938463463374607090314341989, + IOPattern(vec![SpongeOp::Absorb(2), SpongeOp::Squeeze(2)]), + 1, + ); + test( + 340282366920938463463374607090318361668, + IOPattern(vec![SpongeOp::Absorb(2), SpongeOp::Squeeze(2)]), + 0, + ); + test( + 340282366920938463463374607090318361668, + IOPattern(vec![ + SpongeOp::Absorb(1), + SpongeOp::Absorb(1), + SpongeOp::Squeeze(2), + ]), + 0, + ); + test( + 340282366920938463463374607090318361668, + IOPattern(vec![ + SpongeOp::Absorb(1), + SpongeOp::Absorb(1), + SpongeOp::Squeeze(1), + SpongeOp::Squeeze(1), + ]), + 0, + ); + } +} diff --git a/src/provider/poseidon/sponge/circuit.rs b/src/provider/poseidon/sponge/circuit.rs new file mode 100644 index 00000000..36023fda --- /dev/null +++ b/src/provider/poseidon/sponge/circuit.rs @@ -0,0 +1,249 @@ +use crate::frontend::num::AllocatedNum; +use crate::frontend::util_cs::witness_cs::SizedWitness; +use crate::frontend::{ConstraintSystem, Namespace, SynthesisError}; +use crate::provider::poseidon::circuit2::{Elt, PoseidonCircuit2}; +use crate::provider::poseidon::poseidon_inner::{Arity, Poseidon, PoseidonConstants}; +use crate::provider::poseidon::sponge::{ + api::{IOPattern, InnerSpongeAPI}, + vanilla::{Direction, Mode, SpongeTrait}, +}; + +use ff::PrimeField; +use std::collections::VecDeque; +use std::marker::PhantomData; + +/// The Poseidon sponge circuit +pub struct SpongeCircuit<'a, F, A, C> +where + F: PrimeField, + A: Arity, + C: ConstraintSystem, +{ + constants: &'a PoseidonConstants, + mode: Mode, + direction: Direction, + absorbed: usize, + squeezed: usize, + squeeze_pos: usize, + permutation_count: usize, + state: PoseidonCircuit2<'a, F, A>, + queue: VecDeque>, + pattern: IOPattern, + io_count: usize, + poseidon: Poseidon<'a, F, A>, + _c: PhantomData, +} + +impl<'a, F: PrimeField, A: Arity, CS: 'a + ConstraintSystem> SpongeTrait<'a, F, A> + for SpongeCircuit<'a, F, A, CS> +{ + type Acc = Namespace<'a, F, CS>; + type Elt = Elt; + type Error = SynthesisError; + + fn new_with_constants(constants: &'a PoseidonConstants, mode: Mode) -> Self { + Self { + mode, + direction: Direction::Absorbing, + constants, + absorbed: 0, + squeezed: 0, + squeeze_pos: 0, + permutation_count: 0, + state: PoseidonCircuit2::new_empty::(constants), + queue: VecDeque::with_capacity(A::to_usize()), + pattern: IOPattern(Vec::new()), + poseidon: Poseidon::new(constants), + io_count: 0, + _c: Default::default(), + } + } + + fn mode(&self) -> Mode { + self.mode + } + fn direction(&self) -> Direction { + self.direction + } + fn set_direction(&mut self, direction: Direction) { + self.direction = direction; + } + fn absorbed(&self) -> usize { + self.absorbed + } + fn set_absorbed(&mut self, absorbed: usize) { + self.absorbed = absorbed; + } + fn squeezed(&self) -> usize { + self.squeezed + } + fn set_squeezed(&mut self, squeezed: usize) { + self.squeezed = squeezed; + } + fn squeeze_pos(&self) -> usize { + self.squeeze_pos + } + fn set_squeeze_pos(&mut self, squeeze_pos: usize) { + self.squeeze_pos = squeeze_pos; + } + fn absorb_pos(&self) -> usize { + self.state.pos - 1 + } + fn set_absorb_pos(&mut self, pos: usize) { + self.state.pos = pos + 1; + } + + fn element(&self, index: usize) -> Self::Elt { + self.state.elements[index].clone() + } + + fn set_element(&mut self, index: usize, elt: Self::Elt) { + self.poseidon.elements[index] = elt.val().unwrap(); + self.state.elements[index] = elt; + } + + fn make_elt(&self, val: F, ns: &mut Self::Acc) -> Self::Elt { + let allocated = AllocatedNum::alloc_infallible(ns, || val); + Elt::Allocated(allocated) + } + + fn rate(&self) -> usize { + A::to_usize() + } + + fn capacity(&self) -> usize { + 1 + } + + fn size(&self) -> usize { + self.constants.width() + } + + fn constants(&self) -> &PoseidonConstants { + self.constants + } + + fn pad(&mut self) { + self.state.apply_padding::(); + } + + fn permute_state(&mut self, ns: &mut Self::Acc) -> Result<(), Self::Error> { + self.permutation_count += 1; + + if ns.is_witness_generator() { + self.poseidon.generate_witness_into_cs(ns); + + for (elt, scalar) in self + .state + .elements + .iter_mut() + .zip(self.poseidon.elements.iter()) + { + *elt = Elt::num_from_fr::(*scalar); + } + } else { + self + .state + .hash(&mut ns.namespace(|| format!("permutation {}", self.permutation_count)))?; + }; + + Ok(()) + } + + fn enqueue(&mut self, elt: Self::Elt) { + self.queue.push_back(elt); + } + fn dequeue(&mut self) -> Option { + self.queue.pop_front() + } + + fn squeeze_aux(&mut self) -> Self::Elt { + let squeezed = self.element(SpongeTrait::squeeze_pos(self) + SpongeTrait::capacity(self)); + SpongeTrait::set_squeeze_pos(self, SpongeTrait::squeeze_pos(self) + 1); + + squeezed + } + + fn absorb_aux(&mut self, elt: &Self::Elt) -> Self::Elt { + // Elt::add always returns `Ok`, so `unwrap` is safe. + self + .element(SpongeTrait::absorb_pos(self) + SpongeTrait::capacity(self)) + .add_ref(elt) + .unwrap() + } + + fn squeeze_elements(&mut self, count: usize, ns: &mut Self::Acc) -> Vec { + let mut elements = Vec::with_capacity(count); + for _ in 0..count { + if let Ok(Some(squeezed)) = self.squeeze(ns) { + elements.push(squeezed); + } + } + elements + } +} + +impl<'a, F: PrimeField, A: Arity, CS: 'a + ConstraintSystem> InnerSpongeAPI + for SpongeCircuit<'a, F, A, CS> +{ + type Acc = Namespace<'a, F, CS>; + type Value = Elt; + + fn initialize_capacity(&mut self, tag: u128, _acc: &mut Self::Acc) { + let mut repr = F::Repr::default(); + repr.as_mut()[..16].copy_from_slice(&tag.to_le_bytes()); + + let f = F::from_repr(repr).unwrap(); + let elt = Elt::num_from_fr::(f); + self.set_element(0, elt); + } + + fn read_rate_element(&mut self, offset: usize) -> Self::Value { + self.element(offset + SpongeTrait::capacity(self)) + } + fn add_rate_element(&mut self, offset: usize, x: &Self::Value) { + self.set_element(offset + SpongeTrait::capacity(self), x.clone()); + } + fn permute(&mut self, acc: &mut Self::Acc) { + SpongeTrait::permute(self, acc).unwrap(); + } + + // Supplemental methods needed for a generic implementation. + + fn zero() -> Elt { + Elt::num_from_fr::(F::ZERO) + } + + fn rate(&self) -> usize { + SpongeTrait::rate(self) + } + fn absorb_pos(&self) -> usize { + SpongeTrait::absorb_pos(self) + } + fn squeeze_pos(&self) -> usize { + SpongeTrait::squeeze_pos(self) + } + fn set_absorb_pos(&mut self, pos: usize) { + SpongeTrait::set_absorb_pos(self, pos); + } + fn set_squeeze_pos(&mut self, pos: usize) { + SpongeTrait::set_squeeze_pos(self, pos); + } + fn add(a: Elt, b: &Elt) -> Elt { + a.add_ref(b).unwrap() + } + + fn pattern(&self) -> &IOPattern { + &self.pattern + } + + fn set_pattern(&mut self, pattern: IOPattern) { + self.pattern = pattern + } + + fn increment_io_count(&mut self) -> usize { + let old_count = self.io_count; + self.io_count += 1; + old_count + } +} diff --git a/src/provider/poseidon/sponge/mod.rs b/src/provider/poseidon/sponge/mod.rs new file mode 100644 index 00000000..6e83c524 --- /dev/null +++ b/src/provider/poseidon/sponge/mod.rs @@ -0,0 +1,3 @@ +pub mod api; +pub mod circuit; +pub mod vanilla; diff --git a/src/provider/poseidon/sponge/vanilla.rs b/src/provider/poseidon/sponge/vanilla.rs new file mode 100644 index 00000000..bda984cd --- /dev/null +++ b/src/provider/poseidon/sponge/vanilla.rs @@ -0,0 +1,551 @@ +use crate::provider::poseidon::hash_type::HashType; +use crate::provider::poseidon::poseidon_inner::{Arity, Poseidon, PoseidonConstants}; +use crate::provider::poseidon::sponge::api::{IOPattern, InnerSpongeAPI}; +use crate::provider::poseidon::{PoseidonError, Strength}; +use ff::PrimeField; +use std::collections::VecDeque; + +// General information on sponge construction: https://keccak.team/files/CSF-0.1.pdf + +/* +A sponge can be instantiated in either simplex or duplex mode. Once instantiated, a sponge's mode never changes. + +At any time, a sponge is operating in one of two directions: squeezing or absorbing. All sponges are initialized in the +absorbing direction. The number of absorbed field elements is incremented each time an element is absorbed and +decremented each time an element is squeezed. In duplex mode, the count of currently absorbed elements can never +decrease below zero, so only as many elements as have been absorbed can be squeezed at any time. In simplex mode, there +is no limit on the number of elements that can be squeezed, once absorption is complete. + +In simplex mode, absorbing and squeezing cannot be interleaved. First all elements are absorbed, then all needed +elements are squeezed. At most the number of elements which were absorbed can be squeezed. Elements must be absorbed in +chunks of R (rate). After every R chunks have been absorbed, the state is permuted. After the final element has been +absorbed, any needed padding is added, and the final permutation (or two -- if required by padding) is performed. Then +groups of R field elements are squeezed, and the state is permuted after each group of R elements has been squeezed. +After squeezing is complete, a simplex sponge is exhausted, and no further absorption is possible. + +In duplex mode, absorbing and squeezing can be interleaved. The state is permuted after every R elements have been +absorbed. This makes R elements available to be squeezed. If elements remain to be squeezed when the state is permuted, +remaining unsqueezed elements are queued. Otherwise they would be lost when permuting. + +*/ + +pub enum SpongeMode { + SimplexAbsorb, + SimplexSqueeze, + DuplexAbsorb, + DuplexSqueeze, +} + +/// Mode of the sponge +#[derive(Clone, Copy)] +pub enum Mode { + /// Simplex mode + Simplex, + /// Duplex mode + Duplex, +} + +#[derive(Clone, Copy)] +pub enum Direction { + Absorbing, + Squeezing, +} + +/// Poseidon sponge +pub struct Sponge<'a, F: PrimeField, A: Arity> { + absorbed: usize, + squeezed: usize, + /// Poseidon state + pub state: Poseidon<'a, F, A>, + mode: Mode, + direction: Direction, + squeeze_pos: usize, + queue: VecDeque, + pattern: IOPattern, + io_count: usize, +} + +/// Sponge trait +pub trait SpongeTrait<'a, F: PrimeField, A: Arity> +where + Self: Sized, +{ + /// Accumulator type + type Acc; + /// Value type + type Elt; + /// Error type + type Error; + + /// Create a new sponge with the given constants and mode + fn new_with_constants(constants: &'a PoseidonConstants, mode: Mode) -> Self; + + /// Return simplex constants + fn simplex_constants(size: usize) -> PoseidonConstants { + PoseidonConstants::new_constant_length(size) + } + + /// Return duplex constants + fn duplex_constants() -> PoseidonConstants { + PoseidonConstants::new_constant_length(0) + } + + /// Return API constants + fn api_constants(strength: Strength) -> PoseidonConstants { + PoseidonConstants::new_with_strength_and_type(strength, HashType::Sponge) + } + + /// Return the mode of the sponge + fn mode(&self) -> Mode; + /// Return the direction of the sponge + fn direction(&self) -> Direction; + /// Set the direction of the sponge + fn set_direction(&mut self, direction: Direction); + /// Return the number of absorbed elements + fn absorbed(&self) -> usize; + /// Set the number of absorbed elements + fn set_absorbed(&mut self, absorbed: usize); + /// Return the number of squeezed elements + fn squeezed(&self) -> usize; + /// Set the number of squeezed elements + fn set_squeezed(&mut self, squeezed: usize); + /// Return the squeeze position + fn squeeze_pos(&self) -> usize; + /// Set the squeeze position + fn set_squeeze_pos(&mut self, squeeze_pos: usize); + /// Return the absorb position + fn absorb_pos(&self) -> usize; + /// Set the absorb position + fn set_absorb_pos(&mut self, pos: usize); + + /// Return the element at the given index + fn element(&self, index: usize) -> Self::Elt; + /// Set the element at the given index + fn set_element(&mut self, index: usize, elt: Self::Elt); + + /// Make elt + #[deprecated(since = "0.10.0")] + fn make_elt(&self, val: F, acc: &mut Self::Acc) -> Self::Elt; + + /// Return whether the sponge is in simplex mode + fn is_simplex(&self) -> bool { + match self.mode() { + Mode::Simplex => true, + Mode::Duplex => false, + } + } + /// Return whether the sponge is in duplex mode + fn is_duplex(&self) -> bool { + match self.mode() { + Mode::Duplex => true, + Mode::Simplex => false, + } + } + + /// Return whether the sponge is absorbing + fn is_absorbing(&self) -> bool { + match self.direction() { + Direction::Absorbing => true, + Direction::Squeezing => false, + } + } + + /// Return whether the sponge is squeezing + fn is_squeezing(&self) -> bool { + match self.direction() { + Direction::Squeezing => true, + Direction::Absorbing => false, + } + } + + /// Return the number of available elements + fn available(&self) -> usize { + self.absorbed() - self.squeezed() + } + + /// Return whether the sponge is immediately squeezable + fn is_immediately_squeezable(&self) -> bool { + self.squeeze_pos() < self.absorb_pos() + } + + /// Return the rate of the sponge + fn rate(&self) -> usize; + + /// Return the capacity of the sponge + fn capacity(&self) -> usize; + + /// Return the size of the sponge + fn size(&self) -> usize; + + /// Return the total size of the sponge + fn total_size(&self) -> usize { + assert!(self.is_simplex()); + match self.constants().hash_type { + HashType::ConstantLength(l) => l, + HashType::VariableLength => unimplemented!(), + _ => A::to_usize(), + } + } + + /// Return the constants of the sponge + fn constants(&self) -> &PoseidonConstants; + + /// Return whether the sponge can squeeze without permuting + fn can_squeeze_without_permuting(&self) -> bool { + self.squeeze_pos() < self.size() - self.capacity() + } + + /// Return whether the sponge is exhausted + fn is_exhausted(&self) -> bool { + // Exhaustion only applies to simplex. + self.is_simplex() && self.squeezed() >= self.total_size() + } + + /// Ensure the sponge is absorbing + fn ensure_absorbing(&mut self) { + match self.direction() { + Direction::Absorbing => (), + Direction::Squeezing => { + if self.is_simplex() { + panic!("Simplex sponge cannot absorb after squeezing."); + } else { + self.set_direction(Direction::Absorbing); + } + } + } + } + + /// Permute the sponge + fn permute(&mut self, acc: &mut Self::Acc) -> Result<(), Self::Error> { + // NOTE: this will apply any needed padding in the partially-absorbed case. + // However, padding should only be applied when no more elements will be absorbed. + // A duplex sponge should never apply padding implicitly, and a simplex sponge should only do so when it is + // about to apply its final permutation. + let unpermuted = self.absorb_pos(); + let needs_padding = self.is_absorbing() && unpermuted < self.rate(); + + if needs_padding { + match self.mode() { + Mode::Duplex => { + panic!("Duplex sponge must permute exactly `rate` absorbed elements.") + } + Mode::Simplex => { + let final_permutation = self.squeezed() % self.total_size() <= self.rate(); + assert!( + final_permutation, + "Simplex sponge may only pad before final permutation" + ); + self.pad(); + } + } + } + + self.permute_state(acc)?; + self.set_absorb_pos(0); + self.set_squeeze_pos(0); + Ok(()) + } + + /// permutate the sponge state + fn pad(&mut self); + + /// Permute the sponge state + fn permute_state(&mut self, acc: &mut Self::Acc) -> Result<(), Self::Error>; + + /// Ensure the sponge is squeezing + fn ensure_squeezing(&mut self, acc: &mut Self::Acc) -> Result<(), Self::Error> { + match self.direction() { + Direction::Squeezing => (), + Direction::Absorbing => { + match self.mode() { + Mode::Simplex => { + let done_squeezing_previous = self.squeeze_pos() >= self.rate(); + let partially_absorbed = self.absorb_pos() > 0; + + if done_squeezing_previous || partially_absorbed { + self.permute(acc)?; + } + } + Mode::Duplex => (), + } + self.set_direction(Direction::Squeezing); + } + } + Ok(()) + } + + /// Squeeze Aux + fn squeeze_aux(&mut self) -> Self::Elt; + + /// Absorb Aux + fn absorb_aux(&mut self, elt: &Self::Elt) -> Self::Elt; + + /// Absorb one field element + fn absorb(&mut self, elt: &Self::Elt, acc: &mut Self::Acc) -> Result<(), Self::Error> { + self.ensure_absorbing(); + + // Add input element to state and advance absorption position. + let tmp = self.absorb_aux(elt); + self.set_element(self.absorb_pos() + self.capacity(), tmp); + self.set_absorb_pos(self.absorb_pos() + 1); + + // When position equals size, we need to permute. + if self.absorb_pos() >= self.rate() { + if self.is_duplex() { + // When we permute, existing unsqueezed elements will be lost. Enqueue them. + while self.is_immediately_squeezable() { + let elt = self.squeeze_aux(); + self.enqueue(elt); + } + } + + self.permute(acc)?; + } + + self.set_absorbed(self.absorbed() + 1); + Ok(()) + } + + /// Perform a squeeze operation + fn squeeze(&mut self, acc: &mut Self::Acc) -> Result, Self::Error> { + self.ensure_squeezing(acc)?; + + if self.is_duplex() && self.available() == 0 { + // What has not yet been absorbed cannot be squeezed. + return Ok(None); + }; + + self.set_squeezed(self.squeezed() + 1); + + if let Some(queued) = self.dequeue() { + return Ok(Some(queued)); + } + + if !self.can_squeeze_without_permuting() && self.is_simplex() { + self.permute(acc)?; + } + + let squeezed = self.squeeze_aux(); + + Ok(Some(squeezed)) + } + + /// Enqueue an element + fn enqueue(&mut self, elt: Self::Elt); + + /// Dequeue an element + fn dequeue(&mut self) -> Option; + + /// Absorb a slice of field elements + fn absorb_elements( + &mut self, + elts: &[Self::Elt], + acc: &mut Self::Acc, + ) -> Result<(), Self::Error> { + for elt in elts { + self.absorb(elt, acc)?; + } + Ok(()) + } + + /// Squeeze a number of elements + fn squeeze_elements(&mut self, count: usize, acc: &mut Self::Acc) -> Vec; +} + +impl<'a, F: PrimeField, A: Arity> SpongeTrait<'a, F, A> for Sponge<'a, F, A> { + type Acc = (); + type Elt = F; + type Error = PoseidonError; + + fn new_with_constants(constants: &'a PoseidonConstants, mode: Mode) -> Self { + let poseidon = Poseidon::new(constants); + + Self { + mode, + direction: Direction::Absorbing, + state: poseidon, + absorbed: 0, + squeezed: 0, + squeeze_pos: 0, + queue: VecDeque::with_capacity(A::to_usize()), + pattern: IOPattern(Vec::new()), + io_count: 0, + } + } + + fn mode(&self) -> Mode { + self.mode + } + fn direction(&self) -> Direction { + self.direction + } + fn set_direction(&mut self, direction: Direction) { + self.direction = direction; + } + fn absorbed(&self) -> usize { + self.absorbed + } + fn set_absorbed(&mut self, absorbed: usize) { + self.absorbed = absorbed; + } + fn squeezed(&self) -> usize { + self.squeezed + } + fn set_squeezed(&mut self, squeezed: usize) { + self.squeezed = squeezed; + } + fn squeeze_pos(&self) -> usize { + self.squeeze_pos + } + fn set_squeeze_pos(&mut self, squeeze_pos: usize) { + self.squeeze_pos = squeeze_pos; + } + fn absorb_pos(&self) -> usize { + self.state.pos - 1 + } + fn set_absorb_pos(&mut self, pos: usize) { + self.state.pos = pos + 1; + } + + fn element(&self, index: usize) -> Self::Elt { + self.state.elements[index] + } + fn set_element(&mut self, index: usize, elt: Self::Elt) { + self.state.elements[index] = elt; + } + + fn make_elt(&self, val: F, _acc: &mut Self::Acc) -> Self::Elt { + val + } + + fn rate(&self) -> usize { + A::to_usize() + } + + fn capacity(&self) -> usize { + 1 + } + + fn size(&self) -> usize { + self.state.constants.width() + } + + fn constants(&self) -> &PoseidonConstants { + self.state.constants + } + + fn pad(&mut self) { + self.state.apply_padding(); + } + + fn permute_state(&mut self, _acc: &mut Self::Acc) -> Result<(), Self::Error> { + self.state.hash(); + Ok(()) + } + + fn enqueue(&mut self, elt: Self::Elt) { + self.queue.push_back(elt); + } + fn dequeue(&mut self) -> Option { + self.queue.pop_front() + } + + fn squeeze_aux(&mut self) -> Self::Elt { + let squeezed = self.element(SpongeTrait::squeeze_pos(self) + SpongeTrait::capacity(self)); + SpongeTrait::set_squeeze_pos(self, SpongeTrait::squeeze_pos(self) + 1); + + squeezed + } + + fn absorb_aux(&mut self, elt: &Self::Elt) -> Self::Elt { + self.element(SpongeTrait::absorb_pos(self) + SpongeTrait::capacity(self)) + elt + } + + fn absorb_elements(&mut self, elts: &[F], acc: &mut Self::Acc) -> Result<(), Self::Error> { + for elt in elts { + self.absorb(elt, acc)?; + } + Ok(()) + } + + fn squeeze_elements(&mut self, count: usize, _acc: &mut ()) -> Vec { + self.take(count).collect() + } +} + +impl> Iterator for Sponge<'_, F, A> { + type Item = F; + + fn next(&mut self) -> Option { + self.squeeze(&mut ()).unwrap_or(None) + } + + fn size_hint(&self) -> (usize, Option) { + match self.mode { + Mode::Duplex => (self.available(), None), + Mode::Simplex => (0, None), + } + } +} + +impl> InnerSpongeAPI for Sponge<'_, F, A> { + type Acc = (); + type Value = F; + + fn initialize_capacity(&mut self, tag: u128, _: &mut ()) { + let mut repr = F::Repr::default(); + repr.as_mut()[..16].copy_from_slice(&tag.to_le_bytes()); + + let f = F::from_repr(repr).unwrap(); + self.set_element(0, f); + } + + fn read_rate_element(&mut self, offset: usize) -> F { + self.element(offset + SpongeTrait::capacity(self)) + } + fn add_rate_element(&mut self, offset: usize, x: &F) { + self.set_element(offset + SpongeTrait::capacity(self), *x); + } + fn permute(&mut self, acc: &mut ()) { + SpongeTrait::permute(self, acc).unwrap(); + } + + // Supplemental methods needed for a generic implementation. + + fn zero() -> F { + F::ZERO + } + + fn rate(&self) -> usize { + SpongeTrait::rate(self) + } + fn absorb_pos(&self) -> usize { + SpongeTrait::absorb_pos(self) + } + fn squeeze_pos(&self) -> usize { + SpongeTrait::squeeze_pos(self) + } + fn set_absorb_pos(&mut self, pos: usize) { + SpongeTrait::set_absorb_pos(self, pos); + } + fn set_squeeze_pos(&mut self, pos: usize) { + SpongeTrait::set_squeeze_pos(self, pos); + } + fn add(a: F, b: &F) -> F { + a + b + } + + fn pattern(&self) -> &IOPattern { + &self.pattern + } + + fn set_pattern(&mut self, pattern: IOPattern) { + self.pattern = pattern + } + + fn increment_io_count(&mut self) -> usize { + let old_count = self.io_count; + self.io_count += 1; + old_count + } +} diff --git a/src/spartan/direct.rs b/src/spartan/direct.rs index 7cc2cf30..fef5a062 100644 --- a/src/spartan/direct.rs +++ b/src/spartan/direct.rs @@ -1,13 +1,14 @@ //! This module provides interfaces to directly prove a step circuit by using Spartan SNARK. //! In particular, it supports any SNARK that implements `RelaxedR1CSSNARK` trait //! (e.g., with the SNARKs implemented in ppsnark.rs or snark.rs). +use crate::frontend::{num::AllocatedNum, Circuit, ConstraintSystem, SynthesisError}; use crate::{ - bellpepper::{ + errors::NovaError, + frontend::{ r1cs::{NovaShape, NovaWitness}, shape_cs::ShapeCS, solver::SatisfyingAssignment, }, - errors::NovaError, r1cs::{R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness}, traits::{ circuit::StepCircuit, @@ -17,7 +18,6 @@ use crate::{ }, Commitment, CommitmentKey, DerandKey, }; -use bellpepper_core::{num::AllocatedNum, Circuit, ConstraintSystem, SynthesisError}; use core::marker::PhantomData; use ff::Field; use serde::{Deserialize, Serialize}; @@ -184,8 +184,8 @@ impl, C: StepCircuit> DirectSN #[cfg(test)] mod tests { use super::*; + use crate::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use crate::provider::{Bn256EngineKZG, PallasEngine, Secp256k1Engine}; - use ::bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::marker::PhantomData; use ff::PrimeField; diff --git a/src/traits/circuit.rs b/src/traits/circuit.rs index e7dfa61e..4f8a2da0 100644 --- a/src/traits/circuit.rs +++ b/src/traits/circuit.rs @@ -1,5 +1,5 @@ //! This module defines traits that a step function must implement -use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; +use crate::frontend::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::marker::PhantomData; use ff::PrimeField; diff --git a/src/traits/mod.rs b/src/traits/mod.rs index f26572ee..b9dc7eec 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -1,6 +1,6 @@ //! This module defines various traits required by the users of the library to implement. use crate::errors::NovaError; -use bellpepper_core::{boolean::AllocatedBit, num::AllocatedNum, ConstraintSystem, SynthesisError}; +use crate::frontend::{boolean::AllocatedBit, num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::fmt::Debug; use ff::{PrimeField, PrimeFieldBits}; use num_bigint::BigInt;