Skip to content

Commit

Permalink
offline-phase: lowgear: triplets: Complete triplet generation protocol
Browse files Browse the repository at this point in the history
Implements the latter half of the triplet generation protocol, computing
cross terms and summing them into shares of the product `c = a * b`.

Left to do are triplet sacrifice and authentication.
  • Loading branch information
joeykraut committed Apr 10, 2024
1 parent c59c6e8 commit ad10787
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 9 deletions.
19 changes: 19 additions & 0 deletions mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ mod ffi_inner {
// `Plaintext`
type Plaintext_mod_prime;
fn new_plaintext(params: &FHE_Params) -> UniquePtr<Plaintext_mod_prime>;
fn randomize_plaintext(plaintext: Pin<&mut Plaintext_mod_prime>);
fn clone(self: &Plaintext_mod_prime) -> UniquePtr<Plaintext_mod_prime>;
fn to_rust_bytes(self: &Plaintext_mod_prime) -> Vec<u8>;
fn plaintext_from_rust_bytes(
Expand Down Expand Up @@ -97,6 +98,11 @@ mod ffi_inner {
vector: &PlaintextVector,
index: usize,
) -> UniquePtr<Plaintext_mod_prime>;
fn set_plaintext_vector_element(
vector: Pin<&mut PlaintextVector>,
index: usize,
plaintext: &Plaintext_mod_prime,
);
fn randomize_plaintext_vector(vector: Pin<&mut PlaintextVector>);
fn push_plaintext_vector(
vector: Pin<&mut PlaintextVector>,
Expand All @@ -109,6 +115,7 @@ mod ffi_inner {
type Ciphertext;
fn clone(self: &Ciphertext) -> UniquePtr<Ciphertext>;
fn to_rust_bytes(self: &Ciphertext) -> Vec<u8>;
fn rerandomize(self: Pin<&mut Ciphertext>, pk: &FHE_PK);
fn ciphertext_from_rust_bytes(data: &[u8], params: &FHE_Params) -> UniquePtr<Ciphertext>;

fn add_plaintext(c0: &Ciphertext, p1: &Plaintext_mod_prime) -> UniquePtr<Ciphertext>;
Expand All @@ -120,10 +127,21 @@ mod ffi_inner {
type CiphertextVector;
fn new_ciphertext_vector(size: usize, params: &FHE_Params) -> UniquePtr<CiphertextVector>;
fn new_ciphertext_vector_single(ciphertext: &Ciphertext) -> UniquePtr<CiphertextVector>;
fn ciphertext_vector_to_rust_bytes(vector: &CiphertextVector) -> Vec<u8>;
fn ciphertext_vector_from_rust_bytes(
data: &[u8],
params: &FHE_Params,
) -> UniquePtr<CiphertextVector>;

fn get_ciphertext_vector_element(
vector: &CiphertextVector,
index: usize,
) -> UniquePtr<Ciphertext>;
fn set_ciphertext_vector_element(
vector: Pin<&mut CiphertextVector>,
index: usize,
ciphertext: &Ciphertext,
);
fn push_ciphertext_vector(vector: Pin<&mut CiphertextVector>, ciphertext: &Ciphertext);
fn pop_ciphertext_vector(vector: Pin<&mut CiphertextVector>);
fn ciphertext_vector_size(vector: &CiphertextVector) -> usize;
Expand Down Expand Up @@ -156,6 +174,7 @@ unsafe impl Send for CiphertextVector {}
unsafe impl Send for CiphertextWithProof {}
unsafe impl Send for Plaintext_mod_prime {}
unsafe impl Send for PlaintextVector {}
unsafe impl Sync for PlaintextVector {}

#[cfg(test)]
mod test {
Expand Down
35 changes: 34 additions & 1 deletion mp-spdz-rs/src/fhe/ciphertext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ pub struct Ciphertext<C: CurveGroup> {
}

impl<C: CurveGroup> Ciphertext<C> {
/// Rerandomize the ciphertext
pub fn rerandomize(&mut self, pk: &BGVPublicKey<C>) {
self.inner.pin_mut().rerandomize(pk.as_ref());
}

/// Multiply two ciphertexts
pub fn mul_ciphertext(&self, other: &Self, pk: &BGVPublicKey<C>) -> Self {
ffi::mul_ciphertexts(self.as_ref(), other.as_ref(), pk.as_ref()).into()
Expand Down Expand Up @@ -111,6 +116,24 @@ impl<C: CurveGroup> From<UniquePtr<ffi::CiphertextVector>> for CiphertextVector<
}
}

impl<C: CurveGroup> AsRef<ffi::CiphertextVector> for CiphertextVector<C> {
fn as_ref(&self) -> &ffi::CiphertextVector {
self.inner.as_ref().unwrap()
}
}

impl<C: CurveGroup> ToBytes for CiphertextVector<C> {
fn to_bytes(&self) -> Vec<u8> {
ffi::ciphertext_vector_to_rust_bytes(self.as_ref())
}
}

impl<C: CurveGroup> FromBytesWithParams<C> for CiphertextVector<C> {
fn from_bytes(data: &[u8], params: &BGVParams<C>) -> Self {
ffi::ciphertext_vector_from_rust_bytes(data, params.as_ref()).into()
}
}

impl<C: CurveGroup> CiphertextVector<C> {
/// Create a new `CiphertextVector` with a specified size
pub fn new(size: usize, params: &BGVParams<C>) -> Self {
Expand All @@ -124,10 +147,15 @@ impl<C: CurveGroup> CiphertextVector<C> {
}

/// Get the size of the `CiphertextVector`
pub fn size(&self) -> usize {
pub fn len(&self) -> usize {
ffi::ciphertext_vector_size(self.inner.as_ref().unwrap())
}

/// Whether the vector is empty
pub fn is_empty(&self) -> bool {
self.len() == 0
}

/// Add a `Ciphertext` to the end of the `CiphertextVector`
pub fn push(&mut self, ciphertext: &Ciphertext<C>) {
ffi::push_ciphertext_vector(self.inner.pin_mut(), ciphertext.as_ref());
Expand All @@ -143,6 +171,11 @@ impl<C: CurveGroup> CiphertextVector<C> {
let ciphertext = ffi::get_ciphertext_vector_element(self.inner.as_ref().unwrap(), index);
Ciphertext::from(ciphertext)
}

/// Set a `Ciphertext` at a specific index in the `CiphertextVector`
pub fn set(&mut self, index: usize, ciphertext: &Ciphertext<C>) {
ffi::set_ciphertext_vector_element(self.inner.pin_mut(), index, ciphertext.as_ref());
}
}

// -----------------
Expand Down
10 changes: 10 additions & 0 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ impl<C: CurveGroup> Plaintext<C> {
Self { inner, _phantom: PhantomData }
}

/// Randomize the plaintext
pub fn randomize(&mut self) {
ffi::randomize_plaintext(self.inner.pin_mut());
}

/// Get the number of slots in the plaintext
pub fn num_slots(&self) -> u32 {
self.inner.num_slots()
Expand Down Expand Up @@ -211,6 +216,11 @@ impl<C: CurveGroup> PlaintextVector<C> {
let plaintext = ffi::get_plaintext_vector_element(self.inner.as_ref().unwrap(), index);
Plaintext::from(plaintext)
}

/// Set a `Plaintext` at a specific index in the `PlaintextVector`
pub fn set(&mut self, index: usize, plaintext: &Plaintext<C>) {
ffi::set_plaintext_vector_element(self.inner.pin_mut(), index, plaintext.as_ref());
}
}

// -------------------------------
Expand Down
11 changes: 5 additions & 6 deletions offline-phase/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub(crate) mod test_helpers {
}

/// Run a two-party method with a `LowGear` instance setup and in scope
pub async fn mock_lowgear<F, S, T>(mut f: F) -> (T, T)
pub async fn mock_lowgear<F, S, T>(f: F) -> (T, T)
where
T: Send + 'static,
S: Future<Output = T> + Send + 'static,
Expand All @@ -77,9 +77,8 @@ pub(crate) mod test_helpers {
run_mock_lowgear(f, lowgear1, lowgear2).await
}

/// Run a two-party method with a `LowGear` instance, having run keygen and
/// setup
pub async fn mock_lowgear_with_keys<F, S, T>(mut f: F) -> (T, T)
/// Run a two-party method with a `LowGear` instance, mocking keygen setup
pub async fn mock_lowgear_with_keys<F, S, T>(f: F) -> (T, T)
where
T: Send + 'static,
S: Future<Output = T> + Send + 'static,
Expand All @@ -103,9 +102,9 @@ pub(crate) mod test_helpers {

// Set the local keypairs and mac shares
lowgear1.local_keypair = keypair1.clone();
lowgear1.mac_share = mac_share1.clone();
lowgear1.mac_share = mac_share1;
lowgear2.local_keypair = keypair2.clone();
lowgear2.mac_share = mac_share2.clone();
lowgear2.mac_share = mac_share2;

// Set the exchanged values
lowgear1.other_pk = Some(keypair2.public_key());
Expand Down
108 changes: 106 additions & 2 deletions offline-phase/src/lowgear/triplets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
//!
//! These triples are used to define single-round multiplication in the SPDZ
//! protocol
//!
//! Follows the protocol detailed in https://eprint.iacr.org/2017/1230.pdf (Figure 7)
use ark_ec::CurveGroup;
use ark_mpc::network::MpcNetwork;
use mp_spdz_rs::fhe::{ciphertext::CiphertextPoK, plaintext::PlaintextVector};
use mp_spdz_rs::fhe::{
ciphertext::{CiphertextPoK, CiphertextVector},
plaintext::{Plaintext, PlaintextVector},
};

use crate::error::LowGearError;

Expand All @@ -24,13 +29,112 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
let c = &a * &b;

// Encrypt `a` and send it to the counterparty
let my_proof = self.local_keypair.encrypt_and_prove_vector(&mut a);
let other_a_enc = self.exchange_a_values(&mut a).await?;

// Generate shares of the product and exchange
let c_shares = self.share_product(other_a_enc, &b, c).await?;

Ok(())
}

/// Exchange encryptions of the `a` value
///
/// Returns the counterparty's encryption of `a`
async fn exchange_a_values(
&mut self,
a: &mut PlaintextVector<C>,
) -> Result<CiphertextVector<C>, LowGearError> {
// Encrypt `a` and send it to the counterparty
let my_proof = self.local_keypair.encrypt_and_prove_vector(a);
self.send_message(my_proof).await?;
let mut other_proof: CiphertextPoK<C> = self.receive_message().await?;

let other_pk = self.other_pk.as_ref().expect("setup not run");
let other_a_enc = other_pk.verify_proof(&mut other_proof);

Ok(other_a_enc)
}

/// Create shares of the product `c` by exchanging homomorphically evaluated
/// encryptions of `my_b * other_a`
async fn share_product(
&mut self,
other_enc_a: CiphertextVector<C>,
my_b_share: &PlaintextVector<C>,
my_c_share: PlaintextVector<C>,
) -> Result<PlaintextVector<C>, LowGearError> {
let mut c_res = my_c_share;

// Compute the cross products then share them with the counterparty and compute
// local shares of `c`
let cross_products =
self.compute_triplet_cross_products(&other_enc_a, my_b_share, &mut c_res);
self.exchange_cross_products(cross_products, &mut c_res).await?;

Ok(c_res)
}

/// Compute the cross products in the triplet generation
fn compute_triplet_cross_products(
&mut self,
other_a: &CiphertextVector<C>,
my_b: &PlaintextVector<C>,
my_c: &mut PlaintextVector<C>,
) -> CiphertextVector<C> {
let n = other_a.len();
let mut cross_products = CiphertextVector::new(n, &self.params);

// Compute the cross products of the local party's `b` share and the encryption
// of the counterparty's `a` share
for i in 0..n {
let a_enc = other_a.get(i);
let b = my_b.get(i);
let c = my_c.get(i);

// Compute the product of `my_b` and `other_enc_a`
let mut product = &a_enc * &b;

// Rerandomize the product to add drowning noise and mask it with a random value
product.rerandomize(self.other_pk.as_ref().unwrap());
let mut mask = Plaintext::new(&self.params);
mask.randomize();

let masked_product = &product + &mask;

// Subtract the masked product from our share
let my_share = &c - &mask;
my_c.set(i, &my_share);
cross_products.set(i, &masked_product);
}

cross_products
}

/// Exchange cross products and compute final shares of `c`
async fn exchange_cross_products(
&mut self,
cross_products: CiphertextVector<C>,
my_c_share: &mut PlaintextVector<C>,
) -> Result<(), LowGearError> {
let n = cross_products.len();

// Send and receive cross products to/from the counterparty
self.send_message(cross_products).await?;
let other_cross_products: CiphertextVector<C> = self.receive_message().await?;

// Add each cross product to the local party's share of `c`
for i in 0..n {
let cross_product = other_cross_products.get(i);
let c = my_c_share.get(i);

// Decrypt the term
let cross_product = self.local_keypair.decrypt(&cross_product);

// Add the cross product to the local party's share of `c`
let my_share = &c + &cross_product;
my_c_share.set(i, &my_share);
}

Ok(())
}
}
Expand Down

0 comments on commit ad10787

Please sign in to comment.