Skip to content

Commit

Permalink
mp-spdz-rs: fhe: Add Ciphertext and keypair bindings + arithmetic
Browse files Browse the repository at this point in the history
This provides a high level interface over the the BGV implementation in
MP-SPDZ that avoids calling directly into the ffi.
  • Loading branch information
joeykraut committed Apr 4, 2024
1 parent 4604eaf commit 230a2db
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 23 deletions.
46 changes: 23 additions & 23 deletions mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@ mod ffi_inner {
fn new_keypair(params: &FHE_Params) -> UniquePtr<FHE_KeyPair>;
fn get_pk(keypair: &FHE_KeyPair) -> UniquePtr<FHE_PK>;
fn get_sk(keypair: &FHE_KeyPair) -> UniquePtr<FHE_SK>;
fn encrypt(keypair: &FHE_KeyPair, plaintext: &Plaintext_mod_prime)
-> UniquePtr<Ciphertext>;
fn decrypt(
keypair: Pin<&mut FHE_KeyPair>,
ciphertext: &Ciphertext,
) -> UniquePtr<Plaintext_mod_prime>;
fn encrypt(pk: &FHE_PK, plaintext: &Plaintext_mod_prime) -> UniquePtr<Ciphertext>;
fn decrypt(sk: Pin<&mut FHE_SK>, ciphertext: &Ciphertext)
-> UniquePtr<Plaintext_mod_prime>;

// `Plaintext`
type Plaintext_mod_prime;
Expand Down Expand Up @@ -95,14 +92,21 @@ mod test {
params: &FHE_Params,
) -> UniquePtr<Ciphertext> {
let plaintext = plaintext_int(value, params);
encrypt(keypair, &plaintext)
encrypt(&get_pk(keypair), &plaintext)
}

/// Decrypt a ciphertext and return the plaintext element in the zero'th
/// slot
fn decrypt_int(keypair: &FHE_KeyPair, ciphertext: &Ciphertext) -> u32 {
let plaintext = decrypt(get_sk(keypair).pin_mut(), ciphertext);
get_element_int(&plaintext, 0)
}

/// Tests addition of a plaintext to a ciphertext
#[test]
fn test_plaintext_addition() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe(0, 254);
let (params, keypair) = setup_fhe(0, 254);

// Add a plaintext to a ciphertext
let val1 = rng.next_u32() / 2;
Expand All @@ -114,18 +118,17 @@ mod test {
let sum = add_plaintext(ciphertext.as_ref().unwrap(), plaintext.as_ref().unwrap());

// Decrypt the sum
let plaintext_res = decrypt(keypair.pin_mut(), &sum);
let pt_u32 = get_element_int(&plaintext_res, 0);
let plaintext_res = decrypt_int(keypair.as_ref().unwrap(), &sum);
let expected = val1 + val2;

assert_eq!(pt_u32, expected);
assert_eq!(plaintext_res, expected);
}

/// Tests multiplication of a plaintext to a ciphertext
#[test]
fn test_plaintext_multiplication() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe(1, 254);
let (params, keypair) = setup_fhe(1, 254);

// Multiply a plaintext to a ciphertext
let range = 0..(2u32.pow(16));
Expand All @@ -138,18 +141,17 @@ mod test {
let product = mul_plaintext(ciphertext.as_ref().unwrap(), plaintext.as_ref().unwrap());

// Decrypt the product
let plaintext_res = decrypt(keypair.pin_mut(), &product);
let pt_u32 = get_element_int(&plaintext_res, 0);
let plaintext_res = decrypt_int(keypair.as_ref().unwrap(), &product);
let expected = val1 * val2;

assert_eq!(pt_u32, expected);
assert_eq!(plaintext_res, expected);
}

/// Tests addition of two encrypted values
#[test]
fn test_encrypted_addition() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe(0, 254);
let (params, keypair) = setup_fhe(0, 254);

// Add two ciphertexts, divide by two to avoid overflow
let val1 = rng.next_u32() / 2;
Expand All @@ -161,18 +163,17 @@ mod test {
let sum = add_ciphertexts(cipher1.as_ref().unwrap(), cipher2.as_ref().unwrap());

// Decrypt the sum
let plaintext_res = decrypt(keypair.pin_mut(), &sum);
let pt_u32 = get_element_int(&plaintext_res, 0);
let plaintext_res = decrypt_int(keypair.as_ref().unwrap(), &sum);
let expected = val1 + val2;

assert_eq!(pt_u32, expected);
assert_eq!(plaintext_res, expected);
}

/// Tests multiplication of two encrypted values
#[test]
fn test_encrypted_multiplication() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe(1, 254);
let (params, keypair) = setup_fhe(1, 254);

// Multiply two ciphertexts; capped bit length to avoid overflow
let range = 0..(2u32.pow(16));
Expand All @@ -190,10 +191,9 @@ mod test {
);

// Decrypt the product
let plaintext_res = decrypt(keypair.pin_mut(), &product);
let pt_u32 = get_element_int(&plaintext_res, 0);
let plaintext_res = decrypt_int(keypair.as_ref().unwrap(), &product);
let expected = val1 * val2;

assert_eq!(pt_u32, expected);
assert_eq!(plaintext_res, expected);
}
}
200 changes: 200 additions & 0 deletions mp-spdz-rs/src/fhe/ciphertext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
//! Ciphertext wrapper around the MP-SPDZ `Ciphertext` struct
use std::{
marker::PhantomData,
ops::{Add, Mul},
};

use ark_ec::CurveGroup;
use cxx::UniquePtr;

use crate::ffi::{
add_ciphertexts as ffi_add_cipher, add_plaintext as ffi_add_plaintext,
mul_ciphertexts as ffi_mul_ciphertext, mul_plaintext as ffi_mul_plaintext,
Ciphertext as FfiCiphertext,
};

use super::{keys::BGVPublicKey, plaintext::Plaintext};

/// A ciphertext in the BGV implementation
///
/// The ciphertext is defined over the Scalar field of the curve group
pub struct Ciphertext<C: CurveGroup> {
/// The wrapped MP-SPDZ `Ciphertext`
pub(crate) inner: UniquePtr<FfiCiphertext>,
/// Phantom
_phantom: PhantomData<C>,
}

impl<C: CurveGroup> Ciphertext<C> {
/// Multiply two ciphertexts
pub fn mul_ciphertext(&self, other: &Self, pk: &BGVPublicKey<C>) -> Self {
ffi_mul_ciphertext(self.as_ref(), other.as_ref(), pk.as_ref()).into()
}
}

impl<C: CurveGroup> From<UniquePtr<FfiCiphertext>> for Ciphertext<C> {
fn from(inner: UniquePtr<FfiCiphertext>) -> Self {
Self { inner, _phantom: PhantomData }
}
}

impl<C: CurveGroup> AsRef<FfiCiphertext> for Ciphertext<C> {
fn as_ref(&self) -> &FfiCiphertext {
self.inner.as_ref().unwrap()
}
}

impl<C: CurveGroup> Add<&Plaintext<C>> for &Ciphertext<C> {
type Output = Ciphertext<C>;

fn add(self, rhs: &Plaintext<C>) -> Self::Output {
ffi_add_plaintext(self.as_ref(), rhs.as_ref()).into()
}
}

impl<C: CurveGroup> Add for &Ciphertext<C> {
type Output = Ciphertext<C>;

fn add(self, rhs: Self) -> Self::Output {
ffi_add_cipher(self.as_ref(), rhs.as_ref()).into()
}
}

impl<C: CurveGroup> Mul<&Plaintext<C>> for &Ciphertext<C> {
type Output = Ciphertext<C>;

fn mul(self, rhs: &Plaintext<C>) -> Self::Output {
ffi_mul_plaintext(self.as_ref(), rhs.as_ref()).into()
}
}

#[cfg(test)]
mod test {
use rand::{thread_rng, Rng, RngCore};

use crate::fhe::{keys::BGVKeypair, params::BGVParams, plaintext::Plaintext};
use crate::TestCurve;

use super::Ciphertext;

/// Setup the FHE scheme
fn setup_fhe() -> (BGVParams<TestCurve>, BGVKeypair<TestCurve>) {
let params = BGVParams::new(1 /* n_mults */);
let keypair = BGVKeypair::gen(&params);

(params, keypair)
}

/// Get a plaintext with the given value in the first slot
fn plaintext_int(val: u32, params: &BGVParams<TestCurve>) -> Plaintext<TestCurve> {
let mut plaintext = Plaintext::new(params);
plaintext.set_element(0, val);

plaintext
}

/// Get the ciphertext with the given value in the first slot
fn encrypt_int(
value: u32,
keypair: &BGVKeypair<TestCurve>,
params: &BGVParams<TestCurve>,
) -> Ciphertext<TestCurve> {
let plaintext = plaintext_int(value, params);
keypair.encrypt(&plaintext)
}

/// Tests addition of a ciphertext with a plaintext
#[test]
fn test_ciphertext_plaintext_addition() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe();

// Add a ciphertext with a plaintext
let val1 = rng.next_u32() / 2;
let val2 = rng.next_u32() / 2;

let plaintext = plaintext_int(val2, &params);
let ciphertext = encrypt_int(val1, &keypair, &params);

let sum = &ciphertext + &plaintext;

// Decrypt the sum
let plaintext_res = keypair.decrypt(&sum);
let res = plaintext_res.get_element(0);
let expected = val1 + val2;

assert_eq!(res, expected);
}

/// Tests multiplication of a ciphertext with a plaintext
#[test]
fn test_ciphertext_plaintext_multiplication() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe();

// Multiply a ciphertext with a plaintext
let range = 0..(1u32 << 16);
let val1 = rng.gen_range(range.clone());
let val2 = rng.gen_range(range);

let plaintext = plaintext_int(val2, &params);
let ciphertext = encrypt_int(val1, &keypair, &params);

let product = &ciphertext * &plaintext;

// Decrypt the product
let plaintext_res = keypair.decrypt(&product);
let res = plaintext_res.get_element(0);
let expected = val1 * val2;

assert_eq!(res, expected);
}

/// Tests addition of two ciphertexts
#[test]
fn test_ciphertext_ciphertext_addition() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe();

// Add two ciphertexts
let val1 = rng.next_u32() / 2;
let val2 = rng.next_u32() / 2;

let ciphertext1 = encrypt_int(val1, &keypair, &params);
let ciphertext2 = encrypt_int(val2, &keypair, &params);

let sum = &ciphertext1 + &ciphertext2;

// Decrypt the sum
let plaintext_res = keypair.decrypt(&sum);
let res = plaintext_res.get_element(0);
let expected = val1 + val2;

assert_eq!(res, expected);
}

/// Tests multiplication of two ciphertexts
#[test]
fn test_ciphertext_ciphertext_multiplication() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe();

// Multiply two ciphertexts
let range = 0..(1u32 << 16);
let val1 = rng.gen_range(range.clone());
let val2 = rng.gen_range(range);

let ciphertext1 = encrypt_int(val1, &keypair, &params);
let ciphertext2 = encrypt_int(val2, &keypair, &params);

let product = ciphertext1.mul_ciphertext(&ciphertext2, &keypair.public_key);

// Decrypt the product
let plaintext_res = keypair.decrypt(&product);
let res = plaintext_res.get_element(0);
let expected = val1 * val2;

assert_eq!(res, expected);
}
}
Loading

0 comments on commit 230a2db

Please sign in to comment.