Skip to content

Commit

Permalink
offline-phase: lowgear: triplets: Complete + test triplet generation
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Apr 11, 2024
1 parent ad10787 commit 2021c15
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 3 deletions.
1 change: 1 addition & 0 deletions offline-phase/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ rand = "0.8"
[dev-dependencies]
ark-bn254 = "0.4"
ark-mpc = { path = "../online-phase", features = ["test_helpers"] }
itertools = "0.10"
tokio = { version = "1", features = ["full"] }
32 changes: 31 additions & 1 deletion offline-phase/src/lowgear/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ pub struct LowGear<C: CurveGroup, N: MpcNetwork<C>> {
pub other_pk: Option<BGVPublicKey<C>>,
/// An encryption of the counterparty's MAC key share under their public key
pub other_mac_enc: Option<Ciphertext<C>>,
/// The Beaver triples generated during the offline phase
pub triples: Vec<(Scalar<C>, Scalar<C>, Scalar<C>)>,
/// A reference to the underlying network connection
pub network: N,
}
Expand All @@ -47,7 +49,15 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
let local_keypair = BGVKeypair::gen(&params);
let mac_share = Scalar::random(&mut rng);

Self { params, local_keypair, mac_share, other_pk: None, other_mac_enc: None, network }
Self {
params,
local_keypair,
mac_share,
other_pk: None,
other_mac_enc: None,
triples: vec![],
network,
}
}

/// Get the setup parameters from the offline phase
Expand All @@ -70,6 +80,18 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
Ok(())
}

/// Send a message to the counterparty that can directly be converted to a
/// network payload
pub async fn send_network_payload<T: Into<NetworkPayload<C>>>(
&mut self,
payload: T,
) -> Result<(), LowGearError> {
let msg = NetworkOutbound { result_id: 0, payload: payload.into() };

self.network.send(msg).await.map_err(|e| LowGearError::SendMessage(e.to_string()))?;
Ok(())
}

/// Receive a message from the counterparty
pub async fn receive_message<T: FromBytesWithParams<C>>(&mut self) -> Result<T, LowGearError> {
let msg = self.network.next().await.unwrap().unwrap();
Expand All @@ -80,6 +102,14 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {

Ok(T::from_bytes(&payload, &self.params))
}

/// Receive a network payload from the counterparty
pub async fn receive_network_payload<T: From<NetworkPayload<C>>>(
&mut self,
) -> Result<T, LowGearError> {
let msg = self.network.next().await.unwrap().unwrap();
Ok(msg.payload.into())
}
}

#[cfg(test)]
Expand Down
57 changes: 55 additions & 2 deletions offline-phase/src/lowgear/triplets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
// Generate shares of the product and exchange
let c_shares = self.share_product(other_a_enc, &b, c).await?;

// Increase the size of self.triples by self.params.ciphertext_pok_batch_size
self.triples.reserve(self.params.ciphertext_pok_batch_size());
for pt_idx in 0..a.len() {
let plaintext_a = a.get(pt_idx);
let plaintext_b = b.get(pt_idx);
let plaintext_c = c_shares.get(pt_idx);

for slot_idx in 0..plaintext_a.num_slots() as usize {
let a = plaintext_a.get_element(slot_idx);
let b = plaintext_b.get_element(slot_idx);
let c = plaintext_c.get_element(slot_idx);

self.triples.push((a, b, c));
}
}

Ok(())
}

Expand Down Expand Up @@ -141,13 +157,50 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {

#[cfg(test)]
mod test {
use crate::test_helpers::mock_lowgear_with_keys;
use ark_mpc::algebra::Scalar;
use itertools::izip;

use crate::test_helpers::{mock_lowgear_with_keys, TestCurve};

/// Tests the basic triplet generation flow
#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_triplet_gen() {
// The number of triplets to test
mock_lowgear_with_keys(|mut lowgear| async move {
lowgear.generate_triples().await.unwrap();

assert_eq!(lowgear.triples.len(), lowgear.params.ciphertext_pok_batch_size());

// Exchange triples
let (mut my_a, mut my_b, mut my_c) = (vec![], vec![], vec![]);
for (a, b, c) in lowgear.triples.iter() {
my_a.push(*a);
my_b.push(*b);
my_c.push(*c);
}

lowgear.send_network_payload(my_a.clone()).await.unwrap();
lowgear.send_network_payload(my_b.clone()).await.unwrap();
lowgear.send_network_payload(my_c.clone()).await.unwrap();
let their_a: Vec<Scalar<TestCurve>> = lowgear.receive_network_payload().await.unwrap();
let their_b: Vec<Scalar<TestCurve>> = lowgear.receive_network_payload().await.unwrap();
let their_c: Vec<Scalar<TestCurve>> = lowgear.receive_network_payload().await.unwrap();

// Add together all the shares to get the final values
for (a_1, a_2, b_1, b_2, c_1, c_2) in izip!(
my_a.iter(),
their_a.iter(),
my_b.iter(),
their_b.iter(),
my_c.iter(),
their_c.iter()
) {
let a = a_1 + a_2;
let b = b_1 + b_2;
let c = c_1 + c_2;

assert_eq!(a * b, c);
}
})
.await;
}
Expand Down
45 changes: 45 additions & 0 deletions online-phase/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,51 @@ impl<C: CurveGroup> From<Vec<CurvePoint<C>>> for NetworkPayload<C> {
}
}

impl<C: CurveGroup> From<NetworkPayload<C>> for Vec<u8> {
fn from(payload: NetworkPayload<C>) -> Self {
match payload {
NetworkPayload::Bytes(bytes) => bytes,
_ => panic!("Expected NetworkPayload::Bytes"),
}
}
}

impl<C: CurveGroup> From<NetworkPayload<C>> for Scalar<C> {
fn from(payload: NetworkPayload<C>) -> Self {
match payload {
NetworkPayload::Scalar(scalar) => scalar,
_ => panic!("Expected NetworkPayload::Scalar"),
}
}
}

impl<C: CurveGroup> From<NetworkPayload<C>> for Vec<Scalar<C>> {
fn from(payload: NetworkPayload<C>) -> Self {
match payload {
NetworkPayload::ScalarBatch(scalars) => scalars,
_ => panic!("Expected NetworkPayload::ScalarBatch"),
}
}
}

impl<C: CurveGroup> From<NetworkPayload<C>> for CurvePoint<C> {
fn from(payload: NetworkPayload<C>) -> Self {
match payload {
NetworkPayload::Point(point) => point,
_ => panic!("Expected NetworkPayload::Point"),
}
}
}

impl<C: CurveGroup> From<NetworkPayload<C>> for Vec<CurvePoint<C>> {
fn from(payload: NetworkPayload<C>) -> Self {
match payload {
NetworkPayload::PointBatch(points) => points,
_ => panic!("Expected NetworkPayload::PointBatch"),
}
}
}

/// The `MpcNetwork` trait defines shared functionality for a network
/// implementing a connection between two parties in a 2PC
///
Expand Down

0 comments on commit 2021c15

Please sign in to comment.