From 5dcc3917144a133e7b5810e5020d2abd631de136 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 2 Apr 2024 18:46:13 +0200 Subject: [PATCH 1/2] feat(tfhe): add zk-pok code base - integration of work done by Sarah in the repo Co-authored-by: sarah el kazdadi --- .github/workflows/aws_tfhe_fast_tests.yml | 4 + .github/workflows/aws_tfhe_tests.yml | 4 + .github/workflows/m1_tests.yml | 4 + Cargo.toml | 1 + Makefile | 12 +- tfhe-zk-pok/.gitignore | 1 + tfhe-zk-pok/Cargo.toml | 21 + tfhe-zk-pok/src/curve_446/mod.rs | 772 +++++++++++++++ tfhe-zk-pok/src/curve_api.rs | 341 +++++++ tfhe-zk-pok/src/curve_api/bls12_381.rs | 776 +++++++++++++++ tfhe-zk-pok/src/curve_api/bls12_446.rs | 970 +++++++++++++++++++ tfhe-zk-pok/src/lib.rs | 3 + tfhe-zk-pok/src/proofs/binary.rs | 213 +++++ tfhe-zk-pok/src/proofs/index.rs | 121 +++ tfhe-zk-pok/src/proofs/mod.rs | 95 ++ tfhe-zk-pok/src/proofs/pke.rs | 1043 +++++++++++++++++++++ tfhe-zk-pok/src/proofs/range.rs | 354 +++++++ tfhe-zk-pok/src/proofs/rlwe.rs | 932 ++++++++++++++++++ 18 files changed, 5666 insertions(+), 1 deletion(-) create mode 100644 tfhe-zk-pok/.gitignore create mode 100644 tfhe-zk-pok/Cargo.toml create mode 100644 tfhe-zk-pok/src/curve_446/mod.rs create mode 100644 tfhe-zk-pok/src/curve_api.rs create mode 100644 tfhe-zk-pok/src/curve_api/bls12_381.rs create mode 100644 tfhe-zk-pok/src/curve_api/bls12_446.rs create mode 100644 tfhe-zk-pok/src/lib.rs create mode 100644 tfhe-zk-pok/src/proofs/binary.rs create mode 100644 tfhe-zk-pok/src/proofs/index.rs create mode 100644 tfhe-zk-pok/src/proofs/mod.rs create mode 100644 tfhe-zk-pok/src/proofs/pke.rs create mode 100644 tfhe-zk-pok/src/proofs/range.rs create mode 100644 tfhe-zk-pok/src/proofs/rlwe.rs diff --git a/.github/workflows/aws_tfhe_fast_tests.yml b/.github/workflows/aws_tfhe_fast_tests.yml index 30c63e0806..2c434b448b 100644 --- a/.github/workflows/aws_tfhe_fast_tests.yml +++ b/.github/workflows/aws_tfhe_fast_tests.yml @@ -60,6 +60,10 @@ jobs: run: | make test_concrete_csprng + - name: Run tfhe-zk-pok tests + run: | + make test_zk_pok + - name: Run core tests run: | AVX512_SUPPORT=ON make test_core_crypto diff --git a/.github/workflows/aws_tfhe_tests.yml b/.github/workflows/aws_tfhe_tests.yml index 86107425c8..cc0108e658 100644 --- a/.github/workflows/aws_tfhe_tests.yml +++ b/.github/workflows/aws_tfhe_tests.yml @@ -61,6 +61,10 @@ jobs: run: | make test_concrete_csprng + - name: Run tfhe-zk-pok tests + run: | + make test_zk_pok + - name: Run core tests run: | AVX512_SUPPORT=ON make test_core_crypto diff --git a/.github/workflows/m1_tests.yml b/.github/workflows/m1_tests.yml index 95be745d0c..ebbba24ad7 100644 --- a/.github/workflows/m1_tests.yml +++ b/.github/workflows/m1_tests.yml @@ -74,6 +74,10 @@ jobs: run: | make test_concrete_csprng + - name: Run tfhe-zk-pok tests + run: | + make test_zk_pok + - name: Run core tests run: | make test_core_crypto diff --git a/Cargo.toml b/Cargo.toml index b569ef857e..1fdc61b275 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ "tfhe", + "tfhe-zk-pok", "tasks", "apps/trivium", "concrete-csprng", diff --git a/Makefile b/Makefile index b5e78cf001..9976c821fa 100644 --- a/Makefile +++ b/Makefile @@ -283,9 +283,14 @@ clippy_concrete_csprng: --features=$(TARGET_ARCH_FEATURE) \ -p concrete-csprng -- --no-deps -D warnings +.PHONY: clippy_zk_pok # Run clippy lints on tfhe-zk-pok +clippy_zk_pok: + RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \ + -p tfhe-zk-pok -- --no-deps -D warnings + .PHONY: clippy_all # Run all clippy targets clippy_all: clippy clippy_boolean clippy_shortint clippy_integer clippy_all_targets clippy_c_api \ -clippy_js_wasm_api clippy_tasks clippy_core clippy_concrete_csprng clippy_trivium +clippy_js_wasm_api clippy_tasks clippy_core clippy_concrete_csprng clippy_zk_pok clippy_trivium .PHONY: clippy_fast # Run main clippy targets clippy_fast: clippy clippy_all_targets clippy_c_api clippy_js_wasm_api clippy_tasks clippy_core \ @@ -628,6 +633,11 @@ test_concrete_csprng: RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ --features=$(TARGET_ARCH_FEATURE) -p concrete-csprng +.PHONY: test_zk_pok # Run tfhe-zk-pok tests +test_zk_pok: + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ + -p tfhe-zk-pok + .PHONY: doc # Build rust doc doc: install_rs_check_toolchain @# Even though we are not in docs.rs, this allows to "just" build the doc diff --git a/tfhe-zk-pok/.gitignore b/tfhe-zk-pok/.gitignore new file mode 100644 index 0000000000..ea8c4bf7f3 --- /dev/null +++ b/tfhe-zk-pok/.gitignore @@ -0,0 +1 @@ +/target diff --git a/tfhe-zk-pok/Cargo.toml b/tfhe-zk-pok/Cargo.toml new file mode 100644 index 0000000000..b1b0abc9cb --- /dev/null +++ b/tfhe-zk-pok/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tfhe-zk-pok" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ark-bls12-381 = "0.4.0" +ark-ec = "0.4.2" +ark-ff = "0.4.2" +ark-poly = "0.4.2" +rand = "0.8.5" +rayon = "1.8.0" +sha3 = "0.10.8" +serde = { version = "~1.0", features = ["derive"] } +ark-serialize = { version = "0.4.2" } +zeroize = "1.7.0" + +[dev-dependencies] +serde_json = "~1.0" diff --git a/tfhe-zk-pok/src/curve_446/mod.rs b/tfhe-zk-pok/src/curve_446/mod.rs new file mode 100644 index 0000000000..96fcecb9d0 --- /dev/null +++ b/tfhe-zk-pok/src/curve_446/mod.rs @@ -0,0 +1,772 @@ +use ark_ec::bls12::{Bls12, Bls12Config, TwistType}; +use ark_ff::fields::*; +use ark_ff::MontFp; + +#[derive(MontConfig)] +#[modulus = "645383785691237230677916041525710377746967055506026847120930304831624105190538527824412673"] +#[generator = "7"] +#[small_subgroup_base = "3"] +#[small_subgroup_power = "1"] +pub struct FrConfig; +pub type Fr = Fp320>; + +#[derive(MontConfig)] +#[modulus = "172824703542857155980071276579495962243492693522789898437834836356385656662277472896902502740297183690175962001546428467344062165330603"] +#[generator = "2"] +#[small_subgroup_base = "3"] +#[small_subgroup_power = "1"] +pub struct FqConfig; +pub type Fq = Fp448>; + +pub type Fq2 = Fp2; + +pub struct Fq2Config; + +impl Fp2Config for Fq2Config { + type Fp = Fq; + + /// NONRESIDUE = -1 + const NONRESIDUE: Fq = MontFp!("-1"); + + /// Coefficients for the Frobenius automorphism. + const FROBENIUS_COEFF_FP2_C1: &'static [Fq] = &[ + // Fq(-1)**(((q^0) - 1) / 2) + Fq::ONE, + // Fq(-1)**(((q^1) - 1) / 2) + MontFp!("-1"), + ]; + + #[inline(always)] + fn mul_fp_by_nonresidue_in_place(fp: &mut Self::Fp) -> &mut Self::Fp { + fp.neg_in_place() + } + + #[inline(always)] + fn sub_and_mul_fp_by_nonresidue(y: &mut Self::Fp, x: &Self::Fp) { + *y += x; + } + + #[inline(always)] + fn mul_fp_by_nonresidue_plus_one_and_add(y: &mut Self::Fp, x: &Self::Fp) { + *y = *x; + } + + fn mul_fp_by_nonresidue_and_add(y: &mut Self::Fp, x: &Self::Fp) { + y.neg_in_place(); + *y += x; + } +} + +pub type Fq6 = Fp6; + +#[derive(Clone, Copy)] +pub struct Fq6Config; + +impl Fp6Config for Fq6Config { + type Fp2Config = Fq2Config; + + /// NONRESIDUE = (U + 1) + const NONRESIDUE: Fq2 = Fq2::new(Fq::ONE, Fq::ONE); + + const FROBENIUS_COEFF_FP6_C1: &'static [Fq2] = &[ + // Fp2::NONRESIDUE^(((q^0) - 1) / 3) + Fq2::new( + Fq::ONE, + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^1) - 1) / 3) + Fq2::new( + Fq::ZERO, + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + ), + // Fp2::NONRESIDUE^(((q^2) - 1) / 3) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^3) - 1) / 3) + Fq2::new( + Fq::ZERO, + Fq::ONE, + ), + // Fp2::NONRESIDUE^(((q^4) - 1) / 3) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^5) - 1) / 3) + Fq2::new( + Fq::ZERO, + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + ), + ]; + + const FROBENIUS_COEFF_FP6_C2: &'static [Fq2] = &[ + // Fq2(u + 1)**(((2q^0) - 2) / 3) + Fq2::new( + Fq::ONE, + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^1) - 2) / 3) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^2) - 2) / 3) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^3) - 2) / 3) + Fq2::new( + MontFp!("-1"), + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^4) - 2) / 3) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fq2(u + 1)**(((2q^5) - 2) / 3) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + ]; + + /// Multiply this element by the quadratic nonresidue 1 + u. + /// Make this generic. + fn mul_fp2_by_nonresidue_in_place(fe: &mut Fq2) -> &mut Fq2 { + let t0 = fe.c0; + fe.c0 -= &fe.c1; + fe.c1 += &t0; + fe + } +} + +pub type Fq12 = Fp12; + +#[derive(Clone, Copy)] +pub struct Fq12Config; + +impl Fp12Config for Fq12Config { + type Fp6Config = Fq6Config; + + const NONRESIDUE: Fq6 = Fq6::new(Fq2::ZERO, Fq2::ONE, Fq2::ZERO); + + const FROBENIUS_COEFF_FP12_C1: &'static [Fq2] = &[ + // Fp2::NONRESIDUE^(((q^0) - 1) / 6) + Fq2::new( + Fq::ONE, + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^1) - 1) / 6) + Fq2::new( + MontFp!("22118644822122453894295732432166425368368980329889476319266915965514828099635526724748286229964921634997234117686841299669336163301597"), + MontFp!("-22118644822122453894295732432166425368368980329889476319266915965514828099635526724748286229964921634997234117686841299669336163301597"), + ), + // Fp2::NONRESIDUE^(((q^2) - 1) / 6) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^3) - 1) / 6) + Fq2::new( + MontFp!("-84459159508829117195668840503504856816171858703899096210464197465513610215112549935889502423482516188066933947513637464187184810836060"), + MontFp!("84459159508829117195668840503504856816171858703899096210464197465513610215112549935889502423482516188066933947513637464187184810836060"), + ), + // Fp2::NONRESIDUE^(((q^4) - 1) / 6) + Fq2::new( + MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^5) - 1) / 6) + Fq2::new( + MontFp!("66246899211905584890106703643824680058951854489001325908103722925357218347529396236264714086849745867111793936345949703487541191192946"), + MontFp!("-66246899211905584890106703643824680058951854489001325908103722925357218347529396236264714086849745867111793936345949703487541191192946"), + ), + // Fp2::NONRESIDUE^(((q^6) - 1) / 6) + Fq2::new( + MontFp!("-1"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^7) - 1) / 6) + Fq2::new( + MontFp!("-22118644822122453894295732432166425368368980329889476319266915965514828099635526724748286229964921634997234117686841299669336163301597"), + MontFp!("22118644822122453894295732432166425368368980329889476319266915965514828099635526724748286229964921634997234117686841299669336163301597"), + ), + // Fp2::NONRESIDUE^(((q^8) - 1) / 6) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051013"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^9) - 1) / 6) + Fq2::new( + MontFp!("84459159508829117195668840503504856816171858703899096210464197465513610215112549935889502423482516188066933947513637464187184810836060"), + MontFp!("-84459159508829117195668840503504856816171858703899096210464197465513610215112549935889502423482516188066933947513637464187184810836060"), + ), + // Fp2::NONRESIDUE^(((q^10) - 1) / 6) + Fq2::new( + MontFp!("-18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"), + Fq::ZERO, + ), + // Fp2::NONRESIDUE^(((q^11) - 1) / 6) + Fq2::new( + MontFp!("-66246899211905584890106703643824680058951854489001325908103722925357218347529396236264714086849745867111793936345949703487541191192946"), + MontFp!("66246899211905584890106703643824680058951854489001325908103722925357218347529396236264714086849745867111793936345949703487541191192946"), + ), + ]; +} + +pub type Bls12_446 = Bls12; +use g1::G1Affine; +use g2::G2Affine; + +pub struct Config; + +impl Bls12Config for Config { + const X: &'static [u64] = &[0x8204000000020001, 0x600]; + const X_IS_NEGATIVE: bool = true; + const TWIST_TYPE: TwistType = TwistType::M; + type Fp = Fq; + type Fp2Config = Fq2Config; + type Fp6Config = Fq6Config; + type Fp12Config = Fq12Config; + type G1Config = g1::Config; + type G2Config = g2::Config; +} + +pub mod util { + use ark_ec::short_weierstrass::Affine; + use ark_ec::AffineRepr; + use ark_ff::{BigInteger448, PrimeField}; + use ark_serialize::SerializationError; + + use super::g1::Config as G1Config; + use super::g2::Config as G2Config; + use super::{Fq, Fq2, G1Affine, G2Affine}; + + pub const G1_SERIALIZED_SIZE: usize = 57; + pub const G2_SERIALIZED_SIZE: usize = 114; + + pub struct EncodingFlags { + pub is_compressed: bool, + pub is_infinity: bool, + pub is_lexographically_largest: bool, + } + + impl EncodingFlags { + pub fn get_flags(bytes: &[u8]) -> Self { + let compression_flag_set = (bytes[0] >> 7) & 1; + let infinity_flag_set = (bytes[0] >> 6) & 1; + let sort_flag_set = (bytes[0] >> 5) & 1; + + Self { + is_compressed: compression_flag_set == 1, + is_infinity: infinity_flag_set == 1, + is_lexographically_largest: sort_flag_set == 1, + } + } + pub fn encode_flags(&self, bytes: &mut [u8]) { + if self.is_compressed { + bytes[0] |= 1 << 7; + } + + if self.is_infinity { + bytes[0] |= 1 << 6; + } + + if self.is_compressed && !self.is_infinity && self.is_lexographically_largest { + bytes[0] |= 1 << 5; + } + } + } + + pub(crate) fn deserialize_fq(bytes: [u8; 56]) -> Option { + let mut tmp = BigInteger448::new([0, 0, 0, 0, 0, 0, 0]); + + // Note: The following unwraps are if the compiler cannot convert + // the byte slice into [u8;8], we know this is infallible since we + // are providing the indices at compile time and bytes has a fixed size + tmp.0[6] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()); + tmp.0[5] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()); + tmp.0[4] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()); + tmp.0[3] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()); + tmp.0[2] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()); + tmp.0[1] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()); + tmp.0[0] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()); + + Fq::from_bigint(tmp) + } + + pub(crate) fn serialize_fq(field: Fq) -> [u8; 56] { + let mut result = [0u8; 56]; + + let rep = field.into_bigint(); + + result[0..8].copy_from_slice(&rep.0[6].to_be_bytes()); + result[8..16].copy_from_slice(&rep.0[5].to_be_bytes()); + result[16..24].copy_from_slice(&rep.0[4].to_be_bytes()); + result[24..32].copy_from_slice(&rep.0[3].to_be_bytes()); + result[32..40].copy_from_slice(&rep.0[2].to_be_bytes()); + result[40..48].copy_from_slice(&rep.0[1].to_be_bytes()); + result[48..56].copy_from_slice(&rep.0[0].to_be_bytes()); + + result + } + + pub(crate) fn read_fq_with_offset( + bytes: &[u8], + offset: usize, + ) -> Result { + let mut tmp = [0; G1_SERIALIZED_SIZE - 1]; + // read `G1_SERIALIZED_SIZE` bytes + tmp.copy_from_slice( + &bytes[offset * G1_SERIALIZED_SIZE + 1..G1_SERIALIZED_SIZE * (offset + 1)], + ); + + deserialize_fq(tmp).ok_or(SerializationError::InvalidData) + } + + pub(crate) fn read_g1_compressed( + mut reader: R, + ) -> Result, ark_serialize::SerializationError> { + let mut bytes = [0u8; G1_SERIALIZED_SIZE]; + reader + .read_exact(&mut bytes) + .ok() + .ok_or(SerializationError::InvalidData)?; + + // Obtain the three flags from the start of the byte sequence + let flags = EncodingFlags::get_flags(&bytes[..]); + + // we expect to be deserializing a compressed point + if !flags.is_compressed { + return Err(SerializationError::UnexpectedFlags); + } + + if flags.is_infinity { + return Ok(G1Affine::zero()); + } + + // Attempt to obtain the x-coordinate + let x = read_fq_with_offset(&bytes, 0)?; + + let p = G1Affine::get_point_from_x_unchecked(x, flags.is_lexographically_largest) + .ok_or(SerializationError::InvalidData)?; + + Ok(p) + } + + pub(crate) fn read_g1_uncompressed( + mut reader: R, + ) -> Result, ark_serialize::SerializationError> { + let mut bytes = [0u8; 2 * G1_SERIALIZED_SIZE]; + reader + .read_exact(&mut bytes) + .map_err(|_| SerializationError::InvalidData)?; + + // Obtain the three flags from the start of the byte sequence + let flags = EncodingFlags::get_flags(&bytes[..]); + + // we expect to be deserializing an uncompressed point + if flags.is_compressed { + return Err(SerializationError::UnexpectedFlags); + } + + if flags.is_infinity { + return Ok(G1Affine::zero()); + } + + // Attempt to obtain the x-coordinate + let x = read_fq_with_offset(&bytes, 0)?; + // Attempt to obtain the y-coordinate + let y = read_fq_with_offset(&bytes, 1)?; + + let p = G1Affine::new_unchecked(x, y); + + Ok(p) + } + + pub(crate) fn read_g2_compressed( + mut reader: R, + ) -> Result, ark_serialize::SerializationError> { + let mut bytes = [0u8; G2_SERIALIZED_SIZE]; + reader + .read_exact(&mut bytes) + .map_err(|_| SerializationError::InvalidData)?; + + // Obtain the three flags from the start of the byte sequence + let flags = EncodingFlags::get_flags(&bytes); + + // we expect to be deserializing a compressed point + if !flags.is_compressed { + return Err(SerializationError::UnexpectedFlags); + } + + if flags.is_infinity { + return Ok(G2Affine::zero()); + } + + // Attempt to obtain the x-coordinate + let xc1 = read_fq_with_offset(&bytes, 0)?; + let xc0 = read_fq_with_offset(&bytes, 1)?; + + let x = Fq2::new(xc0, xc1); + + let p = G2Affine::get_point_from_x_unchecked(x, flags.is_lexographically_largest) + .ok_or(SerializationError::InvalidData)?; + + Ok(p) + } + + pub(crate) fn read_g2_uncompressed( + mut reader: R, + ) -> Result, ark_serialize::SerializationError> { + let mut bytes = [0u8; 2 * G2_SERIALIZED_SIZE]; + reader + .read_exact(&mut bytes) + .map_err(|_| SerializationError::InvalidData)?; + + // Obtain the three flags from the start of the byte sequence + let flags = EncodingFlags::get_flags(&bytes); + + // we expect to be deserializing an uncompressed point + if flags.is_compressed { + return Err(SerializationError::UnexpectedFlags); + } + + if flags.is_infinity { + return Ok(G2Affine::zero()); + } + + // Attempt to obtain the x-coordinate + let xc1 = read_fq_with_offset(&bytes, 0)?; + let xc0 = read_fq_with_offset(&bytes, 1)?; + let x = Fq2::new(xc0, xc1); + + // Attempt to obtain the y-coordinate + let yc1 = read_fq_with_offset(&bytes, 2)?; + let yc0 = read_fq_with_offset(&bytes, 3)?; + let y = Fq2::new(yc0, yc1); + + let p = G2Affine::new_unchecked(x, y); + + Ok(p) + } +} + +pub mod g1 { + use super::util::{ + read_g1_compressed, read_g1_uncompressed, serialize_fq, EncodingFlags, G1_SERIALIZED_SIZE, + }; + use super::{Fq, Fr}; + use ark_ec::bls12::Bls12Config; + use ark_ec::models::CurveConfig; + use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; + use ark_ec::{bls12, AffineRepr, Group}; + use ark_ff::{Field, MontFp, One, PrimeField, Zero}; + use ark_serialize::{Compress, SerializationError}; + use core::ops::Neg; + + #[derive(Clone, Default, PartialEq, Eq)] + pub struct Config; + + impl CurveConfig for Config { + type BaseField = Fq; + type ScalarField = Fr; + + /// COFACTOR = (x - 1)^2 / 3 = 267785939737784928360481681640896166738700972 + const COFACTOR: &'static [u64] = &[0xad5aaaac0002aaac, 0x2602b0055d560ab0, 0xc0208]; + + /// COFACTOR_INV = COFACTOR^{-1} mod r + /// = 645383785691237230677779421366207365261112665008071669867241543525136277620937226389553150 + const COFACTOR_INV: Fr = + MontFp!("645383785691237230677779421366207365261112665008071669867241543525136277620937226389553150"); + } + + pub type G1Affine = bls12::G1Affine; + pub type G1Projective = bls12::G1Projective; + + impl SWCurveConfig for Config { + /// COEFF_A = 0 + const COEFF_A: Fq = Fq::ZERO; + + /// COEFF_B = 1 + const COEFF_B: Fq = MontFp!("1"); + + /// AFFINE_GENERATOR_COEFFS = (G1_GENERATOR_X, G1_GENERATOR_Y) + const GENERATOR: G1Affine = G1Affine::new_unchecked(G1_GENERATOR_X, G1_GENERATOR_Y); + + #[inline(always)] + fn mul_by_a(_: Self::BaseField) -> Self::BaseField { + Self::BaseField::zero() + } + + #[inline] + fn is_in_correct_subgroup_assuming_on_curve(p: &G1Affine) -> bool { + // Algorithm from Section 6 of https://eprint.iacr.org/2021/1130. + // + // Check that endomorphism_p(P) == -[X^2]P + + // An early-out optimization described in Section 6. + // If uP == P but P != point of infinity, then the point is not in the right + // subgroup. + let x_times_p = p.mul_bigint(super::Config::X); + if x_times_p.eq(p) && !p.infinity { + return false; + } + + let minus_x_squared_times_p = x_times_p.mul_bigint(super::Config::X).neg(); + let endomorphism_p = endomorphism(p); + minus_x_squared_times_p.eq(&endomorphism_p) + } + + #[inline] + fn clear_cofactor(p: &G1Affine) -> G1Affine { + // Using the effective cofactor, as explained in + // Section 5 of https://eprint.iacr.org/2019/403.pdf. + // + // It is enough to multiply by (1 - x), instead of (x - 1)^2 / 3 + let h_eff = one_minus_x().into_bigint(); + Config::mul_affine(p, h_eff.as_ref()).into() + } + + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result, ark_serialize::SerializationError> { + let p = if compress == ark_serialize::Compress::Yes { + read_g1_compressed(&mut reader)? + } else { + read_g1_uncompressed(&mut reader)? + }; + + if validate == ark_serialize::Validate::Yes + && !p.is_in_correct_subgroup_assuming_on_curve() + { + return Err(SerializationError::InvalidData); + } + Ok(p) + } + + fn serialize_with_mode( + item: &Affine, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), SerializationError> { + let encoding = EncodingFlags { + is_compressed: compress == ark_serialize::Compress::Yes, + is_infinity: item.is_zero(), + is_lexographically_largest: item.y > -item.y, + }; + let mut p = *item; + if encoding.is_infinity { + p = G1Affine::zero(); + } + // need to access the field struct `x` directly, otherwise we get None from xy() + // method + let x_bytes = serialize_fq(p.x); + if encoding.is_compressed { + let mut bytes = [0u8; G1_SERIALIZED_SIZE]; + bytes[1..].copy_from_slice(&x_bytes); + + encoding.encode_flags(&mut bytes); + writer.write_all(&bytes)?; + } else { + let mut bytes = [0u8; 2 * G1_SERIALIZED_SIZE]; + bytes[1..1 + G1_SERIALIZED_SIZE].copy_from_slice(&x_bytes[..]); + bytes[2 + G1_SERIALIZED_SIZE..].copy_from_slice(&serialize_fq(p.y)[..]); + + encoding.encode_flags(&mut bytes); + writer.write_all(&bytes)?; + }; + + Ok(()) + } + + fn serialized_size(compress: Compress) -> usize { + if compress == Compress::Yes { + G1_SERIALIZED_SIZE + } else { + G1_SERIALIZED_SIZE * 2 + } + } + } + + fn one_minus_x() -> Fr { + const X: Fr = Fr::from_sign_and_limbs(!super::Config::X_IS_NEGATIVE, super::Config::X); + Fr::one() - X + } + + /// G1_GENERATOR_X = + /// 143189966182216199425404656824735381247272236095050141599848381692039676741476615087722874458136990266833440576646963466074693171606778 + pub const G1_GENERATOR_X: Fq = MontFp!("143189966182216199425404656824735381247272236095050141599848381692039676741476615087722874458136990266833440576646963466074693171606778"); + + /// G1_GENERATOR_Y = + /// 75202396197342917254523279069469674666303680671605970245803554133573745859131002231546341942288521574682619325841484506619191207488304 + pub const G1_GENERATOR_Y: Fq = MontFp!("75202396197342917254523279069469674666303680671605970245803554133573745859131002231546341942288521574682619325841484506619191207488304"); + + /// BETA is a non-trivial cubic root of unity in Fq. + pub const BETA: Fq = MontFp!("18292478899820133222385880210918854254706405831091403105831645830694649873798259945392135397923436410689931051012"); + + pub fn endomorphism(p: &Affine) -> Affine { + // Endomorphism of the points on the curve. + // endomorphism_p(x,y) = (BETA * x, y) + // where BETA is a non-trivial cubic root of unity in Fq. + let mut res = *p; + res.x *= BETA; + res + } +} + +pub mod g2 { + use super::util::{ + read_g2_compressed, read_g2_uncompressed, serialize_fq, EncodingFlags, G2_SERIALIZED_SIZE, + }; + use super::*; + use ark_ec::models::CurveConfig; + use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; + use ark_ec::{bls12, AffineRepr}; + use ark_ff::{MontFp, Zero}; + use ark_serialize::{Compress, SerializationError}; + + pub type G2Affine = bls12::G2Affine; + pub type G2Projective = bls12::G2Projective; + + #[derive(Clone, Default, PartialEq, Eq)] + pub struct Config; + + impl CurveConfig for Config { + type BaseField = Fq2; + type ScalarField = Fr; + + /// COFACTOR = (x^8 - 4 x^7 + 5 x^6) - (4 x^4 + 6 x^3 - 4 x^2 - 4 x + 13) // + /// 9 + /// = 46280025648128091779281203587029183771098593081950199160533444883894201638329761721685747232785203763275581499269893683911356926248942802726857101798724933488377584092259436345573 + const COFACTOR: &'static [u64] = &[ + 0xce555594000638e5, + 0xa75088593e6a92ef, + 0xc81e026dd55b51d6, + 0x47f8e24b79369c54, + 0x74c3560ced298d51, + 0x7cefe5c3dd2555cb, + 0x657742bf55690156, + 0x5780484639bf731d, + 0x3988a06f1bb3444d, + 0x2daee, + ]; + + /// COFACTOR_INV = COFACTOR^{-1} mod r + /// 420747440395392227734782296805460539842466911252881029283882861015362447833828968293150382 + const COFACTOR_INV: Fr = MontFp!( + "420747440395392227734782296805460539842466911252881029283882861015362447833828968293150382" + ); + } + + impl SWCurveConfig for Config { + /// COEFF_A = [0, 0] + const COEFF_A: Fq2 = Fq2::new(g1::Config::COEFF_A, g1::Config::COEFF_A); + + /// COEFF_B = [1, 1] + const COEFF_B: Fq2 = Fq2::new(g1::Config::COEFF_B, g1::Config::COEFF_B); + + /// AFFINE_GENERATOR_COEFFS = (G2_GENERATOR_X, G2_GENERATOR_Y) + const GENERATOR: G2Affine = G2Affine::new_unchecked(G2_GENERATOR_X, G2_GENERATOR_Y); + + #[inline(always)] + fn mul_by_a(_: Self::BaseField) -> Self::BaseField { + Self::BaseField::zero() + } + + fn deserialize_with_mode( + mut reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result, ark_serialize::SerializationError> { + let p = if compress == ark_serialize::Compress::Yes { + read_g2_compressed(&mut reader)? + } else { + read_g2_uncompressed(&mut reader)? + }; + + if validate == ark_serialize::Validate::Yes + && !p.is_in_correct_subgroup_assuming_on_curve() + { + return Err(SerializationError::InvalidData); + } + Ok(p) + } + + fn serialize_with_mode( + item: &Affine, + mut writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), SerializationError> { + let encoding = EncodingFlags { + is_compressed: compress == ark_serialize::Compress::Yes, + is_infinity: item.is_zero(), + is_lexographically_largest: item.y > -item.y, + }; + let mut p = *item; + if encoding.is_infinity { + p = G2Affine::zero(); + } + + let mut x_bytes = [0u8; G2_SERIALIZED_SIZE]; + let c1_bytes = serialize_fq(p.x.c1); + let c0_bytes = serialize_fq(p.x.c0); + x_bytes[1..56 + 1].copy_from_slice(&c1_bytes[..]); + x_bytes[56 + 2..114].copy_from_slice(&c0_bytes[..]); + if encoding.is_compressed { + let mut bytes: [u8; G2_SERIALIZED_SIZE] = x_bytes; + + encoding.encode_flags(&mut bytes); + writer.write_all(&bytes)?; + } else { + let mut bytes = [0u8; 2 * G2_SERIALIZED_SIZE]; + + let mut y_bytes = [0u8; G2_SERIALIZED_SIZE]; + let c1_bytes = serialize_fq(p.y.c1); + let c0_bytes = serialize_fq(p.y.c0); + y_bytes[1..56 + 1].copy_from_slice(&c1_bytes[..]); + y_bytes[56 + 2..114].copy_from_slice(&c0_bytes[..]); + bytes[0..G2_SERIALIZED_SIZE].copy_from_slice(&x_bytes); + bytes[G2_SERIALIZED_SIZE..].copy_from_slice(&y_bytes); + + encoding.encode_flags(&mut bytes); + writer.write_all(&bytes)?; + }; + + Ok(()) + } + + fn serialized_size(compress: ark_serialize::Compress) -> usize { + if compress == Compress::Yes { + G2_SERIALIZED_SIZE + } else { + 2 * G2_SERIALIZED_SIZE + } + } + } + + pub const G2_GENERATOR_X: Fq2 = Fq2::new(G2_GENERATOR_X_C0, G2_GENERATOR_X_C1); + pub const G2_GENERATOR_Y: Fq2 = Fq2::new(G2_GENERATOR_Y_C0, G2_GENERATOR_Y_C1); + + /// G2_GENERATOR_X_C0 = + /// 96453755443802578867745476081903764610578492683850270111202389209355548711427786327510993588141991264564812146530214503491136289085725 + pub const G2_GENERATOR_X_C0: Fq = MontFp!("96453755443802578867745476081903764610578492683850270111202389209355548711427786327510993588141991264564812146530214503491136289085725"); + + /// G2_GENERATOR_X_C1 = + /// 85346509177292795277012009839788781950274202400882571466460158277083221521663169974265433098009350061415973662678938824527658049065530 + pub const G2_GENERATOR_X_C1: Fq = MontFp!("85346509177292795277012009839788781950274202400882571466460158277083221521663169974265433098009350061415973662678938824527658049065530"); + + /// G2_GENERATOR_Y_C0 = + /// 49316184343270950587272132771103279293158283984999436491292404103501221698714795975575879957605051223501287444864258801515822358837529 + pub const G2_GENERATOR_Y_C0: Fq = MontFp!("49316184343270950587272132771103279293158283984999436491292404103501221698714795975575879957605051223501287444864258801515822358837529"); + + /// G2_GENERATOR_Y_C1 = + /// 107680854723992552431070996218129928499826544031468382031848626814251381379173928074140221537929995580031433096217223703806029068859074 + pub const G2_GENERATOR_Y_C1: Fq = MontFp!("107680854723992552431070996218129928499826544031468382031848626814251381379173928074140221537929995580031433096217223703806029068859074"); +} diff --git a/tfhe-zk-pok/src/curve_api.rs b/tfhe-zk-pok/src/curve_api.rs new file mode 100644 index 0000000000..88a9ac47ef --- /dev/null +++ b/tfhe-zk-pok/src/curve_api.rs @@ -0,0 +1,341 @@ +use ark_ec::{CurveGroup, Group, VariableBaseMSM}; +use ark_ff::{BigInt, Field, MontFp, Zero}; +use ark_poly::univariate::DensePolynomial; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; +use core::fmt; +use core::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; +use serde::{Deserialize, Serialize}; + +fn ark_se(a: &A, s: S) -> Result +where + S: serde::Serializer, +{ + let mut bytes = vec![]; + a.serialize_with_mode(&mut bytes, Compress::Yes) + .map_err(serde::ser::Error::custom)?; + s.serialize_bytes(&bytes) +} + +fn ark_de<'de, D, A: CanonicalDeserialize>(data: D) -> Result +where + D: serde::de::Deserializer<'de>, +{ + let s: Vec = serde::de::Deserialize::deserialize(data)?; + let a = A::deserialize_with_mode(s.as_slice(), Compress::Yes, Validate::Yes); + a.map_err(serde::de::Error::custom) +} + +struct MontIntDisplay<'a, T>(&'a T); + +impl fmt::Debug for MontIntDisplay<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if *self.0 == T::ZERO { + f.write_str("0") + } else { + f.write_fmt(format_args!("{}", self.0)) + } + } +} + +pub mod bls12_381; +pub mod bls12_446; + +pub trait FieldOps: + Copy + + Send + + Sync + + core::ops::AddAssign + + core::ops::SubAssign + + core::ops::Add + + core::ops::Sub + + core::ops::Mul + + core::ops::Div + + core::ops::Neg + + core::iter::Sum +{ + const ZERO: Self; + const ONE: Self; + + fn from_u128(n: u128) -> Self; + fn from_u64(n: u64) -> Self; + fn from_i64(n: i64) -> Self; + fn to_bytes(self) -> impl AsRef<[u8]>; + fn rand(rng: &mut dyn rand::RngCore) -> Self; + fn hash(values: &mut [Self], data: &[&[u8]]); + fn poly_mul(p: &[Self], q: &[Self]) -> Vec; + fn poly_sub(p: &[Self], q: &[Self]) -> Vec { + use core::iter::zip; + let mut out = vec![Self::ZERO; Ord::max(p.len(), q.len())]; + + for (out, (p, q)) in zip( + &mut out, + zip( + p.iter().copied().chain(core::iter::repeat(Self::ZERO)), + q.iter().copied().chain(core::iter::repeat(Self::ZERO)), + ), + ) { + *out = p - q; + } + + out + } + fn poly_add(p: &[Self], q: &[Self]) -> Vec { + use core::iter::zip; + let mut out = vec![Self::ZERO; Ord::max(p.len(), q.len())]; + + for (out, (p, q)) in zip( + &mut out, + zip( + p.iter().copied().chain(core::iter::repeat(Self::ZERO)), + q.iter().copied().chain(core::iter::repeat(Self::ZERO)), + ), + ) { + *out = p + q; + } + + out + } +} + +pub trait CurveGroupOps: + Copy + + Send + + Sync + + core::fmt::Debug + + core::ops::AddAssign + + core::ops::SubAssign + + core::ops::Add + + core::ops::Sub + + core::ops::Neg + + core::iter::Sum +{ + const ZERO: Self; + const GENERATOR: Self; + const BYTE_SIZE: usize; + + fn mul_scalar(self, scalar: Zp) -> Self; + fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self; + fn to_bytes(self) -> impl AsRef<[u8]>; + fn double(self) -> Self; +} + +pub trait PairingGroupOps: + Copy + + Send + + Sync + + PartialEq + + core::fmt::Debug + + core::ops::AddAssign + + core::ops::SubAssign + + core::ops::Add + + core::ops::Sub + + core::ops::Neg +{ + fn mul_scalar(self, scalar: Zp) -> Self; + fn pairing(x: G1, y: G2) -> Self; +} + +pub trait Curve { + type Zp: FieldOps; + type G1: CurveGroupOps + serde::Serialize + for<'de> serde::Deserialize<'de>; + type G2: CurveGroupOps + serde::Serialize + for<'de> serde::Deserialize<'de>; + type Gt: PairingGroupOps; +} + +impl FieldOps for bls12_381::Zp { + const ZERO: Self = Self::ZERO; + const ONE: Self = Self::ONE; + + fn from_u128(n: u128) -> Self { + Self::from_bigint([n as u64, (n >> 64) as u64, 0, 0]) + } + fn from_u64(n: u64) -> Self { + Self::from_u64(n) + } + fn from_i64(n: i64) -> Self { + Self::from_i64(n) + } + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + fn rand(rng: &mut dyn rand::RngCore) -> Self { + Self::rand(rng) + } + fn hash(values: &mut [Self], data: &[&[u8]]) { + Self::hash(values, data) + } + + fn poly_mul(p: &[Self], q: &[Self]) -> Vec { + let p = p.iter().map(|x| x.inner).collect(); + let q = q.iter().map(|x| x.inner).collect(); + let p = DensePolynomial { coeffs: p }; + let q = DensePolynomial { coeffs: q }; + (&p * &q) + .coeffs + .into_iter() + .map(|inner| bls12_381::Zp { inner }) + .collect() + } +} + +impl CurveGroupOps for bls12_381::G1 { + const ZERO: Self = Self::ZERO; + const GENERATOR: Self = Self::GENERATOR; + const BYTE_SIZE: usize = Self::BYTE_SIZE; + + fn mul_scalar(self, scalar: bls12_381::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self { + Self::multi_mul_scalar(bases, scalars) + } + + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + + fn double(self) -> Self { + self.double() + } +} + +impl CurveGroupOps for bls12_381::G2 { + const ZERO: Self = Self::ZERO; + const GENERATOR: Self = Self::GENERATOR; + const BYTE_SIZE: usize = Self::BYTE_SIZE; + + fn mul_scalar(self, scalar: bls12_381::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self { + Self::multi_mul_scalar(bases, scalars) + } + + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + + fn double(self) -> Self { + self.double() + } +} + +impl PairingGroupOps for bls12_381::Gt { + fn mul_scalar(self, scalar: bls12_381::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn pairing(x: bls12_381::G1, y: bls12_381::G2) -> Self { + Self::pairing(x, y) + } +} + +impl FieldOps for bls12_446::Zp { + const ZERO: Self = Self::ZERO; + const ONE: Self = Self::ONE; + + fn from_u128(n: u128) -> Self { + Self::from_bigint([n as u64, (n >> 64) as u64, 0, 0, 0]) + } + fn from_u64(n: u64) -> Self { + Self::from_u64(n) + } + fn from_i64(n: i64) -> Self { + Self::from_i64(n) + } + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + fn rand(rng: &mut dyn rand::RngCore) -> Self { + Self::rand(rng) + } + fn hash(values: &mut [Self], data: &[&[u8]]) { + Self::hash(values, data) + } + + fn poly_mul(p: &[Self], q: &[Self]) -> Vec { + let p = p.iter().map(|x| x.inner).collect(); + let q = q.iter().map(|x| x.inner).collect(); + let p = DensePolynomial { coeffs: p }; + let q = DensePolynomial { coeffs: q }; + (&p * &q) + .coeffs + .into_iter() + .map(|inner| bls12_446::Zp { inner }) + .collect() + } +} + +impl CurveGroupOps for bls12_446::G1 { + const ZERO: Self = Self::ZERO; + const GENERATOR: Self = Self::GENERATOR; + const BYTE_SIZE: usize = Self::BYTE_SIZE; + + fn mul_scalar(self, scalar: bls12_446::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self { + Self::multi_mul_scalar(bases, scalars) + } + + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + + fn double(self) -> Self { + self.double() + } +} + +impl CurveGroupOps for bls12_446::G2 { + const ZERO: Self = Self::ZERO; + const GENERATOR: Self = Self::GENERATOR; + const BYTE_SIZE: usize = Self::BYTE_SIZE; + + fn mul_scalar(self, scalar: bls12_446::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self { + Self::multi_mul_scalar(bases, scalars) + } + + fn to_bytes(self) -> impl AsRef<[u8]> { + self.to_bytes() + } + + fn double(self) -> Self { + self.double() + } +} + +impl PairingGroupOps for bls12_446::Gt { + fn mul_scalar(self, scalar: bls12_446::Zp) -> Self { + self.mul_scalar(scalar) + } + + fn pairing(x: bls12_446::G1, y: bls12_446::G2) -> Self { + Self::pairing(x, y) + } +} + +#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)] +pub struct Bls12_381; +#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)] +pub struct Bls12_446; + +impl Curve for Bls12_381 { + type Zp = bls12_381::Zp; + type G1 = bls12_381::G1; + type G2 = bls12_381::G2; + type Gt = bls12_381::Gt; +} +impl Curve for Bls12_446 { + type Zp = bls12_446::Zp; + type G1 = bls12_446::G1; + type G2 = bls12_446::G2; + type Gt = bls12_446::Gt; +} diff --git a/tfhe-zk-pok/src/curve_api/bls12_381.rs b/tfhe-zk-pok/src/curve_api/bls12_381.rs new file mode 100644 index 0000000000..a4bb50dc05 --- /dev/null +++ b/tfhe-zk-pok/src/curve_api/bls12_381.rs @@ -0,0 +1,776 @@ +use super::*; + +/// multiply EC point with scalar (= exponentiation in multiplicative notation) +fn mul_zp + Group>(x: T, scalar: Zp) -> T { + let zero = T::zero(); + let n: BigInt<4> = scalar.inner.into(); + + if n == BigInt([0; 4]) { + return zero; + } + + let mut y = zero; + let mut x = x; + + let n = n.0; + for word in n { + for idx in 0..64 { + let bit = (word >> idx) & 1; + if bit == 1 { + y += x; + } + x.double_in_place(); + } + } + y +} + +fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] { + let mut buf = [0u8; 6 * 8]; + for (i, &xi) in x.iter().enumerate() { + buf[i * 8..][..8].copy_from_slice(&xi.to_le_bytes()); + } + buf +} + +mod g1 { + use super::*; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct G1 { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: ark_bls12_381::G1Projective, + } + + impl fmt::Debug for G1 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("G1") + .field("x", &MontIntDisplay(&self.inner.x)) + .field("y", &MontIntDisplay(&self.inner.y)) + .field("z", &MontIntDisplay(&self.inner.z)) + .finish() + } + } + + impl G1 { + pub const ZERO: Self = Self { + inner: ark_bls12_381::G1Projective { + x: MontFp!("1"), + y: MontFp!("1"), + z: MontFp!("0"), + }, + }; + + // https://github.com/zcash/librustzcash/blob/6e0364cd42a2b3d2b958a54771ef51a8db79dd29/pairing/src/bls12_381/README.md#g1 + pub const GENERATOR: Self = Self { + inner: ark_bls12_381::G1Projective { + x: MontFp!("3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507"), + y: MontFp!("1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569"), + z: MontFp!("1"), + }, + }; + + // Size in number of bytes when the [to_bytes] + // function is called. + // This is not the size after serialization! + pub const BYTE_SIZE: usize = 2 * 6 * 8 + 1; + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { + use rayon::prelude::*; + let n_threads = rayon::current_num_threads(); + let chunk_size = bases.len().div_ceil(n_threads); + bases + .par_iter() + .map(|&x| x.inner.into_affine()) + .chunks(chunk_size) + .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) + .map(|(bases, scalars)| Self { + inner: ark_bls12_381::G1Projective::msm(&bases, &scalars).unwrap(), + }) + .sum::() + } + + pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + let g = self.inner.into_affine(); + let x = bigint_to_bytes(g.x.0 .0); + let y = bigint_to_bytes(g.y.0 .0); + let mut buf = [0u8; 2 * 6 * 8 + 1]; + buf[..6 * 8].copy_from_slice(&x); + buf[6 * 8..][..6 * 8].copy_from_slice(&y); + buf[2 * 6 * 8] = g.infinity as u8; + buf + } + + pub fn double(self) -> Self { + Self { + inner: self.inner.double(), + } + } + } + + impl Add for G1 { + type Output = G1; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + G1 { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for G1 { + type Output = G1; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + G1 { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for G1 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for G1 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl core::iter::Sum for G1 { + fn sum>(iter: I) -> Self { + iter.fold(G1::ZERO, Add::add) + } + } + + impl Neg for G1 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod g2 { + use super::*; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct G2 { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(super) inner: ark_bls12_381::G2Projective, + } + + impl fmt::Debug for G2 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(dead_code)] + #[derive(Debug)] + struct QuadExtField { + c0: T, + c1: T, + } + + f.debug_struct("G2") + .field( + "x", + &QuadExtField { + c0: MontIntDisplay(&self.inner.x.c0), + c1: MontIntDisplay(&self.inner.x.c1), + }, + ) + .field( + "y", + &QuadExtField { + c0: MontIntDisplay(&self.inner.y.c0), + c1: MontIntDisplay(&self.inner.y.c1), + }, + ) + .field( + "z", + &QuadExtField { + c0: MontIntDisplay(&self.inner.z.c0), + c1: MontIntDisplay(&self.inner.z.c1), + }, + ) + .finish() + } + } + + impl G2 { + pub const ZERO: Self = Self { + inner: ark_bls12_381::G2Projective { + x: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }, + y: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }, + z: ark_ff::QuadExtField { + c0: MontFp!("0"), + c1: MontFp!("0"), + }, + }, + }; + + // https://github.com/zcash/librustzcash/blob/6e0364cd42a2b3d2b958a54771ef51a8db79dd29/pairing/src/bls12_381/README.md#g2 + pub const GENERATOR: Self = Self { + inner: ark_bls12_381::G2Projective { + x: ark_ff::QuadExtField { + c0: MontFp!("352701069587466618187139116011060144890029952792775240219908644239793785735715026873347600343865175952761926303160"), + c1: MontFp!("3059144344244213709971259814753781636986470325476647558659373206291635324768958432433509563104347017837885763365758"), + }, + y: ark_ff::QuadExtField { + c0: MontFp!("1985150602287291935568054521177171638300868978215655730859378665066344726373823718423869104263333984641494340347905"), + c1: MontFp!("927553665492332455747201965776037880757740193453592970025027978793976877002675564980949289727957565575433344219582"), + }, + z: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0") , + }, + }, + }; + + // Size in number of bytes when the [to_bytes] + // function is called. + // This is not the size after serialization! + pub const BYTE_SIZE: usize = 4 * 6 * 8 + 1; + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { + use rayon::prelude::*; + let n_threads = rayon::current_num_threads(); + let chunk_size = bases.len().div_ceil(n_threads); + bases + .par_iter() + .map(|&x| x.inner.into_affine()) + .chunks(chunk_size) + .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) + .map(|(bases, scalars)| Self { + inner: ark_bls12_381::G2Projective::msm(&bases, &scalars).unwrap(), + }) + .sum::() + } + + pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + let g = self.inner.into_affine(); + let xc0 = bigint_to_bytes(g.x.c0.0 .0); + let xc1 = bigint_to_bytes(g.x.c1.0 .0); + let yc0 = bigint_to_bytes(g.y.c0.0 .0); + let yc1 = bigint_to_bytes(g.y.c1.0 .0); + let mut buf = [0u8; 4 * 6 * 8 + 1]; + buf[..6 * 8].copy_from_slice(&xc0); + buf[6 * 8..][..6 * 8].copy_from_slice(&xc1); + buf[2 * 6 * 8..][..6 * 8].copy_from_slice(&yc0); + buf[3 * 6 * 8..][..6 * 8].copy_from_slice(&yc1); + buf[4 * 6 * 8] = g.infinity as u8; + buf + } + + pub fn double(self) -> Self { + Self { + inner: self.inner.double(), + } + } + } + + impl Add for G2 { + type Output = G2; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + G2 { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for G2 { + type Output = G2; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + G2 { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for G2 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for G2 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl core::iter::Sum for G2 { + fn sum>(iter: I) -> Self { + iter.fold(G2::ZERO, Add::add) + } + } + + impl Neg for G2 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod gt { + use super::*; + use ark_ec::pairing::Pairing; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct Gt { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + inner: ark_ec::pairing::PairingOutput, + } + + impl fmt::Debug for Gt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(dead_code)] + #[derive(Debug)] + struct QuadExtField { + c0: T, + c1: T, + } + + #[allow(dead_code)] + #[derive(Debug)] + struct CubicExtField { + c0: T, + c1: T, + c2: T, + } + + #[allow(dead_code)] + #[derive(Debug)] + pub struct Gt { + inner: T, + } + + f.debug_struct("Gt") + .field( + "inner", + &Gt { + inner: QuadExtField { + c0: CubicExtField { + c0: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c0.c0), + c1: MontIntDisplay(&self.inner.0.c0.c0.c1), + }, + c1: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c1.c0), + c1: MontIntDisplay(&self.inner.0.c0.c1.c1), + }, + c2: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c2.c0), + c1: MontIntDisplay(&self.inner.0.c0.c2.c1), + }, + }, + c1: CubicExtField { + c0: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c0.c0), + c1: MontIntDisplay(&self.inner.0.c1.c0.c1), + }, + c1: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c1.c0), + c1: MontIntDisplay(&self.inner.0.c1.c1.c1), + }, + c2: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c2.c0), + c1: MontIntDisplay(&self.inner.0.c1.c2.c1), + }, + }, + }, + }, + ) + .finish() + } + } + + impl Gt { + pub fn pairing(g1: G1, g2: G2) -> Self { + Self { + inner: ark_bls12_381::Bls12_381::pairing(g1.inner, g2.inner), + } + } + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + } + + impl Add for Gt { + type Output = Gt; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Gt { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for Gt { + type Output = Gt; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Gt { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for Gt { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for Gt { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl Neg for Gt { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod zp { + use super::*; + use ark_ff::Fp; + use zeroize::Zeroize; + + fn redc(n: [u64; 4], nprime: u64, mut t: [u64; 6]) -> [u64; 4] { + for i in 0..2 { + let mut c = 0u64; + let m = u64::wrapping_mul(t[i], nprime); + + for j in 0..4 { + let x = t[i + j] as u128 + m as u128 * n[j] as u128 + c as u128; + t[i + j] = x as u64; + c = (x >> 64) as u64; + } + + for j in 4..6 - i { + let x = t[i + j] as u128 + c as u128; + t[i + j] = x as u64; + c = (x >> 64) as u64; + } + } + + let mut t = [t[2], t[3], t[4], t[5]]; + + if t.into_iter().rev().ge(n.into_iter().rev()) { + let mut o = false; + for i in 0..4 { + let (ti, o0) = u64::overflowing_sub(t[i], n[i]); + let (ti, o1) = u64::overflowing_sub(ti, o as u64); + o = o0 | o1; + t[i] = ti; + } + } + assert!(t.into_iter().rev().lt(n.into_iter().rev())); + + t + } + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Zeroize)] + #[repr(transparent)] + pub struct Zp { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: ark_bls12_381::Fr, + } + + impl fmt::Debug for Zp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Zp") + .field(&MontIntDisplay(&self.inner)) + .finish() + } + } + + impl Zp { + pub const ZERO: Self = Self { + inner: MontFp!("0"), + }; + + pub const ONE: Self = Self { + inner: MontFp!("1"), + }; + + pub fn from_bigint(n: [u64; 4]) -> Self { + Self { + inner: BigInt(n).into(), + } + } + + pub fn from_u64(n: u64) -> Self { + Self { + inner: BigInt([n, 0, 0, 0]).into(), + } + } + + pub fn from_i64(n: i64) -> Self { + let n_abs = Self::from_u64(n.unsigned_abs()); + if n > 0 { + n_abs + } else { + -n_abs + } + } + + pub fn to_bytes(self) -> [u8; 4 * 8] { + let buf = [ + self.inner.0 .0[0].to_le_bytes(), + self.inner.0 .0[1].to_le_bytes(), + self.inner.0 .0[2].to_le_bytes(), + self.inner.0 .0[3].to_le_bytes(), + ]; + unsafe { core::mem::transmute(buf) } + } + + fn from_raw_u64x6(n: [u64; 6]) -> Self { + const MODULUS: BigInt<4> = BigInt!( + "52435875175126190479447740508185965837690552500527637822603658699938581184513" + ); + + const MODULUS_MONTGOMERY: u64 = 18446744069414584319; + + let mut n = n; + // zero the two leading bits, so the result is <= MODULUS * 2^128 + n[5] &= (1 << 62) - 1; + Zp { + inner: Fp( + BigInt(redc(MODULUS.0, MODULUS_MONTGOMERY, n)), + core::marker::PhantomData, + ), + } + } + + pub fn rand(rng: &mut dyn rand::RngCore) -> Self { + use rand::Rng; + + Self::from_raw_u64x6([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + } + + pub fn hash(values: &mut [Zp], data: &[&[u8]]) { + use sha3::digest::{ExtendableOutput, Update, XofReader}; + + let mut hasher = sha3::Shake256::default(); + for data in data { + hasher.update(data); + } + let mut reader = hasher.finalize_xof(); + + for value in values { + let mut bytes = [0u8; 6 * 8]; + reader.read(&mut bytes); + let bytes: [[u8; 8]; 6] = unsafe { core::mem::transmute(bytes) }; + *value = Zp::from_raw_u64x6(bytes.map(u64::from_le_bytes)); + } + } + } + + impl Add for Zp { + type Output = Zp; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for Zp { + type Output = Zp; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner - rhs.inner, + } + } + } + + impl Mul for Zp { + type Output = Zp; + + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner * rhs.inner, + } + } + } + + impl Div for Zp { + type Output = Zp; + + #[inline] + fn div(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner / rhs.inner, + } + } + } + impl AddAssign for Zp { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for Zp { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl Neg for Zp { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } + + impl core::iter::Sum for Zp { + fn sum>(iter: I) -> Self { + iter.fold(Zp::ZERO, Add::add) + } + } +} + +pub use g1::G1; +pub use g2::G2; +pub use gt::Gt; +pub use zp::Zp; + +#[cfg(test)] +mod tests { + use rand::rngs::StdRng; + use rand::SeedableRng; + use std::collections::HashMap; + + use super::*; + + #[test] + fn test_distributivity() { + let a = Zp { + inner: MontFp!( + "20799633726231143268782044631117354647259165363910905818134484248029981143850" + ), + }; + let b = Zp { + inner: MontFp!( + "42333504039292951860879669847432876299949385605895551964353185488509497658948" + ), + }; + let c = Zp { + inner: MontFp!( + "6797004509292554067788526429737434481164547177696793280652530849910670196287" + ), + }; + + assert_eq!((((a - b) * c) - (a * c - b * c)).inner, Zp::ZERO.inner); + } + + #[test] + fn test_serialization() { + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let g_cur = G1::GENERATOR.mul_scalar(alpha); + let g_hat_cur = G2::GENERATOR.mul_scalar(alpha); + + let alpha2: Zp = serde_json::from_str(&serde_json::to_string(&alpha).unwrap()).unwrap(); + assert_eq!(alpha, alpha2); + + let g_cur2: G1 = serde_json::from_str(&serde_json::to_string(&g_cur).unwrap()).unwrap(); + assert_eq!(g_cur, g_cur2); + + let g_hat_cur2: G2 = + serde_json::from_str(&serde_json::to_string(&g_hat_cur).unwrap()).unwrap(); + assert_eq!(g_hat_cur, g_hat_cur2); + } + + #[test] + fn test_hasher_and_eq() { + // we need to make sure if the points are the same + // but the projective representations are different + // then they still hash into the same thing + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let a = G1::GENERATOR.mul_scalar(alpha); + + // serialization should convert the point to affine representation + // after deserializing it we should have the same point + // but with a different representation + let a_affine: G1 = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap(); + + // the internal elements should be different + assert_ne!(a.inner.x, a_affine.inner.x); + assert_ne!(a.inner.y, a_affine.inner.y); + assert_ne!(a.inner.z, a_affine.inner.z); + + // but equality and hasher should see the two as the same point + assert_eq!(a, a_affine); + let mut hm = HashMap::new(); + hm.insert(a, 1); + assert_eq!(hm.len(), 1); + hm.insert(a_affine, 2); + assert_eq!(hm.len(), 1); + } +} diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs new file mode 100644 index 0000000000..4769acb67b --- /dev/null +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -0,0 +1,970 @@ +use super::*; + +/// multiply EC point with scalar (= exponentiation in multiplicative notation) +fn mul_zp + Group>(x: T, scalar: Zp) -> T { + let zero = T::zero(); + let n: BigInt<5> = scalar.inner.into(); + + if n == BigInt([0; 5]) { + return zero; + } + + let mut y = zero; + let mut x = x; + + let n = n.0; + for word in n { + for idx in 0..64 { + let bit = (word >> idx) & 1; + if bit == 1 { + y += x; + } + x.double_in_place(); + } + } + y +} + +fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] { + let mut buf = [0u8; 7 * 8]; + for (i, &xi) in x.iter().enumerate() { + buf[i * 8..][..8].copy_from_slice(&xi.to_le_bytes()); + } + buf +} + +mod g1 { + use super::*; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct G1 { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: crate::curve_446::g1::G1Projective, + } + + impl fmt::Debug for G1 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("G1") + .field("x", &MontIntDisplay(&self.inner.x)) + .field("y", &MontIntDisplay(&self.inner.y)) + .field("z", &MontIntDisplay(&self.inner.z)) + .finish() + } + } + + impl G1 { + pub const ZERO: Self = Self { + inner: crate::curve_446::g1::G1Projective { + x: MontFp!("1"), + y: MontFp!("1"), + z: MontFp!("0"), + }, + }; + + pub const GENERATOR: Self = Self { + inner: crate::curve_446::g1::G1Projective { + x: MontFp!("143189966182216199425404656824735381247272236095050141599848381692039676741476615087722874458136990266833440576646963466074693171606778"), + y: MontFp!("75202396197342917254523279069469674666303680671605970245803554133573745859131002231546341942288521574682619325841484506619191207488304"), + z: MontFp!("1"), + }, + }; + + // Size in number of bytes when the [to_bytes] + // function is called. + // This is not the size after serialization! + pub const BYTE_SIZE: usize = 2 * 7 * 8 + 1; + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { + use rayon::prelude::*; + let n_threads = rayon::current_num_threads(); + let chunk_size = bases.len().div_ceil(n_threads); + bases + .par_iter() + .map(|&x| x.inner.into_affine()) + .chunks(chunk_size) + .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) + .map(|(bases, scalars)| Self { + inner: crate::curve_446::g1::G1Projective::msm(&bases, &scalars).unwrap(), + }) + .sum::() + } + + pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + let g = self.inner.into_affine(); + let x = bigint_to_bytes(g.x.0 .0); + let y = bigint_to_bytes(g.y.0 .0); + let mut buf = [0u8; 2 * 7 * 8 + 1]; + buf[..7 * 8].copy_from_slice(&x); + buf[7 * 8..][..7 * 8].copy_from_slice(&y); + buf[2 * 7 * 8] = g.infinity as u8; + buf + } + + pub fn double(self) -> Self { + Self { + inner: self.inner.double(), + } + } + } + + impl Add for G1 { + type Output = G1; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + G1 { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for G1 { + type Output = G1; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + G1 { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for G1 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for G1 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl core::iter::Sum for G1 { + fn sum>(iter: I) -> Self { + iter.fold(G1::ZERO, Add::add) + } + } + + impl Neg for G1 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod g2 { + use super::*; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct G2 { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(super) inner: crate::curve_446::g2::G2Projective, + } + + impl fmt::Debug for G2 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(dead_code)] + #[derive(Debug)] + struct QuadExtField { + c0: T, + c1: T, + } + + f.debug_struct("G2") + .field( + "x", + &QuadExtField { + c0: MontIntDisplay(&self.inner.x.c0), + c1: MontIntDisplay(&self.inner.x.c1), + }, + ) + .field( + "y", + &QuadExtField { + c0: MontIntDisplay(&self.inner.y.c0), + c1: MontIntDisplay(&self.inner.y.c1), + }, + ) + .field( + "z", + &QuadExtField { + c0: MontIntDisplay(&self.inner.z.c0), + c1: MontIntDisplay(&self.inner.z.c1), + }, + ) + .finish() + } + } + + impl G2 { + pub const ZERO: Self = Self { + inner: crate::curve_446::g2::G2Projective { + x: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }, + y: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }, + z: ark_ff::QuadExtField { + c0: MontFp!("0"), + c1: MontFp!("0"), + }, + }, + }; + + pub const GENERATOR: Self = Self { + inner: crate::curve_446::g2::G2Projective { + x: ark_ff::QuadExtField { + c0: MontFp!("96453755443802578867745476081903764610578492683850270111202389209355548711427786327510993588141991264564812146530214503491136289085725"), + c1: MontFp!("85346509177292795277012009839788781950274202400882571466460158277083221521663169974265433098009350061415973662678938824527658049065530"), + }, + y: ark_ff::QuadExtField { + c0: MontFp!("49316184343270950587272132771103279293158283984999436491292404103501221698714795975575879957605051223501287444864258801515822358837529"), + c1: MontFp!("107680854723992552431070996218129928499826544031468382031848626814251381379173928074140221537929995580031433096217223703806029068859074"), + }, + z: ark_ff::QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0") , + }, + }, + }; + + // Size in number of bytes when the [to_bytes] + // function is called. + // This is not the size after serialization! + pub const BYTE_SIZE: usize = 4 * 7 * 8 + 1; + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { + use rayon::prelude::*; + let n_threads = rayon::current_num_threads(); + let chunk_size = bases.len().div_ceil(n_threads); + bases + .par_iter() + .map(|&x| x.inner.into_affine()) + .chunks(chunk_size) + .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) + .map(|(bases, scalars)| Self { + inner: crate::curve_446::g2::G2Projective::msm(&bases, &scalars).unwrap(), + }) + .sum::() + } + + pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + let g = self.inner.into_affine(); + let xc0 = bigint_to_bytes(g.x.c0.0 .0); + let xc1 = bigint_to_bytes(g.x.c1.0 .0); + let yc0 = bigint_to_bytes(g.y.c0.0 .0); + let yc1 = bigint_to_bytes(g.y.c1.0 .0); + let mut buf = [0u8; 4 * 7 * 8 + 1]; + buf[..7 * 8].copy_from_slice(&xc0); + buf[7 * 8..][..7 * 8].copy_from_slice(&xc1); + buf[2 * 7 * 8..][..7 * 8].copy_from_slice(&yc0); + buf[3 * 7 * 8..][..7 * 8].copy_from_slice(&yc1); + buf[4 * 7 * 8] = g.infinity as u8; + buf + } + + pub fn double(self) -> Self { + Self { + inner: self.inner.double(), + } + } + } + + impl Add for G2 { + type Output = G2; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + G2 { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for G2 { + type Output = G2; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + G2 { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for G2 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for G2 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl core::iter::Sum for G2 { + fn sum>(iter: I) -> Self { + iter.fold(G2::ZERO, Add::add) + } + } + + impl Neg for G2 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod gt { + use super::*; + use ark_ec::bls12::Bls12Config; + use ark_ec::pairing::{MillerLoopOutput, Pairing}; + use ark_ff::{CubicExtField, Fp12, Fp2, QuadExtField}; + + type Bls = crate::curve_446::Bls12_446; + type Config = crate::curve_446::Config; + + const ONE: Fp2<::Fp2Config> = QuadExtField { + c0: MontFp!("1"), + c1: MontFp!("0"), + }; + const ZERO: Fp2<::Fp2Config> = QuadExtField { + c0: MontFp!("0"), + c1: MontFp!("0"), + }; + + const U1: Fp12<::Fp12Config> = QuadExtField { + c0: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: ZERO, + }, + c1: CubicExtField { + c0: ONE, + c1: ZERO, + c2: ZERO, + }, + }; + const U3: Fp12<::Fp12Config> = QuadExtField { + c0: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: ZERO, + }, + c1: CubicExtField { + c0: ZERO, + c1: ONE, + c2: ZERO, + }, + }; + + const fn fp2_to_fp12( + x: Fp2<::Fp2Config>, + ) -> Fp12<::Fp12Config> { + QuadExtField { + c0: CubicExtField { + c0: x, + c1: ZERO, + c2: ZERO, + }, + c1: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: ZERO, + }, + } + } + + const fn fp_to_fp12( + x: ::Fp, + ) -> Fp12<::Fp12Config> { + fp2_to_fp12(QuadExtField { + c0: x, + c1: MontFp!("0"), + }) + } + + fn ate_tangent_ev(qt: G2, evpt: G1) -> Fp12<::Fp12Config> { + let qt = qt.inner.into_affine(); + let evpt = evpt.inner.into_affine(); + + let (xt, yt) = (qt.x, qt.y); + let (xe, ye) = (evpt.x, evpt.y); + + let xt = fp2_to_fp12(xt); + let yt = fp2_to_fp12(yt); + let xe = fp_to_fp12(xe); + let ye = fp_to_fp12(ye); + + let three = fp_to_fp12(MontFp!("3")); + let two = fp_to_fp12(MontFp!("2")); + + let l = three * xt.square() / (two * yt); + ye - (l * xe) / U1 + (l * xt - yt) / U3 + } + + fn ate_line_ev(q1: G2, q2: G2, evpt: G1) -> Fp12<::Fp12Config> { + let q1 = q1.inner.into_affine(); + let q2 = q2.inner.into_affine(); + let evpt = evpt.inner.into_affine(); + + let (x1, y1) = (q1.x, q1.y); + let (x2, y2) = (q2.x, q2.y); + let (xe, ye) = (evpt.x, evpt.y); + + let x1 = fp2_to_fp12(x1); + let y1 = fp2_to_fp12(y1); + let x2 = fp2_to_fp12(x2); + let y2 = fp2_to_fp12(y2); + let xe = fp_to_fp12(xe); + let ye = fp_to_fp12(ye); + + let l = (y2 - y1) / (x2 - x1); + ye - (l * xe) / U1 + (l * x1 - y1) / U3 + } + + #[allow(clippy::needless_range_loop)] + fn ate_pairing(p: G1, q: G2) -> Gt { + let t_log2 = 75; + let t_bits = b"110000000001000001000000100000000000000000000000000000000100000000000000001"; + + let mut fk = fp_to_fp12(MontFp!("1")); + let mut qk = q; + + for k in 1..t_log2 { + let lkk = ate_tangent_ev(qk, p); + qk = qk + qk; + fk = fk.square() * lkk; + + if t_bits[k] == b'1' { + assert_ne!(q, qk); + let lkp1 = if q != -qk { + ate_line_ev(q, qk, p) + } else { + fp_to_fp12(MontFp!("1")) + }; + qk += q; + fk *= lkp1; + } + } + let mlo = MillerLoopOutput(fk); + Gt { + inner: Bls::final_exponentiation(mlo).unwrap(), + } + } + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[repr(transparent)] + pub struct Gt { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: ark_ec::pairing::PairingOutput, + } + + impl fmt::Debug for Gt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(dead_code)] + #[derive(Debug)] + struct QuadExtField { + c0: T, + c1: T, + } + + #[allow(dead_code)] + #[derive(Debug)] + struct CubicExtField { + c0: T, + c1: T, + c2: T, + } + + #[allow(dead_code)] + #[derive(Debug)] + pub struct Gt { + inner: T, + } + + f.debug_struct("Gt") + .field( + "inner", + &Gt { + inner: QuadExtField { + c0: CubicExtField { + c0: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c0.c0), + c1: MontIntDisplay(&self.inner.0.c0.c0.c1), + }, + c1: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c1.c0), + c1: MontIntDisplay(&self.inner.0.c0.c1.c1), + }, + c2: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c0.c2.c0), + c1: MontIntDisplay(&self.inner.0.c0.c2.c1), + }, + }, + c1: CubicExtField { + c0: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c0.c0), + c1: MontIntDisplay(&self.inner.0.c1.c0.c1), + }, + c1: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c1.c0), + c1: MontIntDisplay(&self.inner.0.c1.c1.c1), + }, + c2: QuadExtField { + c0: MontIntDisplay(&self.inner.0.c1.c2.c0), + c1: MontIntDisplay(&self.inner.0.c1.c2.c1), + }, + }, + }, + }, + ) + .finish() + } + } + + impl Gt { + pub fn pairing(g1: G1, g2: G2) -> Self { + ate_pairing(g1, -g2) + } + + pub fn mul_scalar(self, scalar: Zp) -> Self { + Self { + inner: mul_zp(self.inner, scalar), + } + } + } + + impl Add for Gt { + type Output = Gt; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Gt { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for Gt { + type Output = Gt; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Gt { + inner: self.inner - rhs.inner, + } + } + } + + impl AddAssign for Gt { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for Gt { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl Neg for Gt { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } +} + +mod zp { + use super::*; + use ark_ff::Fp; + use zeroize::Zeroize; + + fn redc(n: [u64; 5], nprime: u64, mut t: [u64; 7]) -> [u64; 5] { + for i in 0..2 { + let mut c = 0u64; + let m = u64::wrapping_mul(t[i], nprime); + + for j in 0..5 { + let x = t[i + j] as u128 + m as u128 * n[j] as u128 + c as u128; + t[i + j] = x as u64; + c = (x >> 64) as u64; + } + + for j in 5..7 - i { + let x = t[i + j] as u128 + c as u128; + t[i + j] = x as u64; + c = (x >> 64) as u64; + } + } + + let mut t = [t[2], t[3], t[4], t[5], t[6]]; + + if t.into_iter().rev().ge(n.into_iter().rev()) { + let mut o = false; + for i in 0..5 { + let (ti, o0) = u64::overflowing_sub(t[i], n[i]); + let (ti, o1) = u64::overflowing_sub(ti, o as u64); + o = o0 | o1; + t[i] = ti; + } + } + assert!(t.into_iter().rev().lt(n.into_iter().rev())); + + t + } + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Zeroize)] + #[repr(transparent)] + pub struct Zp { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: crate::curve_446::Fr, + } + + impl fmt::Debug for Zp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Zp") + .field(&MontIntDisplay(&self.inner)) + .finish() + } + } + + impl Zp { + pub const ZERO: Self = Self { + inner: MontFp!("0"), + }; + + pub const ONE: Self = Self { + inner: MontFp!("1"), + }; + + pub fn from_bigint(n: [u64; 5]) -> Self { + Self { + inner: BigInt(n).into(), + } + } + + pub fn from_u64(n: u64) -> Self { + Self { + inner: BigInt([n, 0, 0, 0, 0]).into(), + } + } + + pub fn from_i64(n: i64) -> Self { + let n_abs = Self::from_u64(n.unsigned_abs()); + if n > 0 { + n_abs + } else { + -n_abs + } + } + + pub fn to_bytes(self) -> [u8; 5 * 8] { + let buf = [ + self.inner.0 .0[0].to_le_bytes(), + self.inner.0 .0[1].to_le_bytes(), + self.inner.0 .0[2].to_le_bytes(), + self.inner.0 .0[3].to_le_bytes(), + self.inner.0 .0[4].to_le_bytes(), + ]; + unsafe { core::mem::transmute(buf) } + } + + fn from_raw_u64x7(n: [u64; 7]) -> Self { + const MODULUS: BigInt<5> = BigInt!( + "645383785691237230677916041525710377746967055506026847120930304831624105190538527824412673" + ); + + const MODULUS_MONTGOMERY: u64 = 272467794636046335; + + let mut n = n; + // zero the 22 leading bits, so the result is <= MODULUS * 2^128 + n[6] &= (1 << 42) - 1; + Zp { + inner: Fp( + BigInt(redc(MODULUS.0, MODULUS_MONTGOMERY, n)), + core::marker::PhantomData, + ), + } + } + + pub fn rand(rng: &mut dyn rand::RngCore) -> Self { + use rand::Rng; + + Self::from_raw_u64x7([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + } + + pub fn hash(values: &mut [Zp], data: &[&[u8]]) { + use sha3::digest::{ExtendableOutput, Update, XofReader}; + + let mut hasher = sha3::Shake256::default(); + for data in data { + hasher.update(data); + } + let mut reader = hasher.finalize_xof(); + + for value in values { + let mut bytes = [0u8; 7 * 8]; + reader.read(&mut bytes); + let bytes: [[u8; 8]; 7] = unsafe { core::mem::transmute(bytes) }; + *value = Zp::from_raw_u64x7(bytes.map(u64::from_le_bytes)); + } + } + } + + impl Add for Zp { + type Output = Zp; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner + rhs.inner, + } + } + } + + impl Sub for Zp { + type Output = Zp; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner - rhs.inner, + } + } + } + + impl Mul for Zp { + type Output = Zp; + + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner * rhs.inner, + } + } + } + + impl Div for Zp { + type Output = Zp; + + #[inline] + fn div(self, rhs: Self) -> Self::Output { + Zp { + inner: self.inner / rhs.inner, + } + } + } + impl AddAssign for Zp { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.inner += rhs.inner + } + } + + impl SubAssign for Zp { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.inner -= rhs.inner + } + } + + impl Neg for Zp { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { inner: -self.inner } + } + } + + impl core::iter::Sum for Zp { + fn sum>(iter: I) -> Self { + iter.fold(Zp::ZERO, Add::add) + } + } +} + +pub use g1::G1; +pub use g2::G2; +pub use gt::Gt; +pub use zp::Zp; + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + use std::collections::HashMap; + + #[test] + fn test_g1() { + let x = G1::GENERATOR; + let y = x.mul_scalar(Zp::from_i64(-2) * Zp::from_u64(2)); + assert_eq!(x - x, G1::ZERO); + assert_eq!(x + y - x, y); + } + + #[test] + fn test_g2() { + let x = G2::GENERATOR; + let y = x.mul_scalar(Zp::from_i64(-2) * Zp::from_u64(2)); + assert_eq!(x - x, G2::ZERO); + assert_eq!(x + y - x, y); + } + + #[test] + fn test_g1_msm() { + let n = 1024; + let x = vec![G1::GENERATOR.mul_scalar(Zp::from_i64(-1)); n]; + let mut p = vec![Zp::ZERO; n]; + Zp::hash(&mut p, &[&[0]]); + + let result = G1::multi_mul_scalar(&x, &p); + let expected = x + .iter() + .zip(p.iter()) + .map(|(&x, &p)| x.mul_scalar(p)) + .sum::(); + assert_eq!(result, expected); + } + + #[test] + fn test_g2_msm() { + let n = 1024; + let x = vec![G2::GENERATOR.mul_scalar(Zp::from_i64(-1)); n]; + let mut p = vec![Zp::ZERO; n]; + Zp::hash(&mut p, &[&[0]]); + + let result = G2::multi_mul_scalar(&x, &p); + let expected = x + .iter() + .zip(p.iter()) + .map(|(&x, &p)| x.mul_scalar(p)) + .sum::(); + assert_eq!(result, expected); + } + + #[test] + fn test_pairing() { + let rng = &mut StdRng::seed_from_u64(0); + let p1 = Zp::rand(rng); + let p2 = Zp::rand(rng); + + let x1 = G1::GENERATOR.mul_scalar(p1); + let x2 = G2::GENERATOR.mul_scalar(p2); + + assert_eq!( + Gt::pairing(x1, x2), + Gt::pairing(G1::GENERATOR, G2::GENERATOR).mul_scalar(p1 * p2), + ); + } + + #[test] + fn test_distributivity() { + let a = Zp { + inner: MontFp!( + "20799633726231143268782044631117354647259165363910905818134484248029981143850" + ), + }; + let b = Zp { + inner: MontFp!( + "42333504039292951860879669847432876299949385605895551964353185488509497658948" + ), + }; + let c = Zp { + inner: MontFp!( + "6797004509292554067788526429737434481164547177696793280652530849910670196287" + ), + }; + + assert_eq!((((a - b) * c) - (a * c - b * c)).inner, Zp::ZERO.inner); + } + + #[test] + fn test_serialization() { + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let g_cur = G1::GENERATOR.mul_scalar(alpha); + let g_hat_cur = G2::GENERATOR.mul_scalar(alpha); + + let alpha2: Zp = serde_json::from_str(&serde_json::to_string(&alpha).unwrap()).unwrap(); + assert_eq!(alpha, alpha2); + + let g_cur2: G1 = serde_json::from_str(&serde_json::to_string(&g_cur).unwrap()).unwrap(); + assert_eq!(g_cur, g_cur2); + + let g_hat_cur2: G2 = + serde_json::from_str(&serde_json::to_string(&g_hat_cur).unwrap()).unwrap(); + assert_eq!(g_hat_cur, g_hat_cur2); + } + + #[test] + fn test_hasher_and_eq() { + // we need to make sure if the points are the same + // but the projective representations are different + // then they still hash into the same thing + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let a = G1::GENERATOR.mul_scalar(alpha); + + // serialization should convert the point to affine representation + // after deserializing it we should have the same point + // but with a different representation + let a_affine: G1 = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap(); + + // the internal elements should be different + assert_ne!(a.inner.x, a_affine.inner.x); + assert_ne!(a.inner.y, a_affine.inner.y); + assert_ne!(a.inner.z, a_affine.inner.z); + + // but equality and hasher should see the two as the same point + assert_eq!(a, a_affine); + let mut hm = HashMap::new(); + hm.insert(a, 1); + assert_eq!(hm.len(), 1); + hm.insert(a_affine, 2); + assert_eq!(hm.len(), 1); + } +} diff --git a/tfhe-zk-pok/src/lib.rs b/tfhe-zk-pok/src/lib.rs new file mode 100644 index 0000000000..b75cb224de --- /dev/null +++ b/tfhe-zk-pok/src/lib.rs @@ -0,0 +1,3 @@ +pub mod curve_446; +pub mod curve_api; +pub mod proofs; diff --git a/tfhe-zk-pok/src/proofs/binary.rs b/tfhe-zk-pok/src/proofs/binary.rs new file mode 100644 index 0000000000..ea7a1d8bed --- /dev/null +++ b/tfhe-zk-pok/src/proofs/binary.rs @@ -0,0 +1,213 @@ +use super::*; + +#[derive(Clone, Debug)] +pub struct PublicParams { + g_lists: GroupElements, +} + +impl PublicParams { + pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + Self { + g_lists: GroupElements::from_vec(g_list, g_hat_list), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct PublicCommit { + c_hat: G::G2, +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + message: Vec, + gamma: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct Proof { + c_y: G::G1, + pi: G::G1, +} + +pub fn crs_gen(message_len: usize, rng: &mut dyn RngCore) -> PublicParams { + let alpha = G::Zp::rand(rng); + PublicParams { + g_lists: GroupElements::new(message_len, alpha), + } +} + +pub fn commit( + message: &[u64], + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let g_hat = G::G2::GENERATOR; + let n = message.len(); + + let gamma = G::Zp::rand(rng); + let x = OneBased::new_ref(message); + + let mut c_hat = g_hat.mul_scalar(gamma); + for j in 1..n + 1 { + let term = if x[j] != 0 { + public.g_lists.g_hat_list[j] + } else { + G::G2::ZERO + }; + c_hat += term; + } + + ( + PublicCommit { c_hat }, + PrivateCommit { + message: message.to_vec(), + gamma, + }, + ) +} + +pub fn prove( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + rng: &mut dyn RngCore, +) -> Proof { + let n = private_commit.message.len(); + let g = G::G1::GENERATOR; + let x = OneBased::new_ref(&*private_commit.message); + let c_hat = public.1.c_hat; + let gamma = private_commit.gamma; + let gamma_y = G::Zp::rand(rng); + let g_list = &public.0.g_lists.g_list; + + let mut y = OneBased(vec![G::Zp::ZERO; n]); + G::Zp::hash(&mut y.0, &[c_hat.to_bytes().as_ref()]); + + let mut c_y = g.mul_scalar(gamma_y); + for j in 1..n + 1 { + c_y += (g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(x[j])); + } + + let y_bytes = &*(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(); + let mut t = OneBased(vec![G::Zp::ZERO; n]); + G::Zp::hash( + &mut t.0, + &[y_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + let proof = { + let mut poly_0 = vec![G::Zp::ZERO; n + 1]; + let mut poly_1 = vec![G::Zp::ZERO; n + 1]; + let mut poly_2 = vec![G::Zp::ZERO; n + 1]; + let mut poly_3 = vec![G::Zp::ZERO; n + 1]; + + poly_0[0] = gamma_y * delta_y; + for i in 1..n + 1 { + poly_0[n + 1 - i] = + delta_y * (G::Zp::from_u64(x[i]) * y[i]) + (delta_eq * t[i] - delta_y) * y[i]; + } + + poly_1[0] = gamma; + for i in 1..n + 1 { + poly_1[i] = G::Zp::from_u64(x[i]); + } + + poly_2[0] = gamma_y; + for i in 1..n + 1 { + poly_2[n + 1 - i] = y[i] * G::Zp::from_u64(x[i]); + } + + for i in 1..n + 1 { + poly_3[i] = delta_eq * t[i]; + } + + let poly = G::Zp::poly_sub( + &G::Zp::poly_mul(&poly_0, &poly_1), + &G::Zp::poly_mul(&poly_2, &poly_3), + ); + + let mut proof = g.mul_scalar(poly[0]); + for i in 1..poly.len() { + proof += g_list[i].mul_scalar(poly[i]); + } + proof + }; + + Proof { pi: proof, c_y } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let e = G::Gt::pairing; + let c_hat = public.1.c_hat; + let g_hat = G::G2::GENERATOR; + let g_list = &public.0.g_lists.g_list; + let g_hat_list = &public.0.g_lists.g_hat_list; + let n = public.0.g_lists.message_len; + + let pi = proof.pi; + let c_y = proof.c_y; + + let mut y = OneBased(vec![G::Zp::ZERO; n]); + G::Zp::hash(&mut y.0, &[c_hat.to_bytes().as_ref()]); + + let y_bytes = &*(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(); + let mut t = OneBased(vec![G::Zp::ZERO; n]); + G::Zp::hash( + &mut t.0, + &[y_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + let rhs = e(pi, g_hat); + let lhs = { + let numerator = { + let mut p = c_y.mul_scalar(delta_y); + for i in 1..n + 1 { + let gy = g_list[n + 1 - i].mul_scalar(y[i]); + p += gy.mul_scalar(delta_eq).mul_scalar(t[i]) - gy.mul_scalar(delta_y); + } + e(p, c_hat) + }; + let denominator = { + let mut q = G::G2::ZERO; + for i in 1..n + 1 { + q += g_hat_list[i].mul_scalar(delta_eq).mul_scalar(t[i]); + } + e(c_y, q) + }; + numerator - denominator + }; + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } +} diff --git a/tfhe-zk-pok/src/proofs/index.rs b/tfhe-zk-pok/src/proofs/index.rs new file mode 100644 index 0000000000..641c451404 --- /dev/null +++ b/tfhe-zk-pok/src/proofs/index.rs @@ -0,0 +1,121 @@ +use super::*; + +#[derive(Clone, Debug)] +pub struct PublicParams { + g_lists: GroupElements, +} + +impl PublicParams { + pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + Self { + g_lists: GroupElements::from_vec(g_list, g_hat_list), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct PublicCommit { + c: G::G1, +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + message: Vec, + gamma: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct Proof { + pi: G::G1, +} + +pub fn crs_gen(message_len: usize, rng: &mut dyn RngCore) -> PublicParams { + let alpha = G::Zp::rand(rng); + PublicParams { + g_lists: GroupElements::new(message_len, alpha), + } +} + +pub fn commit( + message: &[u64], + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let g = G::G1::GENERATOR; + let n = message.len(); + + let gamma = G::Zp::rand(rng); + let m = OneBased::new_ref(message); + + let mut c = g.mul_scalar(gamma); + for j in 1..n + 1 { + let term = public.g_lists.g_list[j].mul_scalar(G::Zp::from_u64(m[j])); + c += term; + } + + ( + PublicCommit { c }, + PrivateCommit { + message: message.to_vec(), + gamma, + }, + ) +} + +pub fn prove( + i: usize, + public: (&PublicParams, &PublicCommit), + private: &PrivateCommit, + rng: &mut dyn RngCore, +) -> Proof { + let _ = rng; + let n = private.message.len(); + let m = OneBased::new_ref(&*private.message); + let gamma = private.gamma; + let g_list = &public.0.g_lists.g_list; + + let mut pi = g_list[n + 1 - i].mul_scalar(gamma); + for j in 1..n + 1 { + if i != j { + let term = if m[j] & 1 == 1 { + g_list[n + 1 - i + j] + } else { + G::G1::ZERO + }; + + pi += term; + } + } + + Proof { pi } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + (index, mi): (usize, u64), + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let e = G::Gt::pairing; + let c = public.1.c; + let g_hat = G::G2::GENERATOR; + let g_list = &public.0.g_lists.g_list; + let g_hat_list = &public.0.g_lists.g_hat_list; + let n = public.0.g_lists.message_len; + let i = index + 1; + + let lhs = e(c, g_hat_list[n + 1 - i]); + let rhs = e(proof.pi, g_hat) + (e(g_list[1], g_hat_list[n])).mul_scalar(G::Zp::from_u64(mi)); + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } +} diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs new file mode 100644 index 0000000000..bc6f0cf1a0 --- /dev/null +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -0,0 +1,95 @@ +use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps}; + +use core::ops::{Index, IndexMut}; +use rand::RngCore; + +#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize)] +#[repr(transparent)] +struct OneBased(T); + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ComputeLoad { + Proof, + Verify, +} + +impl OneBased { + pub fn new(inner: T) -> Self + where + T: Sized, + { + Self(inner) + } + + pub fn new_ref(inner: &T) -> &Self { + unsafe { &*(inner as *const T as *const Self) } + } +} + +impl> Index for OneBased { + type Output = T::Output; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.0[index - 1] + } +} + +impl> IndexMut for OneBased { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index - 1] + } +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +struct GroupElements { + g_list: OneBased>, + g_hat_list: OneBased>, + message_len: usize, +} + +impl GroupElements { + pub fn new(message_len: usize, alpha: G::Zp) -> Self { + let mut g_list = Vec::new(); + let mut g_hat_list = Vec::new(); + + let mut g_cur = G::G1::GENERATOR.mul_scalar(alpha); + + for i in 0..2 * message_len { + if i == message_len { + g_list.push(G::G1::ZERO); + } else { + g_list.push(g_cur); + } + g_cur = g_cur.mul_scalar(alpha); + } + + let mut g_hat_cur = G::G2::GENERATOR.mul_scalar(alpha); + for _ in 0..message_len { + g_hat_list.push(g_hat_cur); + g_hat_cur = (g_hat_cur).mul_scalar(alpha); + } + + Self { + g_list: OneBased::new(g_list), + g_hat_list: OneBased::new(g_hat_list), + message_len, + } + } + + pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + let message_len = g_hat_list.len(); + Self { + g_list: OneBased::new(g_list), + g_hat_list: OneBased::new(g_hat_list), + message_len, + } + } +} + +pub mod binary; +pub mod index; +pub mod pke; +pub mod range; +pub mod rlwe; diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs new file mode 100644 index 0000000000..b1cbaf59e8 --- /dev/null +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -0,0 +1,1043 @@ +use super::*; +use core::marker::PhantomData; +use rayon::prelude::*; + +fn bit_iter(x: u64, nbits: u32) -> impl Iterator { + (0..nbits).map(move |idx| ((x >> idx) & 1) != 0) +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct PublicParams { + g_lists: GroupElements, + big_d: usize, + pub n: usize, + pub d: usize, + pub k: usize, + pub b: u64, + pub b_r: u64, + pub q: u64, + pub t: u64, +} + +impl PublicParams { + #[allow(clippy::too_many_arguments)] + pub fn from_vec( + g_list: Vec, + g_hat_list: Vec, + big_d: usize, + n: usize, + d: usize, + k: usize, + b: u64, + b_r: u64, + q: u64, + t: u64, + ) -> Self { + Self { + g_lists: GroupElements::::from_vec(g_list, g_hat_list), + big_d, + n, + d, + k, + b, + b_r, + q, + t, + } + } + + pub fn exclusive_max_noise(&self) -> u64 { + self.b + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct Proof { + c_hat: G::G2, + c_y: G::G1, + pi: G::G1, + c_hat_t: Option, + c_h: Option, + pi_kzg: Option, +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct PublicCommit { + a: Vec, + b: Vec, + c1: Vec, + c2: Vec, + __marker: PhantomData, +} + +impl PublicCommit { + pub fn new(a: Vec, b: Vec, c1: Vec, c2: Vec) -> Self { + Self { + a, + b, + c1, + c2, + __marker: Default::default(), + } + } +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + r: Vec, + e1: Vec, + m: Vec, + e2: Vec, + __marker: PhantomData, +} + +pub fn crs_gen( + d: usize, + k: usize, + b: u64, + q: u64, + t: u64, + rng: &mut dyn RngCore, +) -> (PublicParams, PrivateParams) { + let alpha = G::Zp::rand(rng); + let b_r = d as u64 / 2 + 1; + + let big_d = + d + k * t.ilog2() as usize + (d + k) * (2 + b.ilog2() as usize + b_r.ilog2() as usize); + let n = big_d + 1; + ( + PublicParams { + g_lists: GroupElements::::new(n, alpha), + big_d, + n, + d, + k, + b, + b_r, + q, + t, + }, + PrivateParams { alpha }, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn commit( + a: Vec, + b: Vec, + c1: Vec, + c2: Vec, + r: Vec, + e1: Vec, + m: Vec, + e2: Vec, + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let _ = (public, rng); + ( + PublicCommit { + a, + b, + c1, + c2, + __marker: PhantomData, + }, + PrivateCommit { + r, + e1, + m, + e2, + __marker: PhantomData, + }, + ) +} + +pub fn prove( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + load: ComputeLoad, + rng: &mut dyn RngCore, +) -> Proof { + let &PublicParams { + ref g_lists, + big_d, + n, + d, + b, + b_r, + q, + t, + k, + } = public.0; + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let b_i = b; + + let PublicCommit { a, b, c1, c2, .. } = public.1; + let PrivateCommit { r, e1, m, e2, .. } = private_commit; + + assert!(c2.len() <= k); + let k = k.min(c2.len()); + + // FIXME: div_round + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let g = G::G1::GENERATOR; + let g_hat = G::G2::GENERATOR; + let gamma = G::Zp::rand(rng); + let gamma_y = G::Zp::rand(rng); + + // rot(a) phi(r) + phi(e1) - q phi(r1) = phi(c1) + // phi[d - i + 1](bar(b)).T phi(r) + delta m_i + e2_i - q r2_i = c2 + + // phi(r1) = (rot(a) phi(r) + phi(e1) - phi(c1)) / q + // r2_i = (phi[d - i + 1](bar(b)).T phi(r) + delta m_i + e2_i - c2) / q + + let mut r1 = e1 + .iter() + .zip(c1.iter()) + .map(|(&e1, &c1)| e1 as i128 - c1 as i128) + .collect::>(); + + for i in 0..d { + for j in 0..d { + if i + j < d { + r1[i + j] += a[i] as i128 * r[d - j - 1] as i128; + } else { + r1[i + j - d] -= a[i] as i128 * r[d - j - 1] as i128; + } + } + } + + { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + for r1 in &mut *r1 { + *r1 /= q; + } + } + + let mut r2 = m + .iter() + .zip(e2) + .zip(c2) + .map(|((&m, &e2), &c2)| delta as i128 * m as i128 + e2 as i128 - c2 as i128) + .collect::>(); + + { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + for (i, r2) in r2.iter_mut().enumerate() { + let mut dot = 0i128; + for j in 0..d { + let b = if i + j < d { + b[d - j - i - 1] + } else { + b[2 * d - j - i - 1].wrapping_neg() + }; + + dot += r[d - j - 1] as i128 * b as i128; + } + + *r2 += dot; + *r2 /= q; + } + } + + let r1 = r1 + .into_vec() + .into_iter() + .map(|r1| r1 as i64) + .collect::>(); + + let r2 = r2 + .into_vec() + .into_iter() + .map(|r2| r2 as i64) + .collect::>(); + + let mut w = vec![false; n]; + + let u64 = |x: i64| x as u64; + + w[..big_d] + .iter_mut() + .zip( + r.iter() + .rev() + .flat_map(|&r| bit_iter(u64(r), 1)) + .chain(m.iter().flat_map(|&m| bit_iter(u64(m), t.ilog2()))) + .chain(e1.iter().flat_map(|&e1| bit_iter(u64(e1), 1 + b_i.ilog2()))) + .chain(e2.iter().flat_map(|&e2| bit_iter(u64(e2), 1 + b_i.ilog2()))) + .chain(r1.iter().flat_map(|&r1| bit_iter(u64(r1), 1 + b_r.ilog2()))) + .chain(r2.iter().flat_map(|&r2| bit_iter(u64(r2), 1 + b_r.ilog2()))), + ) + .for_each(|(dst, src)| *dst = src); + + let w = OneBased(w); + + let mut c_hat = g_hat.mul_scalar(gamma); + for j in 1..big_d + 1 { + let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO }; + c_hat += term; + } + + let x_bytes = &*[ + q.to_le_bytes().as_slice(), + d.to_le_bytes().as_slice(), + b_i.to_le_bytes().as_slice(), + t.to_le_bytes().as_slice(), + &*a.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*b.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*c1.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*c2.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + ] + .iter() + .copied() + .flatten() + .copied() + .collect::>(); + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash(&mut y, &[x_bytes, c_hat.to_bytes().as_ref()]); + let y = OneBased(y); + + let scalars = (n + 1 - big_d..n + 1) + .map(|j| (y[n + 1 - j] * G::Zp::from_u64(w[n + 1 - j] as u64))) + .collect::>(); + let c_y = g.mul_scalar(gamma_y) + G::G1::multi_mul_scalar(&g_list.0[n - big_d..n], &scalars); + + let mut theta = vec![G::Zp::ZERO; d + k + 1]; + G::Zp::hash( + &mut theta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + + let theta0 = &theta[..d + k]; + + let delta_theta = theta[d + k]; + + let mut a_theta = vec![G::Zp::ZERO; big_d]; + + compute_a_theta::(theta0, d, a, k, b, &mut a_theta, t, delta, b_i, b_r, q); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + &(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(), + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + let mut poly_0 = vec![G::Zp::ZERO; n + 1]; + let mut poly_1 = vec![G::Zp::ZERO; big_d + 1]; + let mut poly_2 = vec![G::Zp::ZERO; n + 1]; + let mut poly_3 = vec![G::Zp::ZERO; n + 1]; + + poly_0[0] = delta_y * gamma_y; + for i in 1..n + 1 { + poly_0[n + 1 - i] = + delta_y * (y[i] * G::Zp::from_u64(w[i] as u64)) + (delta_eq * t[i] - delta_y) * y[i]; + + if i < big_d + 1 { + poly_0[n + 1 - i] += delta_theta * a_theta[i - 1]; + } + } + + poly_1[0] = gamma; + for i in 1..big_d + 1 { + poly_1[i] = G::Zp::from_u64(w[i] as u64); + } + + poly_2[0] = gamma_y; + for i in 1..big_d + 1 { + poly_2[n + 1 - i] = y[i] * G::Zp::from_u64(w[i] as u64); + } + + for i in 1..n + 1 { + poly_3[i] = delta_eq * t[i]; + } + + let mut t_theta = G::Zp::ZERO; + for i in 0..d { + t_theta += theta0[i] * G::Zp::from_i64(c1[i]); + } + for i in 0..k { + t_theta += theta0[d + i] * G::Zp::from_i64(c2[i]); + } + + let mut poly = G::Zp::poly_sub( + &G::Zp::poly_mul(&poly_0, &poly_1), + &G::Zp::poly_mul(&poly_2, &poly_3), + ); + if poly.len() > n + 1 { + poly[n + 1] -= t_theta * delta_theta; + } + + let pi = + g.mul_scalar(poly[0]) + G::G1::multi_mul_scalar(&g_list.0[..poly.len() - 1], &poly[1..]); + + if load == ComputeLoad::Proof { + let c_hat_t = G::G2::multi_mul_scalar(&g_hat_list.0, &t.0); + let scalars = (1..n + 1) + .map(|i| { + let i = n + 1 - i; + (delta_eq * t[i] - delta_y) * y[i] + + if i < big_d + 1 { + delta_theta * a_theta[i - 1] + } else { + G::Zp::ZERO + } + }) + .collect::>(); + let c_h = G::G1::multi_mul_scalar(&g_list.0[..n], &scalars); + + let mut z = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut z), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + ], + ); + + let mut pow = z; + let mut p_t = G::Zp::ZERO; + let mut p_h = G::Zp::ZERO; + + for i in 1..n + 1 { + p_t += t[i] * pow; + if n - i < big_d { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i] + + delta_theta * a_theta[n - i]) + * pow; + } else { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow; + } + pow = pow * z; + } + + let mut w = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut w), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + z.to_bytes().as_ref(), + p_h.to_bytes().as_ref(), + p_t.to_bytes().as_ref(), + ], + ); + + let mut poly = vec![G::Zp::ZERO; n + 1]; + for i in 1..n + 1 { + poly[i] += w * t[i]; + if i < big_d + 1 { + poly[n + 1 - i] += + (delta_eq * t[i] - delta_y) * y[i] + delta_theta * a_theta[i - 1]; + } else { + poly[n + 1 - i] += (delta_eq * t[i] - delta_y) * y[i]; + } + } + + let mut q = vec![G::Zp::ZERO; n]; + for i in (0..n).rev() { + poly[i] = poly[i] + z * poly[i + 1]; + q[i] = poly[i + 1]; + poly[i + 1] = G::Zp::ZERO; + } + let pi_kzg = g.mul_scalar(q[0]) + G::G1::multi_mul_scalar(&g_list.0[..n - 1], &q[1..n]); + + Proof { + c_hat, + c_y, + pi, + c_hat_t: Some(c_hat_t), + c_h: Some(c_h), + pi_kzg: Some(pi_kzg), + } + } else { + Proof { + c_hat, + c_y, + pi, + c_hat_t: None, + c_h: None, + pi_kzg: None, + } + } +} + +#[allow(clippy::too_many_arguments)] +fn compute_a_theta( + theta0: &[G::Zp], + d: usize, + a: &[i64], + k: usize, + b: &[i64], + a_theta: &mut [G::Zp], + t: u64, + delta: u64, + b_i: u64, + b_r: u64, + q: u64, +) { + // a_theta = Ã.T theta0 + // = [ + // rot(a).T theta1 + phi[d](bar(b)) theta2_1 + ... + phi[d-k+1](bar(b)) theta2_k + // + // delta g[log t].T theta2_1 + // delta g[log t].T theta2_2 + // ... + // delta g[log t].T theta2_k + // + // G[1 + log B].T theta1 + // + // g[1 + log B].T theta2_1 + // g[1 + log B].T theta2_2 + // ... + // g[1 + log B].T theta2_k + // + // -q G[1 + log Br].T theta1 + // + // -q g[1 + log Br].T theta2_1 + // -q g[1 + log Br].T theta2_2 + // ... + // -q g[1 + log Br].T theta2_k + // ] + + let q = if q == 0 { + G::Zp::from_u128(1u128 << 64) + } else { + G::Zp::from_u64(q) + }; + + let theta1 = &theta0[..d]; + let theta2 = &theta0[d..]; + { + let a_theta = &mut a_theta[..d]; + a_theta + .par_iter_mut() + .enumerate() + .for_each(|(i, a_theta_i)| { + let mut dot = G::Zp::ZERO; + + for j in 0..d { + let a = if i <= j { + a[j - i] + } else { + a[d + j - i].wrapping_neg() + }; + + dot += G::Zp::from_i64(a) * theta1[j]; + } + + for j in 0..k { + let b = if i + j < d { + b[d - i - j - 1] + } else { + b[2 * d - i - j - 1].wrapping_neg() + }; + + dot += G::Zp::from_i64(b) * theta2[j]; + } + *a_theta_i = dot; + }); + } + let a_theta = &mut a_theta[d..]; + + let step = t.ilog2() as usize; + for i in 0..k { + for j in 0..step { + let pow2 = G::Zp::from_u64(delta) * G::Zp::from_u64(1 << j) * theta2[i]; + a_theta[step * i + j] = pow2; + } + } + let a_theta = &mut a_theta[k * step..]; + + let step = 1 + b_i.ilog2() as usize; + for i in 0..d { + for j in 0..step { + let pow2 = G::Zp::from_u64(1 << j) * theta1[i]; + a_theta[step * i + j] = if j == step - 1 { -pow2 } else { pow2 }; + } + } + let a_theta = &mut a_theta[d * step..]; + for i in 0..k { + for j in 0..step { + let pow2 = G::Zp::from_u64(1 << j) * theta2[i]; + a_theta[step * i + j] = if j == step - 1 { -pow2 } else { pow2 }; + } + } + let a_theta = &mut a_theta[k * step..]; + + let step = 1 + b_r.ilog2() as usize; + for i in 0..d { + for j in 0..step { + let pow2 = -q * G::Zp::from_u64(1 << j) * theta1[i]; + a_theta[step * i + j] = if j == step - 1 { -pow2 } else { pow2 }; + } + } + let a_theta = &mut a_theta[d * step..]; + for i in 0..k { + for j in 0..step { + let pow2 = -q * G::Zp::from_u64(1 << j) * theta2[i]; + a_theta[step * i + j] = if j == step - 1 { -pow2 } else { pow2 }; + } + } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let &Proof { + c_hat, + c_y, + pi, + c_hat_t, + c_h, + pi_kzg, + } = proof; + let e = G::Gt::pairing; + + let &PublicParams { + ref g_lists, + big_d, + n, + d, + b, + b_r, + q, + t, + k, + } = public.0; + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let b_i = b; + + // FIXME: div_round + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let PublicCommit { a, b, c1, c2, .. } = public.1; + if c2.len() > k { + return Err(()); + } + let k = k.min(c2.len()); + + let x_bytes = &*[ + q.to_le_bytes().as_slice(), + d.to_le_bytes().as_slice(), + b_i.to_le_bytes().as_slice(), + t.to_le_bytes().as_slice(), + &*a.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*b.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*c1.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + &*c2.iter().flat_map(|&x| x.to_le_bytes()).collect::>(), + ] + .iter() + .copied() + .flatten() + .copied() + .collect::>(); + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash(&mut y, &[x_bytes, c_hat.to_bytes().as_ref()]); + let y = OneBased(y); + + let mut theta = vec![G::Zp::ZERO; d + k + 1]; + G::Zp::hash( + &mut theta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let theta0 = &theta[..d + k]; + let delta_theta = theta[d + k]; + + let mut a_theta = vec![G::Zp::ZERO; big_d]; + compute_a_theta::(theta0, d, a, k, b, &mut a_theta, t, delta, b_i, b_r, q); + + let mut t_theta = G::Zp::ZERO; + for i in 0..d { + t_theta += theta0[i] * G::Zp::from_i64(c1[i]); + } + for i in 0..k { + t_theta += theta0[d + i] * G::Zp::from_i64(c2[i]); + } + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + &(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(), + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) { + let mut z = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut z), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + ], + ); + + let mut pow = z; + let mut p_t = G::Zp::ZERO; + let mut p_h = G::Zp::ZERO; + + for i in 1..n + 1 { + p_t += t[i] * pow; + if n - i < big_d { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i] + + delta_theta * a_theta[n - i]) + * pow; + } else { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow; + } + pow = pow * z; + } + + if e(pi, G::G2::GENERATOR) + != e(c_y.mul_scalar(delta_y) + c_h, c_hat) + - e(c_y.mul_scalar(delta_eq), c_hat_t) + - e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta) + { + return Err(()); + } + + let mut w = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut w), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + z.to_bytes().as_ref(), + p_h.to_bytes().as_ref(), + p_t.to_bytes().as_ref(), + ], + ); + + if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR) + + e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w) + == e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z)) + { + Ok(()) + } else { + Err(()) + } + } else { + let (term0, term1) = rayon::join( + || { + let p = c_y.mul_scalar(delta_y) + + (1..n + 1) + .into_par_iter() + .map(|i| { + let mut factor = (delta_eq * t[i] - delta_y) * y[i]; + if i < big_d + 1 { + factor += delta_theta * a_theta[i - 1]; + } + g_list[n + 1 - i].mul_scalar(factor) + }) + .sum::(); + let q = c_hat; + e(p, q) + }, + || { + let p = c_y; + let q = (1..n + 1) + .into_par_iter() + .map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i])) + .sum::(); + e(p, q) + }, + ); + let term2 = { + let p = g_list[1]; + let q = g_hat_list[n]; + e(p, q) + }; + + let lhs = e(pi, G::G2::GENERATOR); + let rhs = term0 - term1 - term2.mul_scalar(t_theta * delta_theta); + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + #[test] + fn test_pke() { + let d = 2048; + let k = 320; + let b_i = 512; + let q = 0; + let t = 1024; + + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let rng = &mut StdRng::seed_from_u64(0); + + let polymul_rev = |a: &[i64], b: &[i64]| -> Vec { + assert_eq!(a.len(), b.len()); + let d = a.len(); + let mut c = vec![0i64; d]; + + for i in 0..d { + for j in 0..d { + if i + j < d { + c[i + j] = c[i + j].wrapping_add(a[i].wrapping_mul(b[d - j - 1])); + } else { + c[i + j - d] = c[i + j - d].wrapping_sub(a[i].wrapping_mul(b[d - j - 1])); + } + } + } + + c + }; + + let a = (0..d).map(|_| rng.gen::()).collect::>(); + let s = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + let e = (0..d) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + let e1 = (0..d) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + let fake_e1 = (0..d) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + let e2 = (0..k) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + let fake_e2 = (0..k) + .map(|_| (rng.gen::() % (2 * b_i)) as i64 - b_i as i64) + .collect::>(); + + let r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + let fake_r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + + let m = (0..k) + .map(|_| (rng.gen::() % t) as i64) + .collect::>(); + let fake_m = (0..k) + .map(|_| (rng.gen::() % t) as i64) + .collect::>(); + + let b = polymul_rev(&a, &s) + .into_iter() + .zip(e.iter()) + .map(|(x, e)| x.wrapping_add(*e)) + .collect::>(); + let c1 = polymul_rev(&a, &r) + .into_iter() + .zip(e1.iter()) + .map(|(x, e1)| x.wrapping_add(*e1)) + .collect::>(); + + let mut c2 = vec![0i64; k]; + + for i in 0..k { + let mut dot = 0i64; + for j in 0..d { + let b = if i + j < d { + b[d - j - i - 1] + } else { + b[2 * d - j - i - 1].wrapping_neg() + }; + + dot = dot.wrapping_add(r[d - j - 1].wrapping_mul(b)); + } + + c2[i] = dot + .wrapping_add(e2[i]) + .wrapping_add((delta * m[i] as u64) as i64); + } + + let mut m_roundtrip = vec![0i64; k]; + for i in 0..k { + let mut dot = 0i128; + for j in 0..d { + let c = if i + j < d { + c1[d - j - i - 1] + } else { + c1[2 * d - j - i - 1].wrapping_neg() + }; + + dot += s[d - j - 1] as i128 * c as i128; + } + + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + let val = ((c2[i] as i128).wrapping_sub(dot)) * t as i128; + let div = val.div_euclid(q); + let rem = val.rem_euclid(q); + let result = div as i64 + (rem > (q / 2)) as i64; + let result = result.rem_euclid(t as i64); + m_roundtrip[i] = result; + } + + let (public_param, _private_param) = + crs_gen::(d, k, b_i, q, t, rng); + + for use_fake_e1 in [false, true] { + for use_fake_e2 in [false, true] { + for use_fake_m in [false, true] { + for use_fake_r in [false, true] { + let (public_commit, private_commit) = commit( + a.clone(), + b.clone(), + c1.clone(), + c2.clone(), + if use_fake_r { + fake_r.clone() + } else { + r.clone() + }, + if use_fake_e1 { + fake_e1.clone() + } else { + e1.clone() + }, + if use_fake_m { + fake_m.clone() + } else { + m.clone() + }, + if use_fake_e2 { + fake_e2.clone() + } else { + e2.clone() + }, + &public_param, + rng, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = + prove((&public_param, &public_commit), &private_commit, load, rng); + + assert_eq!( + verify(&proof, (&public_param, &public_commit)).is_err(), + use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m + ); + } + } + } + } + } + } +} diff --git a/tfhe-zk-pok/src/proofs/range.rs b/tfhe-zk-pok/src/proofs/range.rs new file mode 100644 index 0000000000..499378f32c --- /dev/null +++ b/tfhe-zk-pok/src/proofs/range.rs @@ -0,0 +1,354 @@ +use super::*; + +#[derive(Clone, Debug)] +pub struct PublicParams { + g_lists: GroupElements, +} + +impl PublicParams { + pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + Self { + g_lists: GroupElements::from_vec(g_list, g_hat_list), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct PublicCommit { + l: usize, + v_hat: G::G2, +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + x: u64, + r: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct Proof { + c_y: G::G1, + c_hat: G::G2, + pi: G::G1, +} + +pub fn crs_gen(max_nbits: usize, rng: &mut dyn RngCore) -> PublicParams { + let alpha = G::Zp::rand(rng); + PublicParams { + g_lists: GroupElements::new(max_nbits, alpha), + } +} + +pub fn commit( + x: u64, + l: usize, + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let g_hat = G::G2::GENERATOR; + + let r = G::Zp::rand(rng); + let v_hat = g_hat.mul_scalar(r) + public.g_lists.g_hat_list[1].mul_scalar(G::Zp::from_u64(x)); + + (PublicCommit { l, v_hat }, PrivateCommit { x, r }) +} + +pub fn prove( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + rng: &mut dyn RngCore, +) -> Proof { + let &PrivateCommit { x, r } = private_commit; + let &PublicCommit { l, v_hat } = public.1; + let PublicParams { g_lists } = public.0; + let n = g_lists.message_len; + + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let g = G::G1::GENERATOR; + let g_hat = G::G2::GENERATOR; + let gamma = G::Zp::rand(rng); + let gamma_y = G::Zp::rand(rng); + + let mut x_bits = vec![0u64; n]; + for (i, x_bits) in x_bits[0..l].iter_mut().enumerate() { + *x_bits = (x >> i) & 1; + } + let x_bits = OneBased(x_bits); + + let c_hat = { + let mut c = g_hat.mul_scalar(gamma); + for j in 1..l + 1 { + let term = if x_bits[j] != 0 { + g_hat_list[j] + } else { + G::G2::ZERO + }; + c += term; + } + c + }; + + let mut proof_x = -g_list[n].mul_scalar(r); + for i in 1..l + 1 { + let mut term = g_list[n + 1 - i].mul_scalar(gamma); + for j in 1..l + 1 { + if j != i { + let term_inner = if x_bits[j] != 0 { + g_list[n + 1 - i + j] + } else { + G::G1::ZERO + }; + term += term_inner; + } + } + + for _ in 1..i { + term = term.double(); + } + proof_x += term; + } + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut y, + &[v_hat.to_bytes().as_ref(), c_hat.to_bytes().as_ref()], + ); + let y = OneBased(y); + let mut c_y = g.mul_scalar(gamma_y); + for j in 1..l + 1 { + c_y += g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j])); + } + + let y_bytes = &*(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + y_bytes, + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut proof_eq = G::G1::ZERO; + for i in 1..n + 1 { + let mut numerator = g_list[n + 1 - i].mul_scalar(gamma); + for j in 1..n + 1 { + if j != i { + let term = if x_bits[j] != 0 { + g_list[n + 1 - i + j] + } else { + G::G1::ZERO + }; + numerator += term; + } + } + numerator = numerator.mul_scalar(t[i] * y[i]); + + let mut denominator = g_list[i].mul_scalar(gamma_y); + for j in 1..n + 1 { + if j != i { + denominator += g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j])); + } + } + denominator = denominator.mul_scalar(t[i]); + + proof_eq += numerator - denominator; + } + + let mut proof_y = g.mul_scalar(gamma_y); + for j in 1..n + 1 { + proof_y -= g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j])); + } + proof_y = proof_y.mul_scalar(gamma); + for i in 1..n + 1 { + let mut term = g_list[i].mul_scalar(gamma_y); + for j in 1..n + 1 { + if j != i { + term -= g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j])); + } + } + let term = if x_bits[i] != 0 { term } else { G::G1::ZERO }; + proof_y += term; + } + + let mut s = vec![G::Zp::ZERO; n]; + for (i, s) in s.iter_mut().enumerate() { + G::Zp::hash( + core::slice::from_mut(s), + &[ + &i.to_le_bytes(), + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + } + let s = OneBased(s); + + let mut proof_v = G::G1::ZERO; + for i in 2..n + 1 { + proof_v += G::G1::mul_scalar( + g_list[n + 1 - i].mul_scalar(r) + g_list[n + 2 - i].mul_scalar(G::Zp::from_u64(x)), + s[i], + ); + } + + let mut delta = [G::Zp::ZERO; 4]; + G::Zp::hash( + &mut delta, + &[ + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let [delta_x, delta_eq, delta_y, delta_v] = delta; + + let proof = proof_x.mul_scalar(delta_x) + + proof_eq.mul_scalar(delta_eq) + + proof_y.mul_scalar(delta_y) + + proof_v.mul_scalar(delta_v); + + Proof { + c_y, + c_hat, + pi: proof, + } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let e = G::Gt::pairing; + let &PublicCommit { l, v_hat } = public.1; + let PublicParams { g_lists } = public.0; + let n = g_lists.message_len; + + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let g_hat = G::G2::GENERATOR; + + let &Proof { c_y, c_hat, pi } = proof; + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut y, + &[v_hat.to_bytes().as_ref(), c_hat.to_bytes().as_ref()], + ); + let y = OneBased(y); + + let y_bytes = &*(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + y_bytes, + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut delta = [G::Zp::ZERO; 4]; + G::Zp::hash( + &mut delta, + &[ + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let [delta_x, delta_eq, delta_y, delta_v] = delta; + + let mut s = vec![G::Zp::ZERO; n]; + for (i, s) in s.iter_mut().enumerate() { + G::Zp::hash( + core::slice::from_mut(s), + &[ + &i.to_le_bytes(), + v_hat.to_bytes().as_ref(), + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + } + let s = OneBased(s); + + let rhs = e(pi, g_hat); + let lhs = { + let numerator = { + let mut p = c_y.mul_scalar(delta_y); + for i in 1..n + 1 { + let g = g_list[n + 1 - i]; + if i <= l { + p += g.mul_scalar(delta_x * G::Zp::from_u64(1 << (i - 1))); + } + p += g.mul_scalar((delta_eq * t[i] - delta_y) * y[i]); + } + e(p, c_hat) + }; + let denominator_0 = { + let mut p = g_list[n].mul_scalar(delta_x); + for i in 2..n + 1 { + p -= g_list[n + 1 - i].mul_scalar(delta_v * s[i]); + } + e(p, v_hat) + }; + let denominator_1 = { + let mut q = G::G2::ZERO; + for i in 1..n + 1 { + q += g_hat_list[i].mul_scalar(delta_eq * t[i]); + } + e(c_y, q) + }; + numerator - denominator_0 - denominator_1 + }; + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + #[test] + fn test_range() { + let rng = &mut StdRng::seed_from_u64(0); + + let max_nbits = 10; + let l = 6; + let x = rng.gen::() % (1 << l); + let public_params = crs_gen::(max_nbits, rng); + let (public_commit, private_commit) = commit(x, l, &public_params, rng); + let proof = prove((&public_params, &public_commit), &private_commit, rng); + let verify = verify(&proof, (&public_params, &public_commit)); + assert!(verify.is_ok()); + } +} diff --git a/tfhe-zk-pok/src/proofs/rlwe.rs b/tfhe-zk-pok/src/proofs/rlwe.rs new file mode 100644 index 0000000000..5dc2eb345b --- /dev/null +++ b/tfhe-zk-pok/src/proofs/rlwe.rs @@ -0,0 +1,932 @@ +use super::*; +use core::iter::zip; +use core::marker::PhantomData; +use rayon::prelude::*; + +fn bit_iter(x: u64, nbits: u32) -> impl Iterator { + (0..nbits).map(move |idx| ((x >> idx) & 1) == 1) +} + +#[derive(Clone, Debug)] +pub struct PublicParams { + g_lists: GroupElements, + d: usize, + big_n: usize, + big_m: usize, + b_i: u64, + q: u64, +} + +impl PublicParams { + pub fn from_vec( + g_list: Vec, + g_hat_list: Vec, + d: usize, + big_n: usize, + big_m: usize, + b_i: u64, + q: u64, + ) -> Self { + Self { + g_lists: GroupElements::from_vec(g_list, g_hat_list), + d, + big_n, + big_m, + b_i, + q, + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct PrivateParams { + alpha: G::Zp, +} + +#[derive(Clone, Debug)] +pub struct PublicCommit { + a: Matrix, + c: Vector, + __marker: PhantomData, +} + +#[derive(Clone, Debug)] +pub struct PrivateCommit { + s: Vector, + __marker: PhantomData, +} + +#[derive(Clone, Debug)] +pub struct Proof { + c_hat: G::G2, + c_y: G::G1, + pi: G::G1, + c_hat_t: Option, + c_h: Option, + pi_kzg: Option, +} + +pub fn crs_gen( + d: usize, + big_n: usize, + big_m: usize, + b_i: u64, + q: u64, + rng: &mut dyn RngCore, +) -> PublicParams { + let alpha = G::Zp::rand(rng); + let b_r = ((d * big_m) as u64 * b_i) / 2; + let big_d = d * (big_m * (1 + b_i.ilog2() as usize) + (big_n * (1 + b_r.ilog2() as usize))); + let n = big_d + 1; + PublicParams { + g_lists: GroupElements::new(n, alpha), + d, + big_n, + big_m, + b_i, + q, + } +} + +#[derive(Clone, Debug)] +pub struct Vector { + pub data: Vec, + pub polynomial_size: usize, + pub nrows: usize, +} +#[derive(Clone, Debug)] +pub struct Matrix { + pub data: Vec, + pub polynomial_size: usize, + pub nrows: usize, + pub ncols: usize, +} + +impl Matrix { + pub fn new(polynomial_size: usize, nrows: usize, ncols: usize, value: T) -> Self { + Self { + data: vec![value; polynomial_size * nrows * ncols], + polynomial_size, + nrows, + ncols, + } + } +} +impl Vector { + pub fn new(polynomial_size: usize, nrows: usize, value: T) -> Self { + Self { + data: vec![value; polynomial_size * nrows], + polynomial_size, + nrows, + } + } +} + +impl Index for Vector { + type Output = [T]; + + fn index(&self, row: usize) -> &Self::Output { + let row = row - 1; + &self.data[self.polynomial_size * row..][..self.polynomial_size] + } +} +impl IndexMut for Vector { + fn index_mut(&mut self, row: usize) -> &mut Self::Output { + let row = row - 1; + &mut self.data[self.polynomial_size * row..][..self.polynomial_size] + } +} + +impl Index<(usize, usize)> for Matrix { + type Output = [T]; + + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + let row = row - 1; + let col = col - 1; + &self.data[self.polynomial_size * (row * self.ncols + col)..][..self.polynomial_size] + } +} +impl IndexMut<(usize, usize)> for Matrix { + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { + let row = row - 1; + let col = col - 1; + &mut self.data[self.polynomial_size * (row * self.ncols + col)..][..self.polynomial_size] + } +} + +pub fn commit( + a: Matrix, + c: Vector, + s: Vector, + public: &PublicParams, + rng: &mut dyn RngCore, +) -> (PublicCommit, PrivateCommit) { + let _ = (public, rng); + ( + PublicCommit { + a, + c, + __marker: PhantomData, + }, + PrivateCommit { + s, + __marker: PhantomData, + }, + ) +} + +pub fn prove( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + load: ComputeLoad, + rng: &mut dyn RngCore, +) -> Proof { + let &PublicParams { + ref g_lists, + d, + big_n, + big_m, + b_i, + q, + } = public.0; + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + let s = &private_commit.s; + let a = &public.1.a; + + let b_r = ((d * big_m) as u64 * b_i) / 2; + let big_d = d * (big_m * (1 + b_i.ilog2() as usize) + (big_n * (1 + b_r.ilog2() as usize))); + let n = big_d + 1; + + let g = G::G1::GENERATOR; + let g_hat = G::G2::GENERATOR; + let gamma = G::Zp::rand(rng); + let gamma_y = G::Zp::rand(rng); + + let mut c = Vector { + data: vec![0i64; d], + polynomial_size: d, + nrows: big_n, + }; + let mut r = Vector { + data: vec![0i64; d], + polynomial_size: d, + nrows: big_n, + }; + + for j in 1..big_n + 1 { + let c = &mut c[j]; + let r = &mut r[j]; + + let mut polymul = vec![0i128; d]; + for i in 1..big_m + 1 { + let si = &s[i]; + let aij = &a[(i, j)]; + + for ii in 0..d { + for jj in 0..d { + let p = (aij[ii] as i128) * si[jj] as i128; + if ii + jj < d { + polymul[ii + jj] += p; + } else { + polymul[ii + jj - d] -= p; + } + } + } + } + + for ((ck, rk), old_ck) in zip(zip(c, r), &polymul) { + let q = if q == 0 { q as i128 } else { 1i128 << 64 }; + let mut new_ck = old_ck.rem_euclid(q); + if new_ck >= q / 2 { + new_ck -= q; + } + assert!((old_ck - new_ck) % q == 0); + assert!((*rk).unsigned_abs() < b_r); + + *ck = new_ck as i64; + *rk = ((old_ck - new_ck) / q) as i64; + } + } + let w_tilde = Iterator::chain( + (1..big_m + 1).flat_map(|i| { + s[i].iter() + .copied() + .flat_map(|x| bit_iter(x as u64, b_i.ilog2() + 1)) + }), + (1..big_n + 1).flat_map(|i| { + r[i].iter() + .copied() + .flat_map(|x| bit_iter(x as u64, b_r.ilog2() + 1)) + }), + ) + .collect::>(); + let mut w = vec![false; n].into_boxed_slice(); + w[..big_d].copy_from_slice(&w_tilde); + let w = OneBased::new_ref(&*w); + + let mut c_hat = g_hat.mul_scalar(gamma); + for j in 1..big_d + 1 { + let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO }; + c_hat += term; + } + + let x_bytes = &*[ + &q.to_le_bytes(), + &(d as u64).to_le_bytes(), + &(big_m as u64).to_le_bytes(), + &(big_n as u64).to_le_bytes(), + &b_i.to_le_bytes(), + &*(1..big_m + 1) + .flat_map(|i| { + (1..big_n + 1).flat_map(move |j| a[(i, j)].iter().flat_map(|ai| ai.to_le_bytes())) + }) + .collect::>(), + &(1..big_n + 1) + .flat_map(|j| c[j].iter().flat_map(|ci| ci.to_le_bytes())) + .collect::>(), + ] + .iter() + .copied() + .flatten() + .copied() + .collect::>(); + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash(&mut y, &[x_bytes, c_hat.to_bytes().as_ref()]); + let y = OneBased(y); + + let scalars = (n + 1 - big_d..n + 1) + .map(|j| (y[n + 1 - j] * G::Zp::from_u64(w[n + 1 - j] as u64))) + .collect::>(); + let c_y = g.mul_scalar(gamma_y) + G::G1::multi_mul_scalar(&g_list.0[n - big_d..n], &scalars); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + &(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(), + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut theta_bar = vec![G::Zp::ZERO; big_n * d + 1]; + G::Zp::hash( + &mut theta_bar, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let theta = (0..big_n * d + 1).map(|k| theta_bar[k]).collect::>(); + let theta0 = theta[..big_n * d].to_vec().into_boxed_slice(); + let delta_theta = theta[big_n * d]; + + let mut t_theta = G::Zp::ZERO; + for j in 0..big_n { + let cj = &c[j + 1]; + let theta0j = &theta0[j * d..][..d]; + for k in 0..d { + t_theta += theta0j[k] * G::Zp::from_i64(cj[k]); + } + } + + let mut a_theta = vec![G::Zp::ZERO; big_d]; + let b_step = 1 + b_i.ilog2() as usize; + let step = d * b_step; + for i in 0..big_m { + // a_theta_i = A_tilde_{i + 1}.T × theta0 + let a_theta_i = &mut a_theta[step * i..][..step]; + + for j in 0..big_n { + let aij = &a[(i + 1, j + 1)]; + let theta0_j = &theta0[d * j..][..d]; + + let mut rot_aij_theta0_j = vec![G::Zp::ZERO; d]; + for p in 0..d { + let mut dot = G::Zp::ZERO; + + for q in 0..d { + let a = if p <= q { + G::Zp::from_i64(aij[q - p]) + } else { + -G::Zp::from_i64(aij[d + q - p]) + }; + dot += a * theta0_j[q]; + } + + rot_aij_theta0_j[p] = dot; + } + + for k in 0..b_step { + let a_theta_ik = &mut a_theta_i[k..]; + let mut c = G::Zp::from_u64(1 << k); + if k + 1 == b_step { + c = -c; + } + + for (dst, src) in zip(a_theta_ik.iter_mut().step_by(b_step), &rot_aij_theta0_j) { + *dst = c * *src; + } + } + } + } + + let offset_m = step * big_m; + let b_step = 1 + b_r.ilog2() as usize; + let step = d * b_step; + for j in 0..big_n { + // a_theta_j -= q G.T theta0_j + let a_theta_j = &mut a_theta[offset_m + step * j..][..step]; + let theta0_j = &theta0[d * j..][..d]; + + for k in 0..b_step { + let a_theta_jk = &mut a_theta_j[k..]; + let mut c = -G::Zp::from_u64(1 << k) * G::Zp::from_u64(q); + if k + 1 == b_step { + c = -c; + } + for (dst, src) in zip(a_theta_jk.iter_mut().step_by(b_step), theta0_j) { + *dst = c * *src; + } + } + } + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + let mut poly_0 = vec![G::Zp::ZERO; n + 1]; + let mut poly_1 = vec![G::Zp::ZERO; n + 1]; + let mut poly_2 = vec![G::Zp::ZERO; n + 1]; + let mut poly_3 = vec![G::Zp::ZERO; n + 1]; + + poly_0[0] = delta_y * gamma_y; + for i in 1..n + 1 { + poly_0[n + 1 - i] = + delta_y * (y[i] * G::Zp::from_u64(w[i] as u64)) + (delta_eq * t[i] - delta_y) * y[i]; + + if i < big_d + 1 { + poly_0[n + 1 - i] += delta_theta * a_theta[i - 1]; + } + } + + poly_1[0] = gamma; + for i in 1..big_d + 1 { + poly_1[i] = G::Zp::from_u64(w[i] as u64); + } + + poly_2[0] = gamma_y; + for i in 1..big_d + 1 { + poly_2[n + 1 - i] = y[i] * G::Zp::from_u64(w[i] as u64); + } + + for i in 1..n + 1 { + poly_3[i] = delta_eq * t[i]; + } + + let mut poly = G::Zp::poly_sub( + &G::Zp::poly_mul(&poly_0, &poly_1), + &G::Zp::poly_mul(&poly_2, &poly_3), + ); + + if poly.len() > n + 1 { + poly[n + 1] -= t_theta * delta_theta; + } + + let pi = + g.mul_scalar(poly[0]) + G::G1::multi_mul_scalar(&g_list.0[..poly.len() - 1], &poly[1..]); + + if load == ComputeLoad::Proof { + let c_hat_t = G::G2::multi_mul_scalar(&g_hat_list.0, &t.0); + let scalars = (1..n + 1) + .into_par_iter() + .map(|i| { + let i = n + 1 - i; + (delta_eq * t[i] - delta_y) * y[i] + + if i < big_d + 1 { + delta_theta * a_theta[i - 1] + } else { + G::Zp::ZERO + } + }) + .collect::>(); + let c_h = G::G1::multi_mul_scalar(&g_list.0[..n], &scalars); + + let mut z = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut z), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + ], + ); + + let mut pow = z; + let mut p_t = G::Zp::ZERO; + let mut p_h = G::Zp::ZERO; + + for i in 1..n + 1 { + p_t += t[i] * pow; + if n - i < big_d { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i] + + delta_theta * a_theta[n - i]) + * pow; + } else { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow; + } + pow = pow * z; + } + + let mut w = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut w), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + z.to_bytes().as_ref(), + p_h.to_bytes().as_ref(), + p_t.to_bytes().as_ref(), + ], + ); + + let mut poly = vec![G::Zp::ZERO; n + 1]; + for i in 1..n + 1 { + poly[i] += w * t[i]; + if i < big_d + 1 { + poly[n + 1 - i] += + (delta_eq * t[i] - delta_y) * y[i] + delta_theta * a_theta[i - 1]; + } else { + poly[n + 1 - i] += (delta_eq * t[i] - delta_y) * y[i]; + } + } + + let mut q = vec![G::Zp::ZERO; n]; + for i in (0..n).rev() { + poly[i] = poly[i] + z * poly[i + 1]; + q[i] = poly[i + 1]; + poly[i + 1] = G::Zp::ZERO; + } + let pi_kzg = g.mul_scalar(q[0]) + G::G1::multi_mul_scalar(&g_list.0[..n - 1], &q[1..n]); + + Proof { + c_hat, + c_y, + pi, + c_hat_t: Some(c_hat_t), + c_h: Some(c_h), + pi_kzg: Some(pi_kzg), + } + } else { + Proof { + c_hat, + c_y, + pi, + c_hat_t: None, + c_h: None, + pi_kzg: None, + } + } +} + +#[allow(clippy::result_unit_err)] +pub fn verify( + proof: &Proof, + public: (&PublicParams, &PublicCommit), +) -> Result<(), ()> { + let &Proof { + c_hat, + c_y, + pi, + c_hat_t, + c_h, + pi_kzg, + } = proof; + let e = G::Gt::pairing; + + let &PublicParams { + ref g_lists, + d, + big_n, + big_m, + b_i, + q, + } = public.0; + let g_list = &g_lists.g_list; + let g_hat_list = &g_lists.g_hat_list; + + let b_r = ((d * big_m) as u64 * b_i) / 2; + let big_d = d * (big_m * (1 + b_i.ilog2() as usize) + (big_n * (1 + b_r.ilog2() as usize))); + let n = big_d + 1; + + let a = &public.1.a; + let c = &public.1.c; + + let x_bytes = &*[ + &q.to_le_bytes(), + &(d as u64).to_le_bytes(), + &(big_m as u64).to_le_bytes(), + &(big_n as u64).to_le_bytes(), + &b_i.to_le_bytes(), + &*(1..big_m + 1) + .flat_map(|i| { + (1..big_n + 1).flat_map(move |j| a[(i, j)].iter().flat_map(|ai| ai.to_le_bytes())) + }) + .collect::>(), + &(1..big_n + 1) + .flat_map(|j| c[j].iter().flat_map(|ci| ci.to_le_bytes())) + .collect::>(), + ] + .iter() + .copied() + .flatten() + .copied() + .collect::>(); + + let mut delta = [G::Zp::ZERO; 2]; + G::Zp::hash( + &mut delta, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let [delta_eq, delta_y] = delta; + + let mut y = vec![G::Zp::ZERO; n]; + G::Zp::hash(&mut y, &[x_bytes, c_hat.to_bytes().as_ref()]); + let y = OneBased(y); + + let mut t = vec![G::Zp::ZERO; n]; + G::Zp::hash( + &mut t, + &[ + &(1..n + 1) + .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .collect::>(), + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + ], + ); + let t = OneBased(t); + + let mut theta_bar = vec![G::Zp::ZERO; big_n * d + 1]; + G::Zp::hash( + &mut theta_bar, + &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], + ); + let theta = (0..big_n * d + 1).map(|k| theta_bar[k]).collect::>(); + let theta0 = theta[..big_n * d].to_vec().into_boxed_slice(); + let delta_theta = theta[big_n * d]; + + let mut t_theta = G::Zp::ZERO; + for j in 0..big_n { + let cj = &c[j + 1]; + let theta0j = &theta0[j * d..][..d]; + for k in 0..d { + t_theta += theta0j[k] * G::Zp::from_i64(cj[k]); + } + } + + let mut a_theta = vec![G::Zp::ZERO; big_d]; + let b_step = 1 + b_i.ilog2() as usize; + let step = d * b_step; + for i in 0..big_m { + // a_theta_i = A_tilde_{i + 1}.T × theta0 + let a_theta_i = &mut a_theta[step * i..][..step]; + + for j in 0..big_n { + let aij = &a[(i + 1, j + 1)]; + let theta0_j = &theta0[d * j..][..d]; + + let mut rot_aij_theta0_j = vec![G::Zp::ZERO; d]; + for p in 0..d { + let mut dot = G::Zp::ZERO; + + for q in 0..d { + let a = if p <= q { + G::Zp::from_i64(aij[q - p]) + } else { + -G::Zp::from_i64(aij[d + q - p]) + }; + dot += a * theta0_j[q]; + } + + rot_aij_theta0_j[p] = dot; + } + + for k in 0..b_step { + let a_theta_ik = &mut a_theta_i[k..]; + let mut c = G::Zp::from_u64(1 << k); + if k + 1 == b_step { + c = -c; + } + + for (dst, src) in zip(a_theta_ik.iter_mut().step_by(b_step), &rot_aij_theta0_j) { + *dst = c * *src; + } + } + } + } + + let offset_m = step * big_m; + let b_step = 1 + b_r.ilog2() as usize; + let step = d * b_step; + for j in 0..big_n { + // a_theta_j -= q G.T theta0_j + let a_theta_j = &mut a_theta[offset_m + step * j..][..step]; + let theta0_j = &theta0[d * j..][..d]; + + for k in 0..b_step { + let a_theta_jk = &mut a_theta_j[k..]; + let mut c = -G::Zp::from_u64(1 << k) * G::Zp::from_u64(q); + if k + 1 == b_step { + c = -c; + } + for (dst, src) in zip(a_theta_jk.iter_mut().step_by(b_step), theta0_j) { + *dst = c * *src; + } + } + } + + if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) { + let mut z = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut z), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + ], + ); + + let mut pow = z; + let mut p_t = G::Zp::ZERO; + let mut p_h = G::Zp::ZERO; + + for i in 1..n + 1 { + p_t += t[i] * pow; + if n - i < big_d { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i] + + delta_theta * a_theta[n - i]) + * pow; + } else { + p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow; + } + pow = pow * z; + } + + if e(pi, G::G2::GENERATOR) + != e(c_y.mul_scalar(delta_y) + c_h, c_hat) + - e(c_y.mul_scalar(delta_eq), c_hat_t) + - e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta) + { + return Err(()); + } + + let mut w = G::Zp::ZERO; + G::Zp::hash( + core::array::from_mut(&mut w), + &[ + x_bytes, + c_hat.to_bytes().as_ref(), + c_y.to_bytes().as_ref(), + pi.to_bytes().as_ref(), + c_h.to_bytes().as_ref(), + c_hat_t.to_bytes().as_ref(), + &y.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &t.0.iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + &delta + .iter() + .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .collect::>(), + z.to_bytes().as_ref(), + p_h.to_bytes().as_ref(), + p_t.to_bytes().as_ref(), + ], + ); + + if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR) + + e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w) + == e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z)) + { + Ok(()) + } else { + Err(()) + } + } else { + let (term0, term1) = rayon::join( + || { + let p = c_y.mul_scalar(delta_y) + + (1..n + 1) + .into_par_iter() + .map(|i| { + let mut factor = (delta_eq * t[i] - delta_y) * y[i]; + if i < big_d + 1 { + factor += delta_theta * a_theta[i - 1]; + } + g_list[n + 1 - i].mul_scalar(factor) + }) + .sum::(); + let q = c_hat; + e(p, q) + }, + || { + let p = c_y; + let q = (1..n + 1) + .into_par_iter() + .map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i])) + .sum::(); + e(p, q) + }, + ); + let term2 = { + let p = g_list[1]; + let q = g_hat_list[n]; + e(p, q) + }; + + let lhs = e(pi, G::G2::GENERATOR); + let rhs = term0 - term1 - term2.mul_scalar(t_theta * delta_theta); + + if lhs == rhs { + Ok(()) + } else { + Err(()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + fn time(f: impl FnOnce() -> R) -> R { + let time = std::time::Instant::now(); + let r = f(); + println!("{:?}", time.elapsed()); + r + } + + #[test] + fn test_rlwe() { + let rng = &mut StdRng::seed_from_u64(0); + let d: usize = 2048; + let big_m: usize = 1; + let big_n: usize = 1; + + let q = 1217; + let b_i: u64 = 512; + + let mut a = Matrix::new(d, big_m, big_n, 0i64); + let mut c = Vector::new(d, big_n, 0i64); + let mut s = Vector::new(d, big_m, 0i64); + + for i in 0..big_m { + for k in 0..d { + s[i + 1][k] = (rng.gen::() % (2 * b_i)) as i64 - b_i as i64; + } + } + + for i in 0..big_m { + for j in 0..big_n { + for k in 0..d { + let mut x = (rng.gen::() % q) as i64; + if x >= q as i64 / 2 { + x -= q as i64; + } + a[(i + 1, j + 1)][k] = x; + } + } + } + + for j in 1..big_n + 1 { + let c = &mut c[j]; + + let mut polymul = vec![0i128; d]; + for i in 1..big_m + 1 { + let si = &s[i]; + let aij = &a[(i, j)]; + + for ii in 0..d { + for jj in 0..d { + let p = (aij[ii] as i128) * si[jj] as i128; + if ii + jj < d { + polymul[ii + jj] += p; + } else { + polymul[ii + jj - d] -= p; + } + } + } + } + + for (ck, old_ck) in core::iter::zip(c, &polymul) { + let q = if q == 0 { q as i128 } else { 1i128 << 64 }; + let mut new_ck = old_ck.rem_euclid(q); + if new_ck >= q / 2 { + new_ck -= q; + } + *ck = new_ck as i64; + } + } + + let public_params = crs_gen::(d, big_n, big_m, b_i, q, rng); + let (public_commit, private_commit) = commit(a, c, s, &public_params, rng); + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = + time(|| prove((&public_params, &public_commit), &private_commit, load, rng)); + let verify = time(|| verify(&proof, (&public_params, &public_commit))); + assert!(verify.is_ok()); + } + } +} From 71515501218116faa5f1bb666017640ec2ca3956 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Mon, 8 Apr 2024 10:00:36 +0200 Subject: [PATCH 2/2] feat(tfhe): plug zk-pok into all layers --- Makefile | 32 +- scripts/integer-tests.sh | 2 +- scripts/shortint-tests.sh | 2 +- tfhe-zk-pok/src/proofs/pke.rs | 36 +- tfhe/Cargo.toml | 3 + tfhe/build.rs | 2 + tfhe/c_api_tests/test_high_level_zk.c | 109 ++ tfhe/docs/guides/public_key.md | 2 + tfhe/docs/guides/zk-pok.md | 69 + tfhe/js_on_wasm_tests/test-hlapi-unsigned.js | 56 +- tfhe/src/c_api/high_level_api/booleans.rs | 119 ++ tfhe/src/c_api/high_level_api/integers.rs | 131 +- tfhe/src/c_api/high_level_api/mod.rs | 2 + tfhe/src/c_api/high_level_api/zk.rs | 59 + .../core_crypto/algorithms/lwe_encryption.rs | 1272 ++++++++++++++--- .../lwe_zero_knowledge_verification.rs | 102 ++ tfhe/src/core_crypto/algorithms/mod.rs | 4 + .../algorithms/test/lwe_encryption.rs | 355 ++++- .../commons/math/random/gaussian.rs | 1 - .../commons/math/random/generator.rs | 25 + .../core_crypto/commons/math/random/mod.rs | 78 + .../commons/math/random/t_uniform.rs | 1 - .../src/core_crypto/commons/numeric/signed.rs | 9 + .../core_crypto/commons/numeric/unsigned.rs | 6 + tfhe/src/core_crypto/entities/mod.rs | 2 + tfhe/src/error.rs | 59 + tfhe/src/high_level_api/booleans/compact.rs | 6 +- .../src/high_level_api/booleans/compressed.rs | 2 +- tfhe/src/high_level_api/booleans/encrypt.rs | 10 +- tfhe/src/high_level_api/booleans/mod.rs | 4 + tfhe/src/high_level_api/booleans/tests.rs | 52 +- tfhe/src/high_level_api/booleans/zk.rs | 134 ++ tfhe/src/high_level_api/config.rs | 2 +- tfhe/src/high_level_api/errors.rs | 40 - .../high_level_api/integers/signed/compact.rs | 4 +- .../integers/signed/compressed.rs | 2 +- .../high_level_api/integers/signed/encrypt.rs | 10 +- .../src/high_level_api/integers/signed/mod.rs | 2 + .../high_level_api/integers/signed/static_.rs | 9 + .../high_level_api/integers/signed/tests.rs | 47 + tfhe/src/high_level_api/integers/signed/zk.rs | 140 ++ .../integers/unsigned/compact.rs | 4 +- .../integers/unsigned/compressed.rs | 2 +- .../integers/unsigned/encrypt.rs | 10 +- .../high_level_api/integers/unsigned/mod.rs | 2 + .../integers/unsigned/static_.rs | 8 + .../integers/unsigned/tests/cpu.rs | 48 +- .../high_level_api/integers/unsigned/zk.rs | 140 ++ tfhe/src/high_level_api/mod.rs | 15 +- tfhe/src/high_level_api/tests/mod.rs | 6 +- tfhe/src/high_level_api/zk.rs | 11 + tfhe/src/integer/block_decomposition.rs | 30 +- tfhe/src/integer/mod.rs | 9 +- .../server_key/radix_parallel/ilog2.rs | 1 + tfhe/src/integer/zk.rs | 139 ++ .../js_high_level_api/integers.rs | 453 +++++- .../js_on_wasm_api/js_high_level_api/keys.rs | 4 +- .../js_on_wasm_api/js_high_level_api/mod.rs | 2 + .../js_on_wasm_api/js_high_level_api/zk.rs | 76 + tfhe/src/js_on_wasm_api/shortint.rs | 134 +- tfhe/src/lib.rs | 10 +- tfhe/src/shortint/ciphertext/mod.rs | 5 + tfhe/src/shortint/ciphertext/zk.rs | 194 +++ tfhe/src/shortint/engine/mod.rs | 6 + tfhe/src/shortint/public_key/compact.rs | 161 ++- tfhe/src/test_user_docs.rs | 1 + tfhe/src/zk.rs | 140 ++ tfhe/web_wasm_parallel_tests/index.html | 13 +- tfhe/web_wasm_parallel_tests/index.js | 2 + tfhe/web_wasm_parallel_tests/jest.config.js | 2 +- tfhe/web_wasm_parallel_tests/test/common.mjs | 6 +- .../test/compact-public-key.test.js | 8 + tfhe/web_wasm_parallel_tests/worker.js | 172 ++- 73 files changed, 4485 insertions(+), 331 deletions(-) create mode 100644 tfhe/c_api_tests/test_high_level_zk.c create mode 100644 tfhe/docs/guides/zk-pok.md create mode 100644 tfhe/src/c_api/high_level_api/zk.rs create mode 100644 tfhe/src/core_crypto/algorithms/lwe_zero_knowledge_verification.rs create mode 100644 tfhe/src/error.rs create mode 100644 tfhe/src/high_level_api/booleans/zk.rs create mode 100644 tfhe/src/high_level_api/integers/signed/zk.rs create mode 100644 tfhe/src/high_level_api/integers/unsigned/zk.rs create mode 100644 tfhe/src/high_level_api/zk.rs create mode 100644 tfhe/src/integer/zk.rs create mode 100644 tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs create mode 100644 tfhe/src/shortint/ciphertext/zk.rs create mode 100644 tfhe/src/zk.rs diff --git a/Makefile b/Makefile index 9976c821fa..85ecf47181 100644 --- a/Makefile +++ b/Makefile @@ -175,10 +175,18 @@ fmt_gpu: install_rs_check_toolchain cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt cd "$(TFHECUDA_SRC)" && ./format_tfhe_cuda_backend.sh +.PHONY: fmt_c_tests # Format c tests +fmt_c_tests: + find tfhe/c_api_tests/ -regex '.*\.\(cpp\|hpp\|cu\|c\|h\)' -exec clang-format -style=file -i {} \; + .PHONY: check_fmt # Check rust code format check_fmt: install_rs_check_toolchain cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt --check +.PHONY: check_fmt_c_tests # Check C tests format +check_fmt_c_tests: + find tfhe/c_api_tests/ -regex '.*\.\(cpp\|hpp\|cu\|c\|h\)' -exec clang-format --dry-run --Werror -style=file {} \; + .PHONY: check_fmt_gpu # Check rust and cuda code format check_fmt_gpu: install_rs_check_toolchain cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt --check @@ -274,7 +282,7 @@ clippy_trivium: install_rs_check_toolchain .PHONY: clippy_all_targets # Run clippy lints on all targets (benches, examples, etc.) clippy_all_targets: RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,zk-pok-experimental \ -p $(TFHE_SPEC) -- --no-deps -D warnings .PHONY: clippy_concrete_csprng # Run clippy lints on concrete-csprng @@ -353,14 +361,14 @@ symlink_c_libs_without_fingerprint: .PHONY: build_c_api # Build the C API for boolean, shortint and integer build_c_api: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) build --profile $(CARGO_PROFILE) \ - --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api,$(FORWARD_COMPAT_FEATURE) \ + --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api,zk-pok-experimental,$(FORWARD_COMPAT_FEATURE) \ -p $(TFHE_SPEC) @"$(MAKE)" symlink_c_libs_without_fingerprint .PHONY: build_c_api_gpu # Build the C API for boolean, shortint and integer build_c_api_gpu: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) build --profile $(CARGO_PROFILE) \ - --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api,gpu \ + --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api,zk-pok-experimental,gpu \ -p $(TFHE_SPEC) @"$(MAKE)" symlink_c_libs_without_fingerprint @@ -376,7 +384,7 @@ build_web_js_api: install_rs_build_toolchain install_wasm_pack cd tfhe && \ RUSTFLAGS="$(WASM_RUSTFLAGS)" rustup run "$(RS_BUILD_TOOLCHAIN)" \ wasm-pack build --release --target=web \ - -- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api + -- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,zk-pok-experimental .PHONY: build_web_js_api_parallel # Build the js API targeting the web browser with parallelism support build_web_js_api_parallel: install_rs_check_toolchain install_wasm_pack @@ -384,7 +392,7 @@ build_web_js_api_parallel: install_rs_check_toolchain install_wasm_pack rustup component add rust-src --toolchain $(RS_CHECK_TOOLCHAIN) && \ RUSTFLAGS="$(WASM_RUSTFLAGS) -C target-feature=+atomics,+bulk-memory,+mutable-globals" rustup run $(RS_CHECK_TOOLCHAIN) \ wasm-pack build --release --target=web \ - -- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,parallel-wasm-api \ + -- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,parallel-wasm-api,zk-pok-experimental \ -Z build-std=panic_abort,std .PHONY: build_node_js_api # Build the js API targeting nodejs @@ -392,7 +400,7 @@ build_node_js_api: install_rs_build_toolchain install_wasm_pack cd tfhe && \ RUSTFLAGS="$(WASM_RUSTFLAGS)" rustup run "$(RS_BUILD_TOOLCHAIN)" \ wasm-pack build --release --target=nodejs \ - -- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api + -- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,zk-pok-experimental .PHONY: build_concrete_csprng # Build concrete_csprng build_concrete_csprng: install_rs_build_toolchain @@ -402,10 +410,10 @@ build_concrete_csprng: install_rs_build_toolchain .PHONY: test_core_crypto # Run the tests of the core_crypto module including experimental ones test_core_crypto: install_rs_build_toolchain install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ - --features=$(TARGET_ARCH_FEATURE),experimental -p $(TFHE_SPEC) -- core_crypto:: + --features=$(TARGET_ARCH_FEATURE),experimental,zk-pok-experimental -p $(TFHE_SPEC) -- core_crypto:: @if [[ "$(AVX512_SUPPORT)" == "ON" ]]; then \ RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ - --features=$(TARGET_ARCH_FEATURE),experimental,$(AVX512_FEATURE) -p $(TFHE_SPEC) -- core_crypto::; \ + --features=$(TARGET_ARCH_FEATURE),experimental,zk-pok-experimental,$(AVX512_FEATURE) -p $(TFHE_SPEC) -- core_crypto::; \ fi .PHONY: test_core_crypto_cov # Run the tests of the core_crypto module with code coverage @@ -576,7 +584,7 @@ test_integer_cov: install_rs_check_toolchain install_tarpaulin .PHONY: test_high_level_api # Run all the tests for high_level_api test_high_level_api: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache -p $(TFHE_SPEC) \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,zk-pok-experimental -p $(TFHE_SPEC) \ -- high_level_api:: test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest @@ -587,14 +595,14 @@ test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest .PHONY: test_user_doc # Run tests from the .md documentation test_user_doc: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) --doc \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,pbs-stats \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,pbs-stats,zk-pok-experimental \ -p $(TFHE_SPEC) \ -- test_user_docs:: .PHONY: test_user_doc_gpu # Run tests for GPU from the .md documentation test_user_doc_gpu: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) --doc \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,gpu -p $(TFHE_SPEC) \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,gpu,zk-pok-experimental -p $(TFHE_SPEC) \ -- test_user_docs:: .PHONY: test_fhe_strings # Run tests for fhe_strings example @@ -633,7 +641,7 @@ test_concrete_csprng: RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ --features=$(TARGET_ARCH_FEATURE) -p concrete-csprng -.PHONY: test_zk_pok # Run tfhe-zk-pok tests +.PHONY: test_zk_pok # Run tfhe-zk-pok-experimental tests test_zk_pok: RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ -p tfhe-zk-pok diff --git a/scripts/integer-tests.sh b/scripts/integer-tests.sh index 22d15efce9..512b5990af 100755 --- a/scripts/integer-tests.sh +++ b/scripts/integer-tests.sh @@ -162,7 +162,7 @@ cargo "${RUST_TOOLCHAIN}" nextest run \ --cargo-profile "${cargo_profile}" \ --package "${tfhe_package}" \ --profile ci \ - --features="${ARCH_FEATURE}",integer,internal-keycache,"${avx512_feature}" \ + --features="${ARCH_FEATURE}",integer,internal-keycache,zk-pok-experimental,"${avx512_feature}" \ --test-threads "${test_threads}" \ -E "$filter_expression" diff --git a/scripts/shortint-tests.sh b/scripts/shortint-tests.sh index d1b7236250..1a4421882a 100755 --- a/scripts/shortint-tests.sh +++ b/scripts/shortint-tests.sh @@ -120,7 +120,7 @@ and not test(~smart_add_and_mul)""" # This test is too slow --cargo-profile "${cargo_profile}" \ --package "${tfhe_package}" \ --profile ci \ - --features="${ARCH_FEATURE}",shortint,internal-keycache \ + --features="${ARCH_FEATURE}",shortint,internal-keycache,zk-pok-experimental \ --test-threads "${n_threads_small}" \ -E "${filter_expression_small_params}" diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index b1cbaf59e8..0c4f62b813 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -51,12 +51,6 @@ impl PublicParams { } } -#[allow(dead_code)] -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct PrivateParams { - alpha: G::Zp, -} - #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct Proof { c_hat: G::G2, @@ -104,27 +98,24 @@ pub fn crs_gen( q: u64, t: u64, rng: &mut dyn RngCore, -) -> (PublicParams, PrivateParams) { +) -> PublicParams { let alpha = G::Zp::rand(rng); let b_r = d as u64 / 2 + 1; let big_d = d + k * t.ilog2() as usize + (d + k) * (2 + b.ilog2() as usize + b_r.ilog2() as usize); let n = big_d + 1; - ( - PublicParams { - g_lists: GroupElements::::new(n, alpha), - big_d, - n, - d, - k, - b, - b_r, - q, - t, - }, - PrivateParams { alpha }, - ) + PublicParams { + g_lists: GroupElements::::new(n, alpha), + big_d, + n, + d, + k, + b, + b_r, + q, + t, + } } #[allow(clippy::too_many_arguments)] @@ -990,8 +981,7 @@ mod tests { m_roundtrip[i] = result; } - let (public_param, _private_param) = - crs_gen::(d, k, b_i, q, t, rng); + let public_param = crs_gen::(d, k, b_i, q, t, rng); for use_fake_e1 in [false, true] { for use_fake_e2 in [false, true] { diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index a74d8bfd4c..682be0ae6e 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -69,6 +69,8 @@ paste = "1.0.7" fs2 = { version = "0.4.3", optional = true } # While we wait for repeat_n in rust standard library itertools = "0.11.0" +rand_core = { version = "0.6.4", features = ["std"] } +tfhe-zk-pok = { version = "0.1.0", path = "../tfhe-zk-pok", optional = true } # wasm deps wasm-bindgen = { version = "0.2.86", features = [ @@ -87,6 +89,7 @@ shortint = [] integer = ["shortint"] internal-keycache = ["dep:lazy_static", "dep:fs2"] gpu = ["tfhe-cuda-backend"] +zk-pok-experimental = ["dep:tfhe-zk-pok"] pbs-stats = [] diff --git a/tfhe/build.rs b/tfhe/build.rs index 54deb47e58..ae11ff3670 100644 --- a/tfhe/build.rs +++ b/tfhe/build.rs @@ -57,6 +57,8 @@ fn gen_c_api() { "integer", #[cfg(feature = "gpu")] "gpu", + #[cfg(feature = "zk-pok-experimental")] + "zk-pok-experimental", ]; let parse_expand_vec = if parse_expand_features_vec.is_empty() { diff --git a/tfhe/c_api_tests/test_high_level_zk.c b/tfhe/c_api_tests/test_high_level_zk.c new file mode 100644 index 0000000000..076e407f9b --- /dev/null +++ b/tfhe/c_api_tests/test_high_level_zk.c @@ -0,0 +1,109 @@ +#include "tfhe.h" +#include +#include +#include + +int main(void) { + // We want to use zk-proof, which requires bounded random distributions + // tfhe-rs has the `TUniform` as an available bounded distribution. + + // Note that simply changing parameters like this does not yield secure parameters + // Its only done for the example / tests + ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2_KS_PBS; + params.glwe_noise_distribution = new_t_uniform(9); + assert(params.encryption_key_choice == ShortintEncryptionKeyChoiceBig); + + int status; + + ConfigBuilder *builder; + status = config_builder_default(&builder); + assert(status == 0); + status = config_builder_use_custom_parameters(&builder, params); + assert(status == 0); + + Config *config; + status = config_builder_build(builder, &config); + assert(status == 0); + + // Compute the CRS + // Note that we do that before generating the client key + // as client_key_generate thakes ownership of the config + CompactPkeCrs *crs; + size_t max_num_bits = 32; + status = compact_pke_crs_from_config(config, max_num_bits, &crs); + assert(status == 0); + + CompactPkePublicParams *public_params; + status = compact_pke_crs_public_params(crs, &public_params); + assert(status == 0); + + ClientKey *client_key; + status = client_key_generate(config, &client_key); + assert(status == 0); + + // zk proofs of encryption works only using the CompactPublicKey + CompactPublicKey *pk; + status = compact_public_key_new(client_key, &pk); + assert(status == 0); + + // Demo of ProvenCompactFheUint32 + { + uint32_t msg = 8328937; + ProvenCompactFheUint32 *proven_fhe_uint; + status = proven_compact_fhe_uint32_try_encrypt(msg, public_params, pk, ZkComputeLoadProof, + &proven_fhe_uint); + assert(status == 0); + + FheUint32 *fhe_uint; + // This function does not take ownership of the proven fhe uint, so we have to cleanup later + status = + proven_compact_fhe_uint32_verify_and_expand(proven_fhe_uint, public_params, pk, &fhe_uint); + assert(status == 0); + + uint32_t decrypted; + status = fhe_uint32_decrypt(fhe_uint, client_key, &decrypted); + assert(status == 0); + + assert(decrypted == msg); + fhe_uint32_destroy(fhe_uint); + proven_compact_fhe_uint32_destroy(proven_fhe_uint); + } + + // Demo of ProvenCompactFheUint32List + { + uint32_t msgs[4] = {8328937, 217521191, 2753219039, 91099540}; + ProvenCompactFheUint32List *proven_fhe_list; + status = proven_compact_fhe_uint32_list_try_encrypt(msgs, 4, public_params, pk, + ZkComputeLoadProof, &proven_fhe_list); + assert(status == 0); + + size_t list_len; + status = proven_compact_fhe_uint32_list_len(proven_fhe_list, &list_len); + assert(status == 0); + assert(list_len == 4); + + FheUint32 *fhe_uints[4]; + // This function does not take ownership of the proven fhe uint, so we have to cleanup later + status = proven_compact_fhe_uint32_list_verify_and_expand(proven_fhe_list, public_params, pk, + &fhe_uints[0], 4); + assert(status == 0); + + for (size_t i = 0; i < 4; ++i) { + uint32_t decrypted; + status = fhe_uint32_decrypt(fhe_uints[i], client_key, &decrypted); + assert(status == 0); + + assert(decrypted == msgs[i]); + fhe_uint32_destroy(fhe_uints[i]); + } + + proven_compact_fhe_uint32_list_destroy(proven_fhe_list); + } + + compact_pke_public_params_destroy(public_params); + compact_pke_crs_destroy(crs); + compact_public_key_destroy(pk); + client_key_destroy(client_key); + + return EXIT_SUCCESS; +} diff --git a/tfhe/docs/guides/public_key.md b/tfhe/docs/guides/public_key.md index cd61286733..2feca4d96f 100644 --- a/tfhe/docs/guides/public_key.md +++ b/tfhe/docs/guides/public_key.md @@ -28,6 +28,8 @@ fn main() { This example shows how to use compact public keys. The main difference is in the ConfigBuilder, where the parameter set has been changed. +See [the guide on ZK proofs](zk-pok.md) to see how to encrypt data using compact public keys and generate a zero knowledge proof of correct encryption at the same time. + ```rust use tfhe::prelude::*; use tfhe::{ConfigBuilder, generate_keys, set_server_key, FheUint8, CompactPublicKey}; diff --git a/tfhe/docs/guides/zk-pok.md b/tfhe/docs/guides/zk-pok.md new file mode 100644 index 0000000000..539334f11a --- /dev/null +++ b/tfhe/docs/guides/zk-pok.md @@ -0,0 +1,69 @@ +# Zero Knowledge proof for Compact Public Key encryption + +TFHE-rs enables the generation of a zero-knowledge proof to verify that a compact public key encryption process has been correctly performed. In other words, the creation of a proof reveals nothing about the encrypted message, except for its already known range. This technique is derived from [Libert’s work](https://eprint.iacr.org/2023/800). + +{% hint style="info" %} +You can enable this feature using the flag: `--features=zk-pok-experimental` when building TFHE-rs. +{% endhint %} + + +Deploying this feature is straightforward: the client generates the proof at the time of encryption, while the server verifies it before proceeding with homomorphic computations. Below is an example demonstrating how a client can encrypt and prove a ciphertext, and how a server can verify the ciphertext and carry out computations on it: + + +```rust +use rand::prelude::*; +use tfhe::prelude::FheDecrypt; +use tfhe::shortint::parameters::DynamicDistribution; +use tfhe::set_server_key; +use tfhe::zk::{CompactPkeCrs, ZkComputeLoad}; + +pub fn main() -> Result<(), Box> { + let mut rng = thread_rng(); + + let max_num_message = 1; + + let mut params = tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS; + params.glwe_noise_distribution = DynamicDistribution::new_t_uniform(9); + + let client_key = tfhe::ClientKey::generate(tfhe::ConfigBuilder::with_custom_parameters(params, None)); + // This is done in an offline phase and the CRS is shared to all clients and the server + let crs = CompactPkeCrs::from_shortint_params(params, max_num_message).unwrap(); + let public_zk_params = crs.public_params(); + let server_key = tfhe::ServerKey::new(&client_key); + let public_key = tfhe::CompactPublicKey::try_new(&client_key).unwrap(); + + let clear_a = rng.gen::(); + let clear_b = rng.gen::(); + + let a = tfhe::ProvenCompactFheUint64::try_encrypt( + clear_a, + public_zk_params, + &public_key, + ZkComputeLoad::Proof, + )?; + let b = tfhe::ProvenCompactFheUint64::try_encrypt( + clear_b, + public_zk_params, + &public_key, + ZkComputeLoad::Proof, + )?; + + // Server side + let result = { + set_server_key(server_key); + + // Verify the ciphertexts + let a = a.verify_and_expand(&public_zk_params, &public_key)?; + let b = b.verify_and_expand(&public_zk_params, &public_key)?; + + a + b + }; + + // Back on the client side + let a_plus_b: u64 = result.decrypt(&client_key); + assert_eq!(a_plus_b, clear_a.wrapping_add(clear_b)); + + Ok(()) +} +``` +Encrypting and proving a CompactFheUint64 takes 6.9 s on a Dell XPS 15 9500, simulating a client machine, the verification on an hpc7a.96xlarge available on AWS takes 123 ms. diff --git a/tfhe/js_on_wasm_tests/test-hlapi-unsigned.js b/tfhe/js_on_wasm_tests/test-hlapi-unsigned.js index 50cea53ee4..d74fcf81d1 100644 --- a/tfhe/js_on_wasm_tests/test-hlapi-unsigned.js +++ b/tfhe/js_on_wasm_tests/test-hlapi-unsigned.js @@ -3,6 +3,7 @@ const assert = require('node:assert').strict; const {performance} = require('perf_hooks'); const { init_panic_hook, + Shortint, ShortintParametersName, ShortintParameters, TfheClientKey, @@ -21,9 +22,15 @@ const { CompressedFheUint256, CompactFheUint256, CompactFheUint256List, + ProvenCompactFheUint64, + ProvenCompactFheUint64List, + CompactPkeCrs, + ZkComputeLoad, FheUint256 } = require("../pkg/tfhe.js"); - +const { + randomBytes, +} = require('node:crypto'); const U256_MAX = BigInt("115792089237316195423570985008687907853269984665640564039457584007913129639935"); const U128_MAX = BigInt("340282366920938463463374607431768211455"); @@ -639,3 +646,50 @@ test('hlapi_compact_public_key_encrypt_decrypt_uint256_big_list_compact', (t) => hlapi_compact_public_key_encrypt_decrypt_uint256_list_compact(config); }); + +function generateRandomBigInt(bitLength) { + const bytesNeeded = Math.ceil(bitLength / 8); + const randomBytesBuffer = randomBytes(bytesNeeded); + + // Convert random bytes to BigInt + const randomBigInt = BigInt(`0x${randomBytesBuffer.toString('hex')}`); + + return randomBigInt; +} + +test('hlapi_compact_public_key_encrypt_and_prove_compact_uint256', (t) => { + let block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS); + block_params.set_lwe_noise_distribution(Shortint.try_new_t_uniform(9)); + + let config = TfheConfigBuilder.default() + .use_custom_parameters(block_params) + .build(); + + let clientKey = TfheClientKey.generate(config); + let publicKey = TfheCompactPublicKey.new(clientKey); + + let crs = CompactPkeCrs.from_parameters(block_params, 128); + let public_params = crs.public_params(); + + { + let input = generateRandomBigInt(64) + let encrypted = ProvenCompactFheUint64.encrypt_with_compact_public_key( + input, public_params, publicKey, ZkComputeLoad.Proof); + assert.deepStrictEqual(encrypted.verifies(public_params, publicKey), true); + let expanded = encrypted.verify_and_expand(public_params, publicKey); + let decrypted = expanded.decrypt(clientKey); + assert.deepStrictEqual(decrypted, input); + } + + { + let inputs = [generateRandomBigInt(64), generateRandomBigInt(64), generateRandomBigInt(64), generateRandomBigInt(64)]; + let encrypted = ProvenCompactFheUint64List.encrypt_with_compact_public_key( + inputs, public_params, publicKey, ZkComputeLoad.Proof); + assert.deepStrictEqual(encrypted.verifies(public_params, publicKey), true); + let expanded_list = encrypted.verify_and_expand(public_params, publicKey); + for (let i = 0; i < inputs.length; i++) { + let decrypted = expanded_list[i].decrypt(clientKey); + assert.deepStrictEqual(decrypted, inputs[i]); + } + } +}); diff --git a/tfhe/src/c_api/high_level_api/booleans.rs b/tfhe/src/c_api/high_level_api/booleans.rs index 7b3f7f2088..0174502ff5 100644 --- a/tfhe/src/c_api/high_level_api/booleans.rs +++ b/tfhe/src/c_api/high_level_api/booleans.rs @@ -139,3 +139,122 @@ pub unsafe extern "C" fn compact_fhe_bool_list_expand( } }) } + +#[cfg(feature = "zk-pok-experimental")] +mod zk { + use crate::c_api::high_level_api::utils::{ + impl_clone_on_type, impl_destroy_on_type, impl_safe_serialize_on_type, + impl_serialize_deserialize_on_type, + }; + use std::ffi::c_int; + + pub struct ProvenCompactFheBool(crate::high_level_api::ProvenCompactFheBool); + + impl_destroy_on_type!(ProvenCompactFheBool); + impl_clone_on_type!(ProvenCompactFheBool); + impl_serialize_deserialize_on_type!(ProvenCompactFheBool); + impl_safe_serialize_on_type!(ProvenCompactFheBool); + + #[no_mangle] + pub unsafe extern "C" fn proven_compact_fhe_bool_try_encrypt( + message: bool, + public_params: &crate::c_api::high_level_api::zk::CompactPkePublicParams, + pk: &crate::c_api::high_level_api::keys::CompactPublicKey, + compute_load: crate::c_api::high_level_api::zk::ZkComputeLoad, + out_result: *mut *mut ProvenCompactFheBool, + ) -> c_int { + crate::c_api::utils::catch_panic(|| { + let result = crate::high_level_api::ProvenCompactFheBool::try_encrypt( + message, + &public_params.0, + &pk.0, + compute_load.into(), + ) + .unwrap(); + + *out_result = Box::into_raw(Box::new(ProvenCompactFheBool(result))); + }) + } + + #[no_mangle] + pub unsafe extern "C" fn proven_compact_fhe_bool_verify_and_expand( + ct: *const ProvenCompactFheBool, + public_params: &crate::c_api::high_level_api::zk::CompactPkePublicParams, + pk: &crate::c_api::high_level_api::keys::CompactPublicKey, + out_result: *mut *mut super::FheBool, + ) -> c_int { + crate::c_api::utils::catch_panic(|| { + let ct = crate::c_api::utils::get_ref_checked(ct).unwrap(); + + let result = + ct.0.clone() + .verify_and_expand(&public_params.0, &pk.0) + .unwrap(); + + *out_result = Box::into_raw(Box::new(super::FheBool(result))); + }) + } + + pub struct ProvenCompactFheBoolList(crate::high_level_api::ProvenCompactFheBoolList); + + impl_destroy_on_type!(ProvenCompactFheBoolList); + impl_clone_on_type!(ProvenCompactFheBoolList); + impl_serialize_deserialize_on_type!(ProvenCompactFheBoolList); + impl_safe_serialize_on_type!(ProvenCompactFheBoolList); + + #[no_mangle] + pub unsafe extern "C" fn proven_compact_fhe_bool_list_try_encrypt( + input: *const bool, + input_len: usize, + public_params: &crate::c_api::high_level_api::zk::CompactPkePublicParams, + pk: &crate::c_api::high_level_api::keys::CompactPublicKey, + compute_load: crate::c_api::high_level_api::zk::ZkComputeLoad, + out_result: *mut *mut ProvenCompactFheBoolList, + ) -> ::std::os::raw::c_int { + crate::c_api::utils::catch_panic(|| { + let messages = std::slice::from_raw_parts(input, input_len); + + let result = crate::high_level_api::ProvenCompactFheBoolList::try_encrypt( + messages, + &public_params.0, + &pk.0, + compute_load.into(), + ) + .unwrap(); + + *out_result = Box::into_raw(Box::new(ProvenCompactFheBoolList(result))); + }) + } + + #[no_mangle] + pub unsafe extern "C" fn proven_compact_fhe_bool_list_len( + sself: *const ProvenCompactFheBoolList, + result: *mut usize, + ) -> ::std::os::raw::c_int { + crate::c_api::utils::catch_panic(|| { + let list = crate::c_api::utils::get_ref_checked(sself).unwrap(); + + *result = list.0.len(); + }) + } + + #[no_mangle] + pub unsafe extern "C" fn proven_compact_fhe_bool_list_verify_and_expand( + list: &ProvenCompactFheBoolList, + public_params: &crate::c_api::high_level_api::zk::CompactPkePublicParams, + pk: &crate::c_api::high_level_api::keys::CompactPublicKey, + output: *mut *mut super::FheBool, + output_len: usize, + ) -> ::std::os::raw::c_int { + crate::c_api::utils::catch_panic(|| { + let expanded = list.0.verify_and_expand(&public_params.0, &pk.0).unwrap(); + + let num_to_take = output_len.max(list.0.len()); + let iter = expanded.into_iter().take(num_to_take).enumerate(); + for (i, fhe_uint) in iter { + let ptr = output.wrapping_add(i); + *ptr = Box::into_raw(Box::new(super::FheBool(fhe_uint))); + } + }) + } +} diff --git a/tfhe/src/c_api/high_level_api/integers.rs b/tfhe/src/c_api/high_level_api/integers.rs index 9ddc0b9fda..5d0cef2388 100644 --- a/tfhe/src/c_api/high_level_api/integers.rs +++ b/tfhe/src/c_api/high_level_api/integers.rs @@ -388,6 +388,7 @@ macro_rules! create_integer_wrapper_type { }) } } + // The compact list version of the ciphertext type ::paste::paste! { pub struct []($crate::high_level_api::[]); @@ -434,6 +435,135 @@ macro_rules! create_integer_wrapper_type { }) } } + + + // The zk compact proven version of the compact ciphertext type + #[cfg(feature = "zk-pok-experimental")] + ::paste::paste! { + pub struct []($crate::high_level_api::[]); + + impl_destroy_on_type!([]); + + impl_clone_on_type!([]); + + impl_serialize_deserialize_on_type!([]); + + impl_safe_serialize_on_type!([]); + + #[no_mangle] + pub unsafe extern "C" fn []( + message: $clear_scalar_type, + public_params: &$crate::c_api::high_level_api::zk::CompactPkePublicParams, + pk: &$crate::c_api::high_level_api::keys::CompactPublicKey, + compute_load: $crate::c_api::high_level_api::zk::ZkComputeLoad, + out_result: *mut *mut [], + ) -> c_int { + $crate::c_api::utils::catch_panic(|| { + let message = <$clear_scalar_type as $crate::c_api::high_level_api::utils::CApiIntegerType>::to_rust(message); + + let result = $crate::high_level_api::[]::try_encrypt( + message, + &public_params.0, + &pk.0, + compute_load.into() + ).unwrap(); + + *out_result = Box::into_raw(Box::new([](result))); + }) + } + + #[no_mangle] + pub unsafe extern "C" fn []( + ct: *const [], + public_params: &$crate::c_api::high_level_api::zk::CompactPkePublicParams, + pk: &$crate::c_api::high_level_api::keys::CompactPublicKey, + out_result: *mut *mut $name, + ) -> c_int { + $crate::c_api::utils::catch_panic(|| { + let ct = $crate::c_api::utils::get_ref_checked(ct).unwrap(); + + let result = ct.0.clone().verify_and_expand(&public_params.0, &pk.0).unwrap(); + + *out_result = Box::into_raw(Box::new($name(result))); + }) + } + } + + // The zk compact proven version of the compact ciphertext list type + #[cfg(feature = "zk-pok-experimental")] + ::paste::paste! { + pub struct []($crate::high_level_api::[]); + + impl_destroy_on_type!([]); + + impl_clone_on_type!([]); + + impl_serialize_deserialize_on_type!([]); + + impl_safe_serialize_on_type!([]); + + + #[no_mangle] + pub unsafe extern "C" fn []( + input: *const $clear_scalar_type, + input_len: usize, + public_params: &$crate::c_api::high_level_api::zk::CompactPkePublicParams, + pk: &$crate::c_api::high_level_api::keys::CompactPublicKey, + compute_load: $crate::c_api::high_level_api::zk::ZkComputeLoad, + out_result: *mut *mut [], + ) -> ::std::os::raw::c_int { + $crate::c_api::utils::catch_panic(|| { + let messages = std::slice::from_raw_parts(input, input_len) + .iter() + .copied() + .map(|value| { + <$clear_scalar_type as $crate::c_api::high_level_api::utils::CApiIntegerType>::to_rust(value) + }) + .collect::>(); + + let result = $crate::high_level_api::[]::try_encrypt( + &messages, + &public_params.0, + &pk.0, + compute_load.into() + ).unwrap(); + + *out_result = Box::into_raw(Box::new([](result))); + }) + } + + #[no_mangle] + pub unsafe extern "C" fn []( + sself: *const [], + result: *mut usize, + ) -> ::std::os::raw::c_int { + $crate::c_api::utils::catch_panic(|| { + let list = $crate::c_api::utils::get_ref_checked(sself).unwrap(); + + *result = list.0.len(); + }) + } + + #[no_mangle] + pub unsafe extern "C" fn []( + list: &[], + public_params: &$crate::c_api::high_level_api::zk::CompactPkePublicParams, + pk: &$crate::c_api::high_level_api::keys::CompactPublicKey, + output: *mut *mut $name, + output_len: usize + ) -> ::std::os::raw::c_int { + $crate::c_api::utils::catch_panic(|| { + let expanded = list.0.verify_and_expand(&public_params.0, &pk.0).unwrap(); + + let num_to_take = output_len.max(list.0.len()); + let iter = expanded.into_iter().take(num_to_take).enumerate(); + for (i, fhe_uint) in iter { + let ptr = output.wrapping_add(i); + *ptr = Box::into_raw(Box::new($name(fhe_uint))); + } + }) + } + } }; // This entry point is meant for unsigned types @@ -803,7 +933,6 @@ macro_rules! impl_oprf_for_int { crate::high_level_api::SignedRandomizationSpec::FullSigned, ); *out_result = Box::into_raw(Box::new($name(result))); - }) } } diff --git a/tfhe/src/c_api/high_level_api/mod.rs b/tfhe/src/c_api/high_level_api/mod.rs index b90e95825f..a6780de192 100644 --- a/tfhe/src/c_api/high_level_api/mod.rs +++ b/tfhe/src/c_api/high_level_api/mod.rs @@ -9,3 +9,5 @@ mod threading; pub mod u128; pub mod u256; mod utils; +#[cfg(feature = "zk-pok-experimental")] +mod zk; diff --git a/tfhe/src/c_api/high_level_api/zk.rs b/tfhe/src/c_api/high_level_api/zk.rs new file mode 100644 index 0000000000..37d02d06ba --- /dev/null +++ b/tfhe/src/c_api/high_level_api/zk.rs @@ -0,0 +1,59 @@ +use super::utils::*; +use crate::c_api::high_level_api::config::Config; +use crate::c_api::utils::get_ref_checked; +use std::ffi::c_int; + +#[repr(C)] +#[derive(Copy, Clone)] +pub enum ZkComputeLoad { + ZkComputeLoadProof, + ZkComputeLoadVerify, +} + +impl From for crate::zk::ZkComputeLoad { + fn from(value: ZkComputeLoad) -> Self { + match value { + ZkComputeLoad::ZkComputeLoadProof => Self::Proof, + ZkComputeLoad::ZkComputeLoadVerify => Self::Verify, + } + } +} + +pub struct CompactPkePublicParams(pub(crate) crate::core_crypto::entities::CompactPkePublicParams); +impl_destroy_on_type!(CompactPkePublicParams); +impl_serialize_deserialize_on_type!(CompactPkePublicParams); + +pub struct CompactPkeCrs(pub(crate) crate::core_crypto::entities::CompactPkeCrs); + +impl_destroy_on_type!(CompactPkeCrs); +impl_serialize_deserialize_on_type!(CompactPkeCrs); + +#[no_mangle] +pub unsafe extern "C" fn compact_pke_crs_from_config( + config: *const Config, + max_num_bits: usize, + out_result: *mut *mut CompactPkeCrs, +) -> c_int { + crate::c_api::utils::catch_panic(|| { + let config = get_ref_checked(config).unwrap(); + + let crs = crate::core_crypto::entities::CompactPkeCrs::from_config(config.0, max_num_bits) + .unwrap(); + + *out_result = Box::into_raw(Box::new(CompactPkeCrs(crs))); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn compact_pke_crs_public_params( + crs: *const CompactPkeCrs, + out_public_params: *mut *mut CompactPkePublicParams, +) -> c_int { + crate::c_api::utils::catch_panic(|| { + let crs = get_ref_checked(crs).unwrap(); + + *out_public_params = Box::into_raw(Box::new(CompactPkePublicParams( + crs.0.public_params().clone(), + ))); + }) +} diff --git a/tfhe/src/core_crypto/algorithms/lwe_encryption.rs b/tfhe/src/core_crypto/algorithms/lwe_encryption.rs index c1b5387567..46bec73020 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_encryption.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_encryption.rs @@ -5,6 +5,8 @@ use crate::core_crypto::algorithms::slice_algorithms::*; use crate::core_crypto::algorithms::*; use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind; use crate::core_crypto::commons::generators::{EncryptionRandomGenerator, SecretRandomGenerator}; +#[cfg(feature = "zk-pok-experimental")] +use crate::core_crypto::commons::math::random::BoundedDistribution; use crate::core_crypto::commons::math::random::{ ActivatedRandomGenerator, Distribution, RandomGenerable, RandomGenerator, Uniform, UniformBinary, @@ -13,6 +15,8 @@ use crate::core_crypto::commons::parameters::*; use crate::core_crypto::commons::traits::*; use crate::core_crypto::entities::*; use rayon::prelude::*; +#[cfg(feature = "zk-pok-experimental")] +use tfhe_zk_pok::proofs::pke::{commit, prove}; /// Convenience function to share the core logic of the LWE encryption between all functions needing /// it. @@ -1683,6 +1687,202 @@ where seeded_ct } +/// This struct stores random vectors that were generated during +/// the encryption of a lwe ciphertext or lwe compact ciphertext list. +/// +/// These are needed by the zero-knowledge proof +struct CompactPublicKeyRandomVectors { + // This is 'r' + #[cfg_attr(not(feature = "zk-pok-experimental"), allow(unused))] + binary_random_vector: Vec, + // This is e1 + #[cfg_attr(not(feature = "zk-pok-experimental"), allow(unused))] + mask_noise: Vec, + // This is e2 + #[cfg_attr(not(feature = "zk-pok-experimental"), allow(unused))] + body_noise: Vec, +} + +#[cfg(feature = "zk-pok-experimental")] +fn verify_zero_knowledge_preconditions( + lwe_compact_public_key: &LweCompactPublicKey, + ciphertext_count: LweCiphertextCount, + ciphertext_modulus: CiphertextModulus, + delta: Scalar, + mask_noise_distribution: MaskDistribution, + body_noise_distribution: BodyDistribution, + public_params: &CompactPkePublicParams, +) -> crate::Result<()> +where + Scalar: UnsignedInteger + CastFrom, + Scalar::Signed: CastFrom, + i64: CastFrom, + u64: CastFrom + CastInto, + MaskDistribution: BoundedDistribution, + BodyDistribution: BoundedDistribution, + KeyCont: Container, +{ + let exclusive_max = public_params.exclusive_max_noise(); + if Scalar::BITS < 64 && (1u64 << Scalar::BITS) >= exclusive_max { + return Err( + "The given random distribution would create random values out \ + of the expected bounds of given to the CRS" + .into(), + ); + } + + if mask_noise_distribution.contains(exclusive_max.cast_into()) { + // The proof expect noise bound between [-b, b) (aka -b..b) + return Err( + "The given random distribution would create random values out \ + of the expected bounds of given to the CRS" + .into(), + ); + } + if body_noise_distribution.contains(exclusive_max.cast_into()) { + // The proof expect noise bound between [-b, b) (aka -b..b) + return Err( + "The given random distribution would create random values out \ + of the expected bounds of given to the CRS" + .into(), + ); + } + + if !ciphertext_modulus.is_native_modulus() { + return Err("This operation only supports native modulus".into()); + } + + if Scalar::BITS > 64 { + return Err("Zero knowledge proof do not support moduli greater than 2**64".into()); + } + + let expected_q = if Scalar::BITS == 64 { + 0u64 + } else { + 164 << Scalar::BITS + }; + + if expected_q != public_params.q { + return Err("Mismatched modulus between CRS and ciphertexts".into()); + } + + if ciphertext_count.0 > public_params.k { + return Err(format!( + "CRS allows at most {} ciphertexts to be proven at once, {} contained in the list", + public_params.k, ciphertext_count.0 + ) + .into()); + } + + if lwe_compact_public_key.lwe_dimension().0 > public_params.d { + return Err(format!( + "CRS allows a LweDimension of at most {}, current dimension: {}", + public_params.d, + lwe_compact_public_key.lwe_dimension().0 + ) + .into()); + } + + // 2**64 /delta == ((2**63) / delta) *2 + let plaintext_modulus = ((1u64 << (u64::BITS - 1) as usize) / u64::cast_from(delta)) * 2; + if plaintext_modulus != public_params.t { + return Err(format!( + "Mismatched plaintext modulus: CRS expects {}, requested modulus: {plaintext_modulus:?}", + public_params.t + ).into()); + } + + Ok(()) +} + +fn encrypt_lwe_ciphertext_with_compact_public_key_impl< + Scalar, + KeyCont, + OutputCont, + MaskDistribution, + NoiseDistribution, + SecretGen, + EncryptionGen, +>( + lwe_compact_public_key: &LweCompactPublicKey, + output: &mut LweCiphertext, + encoded: Plaintext, + mask_noise_distribution: MaskDistribution, + body_noise_distribution: NoiseDistribution, + secret_generator: &mut SecretRandomGenerator, + encryption_generator: &mut EncryptionRandomGenerator, +) -> CompactPublicKeyRandomVectors +where + Scalar: Encryptable + RandomGenerable, + KeyCont: Container, + OutputCont: ContainerMut, + MaskDistribution: Distribution, + NoiseDistribution: Distribution, + SecretGen: ByteRandomGenerator, + EncryptionGen: ByteRandomGenerator, +{ + assert!( + output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(), + "Mismatch between LweDimension of output ciphertext and input public key. \ + Got {:?} in output, and {:?} in public key.", + output.lwe_size().to_lwe_dimension(), + lwe_compact_public_key.lwe_dimension() + ); + + assert!( + lwe_compact_public_key.ciphertext_modulus() == output.ciphertext_modulus(), + "Mismatch between CiphertextModulus of output ciphertext and input public key. \ + Got {:?} in output, and {:?} in public key.", + output.ciphertext_modulus(), + lwe_compact_public_key.ciphertext_modulus() + ); + + assert!( + output.ciphertext_modulus().is_native_modulus(), + "This operation only supports native moduli" + ); + + let mut binary_random_vector = vec![Scalar::ZERO; lwe_compact_public_key.lwe_dimension().0]; + secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector); + + let mut mask_noise = vec![Scalar::ZERO; lwe_compact_public_key.lwe_dimension().0]; + encryption_generator + .fill_slice_with_random_noise_from_distribution(&mut mask_noise, mask_noise_distribution); + + let body_noise = vec![Scalar::ZERO; 1]; + encryption_generator + .fill_slice_with_random_noise_from_distribution(&mut mask_noise, body_noise_distribution); + + { + let (mut ct_mask, ct_body) = output.get_mut_mask_and_body(); + let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body(); + + { + slice_semi_reverse_negacyclic_convolution( + ct_mask.as_mut(), + pk_mask.as_ref(), + &binary_random_vector, + ); + + // Noise from Chi_1 for the mask part of the encryption + slice_wrapping_add_assign(ct_mask.as_mut(), mask_noise.as_slice()); + } + + { + *ct_body.data = slice_wrapping_dot_product(pk_body.as_ref(), &binary_random_vector); + // Noise from Chi_2 for the body part of the encryption + *ct_body.data = (*ct_body.data).wrapping_add(body_noise[0]); + *ct_body.data = (*ct_body.data).wrapping_add(encoded.0); + } + } + + CompactPublicKeyRandomVectors { + binary_random_vector, + mask_noise, + body_noise, + } +} + /// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`) using an /// [`LWE compact public key`](`LweCompactPublicKey`). The ciphertext can be decrypted using the /// [`LWE secret key`](`LweSecretKey`) that was used to generate the public key. @@ -1775,71 +1975,39 @@ pub fn encrypt_lwe_ciphertext_with_compact_public_key< SecretGen: ByteRandomGenerator, EncryptionGen: ByteRandomGenerator, { - assert!( - output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(), - "Mismatch between LweDimension of output ciphertext and input public key. \ - Got {:?} in output, and {:?} in public key.", - output.lwe_size().to_lwe_dimension(), - lwe_compact_public_key.lwe_dimension() - ); - - assert!( - lwe_compact_public_key.ciphertext_modulus() == output.ciphertext_modulus(), - "Mismatch between CiphertextModulus of output ciphertext and input public key. \ - Got {:?} in output, and {:?} in public key.", - output.ciphertext_modulus(), - lwe_compact_public_key.ciphertext_modulus() - ); - - assert!( - output.ciphertext_modulus().is_native_modulus(), - "This operation only supports native moduli" - ); - - let mut binary_random_vector = vec![Scalar::ZERO; lwe_compact_public_key.lwe_dimension().0]; - - secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector); - - let (mut ct_mask, ct_body) = output.get_mut_mask_and_body(); - let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body(); - - slice_semi_reverse_negacyclic_convolution( - ct_mask.as_mut(), - pk_mask.as_ref(), - &binary_random_vector, - ); - - // Noise from Chi_1 for the mask part of the encryption - encryption_generator.unsigned_integer_slice_wrapping_add_random_noise_from_distribution_assign( - ct_mask.as_mut(), + let _ = encrypt_lwe_ciphertext_with_compact_public_key_impl( + lwe_compact_public_key, + output, + encoded, mask_noise_distribution, + body_noise_distribution, + secret_generator, + encryption_generator, ); - - *ct_body.data = slice_wrapping_dot_product(pk_body.as_ref(), &binary_random_vector); - // Noise from Chi_2 for the body part of the encryption - *ct_body.data = (*ct_body.data) - .wrapping_add(encryption_generator.random_noise_from_distribution(body_noise_distribution)); - *ct_body.data = (*ct_body.data).wrapping_add(encoded.0); } -/// Encrypt an input plaintext list in an output [`LWE compact ciphertext -/// list`](`LweCompactCiphertextList`) using an [`LWE compact public key`](`LweCompactPublicKey`). -/// The expanded ciphertext list can be decrypted using the [`LWE secret key`](`LweSecretKey`) that -/// was used to generate the public key. +/// Encrypt and generates a zero-knowledge proof of an input cleartext +/// in an output [`LWE ciphertext`](`LweCiphertext`) using an +/// [`LWE compact public key`](`LweCompactPublicKey`). The ciphertext can be decrypted using the +/// [`LWE secret key`](`LweSecretKey`) that was used to generate the public key. +/// +/// /// /// # Example /// /// ```rust +/// use tfhe::core_crypto::commons::math::random::RandomGenerator; /// use tfhe::core_crypto::prelude::*; /// /// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct /// // computations /// // Define parameters for LweCiphertext creation /// let lwe_dimension = LweDimension(2048); -/// let lwe_ciphertext_count = LweCiphertextCount(lwe_dimension.0 * 4); -/// let glwe_noise_distribution = -/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0); +/// let glwe_noise_distribution = TUniform::new(9); /// let ciphertext_modulus = CiphertextModulus::new_native(); +/// let delta_log = 60; +/// let delta = 1u64 << delta_log; +/// let plaintext_modulus = 1u64 << (64 - delta_log); /// /// // Create the PRNG /// let mut seeder = new_seeder(); @@ -1848,6 +2016,7 @@ pub fn encrypt_lwe_ciphertext_with_compact_public_key< /// EncryptionRandomGenerator::::new(seeder.seed(), seeder); /// let mut secret_generator = /// SecretRandomGenerator::::new(seeder.seed()); +/// let mut random_generator = RandomGenerator::::new(seeder.seed()); /// /// // Create the LweSecretKey /// let lwe_secret_key = @@ -1860,61 +2029,178 @@ pub fn encrypt_lwe_ciphertext_with_compact_public_key< /// &mut encryption_generator, /// ); /// -/// let mut input_plaintext_list = PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0)); -/// input_plaintext_list -/// .iter_mut() -/// .enumerate() -/// .for_each(|(idx, x)| { -/// *x.0 = (idx as u64 % 16) << 60; -/// }); -/// -/// // Create a new LweCompactCiphertextList -/// let mut output_compact_ct_list = LweCompactCiphertextList::new( -/// 0u64, -/// lwe_dimension.to_lwe_size(), -/// lwe_ciphertext_count, +/// let crs = CompactPkeCrs::new( +/// lwe_dimension, +/// 1, +/// glwe_noise_distribution, /// ciphertext_modulus, -/// ); +/// plaintext_modulus, +/// &mut random_generator, +/// ) +/// .unwrap(); /// -/// encrypt_lwe_compact_ciphertext_list_with_compact_public_key( +/// // Create the plaintext +/// let msg = Cleartext(3u64); +/// +/// // Create a new LweCiphertext +/// let mut lwe = LweCiphertext::new(0u64, lwe_dimension.to_lwe_size(), ciphertext_modulus); +/// +/// let proof = encrypt_and_prove_lwe_ciphertext_with_compact_public_key( /// &lwe_compact_public_key, -/// &mut output_compact_ct_list, -/// &input_plaintext_list, +/// &mut lwe, +/// msg, +/// delta, /// glwe_noise_distribution, /// glwe_noise_distribution, /// &mut secret_generator, /// &mut encryption_generator, +/// &mut random_generator, +/// crs.public_params(), +/// ZkComputeLoad::Proof, +/// ) +/// .unwrap(); +/// +/// // verify the ciphertext list with the proof +/// assert!( +/// verify_lwe_ciphertext(&lwe, &lwe_compact_public_key, &proof, crs.public_params(),) +/// .is_valid() /// ); /// -/// let mut output_plaintext_list = input_plaintext_list.clone(); -/// output_plaintext_list.as_mut().fill(0u64); -/// -/// let lwe_ciphertext_list = output_compact_ct_list.expand_into_lwe_ciphertext_list(); +/// let decrypted_plaintext = decrypt_lwe_ciphertext(&lwe_secret_key, &lwe); /// -/// decrypt_lwe_ciphertext_list( -/// &lwe_secret_key, -/// &lwe_ciphertext_list, -/// &mut output_plaintext_list, -/// ); +/// // Round and remove encoding +/// // First create a decomposer working on the high 4 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); /// -/// let signed_decomposer = -/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// let rounded = decomposer.closest_representable(decrypted_plaintext.0); /// -/// // Round the plaintexts -/// output_plaintext_list -/// .iter_mut() -/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0)); +/// // Remove the encoding +/// let cleartext = rounded >> 60; /// -/// // Check we recovered the original messages -/// assert_eq!(input_plaintext_list, output_plaintext_list); +/// // Check we recovered the original message +/// assert_eq!(cleartext, msg.0); /// ``` -pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< +#[cfg(feature = "zk-pok-experimental")] +#[allow(clippy::too_many_arguments)] +pub fn encrypt_and_prove_lwe_ciphertext_with_compact_public_key< Scalar, + KeyCont, + OutputCont, MaskDistribution, NoiseDistribution, + SecretGen, + EncryptionGen, + G, +>( + lwe_compact_public_key: &LweCompactPublicKey, + output: &mut LweCiphertext, + message: Cleartext, + delta: Scalar, + mask_noise_distribution: MaskDistribution, + body_noise_distribution: NoiseDistribution, + secret_generator: &mut SecretRandomGenerator, + encryption_generator: &mut EncryptionRandomGenerator, + random_generator: &mut RandomGenerator, + public_params: &CompactPkePublicParams, + load: ZkComputeLoad, +) -> crate::Result +where + Scalar: Encryptable + + RandomGenerable + + CastFrom, + Scalar::Signed: CastFrom, + i64: CastFrom, + u64: CastFrom + CastInto, + KeyCont: Container, + OutputCont: ContainerMut, + MaskDistribution: BoundedDistribution, + NoiseDistribution: BoundedDistribution, + SecretGen: ByteRandomGenerator, + EncryptionGen: ByteRandomGenerator, + G: ByteRandomGenerator, +{ + verify_zero_knowledge_preconditions( + lwe_compact_public_key, + LweCiphertextCount(1), + output.ciphertext_modulus(), + delta, + mask_noise_distribution, + body_noise_distribution, + public_params, + )?; + + let CompactPublicKeyRandomVectors { + binary_random_vector, + mask_noise, + body_noise, + } = encrypt_lwe_ciphertext_with_compact_public_key_impl( + lwe_compact_public_key, + output, + Plaintext(message.0 * delta), + mask_noise_distribution, + body_noise_distribution, + secret_generator, + encryption_generator, + ); + + let (c1, c2) = output.get_mask_and_body(); + + let (public_commit, private_commit) = commit( + lwe_compact_public_key + .get_mask() + .as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + lwe_compact_public_key + .get_body() + .as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + c1.as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + vec![i64::cast_from(*c2.data)], + binary_random_vector + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + mask_noise + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + vec![i64::cast_from(message.0)], + body_noise + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + public_params, + random_generator, + ); + + Ok(prove( + (public_params, &public_commit), + &private_commit, + load, + random_generator, + )) +} + +fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key_impl< + Scalar, KeyCont, InputCont, OutputCont, + MaskDistribution, + NoiseDistribution, SecretGen, EncryptionGen, >( @@ -1925,13 +2211,14 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< body_noise_distribution: NoiseDistribution, secret_generator: &mut SecretRandomGenerator, encryption_generator: &mut EncryptionRandomGenerator, -) where +) -> CompactPublicKeyRandomVectors +where Scalar: Encryptable + RandomGenerable, - MaskDistribution: Distribution, - NoiseDistribution: Distribution, KeyCont: Container, InputCont: Container, OutputCont: ContainerMut, + MaskDistribution: Distribution, + NoiseDistribution: Distribution, SecretGen: ByteRandomGenerator, EncryptionGen: ByteRandomGenerator, { @@ -1964,22 +2251,21 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< "This operation only supports native moduli" ); - let (mut output_mask_list, mut output_body_list) = output.get_mut_mask_and_body_list(); let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body(); - - let lwe_mask_count = output_mask_list.lwe_mask_count(); - let lwe_dimension = output_mask_list.lwe_dimension(); + let (mut output_mask_list, mut output_body_list) = output.get_mut_mask_and_body_list(); let mut binary_random_vector = vec![Scalar::ZERO; output_mask_list.lwe_mask_list_size()]; secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector); - let max_ciphertext_per_bin = lwe_dimension.0; + let mut mask_noise = vec![Scalar::ZERO; output_mask_list.lwe_mask_list_size()]; + encryption_generator + .fill_slice_with_random_noise_from_distribution(&mut mask_noise, mask_noise_distribution); - let gen_iter = encryption_generator - .fork_lwe_compact_ciphertext_list_to_bin::(lwe_mask_count, lwe_dimension) - .expect("Failed to split generator into lwe compact ciphertext bins"); + let mut body_noise = vec![Scalar::ZERO; encoded.plaintext_count().0]; + encryption_generator + .fill_slice_with_random_noise_from_distribution(&mut body_noise, body_noise_distribution); - // Loop over the ciphertext "bins" + let max_ciphertext_per_bin = lwe_compact_public_key.lwe_dimension().0; output_mask_list .iter_mut() .zip( @@ -1987,19 +2273,23 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< .chunks_mut(max_ciphertext_per_bin) .zip(encoded.chunks(max_ciphertext_per_bin)) .zip(binary_random_vector.chunks(max_ciphertext_per_bin)) - .zip(gen_iter), + .zip(mask_noise.as_slice().chunks(max_ciphertext_per_bin)) + .zip(body_noise.as_slice().chunks(max_ciphertext_per_bin)), ) .for_each( |( mut output_mask, ( - ((mut output_body_chunk, input_plaintext_chunk), binary_random_slice), - mut loop_generator, + ( + ((mut output_body_chunk, input_plaintext_chunk), binary_random_slice), + mask_noise, + ), + body_noise, ), )| { - // output_body_chunk may not be able to fit the full convolution result so we create - // a temp buffer to compute the full convolution - let mut pk_body_convolved = vec![Scalar::ZERO; lwe_dimension.0]; + // output_body_chunk may not be able to fit the full convolution result so we + // create a temp buffer to compute the full convolution + let mut pk_body_convolved = vec![Scalar::ZERO; max_ciphertext_per_bin]; slice_semi_reverse_negacyclic_convolution( output_mask.as_mut(), @@ -2014,17 +2304,11 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< binary_random_slice, ); - // Noise from Chi_1 for the mask part of the encryption - loop_generator - .unsigned_integer_slice_wrapping_add_random_noise_from_distribution_assign( - output_mask.as_mut(), - mask_noise_distribution, - ); + slice_wrapping_add_assign(output_mask.as_mut(), mask_noise); - // Fill the body chunk afterwards manually as it most likely will be smaller than + // Fill the body chunk afterward manually as it most likely will be smaller than // the full convolution result. rev(b convolved r) + Delta * m + e2 // taking noise from Chi_2 for the body part of the encryption - // The reverse is to make the first element product match the single ciphertext case output_body_chunk .iter_mut() .zip( @@ -2033,22 +2317,23 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< .rev() .zip(input_plaintext_chunk.iter()), ) - .for_each(|(dst, (&src, plaintext))| { - *dst.data = src - .wrapping_add( - loop_generator - .random_noise_from_distribution(body_noise_distribution), - ) - .wrapping_add(*plaintext.0); + .zip(body_noise) + .for_each(|((dst, (&src, plaintext)), body_noise)| { + *dst.data = src.wrapping_add(*body_noise).wrapping_add(*plaintext.0); }); }, ); + CompactPublicKeyRandomVectors { + binary_random_vector, + mask_noise, + body_noise, + } } -/// Parallel variant of [`encrypt_lwe_compact_ciphertext_list_with_compact_public_key`]. Encrypt an -/// input plaintext list in an output [`LWE compact ciphertext list`](`LweCompactCiphertextList`) -/// using an [`LWE compact public key`](`LweCompactPublicKey`). The expanded ciphertext list can be -/// decrypted using the [`LWE secret key`](`LweSecretKey`) that was used to generate the public key. +/// Encrypt an input plaintext list in an output [`LWE compact ciphertext +/// list`](`LweCompactCiphertextList`) using an [`LWE compact public key`](`LweCompactPublicKey`). +/// The expanded ciphertext list can be decrypted using the [`LWE secret key`](`LweSecretKey`) that +/// was used to generate the public key. /// /// # Example /// @@ -2099,7 +2384,7 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< /// ciphertext_modulus, /// ); /// -/// par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key( +/// encrypt_lwe_compact_ciphertext_list_with_compact_public_key( /// &lwe_compact_public_key, /// &mut output_compact_ct_list, /// &input_plaintext_list, @@ -2112,7 +2397,7 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< /// let mut output_plaintext_list = input_plaintext_list.clone(); /// output_plaintext_list.as_mut().fill(0u64); /// -/// let lwe_ciphertext_list = output_compact_ct_list.par_expand_into_lwe_ciphertext_list(); +/// let lwe_ciphertext_list = output_compact_ct_list.expand_into_lwe_ciphertext_list(); /// /// decrypt_lwe_ciphertext_list( /// &lwe_secret_key, @@ -2131,7 +2416,7 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< /// // Check we recovered the original messages /// assert_eq!(input_plaintext_list, output_plaintext_list); /// ``` -pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key< +pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key< Scalar, MaskDistribution, NoiseDistribution, @@ -2149,17 +2434,303 @@ pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key< secret_generator: &mut SecretRandomGenerator, encryption_generator: &mut EncryptionRandomGenerator, ) where - Scalar: Encryptable - + RandomGenerable - + Sync - + Send, - MaskDistribution: Distribution + Sync, - NoiseDistribution: Distribution + Sync, + Scalar: Encryptable + RandomGenerable, + MaskDistribution: Distribution, + NoiseDistribution: Distribution, KeyCont: Container, InputCont: Container, OutputCont: ContainerMut, SecretGen: ByteRandomGenerator, - EncryptionGen: ParallelByteRandomGenerator, + EncryptionGen: ByteRandomGenerator, +{ + let _ = encrypt_lwe_compact_ciphertext_list_with_compact_public_key_impl( + lwe_compact_public_key, + output, + encoded, + mask_noise_distribution, + body_noise_distribution, + secret_generator, + encryption_generator, + ); +} + +/// Encrypt and generates a zero-knowledge proof of an input cleartext list in an output +/// [`LWE compact ciphertext list`](`LweCompactCiphertextList`) +/// using an [`LWE compact public key`](`LweCompactPublicKey`). +/// +/// The expanded ciphertext list can be decrypted using the [`LWE secret key`](`LweSecretKey`) that +/// was used to generate the public key. +/// +/// - The input cleartext list must have a length smaller or equal the maximum number of message +/// authorized by the CRS. +/// +/// - The noise distributions must be bounded +/// +/// +/// # Example +/// +/// ```rust +/// use tfhe::core_crypto::commons::math::random::RandomGenerator; +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweCiphertext creation +/// let lwe_dimension = LweDimension(2048); +/// let lwe_ciphertext_count = LweCiphertextCount(4); +/// let glwe_noise_distribution = TUniform::new(9); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// let delta_log = 60; +/// let delta = 1u64 << delta_log; +/// let plaintext_modulus = 1u64 << (64 - delta_log); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// let mut random_generator = RandomGenerator::::new(seeder.seed()); +/// +/// let crs = CompactPkeCrs::new( +/// lwe_dimension, +/// lwe_ciphertext_count.0, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// plaintext_modulus, +/// &mut random_generator, +/// ) +/// .unwrap(); +/// +/// // Create the LweSecretKey +/// let lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator); +/// +/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key( +/// &lwe_secret_key, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// let cleartexts = (0..lwe_ciphertext_count.0 as u64).collect::>(); +/// +/// // Create a new LweCompactCiphertextList +/// let mut output_compact_ct_list = LweCompactCiphertextList::new( +/// 0u64, +/// lwe_dimension.to_lwe_size(), +/// lwe_ciphertext_count, +/// ciphertext_modulus, +/// ); +/// +/// let proof = encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key( +/// &lwe_compact_public_key, +/// &mut output_compact_ct_list, +/// &cleartexts, +/// delta, +/// glwe_noise_distribution, +/// glwe_noise_distribution, +/// &mut secret_generator, +/// &mut encryption_generator, +/// &mut random_generator, +/// crs.public_params(), +/// ZkComputeLoad::Proof, +/// ) +/// .unwrap(); +/// +/// // verify the ciphertext list with the proof +/// assert!(verify_lwe_compact_ciphertext_list( +/// &output_compact_ct_list, +/// &lwe_compact_public_key, +/// &proof, +/// crs.public_params(), +/// ) +/// .is_valid()); +/// +/// let mut output_plaintext_list = +/// PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0)); +/// +/// let lwe_ciphertext_list = output_compact_ct_list.expand_into_lwe_ciphertext_list(); +/// +/// decrypt_lwe_ciphertext_list( +/// &lwe_secret_key, +/// &lwe_ciphertext_list, +/// &mut output_plaintext_list, +/// ); +/// +/// let signed_decomposer = +/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// // Round the plaintexts +/// output_plaintext_list +/// .iter_mut() +/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0) >> 60); +/// +/// // Check we recovered the original messages +/// assert_eq!(&cleartexts, output_plaintext_list.as_ref()); +/// ``` +#[cfg(feature = "zk-pok-experimental")] +#[allow(clippy::too_many_arguments)] +pub fn encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key< + Scalar, + KeyCont, + InputCont, + OutputCont, + MaskDistribution, + NoiseDistribution, + SecretGen, + EncryptionGen, + G, +>( + lwe_compact_public_key: &LweCompactPublicKey, + output: &mut LweCompactCiphertextList, + messages: &InputCont, + delta: Scalar, + mask_noise_distribution: MaskDistribution, + body_noise_distribution: NoiseDistribution, + secret_generator: &mut SecretRandomGenerator, + encryption_generator: &mut EncryptionRandomGenerator, + random_generator: &mut RandomGenerator, + public_params: &CompactPkePublicParams, + load: ZkComputeLoad, +) -> crate::Result +where + Scalar: Encryptable + + RandomGenerable + + CastFrom, + Scalar::Signed: CastFrom, + i64: CastFrom, + u64: CastFrom + CastInto, + MaskDistribution: BoundedDistribution, + NoiseDistribution: BoundedDistribution, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, + SecretGen: ByteRandomGenerator, + EncryptionGen: ByteRandomGenerator, + G: ByteRandomGenerator, +{ + verify_zero_knowledge_preconditions( + lwe_compact_public_key, + output.lwe_ciphertext_count(), + output.ciphertext_modulus(), + delta, + mask_noise_distribution, + body_noise_distribution, + public_params, + )?; + + let encoded = PlaintextList::from_container( + messages + .as_ref() + .iter() + .copied() + .map(|m| m * delta) + .collect::>(), + ); + + let CompactPublicKeyRandomVectors { + binary_random_vector, + mask_noise, + body_noise, + } = encrypt_lwe_compact_ciphertext_list_with_compact_public_key_impl( + lwe_compact_public_key, + output, + &encoded, + mask_noise_distribution, + body_noise_distribution, + secret_generator, + encryption_generator, + ); + + let (c1, c2) = output.get_mask_and_body_list(); + + let (public_commit, private_commit) = commit( + lwe_compact_public_key + .get_mask() + .as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + lwe_compact_public_key + .get_body() + .as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + c1.as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + c2.as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + binary_random_vector + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + mask_noise + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + messages + .as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + body_noise + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + public_params, + random_generator, + ); + + Ok(prove( + (public_params, &public_commit), + &private_commit, + load, + random_generator, + )) +} + +fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key_impl< + Scalar, + KeyCont, + InputCont, + OutputCont, + MaskDistribution, + NoiseDistribution, + SecretGen, + EncryptionGen, +>( + lwe_compact_public_key: &LweCompactPublicKey, + output: &mut LweCompactCiphertextList, + encoded: &PlaintextList, + mask_noise_distribution: MaskDistribution, + body_noise_distribution: NoiseDistribution, + secret_generator: &mut SecretRandomGenerator, + encryption_generator: &mut EncryptionRandomGenerator, +) -> CompactPublicKeyRandomVectors +where + Scalar: Encryptable + RandomGenerable, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, + MaskDistribution: Distribution, + NoiseDistribution: Distribution, + SecretGen: ByteRandomGenerator, + EncryptionGen: ByteRandomGenerator, { assert!( output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(), @@ -2190,22 +2761,21 @@ pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key< "This operation only supports native moduli" ); - let (mut output_mask_list, mut output_body_list) = output.get_mut_mask_and_body_list(); let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body(); - - let lwe_mask_count = output_mask_list.lwe_mask_count(); - let lwe_dimension = output_mask_list.lwe_dimension(); + let (mut output_mask_list, mut output_body_list) = output.get_mut_mask_and_body_list(); let mut binary_random_vector = vec![Scalar::ZERO; output_mask_list.lwe_mask_list_size()]; secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector); - let max_ciphertext_per_bin = lwe_dimension.0; + let mut mask_noise = vec![Scalar::ZERO; output_mask_list.lwe_mask_list_size()]; + encryption_generator + .fill_slice_with_random_noise_from_distribution(&mut mask_noise, mask_noise_distribution); - let gen_iter = encryption_generator - .par_fork_lwe_compact_ciphertext_list_to_bin::(lwe_mask_count, lwe_dimension) - .expect("Failed to split generator into lwe compact ciphertext bins"); + let mut body_noise = vec![Scalar::ZERO; encoded.plaintext_count().0]; + encryption_generator + .fill_slice_with_random_noise_from_distribution(&mut body_noise, body_noise_distribution); - // Loop over the ciphertext "bins" + let max_ciphertext_per_bin = lwe_compact_public_key.lwe_dimension().0; output_mask_list .par_iter_mut() .zip( @@ -2213,19 +2783,23 @@ pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key< .par_chunks_mut(max_ciphertext_per_bin) .zip(encoded.par_chunks(max_ciphertext_per_bin)) .zip(binary_random_vector.par_chunks(max_ciphertext_per_bin)) - .zip(gen_iter), + .zip(mask_noise.as_slice().par_chunks(max_ciphertext_per_bin)) + .zip(body_noise.as_slice().par_chunks(max_ciphertext_per_bin)), ) .for_each( |( mut output_mask, ( - ((mut output_body_chunk, input_plaintext_chunk), binary_random_slice), - mut loop_generator, + ( + ((mut output_body_chunk, input_plaintext_chunk), binary_random_slice), + mask_noise, + ), + body_noise, ), )| { - // output_body_chunk may not be able to fit the full convolution result so we create - // a temp buffer to compute the full convolution - let mut pk_body_convolved = vec![Scalar::ZERO; lwe_dimension.0]; + // output_body_chunk may not be able to fit the full convolution result so we + // create a temp buffer to compute the full convolution + let mut pk_body_convolved = vec![Scalar::ZERO; max_ciphertext_per_bin]; rayon::join( || { @@ -2245,17 +2819,11 @@ pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key< }, ); - // Noise from Chi_1 for the mask part of the encryption - loop_generator - .unsigned_integer_slice_wrapping_add_random_noise_from_distribution_assign( - output_mask.as_mut(), - mask_noise_distribution, - ); + slice_wrapping_add_assign(output_mask.as_mut(), mask_noise); - // Fill the body chunk afterwards manually as it most likely will be smaller than + // Fill the body chunk afterward manually as it most likely will be smaller than // the full convolution result. rev(b convolved r) + Delta * m + e2 // taking noise from Chi_2 for the body part of the encryption - // The reverse is to make the first element product match the single ciphertext case output_body_chunk .iter_mut() .zip( @@ -2264,16 +2832,396 @@ pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key< .rev() .zip(input_plaintext_chunk.iter()), ) - .for_each(|(dst, (&src, plaintext))| { - *dst.data = src - .wrapping_add( - loop_generator - .random_noise_from_distribution(body_noise_distribution), - ) - .wrapping_add(*plaintext.0); + .zip(body_noise) + .for_each(|((dst, (&src, plaintext)), body_noise)| { + *dst.data = src.wrapping_add(*body_noise).wrapping_add(*plaintext.0); }); }, ); + CompactPublicKeyRandomVectors { + binary_random_vector, + mask_noise, + body_noise, + } +} + +/// Parallel variant of [`encrypt_lwe_compact_ciphertext_list_with_compact_public_key`]. Encrypt an +/// input plaintext list in an output [`LWE compact ciphertext list`](`LweCompactCiphertextList`) +/// using an [`LWE compact public key`](`LweCompactPublicKey`). The expanded ciphertext list can be +/// decrypted using the [`LWE secret key`](`LweSecretKey`) that was used to generate the public key. +/// +/// # Example +/// +/// ```rust +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweCiphertext creation +/// let lwe_dimension = LweDimension(2048); +/// let lwe_ciphertext_count = LweCiphertextCount(lwe_dimension.0 * 4); +/// let glwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator); +/// +/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key( +/// &lwe_secret_key, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// let mut input_plaintext_list = PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0)); +/// input_plaintext_list +/// .iter_mut() +/// .enumerate() +/// .for_each(|(idx, x)| { +/// *x.0 = (idx as u64 % 16) << 60; +/// }); +/// +/// // Create a new LweCompactCiphertextList +/// let mut output_compact_ct_list = LweCompactCiphertextList::new( +/// 0u64, +/// lwe_dimension.to_lwe_size(), +/// lwe_ciphertext_count, +/// ciphertext_modulus, +/// ); +/// +/// par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key( +/// &lwe_compact_public_key, +/// &mut output_compact_ct_list, +/// &input_plaintext_list, +/// glwe_noise_distribution, +/// glwe_noise_distribution, +/// &mut secret_generator, +/// &mut encryption_generator, +/// ); +/// +/// let mut output_plaintext_list = input_plaintext_list.clone(); +/// output_plaintext_list.as_mut().fill(0u64); +/// +/// let lwe_ciphertext_list = output_compact_ct_list.par_expand_into_lwe_ciphertext_list(); +/// +/// decrypt_lwe_ciphertext_list( +/// &lwe_secret_key, +/// &lwe_ciphertext_list, +/// &mut output_plaintext_list, +/// ); +/// +/// let signed_decomposer = +/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// // Round the plaintexts +/// output_plaintext_list +/// .iter_mut() +/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0)); +/// +/// // Check we recovered the original messages +/// assert_eq!(input_plaintext_list, output_plaintext_list); +/// ``` +pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key< + Scalar, + MaskDistribution, + NoiseDistribution, + KeyCont, + InputCont, + OutputCont, + SecretGen, + EncryptionGen, +>( + lwe_compact_public_key: &LweCompactPublicKey, + output: &mut LweCompactCiphertextList, + encoded: &PlaintextList, + mask_noise_distribution: MaskDistribution, + body_noise_distribution: NoiseDistribution, + secret_generator: &mut SecretRandomGenerator, + encryption_generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable + + RandomGenerable + + Sync + + Send, + MaskDistribution: Distribution + Sync, + NoiseDistribution: Distribution + Sync, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, + SecretGen: ByteRandomGenerator, + EncryptionGen: ParallelByteRandomGenerator, +{ + let _ = par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key_impl( + lwe_compact_public_key, + output, + encoded, + mask_noise_distribution, + body_noise_distribution, + secret_generator, + encryption_generator, + ); +} + +/// Parallel variant of [`encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key`]. +/// Encrypt and generates a zero-knowledge proof of an input cleartext list in an output +/// [`LWE compact ciphertext list`](`LweCompactCiphertextList`) +/// using an [`LWE compact public key`](`LweCompactPublicKey`). +/// +/// The expanded ciphertext list can be decrypted using the [`LWE secret key`](`LweSecretKey`) that +/// was used to generate the public key. +/// +/// - The input cleartext list must have a length smaller or equal the maximum number of message +/// authorized by the CRS. +/// +/// - The noise distributions must be bounded +/// +/// +/// # Example +/// +/// ```rust +/// use tfhe::core_crypto::commons::math::random::RandomGenerator; +/// use tfhe::core_crypto::prelude::*; +/// use tfhe::zk::ZkComputeLoad; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweCiphertext creation +/// let lwe_dimension = LweDimension(2048); +/// let lwe_ciphertext_count = LweCiphertextCount(4); +/// let glwe_noise_distribution = TUniform::new(9); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// let delta_log = 60; +/// let delta = 1u64 << delta_log; +/// let plaintext_modulus = 1u64 << (64 - delta_log); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// let mut random_generator = RandomGenerator::::new(seeder.seed()); +/// +/// let crs = CompactPkeCrs::new( +/// lwe_dimension, +/// lwe_ciphertext_count.0, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// plaintext_modulus, +/// &mut random_generator, +/// ) +/// .unwrap(); +/// +/// // Create the LweSecretKey +/// let lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator); +/// +/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key( +/// &lwe_secret_key, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// let cleartexts = (0..lwe_ciphertext_count.0 as u64).collect::>(); +/// +/// // Create a new LweCompactCiphertextList +/// let mut output_compact_ct_list = LweCompactCiphertextList::new( +/// 0u64, +/// lwe_dimension.to_lwe_size(), +/// lwe_ciphertext_count, +/// ciphertext_modulus, +/// ); +/// +/// let proof = par_encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key( +/// &lwe_compact_public_key, +/// &mut output_compact_ct_list, +/// &cleartexts, +/// delta, +/// glwe_noise_distribution, +/// glwe_noise_distribution, +/// &mut secret_generator, +/// &mut encryption_generator, +/// &mut random_generator, +/// crs.public_params(), +/// ZkComputeLoad::Proof, +/// ) +/// .unwrap(); +/// +/// // verify the ciphertext list with the proof +/// assert!(verify_lwe_compact_ciphertext_list( +/// &output_compact_ct_list, +/// &lwe_compact_public_key, +/// &proof, +/// crs.public_params(), +/// ) +/// .is_valid()); +/// +/// let mut output_plaintext_list = +/// PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0)); +/// +/// let lwe_ciphertext_list = output_compact_ct_list.expand_into_lwe_ciphertext_list(); +/// +/// decrypt_lwe_ciphertext_list( +/// &lwe_secret_key, +/// &lwe_ciphertext_list, +/// &mut output_plaintext_list, +/// ); +/// +/// let signed_decomposer = +/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// // Round the plaintexts +/// output_plaintext_list +/// .iter_mut() +/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0) >> 60); +/// +/// // Check we recovered the original messages +/// assert_eq!(&cleartexts, output_plaintext_list.as_ref()); +/// ``` +#[cfg(feature = "zk-pok-experimental")] +#[allow(clippy::too_many_arguments)] +pub fn par_encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key< + Scalar, + KeyCont, + InputCont, + OutputCont, + MaskDistribution, + NoiseDistribution, + SecretGen, + EncryptionGen, + G, +>( + lwe_compact_public_key: &LweCompactPublicKey, + output: &mut LweCompactCiphertextList, + messages: &InputCont, + delta: Scalar, + mask_noise_distribution: MaskDistribution, + body_noise_distribution: NoiseDistribution, + secret_generator: &mut SecretRandomGenerator, + encryption_generator: &mut EncryptionRandomGenerator, + random_generator: &mut RandomGenerator, + public_params: &CompactPkePublicParams, + load: ZkComputeLoad, +) -> crate::Result +where + Scalar: Encryptable + + RandomGenerable + + CastFrom, + Scalar::Signed: CastFrom, + i64: CastFrom, + u64: CastFrom + CastInto, + MaskDistribution: BoundedDistribution, + NoiseDistribution: BoundedDistribution, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, + SecretGen: ByteRandomGenerator, + EncryptionGen: ByteRandomGenerator, + G: ByteRandomGenerator, +{ + verify_zero_knowledge_preconditions( + lwe_compact_public_key, + output.lwe_ciphertext_count(), + output.ciphertext_modulus(), + delta, + mask_noise_distribution, + body_noise_distribution, + public_params, + )?; + + let encoded = PlaintextList::from_container( + messages + .as_ref() + .iter() + .copied() + .map(|m| m * delta) + .collect::>(), + ); + + let CompactPublicKeyRandomVectors { + binary_random_vector, + mask_noise, + body_noise, + } = encrypt_lwe_compact_ciphertext_list_with_compact_public_key_impl( + lwe_compact_public_key, + output, + &encoded, + mask_noise_distribution, + body_noise_distribution, + secret_generator, + encryption_generator, + ); + + let (c1, c2) = output.get_mask_and_body_list(); + + let (public_commit, private_commit) = commit( + lwe_compact_public_key + .get_mask() + .as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + lwe_compact_public_key + .get_body() + .as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + c1.as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + c2.as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + binary_random_vector + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + mask_noise + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + messages + .as_ref() + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + body_noise + .iter() + .copied() + .map(CastFrom::cast_from) + .collect::>(), + public_params, + random_generator, + ); + + Ok(prove( + (public_params, &public_commit), + &private_commit, + load, + random_generator, + )) } #[cfg(test)] diff --git a/tfhe/src/core_crypto/algorithms/lwe_zero_knowledge_verification.rs b/tfhe/src/core_crypto/algorithms/lwe_zero_knowledge_verification.rs new file mode 100644 index 0000000000..5e3ecd1dd5 --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/lwe_zero_knowledge_verification.rs @@ -0,0 +1,102 @@ +use crate::core_crypto::entities::{LweCompactCiphertextList, LweCompactPublicKey}; +use crate::core_crypto::prelude::{CastFrom, Container, LweCiphertext, UnsignedInteger}; +use crate::zk::{CompactPkeProof, CompactPkePublicParams, ZkVerificationOutCome}; +use tfhe_zk_pok::proofs::pke::{verify, PublicCommit}; + +/// Verifies with the given proof that a [`LweCompactCiphertextList`](LweCompactCiphertextList) +/// is valid. +pub fn verify_lwe_compact_ciphertext_list( + lwe_compact_list: &LweCompactCiphertextList, + compact_public_key: &LweCompactPublicKey, + proof: &CompactPkeProof, + public_params: &CompactPkePublicParams, +) -> ZkVerificationOutCome +where + Scalar: UnsignedInteger, + i64: CastFrom, + ListCont: Container, + KeyCont: Container, +{ + if Scalar::BITS > 64 { + return ZkVerificationOutCome::Invalid; + } + let public_commit = PublicCommit::new( + compact_public_key + .get_mask() + .as_ref() + .iter() + .copied() + .map(|x| i64::cast_from(x)) + .collect(), + compact_public_key + .get_body() + .as_ref() + .iter() + .copied() + .map(|x| i64::cast_from(x)) + .collect(), + lwe_compact_list + .get_mask_list() + .as_ref() + .iter() + .copied() + .map(|x| i64::cast_from(x)) + .collect(), + lwe_compact_list + .get_body_list() + .as_ref() + .iter() + .copied() + .map(|x| i64::cast_from(x)) + .collect(), + ); + match verify(proof, (public_params, &public_commit)) { + Ok(_) => ZkVerificationOutCome::Valid, + Err(_) => ZkVerificationOutCome::Invalid, + } +} + +pub fn verify_lwe_ciphertext( + lwe_ciphertext: &LweCiphertext, + compact_public_key: &LweCompactPublicKey, + proof: &CompactPkeProof, + public_params: &CompactPkePublicParams, +) -> ZkVerificationOutCome +where + Scalar: UnsignedInteger, + i64: CastFrom, + Cont: Container, + KeyCont: Container, +{ + if Scalar::BITS > 64 { + return ZkVerificationOutCome::Invalid; + } + let public_commit = PublicCommit::new( + compact_public_key + .get_mask() + .as_ref() + .iter() + .copied() + .map(|x| i64::cast_from(x)) + .collect(), + compact_public_key + .get_body() + .as_ref() + .iter() + .copied() + .map(|x| i64::cast_from(x)) + .collect(), + lwe_ciphertext + .get_mask() + .as_ref() + .iter() + .copied() + .map(|x| i64::cast_from(x)) + .collect(), + vec![i64::cast_from(*lwe_ciphertext.get_body().data); 1], + ); + match verify(proof, (public_params, &public_commit)) { + Ok(_) => ZkVerificationOutCome::Valid, + Err(_) => ZkVerificationOutCome::Invalid, + } +} diff --git a/tfhe/src/core_crypto/algorithms/mod.rs b/tfhe/src/core_crypto/algorithms/mod.rs index 1d0436e4f0..6bed3630a2 100644 --- a/tfhe/src/core_crypto/algorithms/mod.rs +++ b/tfhe/src/core_crypto/algorithms/mod.rs @@ -27,6 +27,8 @@ pub mod lwe_programmable_bootstrapping; pub mod lwe_public_key_generation; pub mod lwe_secret_key_generation; pub mod lwe_wopbs; +#[cfg(feature = "zk-pok-experimental")] +pub mod lwe_zero_knowledge_verification; pub mod misc; pub mod polynomial_algorithms; pub mod seeded_ggsw_ciphertext_decompression; @@ -73,6 +75,8 @@ pub use lwe_programmable_bootstrapping::*; pub use lwe_public_key_generation::*; pub use lwe_secret_key_generation::*; pub use lwe_wopbs::*; +#[cfg(feature = "zk-pok-experimental")] +pub use lwe_zero_knowledge_verification::*; pub use seeded_ggsw_ciphertext_decompression::*; pub use seeded_ggsw_ciphertext_list_decompression::*; pub use seeded_glwe_ciphertext_decompression::*; diff --git a/tfhe/src/core_crypto/algorithms/test/lwe_encryption.rs b/tfhe/src/core_crypto/algorithms/test/lwe_encryption.rs index 715ca80a4e..4ee7f140d3 100644 --- a/tfhe/src/core_crypto/algorithms/test/lwe_encryption.rs +++ b/tfhe/src/core_crypto/algorithms/test/lwe_encryption.rs @@ -1,6 +1,10 @@ use super::*; use crate::core_crypto::commons::generators::DeterministicSeeder; +#[cfg(feature = "zk-pok-experimental")] +use crate::core_crypto::commons::math::random::RandomGenerator; use crate::core_crypto::commons::test_tools; +#[cfg(feature = "zk-pok-experimental")] +use rand::Rng; #[cfg(not(tarpaulin))] const NB_TESTS: usize = 10; @@ -182,7 +186,7 @@ fn lwe_encrypt_decrypt_custom_mod(params: ClassicTestPara assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -237,7 +241,7 @@ fn lwe_allocate_encrypt_decrypt_custom_mod( assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -290,7 +294,7 @@ fn lwe_trivial_encrypt_decrypt_custom_mod( assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -340,7 +344,7 @@ fn lwe_allocate_trivial_encrypt_decrypt_custom_mod( assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -403,7 +407,7 @@ fn lwe_list_encrypt_decrypt_custom_mod(params: ClassicTes assert!(check_encrypted_content_respects_mod( &list, - ciphertext_modulus + ciphertext_modulus, )); let mut plaintext_list = @@ -474,7 +478,7 @@ fn lwe_list_par_encrypt_decrypt_custom_mod( assert!(check_encrypted_content_respects_mod( &list, - ciphertext_modulus + ciphertext_modulus, )); let mut plaintext_list = @@ -548,7 +552,7 @@ fn lwe_public_encrypt_decrypt_custom_mod(params: ClassicT assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -623,7 +627,7 @@ fn lwe_seeded_public_encrypt_decrypt_custom_mod( assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -689,7 +693,7 @@ fn lwe_seeded_list_par_encrypt_decrypt_custom_mod(params: ClassicT assert!(check_encrypted_content_respects_mod( &std::slice::from_ref(seeded_ct.get_body().data), - ciphertext_modulus + ciphertext_modulus, )); let ct = seeded_ct.decompress_into_lwe_ciphertext(); assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -826,14 +830,14 @@ fn lwe_seeded_allocate_encrypt_decrypt_custom_mod( assert!(check_encrypted_content_respects_mod( &std::slice::from_ref(seeded_ct.get_body().data), - ciphertext_modulus + ciphertext_modulus, )); let ct = seeded_ct.decompress_into_lwe_ciphertext(); assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -971,7 +975,7 @@ fn lwe_compact_public_encrypt_decrypt_custom_mod( assert!(check_encrypted_content_respects_mod( &ct, - ciphertext_modulus + ciphertext_modulus, )); let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); @@ -991,3 +995,324 @@ fn lwe_compact_public_encrypt_decrypt_custom_mod( create_parametrized_test!(lwe_compact_public_encrypt_decrypt_custom_mod { TEST_PARAMS_4_BITS_NATIVE_U64 }); + +#[cfg(feature = "zk-pok-experimental")] +fn lwe_compact_public_encrypt_prove_verify_decrypt_custom_mod( + params: ClassicTestParams, +) where + Scalar: UnsignedTorus + CastFrom, + Scalar::Signed: CastFrom, + i64: CastFrom, + u64: CastFrom + CastInto, + rand_distr::Standard: rand_distr::Distribution, +{ + let lwe_dimension = LweDimension(params.polynomial_size.0); + let glwe_noise_distribution = TUniform::new(9); + let ciphertext_modulus = params.ciphertext_modulus; + let message_modulus_log = params.message_modulus_log; + let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus); + + let mut rsc = TestResources::new(); + let mut random_generator = RandomGenerator::::new(rsc.seeder.seed()); + + let msg_modulus = Scalar::ONE.shl(message_modulus_log.0); + let mut msg = msg_modulus; + let delta: Scalar = encoding_with_padding / msg_modulus; + + let crs = CompactPkeCrs::new( + lwe_dimension, + 1, + glwe_noise_distribution, + ciphertext_modulus, + msg_modulus * Scalar::TWO, + &mut random_generator, + ) + .unwrap(); + + while msg != Scalar::ZERO { + msg = msg.wrapping_sub(Scalar::ONE); + for _ in 0..NB_TESTS { + let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key( + lwe_dimension, + &mut rsc.secret_random_generator, + ); + + let pk = allocate_and_generate_new_lwe_compact_public_key( + &lwe_sk, + glwe_noise_distribution, + ciphertext_modulus, + &mut rsc.encryption_random_generator, + ); + + let mut ct = LweCiphertext::new( + Scalar::ZERO, + lwe_dimension.to_lwe_size(), + ciphertext_modulus, + ); + + let proof = encrypt_and_prove_lwe_ciphertext_with_compact_public_key( + &pk, + &mut ct, + Cleartext(msg), + delta, + glwe_noise_distribution, + glwe_noise_distribution, + &mut rsc.secret_random_generator, + &mut rsc.encryption_random_generator, + &mut random_generator, + crs.public_params(), + ZkComputeLoad::Proof, + ) + .unwrap(); + + assert!(check_encrypted_content_respects_mod( + &ct, + ciphertext_modulus, + )); + + let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct); + + let decoded = round_decode(decrypted.0, delta) % msg_modulus; + + assert_eq!(msg, decoded); + + // Verify the proof + assert!(verify_lwe_ciphertext(&ct, &pk, &proof, crs.public_params()).is_valid()); + + // verify proof with invalid ciphertext + let index = random_generator.gen::() % ct.as_ref().len(); + let value_to_add = random_generator.gen::(); + ct.as_mut()[index] = ct.as_mut()[index].wrapping_add(value_to_add); + assert!(verify_lwe_ciphertext(&ct, &pk, &proof, crs.public_params()).is_invalid()); + } + + // In coverage, we break after one while loop iteration, changing message values does not + // yield higher coverage + #[cfg(tarpaulin)] + break; + } +} + +#[cfg(feature = "zk-pok-experimental")] +create_parametrized_test!(lwe_compact_public_encrypt_prove_verify_decrypt_custom_mod { + TEST_PARAMS_4_BITS_NATIVE_U64 +}); + +#[cfg(feature = "zk-pok-experimental")] +#[test] +fn test_par_compact_lwe_list_public_key_encryption_and_proof() { + use rand::Rng; + + let lwe_dimension = LweDimension(2048); + let glwe_noise_distribution = TUniform::new(9); + let ciphertext_modulus = CiphertextModulus::new_native(); + + let delta_log = 59; + let delta = 1u64 << delta_log; + let message_modulus = 1u64 << (64 - (delta_log + 1)); + let plaintext_modulus = 1u64 << (64 - delta_log); + let mut thread_rng = rand::thread_rng(); + + let max_num_body = 512; + let crs = CompactPkeCrs::new( + lwe_dimension, + max_num_body, + glwe_noise_distribution, + ciphertext_modulus, + plaintext_modulus, + &mut thread_rng, + ) + .unwrap(); + + for _ in 0..4 { + let ct_count = thread_rng.gen_range(1..=max_num_body); + let lwe_ciphertext_count = LweCiphertextCount(ct_count); + + println!("{lwe_dimension:?} {ct_count:?}"); + + let seed = test_tools::random_seed(); + let cleartexts = (0..ct_count) + .map(|_| thread_rng.gen::() % message_modulus) + .collect::>(); + + let par_lwe_ct_list = { + let mut deterministic_seeder = + DeterministicSeeder::::new(seed); + let mut random_generator = + RandomGenerator::::new(deterministic_seeder.seed()); + let mut secret_random_generator = + SecretRandomGenerator::::new(deterministic_seeder.seed()); + let mut encryption_random_generator = + EncryptionRandomGenerator::::new( + deterministic_seeder.seed(), + &mut deterministic_seeder, + ); + + let lwe_sk = + LweSecretKey::generate_new_binary(lwe_dimension, &mut secret_random_generator); + + let mut compact_lwe_pk = + LweCompactPublicKey::new(0u64, lwe_dimension, ciphertext_modulus); + + generate_lwe_compact_public_key( + &lwe_sk, + &mut compact_lwe_pk, + glwe_noise_distribution, + &mut encryption_random_generator, + ); + + let mut output_compact_ct_list = LweCompactCiphertextList::new( + 0u64, + lwe_dimension.to_lwe_size(), + lwe_ciphertext_count, + ciphertext_modulus, + ); + + let proof = par_encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key( + &compact_lwe_pk, + &mut output_compact_ct_list, + &cleartexts, + delta, + glwe_noise_distribution, + glwe_noise_distribution, + &mut secret_random_generator, + &mut encryption_random_generator, + &mut random_generator, + crs.public_params(), + ZkComputeLoad::Proof, + ) + .unwrap(); + + assert!(verify_lwe_compact_ciphertext_list( + &output_compact_ct_list, + &compact_lwe_pk, + &proof, + crs.public_params() + ) + .is_valid()); + + let mut output_plaintext_list = PlaintextList::new(0u64, PlaintextCount(ct_count)); + + let lwe_ciphertext_list = output_compact_ct_list + .clone() + .par_expand_into_lwe_ciphertext_list(); + + decrypt_lwe_ciphertext_list(&lwe_sk, &lwe_ciphertext_list, &mut output_plaintext_list); + + let signed_decomposer = + SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1)); + + output_plaintext_list + .iter_mut() + .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0) >> delta_log); + + assert_eq!(cleartexts.as_slice(), output_plaintext_list.as_ref()); + + // verify proof with invalid ciphertext + let index = random_generator.gen::() % output_compact_ct_list.as_ref().len(); + let value_to_add = random_generator.gen(); + output_compact_ct_list.as_mut()[index] = + output_compact_ct_list.as_mut()[index].wrapping_add(value_to_add); + assert!(verify_lwe_compact_ciphertext_list( + &output_compact_ct_list, + &compact_lwe_pk, + &proof, + crs.public_params() + ) + .is_invalid()); + + lwe_ciphertext_list + }; + + let ser_lwe_ct_list = { + let mut deterministic_seeder = + DeterministicSeeder::::new(seed); + let mut random_generator = + RandomGenerator::::new(deterministic_seeder.seed()); + let mut secret_random_generator = + SecretRandomGenerator::::new(deterministic_seeder.seed()); + let mut encryption_random_generator = + EncryptionRandomGenerator::::new( + deterministic_seeder.seed(), + &mut deterministic_seeder, + ); + + let lwe_sk = + LweSecretKey::generate_new_binary(lwe_dimension, &mut secret_random_generator); + + let mut compact_lwe_pk = + LweCompactPublicKey::new(0u64, lwe_dimension, ciphertext_modulus); + + generate_lwe_compact_public_key( + &lwe_sk, + &mut compact_lwe_pk, + glwe_noise_distribution, + &mut encryption_random_generator, + ); + + let mut output_compact_ct_list = LweCompactCiphertextList::new( + 0u64, + lwe_dimension.to_lwe_size(), + lwe_ciphertext_count, + ciphertext_modulus, + ); + + let proof = par_encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key( + &compact_lwe_pk, + &mut output_compact_ct_list, + &cleartexts, + delta, + glwe_noise_distribution, + glwe_noise_distribution, + &mut secret_random_generator, + &mut encryption_random_generator, + &mut random_generator, + crs.public_params(), + ZkComputeLoad::Proof, + ) + .unwrap(); + + assert!(verify_lwe_compact_ciphertext_list( + &output_compact_ct_list, + &compact_lwe_pk, + &proof, + crs.public_params() + ) + .is_valid()); + + let mut output_plaintext_list = PlaintextList::new(0u64, PlaintextCount(ct_count)); + + let lwe_ciphertext_list = output_compact_ct_list + .clone() + .expand_into_lwe_ciphertext_list(); + + decrypt_lwe_ciphertext_list(&lwe_sk, &lwe_ciphertext_list, &mut output_plaintext_list); + + let signed_decomposer = + SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1)); + + output_plaintext_list + .iter_mut() + .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0) >> delta_log); + + assert_eq!(cleartexts.as_slice(), output_plaintext_list.as_ref()); + + // verify proof with invalid ciphertext + let index = random_generator.gen::() % output_compact_ct_list.as_ref().len(); + let value_to_add = random_generator.gen(); + output_compact_ct_list.as_mut()[index] = + output_compact_ct_list.as_mut()[index].wrapping_add(value_to_add); + assert!(verify_lwe_compact_ciphertext_list( + &output_compact_ct_list, + &compact_lwe_pk, + &proof, + crs.public_params() + ) + .is_invalid()); + + lwe_ciphertext_list + }; + + assert_eq!(ser_lwe_ct_list, par_lwe_ct_list); + } +} diff --git a/tfhe/src/core_crypto/commons/math/random/gaussian.rs b/tfhe/src/core_crypto/commons/math/random/gaussian.rs index 29c9c2b1d1..9094e0adbc 100644 --- a/tfhe/src/core_crypto/commons/math/random/gaussian.rs +++ b/tfhe/src/core_crypto/commons/math/random/gaussian.rs @@ -1,6 +1,5 @@ use super::*; use crate::core_crypto::commons::math::torus::FromTorus; -use crate::core_crypto::commons::numeric::{CastInto, Numeric}; use serde::{Deserialize, Serialize}; // Clippy false positive, does not repro with smaller code diff --git a/tfhe/src/core_crypto/commons/math/random/generator.rs b/tfhe/src/core_crypto/commons/math/random/generator.rs index bee97c39de..0688216c15 100644 --- a/tfhe/src/core_crypto/commons/math/random/generator.rs +++ b/tfhe/src/core_crypto/commons/math/random/generator.rs @@ -796,3 +796,28 @@ impl RandomGenerator { .map(|iter| iter.map(Self)) } } + +impl rand_core::RngCore for RandomGenerator { + fn next_u32(&mut self) -> u32 { + >::generate_one(self, Uniform) + } + + fn next_u64(&mut self) -> u64 { + >::generate_one(self, Uniform) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + dest.iter_mut().for_each(|b| *b = self.generate_next()); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + if let Some(limit) = self.remaining_bytes() { + if limit < dest.len() { + return Err(rand_core::Error::new(format!("The random generator is bounded and cannot fill the slice {} bytes requested, {limit} possible", dest.len())) + ); + } + } + self.fill_bytes(dest); + Ok(()) + } +} diff --git a/tfhe/src/core_crypto/commons/math/random/mod.rs b/tfhe/src/core_crypto/commons/math/random/mod.rs index 342d729783..5fa9f7857d 100644 --- a/tfhe/src/core_crypto/commons/math/random/mod.rs +++ b/tfhe/src/core_crypto/commons/math/random/mod.rs @@ -15,7 +15,9 @@ //! [`RandomGenerator`] instead. use crate::core_crypto::commons::dispersion::{DispersionParameter, StandardDev, Variance}; use crate::core_crypto::commons::numeric::{FloatingPoint, UnsignedInteger}; +use std::ops::Bound; +use crate::core_crypto::prelude::{CastInto, Numeric}; /// Convenience alias for the most efficient CSPRNG implementation available. pub use activated_random_generator::ActivatedRandomGenerator; pub use gaussian::*; @@ -102,6 +104,82 @@ impl Distribution for UniformTernary {} impl Distribution for Gaussian {} impl Distribution for TUniform {} +pub trait BoundedDistribution: Distribution { + fn low_bound(&self) -> Bound; + fn high_bound(&self) -> Bound; + + fn contains(self, value: T) -> bool + where + T: Numeric, + { + { + match self.low_bound() { + Bound::Included(inclusive_low) => { + if value < inclusive_low { + return false; + } + } + Bound::Excluded(exclusive_low) => { + if value <= exclusive_low { + return false; + } + } + Bound::Unbounded => {} + } + } + + { + match self.high_bound() { + Bound::Included(inclusive_high) => { + if value > inclusive_high { + return false; + } + } + Bound::Excluded(exclusive_high) => { + if value >= exclusive_high { + return false; + } + } + Bound::Unbounded => {} + } + } + + true + } +} + +impl BoundedDistribution for TUniform +where + T: UnsignedInteger, +{ + fn low_bound(&self) -> Bound { + Bound::Included(self.min_value_inclusive()) + } + + fn high_bound(&self) -> Bound { + Bound::Included(self.max_value_inclusive()) + } +} + +impl BoundedDistribution for DynamicDistribution +where + T: UnsignedInteger, +{ + fn low_bound(&self) -> Bound { + match self { + Self::Gaussian(_) => Bound::Unbounded, + Self::TUniform(tu) => tu.low_bound(), + } + } + + fn high_bound(&self) -> Bound { + match self { + Self::Gaussian(_) => Bound::Unbounded, + Self::TUniform(tu) => tu.high_bound(), + } + } +} + #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] pub enum DynamicDistribution { Gaussian(Gaussian), diff --git a/tfhe/src/core_crypto/commons/math/random/t_uniform.rs b/tfhe/src/core_crypto/commons/math/random/t_uniform.rs index 14eeff1edf..3b3e318ef1 100644 --- a/tfhe/src/core_crypto/commons/math/random/t_uniform.rs +++ b/tfhe/src/core_crypto/commons/math/random/t_uniform.rs @@ -1,5 +1,4 @@ use super::*; -use crate::core_crypto::commons::numeric::Numeric; use serde::{Deserialize, Serialize}; /// The distribution $TUniform(1, -2^b, 2^b)$ is defined as follows, any value in the interval diff --git a/tfhe/src/core_crypto/commons/numeric/signed.rs b/tfhe/src/core_crypto/commons/numeric/signed.rs index c6f6f94b8c..50e7848cec 100644 --- a/tfhe/src/core_crypto/commons/numeric/signed.rs +++ b/tfhe/src/core_crypto/commons/numeric/signed.rs @@ -41,6 +41,10 @@ pub trait SignedInteger: /// Return a bit representation of the integer, where blocks of length `block_length` are /// separated by whitespaces to increase the readability. fn to_bits_string(&self, block_length: usize) -> String; + + /// Return the absoluted balue + #[must_use] + fn wrapping_abs(self) -> Self; } macro_rules! implement { @@ -77,6 +81,11 @@ macro_rules! implement { } strn } + + #[inline] + fn wrapping_abs(self) -> Self { + self.wrapping_abs() + } } }; } diff --git a/tfhe/src/core_crypto/commons/numeric/unsigned.rs b/tfhe/src/core_crypto/commons/numeric/unsigned.rs index cdbf843b4a..ed730257e6 100644 --- a/tfhe/src/core_crypto/commons/numeric/unsigned.rs +++ b/tfhe/src/core_crypto/commons/numeric/unsigned.rs @@ -87,6 +87,8 @@ pub trait UnsignedInteger: #[must_use] fn is_power_of_two(self) -> bool; #[must_use] + fn next_power_of_two(self) -> Self; + #[must_use] fn ilog2(self) -> u32; #[must_use] fn ceil_ilog2(self) -> u32 { @@ -240,6 +242,10 @@ macro_rules! implement { self.is_power_of_two() } #[inline] + fn next_power_of_two(self) -> Self { + self.next_power_of_two() + } + #[inline] fn ilog2(self) -> u32 { self.ilog2() } diff --git a/tfhe/src/core_crypto/entities/mod.rs b/tfhe/src/core_crypto/entities/mod.rs index 9f846d73cc..1bee3d0992 100644 --- a/tfhe/src/core_crypto/entities/mod.rs +++ b/tfhe/src/core_crypto/entities/mod.rs @@ -53,6 +53,8 @@ pub use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{ FourierGgswCiphertext, FourierGgswCiphertextList, FourierGgswLevelMatrix, FourierGgswLevelRow, }; pub use crate::core_crypto::fft_impl::fft64::math::polynomial::FourierPolynomial; +#[cfg(feature = "zk-pok-experimental")] +pub use crate::zk::*; pub use cleartext::*; pub use ggsw_ciphertext::*; pub use ggsw_ciphertext_list::*; diff --git a/tfhe/src/error.rs b/tfhe/src/error.rs new file mode 100644 index 0000000000..b2559f0f31 --- /dev/null +++ b/tfhe/src/error.rs @@ -0,0 +1,59 @@ +use std::fmt::{Debug, Display, Formatter}; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum ErrorKind { + Message(String), + /// The zero knowledge proof and the content it is supposed to prove + /// failed to correctly prove + #[cfg(feature = "zk-pok-experimental")] + InvalidZkProof, +} + +#[derive(Debug, Clone)] +pub struct Error { + kind: ErrorKind, +} + +impl Error { + pub(crate) fn new(message: String) -> Self { + Self::from(ErrorKind::Message(message)) + } + + pub fn kind(&self) -> &ErrorKind { + &self.kind + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.kind() { + ErrorKind::Message(msg) => { + write!(f, "{msg}") + } + #[cfg(feature = "zk-pok-experimental")] + ErrorKind::InvalidZkProof => { + write!(f, "The zero knowledge proof and the content it is supposed to prove were not valid") + } + } + } +} + +impl From for Error { + fn from(kind: ErrorKind) -> Self { + Self { kind } + } +} + +impl<'a> From<&'a str> for Error { + fn from(message: &'a str) -> Self { + Self::new(message.to_string()) + } +} + +impl From for Error { + fn from(message: String) -> Self { + Self::new(message) + } +} + +impl std::error::Error for Error {} diff --git a/tfhe/src/high_level_api/booleans/compact.rs b/tfhe/src/high_level_api/booleans/compact.rs index 573c2d5e80..9dffe880e9 100644 --- a/tfhe/src/high_level_api/booleans/compact.rs +++ b/tfhe/src/high_level_api/booleans/compact.rs @@ -36,7 +36,7 @@ use crate::{CompactPublicKey, FheBoolConformanceParams, ServerKey}; #[cfg_attr(all(doc, not(doctest)), doc(cfg(feature = "integer")))] #[derive(Clone, serde::Deserialize, serde::Serialize)] pub struct CompactFheBool { - list: CompactCiphertextList, + pub(in crate::high_level_api) list: CompactCiphertextList, } impl CompactFheBool { @@ -55,7 +55,7 @@ impl CompactFheBool { } impl FheTryEncrypt for CompactFheBool { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: bool, key: &CompactPublicKey) -> Result { let mut ciphertext = key.key.try_encrypt_compact(&[u8::from(value)], 1); @@ -145,7 +145,7 @@ impl CompactFheBoolList { } impl<'a> FheTryEncrypt<&'a [bool], CompactPublicKey> for CompactFheBoolList { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; /// Encrypts a slice of bool /// diff --git a/tfhe/src/high_level_api/booleans/compressed.rs b/tfhe/src/high_level_api/booleans/compressed.rs index a1ae79c5a2..2bc0861d0a 100644 --- a/tfhe/src/high_level_api/booleans/compressed.rs +++ b/tfhe/src/high_level_api/booleans/compressed.rs @@ -58,7 +58,7 @@ impl CompressedFheBool { } impl FheTryEncrypt for CompressedFheBool { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; /// Creates a compressed encryption of a boolean value fn try_encrypt(value: bool, key: &ClientKey) -> Result { diff --git a/tfhe/src/high_level_api/booleans/encrypt.rs b/tfhe/src/high_level_api/booleans/encrypt.rs index ac348bf012..bab0b9bc6b 100644 --- a/tfhe/src/high_level_api/booleans/encrypt.rs +++ b/tfhe/src/high_level_api/booleans/encrypt.rs @@ -12,7 +12,7 @@ use crate::shortint::ciphertext::Degree; use crate::{ClientKey, CompactPublicKey, CompressedPublicKey, PublicKey}; impl FheTryEncrypt for FheBool { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: bool, key: &ClientKey) -> Result { let integer_client_key = &key.key.key; @@ -23,7 +23,7 @@ impl FheTryEncrypt for FheBool { } impl FheTryEncrypt for FheBool { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: bool, key: &CompactPublicKey) -> Result { let mut ciphertext = key.key.key.encrypt_radix(value as u8, 1); @@ -66,7 +66,7 @@ impl FheTrivialEncrypt for FheBool { } impl FheTryEncrypt for FheBool { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: bool, key: &CompressedPublicKey) -> Result { let key = &key.key; @@ -77,7 +77,7 @@ impl FheTryEncrypt for FheBool { } impl FheTryEncrypt for FheBool { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: bool, key: &PublicKey) -> Result { let key = &key.key; @@ -95,7 +95,7 @@ impl FheDecrypt for FheBool { } impl FheTryTrivialEncrypt for FheBool { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt_trivial(value: bool) -> Result { let ciphertext = global_state::with_internal_keys(|key| match key { diff --git a/tfhe/src/high_level_api/booleans/mod.rs b/tfhe/src/high_level_api/booleans/mod.rs index 5e6184fdaa..b95a0dc356 100644 --- a/tfhe/src/high_level_api/booleans/mod.rs +++ b/tfhe/src/high_level_api/booleans/mod.rs @@ -1,6 +1,8 @@ pub use base::{FheBool, FheBoolConformanceParams}; pub use compact::{CompactFheBool, CompactFheBoolList, CompactFheBoolListConformanceParams}; pub use compressed::CompressedFheBool; +#[cfg(feature = "zk-pok-experimental")] +pub use zk::{ProvenCompactFheBool, ProvenCompactFheBoolList}; mod base; mod compact; @@ -9,3 +11,5 @@ mod encrypt; mod inner; #[cfg(test)] mod tests; +#[cfg(feature = "zk-pok-experimental")] +mod zk; diff --git a/tfhe/src/high_level_api/booleans/tests.rs b/tfhe/src/high_level_api/booleans/tests.rs index 60c300e13d..314e610054 100644 --- a/tfhe/src/high_level_api/booleans/tests.rs +++ b/tfhe/src/high_level_api/booleans/tests.rs @@ -825,7 +825,57 @@ mod cpu { let expanded_list = deserialized_list.expand(); for (fhe_uint, expected) in expanded_list.iter().zip(clears.into_iter()) { let decrypted: bool = fhe_uint.decrypt(&client_key); - assert_eq!(decrypted, expected); + assert_eq!(decrypted, expected) + } + } + + #[cfg(feature = "zk-pok-experimental")] + #[test] + fn test_fhe_bool_zk() { + use crate::core_crypto::prelude::DynamicDistribution; + use crate::zk::{CompactPkeCrs, ZkComputeLoad}; + + let mut params = crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + params.glwe_noise_distribution = DynamicDistribution::new_t_uniform(9); + + let config = ConfigBuilder::with_custom_parameters(params, None).build(); + let crs = CompactPkeCrs::from_config(config, 2).unwrap(); + let ck = ClientKey::generate(config); + let pk = CompactPublicKey::new(&ck); + + for msg in [true, false] { + let proven_compact_fhe_bool = crate::ProvenCompactFheBool::try_encrypt( + msg, + crs.public_params(), + &pk, + ZkComputeLoad::Proof, + ) + .unwrap(); + let fhe_bool = proven_compact_fhe_bool + .verify_and_expand(crs.public_params(), &pk) + .unwrap(); + let decrypted = fhe_bool.decrypt(&ck); + assert_eq!(decrypted, msg); + assert_degree_is_ok(&fhe_bool); + } + + let proven_compact_fhe_bool_list = crate::ProvenCompactFheBoolList::try_encrypt( + &[true, false], + crs.public_params(), + &pk, + ZkComputeLoad::Proof, + ) + .unwrap(); + let fhe_bools = proven_compact_fhe_bool_list + .verify_and_expand(crs.public_params(), &pk) + .unwrap(); + let decrypted = fhe_bools + .iter() + .map(|fb| fb.decrypt(&ck)) + .collect::>(); + assert_eq!(decrypted.as_slice(), &[true, false]); + for fhe_bool in fhe_bools { + assert_degree_is_ok(&fhe_bool); } } } diff --git a/tfhe/src/high_level_api/booleans/zk.rs b/tfhe/src/high_level_api/booleans/zk.rs new file mode 100644 index 0000000000..964ed7641f --- /dev/null +++ b/tfhe/src/high_level_api/booleans/zk.rs @@ -0,0 +1,134 @@ +use crate::integer::{BooleanBlock, ProvenCompactCiphertextList, RadixCiphertext}; +use crate::named::Named; +use crate::shortint::ciphertext::Degree; +use crate::zk::{CompactPkePublicParams, ZkComputeLoad, ZkVerificationOutCome}; +use crate::{CompactPublicKey, FheBool}; +use serde::{Deserialize, Serialize}; + +/// A `CompactFheBool` tied to a Zero-Knowledge proof +/// +/// The zero-knowledge proof allows to verify that the ciphertext is correctly +/// encrypted. +#[derive(Clone, Serialize, Deserialize)] +pub struct ProvenCompactFheBool { + inner: ProvenCompactCiphertextList, +} + +impl Named for ProvenCompactFheBool { + const NAME: &'static str = "high_level_api::ProvenCompactFheBool"; +} + +impl ProvenCompactFheBool { + /// Encrypts the message while also generating the zero-knowledge proof + pub fn try_encrypt( + value: bool, + public_params: &CompactPkePublicParams, + key: &CompactPublicKey, + load: ZkComputeLoad, + ) -> crate::Result { + let value = value as u8; + let inner = key.key.key.encrypt_and_prove_radix_compact( + &[value], + 1, /* num blocks */ + public_params, + load, + )?; + Ok(Self { inner }) + } + + /// Verifies the ciphertext and the proof + /// + /// If the proof and ciphertext are valid, it returns an `Ok` with + /// the underlying `FheBool`. + pub fn verify_and_expand( + self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result { + let mut radix = self + .inner + .verify_and_expand_one::(public_params, &public_key.key.key)?; + assert_eq!(radix.blocks.len(), 1); + radix.blocks[0].degree = Degree::new(1); + Ok(FheBool::new(BooleanBlock::new_unchecked( + radix.blocks.pop().unwrap(), + ))) + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + self.inner.verify(public_params, &public_key.key.key) + } +} + +/// A `CompactFheBoolList` tied to a Zero-Knowledge proof +/// +/// The zero-knowledge proof allows to verify that the ciphertext list is correctly +/// encrypted. +#[derive(Clone, Serialize, Deserialize)] +pub struct ProvenCompactFheBoolList { + inner: ProvenCompactCiphertextList, +} + +impl Named for ProvenCompactFheBoolList { + const NAME: &'static str = "high_level_api::ProvenCompactFheBoolList"; +} + +impl ProvenCompactFheBoolList { + /// Encrypts the message while also generating the zero-knowledge proof + pub fn try_encrypt( + values: &[bool], + public_params: &CompactPkePublicParams, + key: &CompactPublicKey, + load: ZkComputeLoad, + ) -> crate::Result { + let values = values.iter().copied().map(u8::from).collect::>(); + let inner = key.key.key.encrypt_and_prove_radix_compact( + &values, + 1, /* num_blocks */ + public_params, + load, + )?; + Ok(Self { inner }) + } + + pub fn len(&self) -> usize { + self.inner.ciphertext_count() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Verifies the ciphertext and the proof + /// + /// If the proof and ciphertext are valid, it returns an `Ok` with + /// the underlying `FheBool`s. + pub fn verify_and_expand( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result> { + Ok(self + .inner + .verify_and_expand::(public_params, &public_key.key.key)? + .into_iter() + .map(|mut radix| { + assert_eq!(radix.blocks.len(), 1); + radix.blocks[0].degree = Degree::new(1); + FheBool::new(BooleanBlock::new_unchecked(radix.blocks.pop().unwrap())) + }) + .collect()) + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + self.inner.verify(public_params, &public_key.key.key) + } +} diff --git a/tfhe/src/high_level_api/config.rs b/tfhe/src/high_level_api/config.rs index cca6cf56a6..cce4aec994 100644 --- a/tfhe/src/high_level_api/config.rs +++ b/tfhe/src/high_level_api/config.rs @@ -1,7 +1,7 @@ use crate::high_level_api::keys::IntegerConfig; /// The config type -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Copy, Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct Config { pub(crate) inner: IntegerConfig, } diff --git a/tfhe/src/high_level_api/errors.rs b/tfhe/src/high_level_api/errors.rs index 83a1ecb5bc..19c6406a40 100644 --- a/tfhe/src/high_level_api/errors.rs +++ b/tfhe/src/high_level_api/errors.rs @@ -38,43 +38,3 @@ impl Display for UninitializedServerKey { } impl std::error::Error for UninitializedServerKey {} - -/// Error when trying to create a short integer from a value that was too big to be represented -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct OutOfRangeError; - -impl Display for OutOfRangeError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "Value is out of range") - } -} - -impl std::error::Error for OutOfRangeError {} - -#[non_exhaustive] -#[derive(Debug, Eq, PartialEq)] -pub enum Error { - OutOfRange, - UninitializedServerKey, -} - -impl From for Error { - fn from(_: OutOfRangeError) -> Self { - Self::OutOfRange - } -} - -impl Display for Error { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::OutOfRange => { - write!(f, "{OutOfRangeError}") - } - Self::UninitializedServerKey => { - write!(f, "{UninitializedServerKey}") - } - } - } -} - -impl std::error::Error for Error {} diff --git a/tfhe/src/high_level_api/integers/signed/compact.rs b/tfhe/src/high_level_api/integers/signed/compact.rs index b325ac2b5e..63264c683c 100644 --- a/tfhe/src/high_level_api/integers/signed/compact.rs +++ b/tfhe/src/high_level_api/integers/signed/compact.rs @@ -66,7 +66,7 @@ where T: crate::integer::block_decomposition::DecomposableInto, Id: FheIntId, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &CompactPublicKey) -> Result { let id = Id::default(); @@ -169,7 +169,7 @@ where T: crate::integer::block_decomposition::DecomposableInto, Id: FheIntId, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(values: &'a [T], key: &CompactPublicKey) -> Result { let id = Id::default(); diff --git a/tfhe/src/high_level_api/integers/signed/compressed.rs b/tfhe/src/high_level_api/integers/signed/compressed.rs index aaf3f1650d..1153bc4edd 100644 --- a/tfhe/src/high_level_api/integers/signed/compressed.rs +++ b/tfhe/src/high_level_api/integers/signed/compressed.rs @@ -88,7 +88,7 @@ where Id: FheIntId, T: DecomposableInto + SignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &ClientKey) -> Result { let integer_client_key = &key.key.key; diff --git a/tfhe/src/high_level_api/integers/signed/encrypt.rs b/tfhe/src/high_level_api/integers/signed/encrypt.rs index d6747ba670..8d60710d72 100644 --- a/tfhe/src/high_level_api/integers/signed/encrypt.rs +++ b/tfhe/src/high_level_api/integers/signed/encrypt.rs @@ -43,7 +43,7 @@ where Id: FheIntId, T: DecomposableInto + SignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &ClientKey) -> Result { let ciphertext = key @@ -59,7 +59,7 @@ where Id: FheIntId, T: DecomposableInto + SignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &PublicKey) -> Result { let ciphertext = key @@ -74,7 +74,7 @@ where Id: FheIntId, T: DecomposableInto + SignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &CompressedPublicKey) -> Result { let ciphertext = key @@ -89,7 +89,7 @@ where Id: FheIntId, T: DecomposableInto + SignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &CompactPublicKey) -> Result { let ciphertext = key @@ -105,7 +105,7 @@ where T: DecomposableInto, Id: FheIntId, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; /// Creates a trivial encryption of a signed integer. /// diff --git a/tfhe/src/high_level_api/integers/signed/mod.rs b/tfhe/src/high_level_api/integers/signed/mod.rs index cb0ffdeb69..cc74f2c651 100644 --- a/tfhe/src/high_level_api/integers/signed/mod.rs +++ b/tfhe/src/high_level_api/integers/signed/mod.rs @@ -9,6 +9,8 @@ mod scalar_ops; mod static_; #[cfg(test)] mod tests; +#[cfg(feature = "zk-pok-experimental")] +mod zk; pub use base::{FheInt, FheIntId}; pub use compact::{CompactFheInt, CompactFheIntList}; diff --git a/tfhe/src/high_level_api/integers/signed/static_.rs b/tfhe/src/high_level_api/integers/signed/static_.rs index 13f0edf779..1bb144d981 100644 --- a/tfhe/src/high_level_api/integers/signed/static_.rs +++ b/tfhe/src/high_level_api/integers/signed/static_.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "zk-pok-experimental")] +use super::zk::{ProvenCompactFheInt, ProvenCompactFheIntList}; use crate::high_level_api::integers::signed::base::{FheInt, FheIntConformanceParams, FheIntId}; use crate::high_level_api::integers::signed::compact::{ CompactFheInt, CompactFheIntList, CompactFheIntListConformanceParams, @@ -57,6 +59,13 @@ macro_rules! static_int_type { #[cfg_attr(all(doc, not(doctest)), cfg(feature = "integer"))] pub type [] = CompactFheIntListConformanceParams<[]>; + + // Zero-knowledge Stuff + #[cfg(feature = "zk-pok-experimental")] + pub type [] = ProvenCompactFheInt<[]>; + + #[cfg(feature = "zk-pok-experimental")] + pub type [] = ProvenCompactFheIntList<[]>; } }; } diff --git a/tfhe/src/high_level_api/integers/signed/tests.rs b/tfhe/src/high_level_api/integers/signed/tests.rs index a07355fe4a..fe564ce9a6 100644 --- a/tfhe/src/high_level_api/integers/signed/tests.rs +++ b/tfhe/src/high_level_api/integers/signed/tests.rs @@ -801,3 +801,50 @@ fn test_safe_deserialize_conformant_compact_fhe_int32_list() { )); assert!(deserialized_list.is_conformant(¶ms)); } + +#[cfg(feature = "zk-pok-experimental")] +#[test] +fn test_fhe_int_zk() { + use crate::core_crypto::prelude::DynamicDistribution; + use crate::zk::{CompactPkeCrs, ZkComputeLoad}; + + let mut params = crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + params.glwe_noise_distribution = DynamicDistribution::new_t_uniform(9); + + let config = ConfigBuilder::with_custom_parameters(params, None).build(); + let crs = CompactPkeCrs::from_config(config, 32).unwrap(); + let ck = ClientKey::generate(config); + let pk = CompactPublicKey::new(&ck); + + let msg = random::(); + + let proven_compact_fhe_uint = crate::ProvenCompactFheInt32::try_encrypt( + msg, + crs.public_params(), + &pk, + ZkComputeLoad::Proof, + ) + .unwrap(); + let fhe_uint = proven_compact_fhe_uint + .verify_and_expand(crs.public_params(), &pk) + .unwrap(); + let decrypted: i32 = fhe_uint.decrypt(&ck); + assert_eq!(decrypted, msg); + + let messages = (0..4).map(|_| random()).collect::>(); + let proven_compact_fhe_uint_list = crate::ProvenCompactFheInt32List::try_encrypt( + &messages, + crs.public_params(), + &pk, + ZkComputeLoad::Proof, + ) + .unwrap(); + let fhe_uints = proven_compact_fhe_uint_list + .verify_and_expand(crs.public_params(), &pk) + .unwrap(); + let decrypted = fhe_uints + .iter() + .map(|fb| fb.decrypt(&ck)) + .collect::>(); + assert_eq!(decrypted.as_slice(), &messages); +} diff --git a/tfhe/src/high_level_api/integers/signed/zk.rs b/tfhe/src/high_level_api/integers/signed/zk.rs new file mode 100644 index 0000000000..402fccac1b --- /dev/null +++ b/tfhe/src/high_level_api/integers/signed/zk.rs @@ -0,0 +1,140 @@ +use crate::core_crypto::commons::math::random::{Deserialize, Serialize}; +use crate::core_crypto::prelude::SignedNumeric; +use crate::high_level_api::integers::FheIntId; +use crate::integer::block_decomposition::DecomposableInto; +use crate::integer::{ProvenCompactCiphertextList, SignedRadixCiphertext}; +use crate::named::Named; +use crate::zk::{CompactPkePublicParams, ZkComputeLoad, ZkVerificationOutCome}; +use crate::{CompactPublicKey, FheInt}; + +/// A `CompactFheInt` tied to a Zero-Knowledge proof +/// +/// The zero-knowledge proof allows to verify that the ciphertext is correctly +/// encrypted. +#[derive(Clone, Serialize, Deserialize)] +pub struct ProvenCompactFheInt { + inner: ProvenCompactCiphertextList, + _id: Id, +} + +impl Named for ProvenCompactFheInt { + const NAME: &'static str = "high_level_api::ProvenCompactFheUintList"; +} + +impl ProvenCompactFheInt +where + Id: FheIntId, +{ + /// Encrypts the message while also generating the zero-knowledge proof + pub fn try_encrypt( + value: Clear, + public_params: &CompactPkePublicParams, + key: &CompactPublicKey, + load: ZkComputeLoad, + ) -> crate::Result + where + Clear: DecomposableInto + SignedNumeric, + { + let inner = key.key.key.encrypt_and_prove_radix_compact( + &[value], + Id::num_blocks(key.key.key.key.parameters.message_modulus()), + public_params, + load, + )?; + Ok(Self { + inner, + _id: Id::default(), + }) + } + + /// Verifies the ciphertext and the proof + /// + /// If the proof and ciphertext are valid, it returns an `Ok` with + /// the underlying `FheInt` + pub fn verify_and_expand( + self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result> { + let expanded_inner = self + .inner + .verify_and_expand_one::(public_params, &public_key.key.key)?; + Ok(FheInt::new(expanded_inner)) + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + self.inner.verify(public_params, &public_key.key.key) + } +} + +/// A `CompactFheIntList` tied to a Zero-Knowledge proof +/// +/// The zero-knowledge proof allows to verify that the ciphertext list is correctly +/// encrypted. +#[derive(Clone, Serialize, Deserialize)] +pub struct ProvenCompactFheIntList { + inner: ProvenCompactCiphertextList, + _id: Id, +} + +impl Named for ProvenCompactFheIntList { + const NAME: &'static str = "high_level_api::ProvenCompactFheIntList"; +} + +impl ProvenCompactFheIntList +where + Id: FheIntId, +{ + /// Encrypts the message while also generating the zero-knowledge proof + pub fn try_encrypt( + values: &[Clear], + public_params: &CompactPkePublicParams, + key: &CompactPublicKey, + load: ZkComputeLoad, + ) -> crate::Result + where + Clear: DecomposableInto + SignedNumeric, + { + let inner = key.key.key.encrypt_and_prove_radix_compact( + values, + Id::num_blocks(key.key.key.key.parameters.message_modulus()), + public_params, + load, + )?; + Ok(Self { + inner, + _id: Id::default(), + }) + } + + pub fn len(&self) -> usize { + self.inner.ciphertext_count() + } + + /// Verifies the ciphertext and the proof + /// + /// If the proof and ciphertext are valid, it returns an `Ok` with + /// the underlying `FheInt`s. + pub fn verify_and_expand( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result>> { + let expanded_inners = self + .inner + .verify_and_expand::(public_params, &public_key.key.key)?; + Ok(expanded_inners.into_iter().map(FheInt::new).collect()) + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + self.inner.verify(public_params, &public_key.key.key) + } +} diff --git a/tfhe/src/high_level_api/integers/unsigned/compact.rs b/tfhe/src/high_level_api/integers/unsigned/compact.rs index 35eff5fb47..31b5de508b 100644 --- a/tfhe/src/high_level_api/integers/unsigned/compact.rs +++ b/tfhe/src/high_level_api/integers/unsigned/compact.rs @@ -69,7 +69,7 @@ where T: crate::integer::block_decomposition::DecomposableInto, Id: FheUintId, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &CompactPublicKey) -> Result { let ciphertext = key @@ -175,7 +175,7 @@ where T: crate::integer::block_decomposition::DecomposableInto, Id: FheUintId, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(values: &'a [T], key: &CompactPublicKey) -> Result { let ciphertext = key diff --git a/tfhe/src/high_level_api/integers/unsigned/compressed.rs b/tfhe/src/high_level_api/integers/unsigned/compressed.rs index 551d298fa0..c872b600ca 100644 --- a/tfhe/src/high_level_api/integers/unsigned/compressed.rs +++ b/tfhe/src/high_level_api/integers/unsigned/compressed.rs @@ -91,7 +91,7 @@ where Id: FheUintId, T: DecomposableInto + UnsignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &ClientKey) -> Result { let inner = key diff --git a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs index 83b69f7347..bb3ef27e07 100644 --- a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs +++ b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs @@ -45,7 +45,7 @@ where Id: FheUintId, T: DecomposableInto + UnsignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &ClientKey) -> Result { let cpu_ciphertext = key @@ -65,7 +65,7 @@ where Id: FheUintId, T: DecomposableInto + UnsignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &PublicKey) -> Result { let cpu_ciphertext = key @@ -84,7 +84,7 @@ where Id: FheUintId, T: DecomposableInto + UnsignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &CompressedPublicKey) -> Result { let cpu_ciphertext = key @@ -102,7 +102,7 @@ where Id: FheUintId, T: DecomposableInto + UnsignedNumeric, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt(value: T, key: &CompactPublicKey) -> Result { let cpu_ciphertext = key @@ -121,7 +121,7 @@ where T: DecomposableInto + UnsignedNumeric, Id: FheUintId, { - type Error = crate::high_level_api::errors::Error; + type Error = crate::Error; fn try_encrypt_trivial(value: T) -> Result { global_state::with_internal_keys(|key| match key { diff --git a/tfhe/src/high_level_api/integers/unsigned/mod.rs b/tfhe/src/high_level_api/integers/unsigned/mod.rs index 0d70967a54..1f03d7fb71 100644 --- a/tfhe/src/high_level_api/integers/unsigned/mod.rs +++ b/tfhe/src/high_level_api/integers/unsigned/mod.rs @@ -23,3 +23,5 @@ mod overflowing_ops; mod scalar_ops; #[cfg(test)] mod tests; +#[cfg(feature = "zk-pok-experimental")] +mod zk; diff --git a/tfhe/src/high_level_api/integers/unsigned/static_.rs b/tfhe/src/high_level_api/integers/unsigned/static_.rs index b436e2a52a..59816e51f5 100644 --- a/tfhe/src/high_level_api/integers/unsigned/static_.rs +++ b/tfhe/src/high_level_api/integers/unsigned/static_.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "zk-pok-experimental")] +use super::zk::{ProvenCompactFheUint, ProvenCompactFheUintList}; use crate::high_level_api::integers::unsigned::base::{ FheUint, FheUintConformanceParams, FheUintId, }; @@ -56,6 +58,12 @@ macro_rules! static_int_type { #[cfg_attr(all(doc, not(doctest)), cfg(feature = "integer"))] pub type [] = CompactFheUintListConformanceParams<[]>; + + #[cfg(feature = "zk-pok-experimental")] + pub type [] = ProvenCompactFheUint<[]>; + + #[cfg(feature = "zk-pok-experimental")] + pub type [] = ProvenCompactFheUintList<[]>; } }; } diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs index de3dddaaeb..46da9237b0 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs @@ -10,7 +10,7 @@ use crate::{ CompressedPublicKey, Config, FheInt16, FheInt32, FheInt8, FheUint128, FheUint16, FheUint256, FheUint32, FheUint32ConformanceParams, }; -use rand::{random, Rng}; +use rand::prelude::*; fn setup_cpu(params: Option>) -> ClientKey { let config = params @@ -543,3 +543,49 @@ fn test_safe_deserialize_conformant_compact_fhe_uint32_list() { assert_eq!(decrypted, expected); } } + +#[cfg(feature = "zk-pok-experimental")] +#[test] +fn test_fhe_uint_zk() { + use crate::zk::{CompactPkeCrs, ZkComputeLoad}; + + let mut params = PARAM_MESSAGE_2_CARRY_2_KS_PBS; + params.glwe_noise_distribution = DynamicDistribution::new_t_uniform(9); + + let config = ConfigBuilder::with_custom_parameters(params, None).build(); + let crs = CompactPkeCrs::from_config(config, 32).unwrap(); + let ck = ClientKey::generate(config); + let pk = CompactPublicKey::new(&ck); + + let msg = random::(); + + let proven_compact_fhe_uint = crate::ProvenCompactFheUint32::try_encrypt( + msg, + crs.public_params(), + &pk, + ZkComputeLoad::Proof, + ) + .unwrap(); + let fhe_uint = proven_compact_fhe_uint + .verify_and_expand(crs.public_params(), &pk) + .unwrap(); + let decrypted: u32 = fhe_uint.decrypt(&ck); + assert_eq!(decrypted, msg); + + let messages = (0..4).map(|_| random()).collect::>(); + let proven_compact_fhe_uint_list = crate::ProvenCompactFheUint32List::try_encrypt( + &messages, + crs.public_params(), + &pk, + ZkComputeLoad::Proof, + ) + .unwrap(); + let fhe_uints = proven_compact_fhe_uint_list + .verify_and_expand(crs.public_params(), &pk) + .unwrap(); + let decrypted = fhe_uints + .iter() + .map(|fb| fb.decrypt(&ck)) + .collect::>(); + assert_eq!(decrypted.as_slice(), &messages); +} diff --git a/tfhe/src/high_level_api/integers/unsigned/zk.rs b/tfhe/src/high_level_api/integers/unsigned/zk.rs new file mode 100644 index 0000000000..3f0818f7c2 --- /dev/null +++ b/tfhe/src/high_level_api/integers/unsigned/zk.rs @@ -0,0 +1,140 @@ +use super::FheUintId; +use crate::core_crypto::commons::math::random::{Deserialize, Serialize}; +use crate::core_crypto::prelude::UnsignedNumeric; +use crate::integer::block_decomposition::DecomposableInto; +use crate::integer::{ProvenCompactCiphertextList, RadixCiphertext}; +use crate::named::Named; +use crate::zk::{CompactPkePublicParams, ZkComputeLoad, ZkVerificationOutCome}; +use crate::{CompactPublicKey, FheUint}; + +/// A `CompactFheUint` tied to a Zero-Knowledge proof +/// +/// The zero-knowledge proof allows to verify that the ciphertext is correctly +/// encrypted. +#[derive(Clone, Serialize, Deserialize)] +pub struct ProvenCompactFheUint { + inner: ProvenCompactCiphertextList, + _id: Id, +} + +impl Named for ProvenCompactFheUint { + const NAME: &'static str = "high_level_api::ProvenCompactFheUint"; +} + +impl ProvenCompactFheUint +where + Id: FheUintId, +{ + /// Encrypts the message while also generating the zero-knowledge proof + pub fn try_encrypt( + value: Clear, + public_params: &CompactPkePublicParams, + key: &CompactPublicKey, + load: ZkComputeLoad, + ) -> crate::Result + where + Clear: DecomposableInto + UnsignedNumeric, + { + let inner = key.key.key.encrypt_and_prove_radix_compact( + &[value], + Id::num_blocks(key.key.key.key.parameters.message_modulus()), + public_params, + load, + )?; + Ok(Self { + inner, + _id: Id::default(), + }) + } + + /// Verifies the ciphertext and the proof + /// + /// If the proof and ciphertext are valid, it returns an `Ok` with + /// the underlying `FheUint` + pub fn verify_and_expand( + self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result> { + let expanded_inner = self + .inner + .verify_and_expand_one::(public_params, &public_key.key.key)?; + Ok(FheUint::new(expanded_inner)) + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + self.inner.verify(public_params, &public_key.key.key) + } +} + +/// A `CompactFheUintList` tied to a Zero-Knowledge proof +/// +/// The zero-knowledge proof allows to verify that the ciphertext list is correctly +/// encrypted. +#[derive(Clone, Serialize, Deserialize)] +pub struct ProvenCompactFheUintList { + inner: ProvenCompactCiphertextList, + _id: Id, +} + +impl Named for ProvenCompactFheUintList { + const NAME: &'static str = "high_level_api::ProvenCompactFheUintList"; +} + +impl ProvenCompactFheUintList +where + Id: FheUintId, +{ + /// Encrypts the message while also generating the zero-knowledge proof + pub fn try_encrypt( + values: &[Clear], + public_params: &CompactPkePublicParams, + key: &CompactPublicKey, + load: ZkComputeLoad, + ) -> crate::Result + where + Clear: DecomposableInto + UnsignedNumeric, + { + let inner = key.key.key.encrypt_and_prove_radix_compact( + values, + Id::num_blocks(key.key.key.key.parameters.message_modulus()), + public_params, + load, + )?; + Ok(Self { + inner, + _id: Id::default(), + }) + } + + pub fn len(&self) -> usize { + self.inner.ciphertext_count() + } + + /// Verifies the ciphertext and the proof + /// + /// If the proof and ciphertext are valid, it returns an `Ok` with + /// the underlying `FheUint`s. + pub fn verify_and_expand( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result>> { + let expanded_inners = self + .inner + .verify_and_expand::(public_params, &public_key.key.key)?; + Ok(expanded_inners.into_iter().map(FheUint::new).collect()) + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + self.inner.verify(public_params, &public_key.key.key) + } +} diff --git a/tfhe/src/high_level_api/mod.rs b/tfhe/src/high_level_api/mod.rs index ebb1457230..cb152de5fd 100644 --- a/tfhe/src/high_level_api/mod.rs +++ b/tfhe/src/high_level_api/mod.rs @@ -23,6 +23,13 @@ macro_rules! expand_pub_use_fhe_type( )* }; + #[cfg(feature = "zk-pok-experimental")] + pub use $module_path::{ + $( + [], + [], + )* + }; } } ); @@ -30,7 +37,6 @@ macro_rules! expand_pub_use_fhe_type( pub use crate::core_crypto::commons::math::random::Seed; pub use crate::integer::oprf::SignedRandomizationSpec; pub use config::{Config, ConfigBuilder}; -pub use errors::{Error, OutOfRangeError}; pub use global_state::{set_server_key, unset_server_key, with_server_key_as_context}; pub use integers::{ @@ -51,6 +57,8 @@ pub use crate::high_level_api::booleans::{ CompactFheBool, CompactFheBoolList, CompactFheBoolListConformanceParams, CompressedFheBool, FheBool, FheBoolConformanceParams, }; +#[cfg(feature = "zk-pok-experimental")] +pub use crate::high_level_api::booleans::{ProvenCompactFheBool, ProvenCompactFheBoolList}; expand_pub_use_fhe_type!( pub use crate::high_level_api::integers{ FheUint2, FheUint4, FheUint6, FheUint8, FheUint10, FheUint12, FheUint14, FheUint16, @@ -60,6 +68,7 @@ expand_pub_use_fhe_type!( FheInt32, FheInt64, FheInt128, FheInt160, FheInt256 }; ); + pub use safe_serialize::safe_serialize; mod config; @@ -68,12 +77,14 @@ mod keys; mod traits; mod booleans; -pub mod errors; +mod errors; mod integers; pub(in crate::high_level_api) mod details; /// The tfhe prelude. pub mod prelude; +#[cfg(feature = "zk-pok-experimental")] +mod zk; /// Devices supported by tfhe-rs #[derive(Copy, Clone, PartialEq, Eq, Debug)] diff --git a/tfhe/src/high_level_api/tests/mod.rs b/tfhe/src/high_level_api/tests/mod.rs index e81864a8fd..a4ea1d7b1e 100644 --- a/tfhe/src/high_level_api/tests/mod.rs +++ b/tfhe/src/high_level_api/tests/mod.rs @@ -99,9 +99,9 @@ fn test_with_seed() { let builder = ConfigBuilder::default(); let config = builder.build(); - let cks1 = ClientKey::generate_with_seed(config.clone(), Seed(125)); - let cks2 = ClientKey::generate(config.clone()); - let cks3 = ClientKey::generate_with_seed(config.clone(), Seed(125)); + let cks1 = ClientKey::generate_with_seed(config, Seed(125)); + let cks2 = ClientKey::generate(config); + let cks3 = ClientKey::generate_with_seed(config, Seed(125)); let cks4 = ClientKey::generate_with_seed(config, Seed(127)); let cks1_serialized = bincode::serialize(&cks1).unwrap(); diff --git a/tfhe/src/high_level_api/zk.rs b/tfhe/src/high_level_api/zk.rs new file mode 100644 index 0000000000..8b11d6c73b --- /dev/null +++ b/tfhe/src/high_level_api/zk.rs @@ -0,0 +1,11 @@ +use crate::zk::CompactPkeCrs; +use crate::Config; + +impl CompactPkeCrs { + pub fn from_config(config: Config, max_bit_size: usize) -> crate::Result { + let max_num_message = + max_bit_size / config.inner.block_parameters.message_modulus().0.ilog2() as usize; + let crs = Self::from_shortint_params(config.inner.block_parameters, max_num_message)?; + Ok(crs) + } +} diff --git a/tfhe/src/integer/block_decomposition.rs b/tfhe/src/integer/block_decomposition.rs index d3e1a95f6f..ea4de647ce 100644 --- a/tfhe/src/integer/block_decomposition.rs +++ b/tfhe/src/integer/block_decomposition.rs @@ -27,6 +27,11 @@ pub trait Recomposable: + Shl + Sub { + // TODO: need for wrapping arithmetic traits + // This is a wrapping add but to avoid conflicts with other parts of the code using external + // wrapping traits definition we change the name here + #[must_use] + fn recomposable_wrapping_add(self, other: Self) -> Self; } // Convenience traits have simpler bounds @@ -39,7 +44,12 @@ macro_rules! impl_recomposable_decomposable { ) => { $( impl Decomposable for $type { } - impl Recomposable for $type { } + impl Recomposable for $type { + #[inline] + fn recomposable_wrapping_add(self, other: Self) -> Self { + self.wrapping_add(other) + } + } impl RecomposableFrom for $type { } impl DecomposableInto for $type { } impl RecomposableFrom for $type { } @@ -51,14 +61,26 @@ macro_rules! impl_recomposable_decomposable { impl_recomposable_decomposable!(u8, u16, u32, u64, u128, i8, i16, i32, i64, i128,); impl Decomposable for StaticSignedBigInt {} -impl Recomposable for StaticSignedBigInt {} +impl Recomposable for StaticSignedBigInt { + #[inline] + fn recomposable_wrapping_add(mut self, other: Self) -> Self { + self.add_assign(other); + self + } +} impl RecomposableFrom for StaticSignedBigInt {} impl RecomposableFrom for StaticSignedBigInt {} impl DecomposableInto for StaticSignedBigInt {} impl DecomposableInto for StaticSignedBigInt {} impl Decomposable for StaticUnsignedBigInt {} -impl Recomposable for StaticUnsignedBigInt {} +impl Recomposable for StaticUnsignedBigInt { + #[inline] + fn recomposable_wrapping_add(mut self, other: Self) -> Self { + self.add_assign(other); + self + } +} impl RecomposableFrom for StaticUnsignedBigInt {} impl RecomposableFrom for StaticUnsignedBigInt {} impl DecomposableInto for StaticUnsignedBigInt {} @@ -258,7 +280,7 @@ where } block <<= self.bit_pos; - self.data += block; + self.data = self.data.recomposable_wrapping_add(block); self.bit_pos += self.num_bits_in_block; true diff --git a/tfhe/src/integer/mod.rs b/tfhe/src/integer/mod.rs index 30d0c832a7..d57cfc607f 100755 --- a/tfhe/src/integer/mod.rs +++ b/tfhe/src/integer/mod.rs @@ -65,6 +65,11 @@ pub mod wopbs; #[cfg(feature = "gpu")] pub mod gpu; +#[cfg(feature = "zk-pok-experimental")] +mod zk; + +#[cfg(feature = "zk-pok-experimental")] +pub use zk::ProvenCompactCiphertextList; pub use bigint::i256::I256; pub use bigint::i512::I512; @@ -76,7 +81,9 @@ pub use ciphertext::{ SignedRadixCiphertext, }; pub use client_key::{ClientKey, CrtClientKey, RadixClientKey}; -pub use public_key::{CompressedCompactPublicKey, CompressedPublicKey, PublicKey}; +pub use public_key::{ + CompactPublicKey, CompressedCompactPublicKey, CompressedPublicKey, PublicKey, +}; pub use server_key::{CheckError, CompressedServerKey, ServerKey}; /// Enum to indicate which kind of computations the [`ServerKey`] will be performing, this changes diff --git a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs index 98bba139b3..79b528d4ef 100644 --- a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs +++ b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs @@ -241,6 +241,7 @@ impl ServerKey { let counter_num_blocks = ((num_bits_in_ciphertext - 1).ilog2() + 1 + 1) .div_ceil(self.message_modulus().0.ilog2()) as usize; + // 11111000 // x.ilog2() = (x.num_bit() - 1) - x.leading_zeros() // - (x.num_bit() - 1) is trivially known // - we can get leading zeros via a sum diff --git a/tfhe/src/integer/zk.rs b/tfhe/src/integer/zk.rs new file mode 100644 index 0000000000..a7ca5f5879 --- /dev/null +++ b/tfhe/src/integer/zk.rs @@ -0,0 +1,139 @@ +use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; +use crate::integer::encryption::KnowsMessageModulus; +use crate::integer::public_key::CompactPublicKey; +use crate::integer::IntegerRadixCiphertext; +use crate::zk::{CompactPkePublicParams, ZkComputeLoad, ZkVerificationOutCome}; +use serde::{Deserialize, Serialize}; + +impl CompactPublicKey { + pub fn encrypt_and_prove_radix_compact>( + &self, + messages: &[T], + num_blocks_per_integer: usize, + public_params: &CompactPkePublicParams, + load: ZkComputeLoad, + ) -> crate::Result { + let messages = messages + .iter() + .copied() + .flat_map(|message| { + BlockDecomposer::new(message, self.key.message_modulus().0.ilog2()) + .iter_as::() + .take(num_blocks_per_integer) + }) + .collect::>(); + + let proved_list = self + .key + .encrypt_and_prove_slice(&messages, public_params, load)?; + + Ok(ProvenCompactCiphertextList { + proved_list, + num_blocks_per_integer, + }) + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct ProvenCompactCiphertextList { + pub(crate) proved_list: crate::shortint::ciphertext::ProvenCompactCiphertextList, + // Keep track of the num_blocks, as we allow + // storing many integer that have the same num_blocks + // into ct_list + pub(crate) num_blocks_per_integer: usize, +} + +impl ProvenCompactCiphertextList { + pub fn verify_and_expand_one( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result { + let blocks = self + .proved_list + .verify_and_expand(public_params, &public_key.key)?; + assert_eq!(blocks.len(), self.num_blocks_per_integer); + + Ok(T::from_blocks(blocks)) + } + + pub fn ciphertext_count(&self) -> usize { + self.proved_list.ciphertext_count() / self.num_blocks_per_integer + } + + pub fn verify_and_expand( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result> { + let blocks = self + .proved_list + .verify_and_expand(public_params, &public_key.key)?; + + let mut integers = Vec::with_capacity(self.ciphertext_count()); + let mut blocks_iter = blocks.into_iter(); + for _ in 0..self.ciphertext_count() { + let radix_blocks = blocks_iter + .by_ref() + .take(self.num_blocks_per_integer) + .collect::>(); + integers.push(T::from_blocks(radix_blocks)); + } + Ok(integers) + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + self.proved_list.verify(public_params, &public_key.key) + } +} + +#[cfg(test)] +mod tests { + use crate::integer::{ClientKey, CompactPublicKey}; + use crate::shortint::parameters::DynamicDistribution; + use crate::shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + use crate::zk::{CompactPkeCrs, ZkComputeLoad}; + use rand::random; + + #[test] + fn test_zk_compact_ciphertext_list_encryption_ci_run_filter() { + let mut params = PARAM_MESSAGE_2_CARRY_2_KS_PBS; + params.glwe_noise_distribution = DynamicDistribution::new_t_uniform(9); + + let num_blocks = 4usize; + let modulus = (params.message_modulus.0 as u64) + .checked_pow(num_blocks as u32) + .unwrap(); + + let crs = CompactPkeCrs::from_shortint_params(params, 512).unwrap(); + let cks = ClientKey::new(params); + let pk = CompactPublicKey::new(&cks); + + let msgs = (0..512) + .map(|_| random::() % modulus) + .collect::>(); + + let proven_ct = pk + .encrypt_and_prove_radix_compact( + &msgs, + num_blocks, + crs.public_params(), + ZkComputeLoad::Proof, + ) + .unwrap(); + assert!(proven_ct.verify(crs.public_params(), &pk).is_valid()); + + let expanded = proven_ct + .verify_and_expand(crs.public_params(), &pk) + .unwrap(); + let decrypted = expanded + .iter() + .map(|ciphertext| cks.decrypt_radix::(ciphertext)) + .collect::>(); + assert_eq!(msgs, decrypted); + } +} diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs index 4f1ae35cd7..58bc7e73a4 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs @@ -91,6 +91,9 @@ macro_rules! create_wrapper_type_non_native_type ( compressed_type_name: $compressed_type_name:ident, compact_type_name: $compact_type_name:ident, compact_list_type_name: $compact_list_type_name:ident, + proven_type: $proven_type:ident, + proven_compact_type_name: $proven_compact_type_name:ident, + proven_compact_list_type_name: $proven_compact_list_type_name:ident, rust_type: $rust_type:ty $(,)? } ) => { @@ -376,6 +379,146 @@ macro_rules! create_wrapper_type_non_native_type ( }) } } + + #[cfg(feature = "zk-pok-experimental")] + #[wasm_bindgen] + pub struct $proven_compact_type_name(pub(crate) crate::high_level_api::$proven_compact_type_name); + + #[cfg(feature = "zk-pok-experimental")] + #[wasm_bindgen] + impl $proven_compact_type_name { + #[wasm_bindgen] + pub fn encrypt_with_compact_public_key( + value: JsValue, + public_params: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey, + compute_load: crate::js_on_wasm_api::js_high_level_api::zk::ZkComputeLoad, + ) -> Result<$proven_compact_type_name, JsError> { + catch_panic_result(|| { + let value = <$rust_type>::try_from(value) + .map_err(|_| JsError::new(&format!("Failed to convert the value to a {}", stringify!($rust_type))))?; + crate::high_level_api::$proven_compact_type_name::try_encrypt( + value, + &public_params.0, + &public_key.0, + compute_load.into() + ).map($proven_compact_type_name) + .map_err(into_js_error) + }) + } + + #[wasm_bindgen] + pub fn verifies( + &self, + public_parameters: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey + ) -> bool { + self.0.verify(&public_parameters.0, &public_key.0).is_valid() + } + + #[wasm_bindgen] + pub fn verify_and_expand( + &self, + public_parameters: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey + ) -> Result<$type_name, JsError> { + catch_panic(||{ + self.0 + .clone() + .verify_and_expand(&public_parameters.0, &public_key.0) + .map($type_name) + .unwrap() + }) + } + + #[wasm_bindgen] + pub fn serialize(&self) -> Result, JsError> { + catch_panic_result(|| bincode::serialize(&self.0).map_err(into_js_error)) + } + + #[wasm_bindgen] + pub fn deserialize(buffer: &[u8]) -> Result<$proven_compact_type_name, JsError> { + catch_panic_result(|| { + bincode::deserialize(buffer) + .map($proven_compact_type_name) + .map_err(into_js_error) + }) + } + } + + #[cfg(feature = "zk-pok-experimental")] + #[wasm_bindgen] + pub struct $proven_compact_list_type_name(pub(crate) crate::high_level_api::$proven_compact_list_type_name); + + #[cfg(feature = "zk-pok-experimental")] + #[wasm_bindgen] + impl $proven_compact_list_type_name { + #[wasm_bindgen] + pub fn encrypt_with_compact_public_key( + values: Vec, + public_params: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey, + compute_load: crate::js_on_wasm_api::js_high_level_api::zk::ZkComputeLoad, + ) -> Result<$proven_compact_list_type_name, JsError> { + catch_panic_result(|| { + let values = values + .into_iter() + .map(|value| { + <$rust_type>::try_from(value) + .map_err(|_| { + JsError::new(&format!("Failed to convert the value to a {}", stringify!($rust_type))) + }) + }) + .collect::, _>>()?; + crate::high_level_api::$proven_compact_list_type_name::try_encrypt( + &values, + &public_params.0, + &public_key.0, + compute_load.into() + ).map($proven_compact_list_type_name) + .map_err(into_js_error) + }) + } + + #[wasm_bindgen] + pub fn verifies( + &self, + public_parameters: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey + ) -> bool { + self.0.verify(&public_parameters.0, &public_key.0).is_valid() + } + + #[wasm_bindgen] + pub fn verify_and_expand( + &self, + public_parameters: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey + ) -> Result, JsError> { + catch_panic(||{ + self.0 + .clone() + .verify_and_expand(&public_parameters.0, &public_key.0) + .map(|vec| vec.into_iter().map($type_name).collect::>()) + .unwrap() + }) + } + + #[wasm_bindgen] + pub fn serialize(&self) -> Result, JsError> { + catch_panic_result(|| bincode::serialize(&self.0).map_err(into_js_error)) + } + + #[wasm_bindgen] + pub fn deserialize(buffer: &[u8]) -> Result<$proven_compact_list_type_name, JsError> { + catch_panic_result(|| { + bincode::deserialize(buffer) + .map($proven_compact_list_type_name) + .map_err(into_js_error) + }) + } + } + }; ( @@ -385,6 +528,9 @@ macro_rules! create_wrapper_type_non_native_type ( compressed_type_name: $compressed_type_name:ident, compact_type_name: $compact_type_name:ident, compact_list_type_name: $compact_list_type_name:ident, + proven_type: $proven_type:ident, + proven_compact_type_name: $proven_compact_type_name:ident, + proven_compact_list_type_name: $proven_compact_list_type_name:ident, rust_type: $rust_type:ty $(,)? } ),* @@ -397,6 +543,9 @@ macro_rules! create_wrapper_type_non_native_type ( compressed_type_name: $compressed_type_name, compact_type_name: $compact_type_name, compact_list_type_name: $compact_list_type_name, + proven_type: $proven_type, + proven_compact_type_name: $proven_compact_type_name, + proven_compact_list_type_name: $proven_compact_list_type_name, rust_type: $rust_type } ); @@ -405,25 +554,34 @@ macro_rules! create_wrapper_type_non_native_type ( ); create_wrapper_type_non_native_type!( - { - type_name: FheUint160, - compressed_type_name: CompressedFheUint160, - compact_type_name: CompactFheUint160, - compact_list_type_name: CompactFheUint160List, - rust_type: U256, - }, { type_name: FheUint128, compressed_type_name: CompressedFheUint128, compact_type_name: CompactFheUint128, compact_list_type_name: CompactFheUint128List, + proven_type: ProvenFheUint128, + proven_compact_type_name: ProvenCompactFheUint128, + proven_compact_list_type_name: ProvenCompactFheUint128List, rust_type: u128, }, + { + type_name: FheUint160, + compressed_type_name: CompressedFheUint160, + compact_type_name: CompactFheUint160, + compact_list_type_name: CompactFheUint160List, + proven_type: ProvenFheUint160, + proven_compact_type_name: ProvenCompactFheUint160, + proven_compact_list_type_name: ProvenCompactFheUint160List, + rust_type: U256, + }, { type_name: FheUint256, compressed_type_name: CompressedFheUint256, compact_type_name: CompactFheUint256, compact_list_type_name: CompactFheUint256List, + proven_type: ProvenFheUint256, + proven_compact_type_name: ProvenCompactFheUint256, + proven_compact_list_type_name: ProvenCompactFheUint256List, rust_type: U256, }, // Signed @@ -432,6 +590,9 @@ create_wrapper_type_non_native_type!( compressed_type_name: CompressedFheInt128, compact_type_name: CompactFheInt128, compact_list_type_name: CompactFheInt128List, + proven_type: ProvenFheInt128, + proven_compact_type_name: ProvenCompactFheInt128, + proven_compact_list_type_name: ProvenCompactFheInt128List, rust_type: i128, }, { @@ -439,6 +600,9 @@ create_wrapper_type_non_native_type!( compressed_type_name: CompressedFheInt160, compact_type_name: CompactFheInt160, compact_list_type_name: CompactFheInt160List, + proven_type: ProvenFheInt160, + proven_compact_type_name: ProvenCompactFheInt160, + proven_compact_list_type_name: ProvenCompactFheInt160List, rust_type: I256, }, { @@ -446,6 +610,9 @@ create_wrapper_type_non_native_type!( compressed_type_name: CompressedFheInt256, compact_type_name: CompactFheInt256, compact_list_type_name: CompactFheInt256List, + proven_type: ProvenFheInt256, + proven_compact_type_name: ProvenCompactFheInt256, + proven_compact_list_type_name: ProvenCompactFheInt256List, rust_type: I256, }, ); @@ -460,6 +627,9 @@ macro_rules! create_wrapper_type_that_has_native_type ( compressed_type_name: $compressed_type_name:ident, compact_type_name: $compact_type_name:ident, compact_list_type_name: $compact_list_type_name:ident, + proven_type: $proven_type:ident, + proven_compact_type_name: $proven_compact_type_name:ident, + proven_compact_list_type_name: $proven_compact_list_type_name:ident, native_type: $native_type:ty $(,)? } ) => { @@ -708,6 +878,116 @@ macro_rules! create_wrapper_type_that_has_native_type ( } } + #[cfg(feature = "zk-pok-experimental")] + #[wasm_bindgen] + pub struct $proven_compact_type_name(pub(crate) crate::high_level_api::$proven_compact_type_name); + + #[cfg(feature = "zk-pok-experimental")] + #[wasm_bindgen] + impl $proven_compact_type_name { + #[wasm_bindgen] + pub fn encrypt_with_compact_public_key( + value: $native_type, + public_params: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey, + compute_load: crate::js_on_wasm_api::js_high_level_api::zk::ZkComputeLoad, + ) -> Result<$proven_compact_type_name, JsError> { + catch_panic_result(|| { + crate::high_level_api::$proven_compact_type_name::try_encrypt( + value, + &public_params.0, + &public_key.0, + compute_load.into() + ).map($proven_compact_type_name) + .map_err(into_js_error) + }) + } + + #[wasm_bindgen] + pub fn verifies( + &self, + public_parameters: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey + ) -> bool { + self.0.verify(&public_parameters.0, &public_key.0).is_valid() + } + + #[wasm_bindgen] + pub fn verify_and_expand( + &self, + public_parameters: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey + ) -> Result<$type_name, JsError> { + catch_panic(||{ + self.0 + .clone() + .verify_and_expand(&public_parameters.0, &public_key.0) + .map($type_name) + .unwrap() + }) + } + + #[wasm_bindgen] + pub fn serialize(&self) -> Result, JsError> { + catch_panic_result(|| bincode::serialize(&self.0).map_err(into_js_error)) + } + + #[wasm_bindgen] + pub fn deserialize(buffer: &[u8]) -> Result<$proven_compact_type_name, JsError> { + catch_panic_result(|| { + bincode::deserialize(buffer) + .map($proven_compact_type_name) + .map_err(into_js_error) + }) + } + } + + #[cfg(feature = "zk-pok-experimental")] + #[wasm_bindgen] + pub struct $proven_compact_list_type_name(pub(crate) crate::high_level_api::$proven_compact_list_type_name); + + #[cfg(feature = "zk-pok-experimental")] + #[wasm_bindgen] + impl $proven_compact_list_type_name { + #[wasm_bindgen] + pub fn verifies( + &self, + public_parameters: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey + ) -> bool { + self.0.verify(&public_parameters.0, &public_key.0).is_valid() + } + + #[wasm_bindgen] + pub fn verify_and_expand( + &self, + public_parameters: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey + ) -> Result, JsError> { + catch_panic(||{ + self.0 + .clone() + .verify_and_expand(&public_parameters.0, &public_key.0) + .map(|vec| vec.into_iter().map($type_name).collect::>()) + .unwrap() + }) + } + + #[wasm_bindgen] + pub fn serialize(&self) -> Result, JsError> { + catch_panic_result(|| bincode::serialize(&self.0).map_err(into_js_error)) + } + + #[wasm_bindgen] + pub fn deserialize(buffer: &[u8]) -> Result<$proven_compact_list_type_name, JsError> { + catch_panic_result(|| { + bincode::deserialize(buffer) + .map($proven_compact_list_type_name) + .map_err(into_js_error) + }) + } + } + }; ( $( @@ -716,6 +996,9 @@ macro_rules! create_wrapper_type_that_has_native_type ( compressed_type_name: $compressed_type_name:ident, compact_type_name: $compact_type_name:ident, compact_list_type_name: $compact_list_type_name:ident, + proven_type: $proven_type:ident, + proven_compact_type_name: $proven_compact_type_name:ident, + proven_compact_list_type_name: $proven_compact_list_type_name:ident, native_type: $native_type:ty $(,)? } ),* @@ -728,6 +1011,9 @@ macro_rules! create_wrapper_type_that_has_native_type ( compressed_type_name: $compressed_type_name, compact_type_name: $compact_type_name, compact_list_type_name: $compact_list_type_name, + proven_type: $proven_type, + proven_compact_type_name: $proven_compact_type_name, + proven_compact_list_type_name: $proven_compact_list_type_name, native_type: $native_type } ); @@ -741,6 +1027,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheBool, compact_type_name: CompactFheBool, compact_list_type_name: CompactFheBoolList, + proven_type: ProvenFheBool, + proven_compact_type_name: ProvenCompactFheBool, + proven_compact_list_type_name: ProvenCompactFheBoolList, native_type: bool, }, { @@ -748,6 +1037,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint2, compact_type_name: CompactFheUint2, compact_list_type_name: CompactFheUint2List, + proven_type: ProvenFheUint2, + proven_compact_type_name: ProvenCompactFheUint2, + proven_compact_list_type_name: ProvenCompactFheUint2List, native_type: u8, }, { @@ -755,6 +1047,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint4, compact_type_name: CompactFheUint4, compact_list_type_name: CompactFheUint4List, + proven_type: ProvenFheUint4, + proven_compact_type_name: ProvenCompactFheUint4, + proven_compact_list_type_name: ProvenCompactFheUint4List, native_type: u8, }, { @@ -762,6 +1057,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint6, compact_type_name: CompactFheUint6, compact_list_type_name: CompactFheUint6List, + proven_type: ProvenFheUint6, + proven_compact_type_name: ProvenCompactFheUint6, + proven_compact_list_type_name: ProvenCompactFheUint6List, native_type: u8, }, { @@ -769,6 +1067,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint8, compact_type_name: CompactFheUint8, compact_list_type_name: CompactFheUint8List, + proven_type: ProvenFheUint8, + proven_compact_type_name: ProvenCompactFheUint8, + proven_compact_list_type_name: ProvenCompactFheUint8List, native_type: u8, }, { @@ -776,6 +1077,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint10, compact_type_name: CompactFheUint10, compact_list_type_name: CompactFheUint10List, + proven_type: ProvenFheUint10, + proven_compact_type_name: ProvenCompactFheUint10, + proven_compact_list_type_name: ProvenCompactFheUint10List, native_type: u16, }, { @@ -783,6 +1087,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint12, compact_type_name: CompactFheUint12, compact_list_type_name: CompactFheUint12List, + proven_type: ProvenFheUint12, + proven_compact_type_name: ProvenCompactFheUint12, + proven_compact_list_type_name: ProvenCompactFheUint12List, native_type: u16, }, { @@ -790,6 +1097,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint14, compact_type_name: CompactFheUint14, compact_list_type_name: CompactFheUint14List, + proven_type: ProvenFheUint14, + proven_compact_type_name: ProvenCompactFheUint14, + proven_compact_list_type_name: ProvenCompactFheUint14List, native_type: u16, }, { @@ -797,6 +1107,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint16, compact_type_name: CompactFheUint16, compact_list_type_name: CompactFheUint16List, + proven_type: ProvenFheUint16, + proven_compact_type_name: ProvenCompactFheUint16, + proven_compact_list_type_name: ProvenCompactFheUint16List, native_type: u16, }, { @@ -804,6 +1117,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint32, compact_type_name: CompactFheUint32, compact_list_type_name: CompactFheUint32List, + proven_type: ProvenFheUint32, + proven_compact_type_name: ProvenCompactFheUint32, + proven_compact_list_type_name: ProvenCompactFheUint32List, native_type: u32, }, { @@ -811,6 +1127,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheUint64, compact_type_name: CompactFheUint64, compact_list_type_name: CompactFheUint64List, + proven_type: ProvenFheUint64, + proven_compact_type_name: ProvenCompactFheUint64, + proven_compact_list_type_name: ProvenCompactFheUint64List, native_type: u64, }, // Signed @@ -819,6 +1138,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt2, compact_type_name: CompactFheInt2, compact_list_type_name: CompactFheInt2List, + proven_type: ProvenFheInt2, + proven_compact_type_name: ProvenCompactFheInt2, + proven_compact_list_type_name: ProvenCompactFheInt2List, native_type: i8, }, { @@ -826,6 +1148,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt4, compact_type_name: CompactFheInt4, compact_list_type_name: CompactFheInt4List, + proven_type: ProvenFheInt4, + proven_compact_type_name: ProvenCompactFheInt4, + proven_compact_list_type_name: ProvenCompactFheInt4List, native_type: i8, }, { @@ -833,6 +1158,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt6, compact_type_name: CompactFheInt6, compact_list_type_name: CompactFheInt6List, + proven_type: ProvenFheInt6, + proven_compact_type_name: ProvenCompactFheInt6, + proven_compact_list_type_name: ProvenCompactFheInt6List, native_type: i8, }, { @@ -840,6 +1168,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt8, compact_type_name: CompactFheInt8, compact_list_type_name: CompactFheInt8List, + proven_type: ProvenFheInt8, + proven_compact_type_name: ProvenCompactFheInt8, + proven_compact_list_type_name: ProvenCompactFheInt8List, native_type: i8, }, { @@ -847,6 +1178,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt10, compact_type_name: CompactFheInt10, compact_list_type_name: CompactFheInt10List, + proven_type: ProvenFheInt10, + proven_compact_type_name: ProvenCompactFheInt10, + proven_compact_list_type_name: ProvenCompactFheInt10List, native_type: i16, }, { @@ -854,6 +1188,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt12, compact_type_name: CompactFheInt12, compact_list_type_name: CompactFheInt12List, + proven_type: ProvenFheInt12, + proven_compact_type_name: ProvenCompactFheInt12, + proven_compact_list_type_name: ProvenCompactFheInt12List, native_type: i16, }, { @@ -861,6 +1198,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt14, compact_type_name: CompactFheInt14, compact_list_type_name: CompactFheInt14List, + proven_type: ProvenFheInt14, + proven_compact_type_name: ProvenCompactFheInt14, + proven_compact_list_type_name: ProvenCompactFheInt14List, native_type: i16, }, { @@ -868,6 +1208,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt16, compact_type_name: CompactFheInt16, compact_list_type_name: CompactFheInt16List, + proven_type: ProvenFheInt16, + proven_compact_type_name: ProvenCompactFheInt16, + proven_compact_list_type_name: ProvenCompactFheInt16List, native_type: i16, }, { @@ -875,6 +1218,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt32, compact_type_name: CompactFheInt32, compact_list_type_name: CompactFheInt32List, + proven_type: ProvenFheInt32, + proven_compact_type_name: ProvenCompactFheInt32, + proven_compact_list_type_name: ProvenCompactFheInt32List, native_type: i32, }, { @@ -882,6 +1228,9 @@ create_wrapper_type_that_has_native_type!( compressed_type_name: CompressedFheInt64, compact_type_name: CompactFheInt64, compact_list_type_name: CompactFheInt64List, + proven_type: ProvenFheInt64, + proven_compact_type_name: ProvenCompactFheInt64, + proven_compact_list_type_name: ProvenCompactFheInt64List, native_type: i64, }, ); @@ -967,3 +1316,93 @@ impl CompactFheBoolList { }) } } + +#[cfg(feature = "zk-pok-experimental")] +macro_rules! define_prove_and_encrypt_list_with_compact_public_key { + ( + $( + {$proven_compact_list_type_name:ident, $native_type:ty} + ),* + $(,)? + ) => { + $( + #[wasm_bindgen] + impl $proven_compact_list_type_name { + + #[wasm_bindgen] + pub fn encrypt_with_compact_public_key( + values: Vec<$native_type>, + public_params: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey, + compute_load: crate::js_on_wasm_api::js_high_level_api::zk::ZkComputeLoad, + ) -> Result<$proven_compact_list_type_name, JsError> { + catch_panic_result(|| { + $crate::high_level_api::$proven_compact_list_type_name::try_encrypt( + &values, + &public_params.0, + &public_key.0, + compute_load.into(), + ).map($proven_compact_list_type_name) + .map_err(into_js_error) + }) + } + } + )* + }; +} + +#[cfg(feature = "zk-pok-experimental")] +define_prove_and_encrypt_list_with_compact_public_key!( + {ProvenCompactFheUint2List, u8}, + {ProvenCompactFheUint4List, u8}, + {ProvenCompactFheUint6List, u8}, + {ProvenCompactFheUint8List, u8}, + {ProvenCompactFheUint12List, u16}, + {ProvenCompactFheUint14List, u16}, + {ProvenCompactFheUint16List, u16}, + {ProvenCompactFheUint32List, u32}, + {ProvenCompactFheUint64List, u64}, + // Signed + {ProvenCompactFheInt2List, i8}, + {ProvenCompactFheInt4List, i8}, + {ProvenCompactFheInt6List, i8}, + {ProvenCompactFheInt8List, i8}, + {ProvenCompactFheInt12List, i16}, + {ProvenCompactFheInt14List, i16}, + {ProvenCompactFheInt16List, i16}, + {ProvenCompactFheInt32List, i32}, + {ProvenCompactFheInt64List, i64}, +); + +#[cfg(feature = "zk-pok-experimental")] +#[allow(clippy::use_self)] +#[allow(clippy::needless_pass_by_value)] +#[wasm_bindgen] +impl ProvenCompactFheBoolList { + #[wasm_bindgen] + pub fn encrypt_with_compact_public_key( + values: Vec, + public_params: &crate::js_on_wasm_api::js_high_level_api::zk::CompactPkePublicParams, + public_key: &crate::js_on_wasm_api::js_high_level_api::keys::TfheCompactPublicKey, + compute_load: crate::js_on_wasm_api::js_high_level_api::zk::ZkComputeLoad, + ) -> Result { + catch_panic_result(|| { + let booleans = values + .iter() + .map(|jsvalue| { + jsvalue + .as_bool() + .ok_or_else(|| JsError::new("Value is not a boolean")) + }) + .collect::, JsError>>()?; + crate::high_level_api::ProvenCompactFheBoolList::try_encrypt( + &booleans, + &public_params.0, + &public_key.0, + compute_load.into(), + ) + .map(ProvenCompactFheBoolList) + .map_err(into_js_error) + }) + } +} diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/keys.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/keys.rs index 920c8eb8e5..8b1fc3bf3a 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/keys.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/keys.rs @@ -15,7 +15,7 @@ pub struct TfheClientKey(pub(crate) hlapi::ClientKey); impl TfheClientKey { #[wasm_bindgen] pub fn generate(config: &TfheConfig) -> Result { - catch_panic(|| Self(hlapi::ClientKey::generate(config.0.clone()))) + catch_panic(|| Self(hlapi::ClientKey::generate(config.0))) } #[wasm_bindgen] @@ -26,7 +26,7 @@ impl TfheClientKey { catch_panic_result(|| { let seed = u128::try_from(seed).map_err(|_| JsError::new("Value does not fit in a u128"))?; - let key = hlapi::ClientKey::generate_with_seed(config.0.clone(), crate::Seed(seed)); + let key = hlapi::ClientKey::generate_with_seed(config.0, crate::Seed(seed)); Ok(Self(key)) }) } diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/mod.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/mod.rs index 5f1106d110..208c9e1cf6 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/mod.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/mod.rs @@ -5,6 +5,8 @@ pub(crate) mod integers; // using Self does not work well with #[wasm_bindgen] macro #[allow(clippy::use_self)] pub(crate) mod keys; +#[cfg(feature = "zk-pok-experimental")] +mod zk; pub(crate) fn into_js_error(e: E) -> wasm_bindgen::JsError { wasm_bindgen::JsError::new(format!("{e:?}").as_str()) diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs new file mode 100644 index 0000000000..0fcf327f51 --- /dev/null +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs @@ -0,0 +1,76 @@ +use wasm_bindgen::prelude::*; + +use crate::js_on_wasm_api::js_high_level_api::config::TfheConfig; +use crate::js_on_wasm_api::js_high_level_api::{catch_panic_result, into_js_error}; +use crate::js_on_wasm_api::shortint::ShortintParameters; + +#[derive(Copy, Clone, Eq, PartialEq)] +#[wasm_bindgen] +pub enum ZkComputeLoad { + Proof, + Verify, +} + +impl Into for ZkComputeLoad { + fn into(self) -> crate::zk::ZkComputeLoad { + match self { + Self::Proof => crate::zk::ZkComputeLoad::Proof, + Self::Verify => crate::zk::ZkComputeLoad::Verify, + } + } +} + +#[wasm_bindgen] +pub struct CompactPkeCrs(pub(crate) crate::core_crypto::entities::CompactPkeCrs); + +#[wasm_bindgen] +pub struct CompactPkePublicParams(pub(crate) crate::zk::CompactPkePublicParams); + +#[wasm_bindgen] +impl CompactPkePublicParams { + #[wasm_bindgen] + pub fn serialize(&self) -> Result, JsError> { + catch_panic_result(|| bincode::serialize(&self.0).map_err(into_js_error)) + } + + #[wasm_bindgen] + pub fn deserialize(buffer: &[u8]) -> Result { + catch_panic_result(|| { + bincode::deserialize(buffer) + .map(CompactPkePublicParams) + .map_err(into_js_error) + }) + } +} + +#[wasm_bindgen] +impl CompactPkeCrs { + #[wasm_bindgen] + pub fn from_parameters( + parameters: ShortintParameters, + max_num_message: usize, + ) -> Result { + catch_panic_result(|| { + crate::core_crypto::entities::CompactPkeCrs::from_shortint_params( + parameters.0, + max_num_message, + ) + .map(CompactPkeCrs) + .map_err(into_js_error) + }) + } + + #[wasm_bindgen] + pub fn from_config(config: &TfheConfig, max_num_bits: usize) -> Result { + catch_panic_result(|| { + crate::core_crypto::entities::CompactPkeCrs::from_config(config.0, max_num_bits) + .map(CompactPkeCrs) + .map_err(into_js_error) + }) + } + + #[wasm_bindgen] + pub fn public_params(&self) -> CompactPkePublicParams { + CompactPkePublicParams(self.0.public_params().clone()) + } +} diff --git a/tfhe/src/js_on_wasm_api/shortint.rs b/tfhe/src/js_on_wasm_api/shortint.rs index 89fb69ffe7..21908319cf 100644 --- a/tfhe/src/js_on_wasm_api/shortint.rs +++ b/tfhe/src/js_on_wasm_api/shortint.rs @@ -30,13 +30,136 @@ pub struct Shortint {} #[wasm_bindgen] pub struct ShortintParameters(pub(crate) crate::shortint::ClassicPBSParameters); +#[wasm_bindgen] +impl ShortintParameters { + #[wasm_bindgen] + pub fn lwe_dimension(&self) -> usize { + self.0.lwe_dimension.0 + } + + #[wasm_bindgen] + pub fn set_lwe_dimension(&mut self, new_value: usize) { + self.0.lwe_dimension.0 = new_value; + } + + #[wasm_bindgen] + pub fn glwe_dimension(&self) -> usize { + self.0.glwe_dimension.0 + } + + #[wasm_bindgen] + pub fn set_glwe_dimension(&mut self, new_value: usize) { + self.0.glwe_dimension.0 = new_value; + } + + #[wasm_bindgen] + pub fn polynomial_size(&self) -> usize { + self.0.polynomial_size.0 + } + + #[wasm_bindgen] + pub fn set_polynomial_size(&mut self, new_value: usize) { + self.0.polynomial_size.0 = new_value; + } + + #[wasm_bindgen] + pub fn lwe_noise_distribution(&self) -> ShortintNoiseDistribution { + ShortintNoiseDistribution(self.0.lwe_noise_distribution) + } + + #[wasm_bindgen] + pub fn set_lwe_noise_distribution(&mut self, new_value: &ShortintNoiseDistribution) { + self.0.lwe_noise_distribution = new_value.0; + } + + #[wasm_bindgen] + pub fn glwe_noise_distribution(&self) -> ShortintNoiseDistribution { + ShortintNoiseDistribution(self.0.lwe_noise_distribution) + } + + #[wasm_bindgen] + pub fn set_glwe_noise_distribution(&mut self, new_value: &ShortintNoiseDistribution) { + self.0.glwe_noise_distribution = new_value.0; + } + + #[wasm_bindgen] + pub fn pbs_base_log(&self) -> usize { + self.0.pbs_base_log.0 + } + + #[wasm_bindgen] + pub fn set_pbs_base_log(&mut self, new_value: usize) { + self.0.pbs_base_log.0 = new_value; + } + + #[wasm_bindgen] + pub fn pbs_level(&self) -> usize { + self.0.pbs_level.0 + } + + #[wasm_bindgen] + pub fn set_pbs_level(&mut self, new_value: usize) { + self.0.pbs_level.0 = new_value; + } + + #[wasm_bindgen] + pub fn ks_base_log(&self) -> usize { + self.0.ks_base_log.0 + } + + #[wasm_bindgen] + pub fn set_ks_base_log(&mut self, new_value: usize) { + self.0.ks_base_log.0 = new_value; + } + + #[wasm_bindgen] + pub fn ks_level(&self) -> usize { + self.0.ks_level.0 + } + + #[wasm_bindgen] + pub fn set_ks_level(&mut self, new_value: usize) { + self.0.ks_level.0 = new_value; + } + + #[wasm_bindgen] + pub fn message_modulus(&self) -> usize { + self.0.message_modulus.0 + } + + #[wasm_bindgen] + pub fn set_message_modulus(&mut self, new_value: usize) { + self.0.message_modulus.0 = new_value; + } + + #[wasm_bindgen] + pub fn carry_modulus(&self) -> usize { + self.0.carry_modulus.0 + } + + #[wasm_bindgen] + pub fn set_carry_modulus(&mut self, new_value: usize) { + self.0.carry_modulus.0 = new_value; + } + + #[wasm_bindgen] + pub fn encryption_key_choice(&self) -> ShortintEncryptionKeyChoice { + self.0.encryption_key_choice.into() + } + + #[wasm_bindgen] + pub fn set_encryption_key_choice(&mut self, new_value: ShortintEncryptionKeyChoice) { + self.0.encryption_key_choice = new_value.into(); + } +} + #[wasm_bindgen] pub enum ShortintEncryptionKeyChoice { Big, Small, } -impl From for crate::shortint::parameters::EncryptionKeyChoice { +impl From for EncryptionKeyChoice { fn from(value: ShortintEncryptionKeyChoice) -> Self { match value { ShortintEncryptionKeyChoice::Big => Self::Big, @@ -45,6 +168,15 @@ impl From for crate::shortint::parameters::Encrypti } } +impl From for ShortintEncryptionKeyChoice { + fn from(value: EncryptionKeyChoice) -> Self { + match value { + EncryptionKeyChoice::Big => Self::Big, + EncryptionKeyChoice::Small => Self::Small, + } + } +} + #[wasm_bindgen] pub struct ShortintNoiseDistribution( pub(crate) crate::core_crypto::commons::math::random::DynamicDistribution, diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index 57f4a9ee2d..6570816b0b 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -109,7 +109,8 @@ mod js_on_wasm_api; doctest, feature = "shortint", feature = "boolean", - feature = "integer" + feature = "integer", + feature = "zk-pok-experimental" ))] mod test_user_docs; @@ -129,3 +130,10 @@ pub mod safe_deserialization; pub mod conformance; pub mod named; + +pub mod error; +#[cfg(feature = "zk-pok-experimental")] +pub mod zk; + +pub use error::{Error, ErrorKind}; +pub type Result = std::result::Result; diff --git a/tfhe/src/shortint/ciphertext/mod.rs b/tfhe/src/shortint/ciphertext/mod.rs index bd11653188..b39060e84b 100644 --- a/tfhe/src/shortint/ciphertext/mod.rs +++ b/tfhe/src/shortint/ciphertext/mod.rs @@ -9,3 +9,8 @@ pub use compact_list::*; pub use compressed::*; pub use compressed_modulus_switched_ciphertext::*; pub use standard::*; +#[cfg(feature = "zk-pok-experimental")] +pub use zk::*; + +#[cfg(feature = "zk-pok-experimental")] +mod zk; diff --git a/tfhe/src/shortint/ciphertext/zk.rs b/tfhe/src/shortint/ciphertext/zk.rs new file mode 100644 index 0000000000..da1de3fde8 --- /dev/null +++ b/tfhe/src/shortint/ciphertext/zk.rs @@ -0,0 +1,194 @@ +use crate::core_crypto::algorithms::verify_lwe_compact_ciphertext_list; +use crate::core_crypto::prelude::verify_lwe_ciphertext; +use crate::shortint::ciphertext::CompactCiphertextList; +use crate::shortint::{Ciphertext, CompactPublicKey, EncryptionKeyChoice}; +use crate::zk::{CompactPkeCrs, CompactPkeProof, CompactPkePublicParams, ZkVerificationOutCome}; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; + +impl CompactPkeCrs { + /// Construct the CRS that corresponds to the given parameters + /// + /// max_num_message is how many message a single proof can prove + pub fn from_shortint_params( + params: impl Into, + max_num_message: usize, + ) -> crate::Result { + let params = params.into(); + let (size, noise_distribution) = match params.encryption_key_choice() { + EncryptionKeyChoice::Big => { + let size = params + .glwe_dimension() + .to_equivalent_lwe_dimension(params.polynomial_size()); + (size, params.glwe_noise_distribution()) + } + EncryptionKeyChoice::Small => (params.lwe_dimension(), params.lwe_noise_distribution()), + }; + + let mut plaintext_modulus = (params.message_modulus().0 * params.carry_modulus().0) as u64; + // Our plaintext modulus does not take into account the bit of padding + plaintext_modulus *= 2; + + crate::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| { + Self::new( + size, + max_num_message, + noise_distribution, + params.ciphertext_modulus(), + plaintext_modulus, + &mut engine.random_generator, + ) + }) + } +} + +/// A Ciphertext tied to a zero-knowledge proof +/// +/// The proof can only be generated during the encryption with a [CompactPublicKey] +pub struct ProvenCiphertext { + pub(crate) ciphertext: Ciphertext, + pub(crate) proof: CompactPkeProof, +} + +impl ProvenCiphertext { + pub fn ciphertext(&self) -> &Ciphertext { + &self.ciphertext + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + verify_lwe_ciphertext( + &self.ciphertext.ct, + &public_key.key, + &self.proof, + public_params, + ) + } +} + +/// A List of CompactCiphertext with their zero-knowledge proofs +/// +/// The proofs can only be generated during the encryption with a [CompactPublicKey] +#[derive(Clone, Serialize, Deserialize)] +pub struct ProvenCompactCiphertextList { + pub(crate) proved_lists: Vec<(CompactCiphertextList, CompactPkeProof)>, +} + +impl ProvenCompactCiphertextList { + pub fn ciphertext_count(&self) -> usize { + self.proved_lists + .iter() + .map(|(list, _)| list.ct_list.lwe_ciphertext_count().0) + .sum() + } + + pub fn verify_and_expand( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> crate::Result> { + let not_all_valid = self.proved_lists.par_iter().any(|(ct_list, proof)| { + verify_lwe_compact_ciphertext_list( + &ct_list.ct_list, + &public_key.key, + proof, + public_params, + ) + .is_invalid() + }); + + if not_all_valid { + return Err(crate::ErrorKind::InvalidZkProof.into()); + } + + let expanded = self + .proved_lists + .iter() + .flat_map(|(ct_list, _proof)| ct_list.expand()) + .collect(); + + Ok(expanded) + } + + pub fn verify( + &self, + public_params: &CompactPkePublicParams, + public_key: &CompactPublicKey, + ) -> ZkVerificationOutCome { + let all_valid = self.proved_lists.par_iter().all(|(ct_list, proof)| { + verify_lwe_compact_ciphertext_list( + &ct_list.ct_list, + &public_key.key, + proof, + public_params, + ) + .is_valid() + }); + + if all_valid { + ZkVerificationOutCome::Valid + } else { + ZkVerificationOutCome::Invalid + } + } +} + +#[cfg(test)] +mod tests { + use crate::shortint::parameters::DynamicDistribution; + use crate::shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + use crate::shortint::{ClientKey, CompactPublicKey}; + use crate::zk::{CompactPkeCrs, ZkComputeLoad}; + use rand::random; + + #[test] + fn test_zk_ciphertext_encryption_ci_run_filter() { + let mut params = PARAM_MESSAGE_2_CARRY_2_KS_PBS; + params.glwe_noise_distribution = DynamicDistribution::new_t_uniform(9); + + let crs = CompactPkeCrs::from_shortint_params(params, 4).unwrap(); + let cks = ClientKey::new(params); + let pk = CompactPublicKey::new(&cks); + + let msg = random::() % params.message_modulus.0 as u64; + + let proven_ct = pk + .encrypt_and_prove(msg, crs.public_params(), ZkComputeLoad::Proof) + .unwrap(); + assert!(proven_ct.verify(crs.public_params(), &pk).is_valid()); + + let decrypted = cks.decrypt(proven_ct.ciphertext()); + assert_eq!(msg, decrypted); + } + + #[test] + fn test_zk_compact_ciphertext_list_encryption_ci_run_filter() { + let mut params = PARAM_MESSAGE_2_CARRY_2_KS_PBS; + params.glwe_noise_distribution = DynamicDistribution::new_t_uniform(9); + + let crs = CompactPkeCrs::from_shortint_params(params, 512).unwrap(); + let cks = ClientKey::new(params); + let pk = CompactPublicKey::new(&cks); + + let msgs = (0..512) + .map(|_| random::() % params.message_modulus.0 as u64) + .collect::>(); + + let proven_ct = pk + .encrypt_and_prove_slice(&msgs, crs.public_params(), ZkComputeLoad::Proof) + .unwrap(); + assert!(proven_ct.verify(crs.public_params(), &pk).is_valid()); + + let expanded = proven_ct + .verify_and_expand(crs.public_params(), &pk) + .unwrap(); + let decrypted = expanded + .iter() + .map(|ciphertext| cks.decrypt(ciphertext)) + .collect::>(); + assert_eq!(msgs, decrypted); + } +} diff --git a/tfhe/src/shortint/engine/mod.rs b/tfhe/src/shortint/engine/mod.rs index 4cce4bc86f..196b1848f3 100644 --- a/tfhe/src/shortint/engine/mod.rs +++ b/tfhe/src/shortint/engine/mod.rs @@ -7,6 +7,8 @@ use crate::core_crypto::commons::computation_buffers::ComputationBuffers; use crate::core_crypto::commons::generators::{ DeterministicSeeder, EncryptionRandomGenerator, SecretRandomGenerator, }; +#[cfg(feature = "zk-pok-experimental")] +use crate::core_crypto::commons::math::random::RandomGenerator; use crate::core_crypto::commons::math::random::{ActivatedRandomGenerator, Seeder}; use crate::core_crypto::entities::*; use crate::core_crypto::prelude::ContainerMut; @@ -284,6 +286,8 @@ pub struct ShortintEngine { /// A seeder that can be called to generate 128 bits seeds, useful to create new /// [`EncryptionRandomGenerator`] to encrypt seeded types. pub(crate) seeder: DeterministicSeeder, + #[cfg(feature = "zk-pok-experimental")] + pub(crate) random_generator: RandomGenerator, pub(crate) computation_buffers: ComputationBuffers, ciphertext_buffers: Memory, } @@ -327,6 +331,8 @@ impl ShortintEngine { deterministic_seeder.seed(), &mut deterministic_seeder, ), + #[cfg(feature = "zk-pok-experimental")] + random_generator: RandomGenerator::new(deterministic_seeder.seed()), seeder: deterministic_seeder, computation_buffers: ComputationBuffers::default(), ciphertext_buffers: Memory::default(), diff --git a/tfhe/src/shortint/public_key/compact.rs b/tfhe/src/shortint/public_key/compact.rs index cfacad37a3..57125195d0 100644 --- a/tfhe/src/shortint/public_key/compact.rs +++ b/tfhe/src/shortint/public_key/compact.rs @@ -1,3 +1,7 @@ +#[cfg(feature = "zk-pok-experimental")] +use crate::core_crypto::algorithms::encrypt_and_prove_lwe_ciphertext_with_compact_public_key; +#[cfg(feature = "zk-pok-experimental")] +use crate::core_crypto::entities::Cleartext; use crate::core_crypto::prelude::{ allocate_and_generate_new_seeded_lwe_compact_public_key, encrypt_lwe_ciphertext_with_compact_public_key, generate_lwe_compact_public_key, @@ -5,8 +9,12 @@ use crate::core_crypto::prelude::{ LweCompactPublicKeyOwned, Plaintext, PlaintextList, SeededLweCompactPublicKeyOwned, }; use crate::shortint::ciphertext::{CompactCiphertextList, Degree, NoiseLevel}; +#[cfg(feature = "zk-pok-experimental")] +use crate::shortint::ciphertext::{ProvenCiphertext, ProvenCompactCiphertextList}; use crate::shortint::engine::ShortintEngine; use crate::shortint::{Ciphertext, ClientKey, PBSOrder, ShortintParameterSet}; +#[cfg(feature = "zk-pok-experimental")] +use crate::zk::{CompactPkePublicParams, ZkComputeLoad}; use serde::{Deserialize, Serialize}; use std::iter::once; @@ -162,12 +170,8 @@ impl CompactPublicKey { ); let encryption_noise_distribution = match self.pbs_order { - crate::shortint::PBSOrder::KeyswitchBootstrap => { - self.parameters.glwe_noise_distribution() - } - crate::shortint::PBSOrder::BootstrapKeyswitch => { - self.parameters.lwe_noise_distribution() - } + PBSOrder::KeyswitchBootstrap => self.parameters.glwe_noise_distribution(), + PBSOrder::BootstrapKeyswitch => self.parameters.lwe_noise_distribution(), }; ShortintEngine::with_thread_local_mut(|engine| { @@ -193,6 +197,58 @@ impl CompactPublicKey { ) } + #[cfg(feature = "zk-pok-experimental")] + pub fn encrypt_and_prove( + &self, + message: u64, + public_params: &CompactPkePublicParams, + load: ZkComputeLoad, + ) -> crate::Result { + // This allocates the required ct + let mut encrypted_ct = LweCiphertextOwned::new( + 0u64, + self.key.lwe_dimension().to_lwe_size(), + self.parameters.ciphertext_modulus(), + ); + + let encryption_noise_distribution = match self.pbs_order { + PBSOrder::KeyswitchBootstrap => self.parameters.glwe_noise_distribution(), + PBSOrder::BootstrapKeyswitch => self.parameters.lwe_noise_distribution(), + }; + + let plaintext_modulus = + (self.parameters.message_modulus().0 * self.parameters.carry_modulus().0) as u64; + let delta = (1u64 << 63) / plaintext_modulus; + + let proof = ShortintEngine::with_thread_local_mut(|engine| { + encrypt_and_prove_lwe_ciphertext_with_compact_public_key( + &self.key, + &mut encrypted_ct, + Cleartext(message), + delta, + encryption_noise_distribution, + encryption_noise_distribution, + &mut engine.secret_generator, + &mut engine.encryption_generator, + &mut engine.random_generator, + public_params, + load, + ) + })?; + + let message_modulus = self.parameters.message_modulus(); + let ciphertext = Ciphertext::new( + encrypted_ct, + Degree::new(message_modulus.0 - 1), + NoiseLevel::NOMINAL, + message_modulus, + self.parameters.carry_modulus(), + self.pbs_order, + ); + + Ok(ProvenCiphertext { ciphertext, proof }) + } + pub fn encrypt_slice(&self, messages: &[u64]) -> CompactCiphertextList { self.encrypt_iter(messages.iter().copied()) } @@ -211,12 +267,8 @@ impl CompactPublicKey { ); let encryption_noise_distribution = match self.pbs_order { - crate::shortint::PBSOrder::KeyswitchBootstrap => { - self.parameters.glwe_noise_distribution() - } - crate::shortint::PBSOrder::BootstrapKeyswitch => { - self.parameters.lwe_noise_distribution() - } + PBSOrder::KeyswitchBootstrap => self.parameters.glwe_noise_distribution(), + PBSOrder::BootstrapKeyswitch => self.parameters.lwe_noise_distribution(), }; // No parallelism allowed @@ -264,6 +316,91 @@ impl CompactPublicKey { } } + #[cfg(feature = "zk-pok-experimental")] + pub fn encrypt_and_prove_slice( + &self, + messages: &[u64], + public_params: &CompactPkePublicParams, + load: ZkComputeLoad, + ) -> crate::Result { + let plaintext_modulus = + (self.parameters.message_modulus().0 * self.parameters.carry_modulus().0) as u64; + let delta = (1u64 << 63) / plaintext_modulus; + + let max_num_message = public_params.k; + let num_lists = messages.len().div_ceil(max_num_message); + let mut proved_lists = Vec::with_capacity(num_lists); + for message_chunk in messages.chunks(max_num_message) { + let mut ct_list = LweCompactCiphertextListOwned::new( + 0u64, + self.key.lwe_dimension().to_lwe_size(), + LweCiphertextCount(message_chunk.len()), + self.parameters.ciphertext_modulus(), + ); + + let encryption_noise_distribution = match self.pbs_order { + PBSOrder::KeyswitchBootstrap => self.parameters.glwe_noise_distribution(), + PBSOrder::BootstrapKeyswitch => self.parameters.lwe_noise_distribution(), + }; + + // No parallelism allowed + #[cfg(all(feature = "__wasm_api", not(feature = "parallel-wasm-api")))] + let proof = { + use crate::core_crypto::prelude::encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key; + ShortintEngine::with_thread_local_mut(|engine| { + encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key( + &self.key, + &mut ct_list, + &message_chunk, + delta, + encryption_noise_distribution, + encryption_noise_distribution, + &mut engine.secret_generator, + &mut engine.encryption_generator, + &mut engine.random_generator, + public_params, + load, + ) + }) + }?; + + // Parallelism allowed / + #[cfg(any(not(feature = "__wasm_api"), feature = "parallel-wasm-api"))] + let proof = { + use crate::core_crypto::prelude::par_encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key; + ShortintEngine::with_thread_local_mut(|engine| { + par_encrypt_and_prove_lwe_compact_ciphertext_list_with_compact_public_key( + &self.key, + &mut ct_list, + &message_chunk, + delta, + encryption_noise_distribution, + encryption_noise_distribution, + &mut engine.secret_generator, + &mut engine.encryption_generator, + &mut engine.random_generator, + public_params, + load, + ) + }) + }?; + + let message_modulus = self.parameters.message_modulus(); + let ciphertext = CompactCiphertextList { + ct_list, + degree: Degree::new(message_modulus.0 - 1), + message_modulus, + carry_modulus: self.parameters.carry_modulus(), + pbs_order: self.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }; + + proved_lists.push((ciphertext, proof)); + } + + Ok(ProvenCompactCiphertextList { proved_lists }) + } + pub fn size_elements(&self) -> usize { self.key.size_elements() } diff --git a/tfhe/src/test_user_docs.rs b/tfhe/src/test_user_docs.rs index 052bd00a3b..c0baee0427 100644 --- a/tfhe/src/test_user_docs.rs +++ b/tfhe/src/test_user_docs.rs @@ -61,6 +61,7 @@ mod test_cpu_doc { "../docs/guides/trivial_ciphertext.md", guides_trivial_ciphertext ); + doctest!("../docs/guides/zk-pok.md", guides_zk_pok); // REFERENCES diff --git a/tfhe/src/zk.rs b/tfhe/src/zk.rs new file mode 100644 index 0000000000..b22a8b49d2 --- /dev/null +++ b/tfhe/src/zk.rs @@ -0,0 +1,140 @@ +use crate::core_crypto::commons::math::random::{BoundedDistribution, Deserialize, Serialize}; +use crate::core_crypto::prelude::*; +use rand_core::RngCore; +use std::cmp::Ordering; +use std::collections::Bound; +use std::fmt::Debug; +use tfhe_zk_pok::proofs::pke::crs_gen; + +pub use tfhe_zk_pok::proofs::ComputeLoad as ZkComputeLoad; +type Curve = tfhe_zk_pok::curve_api::Bls12_446; +pub type CompactPkeProof = tfhe_zk_pok::proofs::pke::Proof; +pub type CompactPkePublicParams = tfhe_zk_pok::proofs::pke::PublicParams; + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum ZkVerificationOutCome { + /// The proof ands its entity were valid + Valid, + /// The proof ands its entity were not + Invalid, +} + +impl ZkVerificationOutCome { + pub fn is_valid(self) -> bool { + self == Self::Valid + } + + pub fn is_invalid(self) -> bool { + self == Self::Invalid + } +} + +#[derive(Serialize, Deserialize)] +pub struct CompactPkeCrs { + public_params: CompactPkePublicParams, +} + +impl CompactPkeCrs { + pub fn new( + lwe_dim: LweDimension, + max_num_cleartext: usize, + noise_distribution: NoiseDistribution, + ciphertext_modulus: CiphertextModulus, + plaintext_modulus: Scalar, + rng: &mut impl RngCore, + ) -> crate::Result + where + Scalar: UnsignedInteger + CastInto + Debug, + NoiseDistribution: BoundedDistribution, + { + // The bound for the crs has to be a power of two, + // it is [-b, b) (non-inclusive for the high bound) + // so we may have to give a bound that is bigger than + // what the distribution generates + let high_bound = match noise_distribution.high_bound() { + Bound::Included(high_b) => { + let high_b = high_b.wrapping_abs().into_unsigned(); + if high_b.is_power_of_two() { + high_b * Scalar::TWO + } else { + high_b.next_power_of_two() + } + } + Bound::Excluded(high_b) => { + let high_b = high_b.wrapping_abs().into_unsigned(); + if high_b.is_power_of_two() { + high_b + } else { + high_b.next_power_of_two() + } + } + Bound::Unbounded => { + return Err("requires bounded distribution".into()); + } + }; + + let abs_low_bound = match noise_distribution.low_bound() { + Bound::Included(low_b) => { + let low_b = low_b.wrapping_abs().into_unsigned(); + if low_b.is_power_of_two() { + low_b * Scalar::TWO + } else { + low_b.next_power_of_two() + } + } + Bound::Excluded(low_b) => { + let low_b = low_b.wrapping_abs().into_unsigned(); + if low_b.is_power_of_two() { + low_b + } else { + low_b.next_power_of_two() + } + } + Bound::Unbounded => { + return Err("requires bounded distribution".into()); + } + }; + + let noise_bound = abs_low_bound.max(high_bound); + + if Scalar::BITS > 64 && noise_bound >= (Scalar::ONE << 64usize) { + return Err("noise bounds exceeds 64 bits modulus".into()); + } + + if Scalar::BITS > 64 && plaintext_modulus >= (Scalar::ONE << 64usize) { + return Err("Plaintext modulus exceeds 64 bits modulus".into()); + } + + let q = if ciphertext_modulus.is_native_modulus() { + match Scalar::BITS.cmp(&64) { + Ordering::Greater => Err( + "Zero Knowledge proof do not support ciphertext modulus > 64 bits".to_string(), + ), + Ordering::Equal => Ok(0u64), + Ordering::Less => Ok(1u64 << Scalar::BITS), + } + } else { + let custom_modulus = ciphertext_modulus.get_custom_modulus(); + if custom_modulus > (u64::MAX) as u128 { + Err("Zero Knowledge proof do not support ciphertext modulus > 64 bits".to_string()) + } else { + Ok(custom_modulus as u64) + } + }?; + + let public_params = crs_gen( + lwe_dim.0, + max_num_cleartext, + noise_bound.cast_into(), + q, + plaintext_modulus.cast_into(), + rng, + ); + + Ok(Self { public_params }) + } + + pub fn public_params(&self) -> &CompactPkePublicParams { + &self.public_params + } +} diff --git a/tfhe/web_wasm_parallel_tests/index.html b/tfhe/web_wasm_parallel_tests/index.html index 1792a2fb95..8f624f4fcb 100644 --- a/tfhe/web_wasm_parallel_tests/index.html +++ b/tfhe/web_wasm_parallel_tests/index.html @@ -68,7 +68,12 @@ value="Compressed Compact Public Key Test 256 Bits Big" disabled /> - +
@@ -110,6 +115,12 @@ value="Compressed Server Key Bench 2_2" disabled /> +
diff --git a/tfhe/web_wasm_parallel_tests/index.js b/tfhe/web_wasm_parallel_tests/index.js index 3a951ba64b..9d189d4e31 100644 --- a/tfhe/web_wasm_parallel_tests/index.js +++ b/tfhe/web_wasm_parallel_tests/index.js @@ -31,12 +31,14 @@ async function setup() { "compactPublicKeyTest256BitSmall", "compressedCompactPublicKeyTest256BitBig", "compressedCompactPublicKeyTest256BitSmall", + "compactPublicKeyZeroKnowledge", "compactPublicKeyBench32BitBig", "compactPublicKeyBench32BitSmall", "compactPublicKeyBench256BitBig", "compactPublicKeyBench256BitSmall", "compressedServerKeyBenchMessage1Carry1", "compressedServerKeyBenchMessage2Carry2", + "compactPublicKeyZeroKnowledgeBench", ]; function setupBtn(id) { diff --git a/tfhe/web_wasm_parallel_tests/jest.config.js b/tfhe/web_wasm_parallel_tests/jest.config.js index 55398dd306..f86081c954 100644 --- a/tfhe/web_wasm_parallel_tests/jest.config.js +++ b/tfhe/web_wasm_parallel_tests/jest.config.js @@ -1,4 +1,4 @@ -const secs = 60; +const secs = 1200; // 20 Minutes const config = { verbose: true, diff --git a/tfhe/web_wasm_parallel_tests/test/common.mjs b/tfhe/web_wasm_parallel_tests/test/common.mjs index f242e628b5..ccceff96f5 100644 --- a/tfhe/web_wasm_parallel_tests/test/common.mjs +++ b/tfhe/web_wasm_parallel_tests/test/common.mjs @@ -50,22 +50,26 @@ async function runActualTest(page, buttonId) { } } +const TWENTY_MINUTES = 1200 * 1000; + async function runTestAttachedToButton(buttonId) { let browser; if (isRoot()) { browser = await puppeteer.launch({ headless: "new", args: ["--no-sandbox"], + protocolTimeout: TWENTY_MINUTES, }); } else { browser = await puppeteer.launch({ headless: "new", + protocolTimeout: TWENTY_MINUTES, }); } let page = await browser.newPage(); - await page.setDefaultTimeout(300000); // Five minutes timeout + await page.setDefaultTimeout(TWENTY_MINUTES); await page.goto("http://localhost:3000"); page.on("console", (msg) => console.log("PAGE LOG:", msg.text())); diff --git a/tfhe/web_wasm_parallel_tests/test/compact-public-key.test.js b/tfhe/web_wasm_parallel_tests/test/compact-public-key.test.js index 46a63b85c4..c575e5f6df 100644 --- a/tfhe/web_wasm_parallel_tests/test/compact-public-key.test.js +++ b/tfhe/web_wasm_parallel_tests/test/compact-public-key.test.js @@ -23,3 +23,11 @@ it("Compressed Compact Public Key Test Small 256 Bit", async () => { it("Compressed Compact Public Key Test Big 256 Bit", async () => { await runTestAttachedToButton("compressedCompactPublicKeyTest256BitBig"); }); + +it( + "Compact Public Key Test Big 64 Bit With Zero Knowledge", + async () => { + await runTestAttachedToButton("compactPublicKeyZeroKnowledge"); + }, + 1200 * 1000, +); // 20 minutes timeout diff --git a/tfhe/web_wasm_parallel_tests/worker.js b/tfhe/web_wasm_parallel_tests/worker.js index 763f6dfabc..d8d17374e8 100644 --- a/tfhe/web_wasm_parallel_tests/worker.js +++ b/tfhe/web_wasm_parallel_tests/worker.js @@ -11,17 +11,15 @@ import init, { TfheCompressedCompactPublicKey, TfheCompactPublicKey, TfheConfigBuilder, - CompressedFheUint8, FheUint8, - FheUint32, - CompactFheUint32, CompactFheUint32List, - CompressedFheUint128, - FheUint128, - CompressedFheUint256, - FheUint256, - CompactFheUint256, CompactFheUint256List, + ZkComputeLoad, + ProvenCompactFheUint64, + ProvenCompactFheUint64List, + CompactPkeCrs, + Shortint, + CompactFheUint64, } from "./pkg/tfhe.js"; function assert(cond, text) { @@ -74,7 +72,7 @@ async function compressedPublicKeyTest() { } async function publicKeyTest() { - let config = TfheConfigBuilder.default().use_small_encryption().build(); + let config = TfheConfigBuilder.default_with_small_encryption().build(); console.time("ClientKey Gen"); let clientKey = TfheClientKey.generate(config); @@ -379,6 +377,106 @@ async function compressedCompactPublicKeyTest256BitOnConfig(config) { } } +function generateRandomBigInt(bitLen) { + let result = BigInt(0); + for (let i = 0; i < bitLen; i++) { + result << 1n; + result |= BigInt(Math.random() < 0.5); + } + return result; +} + +async function compactPublicKeyZeroKnowledge() { + let block_params = new ShortintParameters( + ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + block_params.set_glwe_noise_distribution(Shortint.try_new_t_uniform(9)); + + let config = TfheConfigBuilder.default() + .use_custom_parameters(block_params) + .build(); + + let clientKey = TfheClientKey.generate(config); + let publicKey = TfheCompactPublicKey.new(clientKey); + + console.log("Start CRS generation"); + console.time("CRS generation"); + let crs = CompactPkeCrs.from_config(config, 4 * 64); + console.timeEnd("CRS generation"); + let public_params = crs.public_params(); + + { + let input = generateRandomBigInt(64); + let start = performance.now(); + let encrypted = ProvenCompactFheUint64.encrypt_with_compact_public_key( + input, + public_params, + publicKey, + ZkComputeLoad.Proof, + ); + let end = performance.now(); + console.log( + "Time to encrypt + prove CompactFheUint64: ", + end - start, + " ms", + ); + + let bytes = encrypted.serialize(); + console.log("ProvenCompactFheUint64 size:", bytes.length); + + assert_eq(encrypted.verifies(public_params, publicKey), true); + + start = performance.now(); + let expanded = encrypted.verify_and_expand(public_params, publicKey); + end = performance.now(); + console.log( + "Time to verify + expand CompactFheUint64: ", + end - start, + " ms", + ); + + let decrypted = expanded.decrypt(clientKey); + assert_eq(decrypted, input); + } + + { + let inputs = [ + generateRandomBigInt(64), + generateRandomBigInt(64), + generateRandomBigInt(64), + generateRandomBigInt(64), + ]; + let start = performance.now(); + let encrypted = ProvenCompactFheUint64List.encrypt_with_compact_public_key( + inputs, + public_params, + publicKey, + ZkComputeLoad.Proof, + ); + let end = performance.now(); + console.log( + "Time to encrypt + prove CompactFheUint64List of 4: ", + end - start, + " ms", + ); + assert_eq(encrypted.verifies(public_params, publicKey), true); + + start = performance.now(); + let expanded_list = encrypted.verify_and_expand(public_params, publicKey); + end = performance.now(); + console.log( + "Time to verify + expand CompactFheUint64: ", + end - start, + " ms", + ); + + for (let i = 0; i < inputs.length; i++) { + let decrypted = expanded_list[i].decrypt(clientKey); + assert_eq(decrypted, inputs[i]); + } + } +} + async function compressedCompactPublicKeyTest256BitBig() { const block_params = new ShortintParameters( ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS, @@ -550,6 +648,60 @@ async function compressedServerKeyBenchMessage2Carry2() { ); } +async function compactPublicKeyZeroKnowledgeBench() { + let block_params = new ShortintParameters( + ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS, + ); + block_params.set_lwe_noise_distribution(Shortint.try_new_t_uniform(9)); + + let config = TfheConfigBuilder.default() + .use_custom_parameters(block_params) + .build(); + + let clientKey = TfheClientKey.generate(config); + let publicKey = TfheCompactPublicKey.new(clientKey); + + console.log("Start CRS generation"); + console.time("CRS generation"); + let crs = CompactPkeCrs.from_config(config, 4 * 64); + console.timeEnd("CRS generation"); + let public_params = crs.public_params(); + + const bench_loops = 4; // The computation is expensive + let bench_results = {}; + let load_choices = [ZkComputeLoad.Proof, ZkComputeLoad.Verify]; + const load_to_str = { + [ZkComputeLoad.Proof]: "compute_load_proof", + [ZkComputeLoad.Verify]: "compute_load_verify", + }; + for (const loadChoice of load_choices) { + let timing = 0; + for (let i = 0; i < bench_loops; i++) { + let input = generateRandomBigInt(64); + + const start = performance.now(); + let _ = ProvenCompactFheUint64.encrypt_with_compact_public_key( + input, + public_params, + publicKey, + loadChoice, + ); + const end = performance.now(); + timing += end - start; + } + const mean = timing / bench_loops; + + const bench_str = + "compact_fhe_uint64_proven_encryption_" + + load_to_str[loadChoice] + + "_mean"; + console.log(bench_str, ": ", mean, " ms"); + bench_results["compact_fhe_uint64_proven_encryption_"] = mean; + } + + return bench_results; +} + async function main() { await init(); await initThreadPool(navigator.hardwareConcurrency); @@ -564,12 +716,14 @@ async function main() { compactPublicKeyTest256BitBig, compressedCompactPublicKeyTest256BitSmall, compressedCompactPublicKeyTest256BitBig, + compactPublicKeyZeroKnowledge, compactPublicKeyBench32BitBig, compactPublicKeyBench32BitSmall, compactPublicKeyBench256BitBig, compactPublicKeyBench256BitSmall, compressedServerKeyBenchMessage1Carry1, compressedServerKeyBenchMessage2Carry2, + compactPublicKeyZeroKnowledgeBench, }); }