From ccbd64df76c573432699ded42580cfef9b6591ca Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 2 Apr 2024 18:46:13 +0200 Subject: [PATCH] feat(tfhe): add zk-pok code base - integration of work done by Sarah in the repo Co-authored-by: sarah el kazdadi --- .github/workflows/aws_tfhe_fast_tests.yml | 4 + .github/workflows/aws_tfhe_tests.yml | 4 + .github/workflows/m1_tests.yml | 4 + Cargo.toml | 1 + Makefile | 12 +- tfhe-zk-pok/.gitignore | 1 + tfhe-zk-pok/Cargo.toml | 21 + tfhe-zk-pok/src/curve_446/mod.rs | 772 +++++++++++++++ tfhe-zk-pok/src/curve_api.rs | 341 +++++++ tfhe-zk-pok/src/curve_api/bls12_381.rs | 776 +++++++++++++++ tfhe-zk-pok/src/curve_api/bls12_446.rs | 970 +++++++++++++++++++ tfhe-zk-pok/src/lib.rs | 3 + tfhe-zk-pok/src/proofs/binary.rs | 219 +++++ tfhe-zk-pok/src/proofs/index.rs | 127 +++ tfhe-zk-pok/src/proofs/mod.rs | 95 ++ tfhe-zk-pok/src/proofs/pke.rs | 1043 +++++++++++++++++++++ tfhe-zk-pok/src/proofs/range.rs | 360 +++++++ tfhe-zk-pok/src/proofs/rlwe.rs | 936 ++++++++++++++++++ 18 files changed, 5688 insertions(+), 1 deletion(-) create mode 100644 tfhe-zk-pok/.gitignore create mode 100644 tfhe-zk-pok/Cargo.toml create mode 100644 tfhe-zk-pok/src/curve_446/mod.rs create mode 100644 tfhe-zk-pok/src/curve_api.rs create mode 100644 tfhe-zk-pok/src/curve_api/bls12_381.rs create mode 100644 tfhe-zk-pok/src/curve_api/bls12_446.rs create mode 100644 tfhe-zk-pok/src/lib.rs create mode 100644 tfhe-zk-pok/src/proofs/binary.rs create mode 100644 tfhe-zk-pok/src/proofs/index.rs create mode 100644 tfhe-zk-pok/src/proofs/mod.rs create mode 100644 tfhe-zk-pok/src/proofs/pke.rs create mode 100644 tfhe-zk-pok/src/proofs/range.rs create mode 100644 tfhe-zk-pok/src/proofs/rlwe.rs diff --git a/.github/workflows/aws_tfhe_fast_tests.yml b/.github/workflows/aws_tfhe_fast_tests.yml index 1b0de70e16..04f277f2b4 100644 --- a/.github/workflows/aws_tfhe_fast_tests.yml +++ b/.github/workflows/aws_tfhe_fast_tests.yml @@ -60,6 +60,10 @@ jobs: run: | make test_concrete_csprng + - name: Run tfhe-zk-pok tests + run: | + make test_zk_pok + - name: Run core tests run: | AVX512_SUPPORT=ON make test_core_crypto diff --git a/.github/workflows/aws_tfhe_tests.yml b/.github/workflows/aws_tfhe_tests.yml index bb0c6e88cc..ed62543e89 100644 --- a/.github/workflows/aws_tfhe_tests.yml +++ b/.github/workflows/aws_tfhe_tests.yml @@ -61,6 +61,10 @@ jobs: run: | make test_concrete_csprng + - name: Run tfhe-zk-pok tests + run: | + make test_zk_pok + - name: Run core tests run: | AVX512_SUPPORT=ON make test_core_crypto diff --git a/.github/workflows/m1_tests.yml b/.github/workflows/m1_tests.yml index fab631a89e..b3ee586c8e 100644 --- a/.github/workflows/m1_tests.yml +++ b/.github/workflows/m1_tests.yml @@ -74,6 +74,10 @@ jobs: run: | make test_concrete_csprng + - name: Run tfhe-zk-pok tests + run: | + make test_zk_pok + - name: Run core tests run: | make test_core_crypto diff --git a/Cargo.toml b/Cargo.toml index b569ef857e..1fdc61b275 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ "tfhe", + "tfhe-zk-pok", "tasks", "apps/trivium", "concrete-csprng", diff --git a/Makefile b/Makefile index 01f4c13e30..86e2a697dc 100644 --- a/Makefile +++ b/Makefile @@ -283,9 +283,14 @@ clippy_concrete_csprng: --features=$(TARGET_ARCH_FEATURE) \ -p concrete-csprng -- --no-deps -D warnings +.PHONY: clippy_zk_pok # Run clippy lints on tfhe-zk-pok +clippy_zk_pok: + RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \ + -p tfhe-zk-pok -- --no-deps -D warnings + .PHONY: clippy_all # Run all clippy targets clippy_all: clippy clippy_boolean clippy_shortint clippy_integer clippy_all_targets clippy_c_api \ -clippy_js_wasm_api clippy_tasks clippy_core clippy_concrete_csprng clippy_trivium +clippy_js_wasm_api clippy_tasks clippy_core clippy_concrete_csprng clippy_zk_pok clippy_trivium .PHONY: clippy_fast # Run main clippy targets clippy_fast: clippy clippy_all_targets clippy_c_api clippy_js_wasm_api clippy_tasks clippy_core \ @@ -627,6 +632,11 @@ test_concrete_csprng: RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ --features=$(TARGET_ARCH_FEATURE) -p concrete-csprng +.PHONY: test_zk_pok # Run tfhe-zk-pok tests +test_zk_pok: + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ + -p tfhe-zk-pok + .PHONY: doc # Build rust doc doc: install_rs_check_toolchain @# Even though we are not in docs.rs, this allows to "just" build the doc diff --git a/tfhe-zk-pok/.gitignore b/tfhe-zk-pok/.gitignore new file mode 100644 index 0000000000..ea8c4bf7f3 --- /dev/null +++ b/tfhe-zk-pok/.gitignore @@ -0,0 +1 @@ +/target diff --git a/tfhe-zk-pok/Cargo.toml b/tfhe-zk-pok/Cargo.toml new file mode 100644 index 0000000000..b1b0abc9cb --- /dev/null +++ b/tfhe-zk-pok/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tfhe-zk-pok" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ark-bls12-381 = "0.4.0" +ark-ec = "0.4.2" +ark-ff = "0.4.2" +ark-poly = "0.4.2" +rand = "0.8.5" +rayon = "1.8.0" +sha3 = "0.10.8" +serde = { version = "~1.0", features = ["derive"] } +ark-serialize = { version = "0.4.2" } +zeroize = "1.7.0" + +[dev-dependencies] +serde_json = "~1.0" diff --git a/tfhe-zk-pok/src/curve_446/mod.rs b/tfhe-zk-pok/src/curve_446/mod.rs new file mode 100644 index 0000000000..96fcecb9d0 --- /dev/null +++ b/tfhe-zk-pok/src/curve_446/mod.rs @@ -0,0 +1,772 @@ +use ark_ec::bls12::{Bls12, Bls12Config, TwistType}; +use ark_ff::fields::*; +use ark_ff::MontFp; + +#[derive(MontConfig)] +#[modulus = "645383785691237230677916041525710377746967055506026847120930304831624105190538527824412673"] +#[generator = "7"] +#[small_subgroup_base = "3"] +#[small_subgroup_power = "1"] +pub struct FrConfig; +pub type Fr = Fp320>; + +#[derive(MontConfig)] +#[modulus = "172824703542857155980071276579495962243492693522789898437834836356385656662277472896902502740297183690175962001546428467344062165330603"] +#[generator = "2"] +#[small_subgroup_base = "3"] +#[small_subgroup_power = "1"] +pub struct FqConfig; +pub type Fq = Fp448>; + +pub type Fq2 = Fp2; + +pub struct Fq2Config; + +impl Fp2Config for Fq2Config { + type Fp = Fq; + + /// NONRESIDUE = -1 + const NONRESIDUE: Fq = MontFp!("-1"); + + /// Coefficients for the Frobenius automorphism. + const FROBENIUS_COEFF_FP2_C1: &'static [Fq] = &[ + // Fq(-1)**(((q^0) - 1) / 2) + Fq::ONE, + // Fq(-1)**(((q^1) - 1) / 2) + MontFp!("-1"), + ]; + + #[inline(always)] + fn mul_fp_by_nonresidue_in_place(fp: &mut Self::Fp) -> &mut Self::Fp { + fp.neg_in_place() + } + + #[inline(always)] + fn sub_and_mul_fp_by_nonresidue(y: &mut Self::Fp, x: &Self::Fp) { + *y += x; + } + + #[inline(always)] + fn mul_fp_by_nonresidue_plus_one_and_add(y: &mut Self::Fp, x: &Self::Fp) { + *y = *x; + } + + fn mul_fp_by_nonresidue_and_add(y: &mut Self::Fp, x: &Self::Fp) { + y.neg_in_place(); + *y += x; + } +} + +pub type Fq6 = Fp6; + +#[derive(Clone, Copy)] +pub struct Fq6Config; + +impl Fp6Config for Fq6Config { + type Fp2Config = Fq2Config; + + /// NONRESIDUE = (U + 1) + const NONRESIDUE: Fq2 = Fq2::new(Fq::ONE, Fq::ONE); + + const FROBENIUS_COEFF_FP6_C1: &'static [Fq2] = &[ + // Fp2::NONRESIDUE^(((q^0) - 1) / 3) + Fq2::new( + Fq::ONE, + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^1) - 1) / 3) + Fq2::new( + Fq::ZERO, + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + ), + // Fp2::NONRESIDUE^(((q^2) - 1) / 3) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^3) - 1) / 3) + Fq2::new( + Fq::ZERO, + Fq::ONE, + ), + // Fp2::NONRESIDUE^(((q^4) - 1) / 3) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^5) - 1) / 3) + Fq2::new( + Fq::ZERO, + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + ), + ]; + + const FROBENIUS_COEFF_FP6_C2: &'static [Fq2] = &[ + // Fq2(u + 1)**(((2q^0) - 2) / 3) + Fq2::new( + Fq::ONE, + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^1) - 2) / 3) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^2) - 2) / 3) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^3) - 2) / 3) + Fq2::new( + MontFp!("-1"), + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^4) - 2) / 3) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^5) - 2) / 3) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + ]; + + /// Multiply this element by the quadratic nonresidue 1 + u. + /// Make this generic. + fn mul_fp2_by_nonresidue_in_place(fe: &mut Fq2) -> &mut Fq2 { + let t0 = fe.c0; + fe.c0 -= &fe.c1; + fe.c1 += &t0; + fe + } +} + +pub type Fq12 = Fp12; + +#[derive(Clone, Copy)] +pub struct Fq12Config; + +impl Fp12Config for Fq12Config { + type Fp6Config = Fq6Config; + + const NONRESIDUE: Fq6 = Fq6::new(Fq2::ZERO, Fq2::ONE, Fq2::ZERO); + + const FROBENIUS_COEFF_FP12_C1: &'static [Fq2] = &[ + // Fp2::NONRESIDUE^(((q^0) - 1) / 6) + Fq2::new( + Fq::ONE, + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^1) - 1) / 6) + Fq2::new( + MontFp!("22118644822122453894295732432166425368368980329889476319266915965514828099635526724748286229964921634997234117686841299669336163301597"), + MontFp!("-22118644822122453894295732432166425368368980329889476319266915965514828099635526724748286229964921634997234117686841299669336163301597"), + ), + // Fp2::NONRESIDUE^(((q^2) - 1) / 6) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^3) - 1) / 6) + Fq2::new( + MontFp!("-84459159508829117195668840503504856816171858703899096210464197465513610215112549935889502423482516188066933947513637464187184810836060"), + MontFp!("84459159508829117195668840503504856816171858703899096210464197465513610215112549935889502423482516188066933947513637464187184810836060"), + ), + // Fp2::NONRESIDUE^(((q^4) - 1) / 6) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^5) - 1) / 6) + Fq2::new( + MontFp!("66246899211905584890106703643824680058951854489001325908103722925357218347529396236264714086849745867111793936345949703487541191192946"), + MontFp!("-66246899211905584890106703643824680058951854489001325908103722925357218347529396236264714086849745867111793936345949703487541191192946"), + ), + // Fp2::NONRESIDUE^(((q^6) - 1) / 6) + Fq2::new( + MontFp!("-1"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^7) - 1) / 6) + Fq2::new( + MontFp!("-22118644822122453894295732432166425368368980329889476319266915965514828099635526724748286229964921634997234117686841299669336163301597"), + MontFp!("22118644822122453894295732432166425368368980329889476319266915965514828099635526724748286229964921634997234117686841299669336163301597"), + ), + // Fp2::NONRESIDUE^(((q^8) - 1) / 6) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^9) - 1) / 6) + Fq2::new( + MontFp!("84459159508829117195668840503504856816171858703899096210464197465513610215112549935889502423482516188066933947513637464187184810836060"), + MontFp!("-84459159508829117195668840503504856816171858703899096210464197465513610215112549935889502423482516188066933947513637464187184810836060"), + ), + // Fp2::NONRESIDUE^(((q^10) - 1) / 6) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^11) - 1) / 6) + Fq2::new( + MontFp!("-66246899211905584890106703643824680058951854489001325908103722925357218347529396236264714086849745867111793936345949703487541191192946"), + MontFp!("66246899211905584890106703643824680058951854489001325908103722925357218347529396236264714086849745867111793936345949703487541191192946"), + ), + ]; +} + +pub type Bls12_446 = Bls12; +use g1::G1Affine; +use g2::G2Affine; + +pub struct Config; + +impl Bls12Config for Config { + const X: &'static [u64] = &[0x8204000000020001, 0x600]; + const X_IS_NEGATIVE: bool = true; + const TWIST_TYPE: TwistType = TwistType::M; + type Fp = Fq; + type Fp2Config = Fq2Config; + type Fp6Config = Fq6Config; + type Fp12Config = Fq12Config; + type G1Config = g1::Config; + type G2Config = g2::Config; +} + +pub mod util { + use ark_ec::short_weierstrass::Affine; + use ark_ec::AffineRepr; + use ark_ff::{BigInteger448, PrimeField}; + use ark_serialize::SerializationError; + + use super::g1::Config as G1Config; + use super::g2::Config as G2Config; + use super::{Fq, Fq2, G1Affine, G2Affine}; + + pub const G1_SERIALIZED_SIZE: usize = 57; + pub const G2_SERIALIZED_SIZE: usize = 114; + + pub struct EncodingFlags { + pub is_compressed: bool, + pub is_infinity: bool, + pub is_lexographically_largest: bool, + } + + impl EncodingFlags { + pub fn get_flags(bytes: &[u8]) -> Self { + let compression_flag_set = (bytes[0] >> 7) & 1; + let infinity_flag_set = (bytes[0] >> 6) & 1; + let sort_flag_set = (bytes[0] >> 5) & 1; + + Self { + is_compressed: compression_flag_set == 1, + is_infinity: infinity_flag_set == 1, + is_lexographically_largest: sort_flag_set == 1, + } + } + pub fn encode_flags(&self, bytes: &mut [u8]) { + if self.is_compressed { + bytes[0] |= 1 << 7; + } + + if self.is_infinity { + bytes[0] |= 1 << 6; + } + + if self.is_compressed && !self.is_infinity && self.is_lexographically_largest { + bytes[0] |= 1 << 5; + } + } + } + + pub(crate) fn deserialize_fq(bytes: [u8; 56]) -> Option { + let mut tmp = BigInteger448::new([0, 0, 0, 0, 0, 0, 0]); + + // Note: The following unwraps are if the compiler cannot convert + // the byte slice into [u8;8], we know this is infallible since we + // are providing the indices at compile time and bytes has a fixed size + tmp.0[6] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()); + tmp.0[5] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()); + tmp.0[4] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()); + tmp.0[3] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()); + tmp.0[2] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()); + tmp.0[1] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()); + tmp.0[0] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()); + + Fq::from_bigint(tmp) + } + + pub(crate) fn serialize_fq(field: Fq) -> [u8; 56] { + let mut result = [0u8; 56]; + + let rep = field.into_bigint(); + + result[0..8].copy_from_slice(&rep.0[6].to_be_bytes()); + result[8..16].copy_from_slice(&rep.0[5].to_be_bytes()); + result[16..24].copy_from_slice(&rep.0[4].to_be_bytes()); + result[24..32].copy_from_slice(&rep.0[3].to_be_bytes()); + result[32..40].copy_from_slice(&rep.0[2].to_be_bytes()); + result[40..48].copy_from_slice(&rep.0[1].to_be_bytes()); + result[48..56].copy_from_slice(&rep.0[0].to_be_bytes()); + + result + } + + pub(crate) fn read_fq_with_offset( + bytes: &[u8], + offset: usize, + ) -> Result { + let mut tmp = [0; G1_SERIALIZED_SIZE - 1]; + // read `G1_SERIALIZED_SIZE` bytes + tmp.copy_from_slice( + &bytes[offset * G1_SERIALIZED_SIZE + 1..G1_SERIALIZED_SIZE * (offset + 1)], + ); + + deserialize_fq(tmp).ok_or(SerializationError::InvalidData) + } + + pub(crate) fn read_g1_compressed( + mut reader: R, + ) -> Result, ark_serialize::SerializationError> { + let mut bytes = [0u8; G1_SERIALIZED_SIZE]; + reader + .read_exact(&mut bytes) + .ok() + .ok_or(SerializationError::InvalidData)?; + + // Obtain the three flags from the start of the byte sequence + let flags = EncodingFlags::get_flags(&bytes[..]); + + // we expect to be deserializing a compressed point + if !flags.is_compressed { + return Err(SerializationError::UnexpectedFlags); + } + + if flags.is_infinity { + return Ok(G1Affine::zero()); + } + + // Attempt to obtain the x-coordinate + let x = read_fq_with_offset(&bytes, 0)?; + + let p = G1Affine::get_point_from_x_unchecked(x, flags.is_lexographically_largest) + .ok_or(SerializationError::InvalidData)?; + + Ok(p) + } + + pub(crate) fn read_g1_uncompressed( + mut reader: R, + ) -> Result, ark_serialize::SerializationError> { + let mut bytes = [0u8; 2 * G1_SERIALIZED_SIZE]; + reader + .read_exact(&mut bytes) + .map_err(|_| SerializationError::InvalidData)?; + + // Obtain the three flags from the start of the byte sequence + let flags = EncodingFlags::get_flags(&bytes[..]); + + // we expect to be deserializing an uncompressed point + if flags.is_compressed { + return Err(SerializationError::UnexpectedFlags); + } + + if flags.is_infinity { + return Ok(G1Affine::zero()); + } + + // Attempt to obtain the x-coordinate + let x = read_fq_with_offset(&bytes, 0)?; + // Attempt to obtain the y-coordinate + let y = read_fq_with_offset(&bytes, 1)?; + + let p = G1Affine::new_unchecked(x, y); + + Ok(p) + } + + pub(crate) fn read_g2_compressed( + mut reader: R, + ) -> Result, ark_serialize::SerializationError> { + let mut bytes = [0u8; G2_SERIALIZED_SIZE]; + reader + .read_exact(&mut bytes) + .map_err(|_| SerializationError::InvalidData)?; + + // Obtain the three flags from the start of the byte sequence + let flags = EncodingFlags::get_flags(&bytes); + + // we expect to be deserializing a compressed point + if !flags.is_compressed { + return Err(SerializationError::UnexpectedFlags); + } + + if flags.is_infinity { + return Ok(G2Affine::zero()); + } + + // Attempt to obtain the x-coordinate + let xc1 = read_fq_with_offset(&bytes, 0)?; + let xc0 = read_fq_with_offset(&bytes, 1)?; + + let x = Fq2::new(xc0, xc1); + + let p = G2Affine::get_point_from_x_unchecked(x, flags.is_lexographically_largest) + .ok_or(SerializationError::InvalidData)?; + + Ok(p) + } + + pub(crate) fn read_g2_uncompressed( + mut reader: R, + ) -> Result, ark_serialize::SerializationError> { + let mut bytes = [0u8; 2 * G2_SERIALIZED_SIZE]; + reader + .read_exact(&mut bytes) + .map_err(|_| SerializationError::InvalidData)?; + + // Obtain the three flags from the start of the byte sequence + let flags = EncodingFlags::get_flags(&bytes); + + // we expect to be deserializing an uncompressed point + if flags.is_compressed { + return Err(SerializationError::UnexpectedFlags); + } + + if flags.is_infinity { + return Ok(G2Affine::zero()); + } + + // Attempt to obtain the x-coordinate + let xc1 = read_fq_with_offset(&bytes, 0)?; + let xc0 = read_fq_with_offset(&bytes, 1)?; + let x = Fq2::new(xc0, xc1); + + // Attempt to obtain the y-coordinate + let yc1 = read_fq_with_offset(&bytes, 2)?; + let yc0 = read_fq_with_offset(&bytes, 3)?; + let y = Fq2::new(yc0, yc1); + + let p = G2Affine::new_unchecked(x, y); + + Ok(p) + } +} + +pub mod g1 { + use super::util::{ + read_g1_compressed, read_g1_uncompressed, serialize_fq, EncodingFlags, G1_SERIALIZED_SIZE, + }; + use super::{Fq, Fr}; + use ark_ec::bls12::Bls12Config; + use ark_ec::models::CurveConfig; + use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; + use ark_ec::{bls12, AffineRepr, Group}; + use ark_ff::{Field, MontFp, One, PrimeField, Zero}; + use ark_serialize::{Compress, SerializationError}; + use core::ops::Neg; + + #[derive(Clone, Default, PartialEq, Eq)] + pub struct Config; + + impl CurveConfig for Config { + type BaseField = Fq; + type ScalarField = Fr; + + /// COFACTOR = (x - 1)^2 / 3 = 267785939737784928360481681640896166738700972 + const COFACTOR: &'static [u64] = &[0xad5aaaac0002aaac, 0x2602b0055d560ab0, 0xc0208]; + + /// COFACTOR_INV = COFACTOR^{-1} mod r + /// = 645383785691237230677779421366207365261112665008071669867241543525136277620937226389553150 + const COFACTOR_INV: Fr = + MontFp!("645383785691237230677779421366207365261112665008071669867241543525136277620937226389553150"); + } + + pub type G1Affine = bls12::G1Affine; + pub type G1Projective = bls12::G1Projective; + + impl SWCurveConfig for Config { + /// COEFF_A = 0 + const COEFF_A: Fq = Fq::ZERO; + + /// COEFF_B = 1 + const COEFF_B: Fq = MontFp!("1"); + + /// AFFINE_GENERATOR_COEFFS = (G1_GENERATOR_X, G1_GENERATOR_Y) + const GENERATOR: G1Affine = G1Affine::new_unchecked(G1_GENERATOR_X, G1_GENERATOR_Y); + + #[inline(always)] + fn mul_by_a(_: Self::BaseField) -> Self::BaseField { + Self::BaseField::zero() + } + + #[inline] + fn is_in_correct_subgroup_assuming_on_curve(p: &G1Affine) -> bool { + // Algorithm from Section 6 of https://eprint.iacr.org/2021/1130. + // + // Check that endomorphism_p(P) == -[X^2]P + + // An early-out optimization described in Section 6. + // If uP == P but P != point of infinity, then the point is not in the right + // subgroup. + let x_times_p = p.mul_bigint(super::Config::X); + if x_times_p.eq(p) && !p.infinity { + return false; + } + + let minus_x_squared_times_p = x_times_p.mul_bigint(super::Config::X).neg(); + let endomorphism_p = endomorphism(p); + minus_x_squared_times_p.eq(&endomorphism_p) + } + + #[inline] + fn clear_cofactor(p: &G1Affine) -> G1Affine { + // Using the effective cofactor, as explained in + // Section 5 of https://eprint.iacr.org/2019/403.pdf. + // + // It is enough to multiply by (1 - x), instead of (x - 1)^2 / 3 + let h_eff = one_minus_x().into_bigint(); + Config::mul_affine(p, h_eff.as_ref()).into() + } + + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result, ark_serialize::SerializationError> { + let p = if compress == ark_serialize::Compress::Yes { + read_g1_compressed(&mut reader)? + } else { + read_g1_uncompressed(&mut reader)? + }; + + if validate == ark_serialize::Validate::Yes + && !p.is_in_correct_subgroup_assuming_on_curve() + { + return Err(SerializationError::InvalidData); + } + Ok(p) + } + + fn serialize_with_mode( + item: &Affine, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), SerializationError> { + let encoding = EncodingFlags { + is_compressed: compress == ark_serialize::Compress::Yes, + is_infinity: item.is_zero(), + is_lexographically_largest: item.y > -item.y, + }; + let mut p = *item; + if encoding.is_infinity { + p = G1Affine::zero(); + } + // need to access the field struct `x` directly, otherwise we get None from xy() + // method + let x_bytes = serialize_fq(p.x); + if encoding.is_compressed { + let mut bytes = [0u8; G1_SERIALIZED_SIZE]; + bytes[1..].copy_from_slice(&x_bytes); + + encoding.encode_flags(&mut bytes); + writer.write_all(&bytes)?; + } else { + let mut bytes = [0u8; 2 * G1_SERIALIZED_SIZE]; + bytes[1..1 + G1_SERIALIZED_SIZE].copy_from_slice(&x_bytes[..]); + bytes[2 + G1_SERIALIZED_SIZE..].copy_from_slice(&serialize_fq(p.y)[..]); + + encoding.encode_flags(&mut bytes); + writer.write_all(&bytes)?; + }; + + Ok(()) + } + + fn serialized_size(compress: Compress) -> usize { + if compress == Compress::Yes { + G1_SERIALIZED_SIZE + } else { + G1_SERIALIZED_SIZE * 2 + } + } + } + + fn one_minus_x() -> Fr { + const X: Fr = Fr::from_sign_and_limbs(!super::Config::X_IS_NEGATIVE, super::Config::X); + Fr::one() - X + } + + /// G1_GENERATOR_X = + /// 143189966182216199425404656824735381247272236095050141599848381692039676741476615087722874458136990266833440576646963466074693171606778 + pub const G1_GENERATOR_X: Fq = MontFp!("143189966182216199425404656824735381247272236095050141599848381692039676741476615087722874458136990266833440576646963466074693171606778"); + + /// G1_GENERATOR_Y = + /// 75202396197342917254523279069469674666303680671605970245803554133573745859131002231546341942288521574682619325841484506619191207488304 + pub const G1_GENERATOR_Y: Fq = MontFp!("75202396197342917254523279069469674666303680671605970245803554133573745859131002231546341942288521574682619325841484506619191207488304"); + + /// BETA is a non-trivial cubic root of unity in Fq. + pub const BETA: Fq = MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"); + + pub fn endomorphism(p: &Affine) -> Affine { + // Endomorphism of the points on the curve. + // endomorphism_p(x,y) = (BETA * x, y) + // where BETA is a non-trivial cubic root of unity in Fq. + let mut res = *p; + res.x *= BETA; + res + } +} + +pub mod g2 { + use super::util::{ + read_g2_compressed, read_g2_uncompressed, serialize_fq, EncodingFlags, G2_SERIALIZED_SIZE, + }; + use super::*; + use ark_ec::models::CurveConfig; + use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; + use ark_ec::{bls12, AffineRepr}; + use ark_ff::{MontFp, Zero}; + use ark_serialize::{Compress, SerializationError}; + + pub type G2Affine = bls12::G2Affine; + pub type G2Projective = bls12::G2Projective; + + #[derive(Clone, Default, PartialEq, Eq)] + pub struct Config; + + impl CurveConfig for Config { + type BaseField = Fq2; + type ScalarField = Fr; + + /// COFACTOR = (x^8 - 4 x^7 + 5 x^6) - (4 x^4 + 6 x^3 - 4 x^2 - 4 x + 13) // + /// 9 + /// = 46280025648128091779281203587029183771098593081950199160533444883894201638329761721685747232785203763275581499269893683911356926248942802726857101798724933488377584092259436345573 + const COFACTOR: &'static [u64] = &[ + 0xce555594000638e5, + 0xa75088593e6a92ef, + 0xc81e026dd55b51d6, + 0x47f8e24b79369c54, + 0x74c3560ced298d51, + 0x7cefe5c3dd2555cb, + 0x657742bf55690156, + 0x5780484639bf731d, + 0x3988a06f1bb3444d, + 0x2daee, + ]; + + /// COFACTOR_INV = COFACTOR^{-1} mod r + /// 420747440395392227734782296805460539842466911252881029283882861015362447833828968293150382 + const COFACTOR_INV: Fr = MontFp!( + "420747440395392227734782296805460539842466911252881029283882861015362447833828968293150382" + ); + } + + impl SWCurveConfig for Config { + /// COEFF_A = [0, 0] + const COEFF_A: Fq2 = Fq2::new(g1::Config::COEFF_A, g1::Config::COEFF_A); + + /// COEFF_B = [1, 1] + const COEFF_B: Fq2 = Fq2::new(g1::Config::COEFF_B, g1::Config::COEFF_B); + + /// AFFINE_GENERATOR_COEFFS = (G2_GENERATOR_X, G2_GENERATOR_Y) + const GENERATOR: G2Affine = G2Affine::new_unchecked(G2_GENERATOR_X, G2_GENERATOR_Y); + + #[inline(always)] + fn mul_by_a(_: Self::BaseField) -> Self::BaseField { + Self::BaseField::zero() + } + + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result, ark_serialize::SerializationError> { + let p = if compress == ark_serialize::Compress::Yes { + read_g2_compressed(&mut reader)? + } else { + read_g2_uncompressed(&mut reader)? + }; + + if validate == ark_serialize::Validate::Yes + && !p.is_in_correct_subgroup_assuming_on_curve() + { + return Err(SerializationError::InvalidData); + } + Ok(p) + } + + fn serialize_with_mode( + item: &Affine, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), SerializationError> { + let encoding = EncodingFlags { + is_compressed: compress == ark_serialize::Compress::Yes, + is_infinity: item.is_zero(), + is_lexographically_largest: item.y > -item.y, + }; + let mut p = *item; + if encoding.is_infinity { + p = G2Affine::zero(); + } + + let mut x_bytes = [0u8; G2_SERIALIZED_SIZE]; + let c1_bytes = serialize_fq(p.x.c1); + let c0_bytes = serialize_fq(p.x.c0); + x_bytes[1..56 + 1].copy_from_slice(&c1_bytes[..]); + x_bytes[56 + 2..114].copy_from_slice(&c0_bytes[..]); + if encoding.is_compressed { + let mut bytes: [u8; G2_SERIALIZED_SIZE] = x_bytes; + + encoding.encode_flags(&mut bytes); + writer.write_all(&bytes)?; + } else { + let mut bytes = [0u8; 2 * G2_SERIALIZED_SIZE]; + + let mut y_bytes = [0u8; G2_SERIALIZED_SIZE]; + let c1_bytes = serialize_fq(p.y.c1); + let c0_bytes = serialize_fq(p.y.c0); + y_bytes[1..56 + 1].copy_from_slice(&c1_bytes[..]); + y_bytes[56 + 2..114].copy_from_slice(&c0_bytes[..]); + bytes[0..G2_SERIALIZED_SIZE].copy_from_slice(&x_bytes); + bytes[G2_SERIALIZED_SIZE..].copy_from_slice(&y_bytes); + + encoding.encode_flags(&mut bytes); + writer.write_all(&bytes)?; + }; + + Ok(()) + } + + fn serialized_size(compress: ark_serialize::Compress) -> usize { + if compress == Compress::Yes { + G2_SERIALIZED_SIZE + } else { + 2 * G2_SERIALIZED_SIZE + } + } + } + + pub const G2_GENERATOR_X: Fq2 = Fq2::new(G2_GENERATOR_X_C0, G2_GENERATOR_X_C1); + pub const G2_GENERATOR_Y: Fq2 = Fq2::new(G2_GENERATOR_Y_C0, G2_GENERATOR_Y_C1); + + /// G2_GENERATOR_X_C0 = + /// 96453755443802578867745476081903764610578492683850270111202389209355548711427786327510993588141991264564812146530214503491136289085725 + pub const G2_GENERATOR_X_C0: Fq = MontFp!("96453755443802578867745476081903764610578492683850270111202389209355548711427786327510993588141991264564812146530214503491136289085725"); + + /// G2_GENERATOR_X_C1 = + /// 85346509177292795277012009839788781950274202400882571466460158277083221521663169974265433098009350061415973662678938824527658049065530 + pub const G2_GENERATOR_X_C1: Fq = MontFp!("85346509177292795277012009839788781950274202400882571466460158277083221521663169974265433098009350061415973662678938824527658049065530"); + + /// G2_GENERATOR_Y_C0 = + /// 49316184343270950587272132771103279293158283984999436491292404103501221698714795975575879957605051223501287444864258801515822358837529 + pub const G2_GENERATOR_Y_C0: Fq = MontFp!("49316184343270950587272132771103279293158283984999436491292404103501221698714795975575879957605051223501287444864258801515822358837529"); + + /// G2_GENERATOR_Y_C1 = + /// 107680854723992552431070996218129928499826544031468382031848626814251381379173928074140221537929995580031433096217223703806029068859074 + pub const G2_GENERATOR_Y_C1: Fq = MontFp!("107680854723992552431070996218129928499826544031468382031848626814251381379173928074140221537929995580031433096217223703806029068859074"); +} diff --git a/tfhe-zk-pok/src/curve_api.rs b/tfhe-zk-pok/src/curve_api.rs new file mode 100644 index 0000000000..88a9ac47ef --- /dev/null +++ b/tfhe-zk-pok/src/curve_api.rs @@ -0,0 +1,341 @@ +use ark_ec::{CurveGroup, Group, VariableBaseMSM}; +use ark_ff::{BigInt, Field, MontFp, Zero}; +use ark_poly::univariate::DensePolynomial; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; +use core::fmt; +use core::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; +use serde::{Deserialize, Serialize}; + +fn ark_se(a: &A, s: S) -> Result +where + S: serde::Serializer, +{ + let mut bytes = vec![]; + a.serialize_with_mode(&mut bytes, Compress::Yes) + .map_err(serde::ser::Error::custom)?; + s.serialize_bytes(&bytes) +} + +fn ark_de<'de, D, A: CanonicalDeserialize>(data: D) -> Result +where + D: serde::de::Deserializer<'de>, +{ + let s: Vec = serde::de::Deserialize::deserialize(data)?; + let a = A::deserialize_with_mode(s.as_slice(), Compress::Yes, Validate::Yes); + a.map_err(serde::de::Error::custom) +} + +struct MontIntDisplay<'a, T>(&'a T); + +impl fmt::Debug for MontIntDisplay<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if *self.0 == T::ZERO { + f.write_str("0") + } else { + f.write_fmt(format_args!("{}", self.0)) + } + } +} + +pub mod bls12_381; +pub mod bls12_446; + +pub trait FieldOps: + Copy + + Send + + Sync + + core::ops::AddAssign + + core::ops::SubAssign + + core::ops::Add + + core::ops::Sub + + core::ops::Mul + + core::ops::Div + + core::ops::Neg + + core::iter::Sum +{ + const ZERO: Self; + const ONE: Self; + + fn from_u128(n: u128) -> Self; + fn from_u64(n: u64) -> Self; + fn from_i64(n: i64) -> Self; + fn to_bytes(self) -> impl AsRef<[u8]>; + fn rand(rng: &mut dyn rand::RngCore) -> Self; + fn hash(values: &mut [Self], data: &[&[u8]]); + fn poly_mul(p: &[Self], q: &[Self]) -> Vec; + fn poly_sub(p: &[Self], q: &[Self]) -> Vec { + use core::iter::zip; + let mut out = vec![Self::ZERO; Ord::max(p.len(), q.len())]; + + for (out, (p, q)) in zip( + &mut out, + zip( + p.iter().copied().chain(core::iter::repeat(Self::ZERO)), + q.iter().copied().chain(core::iter::repeat(Self::ZERO)), + ), + ) { + *out = p - q; + } + + out + } + fn poly_add(p: &[Self], q: &[Self]) -> Vec { + use core::iter::zip; + let mut out = vec![Self::ZERO; Ord::max(p.len(), q.len())]; + + for (out, (p, q)) in zip( + &mut out, + zip( + p.iter().copied().chain(core::iter::repeat(Self::ZERO)), + q.iter().copied().chain(core::iter::repeat(Self::ZERO)), + ), + ) { + *out = p + q; + } + + out + } +} + +pub trait CurveGroupOps: + Copy + + Send + + Sync + + core::fmt::Debug + + core::ops::AddAssign + + core::ops::SubAssign + + core::ops::Add + + core::ops::Sub + + core::ops::Neg + + core::iter::Sum +{ + const ZERO: Self; + const GENERATOR: Self; + const BYTE_SIZE: usize; + + fn mul_scalar(self, scalar: Zp) -> Self; + fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self; + fn to_bytes(self) -> impl AsRef<[u8]>; + fn double(self) -> Self; +} + +pub trait PairingGroupOps: + Copy + + Send + + Sync + + PartialEq + + core::fmt::Debug + + core::ops::AddAssign + + core::ops::SubAssign + + core::ops::Add + + core::ops::Sub + + core::ops::Neg +{ + fn mul_scalar(self, scalar: Zp) -> Self; + fn pairing(x: G1, y: G2) -> Self; +} + +pub trait Curve { + type Zp: FieldOps; + type G1: CurveGroupOps + serde::Serialize + for<'de> serde::Deserialize<'de>; + type G2: CurveGroupOps + serde::Serialize + for<'de> serde::Deserialize<'de>; + type Gt: PairingGroupOps; +} + +impl FieldOps for bls12_381::Zp { + const ZERO: Self = Self::ZERO; + const ONE: Self = Self::ONE; + + fn from_u128(n: u128) -> Self { + Self::from_bigint([n as u64, (n >> 64) as u64, 0, 0]) + } + fn from_u64(n: u64) -> Self { + Self::from_u64(n) + } + fn from_i64(n: i64) -> Self { + Self::from_i64(n) + } + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + fn rand(rng: &mut dyn rand::RngCore) -> Self { + Self::rand(rng) + } + fn hash(values: &mut [Self], data: &[&[u8]]) { + Self::hash(values, data) + } + + fn poly_mul(p: &[Self], q: &[Self]) -> Vec { + let p = p.iter().map(|x| x.inner).collect(); + let q = q.iter().map(|x| x.inner).collect(); + let p = DensePolynomial { coeffs: p }; + let q = DensePolynomial { coeffs: q }; + (&p * &q) + .coeffs + .into_iter() + .map(|inner| bls12_381::Zp { inner }) + .collect() + } +} + +impl CurveGroupOps for bls12_381::G1 { + const ZERO: Self = Self::ZERO; + const GENERATOR: Self = Self::GENERATOR; + const BYTE_SIZE: usize = Self::BYTE_SIZE; + + fn mul_scalar(self, scalar: bls12_381::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self { + Self::multi_mul_scalar(bases, scalars) + } + + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + + fn double(self) -> Self { + self.double() + } +} + +impl CurveGroupOps for bls12_381::G2 { + const ZERO: Self = Self::ZERO; + const GENERATOR: Self = Self::GENERATOR; + const BYTE_SIZE: usize = Self::BYTE_SIZE; + + fn mul_scalar(self, scalar: bls12_381::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self { + Self::multi_mul_scalar(bases, scalars) + } + + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + + fn double(self) -> Self { + self.double() + } +} + +impl PairingGroupOps for bls12_381::Gt { + fn mul_scalar(self, scalar: bls12_381::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn pairing(x: bls12_381::G1, y: bls12_381::G2) -> Self { + Self::pairing(x, y) + } +} + +impl FieldOps for bls12_446::Zp { + const ZERO: Self = Self::ZERO; + const ONE: Self = Self::ONE; + + fn from_u128(n: u128) -> Self { + Self::from_bigint([n as u64, (n >> 64) as u64, 0, 0, 0]) + } + fn from_u64(n: u64) -> Self { + Self::from_u64(n) + } + fn from_i64(n: i64) -> Self { + Self::from_i64(n) + } + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + fn rand(rng: &mut dyn rand::RngCore) -> Self { + Self::rand(rng) + } + fn hash(values: &mut [Self], data: &[&[u8]]) { + Self::hash(values, data) + } + + fn poly_mul(p: &[Self], q: &[Self]) -> Vec { + let p = p.iter().map(|x| x.inner).collect(); + let q = q.iter().map(|x| x.inner).collect(); + let p = DensePolynomial { coeffs: p }; + let q = DensePolynomial { coeffs: q }; + (&p * &q) + .coeffs + .into_iter() + .map(|inner| bls12_446::Zp { inner }) + .collect() + } +} + +impl CurveGroupOps for bls12_446::G1 { + const ZERO: Self = Self::ZERO; + const GENERATOR: Self = Self::GENERATOR; + const BYTE_SIZE: usize = Self::BYTE_SIZE; + + fn mul_scalar(self, scalar: bls12_446::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self { + Self::multi_mul_scalar(bases, scalars) + } + + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + + fn double(self) -> Self { + self.double() + } +} + +impl CurveGroupOps for bls12_446::G2 { + const ZERO: Self = Self::ZERO; + const GENERATOR: Self = Self::GENERATOR; + const BYTE_SIZE: usize = Self::BYTE_SIZE; + + fn mul_scalar(self, scalar: bls12_446::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self { + Self::multi_mul_scalar(bases, scalars) + } + + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + + fn double(self) -> Self { + self.double() + } +} + +impl PairingGroupOps for bls12_446::Gt { + fn mul_scalar(self, scalar: bls12_446::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn pairing(x: bls12_446::G1, y: bls12_446::G2) -> Self { + Self::pairing(x, y) + } +} + +#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)] +pub struct Bls12_381; +#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)] +pub struct Bls12_446; + +impl Curve for Bls12_381 { + type Zp = bls12_381::Zp; + type G1 = bls12_381::G1; + type G2 = bls12_381::G2; + type Gt = bls12_381::Gt; +} +impl Curve for Bls12_446 { + type Zp = bls12_446::Zp; + type G1 = bls12_446::G1; + type G2 = bls12_446::G2; + type Gt = bls12_446::Gt; +} diff --git a/tfhe-zk-pok/src/curve_api/bls12_381.rs b/tfhe-zk-pok/src/curve_api/bls12_381.rs new file mode 100644 index 0000000000..a4bb50dc05 --- /dev/null +++ b/tfhe-zk-pok/src/curve_api/bls12_381.rs @@ -0,0 +1,776 @@ +use super::*; + +/// multiply EC point with scalar (= exponentiation in multiplicative notation) +fn mul_zp + Group>(x: T, scalar: Zp) -> T { + let zero = T::zero(); + let n: BigInt<4> = scalar.inner.into(); + + if n == BigInt([0; 4]) { + return zero; + } + + let mut y = zero; + let mut x = x; + + let n = n.0; + for word in n { + for idx in 0..64 { + let bit = (word >> idx) & 1; + if bit == 1 { + y += x; + } + x.double_in_place(); + } + } + y +} + +fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] { + let mut buf = [0u8; 6 * 8]; + for (i, &xi) in x.iter().enumerate() { + buf[i * 8..][..8].copy_from_slice(&xi.to_le_bytes()); + } + buf +} + +mod g1 { + use super::*; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct G1 { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: ark_bls12_381::G1Projective, + } + + impl fmt::Debug for G1 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("G1") + .field("x", &MontIntDisplay(&self.inner.x)) + .field("y", &MontIntDisplay(&self.inner.y)) + .field("z", &MontIntDisplay(&self.inner.z)) + .finish() + } + } + + impl G1 { + pub const ZERO: Self = Self { + inner: ark_bls12_381::G1Projective { + x: MontFp!("1"), + y: MontFp!("1"), + z: MontFp!("0"), + }, + }; + + // https://github.com/zcash/librustzcash/blob/6e0364cd42a2b3d2b958a54771ef51a8db79dd29/pairing/src/bls12_381/README.md#g1 + pub const GENERATOR: Self = Self { + inner: ark_bls12_381::G1Projective { + x: MontFp!("3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507"), + y: MontFp!("1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569"), + z: MontFp!("1"), + }, + }; + + // Size in number of bytes when the [to_bytes] + // function is called. + // This is not the size after serialization! + pub const BYTE_SIZE: usize = 2 * 6 * 8 + 1; + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { + use rayon::prelude::*; + let n_threads = rayon::current_num_threads(); + let chunk_size = bases.len().div_ceil(n_threads); + bases + .par_iter() + .map(|&x| x.inner.into_affine()) + .chunks(chunk_size) + .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) + .map(|(bases, scalars)| Self { + inner: ark_bls12_381::G1Projective::msm(&bases, &scalars).unwrap(), + }) + .sum::() + } + + pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + let g = self.inner.into_affine(); + let x = bigint_to_bytes(g.x.0 .0); + let y = bigint_to_bytes(g.y.0 .0); + let mut buf = [0u8; 2 * 6 * 8 + 1]; + buf[..6 * 8].copy_from_slice(&x); + buf[6 * 8..][..6 * 8].copy_from_slice(&y); + buf[2 * 6 * 8] = g.infinity as u8; + buf + } + + pub fn double(self) -> Self { + Self { + inner: self.inner.double(), + } + } + } + + impl Add for G1 { + type Output = G1; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + G1 { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for G1 { + type Output = G1; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + G1 { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for G1 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for G1 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl core::iter::Sum for G1 { + fn sum>(iter: I) -> Self { + iter.fold(G1::ZERO, Add::add) + } + } + + impl Neg for G1 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod g2 { + use super::*; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct G2 { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(super) inner: ark_bls12_381::G2Projective, + } + + impl fmt::Debug for G2 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(dead_code)] + #[derive(Debug)] + struct QuadExtField { + c0: T, + c1: T, + } + + f.debug_struct("G2") + .field( + "x", + &QuadExtField { + c0: MontIntDisplay(&self.inner.x.c0), + c1: MontIntDisplay(&self.inner.x.c1), + }, + ) + .field( + "y", + &QuadExtField { + c0: MontIntDisplay(&self.inner.y.c0), + c1: MontIntDisplay(&self.inner.y.c1), + }, + ) + .field( + "z", + &QuadExtField { + c0: MontIntDisplay(&self.inner.z.c0), + c1: MontIntDisplay(&self.inner.z.c1), + }, + ) + .finish() + } + } + + impl G2 { + pub const ZERO: Self = Self { + inner: ark_bls12_381::G2Projective { + x: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }, + y: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }, + z: ark_ff::QuadExtField { + c0: MontFp!("0"), + c1: MontFp!("0"), + }, + }, + }; + + // https://github.com/zcash/librustzcash/blob/6e0364cd42a2b3d2b958a54771ef51a8db79dd29/pairing/src/bls12_381/README.md#g2 + pub const GENERATOR: Self = Self { + inner: ark_bls12_381::G2Projective { + x: ark_ff::QuadExtField { + c0: MontFp!("352701069587466618187139116011060144890029952792775240219908644239793785735715026873347600343865175952761926303160"), + c1: MontFp!("3059144344244213709971259814753781636986470325476647558659373206291635324768958432433509563104347017837885763365758"), + }, + y: ark_ff::QuadExtField { + c0: MontFp!("1985150602287291935568054521177171638300868978215655730859378665066344726373823718423869104263333984641494340347905"), + c1: MontFp!("927553665492332455747201965776037880757740193453592970025027978793976877002675564980949289727957565575433344219582"), + }, + z: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0") , + }, + }, + }; + + // Size in number of bytes when the [to_bytes] + // function is called. + // This is not the size after serialization! + pub const BYTE_SIZE: usize = 4 * 6 * 8 + 1; + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { + use rayon::prelude::*; + let n_threads = rayon::current_num_threads(); + let chunk_size = bases.len().div_ceil(n_threads); + bases + .par_iter() + .map(|&x| x.inner.into_affine()) + .chunks(chunk_size) + .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) + .map(|(bases, scalars)| Self { + inner: ark_bls12_381::G2Projective::msm(&bases, &scalars).unwrap(), + }) + .sum::() + } + + pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + let g = self.inner.into_affine(); + let xc0 = bigint_to_bytes(g.x.c0.0 .0); + let xc1 = bigint_to_bytes(g.x.c1.0 .0); + let yc0 = bigint_to_bytes(g.y.c0.0 .0); + let yc1 = bigint_to_bytes(g.y.c1.0 .0); + let mut buf = [0u8; 4 * 6 * 8 + 1]; + buf[..6 * 8].copy_from_slice(&xc0); + buf[6 * 8..][..6 * 8].copy_from_slice(&xc1); + buf[2 * 6 * 8..][..6 * 8].copy_from_slice(&yc0); + buf[3 * 6 * 8..][..6 * 8].copy_from_slice(&yc1); + buf[4 * 6 * 8] = g.infinity as u8; + buf + } + + pub fn double(self) -> Self { + Self { + inner: self.inner.double(), + } + } + } + + impl Add for G2 { + type Output = G2; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + G2 { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for G2 { + type Output = G2; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + G2 { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for G2 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for G2 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl core::iter::Sum for G2 { + fn sum>(iter: I) -> Self { + iter.fold(G2::ZERO, Add::add) + } + } + + impl Neg for G2 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod gt { + use super::*; + use ark_ec::pairing::Pairing; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct Gt { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + inner: ark_ec::pairing::PairingOutput, + } + + impl fmt::Debug for Gt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(dead_code)] + #[derive(Debug)] + struct QuadExtField { + c0: T, + c1: T, + } + + #[allow(dead_code)] + #[derive(Debug)] + struct CubicExtField { + c0: T, + c1: T, + c2: T, + } + + #[allow(dead_code)] + #[derive(Debug)] + pub struct Gt { + inner: T, + } + + f.debug_struct("Gt") + .field( + "inner", + &Gt { + inner: QuadExtField { + c0: CubicExtField { + c0: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c0.c0), + c1: MontIntDisplay(&self.inner.0.c0.c0.c1), + }, + c1: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c1.c0), + c1: MontIntDisplay(&self.inner.0.c0.c1.c1), + }, + c2: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c2.c0), + c1: MontIntDisplay(&self.inner.0.c0.c2.c1), + }, + }, + c1: CubicExtField { + c0: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c0.c0), + c1: MontIntDisplay(&self.inner.0.c1.c0.c1), + }, + c1: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c1.c0), + c1: MontIntDisplay(&self.inner.0.c1.c1.c1), + }, + c2: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c2.c0), + c1: MontIntDisplay(&self.inner.0.c1.c2.c1), + }, + }, + }, + }, + ) + .finish() + } + } + + impl Gt { + pub fn pairing(g1: G1, g2: G2) -> Self { + Self { + inner: ark_bls12_381::Bls12_381::pairing(g1.inner, g2.inner), + } + } + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + } + + impl Add for Gt { + type Output = Gt; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Gt { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for Gt { + type Output = Gt; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Gt { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for Gt { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for Gt { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl Neg for Gt { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod zp { + use super::*; + use ark_ff::Fp; + use zeroize::Zeroize; + + fn redc(n: [u64; 4], nprime: u64, mut t: [u64; 6]) -> [u64; 4] { + for i in 0..2 { + let mut c = 0u64; + let m = u64::wrapping_mul(t[i], nprime); + + for j in 0..4 { + let x = t[i + j] as u128 + m as u128 * n[j] as u128 + c as u128; + t[i + j] = x as u64; + c = (x >> 64) as u64; + } + + for j in 4..6 - i { + let x = t[i + j] as u128 + c as u128; + t[i + j] = x as u64; + c = (x >> 64) as u64; + } + } + + let mut t = [t[2], t[3], t[4], t[5]]; + + if t.into_iter().rev().ge(n.into_iter().rev()) { + let mut o = false; + for i in 0..4 { + let (ti, o0) = u64::overflowing_sub(t[i], n[i]); + let (ti, o1) = u64::overflowing_sub(ti, o as u64); + o = o0 | o1; + t[i] = ti; + } + } + assert!(t.into_iter().rev().lt(n.into_iter().rev())); + + t + } + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Zeroize)] + #[repr(transparent)] + pub struct Zp { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: ark_bls12_381::Fr, + } + + impl fmt::Debug for Zp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Zp") + .field(&MontIntDisplay(&self.inner)) + .finish() + } + } + + impl Zp { + pub const ZERO: Self = Self { + inner: MontFp!("0"), + }; + + pub const ONE: Self = Self { + inner: MontFp!("1"), + }; + + pub fn from_bigint(n: [u64; 4]) -> Self { + Self { + inner: BigInt(n).into(), + } + } + + pub fn from_u64(n: u64) -> Self { + Self { + inner: BigInt([n, 0, 0, 0]).into(), + } + } + + pub fn from_i64(n: i64) -> Self { + let n_abs = Self::from_u64(n.unsigned_abs()); + if n > 0 { + n_abs + } else { + -n_abs + } + } + + pub fn to_bytes(self) -> [u8; 4 * 8] { + let buf = [ + self.inner.0 .0[0].to_le_bytes(), + self.inner.0 .0[1].to_le_bytes(), + self.inner.0 .0[2].to_le_bytes(), + self.inner.0 .0[3].to_le_bytes(), + ]; + unsafe { core::mem::transmute(buf) } + } + + fn from_raw_u64x6(n: [u64; 6]) -> Self { + const MODULUS: BigInt<4> = BigInt!( + "52435875175126190479447740508185965837690552500527637822603658699938581184513" + ); + + const MODULUS_MONTGOMERY: u64 = 18446744069414584319; + + let mut n = n; + // zero the two leading bits, so the result is <= MODULUS * 2^128 + n[5] &= (1 << 62) - 1; + Zp { + inner: Fp( + BigInt(redc(MODULUS.0, MODULUS_MONTGOMERY, n)), + core::marker::PhantomData, + ), + } + } + + pub fn rand(rng: &mut dyn rand::RngCore) -> Self { + use rand::Rng; + + Self::from_raw_u64x6([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + } + + pub fn hash(values: &mut [Zp], data: &[&[u8]]) { + use sha3::digest::{ExtendableOutput, Update, XofReader}; + + let mut hasher = sha3::Shake256::default(); + for data in data { + hasher.update(data); + } + let mut reader = hasher.finalize_xof(); + + for value in values { + let mut bytes = [0u8; 6 * 8]; + reader.read(&mut bytes); + let bytes: [[u8; 8]; 6] = unsafe { core::mem::transmute(bytes) }; + *value = Zp::from_raw_u64x6(bytes.map(u64::from_le_bytes)); + } + } + } + + impl Add for Zp { + type Output = Zp; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for Zp { + type Output = Zp; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner - rhs.inner, + } + } + } + + impl Mul for Zp { + type Output = Zp; + + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner * rhs.inner, + } + } + } + + impl Div for Zp { + type Output = Zp; + + #[inline] + fn div(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner / rhs.inner, + } + } + } + impl AddAssign for Zp { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for Zp { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl Neg for Zp { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } + + impl core::iter::Sum for Zp { + fn sum>(iter: I) -> Self { + iter.fold(Zp::ZERO, Add::add) + } + } +} + +pub use g1::G1; +pub use g2::G2; +pub use gt::Gt; +pub use zp::Zp; + +#[cfg(test)] +mod tests { + use rand::rngs::StdRng; + use rand::SeedableRng; + use std::collections::HashMap; + + use super::*; + + #[test] + fn test_distributivity() { + let a = Zp { + inner: MontFp!( + "20799633726231143268782044631117354647259165363910905818134484248029981143850" + ), + }; + let b = Zp { + inner: MontFp!( + "42333504039292951860879669847432876299949385605895551964353185488509497658948" + ), + }; + let c = Zp { + inner: MontFp!( + "6797004509292554067788526429737434481164547177696793280652530849910670196287" + ), + }; + + assert_eq!((((a - b) * c) - (a * c - b * c)).inner, Zp::ZERO.inner); + } + + #[test] + fn test_serialization() { + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let g_cur = G1::GENERATOR.mul_scalar(alpha); + let g_hat_cur = G2::GENERATOR.mul_scalar(alpha); + + let alpha2: Zp = serde_json::from_str(&serde_json::to_string(&alpha).unwrap()).unwrap(); + assert_eq!(alpha, alpha2); + + let g_cur2: G1 = serde_json::from_str(&serde_json::to_string(&g_cur).unwrap()).unwrap(); + assert_eq!(g_cur, g_cur2); + + let g_hat_cur2: G2 = + serde_json::from_str(&serde_json::to_string(&g_hat_cur).unwrap()).unwrap(); + assert_eq!(g_hat_cur, g_hat_cur2); + } + + #[test] + fn test_hasher_and_eq() { + // we need to make sure if the points are the same + // but the projective representations are different + // then they still hash into the same thing + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let a = G1::GENERATOR.mul_scalar(alpha); + + // serialization should convert the point to affine representation + // after deserializing it we should have the same point + // but with a different representation + let a_affine: G1 = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap(); + + // the internal elements should be different + assert_ne!(a.inner.x, a_affine.inner.x); + assert_ne!(a.inner.y, a_affine.inner.y); + assert_ne!(a.inner.z, a_affine.inner.z); + + // but equality and hasher should see the two as the same point + assert_eq!(a, a_affine); + let mut hm = HashMap::new(); + hm.insert(a, 1); + assert_eq!(hm.len(), 1); + hm.insert(a_affine, 2); + assert_eq!(hm.len(), 1); + } +} diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs new file mode 100644 index 0000000000..4769acb67b --- /dev/null +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -0,0 +1,970 @@ +use super::*; + +/// multiply EC point with scalar (= exponentiation in multiplicative notation) +fn mul_zp + Group>(x: T, scalar: Zp) -> T { + let zero = T::zero(); + let n: BigInt<5> = scalar.inner.into(); + + if n == BigInt([0; 5]) { + return zero; + } + + let mut y = zero; + let mut x = x; + + let n = n.0; + for word in n { + for idx in 0..64 { + let bit = (word >> idx) & 1; + if bit == 1 { + y += x; + } + x.double_in_place(); + } + } + y +} + +fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] { + let mut buf = [0u8; 7 * 8]; + for (i, &xi) in x.iter().enumerate() { + buf[i * 8..][..8].copy_from_slice(&xi.to_le_bytes()); + } + buf +} + +mod g1 { + use super::*; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct G1 { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: crate::curve_446::g1::G1Projective, + } + + impl fmt::Debug for G1 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("G1") + .field("x", &MontIntDisplay(&self.inner.x)) + .field("y", &MontIntDisplay(&self.inner.y)) + .field("z", &MontIntDisplay(&self.inner.z)) + .finish() + } + } + + impl G1 { + pub const ZERO: Self = Self { + inner: crate::curve_446::g1::G1Projective { + x: MontFp!("1"), + y: MontFp!("1"), + z: MontFp!("0"), + }, + }; + + pub const GENERATOR: Self = Self { + inner: crate::curve_446::g1::G1Projective { + x: MontFp!("143189966182216199425404656824735381247272236095050141599848381692039676741476615087722874458136990266833440576646963466074693171606778"), + y: MontFp!("75202396197342917254523279069469674666303680671605970245803554133573745859131002231546341942288521574682619325841484506619191207488304"), + z: MontFp!("1"), + }, + }; + + // Size in number of bytes when the [to_bytes] + // function is called. + // This is not the size after serialization! + pub const BYTE_SIZE: usize = 2 * 7 * 8 + 1; + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { + use rayon::prelude::*; + let n_threads = rayon::current_num_threads(); + let chunk_size = bases.len().div_ceil(n_threads); + bases + .par_iter() + .map(|&x| x.inner.into_affine()) + .chunks(chunk_size) + .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) + .map(|(bases, scalars)| Self { + inner: crate::curve_446::g1::G1Projective::msm(&bases, &scalars).unwrap(), + }) + .sum::() + } + + pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + let g = self.inner.into_affine(); + let x = bigint_to_bytes(g.x.0 .0); + let y = bigint_to_bytes(g.y.0 .0); + let mut buf = [0u8; 2 * 7 * 8 + 1]; + buf[..7 * 8].copy_from_slice(&x); + buf[7 * 8..][..7 * 8].copy_from_slice(&y); + buf[2 * 7 * 8] = g.infinity as u8; + buf + } + + pub fn double(self) -> Self { + Self { + inner: self.inner.double(), + } + } + } + + impl Add for G1 { + type Output = G1; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + G1 { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for G1 { + type Output = G1; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + G1 { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for G1 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for G1 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl core::iter::Sum for G1 { + fn sum>(iter: I) -> Self { + iter.fold(G1::ZERO, Add::add) + } + } + + impl Neg for G1 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod g2 { + use super::*; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct G2 { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(super) inner: crate::curve_446::g2::G2Projective, + } + + impl fmt::Debug for G2 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(dead_code)] + #[derive(Debug)] + struct QuadExtField { + c0: T, + c1: T, + } + + f.debug_struct("G2") + .field( + "x", + &QuadExtField { + c0: MontIntDisplay(&self.inner.x.c0), + c1: MontIntDisplay(&self.inner.x.c1), + }, + ) + .field( + "y", + &QuadExtField { + c0: MontIntDisplay(&self.inner.y.c0), + c1: MontIntDisplay(&self.inner.y.c1), + }, + ) + .field( + "z", + &QuadExtField { + c0: MontIntDisplay(&self.inner.z.c0), + c1: MontIntDisplay(&self.inner.z.c1), + }, + ) + .finish() + } + } + + impl G2 { + pub const ZERO: Self = Self { + inner: crate::curve_446::g2::G2Projective { + x: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }, + y: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }, + z: ark_ff::QuadExtField { + c0: MontFp!("0"), + c1: MontFp!("0"), + }, + }, + }; + + pub const GENERATOR: Self = Self { + inner: crate::curve_446::g2::G2Projective { + x: ark_ff::QuadExtField { + c0: MontFp!("96453755443802578867745476081903764610578492683850270111202389209355548711427786327510993588141991264564812146530214503491136289085725"), + c1: MontFp!("85346509177292795277012009839788781950274202400882571466460158277083221521663169974265433098009350061415973662678938824527658049065530"), + }, + y: ark_ff::QuadExtField { + c0: MontFp!("49316184343270950587272132771103279293158283984999436491292404103501221698714795975575879957605051223501287444864258801515822358837529"), + c1: MontFp!("107680854723992552431070996218129928499826544031468382031848626814251381379173928074140221537929995580031433096217223703806029068859074"), + }, + z: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0") , + }, + }, + }; + + // Size in number of bytes when the [to_bytes] + // function is called. + // This is not the size after serialization! + pub const BYTE_SIZE: usize = 4 * 7 * 8 + 1; + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { + use rayon::prelude::*; + let n_threads = rayon::current_num_threads(); + let chunk_size = bases.len().div_ceil(n_threads); + bases + .par_iter() + .map(|&x| x.inner.into_affine()) + .chunks(chunk_size) + .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) + .map(|(bases, scalars)| Self { + inner: crate::curve_446::g2::G2Projective::msm(&bases, &scalars).unwrap(), + }) + .sum::() + } + + pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + let g = self.inner.into_affine(); + let xc0 = bigint_to_bytes(g.x.c0.0 .0); + let xc1 = bigint_to_bytes(g.x.c1.0 .0); + let yc0 = bigint_to_bytes(g.y.c0.0 .0); + let yc1 = bigint_to_bytes(g.y.c1.0 .0); + let mut buf = [0u8; 4 * 7 * 8 + 1]; + buf[..7 * 8].copy_from_slice(&xc0); + buf[7 * 8..][..7 * 8].copy_from_slice(&xc1); + buf[2 * 7 * 8..][..7 * 8].copy_from_slice(&yc0); + buf[3 * 7 * 8..][..7 * 8].copy_from_slice(&yc1); + buf[4 * 7 * 8] = g.infinity as u8; + buf + } + + pub fn double(self) -> Self { + Self { + inner: self.inner.double(), + } + } + } + + impl Add for G2 { + type Output = G2; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + G2 { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for G2 { + type Output = G2; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + G2 { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for G2 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for G2 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl core::iter::Sum for G2 { + fn sum>(iter: I) -> Self { + iter.fold(G2::ZERO, Add::add) + } + } + + impl Neg for G2 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod gt { + use super::*; + use ark_ec::bls12::Bls12Config; + use ark_ec::pairing::{MillerLoopOutput, Pairing}; + use ark_ff::{CubicExtField, Fp12, Fp2, QuadExtField}; + + type Bls = crate::curve_446::Bls12_446; + type Config = crate::curve_446::Config; + + const ONE: Fp2<::Fp2Config> = QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }; + const ZERO: Fp2<::Fp2Config> = QuadExtField { + c0: MontFp!("0"), + c1: MontFp!("0"), + }; + + const U1: Fp12<::Fp12Config> = QuadExtField { + c0: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: ZERO, + }, + c1: CubicExtField { + c0: ONE, + c1: ZERO, + c2: ZERO, + }, + }; + const U3: Fp12<::Fp12Config> = QuadExtField { + c0: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: ZERO, + }, + c1: CubicExtField { + c0: ZERO, + c1: ONE, + c2: ZERO, + }, + }; + + const fn fp2_to_fp12( + x: Fp2<::Fp2Config>, + ) -> Fp12<::Fp12Config> { + QuadExtField { + c0: CubicExtField { + c0: x, + c1: ZERO, + c2: ZERO, + }, + c1: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: ZERO, + }, + } + } + + const fn fp_to_fp12( + x: ::Fp, + ) -> Fp12<::Fp12Config> { + fp2_to_fp12(QuadExtField { + c0: x, + c1: MontFp!("0"), + }) + } + + fn ate_tangent_ev(qt: G2, evpt: G1) -> Fp12<::Fp12Config> { + let qt = qt.inner.into_affine(); + let evpt = evpt.inner.into_affine(); + + let (xt, yt) = (qt.x, qt.y); + let (xe, ye) = (evpt.x, evpt.y); + + let xt = fp2_to_fp12(xt); + let yt = fp2_to_fp12(yt); + let xe = fp_to_fp12(xe); + let ye = fp_to_fp12(ye); + + let three = fp_to_fp12(MontFp!("3")); + let two = fp_to_fp12(MontFp!("2")); + + let l = three * xt.square() / (two * yt); + ye - (l * xe) / U1 + (l * xt - yt) / U3 + } + + fn ate_line_ev(q1: G2, q2: G2, evpt: G1) -> Fp12<::Fp12Config> { + let q1 = q1.inner.into_affine(); + let q2 = q2.inner.into_affine(); + let evpt = evpt.inner.into_affine(); + + let (x1, y1) = (q1.x, q1.y); + let (x2, y2) = (q2.x, q2.y); + let (xe, ye) = (evpt.x, evpt.y); + + let x1 = fp2_to_fp12(x1); + let y1 = fp2_to_fp12(y1); + let x2 = fp2_to_fp12(x2); + let y2 = fp2_to_fp12(y2); + let xe = fp_to_fp12(xe); + let ye = fp_to_fp12(ye); + + let l = (y2 - y1) / (x2 - x1); + ye - (l * xe) / U1 + (l * x1 - y1) / U3 + } + + #[allow(clippy::needless_range_loop)] + fn ate_pairing(p: G1, q: G2) -> Gt { + let t_log2 = 75; + let t_bits = b"110000000001000001000000100000000000000000000000000000000100000000000000001"; + + let mut fk = fp_to_fp12(MontFp!("1")); + let mut qk = q; + + for k in 1..t_log2 { + let lkk = ate_tangent_ev(qk, p); + qk = qk + qk; + fk = fk.square() * lkk; + + if t_bits[k] == b'1' { + assert_ne!(q, qk); + let lkp1 = if q != -qk { + ate_line_ev(q, qk, p) + } else { + fp_to_fp12(MontFp!("1")) + }; + qk += q; + fk *= lkp1; + } + } + let mlo = MillerLoopOutput(fk); + Gt { + inner: Bls::final_exponentiation(mlo).unwrap(), + } + } + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct Gt { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: ark_ec::pairing::PairingOutput, + } + + impl fmt::Debug for Gt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(dead_code)] + #[derive(Debug)] + struct QuadExtField { + c0: T, + c1: T, + } + + #[allow(dead_code)] + #[derive(Debug)] + struct CubicExtField { + c0: T, + c1: T, + c2: T, + } + + #[allow(dead_code)] + #[derive(Debug)] + pub struct Gt { + inner: T, + } + + f.debug_struct("Gt") + .field( + "inner", + &Gt { + inner: QuadExtField { + c0: CubicExtField { + c0: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c0.c0), + c1: MontIntDisplay(&self.inner.0.c0.c0.c1), + }, + c1: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c1.c0), + c1: MontIntDisplay(&self.inner.0.c0.c1.c1), + }, + c2: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c2.c0), + c1: MontIntDisplay(&self.inner.0.c0.c2.c1), + }, + }, + c1: CubicExtField { + c0: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c0.c0), + c1: MontIntDisplay(&self.inner.0.c1.c0.c1), + }, + c1: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c1.c0), + c1: MontIntDisplay(&self.inner.0.c1.c1.c1), + }, + c2: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c2.c0), + c1: MontIntDisplay(&self.inner.0.c1.c2.c1), + }, + }, + }, + }, + ) + .finish() + } + } + + impl Gt { + pub fn pairing(g1: G1, g2: G2) -> Self { + ate_pairing(g1, -g2) + } + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + } + + impl Add for Gt { + type Output = Gt; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Gt { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for Gt { + type Output = Gt; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Gt { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for Gt { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for Gt { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl Neg for Gt { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod zp { + use super::*; + use ark_ff::Fp; + use zeroize::Zeroize; + + fn redc(n: [u64; 5], nprime: u64, mut t: [u64; 7]) -> [u64; 5] { + for i in 0..2 { + let mut c = 0u64; + let m = u64::wrapping_mul(t[i], nprime); + + for j in 0..5 { + let x = t[i + j] as u128 + m as u128 * n[j] as u128 + c as u128; + t[i + j] = x as u64; + c = (x >> 64) as u64; + } + + for j in 5..7 - i { + let x = t[i + j] as u128 + c as u128; + t[i + j] = x as u64; + c = (x >> 64) as u64; + } + } + + let mut t = [t[2], t[3], t[4], t[5], t[6]]; + + if t.into_iter().rev().ge(n.into_iter().rev()) { + let mut o = false; + for i in 0..5 { + let (ti, o0) = u64::overflowing_sub(t[i], n[i]); + let (ti, o1) = u64::overflowing_sub(ti, o as u64); + o = o0 | o1; + t[i] = ti; + } + } + assert!(t.into_iter().rev().lt(n.into_iter().rev())); + + t + } + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Zeroize)] + #[repr(transparent)] + pub struct Zp { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: crate::curve_446::Fr, + } + + impl fmt::Debug for Zp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Zp") + .field(&MontIntDisplay(&self.inner)) + .finish() + } + } + + impl Zp { + pub const ZERO: Self = Self { + inner: MontFp!("0"), + }; + + pub const ONE: Self = Self { + inner: MontFp!("1"), + }; + + pub fn from_bigint(n: [u64; 5]) -> Self { + Self { + inner: BigInt(n).into(), + } + } + + pub fn from_u64(n: u64) -> Self { + Self { + inner: BigInt([n, 0, 0, 0, 0]).into(), + } + } + + pub fn from_i64(n: i64) -> Self { + let n_abs = Self::from_u64(n.unsigned_abs()); + if n > 0 { + n_abs + } else { + -n_abs + } + } + + pub fn to_bytes(self) -> [u8; 5 * 8] { + let buf = [ + self.inner.0 .0[0].to_le_bytes(), + self.inner.0 .0[1].to_le_bytes(), + self.inner.0 .0[2].to_le_bytes(), + self.inner.0 .0[3].to_le_bytes(), + self.inner.0 .0[4].to_le_bytes(), + ]; + unsafe { core::mem::transmute(buf) } + } + + fn from_raw_u64x7(n: [u64; 7]) -> Self { + const MODULUS: BigInt<5> = BigInt!( + "645383785691237230677916041525710377746967055506026847120930304831624105190538527824412673" + ); + + const MODULUS_MONTGOMERY: u64 = 272467794636046335; + + let mut n = n; + // zero the 22 leading bits, so the result is <= MODULUS * 2^128 + n[6] &= (1 << 42) - 1; + Zp { + inner: Fp( + BigInt(redc(MODULUS.0, MODULUS_MONTGOMERY, n)), + core::marker::PhantomData, + ), + } + } + + pub fn rand(rng: &mut dyn rand::RngCore) -> Self { + use rand::Rng; + + Self::from_raw_u64x7([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + } + + pub fn hash(values: &mut [Zp], data: &[&[u8]]) { + use sha3::digest::{ExtendableOutput, Update, XofReader}; + + let mut hasher = sha3::Shake256::default(); + for data in data { + hasher.update(data); + } + let mut reader = hasher.finalize_xof(); + + for value in values { + let mut bytes = [0u8; 7 * 8]; + reader.read(&mut bytes); + let bytes: [[u8; 8]; 7] = unsafe { core::mem::transmute(bytes) }; + *value = Zp::from_raw_u64x7(bytes.map(u64::from_le_bytes)); + } + } + } + + impl Add for Zp { + type Output = Zp; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for Zp { + type Output = Zp; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner - rhs.inner, + } + } + } + + impl Mul for Zp { + type Output = Zp; + + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner * rhs.inner, + } + } + } + + impl Div for Zp { + type Output = Zp; + + #[inline] + fn div(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner / rhs.inner, + } + } + } + impl AddAssign for Zp { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for Zp { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl Neg for Zp { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } + + impl core::iter::Sum for Zp { + fn sum>(iter: I) -> Self { + iter.fold(Zp::ZERO, Add::add) + } + } +} + +pub use g1::G1; +pub use g2::G2; +pub use gt::Gt; +pub use zp::Zp; + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + use std::collections::HashMap; + + #[test] + fn test_g1() { + let x = G1::GENERATOR; + let y = x.mul_scalar(Zp::from_i64(-2) * Zp::from_u64(2)); + assert_eq!(x - x, G1::ZERO); + assert_eq!(x + y - x, y); + } + + #[test] + fn test_g2() { + let x = G2::GENERATOR; + let y = x.mul_scalar(Zp::from_i64(-2) * Zp::from_u64(2)); + assert_eq!(x - x, G2::ZERO); + assert_eq!(x + y - x, y); + } + + #[test] + fn test_g1_msm() { + let n = 1024; + let x = vec![G1::GENERATOR.mul_scalar(Zp::from_i64(-1)); n]; + let mut p = vec![Zp::ZERO; n]; + Zp::hash(&mut p, &[&[0]]); + + let result = G1::multi_mul_scalar(&x, &p); + let expected = x + .iter() + .zip(p.iter()) + .map(|(&x, &p)| x.mul_scalar(p)) + .sum::(); + assert_eq!(result, expected); + } + + #[test] + fn test_g2_msm() { + let n = 1024; + let x = vec![G2::GENERATOR.mul_scalar(Zp::from_i64(-1)); n]; + let mut p = vec![Zp::ZERO; n]; + Zp::hash(&mut p, &[&[0]]); + + let result = G2::multi_mul_scalar(&x, &p); + let expected = x + .iter() + .zip(p.iter()) + .map(|(&x, &p)| x.mul_scalar(p)) + .sum::(); + assert_eq!(result, expected); + } + + #[test] + fn test_pairing() { + let rng = &mut StdRng::seed_from_u64(0); + let p1 = Zp::rand(rng); + let p2 = Zp::rand(rng); + + let x1 = G1::GENERATOR.mul_scalar(p1); + let x2 = G2::GENERATOR.mul_scalar(p2); + + assert_eq!( + Gt::pairing(x1, x2), + Gt::pairing(G1::GENERATOR, G2::GENERATOR).mul_scalar(p1 * p2), + ); + } + + #[test] + fn test_distributivity() { + let a = Zp { + inner: MontFp!( + "20799633726231143268782044631117354647259165363910905818134484248029981143850" + ), + }; + let b = Zp { + inner: MontFp!( + "42333504039292951860879669847432876299949385605895551964353185488509497658948" + ), + }; + let c = Zp { + inner: MontFp!( + "6797004509292554067788526429737434481164547177696793280652530849910670196287" + ), + }; + + assert_eq!((((a - b) * c) - (a * c - b * c)).inner, Zp::ZERO.inner); + } + + #[test] + fn test_serialization() { + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let g_cur = G1::GENERATOR.mul_scalar(alpha); + let g_hat_cur = G2::GENERATOR.mul_scalar(alpha); + + let alpha2: Zp = serde_json::from_str(&serde_json::to_string(&alpha).unwrap()).unwrap(); + assert_eq!(alpha, alpha2); + + let g_cur2: G1 = serde_json::from_str(&serde_json::to_string(&g_cur).unwrap()).unwrap(); + assert_eq!(g_cur, g_cur2); + + let g_hat_cur2: G2 = + serde_json::from_str(&serde_json::to_string(&g_hat_cur).unwrap()).unwrap(); + assert_eq!(g_hat_cur, g_hat_cur2); + } + + #[test] + fn test_hasher_and_eq() { + // we need to make sure if the points are the same + // but the projective representations are different + // then they still hash into the same thing + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let a = G1::GENERATOR.mul_scalar(alpha); + + // serialization should convert the point to affine representation + // after deserializing it we should have the same point + // but with a different representation + let a_affine: G1 = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap(); + + // the internal elements should be different + assert_ne!(a.inner.x, a_affine.inner.x); + assert_ne!(a.inner.y, a_affine.inner.y); + assert_ne!(a.inner.z, a_affine.inner.z); + + // but equality and hasher should see the two as the same point + assert_eq!(a, a_affine); + let mut hm = HashMap::new(); + hm.insert(a, 1); + assert_eq!(hm.len(), 1); + hm.insert(a_affine, 2); + assert_eq!(hm.len(), 1); + } +} diff --git a/tfhe-zk-pok/src/lib.rs b/tfhe-zk-pok/src/lib.rs new file mode 100644 index 0000000000..b75cb224de --- /dev/null +++ b/tfhe-zk-pok/src/lib.rs @@ -0,0 +1,3 @@ +pub mod curve_446; +pub mod curve_api; +pub mod proofs; diff --git a/tfhe-zk-pok/src/proofs/binary.rs b/tfhe-zk-pok/src/proofs/binary.rs new file mode 100644 index 0000000000..aad5b9aa2a --- /dev/null +++ b/tfhe-zk-pok/src/proofs/binary.rs @@ -0,0 +1,219 @@ +use super::*; + +#[derive(Clone, Debug)] +pub struct PublicParams { + g_lists: GroupElements, +} + +impl PublicParams { + pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + Self { + g_lists: GroupElements::from_vec(g_list, g_hat_list), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct PublicCommit { + c_hat: G::G2, +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + message: Vec, + gamma: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct Proof { + c_y: G::G1, + pi: G::G1, +} + +pub fn crs_gen( + message_len: usize, + rng: &mut dyn RngCore, +) -> (PublicParams, PrivateParams) { + let alpha = G::Zp::rand(rng); + ( + PublicParams { + g_lists: GroupElements::new(message_len, alpha), + }, + PrivateParams { alpha }, + ) +} + +pub fn commit( + message: &[u64], + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let g_hat = G::G2::GENERATOR; + let n = message.len(); + + let gamma = G::Zp::rand(rng); + let x = OneBased::new_ref(message); + + let mut c_hat = g_hat.mul_scalar(gamma); + for j in 1..n + 1 { + let term = if x[j] != 0 { + public.g_lists.g_hat_list[j] + } else { + G::G2::ZERO + }; + c_hat += term; + } + + ( + PublicCommit { c_hat }, + PrivateCommit { + message: message.to_vec(), + gamma, + }, + ) +} + +pub fn prove( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + rng: &mut dyn RngCore, +) -> Proof { + let n = private_commit.message.len(); + let g = G::G1::GENERATOR; + let x = OneBased::new_ref(&*private_commit.message); + let c_hat = public.1.c_hat; + let gamma = private_commit.gamma; + let gamma_y = G::Zp::rand(rng); + let g_list = &public.0.g_lists.g_list; + + let mut y = OneBased(vec![G::Zp::ZERO; n]); + G::Zp::hash(&mut y.0, &[c_hat.to_bytes().as_ref()]); + + let mut c_y = g.mul_scalar(gamma_y); + for j in 1..n + 1 { + c_y += (g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(x[j])); + } + + let y_bytes = &*(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(); + let mut t = OneBased(vec![G::Zp::ZERO; n]); + G::Zp::hash( + &mut t.0, + &[y_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + let proof = { + let mut poly_0 = vec![G::Zp::ZERO; n + 1]; + let mut poly_1 = vec![G::Zp::ZERO; n + 1]; + let mut poly_2 = vec![G::Zp::ZERO; n + 1]; + let mut poly_3 = vec![G::Zp::ZERO; n + 1]; + + poly_0[0] = gamma_y * delta_y; + for i in 1..n + 1 { + poly_0[n + 1 - i] = + delta_y * (G::Zp::from_u64(x[i]) * y[i]) + (delta_eq * t[i] - delta_y) * y[i]; + } + + poly_1[0] = gamma; + for i in 1..n + 1 { + poly_1[i] = G::Zp::from_u64(x[i]); + } + + poly_2[0] = gamma_y; + for i in 1..n + 1 { + poly_2[n + 1 - i] = y[i] * G::Zp::from_u64(x[i]); + } + + for i in 1..n + 1 { + poly_3[i] = delta_eq * t[i]; + } + + let poly = G::Zp::poly_sub( + &G::Zp::poly_mul(&poly_0, &poly_1), + &G::Zp::poly_mul(&poly_2, &poly_3), + ); + + let mut proof = g.mul_scalar(poly[0]); + for i in 1..poly.len() { + proof += g_list[i].mul_scalar(poly[i]); + } + proof + }; + + Proof { pi: proof, c_y } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let e = G::Gt::pairing; + let c_hat = public.1.c_hat; + let g_hat = G::G2::GENERATOR; + let g_list = &public.0.g_lists.g_list; + let g_hat_list = &public.0.g_lists.g_hat_list; + let n = public.0.g_lists.message_len; + + let pi = proof.pi; + let c_y = proof.c_y; + + let mut y = OneBased(vec![G::Zp::ZERO; n]); + G::Zp::hash(&mut y.0, &[c_hat.to_bytes().as_ref()]); + + let y_bytes = &*(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(); + let mut t = OneBased(vec![G::Zp::ZERO; n]); + G::Zp::hash( + &mut t.0, + &[y_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + let rhs = e(pi, g_hat); + let lhs = { + let numerator = { + let mut p = c_y.mul_scalar(delta_y); + for i in 1..n + 1 { + let gy = g_list[n + 1 - i].mul_scalar(y[i]); + p += gy.mul_scalar(delta_eq).mul_scalar(t[i]) - gy.mul_scalar(delta_y); + } + e(p, c_hat) + }; + let denominator = { + let mut q = G::G2::ZERO; + for i in 1..n + 1 { + q += g_hat_list[i].mul_scalar(delta_eq).mul_scalar(t[i]); + } + e(c_y, q) + }; + numerator - denominator + }; + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } +} diff --git a/tfhe-zk-pok/src/proofs/index.rs b/tfhe-zk-pok/src/proofs/index.rs new file mode 100644 index 0000000000..a770f98b39 --- /dev/null +++ b/tfhe-zk-pok/src/proofs/index.rs @@ -0,0 +1,127 @@ +use super::*; + +#[derive(Clone, Debug)] +pub struct PublicParams { + g_lists: GroupElements, +} + +impl PublicParams { + pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + Self { + g_lists: GroupElements::from_vec(g_list, g_hat_list), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct PublicCommit { + c: G::G1, +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + message: Vec, + gamma: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct Proof { + pi: G::G1, +} + +pub fn crs_gen( + message_len: usize, + rng: &mut dyn RngCore, +) -> (PublicParams, PrivateParams) { + let alpha = G::Zp::rand(rng); + ( + PublicParams { + g_lists: GroupElements::new(message_len, alpha), + }, + PrivateParams { alpha }, + ) +} + +pub fn commit( + message: &[u64], + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let g = G::G1::GENERATOR; + let n = message.len(); + + let gamma = G::Zp::rand(rng); + let m = OneBased::new_ref(message); + + let mut c = g.mul_scalar(gamma); + for j in 1..n + 1 { + let term = public.g_lists.g_list[j].mul_scalar(G::Zp::from_u64(m[j])); + c += term; + } + + ( + PublicCommit { c }, + PrivateCommit { + message: message.to_vec(), + gamma, + }, + ) +} + +pub fn prove( + i: usize, + public: (&PublicParams, &PublicCommit), + private: (&PrivateParams, &PrivateCommit), + rng: &mut dyn RngCore, +) -> Proof { + let _ = rng; + let n = private.1.message.len(); + let m = OneBased::new_ref(&*private.1.message); + let gamma = private.1.gamma; + let g_list = &public.0.g_lists.g_list; + + let mut pi = g_list[n + 1 - i].mul_scalar(gamma); + for j in 1..n + 1 { + if i != j { + let term = if m[j] & 1 == 1 { + g_list[n + 1 - i + j] + } else { + G::G1::ZERO + }; + + pi += term; + } + } + + Proof { pi } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + (index, mi): (usize, u64), + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let e = G::Gt::pairing; + let c = public.1.c; + let g_hat = G::G2::GENERATOR; + let g_list = &public.0.g_lists.g_list; + let g_hat_list = &public.0.g_lists.g_hat_list; + let n = public.0.g_lists.message_len; + let i = index + 1; + + let lhs = e(c, g_hat_list[n + 1 - i]); + let rhs = e(proof.pi, g_hat) + (e(g_list[1], g_hat_list[n])).mul_scalar(G::Zp::from_u64(mi)); + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } +} diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs new file mode 100644 index 0000000000..bc6f0cf1a0 --- /dev/null +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -0,0 +1,95 @@ +use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps}; + +use core::ops::{Index, IndexMut}; +use rand::RngCore; + +#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize)] +#[repr(transparent)] +struct OneBased(T); + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ComputeLoad { + Proof, + Verify, +} + +impl OneBased { + pub fn new(inner: T) -> Self + where + T: Sized, + { + Self(inner) + } + + pub fn new_ref(inner: &T) -> &Self { + unsafe { &*(inner as *const T as *const Self) } + } +} + +impl> Index for OneBased { + type Output = T::Output; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.0[index - 1] + } +} + +impl> IndexMut for OneBased { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index - 1] + } +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +struct GroupElements { + g_list: OneBased>, + g_hat_list: OneBased>, + message_len: usize, +} + +impl GroupElements { + pub fn new(message_len: usize, alpha: G::Zp) -> Self { + let mut g_list = Vec::new(); + let mut g_hat_list = Vec::new(); + + let mut g_cur = G::G1::GENERATOR.mul_scalar(alpha); + + for i in 0..2 * message_len { + if i == message_len { + g_list.push(G::G1::ZERO); + } else { + g_list.push(g_cur); + } + g_cur = g_cur.mul_scalar(alpha); + } + + let mut g_hat_cur = G::G2::GENERATOR.mul_scalar(alpha); + for _ in 0..message_len { + g_hat_list.push(g_hat_cur); + g_hat_cur = (g_hat_cur).mul_scalar(alpha); + } + + Self { + g_list: OneBased::new(g_list), + g_hat_list: OneBased::new(g_hat_list), + message_len, + } + } + + pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + let message_len = g_hat_list.len(); + Self { + g_list: OneBased::new(g_list), + g_hat_list: OneBased::new(g_hat_list), + message_len, + } + } +} + +pub mod binary; +pub mod index; +pub mod pke; +pub mod range; +pub mod rlwe; diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs new file mode 100644 index 0000000000..b1cbaf59e8 --- /dev/null +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -0,0 +1,1043 @@ +use super::*; +use core::marker::PhantomData; +use rayon::prelude::*; + +fn bit_iter(x: u64, nbits: u32) -> impl Iterator { + (0..nbits).map(move |idx| ((x >> idx) & 1) != 0) +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct PublicParams { + g_lists: GroupElements, + big_d: usize, + pub n: usize, + pub d: usize, + pub k: usize, + pub b: u64, + pub b_r: u64, + pub q: u64, + pub t: u64, +} + +impl PublicParams { + #[allow(clippy::too_many_arguments)] + pub fn from_vec( + g_list: Vec, + g_hat_list: Vec, + big_d: usize, + n: usize, + d: usize, + k: usize, + b: u64, + b_r: u64, + q: u64, + t: u64, + ) -> Self { + Self { + g_lists: GroupElements::::from_vec(g_list, g_hat_list), + big_d, + n, + d, + k, + b, + b_r, + q, + t, + } + } + + pub fn exclusive_max_noise(&self) -> u64 { + self.b + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct Proof { + c_hat: G::G2, + c_y: G::G1, + pi: G::G1, + c_hat_t: Option, + c_h: Option, + pi_kzg: Option, +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct PublicCommit { + a: Vec, + b: Vec, + c1: Vec, + c2: Vec, + __marker: PhantomData, +} + +impl PublicCommit { + pub fn new(a: Vec, b: Vec, c1: Vec, c2: Vec) -> Self { + Self { + a, + b, + c1, + c2, + __marker: Default::default(), + } + } +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + r: Vec, + e1: Vec, + m: Vec, + e2: Vec, + __marker: PhantomData, +} + +pub fn crs_gen( + d: usize, + k: usize, + b: u64, + q: u64, + t: u64, + rng: &mut dyn RngCore, +) -> (PublicParams, PrivateParams) { + let alpha = G::Zp::rand(rng); + let b_r = d as u64 / 2 + 1; + + let big_d = + d + k * t.ilog2() as usize + (d + k) * (2 + b.ilog2() as usize + b_r.ilog2() as usize); + let n = big_d + 1; + ( + PublicParams { + g_lists: GroupElements::::new(n, alpha), + big_d, + n, + d, + k, + b, + b_r, + q, + t, + }, + PrivateParams { alpha }, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn commit( + a: Vec, + b: Vec, + c1: Vec, + c2: Vec, + r: Vec, + e1: Vec, + m: Vec, + e2: Vec, + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let _ = (public, rng); + ( + PublicCommit { + a, + b, + c1, + c2, + __marker: PhantomData, + }, + PrivateCommit { + r, + e1, + m, + e2, + __marker: PhantomData, + }, + ) +} + +pub fn prove( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + load: ComputeLoad, + rng: &mut dyn RngCore, +) -> Proof { + let &PublicParams { + ref g_lists, + big_d, + n, + d, + b, + b_r, + q, + t, + k, + } = public.0; + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let b_i = b; + + let PublicCommit { a, b, c1, c2, .. } = public.1; + let PrivateCommit { r, e1, m, e2, .. } = private_commit; + + assert!(c2.len() <= k); + let k = k.min(c2.len()); + + // FIXME: div_round + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let g = G::G1::GENERATOR; + let g_hat = G::G2::GENERATOR; + let gamma = G::Zp::rand(rng); + let gamma_y = G::Zp::rand(rng); + + // rot(a) phi(r) + phi(e1) - q phi(r1) = phi(c1) + // phi[d - i + 1](bar(b)).T phi(r) + delta m_i + e2_i - q r2_i = c2 + + // phi(r1) = (rot(a) phi(r) + phi(e1) - phi(c1)) / q + // r2_i = (phi[d - i + 1](bar(b)).T phi(r) + delta m_i + e2_i - c2) / q + + let mut r1 = e1 + .iter() + .zip(c1.iter()) + .map(|(&e1, &c1)| e1 as i128 - c1 as i128) + .collect::>(); + + for i in 0..d { + for j in 0..d { + if i + j < d { + r1[i + j] += a[i] as i128 * r[d - j - 1] as i128; + } else { + r1[i + j - d] -= a[i] as i128 * r[d - j - 1] as i128; + } + } + } + + { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + for r1 in &mut *r1 { + *r1 /= q; + } + } + + let mut r2 = m + .iter() + .zip(e2) + .zip(c2) + .map(|((&m, &e2), &c2)| delta as i128 * m as i128 + e2 as i128 - c2 as i128) + .collect::>(); + + { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + for (i, r2) in r2.iter_mut().enumerate() { + let mut dot = 0i128; + for j in 0..d { + let b = if i + j < d { + b[d - j - i - 1] + } else { + b[2 * d - j - i - 1].wrapping_neg() + }; + + dot += r[d - j - 1] as i128 * b as i128; + } + + *r2 += dot; + *r2 /= q; + } + } + + let r1 = r1 + .into_vec() + .into_iter() + .map(|r1| r1 as i64) + .collect::>(); + + let r2 = r2 + .into_vec() + .into_iter() + .map(|r2| r2 as i64) + .collect::>(); + + let mut w = vec![false; n]; + + let u64 = |x: i64| x as u64; + + w[..big_d] + .iter_mut() + .zip( + r.iter() + .rev() + .flat_map(|&r| bit_iter(u64(r), 1)) + .chain(m.iter().flat_map(|&m| bit_iter(u64(m), t.ilog2()))) + .chain(e1.iter().flat_map(|&e1| bit_iter(u64(e1), 1 + b_i.ilog2()))) + .chain(e2.iter().flat_map(|&e2| bit_iter(u64(e2), 1 + b_i.ilog2()))) + .chain(r1.iter().flat_map(|&r1| bit_iter(u64(r1), 1 + b_r.ilog2()))) + .chain(r2.iter().flat_map(|&r2| bit_iter(u64(r2), 1 + b_r.ilog2()))), + ) + .for_each(|(dst, src)| *dst = src); + + let w = OneBased(w); + + let mut c_hat = g_hat.mul_scalar(gamma); + for j in 1..big_d + 1 { + let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO }; + c_hat += term; + } + + let x_bytes = &*[ + q.to_le_bytes().as_slice(), + d.to_le_bytes().as_slice(), + b_i.to_le_bytes().as_slice(), + t.to_le_bytes().as_slice(), + &*a.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*b.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*c1.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*c2.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + ] + .iter() + .copied() + .flatten() + .copied() + .collect::>(); + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash(&mut y, &[x_bytes, c_hat.to_bytes().as_ref()]); + let y = OneBased(y); + + let scalars = (n + 1 - big_d..n + 1) + .map(|j| (y[n + 1 - j] * G::Zp::from_u64(w[n + 1 - j] as u64))) + .collect::>(); + let c_y = g.mul_scalar(gamma_y) + G::G1::multi_mul_scalar(&g_list.0[n - big_d..n], &scalars); + + let mut theta = vec![G::Zp::ZERO; d + k + 1]; + G::Zp::hash( + &mut theta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + + let theta0 = &theta[..d + k]; + + let delta_theta = theta[d + k]; + + let mut a_theta = vec![G::Zp::ZERO; big_d]; + + compute_a_theta::(theta0, d, a, k, b, &mut a_theta, t, delta, b_i, b_r, q); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + &(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(), + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + let mut poly_0 = vec![G::Zp::ZERO; n + 1]; + let mut poly_1 = vec![G::Zp::ZERO; big_d + 1]; + let mut poly_2 = vec![G::Zp::ZERO; n + 1]; + let mut poly_3 = vec![G::Zp::ZERO; n + 1]; + + poly_0[0] = delta_y * gamma_y; + for i in 1..n + 1 { + poly_0[n + 1 - i] = + delta_y * (y[i] * G::Zp::from_u64(w[i] as u64)) + (delta_eq * t[i] - delta_y) * y[i]; + + if i < big_d + 1 { + poly_0[n + 1 - i] += delta_theta * a_theta[i - 1]; + } + } + + poly_1[0] = gamma; + for i in 1..big_d + 1 { + poly_1[i] = G::Zp::from_u64(w[i] as u64); + } + + poly_2[0] = gamma_y; + for i in 1..big_d + 1 { + poly_2[n + 1 - i] = y[i] * G::Zp::from_u64(w[i] as u64); + } + + for i in 1..n + 1 { + poly_3[i] = delta_eq * t[i]; + } + + let mut t_theta = G::Zp::ZERO; + for i in 0..d { + t_theta += theta0[i] * G::Zp::from_i64(c1[i]); + } + for i in 0..k { + t_theta += theta0[d + i] * G::Zp::from_i64(c2[i]); + } + + let mut poly = G::Zp::poly_sub( + &G::Zp::poly_mul(&poly_0, &poly_1), + &G::Zp::poly_mul(&poly_2, &poly_3), + ); + if poly.len() > n + 1 { + poly[n + 1] -= t_theta * delta_theta; + } + + let pi = + g.mul_scalar(poly[0]) + G::G1::multi_mul_scalar(&g_list.0[..poly.len() - 1], &poly[1..]); + + if load == ComputeLoad::Proof { + let c_hat_t = G::G2::multi_mul_scalar(&g_hat_list.0, &t.0); + let scalars = (1..n + 1) + .map(|i| { + let i = n + 1 - i; + (delta_eq * t[i] - delta_y) * y[i] + + if i < big_d + 1 { + delta_theta * a_theta[i - 1] + } else { + G::Zp::ZERO + } + }) + .collect::>(); + let c_h = G::G1::multi_mul_scalar(&g_list.0[..n], &scalars); + + let mut z = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut z), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + ], + ); + + let mut pow = z; + let mut p_t = G::Zp::ZERO; + let mut p_h = G::Zp::ZERO; + + for i in 1..n + 1 { + p_t += t[i] * pow; + if n - i < big_d { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i] + + delta_theta * a_theta[n - i]) + * pow; + } else { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow; + } + pow = pow * z; + } + + let mut w = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut w), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + z.to_bytes().as_ref(), + p_h.to_bytes().as_ref(), + p_t.to_bytes().as_ref(), + ], + ); + + let mut poly = vec![G::Zp::ZERO; n + 1]; + for i in 1..n + 1 { + poly[i] += w * t[i]; + if i < big_d + 1 { + poly[n + 1 - i] += + (delta_eq * t[i] - delta_y) * y[i] + delta_theta * a_theta[i - 1]; + } else { + poly[n + 1 - i] += (delta_eq * t[i] - delta_y) * y[i]; + } + } + + let mut q = vec![G::Zp::ZERO; n]; + for i in (0..n).rev() { + poly[i] = poly[i] + z * poly[i + 1]; + q[i] = poly[i + 1]; + poly[i + 1] = G::Zp::ZERO; + } + let pi_kzg = g.mul_scalar(q[0]) + G::G1::multi_mul_scalar(&g_list.0[..n - 1], &q[1..n]); + + Proof { + c_hat, + c_y, + pi, + c_hat_t: Some(c_hat_t), + c_h: Some(c_h), + pi_kzg: Some(pi_kzg), + } + } else { + Proof { + c_hat, + c_y, + pi, + c_hat_t: None, + c_h: None, + pi_kzg: None, + } + } +} + +#[allow(clippy::too_many_arguments)] +fn compute_a_theta( + theta0: &[G::Zp], + d: usize, + a: &[i64], + k: usize, + b: &[i64], + a_theta: &mut [G::Zp], + t: u64, + delta: u64, + b_i: u64, + b_r: u64, + q: u64, +) { + // a_theta = Ã.T theta0 + // = [ + // rot(a).T theta1 + phi[d](bar(b)) theta2_1 + ... + phi[d-k+1](bar(b)) theta2_k + // + // delta g[log t].T theta2_1 + // delta g[log t].T theta2_2 + // ... + // delta g[log t].T theta2_k + // + // G[1 + log B].T theta1 + // + // g[1 + log B].T theta2_1 + // g[1 + log B].T theta2_2 + // ... + // g[1 + log B].T theta2_k + // + // -q G[1 + log Br].T theta1 + // + // -q g[1 + log Br].T theta2_1 + // -q g[1 + log Br].T theta2_2 + // ... + // -q g[1 + log Br].T theta2_k + // ] + + let q = if q == 0 { + G::Zp::from_u128(1u128 << 64) + } else { + G::Zp::from_u64(q) + }; + + let theta1 = &theta0[..d]; + let theta2 = &theta0[d..]; + { + let a_theta = &mut a_theta[..d]; + a_theta + .par_iter_mut() + .enumerate() + .for_each(|(i, a_theta_i)| { + let mut dot = G::Zp::ZERO; + + for j in 0..d { + let a = if i <= j { + a[j - i] + } else { + a[d + j - i].wrapping_neg() + }; + + dot += G::Zp::from_i64(a) * theta1[j]; + } + + for j in 0..k { + let b = if i + j < d { + b[d - i - j - 1] + } else { + b[2 * d - i - j - 1].wrapping_neg() + }; + + dot += G::Zp::from_i64(b) * theta2[j]; + } + *a_theta_i = dot; + }); + } + let a_theta = &mut a_theta[d..]; + + let step = t.ilog2() as usize; + for i in 0..k { + for j in 0..step { + let pow2 = G::Zp::from_u64(delta) * G::Zp::from_u64(1 << j) * theta2[i]; + a_theta[step * i + j] = pow2; + } + } + let a_theta = &mut a_theta[k * step..]; + + let step = 1 + b_i.ilog2() as usize; + for i in 0..d { + for j in 0..step { + let pow2 = G::Zp::from_u64(1 << j) * theta1[i]; + a_theta[step * i + j] = if j == step - 1 { -pow2 } else { pow2 }; + } + } + let a_theta = &mut a_theta[d * step..]; + for i in 0..k { + for j in 0..step { + let pow2 = G::Zp::from_u64(1 << j) * theta2[i]; + a_theta[step * i + j] = if j == step - 1 { -pow2 } else { pow2 }; + } + } + let a_theta = &mut a_theta[k * step..]; + + let step = 1 + b_r.ilog2() as usize; + for i in 0..d { + for j in 0..step { + let pow2 = -q * G::Zp::from_u64(1 << j) * theta1[i]; + a_theta[step * i + j] = if j == step - 1 { -pow2 } else { pow2 }; + } + } + let a_theta = &mut a_theta[d * step..]; + for i in 0..k { + for j in 0..step { + let pow2 = -q * G::Zp::from_u64(1 << j) * theta2[i]; + a_theta[step * i + j] = if j == step - 1 { -pow2 } else { pow2 }; + } + } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let &Proof { + c_hat, + c_y, + pi, + c_hat_t, + c_h, + pi_kzg, + } = proof; + let e = G::Gt::pairing; + + let &PublicParams { + ref g_lists, + big_d, + n, + d, + b, + b_r, + q, + t, + k, + } = public.0; + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let b_i = b; + + // FIXME: div_round + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let PublicCommit { a, b, c1, c2, .. } = public.1; + if c2.len() > k { + return Err(()); + } + let k = k.min(c2.len()); + + let x_bytes = &*[ + q.to_le_bytes().as_slice(), + d.to_le_bytes().as_slice(), + b_i.to_le_bytes().as_slice(), + t.to_le_bytes().as_slice(), + &*a.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*b.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*c1.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*c2.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + ] + .iter() + .copied() + .flatten() + .copied() + .collect::>(); + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash(&mut y, &[x_bytes, c_hat.to_bytes().as_ref()]); + let y = OneBased(y); + + let mut theta = vec![G::Zp::ZERO; d + k + 1]; + G::Zp::hash( + &mut theta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let theta0 = &theta[..d + k]; + let delta_theta = theta[d + k]; + + let mut a_theta = vec![G::Zp::ZERO; big_d]; + compute_a_theta::(theta0, d, a, k, b, &mut a_theta, t, delta, b_i, b_r, q); + + let mut t_theta = G::Zp::ZERO; + for i in 0..d { + t_theta += theta0[i] * G::Zp::from_i64(c1[i]); + } + for i in 0..k { + t_theta += theta0[d + i] * G::Zp::from_i64(c2[i]); + } + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + &(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(), + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) { + let mut z = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut z), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + ], + ); + + let mut pow = z; + let mut p_t = G::Zp::ZERO; + let mut p_h = G::Zp::ZERO; + + for i in 1..n + 1 { + p_t += t[i] * pow; + if n - i < big_d { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i] + + delta_theta * a_theta[n - i]) + * pow; + } else { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow; + } + pow = pow * z; + } + + if e(pi, G::G2::GENERATOR) + != e(c_y.mul_scalar(delta_y) + c_h, c_hat) + - e(c_y.mul_scalar(delta_eq), c_hat_t) + - e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta) + { + return Err(()); + } + + let mut w = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut w), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + z.to_bytes().as_ref(), + p_h.to_bytes().as_ref(), + p_t.to_bytes().as_ref(), + ], + ); + + if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR) + + e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w) + == e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z)) + { + Ok(()) + } else { + Err(()) + } + } else { + let (term0, term1) = rayon::join( + || { + let p = c_y.mul_scalar(delta_y) + + (1..n + 1) + .into_par_iter() + .map(|i| { + let mut factor = (delta_eq * t[i] - delta_y) * y[i]; + if i < big_d + 1 { + factor += delta_theta * a_theta[i - 1]; + } + g_list[n + 1 - i].mul_scalar(factor) + }) + .sum::(); + let q = c_hat; + e(p, q) + }, + || { + let p = c_y; + let q = (1..n + 1) + .into_par_iter() + .map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i])) + .sum::(); + e(p, q) + }, + ); + let term2 = { + let p = g_list[1]; + let q = g_hat_list[n]; + e(p, q) + }; + + let lhs = e(pi, G::G2::GENERATOR); + let rhs = term0 - term1 - term2.mul_scalar(t_theta * delta_theta); + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + #[test] + fn test_pke() { + let d = 2048; + let k = 320; + let b_i = 512; + let q = 0; + let t = 1024; + + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let rng = &mut StdRng::seed_from_u64(0); + + let polymul_rev = |a: &[i64], b: &[i64]| -> Vec { + assert_eq!(a.len(), b.len()); + let d = a.len(); + let mut c = vec![0i64; d]; + + for i in 0..d { + for j in 0..d { + if i + j < d { + c[i + j] = c[i + j].wrapping_add(a[i].wrapping_mul(b[d - j - 1])); + } else { + c[i + j - d] = c[i + j - d].wrapping_sub(a[i].wrapping_mul(b[d - j - 1])); + } + } + } + + c + }; + + let a = (0..d).map(|_| rng.gen::()).collect::>(); + let s = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + let e = (0..d) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + let e1 = (0..d) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + let fake_e1 = (0..d) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + let e2 = (0..k) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + let fake_e2 = (0..k) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + + let r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + let fake_r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + + let m = (0..k) + .map(|_| (rng.gen::() % t) as i64) + .collect::>(); + let fake_m = (0..k) + .map(|_| (rng.gen::() % t) as i64) + .collect::>(); + + let b = polymul_rev(&a, &s) + .into_iter() + .zip(e.iter()) + .map(|(x, e)| x.wrapping_add(*e)) + .collect::>(); + let c1 = polymul_rev(&a, &r) + .into_iter() + .zip(e1.iter()) + .map(|(x, e1)| x.wrapping_add(*e1)) + .collect::>(); + + let mut c2 = vec![0i64; k]; + + for i in 0..k { + let mut dot = 0i64; + for j in 0..d { + let b = if i + j < d { + b[d - j - i - 1] + } else { + b[2 * d - j - i - 1].wrapping_neg() + }; + + dot = dot.wrapping_add(r[d - j - 1].wrapping_mul(b)); + } + + c2[i] = dot + .wrapping_add(e2[i]) + .wrapping_add((delta * m[i] as u64) as i64); + } + + let mut m_roundtrip = vec![0i64; k]; + for i in 0..k { + let mut dot = 0i128; + for j in 0..d { + let c = if i + j < d { + c1[d - j - i - 1] + } else { + c1[2 * d - j - i - 1].wrapping_neg() + }; + + dot += s[d - j - 1] as i128 * c as i128; + } + + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + let val = ((c2[i] as i128).wrapping_sub(dot)) * t as i128; + let div = val.div_euclid(q); + let rem = val.rem_euclid(q); + let result = div as i64 + (rem > (q / 2)) as i64; + let result = result.rem_euclid(t as i64); + m_roundtrip[i] = result; + } + + let (public_param, _private_param) = + crs_gen::(d, k, b_i, q, t, rng); + + for use_fake_e1 in [false, true] { + for use_fake_e2 in [false, true] { + for use_fake_m in [false, true] { + for use_fake_r in [false, true] { + let (public_commit, private_commit) = commit( + a.clone(), + b.clone(), + c1.clone(), + c2.clone(), + if use_fake_r { + fake_r.clone() + } else { + r.clone() + }, + if use_fake_e1 { + fake_e1.clone() + } else { + e1.clone() + }, + if use_fake_m { + fake_m.clone() + } else { + m.clone() + }, + if use_fake_e2 { + fake_e2.clone() + } else { + e2.clone() + }, + &public_param, + rng, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = + prove((&public_param, &public_commit), &private_commit, load, rng); + + assert_eq!( + verify(&proof, (&public_param, &public_commit)).is_err(), + use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m + ); + } + } + } + } + } + } +} diff --git a/tfhe-zk-pok/src/proofs/range.rs b/tfhe-zk-pok/src/proofs/range.rs new file mode 100644 index 0000000000..299ecf48fd --- /dev/null +++ b/tfhe-zk-pok/src/proofs/range.rs @@ -0,0 +1,360 @@ +use super::*; + +#[derive(Clone, Debug)] +pub struct PublicParams { + g_lists: GroupElements, +} + +impl PublicParams { + pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + Self { + g_lists: GroupElements::from_vec(g_list, g_hat_list), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct PublicCommit { + l: usize, + v_hat: G::G2, +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + x: u64, + r: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct Proof { + c_y: G::G1, + c_hat: G::G2, + pi: G::G1, +} + +pub fn crs_gen( + max_nbits: usize, + rng: &mut dyn RngCore, +) -> (PublicParams, PrivateParams) { + let alpha = G::Zp::rand(rng); + ( + PublicParams { + g_lists: GroupElements::new(max_nbits, alpha), + }, + PrivateParams { alpha }, + ) +} + +pub fn commit( + x: u64, + l: usize, + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let g_hat = G::G2::GENERATOR; + + let r = G::Zp::rand(rng); + let v_hat = g_hat.mul_scalar(r) + public.g_lists.g_hat_list[1].mul_scalar(G::Zp::from_u64(x)); + + (PublicCommit { l, v_hat }, PrivateCommit { x, r }) +} + +pub fn prove( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + rng: &mut dyn RngCore, +) -> Proof { + let &PrivateCommit { x, r } = private_commit; + let &PublicCommit { l, v_hat } = public.1; + let PublicParams { g_lists } = public.0; + let n = g_lists.message_len; + + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let g = G::G1::GENERATOR; + let g_hat = G::G2::GENERATOR; + let gamma = G::Zp::rand(rng); + let gamma_y = G::Zp::rand(rng); + + let mut x_bits = vec![0u64; n]; + for (i, x_bits) in x_bits[0..l].iter_mut().enumerate() { + *x_bits = (x >> i) & 1; + } + let x_bits = OneBased(x_bits); + + let c_hat = { + let mut c = g_hat.mul_scalar(gamma); + for j in 1..l + 1 { + let term = if x_bits[j] != 0 { + g_hat_list[j] + } else { + G::G2::ZERO + }; + c += term; + } + c + }; + + let mut proof_x = -g_list[n].mul_scalar(r); + for i in 1..l + 1 { + let mut term = g_list[n + 1 - i].mul_scalar(gamma); + for j in 1..l + 1 { + if j != i { + let term_inner = if x_bits[j] != 0 { + g_list[n + 1 - i + j] + } else { + G::G1::ZERO + }; + term += term_inner; + } + } + + for _ in 1..i { + term = term.double(); + } + proof_x += term; + } + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut y, + &[v_hat.to_bytes().as_ref(), c_hat.to_bytes().as_ref()], + ); + let y = OneBased(y); + let mut c_y = g.mul_scalar(gamma_y); + for j in 1..l + 1 { + c_y += g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j])); + } + + let y_bytes = &*(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + y_bytes, + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut proof_eq = G::G1::ZERO; + for i in 1..n + 1 { + let mut numerator = g_list[n + 1 - i].mul_scalar(gamma); + for j in 1..n + 1 { + if j != i { + let term = if x_bits[j] != 0 { + g_list[n + 1 - i + j] + } else { + G::G1::ZERO + }; + numerator += term; + } + } + numerator = numerator.mul_scalar(t[i] * y[i]); + + let mut denominator = g_list[i].mul_scalar(gamma_y); + for j in 1..n + 1 { + if j != i { + denominator += g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j])); + } + } + denominator = denominator.mul_scalar(t[i]); + + proof_eq += numerator - denominator; + } + + let mut proof_y = g.mul_scalar(gamma_y); + for j in 1..n + 1 { + proof_y -= g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j])); + } + proof_y = proof_y.mul_scalar(gamma); + for i in 1..n + 1 { + let mut term = g_list[i].mul_scalar(gamma_y); + for j in 1..n + 1 { + if j != i { + term -= g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j])); + } + } + let term = if x_bits[i] != 0 { term } else { G::G1::ZERO }; + proof_y += term; + } + + let mut s = vec![G::Zp::ZERO; n]; + for (i, s) in s.iter_mut().enumerate() { + G::Zp::hash( + core::slice::from_mut(s), + &[ + &i.to_le_bytes(), + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + } + let s = OneBased(s); + + let mut proof_v = G::G1::ZERO; + for i in 2..n + 1 { + proof_v += G::G1::mul_scalar( + g_list[n + 1 - i].mul_scalar(r) + g_list[n + 2 - i].mul_scalar(G::Zp::from_u64(x)), + s[i], + ); + } + + let mut delta = [G::Zp::ZERO; 4]; + G::Zp::hash( + &mut delta, + &[ + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let [delta_x, delta_eq, delta_y, delta_v] = delta; + + let proof = proof_x.mul_scalar(delta_x) + + proof_eq.mul_scalar(delta_eq) + + proof_y.mul_scalar(delta_y) + + proof_v.mul_scalar(delta_v); + + Proof { + c_y, + c_hat, + pi: proof, + } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let e = G::Gt::pairing; + let &PublicCommit { l, v_hat } = public.1; + let PublicParams { g_lists } = public.0; + let n = g_lists.message_len; + + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let g_hat = G::G2::GENERATOR; + + let &Proof { c_y, c_hat, pi } = proof; + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut y, + &[v_hat.to_bytes().as_ref(), c_hat.to_bytes().as_ref()], + ); + let y = OneBased(y); + + let y_bytes = &*(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + y_bytes, + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut delta = [G::Zp::ZERO; 4]; + G::Zp::hash( + &mut delta, + &[ + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let [delta_x, delta_eq, delta_y, delta_v] = delta; + + let mut s = vec![G::Zp::ZERO; n]; + for (i, s) in s.iter_mut().enumerate() { + G::Zp::hash( + core::slice::from_mut(s), + &[ + &i.to_le_bytes(), + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + } + let s = OneBased(s); + + let rhs = e(pi, g_hat); + let lhs = { + let numerator = { + let mut p = c_y.mul_scalar(delta_y); + for i in 1..n + 1 { + let g = g_list[n + 1 - i]; + if i <= l { + p += g.mul_scalar(delta_x * G::Zp::from_u64(1 << (i - 1))); + } + p += g.mul_scalar((delta_eq * t[i] - delta_y) * y[i]); + } + e(p, c_hat) + }; + let denominator_0 = { + let mut p = g_list[n].mul_scalar(delta_x); + for i in 2..n + 1 { + p -= g_list[n + 1 - i].mul_scalar(delta_v * s[i]); + } + e(p, v_hat) + }; + let denominator_1 = { + let mut q = G::G2::ZERO; + for i in 1..n + 1 { + q += g_hat_list[i].mul_scalar(delta_eq * t[i]); + } + e(c_y, q) + }; + numerator - denominator_0 - denominator_1 + }; + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + #[test] + fn test_range() { + let rng = &mut StdRng::seed_from_u64(0); + + let max_nbits = 10; + let l = 6; + let x = rng.gen::() % (1 << l); + let (public_params, _) = crs_gen::(max_nbits, rng); + let (public_commit, private_commit) = commit(x, l, &public_params, rng); + let proof = prove((&public_params, &public_commit), &private_commit, rng); + let verify = verify(&proof, (&public_params, &public_commit)); + assert!(verify.is_ok()); + } +} diff --git a/tfhe-zk-pok/src/proofs/rlwe.rs b/tfhe-zk-pok/src/proofs/rlwe.rs new file mode 100644 index 0000000000..d8cbc96493 --- /dev/null +++ b/tfhe-zk-pok/src/proofs/rlwe.rs @@ -0,0 +1,936 @@ +use super::*; +use core::iter::zip; +use core::marker::PhantomData; +use rayon::prelude::*; + +fn bit_iter(x: u64, nbits: u32) -> impl Iterator { + (0..nbits).map(move |idx| ((x >> idx) & 1) == 1) +} + +#[derive(Clone, Debug)] +pub struct PublicParams { + g_lists: GroupElements, + d: usize, + big_n: usize, + big_m: usize, + b_i: u64, + q: u64, +} + +impl PublicParams { + pub fn from_vec( + g_list: Vec, + g_hat_list: Vec, + d: usize, + big_n: usize, + big_m: usize, + b_i: u64, + q: u64, + ) -> Self { + Self { + g_lists: GroupElements::from_vec(g_list, g_hat_list), + d, + big_n, + big_m, + b_i, + q, + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct PublicCommit { + a: Matrix, + c: Vector, + __marker: PhantomData, +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + s: Vector, + __marker: PhantomData, +} + +#[derive(Clone, Debug)] +pub struct Proof { + c_hat: G::G2, + c_y: G::G1, + pi: G::G1, + c_hat_t: Option, + c_h: Option, + pi_kzg: Option, +} + +pub fn crs_gen( + d: usize, + big_n: usize, + big_m: usize, + b_i: u64, + q: u64, + rng: &mut dyn RngCore, +) -> (PublicParams, PrivateParams) { + let alpha = G::Zp::rand(rng); + let b_r = ((d * big_m) as u64 * b_i) / 2; + let big_d = d * (big_m * (1 + b_i.ilog2() as usize) + (big_n * (1 + b_r.ilog2() as usize))); + let n = big_d + 1; + ( + PublicParams { + g_lists: GroupElements::new(n, alpha), + d, + big_n, + big_m, + b_i, + q, + }, + PrivateParams { alpha }, + ) +} + +#[derive(Clone, Debug)] +pub struct Vector { + pub data: Vec, + pub polynomial_size: usize, + pub nrows: usize, +} +#[derive(Clone, Debug)] +pub struct Matrix { + pub data: Vec, + pub polynomial_size: usize, + pub nrows: usize, + pub ncols: usize, +} + +impl Matrix { + pub fn new(polynomial_size: usize, nrows: usize, ncols: usize, value: T) -> Self { + Self { + data: vec![value; polynomial_size * nrows * ncols], + polynomial_size, + nrows, + ncols, + } + } +} +impl Vector { + pub fn new(polynomial_size: usize, nrows: usize, value: T) -> Self { + Self { + data: vec![value; polynomial_size * nrows], + polynomial_size, + nrows, + } + } +} + +impl Index for Vector { + type Output = [T]; + + fn index(&self, row: usize) -> &Self::Output { + let row = row - 1; + &self.data[self.polynomial_size * row..][..self.polynomial_size] + } +} +impl IndexMut for Vector { + fn index_mut(&mut self, row: usize) -> &mut Self::Output { + let row = row - 1; + &mut self.data[self.polynomial_size * row..][..self.polynomial_size] + } +} + +impl Index<(usize, usize)> for Matrix { + type Output = [T]; + + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + let row = row - 1; + let col = col - 1; + &self.data[self.polynomial_size * (row * self.ncols + col)..][..self.polynomial_size] + } +} +impl IndexMut<(usize, usize)> for Matrix { + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { + let row = row - 1; + let col = col - 1; + &mut self.data[self.polynomial_size * (row * self.ncols + col)..][..self.polynomial_size] + } +} + +pub fn commit( + a: Matrix, + c: Vector, + s: Vector, + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let _ = (public, rng); + ( + PublicCommit { + a, + c, + __marker: PhantomData, + }, + PrivateCommit { + s, + __marker: PhantomData, + }, + ) +} + +pub fn prove( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + load: ComputeLoad, + rng: &mut dyn RngCore, +) -> Proof { + let &PublicParams { + ref g_lists, + d, + big_n, + big_m, + b_i, + q, + } = public.0; + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + let s = &private_commit.s; + let a = &public.1.a; + + let b_r = ((d * big_m) as u64 * b_i) / 2; + let big_d = d * (big_m * (1 + b_i.ilog2() as usize) + (big_n * (1 + b_r.ilog2() as usize))); + let n = big_d + 1; + + let g = G::G1::GENERATOR; + let g_hat = G::G2::GENERATOR; + let gamma = G::Zp::rand(rng); + let gamma_y = G::Zp::rand(rng); + + let mut c = Vector { + data: vec![0i64; d], + polynomial_size: d, + nrows: big_n, + }; + let mut r = Vector { + data: vec![0i64; d], + polynomial_size: d, + nrows: big_n, + }; + + for j in 1..big_n + 1 { + let c = &mut c[j]; + let r = &mut r[j]; + + let mut polymul = vec![0i128; d]; + for i in 1..big_m + 1 { + let si = &s[i]; + let aij = &a[(i, j)]; + + for ii in 0..d { + for jj in 0..d { + let p = (aij[ii] as i128) * si[jj] as i128; + if ii + jj < d { + polymul[ii + jj] += p; + } else { + polymul[ii + jj - d] -= p; + } + } + } + } + + for ((ck, rk), old_ck) in zip(zip(c, r), &polymul) { + let q = if q == 0 { q as i128 } else { 1i128 << 64 }; + let mut new_ck = old_ck.rem_euclid(q); + if new_ck >= q / 2 { + new_ck -= q; + } + assert!((old_ck - new_ck) % q == 0); + assert!((*rk).unsigned_abs() < b_r); + + *ck = new_ck as i64; + *rk = ((old_ck - new_ck) / q) as i64; + } + } + let w_tilde = Iterator::chain( + (1..big_m + 1).flat_map(|i| { + s[i].iter() + .copied() + .flat_map(|x| bit_iter(x as u64, b_i.ilog2() + 1)) + }), + (1..big_n + 1).flat_map(|i| { + r[i].iter() + .copied() + .flat_map(|x| bit_iter(x as u64, b_r.ilog2() + 1)) + }), + ) + .collect::>(); + let mut w = vec![false; n].into_boxed_slice(); + w[..big_d].copy_from_slice(&w_tilde); + let w = OneBased::new_ref(&*w); + + let mut c_hat = g_hat.mul_scalar(gamma); + for j in 1..big_d + 1 { + let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO }; + c_hat += term; + } + + let x_bytes = &*[ + &q.to_le_bytes(), + &(d as u64).to_le_bytes(), + &(big_m as u64).to_le_bytes(), + &(big_n as u64).to_le_bytes(), + &b_i.to_le_bytes(), + &*(1..big_m + 1) + .flat_map(|i| { + (1..big_n + 1).flat_map(move |j| a[(i, j)].iter().flat_map(|ai| ai.to_le_bytes())) + }) + .collect::>(), + &(1..big_n + 1) + .flat_map(|j| c[j].iter().flat_map(|ci| ci.to_le_bytes())) + .collect::>(), + ] + .iter() + .copied() + .flatten() + .copied() + .collect::>(); + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash(&mut y, &[x_bytes, c_hat.to_bytes().as_ref()]); + let y = OneBased(y); + + let scalars = (n + 1 - big_d..n + 1) + .map(|j| (y[n + 1 - j] * G::Zp::from_u64(w[n + 1 - j] as u64))) + .collect::>(); + let c_y = g.mul_scalar(gamma_y) + G::G1::multi_mul_scalar(&g_list.0[n - big_d..n], &scalars); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + &(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(), + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut theta_bar = vec![G::Zp::ZERO; big_n * d + 1]; + G::Zp::hash( + &mut theta_bar, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let theta = (0..big_n * d + 1).map(|k| theta_bar[k]).collect::>(); + let theta0 = theta[..big_n * d].to_vec().into_boxed_slice(); + let delta_theta = theta[big_n * d]; + + let mut t_theta = G::Zp::ZERO; + for j in 0..big_n { + let cj = &c[j + 1]; + let theta0j = &theta0[j * d..][..d]; + for k in 0..d { + t_theta += theta0j[k] * G::Zp::from_i64(cj[k]); + } + } + + let mut a_theta = vec![G::Zp::ZERO; big_d]; + let b_step = 1 + b_i.ilog2() as usize; + let step = d * b_step; + for i in 0..big_m { + // a_theta_i = A_tilde_{i + 1}.T × theta0 + let a_theta_i = &mut a_theta[step * i..][..step]; + + for j in 0..big_n { + let aij = &a[(i + 1, j + 1)]; + let theta0_j = &theta0[d * j..][..d]; + + let mut rot_aij_theta0_j = vec![G::Zp::ZERO; d]; + for p in 0..d { + let mut dot = G::Zp::ZERO; + + for q in 0..d { + let a = if p <= q { + G::Zp::from_i64(aij[q - p]) + } else { + -G::Zp::from_i64(aij[d + q - p]) + }; + dot += a * theta0_j[q]; + } + + rot_aij_theta0_j[p] = dot; + } + + for k in 0..b_step { + let a_theta_ik = &mut a_theta_i[k..]; + let mut c = G::Zp::from_u64(1 << k); + if k + 1 == b_step { + c = -c; + } + + for (dst, src) in zip(a_theta_ik.iter_mut().step_by(b_step), &rot_aij_theta0_j) { + *dst = c * *src; + } + } + } + } + + let offset_m = step * big_m; + let b_step = 1 + b_r.ilog2() as usize; + let step = d * b_step; + for j in 0..big_n { + // a_theta_j -= q G.T theta0_j + let a_theta_j = &mut a_theta[offset_m + step * j..][..step]; + let theta0_j = &theta0[d * j..][..d]; + + for k in 0..b_step { + let a_theta_jk = &mut a_theta_j[k..]; + let mut c = -G::Zp::from_u64(1 << k) * G::Zp::from_u64(q); + if k + 1 == b_step { + c = -c; + } + for (dst, src) in zip(a_theta_jk.iter_mut().step_by(b_step), theta0_j) { + *dst = c * *src; + } + } + } + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + let mut poly_0 = vec![G::Zp::ZERO; n + 1]; + let mut poly_1 = vec![G::Zp::ZERO; n + 1]; + let mut poly_2 = vec![G::Zp::ZERO; n + 1]; + let mut poly_3 = vec![G::Zp::ZERO; n + 1]; + + poly_0[0] = delta_y * gamma_y; + for i in 1..n + 1 { + poly_0[n + 1 - i] = + delta_y * (y[i] * G::Zp::from_u64(w[i] as u64)) + (delta_eq * t[i] - delta_y) * y[i]; + + if i < big_d + 1 { + poly_0[n + 1 - i] += delta_theta * a_theta[i - 1]; + } + } + + poly_1[0] = gamma; + for i in 1..big_d + 1 { + poly_1[i] = G::Zp::from_u64(w[i] as u64); + } + + poly_2[0] = gamma_y; + for i in 1..big_d + 1 { + poly_2[n + 1 - i] = y[i] * G::Zp::from_u64(w[i] as u64); + } + + for i in 1..n + 1 { + poly_3[i] = delta_eq * t[i]; + } + + let mut poly = G::Zp::poly_sub( + &G::Zp::poly_mul(&poly_0, &poly_1), + &G::Zp::poly_mul(&poly_2, &poly_3), + ); + + if poly.len() > n + 1 { + poly[n + 1] -= t_theta * delta_theta; + } + + let pi = + g.mul_scalar(poly[0]) + G::G1::multi_mul_scalar(&g_list.0[..poly.len() - 1], &poly[1..]); + + if load == ComputeLoad::Proof { + let c_hat_t = G::G2::multi_mul_scalar(&g_hat_list.0, &t.0); + let scalars = (1..n + 1) + .into_par_iter() + .map(|i| { + let i = n + 1 - i; + (delta_eq * t[i] - delta_y) * y[i] + + if i < big_d + 1 { + delta_theta * a_theta[i - 1] + } else { + G::Zp::ZERO + } + }) + .collect::>(); + let c_h = G::G1::multi_mul_scalar(&g_list.0[..n], &scalars); + + let mut z = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut z), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + ], + ); + + let mut pow = z; + let mut p_t = G::Zp::ZERO; + let mut p_h = G::Zp::ZERO; + + for i in 1..n + 1 { + p_t += t[i] * pow; + if n - i < big_d { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i] + + delta_theta * a_theta[n - i]) + * pow; + } else { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow; + } + pow = pow * z; + } + + let mut w = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut w), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + z.to_bytes().as_ref(), + p_h.to_bytes().as_ref(), + p_t.to_bytes().as_ref(), + ], + ); + + let mut poly = vec![G::Zp::ZERO; n + 1]; + for i in 1..n + 1 { + poly[i] += w * t[i]; + if i < big_d + 1 { + poly[n + 1 - i] += + (delta_eq * t[i] - delta_y) * y[i] + delta_theta * a_theta[i - 1]; + } else { + poly[n + 1 - i] += (delta_eq * t[i] - delta_y) * y[i]; + } + } + + let mut q = vec![G::Zp::ZERO; n]; + for i in (0..n).rev() { + poly[i] = poly[i] + z * poly[i + 1]; + q[i] = poly[i + 1]; + poly[i + 1] = G::Zp::ZERO; + } + let pi_kzg = g.mul_scalar(q[0]) + G::G1::multi_mul_scalar(&g_list.0[..n - 1], &q[1..n]); + + Proof { + c_hat, + c_y, + pi, + c_hat_t: Some(c_hat_t), + c_h: Some(c_h), + pi_kzg: Some(pi_kzg), + } + } else { + Proof { + c_hat, + c_y, + pi, + c_hat_t: None, + c_h: None, + pi_kzg: None, + } + } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let &Proof { + c_hat, + c_y, + pi, + c_hat_t, + c_h, + pi_kzg, + } = proof; + let e = G::Gt::pairing; + + let &PublicParams { + ref g_lists, + d, + big_n, + big_m, + b_i, + q, + } = public.0; + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let b_r = ((d * big_m) as u64 * b_i) / 2; + let big_d = d * (big_m * (1 + b_i.ilog2() as usize) + (big_n * (1 + b_r.ilog2() as usize))); + let n = big_d + 1; + + let a = &public.1.a; + let c = &public.1.c; + + let x_bytes = &*[ + &q.to_le_bytes(), + &(d as u64).to_le_bytes(), + &(big_m as u64).to_le_bytes(), + &(big_n as u64).to_le_bytes(), + &b_i.to_le_bytes(), + &*(1..big_m + 1) + .flat_map(|i| { + (1..big_n + 1).flat_map(move |j| a[(i, j)].iter().flat_map(|ai| ai.to_le_bytes())) + }) + .collect::>(), + &(1..big_n + 1) + .flat_map(|j| c[j].iter().flat_map(|ci| ci.to_le_bytes())) + .collect::>(), + ] + .iter() + .copied() + .flatten() + .copied() + .collect::>(); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash(&mut y, &[x_bytes, c_hat.to_bytes().as_ref()]); + let y = OneBased(y); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + &(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(), + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut theta_bar = vec![G::Zp::ZERO; big_n * d + 1]; + G::Zp::hash( + &mut theta_bar, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let theta = (0..big_n * d + 1).map(|k| theta_bar[k]).collect::>(); + let theta0 = theta[..big_n * d].to_vec().into_boxed_slice(); + let delta_theta = theta[big_n * d]; + + let mut t_theta = G::Zp::ZERO; + for j in 0..big_n { + let cj = &c[j + 1]; + let theta0j = &theta0[j * d..][..d]; + for k in 0..d { + t_theta += theta0j[k] * G::Zp::from_i64(cj[k]); + } + } + + let mut a_theta = vec![G::Zp::ZERO; big_d]; + let b_step = 1 + b_i.ilog2() as usize; + let step = d * b_step; + for i in 0..big_m { + // a_theta_i = A_tilde_{i + 1}.T × theta0 + let a_theta_i = &mut a_theta[step * i..][..step]; + + for j in 0..big_n { + let aij = &a[(i + 1, j + 1)]; + let theta0_j = &theta0[d * j..][..d]; + + let mut rot_aij_theta0_j = vec![G::Zp::ZERO; d]; + for p in 0..d { + let mut dot = G::Zp::ZERO; + + for q in 0..d { + let a = if p <= q { + G::Zp::from_i64(aij[q - p]) + } else { + -G::Zp::from_i64(aij[d + q - p]) + }; + dot += a * theta0_j[q]; + } + + rot_aij_theta0_j[p] = dot; + } + + for k in 0..b_step { + let a_theta_ik = &mut a_theta_i[k..]; + let mut c = G::Zp::from_u64(1 << k); + if k + 1 == b_step { + c = -c; + } + + for (dst, src) in zip(a_theta_ik.iter_mut().step_by(b_step), &rot_aij_theta0_j) { + *dst = c * *src; + } + } + } + } + + let offset_m = step * big_m; + let b_step = 1 + b_r.ilog2() as usize; + let step = d * b_step; + for j in 0..big_n { + // a_theta_j -= q G.T theta0_j + let a_theta_j = &mut a_theta[offset_m + step * j..][..step]; + let theta0_j = &theta0[d * j..][..d]; + + for k in 0..b_step { + let a_theta_jk = &mut a_theta_j[k..]; + let mut c = -G::Zp::from_u64(1 << k) * G::Zp::from_u64(q); + if k + 1 == b_step { + c = -c; + } + for (dst, src) in zip(a_theta_jk.iter_mut().step_by(b_step), theta0_j) { + *dst = c * *src; + } + } + } + + if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) { + let mut z = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut z), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + ], + ); + + let mut pow = z; + let mut p_t = G::Zp::ZERO; + let mut p_h = G::Zp::ZERO; + + for i in 1..n + 1 { + p_t += t[i] * pow; + if n - i < big_d { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i] + + delta_theta * a_theta[n - i]) + * pow; + } else { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow; + } + pow = pow * z; + } + + if e(pi, G::G2::GENERATOR) + != e(c_y.mul_scalar(delta_y) + c_h, c_hat) + - e(c_y.mul_scalar(delta_eq), c_hat_t) + - e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta) + { + return Err(()); + } + + let mut w = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut w), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + z.to_bytes().as_ref(), + p_h.to_bytes().as_ref(), + p_t.to_bytes().as_ref(), + ], + ); + + if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR) + + e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w) + == e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z)) + { + Ok(()) + } else { + Err(()) + } + } else { + let (term0, term1) = rayon::join( + || { + let p = c_y.mul_scalar(delta_y) + + (1..n + 1) + .into_par_iter() + .map(|i| { + let mut factor = (delta_eq * t[i] - delta_y) * y[i]; + if i < big_d + 1 { + factor += delta_theta * a_theta[i - 1]; + } + g_list[n + 1 - i].mul_scalar(factor) + }) + .sum::(); + let q = c_hat; + e(p, q) + }, + || { + let p = c_y; + let q = (1..n + 1) + .into_par_iter() + .map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i])) + .sum::(); + e(p, q) + }, + ); + let term2 = { + let p = g_list[1]; + let q = g_hat_list[n]; + e(p, q) + }; + + let lhs = e(pi, G::G2::GENERATOR); + let rhs = term0 - term1 - term2.mul_scalar(t_theta * delta_theta); + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + fn time(f: impl FnOnce() -> R) -> R { + let time = std::time::Instant::now(); + let r = f(); + println!("{:?}", time.elapsed()); + r + } + + #[test] + fn test_rlwe() { + let rng = &mut StdRng::seed_from_u64(0); + let d: usize = 2048; + let big_m: usize = 1; + let big_n: usize = 1; + + let q = 1217; + let b_i: u64 = 512; + + let mut a = Matrix::new(d, big_m, big_n, 0i64); + let mut c = Vector::new(d, big_n, 0i64); + let mut s = Vector::new(d, big_m, 0i64); + + for i in 0..big_m { + for k in 0..d { + s[i + 1][k] = (rng.gen::() % (2 * b_i)) as i64 - b_i as i64; + } + } + + for i in 0..big_m { + for j in 0..big_n { + for k in 0..d { + let mut x = (rng.gen::() % q) as i64; + if x >= q as i64 / 2 { + x -= q as i64; + } + a[(i + 1, j + 1)][k] = x; + } + } + } + + for j in 1..big_n + 1 { + let c = &mut c[j]; + + let mut polymul = vec![0i128; d]; + for i in 1..big_m + 1 { + let si = &s[i]; + let aij = &a[(i, j)]; + + for ii in 0..d { + for jj in 0..d { + let p = (aij[ii] as i128) * si[jj] as i128; + if ii + jj < d { + polymul[ii + jj] += p; + } else { + polymul[ii + jj - d] -= p; + } + } + } + } + + for (ck, old_ck) in core::iter::zip(c, &polymul) { + let q = if q == 0 { q as i128 } else { 1i128 << 64 }; + let mut new_ck = old_ck.rem_euclid(q); + if new_ck >= q / 2 { + new_ck -= q; + } + *ck = new_ck as i64; + } + } + + let (public_params, _private_params) = + crs_gen::(d, big_n, big_m, b_i, q, rng); + let (public_commit, private_commit) = commit(a, c, s, &public_params, rng); + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = + time(|| prove((&public_params, &public_commit), &private_commit, load, rng)); + let verify = time(|| verify(&proof, (&public_params, &public_commit))); + assert!(verify.is_ok()); + } + } +}