Skip to content

Commit

Permalink
offline-phase: lowgear: triplets: Implement triplet authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Apr 11, 2024
1 parent 0a8a390 commit a6e9482
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 26 deletions.
4 changes: 4 additions & 0 deletions mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,13 @@ pub use ffi_inner::*;
unsafe impl Send for FHE_Params {}
unsafe impl Send for FHE_KeyPair {}
unsafe impl Send for FHE_PK {}
unsafe impl Sync for FHE_PK {}
unsafe impl Send for Ciphertext {}
unsafe impl Sync for Ciphertext {}
unsafe impl Send for CiphertextVector {}
unsafe impl Sync for CiphertextVector {}
unsafe impl Send for CiphertextWithProof {}
unsafe impl Sync for CiphertextWithProof {}
unsafe impl Send for Plaintext_mod_prime {}
unsafe impl Send for PlaintextVector {}
unsafe impl Sync for PlaintextVector {}
Expand Down
3 changes: 3 additions & 0 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ impl<C: CurveGroup> Add for &PlaintextVector<C> {
type Output = PlaintextVector<C>;

fn add(self, other: Self) -> Self::Output {
assert_eq!(self.len(), other.len(), "Vectors must be the same length");
let mut result = PlaintextVector::empty();
for i in 0..self.len() {
let element = &self.get(i) + &other.get(i);
Expand All @@ -244,6 +245,7 @@ impl<C: CurveGroup> Sub for &PlaintextVector<C> {
type Output = PlaintextVector<C>;

fn sub(self, other: Self) -> Self::Output {
assert_eq!(self.len(), other.len(), "Vectors must be the same length");
let mut result = PlaintextVector::empty();
for i in 0..self.len() {
let element = &self.get(i) - &other.get(i);
Expand All @@ -257,6 +259,7 @@ impl<C: CurveGroup> Mul for &PlaintextVector<C> {
type Output = PlaintextVector<C>;

fn mul(self, other: Self) -> Self::Output {
assert_eq!(self.len(), other.len(), "Vectors must be the same length");
let mut result = PlaintextVector::empty();
for i in 0..self.len() {
let element = &self.get(i) * &other.get(i);
Expand Down
2 changes: 1 addition & 1 deletion mp-spdz-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#![feature(inherent_associated_types)]
#![feature(stmt_expr_attributes)]

pub mod ffi;
mod ffi;
pub mod fhe;

/// A trait for serializing to bytes
Expand Down
22 changes: 20 additions & 2 deletions offline-phase/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ pub(crate) mod test_helpers {
pt
}

/// Get a plaintext with a single value in all slots
pub fn plaintext_all<C: CurveGroup>(val: Scalar<C>, params: &BGVParams<C>) -> Plaintext<C> {
let mut pt = Plaintext::new(params);
pt.set_all(val);

pt
}

/// Encrypt a single value using the BGV cryptosystem
///
/// Places the element in the zeroth slot of the plaintext
Expand All @@ -60,6 +68,16 @@ pub(crate) mod test_helpers {
key.encrypt(&pt)
}

/// Encrypt a single value in all slots of a plaintext
pub fn encrypt_all<C: CurveGroup>(
val: Scalar<C>,
key: &BGVPublicKey<C>,
params: &BGVParams<C>,
) -> Ciphertext<C> {
let pt = plaintext_all(val, params);
key.encrypt(&pt)
}

/// Run a two-party method with a `LowGear` instance setup and in scope
pub async fn mock_lowgear<F, S, T>(f: F) -> (T, T)
where
Expand Down Expand Up @@ -108,9 +126,9 @@ pub(crate) mod test_helpers {

// Set the exchanged values
lowgear1.other_pk = Some(keypair2.public_key());
lowgear1.other_mac_enc = Some(encrypt_val(mac_share2, &keypair2.public_key(), &params));
lowgear1.other_mac_enc = Some(encrypt_all(mac_share2, &keypair2.public_key(), &params));
lowgear2.other_pk = Some(keypair1.public_key());
lowgear2.other_mac_enc = Some(encrypt_val(mac_share1, &keypair1.public_key(), &params));
lowgear2.other_mac_enc = Some(encrypt_all(mac_share1, &keypair1.public_key(), &params));

run_mock_lowgear(f, lowgear1, lowgear2).await
}
Expand Down
44 changes: 40 additions & 4 deletions offline-phase/src/lowgear/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ use ark_mpc::{
use futures::{SinkExt, StreamExt};
use mp_spdz_rs::{
fhe::{
ciphertext::Ciphertext,
ciphertext::{Ciphertext, CiphertextVector},
keys::{BGVKeypair, BGVPublicKey},
params::BGVParams,
plaintext::{Plaintext, PlaintextVector},
},
FromBytesWithParams, ToBytes,
};
Expand All @@ -37,6 +38,8 @@ pub struct LowGear<C: CurveGroup, N: MpcNetwork<C>> {
pub other_mac_enc: Option<Ciphertext<C>>,
/// The Beaver triples generated during the offline phase
pub triples: Vec<(Scalar<C>, Scalar<C>, Scalar<C>)>,
/// The mac values for the triples generated during the offline phase
pub triple_macs: Vec<(Scalar<C>, Scalar<C>, Scalar<C>)>,
/// A reference to the underlying network connection
pub network: N,
}
Expand All @@ -57,6 +60,7 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
other_pk: None,
other_mac_enc: None,
triples: vec![],
triple_macs: vec![],
network,
}
}
Expand All @@ -72,8 +76,40 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
})
}

/// Get a plaintext with the local mac share in all slots
pub fn get_mac_plaintext(&self) -> Plaintext<C> {
let mut pt = Plaintext::new(&self.params);
pt.set_all(self.mac_share);

pt
}

/// Get a plaintext vector wherein each plaintext created with the local mac
/// share in all slots
pub fn get_mac_plaintext_vector(&self, n: usize) -> PlaintextVector<C> {
let mut vec = PlaintextVector::new(n, &self.params);
let pt = self.get_mac_plaintext();
for i in 0..n {
vec.set(i, &pt);
}

vec
}

/// Get a ciphertext vector wherein each ciphertext is an encryption of the
/// counterparty's mac share
pub fn get_other_mac_enc(&self, n: usize) -> CiphertextVector<C> {
let mut vec = CiphertextVector::new(n, &self.params);
let ct = self.other_mac_enc.as_ref().unwrap();
for i in 0..n {
vec.set(i, ct);
}

vec
}

/// Send a message to the counterparty
pub async fn send_message<T: ToBytes>(&mut self, message: T) -> Result<(), LowGearError> {
pub async fn send_message<T: ToBytes>(&mut self, message: &T) -> Result<(), LowGearError> {
let payload = NetworkPayload::<C>::Bytes(message.to_bytes());
let msg = NetworkOutbound { result_id: 0, payload };

Expand Down Expand Up @@ -145,12 +181,12 @@ mod test {
let party = lowgear.network.party_id();
if party == PARTY0 {
let msg = TestMessage(MSG1.to_string());
lowgear.send_message(msg).await.unwrap();
lowgear.send_message(&msg).await.unwrap();
lowgear.receive_message::<TestMessage>().await.unwrap()
} else {
let msg = TestMessage(MSG2.to_string());
let recv = lowgear.receive_message::<TestMessage>().await.unwrap();
lowgear.send_message(msg).await.unwrap();
lowgear.send_message(&msg).await.unwrap();

recv
}
Expand Down
8 changes: 4 additions & 4 deletions offline-phase/src/lowgear/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
/// Exchange BGV public keys and mac shares with the counterparty
pub async fn run_key_exchange(&mut self) -> Result<(), LowGearError> {
// First, share the public key
self.send_message(self.local_keypair.public_key()).await?;
self.send_message(&self.local_keypair.public_key()).await?;
let counterparty_pk: BGVPublicKey<C> = self.receive_message().await?;

// Encrypt my mac share under my public key
Expand All @@ -21,7 +21,7 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
let ct = self.local_keypair.encrypt_and_prove(&pt);

// Send and receive
self.send_message(ct).await?;
self.send_message(&ct).await?;
let mut counterparty_mac_pok: CiphertextPoK<C> = self.receive_message().await?;
let counterparty_mac_enc = counterparty_pk.verify_proof(&mut counterparty_mac_pok);

Expand Down Expand Up @@ -60,7 +60,7 @@ mod test {
let encrypted_val =
encrypt_val(my_val, lowgear.other_pk.as_ref().unwrap(), &lowgear.params);

lowgear.send_message(encrypted_val).await.unwrap();
lowgear.send_message(&encrypted_val).await.unwrap();
let received_val: Ciphertext<TestCurve> = lowgear.receive_message().await.unwrap();

let decrypted_val = lowgear.local_keypair.decrypt(&received_val);
Expand All @@ -72,7 +72,7 @@ mod test {
let ct = lowgear.other_mac_enc.as_ref().unwrap() * &pt;

// Send the result to the other party
lowgear.send_message(ct).await.unwrap();
lowgear.send_message(&ct).await.unwrap();
let received_val: Ciphertext<TestCurve> = lowgear.receive_message().await.unwrap();

let decrypted_val = lowgear.local_keypair.decrypt(&received_val);
Expand Down
Loading

0 comments on commit a6e9482

Please sign in to comment.