Skip to content

Commit

Permalink
refactor(zk): handle compression without canonical serialize
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Sep 26, 2024
1 parent 6666908 commit 6146701
Show file tree
Hide file tree
Showing 12 changed files with 1,000 additions and 300 deletions.
1 change: 1 addition & 0 deletions tfhe-zk-pok/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ tfhe-versionable = { version = "0.3.0", path = "../utils/tfhe-versionable" }
[dev-dependencies]
serde_json = "~1.0"
itertools = "0.11.0"
bincode = "1.3.3"
28 changes: 23 additions & 5 deletions tfhe-zk-pok/src/backward_compatibility/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use tfhe_versionable::VersionsDispatch;

use crate::curve_api::Curve;
use crate::proofs::pke::Proof as PKEv1Proof;
use crate::proofs::pke_v2::Proof as PKEv2Proof;
use crate::curve_api::{Compressible, Curve};
use crate::proofs::pke::{CompressedProof as PKEv1CompressedProof, Proof as PKEv1Proof};
use crate::proofs::pke_v2::{CompressedProof as PKEv2CompressedProof, Proof as PKEv2Proof};
use crate::proofs::GroupElements;
use crate::serialization::{
SerializableAffine, SerializableCubicExtField, SerializableFp, SerializableFp2,
Expand Down Expand Up @@ -34,14 +34,32 @@ pub type SerializableG1AffineVersions = SerializableAffineVersions<SerializableF
pub type SerializableG2AffineVersions = SerializableAffineVersions<SerializableFp2>;
pub type SerializableFp12Versions = SerializableQuadExtFieldVersions<SerializableFp6>;

#[derive(VersionsDispatch)]
pub enum PKEv1ProofVersions<G: Curve> {
V0(PKEv1Proof<G>),
}

#[derive(VersionsDispatch)]
pub enum PKEv2ProofVersions<G: Curve> {
V0(PKEv2Proof<G>),
}

#[derive(VersionsDispatch)]
pub enum PKEv1ProofVersions<G: Curve> {
V0(PKEv1Proof<G>),
pub enum PKEv1CompressedProofVersions<G: Curve>
where
G::G1: Compressible,
G::G2: Compressible,
{
V0(PKEv1CompressedProof<G>),
}

#[derive(VersionsDispatch)]
pub enum PKEv2CompressedProofVersions<G: Curve>
where
G::G1: Compressible,
G::G2: Compressible,
{
V0(PKEv2CompressedProof<G>),
}

#[derive(VersionsDispatch)]
Expand Down
19 changes: 13 additions & 6 deletions tfhe-zk-pok/src/curve_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use ark_ec::short_weierstrass::Affine;
use ark_ec::{AdditiveGroup as Group, CurveGroup, VariableBaseMSM};
use ark_ff::{BigInt, Field, MontFp, Zero};
use ark_poly::univariate::DensePolynomial;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use core::fmt;
use core::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -108,9 +107,7 @@ pub trait CurveGroupOps<Zp>:
+ Sync
+ core::fmt::Debug
+ serde::Serialize
+ for<'de> serde::Deserialize<'de>
+ CanonicalSerialize
+ CanonicalDeserialize;
+ for<'de> serde::Deserialize<'de>;

fn projective(affine: Self::Affine) -> Self;

Expand All @@ -121,6 +118,16 @@ pub trait CurveGroupOps<Zp>:
fn normalize(self) -> Self::Affine;
}

/// Mark that an element can be compressed, by storing only the 'x' coordinates of the affine
/// representation and getting the 'y' from the curve.
pub trait Compressible: Sized {
type Compressed;
type UncompressError;

fn compress(&self) -> Self::Compressed;
fn uncompress(compressed: Self::Compressed) -> Result<Self, Self::UncompressError>;
}

pub trait PairingGroupOps<Zp, G1, G2>:
Copy
+ Send
Expand All @@ -139,8 +146,8 @@ pub trait PairingGroupOps<Zp, G1, G2>:

pub trait Curve: Clone {
type Zp: FieldOps;
type G1: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
type G2: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
type G1: CurveGroupOps<Self::Zp>;
type G2: CurveGroupOps<Self::Zp>;
type Gt: PairingGroupOps<Self::Zp, Self::G1, Self::G2>;
}

Expand Down
142 changes: 85 additions & 57 deletions tfhe-zk-pok/src/curve_api/bls12_381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,7 @@ mod g1 {

use super::*;

#[derive(
Copy,
Clone,
Debug,
PartialEq,
Eq,
Serialize,
Deserialize,
CanonicalSerialize,
CanonicalDeserialize,
Hash,
Versionize,
)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[versionize(
SerializableG1AffineVersions,
Expand All @@ -67,7 +55,7 @@ mod g1 {

impl From<G1Affine> for SerializableAffine<SerializableFp> {
fn from(value: G1Affine) -> Self {
SerializableAffine::compressed(value.inner)
SerializableAffine::uncompressed(value.inner)
}
}

Expand All @@ -81,6 +69,19 @@ mod g1 {
}
}

impl Compressible for G1Affine {
type Compressed = SerializableG1Affine;
type UncompressError = InvalidSerializedAffineError;

fn compress(&self) -> SerializableG1Affine {
SerializableAffine::compressed(self.inner)
}

fn uncompress(compressed: Self::Compressed) -> Result<Self, Self::UncompressError> {
compressed.try_into()
}
}

impl G1Affine {
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
Expand All @@ -96,18 +97,7 @@ mod g1 {
}
}

#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
Versionize,
)]
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[versionize(
SerializableG1AffineVersions,
Expand All @@ -121,11 +111,11 @@ mod g1 {

impl From<G1> for SerializableAffine<SerializableFp> {
fn from(value: G1) -> Self {
SerializableAffine::compressed(value.inner.into_affine())
SerializableAffine::uncompressed(value.inner.into_affine())
}
}

impl TryFrom<SerializableAffine<SerializableFp>> for G1 {
impl TryFrom<SerializableG1Affine> for G1 {
type Error = InvalidSerializedAffineError;

fn try_from(value: SerializableAffine<SerializableFp>) -> Result<Self, Self::Error> {
Expand All @@ -135,6 +125,19 @@ mod g1 {
}
}

impl Compressible for G1 {
type Compressed = SerializableG1Affine;
type UncompressError = InvalidSerializedAffineError;

fn compress(&self) -> SerializableG1Affine {
SerializableAffine::compressed(self.inner.into_affine())
}

fn uncompress(compressed: Self::Compressed) -> Result<Self, Self::UncompressError> {
compressed.try_into()
}
}

impl fmt::Debug for G1 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("G1")
Expand Down Expand Up @@ -266,19 +269,7 @@ mod g2 {

use super::*;

#[derive(
Copy,
Clone,
Debug,
PartialEq,
Eq,
Serialize,
Deserialize,
CanonicalSerialize,
CanonicalDeserialize,
Hash,
Versionize,
)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[versionize(
SerializableG2AffineVersions,
Expand All @@ -292,7 +283,7 @@ mod g2 {

impl From<G2Affine> for SerializableAffine<SerializableFp2> {
fn from(value: G2Affine) -> Self {
SerializableAffine::compressed(value.inner)
SerializableAffine::uncompressed(value.inner)
}
}

Expand All @@ -306,6 +297,20 @@ mod g2 {
}
}

impl Compressible for G2Affine {
type Compressed = SerializableG2Affine;

type UncompressError = InvalidSerializedAffineError;

fn compress(&self) -> SerializableAffine<SerializableFp2> {
SerializableAffine::compressed(self.inner)
}

fn uncompress(compressed: Self::Compressed) -> Result<Self, Self::UncompressError> {
compressed.try_into()
}
}

impl G2Affine {
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
Expand All @@ -321,18 +326,7 @@ mod g2 {
}
}

#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
CanonicalSerialize,
CanonicalDeserialize,
Hash,
Versionize,
)]
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[versionize(
SerializableG2AffineVersions,
Expand All @@ -344,13 +338,13 @@ mod g2 {
pub(crate) inner: ark_bls12_381::G2Projective,
}

impl From<G2> for SerializableAffine<SerializableFp2> {
impl From<G2> for SerializableG2Affine {
fn from(value: G2) -> Self {
SerializableAffine::compressed(value.inner.into_affine())
SerializableAffine::uncompressed(value.inner.into_affine())
}
}

impl TryFrom<SerializableAffine<SerializableFp2>> for G2 {
impl TryFrom<SerializableG2Affine> for G2 {
type Error = InvalidSerializedAffineError;

fn try_from(value: SerializableAffine<SerializableFp2>) -> Result<Self, Self::Error> {
Expand All @@ -360,6 +354,20 @@ mod g2 {
}
}

impl Compressible for G2 {
type Compressed = SerializableG2Affine;

type UncompressError = InvalidSerializedAffineError;

fn compress(&self) -> SerializableAffine<SerializableFp2> {
SerializableAffine::compressed(self.inner.into_affine())
}

fn uncompress(compressed: Self::Compressed) -> Result<Self, Self::UncompressError> {
compressed.try_into()
}
}

impl fmt::Debug for G2 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[allow(dead_code)]
Expand Down Expand Up @@ -998,6 +1006,26 @@ mod tests {
assert_eq!(g_hat_cur, g_hat_cur2);
}

#[test]
fn test_compressed_serialization() {
let rng = &mut StdRng::seed_from_u64(0);
let alpha = Zp::rand(rng);
let g_cur = G1::GENERATOR.mul_scalar(alpha);
let g_hat_cur = G2::GENERATOR.mul_scalar(alpha);

let g_cur2 = G1::uncompress(
serde_json::from_str(&serde_json::to_string(&g_cur.compress()).unwrap()).unwrap(),
)
.unwrap();
assert_eq!(g_cur, g_cur2);

let g_hat_cur2 = G2::uncompress(
serde_json::from_str(&serde_json::to_string(&g_hat_cur.compress()).unwrap()).unwrap(),
)
.unwrap();
assert_eq!(g_hat_cur, g_hat_cur2);
}

#[test]
fn test_hasher_and_eq() {
// we need to make sure if the points are the same
Expand Down
Loading

0 comments on commit 6146701

Please sign in to comment.