Skip to content

Commit

Permalink
mp-spdz-rs: fhe: Use Scalars as input to FHE interface
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Apr 4, 2024
1 parent 230a2db commit 193aac0
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 28 deletions.
21 changes: 20 additions & 1 deletion mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! The FFI bindings for the MP-SPDZ library
#[allow(clippy::missing_safety_doc)]
#[cxx::bridge]
mod ffi_inner {
unsafe extern "C++" {
Expand All @@ -11,6 +12,8 @@ mod ffi_inner {
// `bigint`
type bigint;
fn print(self: &bigint);
unsafe fn bigint_from_be_bytes(data: *mut u8, size: usize) -> UniquePtr<bigint>;
fn bigint_to_be_bytes(x: &bigint) -> Vec<u8>;

// `FHE_Params`
type FHE_Params;
Expand All @@ -32,8 +35,10 @@ mod ffi_inner {
// `Plaintext`
type Plaintext_mod_prime;
fn new_plaintext(params: &FHE_Params) -> UniquePtr<Plaintext_mod_prime>;
fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32);
fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32;
fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32);
fn get_element_bigint(plaintext: &Plaintext_mod_prime, idx: usize) -> UniquePtr<bigint>;
fn set_element_bigint(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: &bigint);

fn add_plaintexts(
x: &Plaintext_mod_prime,
Expand Down Expand Up @@ -102,6 +107,20 @@ mod test {
get_element_int(&plaintext, 0)
}

/// Tests converting bytes to and from a bigint
#[test]
fn test_bigint_to_from_bytes() {
const N_BYTES: usize = 32;
let mut rng = thread_rng();
let data = rng.gen::<[u8; N_BYTES]>();

// Convert the data to a bigint
let bigint = unsafe { bigint_from_be_bytes(data.as_ptr() as *mut u8, N_BYTES) };
let res = bigint_to_be_bytes(&bigint);

assert_eq!(data.to_vec(), res);
}

/// Tests addition of a plaintext to a ciphertext
#[test]
fn test_plaintext_addition() {
Expand Down
28 changes: 15 additions & 13 deletions mp-spdz-rs/src/fhe/ciphertext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ impl<C: CurveGroup> Mul<&Plaintext<C>> for &Ciphertext<C> {

#[cfg(test)]
mod test {
use rand::{thread_rng, Rng, RngCore};
use ark_mpc::algebra::Scalar;
use rand::{thread_rng, RngCore};

use crate::fhe::{keys::BGVKeypair, params::BGVParams, plaintext::Plaintext};
use crate::TestCurve;
Expand All @@ -87,7 +88,10 @@ mod test {
}

/// Get a plaintext with the given value in the first slot
fn plaintext_int(val: u32, params: &BGVParams<TestCurve>) -> Plaintext<TestCurve> {
fn plaintext_int(
val: Scalar<TestCurve>,
params: &BGVParams<TestCurve>,
) -> Plaintext<TestCurve> {
let mut plaintext = Plaintext::new(params);
plaintext.set_element(0, val);

Expand All @@ -96,7 +100,7 @@ mod test {

/// Get the ciphertext with the given value in the first slot
fn encrypt_int(
value: u32,
value: Scalar<TestCurve>,
keypair: &BGVKeypair<TestCurve>,
params: &BGVParams<TestCurve>,
) -> Ciphertext<TestCurve> {
Expand All @@ -111,8 +115,8 @@ mod test {
let (params, mut keypair) = setup_fhe();

// Add a ciphertext with a plaintext
let val1 = rng.next_u32() / 2;
let val2 = rng.next_u32() / 2;
let val1 = rng.next_u64().into();
let val2 = rng.next_u64().into();

let plaintext = plaintext_int(val2, &params);
let ciphertext = encrypt_int(val1, &keypair, &params);
Expand All @@ -134,9 +138,8 @@ mod test {
let (params, mut keypair) = setup_fhe();

// Multiply a ciphertext with a plaintext
let range = 0..(1u32 << 16);
let val1 = rng.gen_range(range.clone());
let val2 = rng.gen_range(range);
let val1 = rng.next_u64().into();
let val2 = rng.next_u64().into();

let plaintext = plaintext_int(val2, &params);
let ciphertext = encrypt_int(val1, &keypair, &params);
Expand All @@ -158,8 +161,8 @@ mod test {
let (params, mut keypair) = setup_fhe();

// Add two ciphertexts
let val1 = rng.next_u32() / 2;
let val2 = rng.next_u32() / 2;
let val1 = rng.next_u64().into();
let val2 = rng.next_u64().into();

let ciphertext1 = encrypt_int(val1, &keypair, &params);
let ciphertext2 = encrypt_int(val2, &keypair, &params);
Expand All @@ -181,9 +184,8 @@ mod test {
let (params, mut keypair) = setup_fhe();

// Multiply two ciphertexts
let range = 0..(1u32 << 16);
let val1 = rng.gen_range(range.clone());
let val2 = rng.gen_range(range);
let val1 = rng.next_u64().into();
let val2 = rng.next_u64().into();

let ciphertext1 = encrypt_int(val1, &keypair, &params);
let ciphertext2 = encrypt_int(val2, &keypair, &params);
Expand Down
19 changes: 19 additions & 0 deletions mp-spdz-rs/src/fhe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,26 @@
//!
//! Implements the BGV cryptosystem
use ark_ec::CurveGroup;
use ark_mpc::algebra::Scalar;
use cxx::UniquePtr;

use crate::ffi::{bigint, bigint_from_be_bytes, bigint_to_be_bytes};

pub mod ciphertext;
pub mod keys;
pub mod params;
pub mod plaintext;

/// A helper method to convert a `Scalar` to a `bigint`
pub fn scalar_to_ffi_bigint<C: CurveGroup>(x: Scalar<C>) -> UniquePtr<bigint> {
let mut bytes = x.to_bytes_be();
unsafe { bigint_from_be_bytes(bytes.as_mut_ptr(), bytes.len()) }
}

/// A helper method to convert a `bigint` to a `Scalar`
///
/// Reduces modulo the scalar field's modulus
pub fn ffi_bigint_to_scalar<C: CurveGroup>(x: &bigint) -> Scalar<C> {
Scalar::from_be_bytes_mod_order(&bigint_to_be_bytes(x))
}
30 changes: 16 additions & 14 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ use std::{
};

use ark_ec::CurveGroup;
use ark_mpc::algebra::Scalar;
use cxx::UniquePtr;

use crate::ffi::{
add_plaintexts, get_element_int, mul_plaintexts, new_plaintext, set_element_int,
add_plaintexts, get_element_bigint, mul_plaintexts, new_plaintext, set_element_bigint,
sub_plaintexts, Plaintext_mod_prime,
};

use super::params::BGVParams;
use super::{ffi_bigint_to_scalar, params::BGVParams, scalar_to_ffi_bigint};

/// A plaintext in the BGV implementation
///
Expand All @@ -39,13 +40,15 @@ impl<C: CurveGroup> Plaintext<C> {
}

/// Set the value of an element in the plaintext
pub fn set_element(&mut self, idx: usize, value: u32) {
set_element_int(self.inner.pin_mut(), idx, value)
pub fn set_element(&mut self, idx: usize, value: Scalar<C>) {
let val_bigint = scalar_to_ffi_bigint(value);
set_element_bigint(self.inner.pin_mut(), idx, val_bigint.as_ref().unwrap());
}

/// Get the value of an element in the plaintext
pub fn get_element(&self, idx: usize) -> u32 {
get_element_int(self.as_ref(), idx)
pub fn get_element(&self, idx: usize) -> Scalar<C> {
let val_bigint = get_element_bigint(self.as_ref(), idx);
ffi_bigint_to_scalar(val_bigint.as_ref().unwrap())
}
}

Expand Down Expand Up @@ -84,7 +87,7 @@ impl<C: CurveGroup> Mul<&Plaintext<C>> for &Plaintext<C> {

#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng, RngCore};
use rand::{thread_rng, RngCore};

use super::*;
use crate::TestCurve;
Expand All @@ -98,8 +101,8 @@ mod tests {
fn test_add() {
let mut rng = thread_rng();
let params = get_params();
let val1 = rng.next_u32() / 2;
let val2 = rng.next_u32() / 2;
let val1: Scalar<TestCurve> = rng.next_u64().into();
let val2: Scalar<TestCurve> = rng.next_u32().into();

let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);
Expand All @@ -115,8 +118,8 @@ mod tests {
fn test_sub() {
let mut rng = thread_rng();
let params = get_params();
let val1 = rng.next_u32();
let val2 = rng.gen_range(0..val1);
let val1: Scalar<TestCurve> = rng.next_u64().into();
let val2: Scalar<TestCurve> = rng.next_u32().into();

let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);
Expand All @@ -132,9 +135,8 @@ mod tests {
fn test_mul() {
let mut rng = thread_rng();
let params = get_params();
let range = 0..(1u32 << 16);
let val1 = rng.gen_range(range.clone());
let val2 = rng.gen_range(range);
let val1: Scalar<TestCurve> = rng.next_u64().into();
let val2: Scalar<TestCurve> = rng.next_u64().into();

let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);
Expand Down

0 comments on commit 193aac0

Please sign in to comment.