Skip to content

Commit

Permalink
offline-phase: lowgear: triplets: Implement initial triplet gen phase
Browse files Browse the repository at this point in the history
This involves each party generating local shares of the triplet,
encrypting their `a` value, generating a ciphertext PoK, and sending it
to the counterparty.

The counterparty then verifies the PoK.
  • Loading branch information
joeykraut committed Apr 9, 2024
1 parent ca4f890 commit c59c6e8
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 4 deletions.
4 changes: 4 additions & 0 deletions mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ mod ffi_inner {

// `PlaintextVector`
type PlaintextVector;
fn new_empty_plaintext_vector() -> UniquePtr<PlaintextVector>;
fn new_plaintext_vector(size: usize, params: &FHE_Params) -> UniquePtr<PlaintextVector>;
fn new_plaintext_vector_single(
plaintext: &Plaintext_mod_prime,
) -> UniquePtr<PlaintextVector>;
fn random_plaintext_vector(size: usize, params: &FHE_Params) -> UniquePtr<PlaintextVector>;
fn get_plaintext_vector_element(
vector: &PlaintextVector,
index: usize,
Expand Down Expand Up @@ -150,8 +152,10 @@ unsafe impl Send for FHE_Params {}
unsafe impl Send for FHE_KeyPair {}
unsafe impl Send for FHE_PK {}
unsafe impl Send for Ciphertext {}
unsafe impl Send for CiphertextVector {}
unsafe impl Send for CiphertextWithProof {}
unsafe impl Send for Plaintext_mod_prime {}
unsafe impl Send for PlaintextVector {}

#[cfg(test)]
mod test {
Expand Down
5 changes: 5 additions & 0 deletions mp-spdz-rs/src/fhe/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ impl<C: CurveGroup> BGVParams<C> {
pub fn plaintext_slots(&self) -> u32 {
self.as_ref().n_plaintext_slots()
}

/// Get the number of ciphertexts that may be proven together
pub fn ciphertext_pok_batch_size(&self) -> usize {
(self.plaintext_slots() as usize) * (DEFAULT_DROWN_SEC as usize)
}
}

impl<C: CurveGroup> Serialize for BGVParams<C> {
Expand Down
84 changes: 81 additions & 3 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ use cxx::UniquePtr;

use crate::{ffi, FromBytesWithParams, ToBytes};

use super::{ffi_bigint_to_scalar, params::BGVParams, scalar_to_ffi_bigint};
use super::{
ffi_bigint_to_scalar,
params::{BGVParams, DEFAULT_DROWN_SEC},
scalar_to_ffi_bigint,
};

/// A plaintext in the BGV implementation
///
Expand Down Expand Up @@ -142,20 +146,51 @@ impl<C: CurveGroup> From<UniquePtr<ffi::PlaintextVector>> for PlaintextVector<C>
impl<C: CurveGroup> PlaintextVector<C> {
/// Create a new `PlaintextVector` with a specified size
pub fn new(size: usize, params: &BGVParams<C>) -> Self {
let inner = crate::ffi::new_plaintext_vector(size, params.as_ref());
let inner = ffi::new_plaintext_vector(size, params.as_ref());
Self { inner, _phantom: PhantomData }
}

/// Create a new empty `PlaintextVector`
pub fn empty() -> Self {
Self { inner: ffi::new_empty_plaintext_vector(), _phantom: PhantomData }
}

/// Generate a random `PlaintextVector` with a specified size
pub fn random(size: usize, params: &BGVParams<C>) -> Self {
let inner = ffi::random_plaintext_vector(size, params.as_ref());
Self { inner, _phantom: PhantomData }
}

/// Get the total number of slots in the `PlaintextVector`
pub fn total_slots(&self) -> usize {
if self.is_empty() {
0
} else {
self.get(0).num_slots() as usize * self.len()
}
}

/// Generate a random `PlaintextVector` of size equal to the batching width
/// of the plaintext PoK proof system
pub fn random_pok_batch(params: &BGVParams<C>) -> Self {
Self::random(DEFAULT_DROWN_SEC as usize, params)
}

/// Get a pinned mutable reference to the inner `PlaintextVector`
pub fn pin_mut(&mut self) -> Pin<&mut ffi::PlaintextVector> {
self.inner.pin_mut()
}

/// Get the size of the `PlaintextVector`
pub fn size(&self) -> usize {
pub fn len(&self) -> usize {
ffi::plaintext_vector_size(self.inner.as_ref().unwrap())
}

/// Whether the vector is empty
pub fn is_empty(&self) -> bool {
self.len() == 0
}

/// Add a `Plaintext` to the end of the `PlaintextVector`
pub fn push(&mut self, plaintext: &Plaintext<C>) {
ffi::push_plaintext_vector(self.inner.pin_mut(), plaintext.as_ref());
Expand All @@ -178,6 +213,49 @@ impl<C: CurveGroup> PlaintextVector<C> {
}
}

// -------------------------------
// | Plaintext Vector Arithmetic |
// -------------------------------

impl<C: CurveGroup> Add for &PlaintextVector<C> {
type Output = PlaintextVector<C>;

fn add(self, other: Self) -> Self::Output {
let mut result = PlaintextVector::empty();
for i in 0..self.len() {
let element = &self.get(i) + &other.get(i);
result.push(&element);
}
result
}
}

impl<C: CurveGroup> Sub for &PlaintextVector<C> {
type Output = PlaintextVector<C>;

fn sub(self, other: Self) -> Self::Output {
let mut result = PlaintextVector::empty();
for i in 0..self.len() {
let element = &self.get(i) - &other.get(i);
result.push(&element);
}
result
}
}

impl<C: CurveGroup> Mul for &PlaintextVector<C> {
type Output = PlaintextVector<C>;

fn mul(self, other: Self) -> Self::Output {
let mut result = PlaintextVector::empty();
for i in 0..self.len() {
let element = &self.get(i) * &other.get(i);
result.push(&element);
}
result
}
}

#[cfg(test)]
mod tests {
use rand::thread_rng;
Expand Down
59 changes: 58 additions & 1 deletion offline-phase/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ pub(crate) mod test_helpers {
};
use futures::Future;
use mp_spdz_rs::fhe::{
ciphertext::Ciphertext, keys::BGVPublicKey, params::BGVParams, plaintext::Plaintext,
ciphertext::Ciphertext,
keys::{BGVKeypair, BGVPublicKey},
params::BGVParams,
plaintext::Plaintext,
};
use rand::thread_rng;

use crate::lowgear::LowGear;

Expand Down Expand Up @@ -70,6 +74,59 @@ pub(crate) mod test_helpers {
let lowgear1 = LowGear::new(net1);
let lowgear2 = LowGear::new(net2);

run_mock_lowgear(f, lowgear1, lowgear2).await
}

/// Run a two-party method with a `LowGear` instance, having run keygen and
/// setup
pub async fn mock_lowgear_with_keys<F, S, T>(mut f: F) -> (T, T)
where
T: Send + 'static,
S: Future<Output = T> + Send + 'static,
F: FnMut(LowGear<TestCurve, MockNetwork<TestCurve>>) -> S,
{
let mut rng = thread_rng();
let (stream1, stream2) = UnboundedDuplexStream::new_duplex_pair();
let net1 = MockNetwork::new(PARTY0, stream1);
let net2 = MockNetwork::new(PARTY1, stream2);

let mut lowgear1 = LowGear::new(net1);
let mut lowgear2 = LowGear::new(net2);

// Setup the lowgear instances
let params = BGVParams::new_no_mults();
let keypair1 = BGVKeypair::gen(&params);
let keypair2 = BGVKeypair::gen(&params);

let mac_share1 = Scalar::random(&mut rng);
let mac_share2 = Scalar::random(&mut rng);

// Set the local keypairs and mac shares
lowgear1.local_keypair = keypair1.clone();
lowgear1.mac_share = mac_share1.clone();
lowgear2.local_keypair = keypair2.clone();
lowgear2.mac_share = mac_share2.clone();

// Set the exchanged values
lowgear1.other_pk = Some(keypair2.public_key());
lowgear1.other_mac_enc = Some(encrypt_val(mac_share2, &keypair2.public_key(), &params));
lowgear2.other_pk = Some(keypair1.public_key());
lowgear2.other_mac_enc = Some(encrypt_val(mac_share1, &keypair1.public_key(), &params));

run_mock_lowgear(f, lowgear1, lowgear2).await
}

/// Run a two-party protocol using the given `LowGear` instances
pub async fn run_mock_lowgear<F, S, T>(
mut f: F,
lowgear1: LowGear<TestCurve, MockNetwork<TestCurve>>,
lowgear2: LowGear<TestCurve, MockNetwork<TestCurve>>,
) -> (T, T)
where
T: Send + 'static,
S: Future<Output = T> + Send + 'static,
F: FnMut(LowGear<TestCurve, MockNetwork<TestCurve>>) -> S,
{
let task1 = tokio::spawn(f(lowgear1));
let task2 = tokio::spawn(f(lowgear2));
let party0_out = task1.await.unwrap();
Expand Down
1 change: 1 addition & 0 deletions offline-phase/src/lowgear/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! keys, authenticating inputs, etc
pub mod setup;
pub mod triplets;

use ark_ec::CurveGroup;
use ark_mpc::{
Expand Down
50 changes: 50 additions & 0 deletions offline-phase/src/lowgear/triplets.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//! Defines the logic for generating shared triples (a, b, c) which satisfy the
//! identity:
//! a * b = c
//!
//! These triples are used to define single-round multiplication in the SPDZ
//! protocol
use ark_ec::CurveGroup;
use ark_mpc::network::MpcNetwork;
use mp_spdz_rs::fhe::{ciphertext::CiphertextPoK, plaintext::PlaintextVector};

use crate::error::LowGearError;

use super::LowGear;

impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
/// Generate a single batch of shared triples
pub async fn generate_triples(&mut self) -> Result<(), LowGearError> {
// First step; generate random values a and b
let mut a = PlaintextVector::random_pok_batch(&self.params);
let b = PlaintextVector::random_pok_batch(&self.params);

// Compute a plaintext multiplication
let c = &a * &b;

// Encrypt `a` and send it to the counterparty
let my_proof = self.local_keypair.encrypt_and_prove_vector(&mut a);
self.send_message(my_proof).await?;
let mut other_proof: CiphertextPoK<C> = self.receive_message().await?;

let other_pk = self.other_pk.as_ref().expect("setup not run");
let other_a_enc = other_pk.verify_proof(&mut other_proof);

Ok(())
}
}

#[cfg(test)]
mod test {
use crate::test_helpers::mock_lowgear_with_keys;

/// Tests the basic triplet generation flow
#[tokio::test]
async fn test_triplet_gen() {
mock_lowgear_with_keys(|mut lowgear| async move {
lowgear.generate_triples().await.unwrap();
})
.await;
}
}

0 comments on commit c59c6e8

Please sign in to comment.