diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs index 87b7d69..ff188e3 100644 --- a/mp-spdz-rs/src/ffi.rs +++ b/mp-spdz-rs/src/ffi.rs @@ -1,5 +1,6 @@ //! The FFI bindings for the MP-SPDZ library +#[allow(clippy::missing_safety_doc)] #[cxx::bridge] mod ffi_inner { unsafe extern "C++" { @@ -11,6 +12,8 @@ mod ffi_inner { // `bigint` type bigint; fn print(self: &bigint); + unsafe fn bigint_from_be_bytes(data: *mut u8, size: usize) -> UniquePtr; + fn bigint_to_be_bytes(x: &bigint) -> Vec; // `FHE_Params` type FHE_Params; @@ -32,8 +35,10 @@ mod ffi_inner { // `Plaintext` type Plaintext_mod_prime; fn new_plaintext(params: &FHE_Params) -> UniquePtr; - fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32); fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32; + fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32); + fn get_element_bigint(plaintext: &Plaintext_mod_prime, idx: usize) -> UniquePtr; + fn set_element_bigint(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: &bigint); fn add_plaintexts( x: &Plaintext_mod_prime, @@ -102,6 +107,20 @@ mod test { get_element_int(&plaintext, 0) } + /// Tests converting bytes to and from a bigint + #[test] + fn test_bigint_to_from_bytes() { + const N_BYTES: usize = 32; + let mut rng = thread_rng(); + let data = rng.gen::<[u8; N_BYTES]>(); + + // Convert the data to a bigint + let bigint = unsafe { bigint_from_be_bytes(data.as_ptr() as *mut u8, N_BYTES) }; + let res = bigint_to_be_bytes(&bigint); + + assert_eq!(data.to_vec(), res); + } + /// Tests addition of a plaintext to a ciphertext #[test] fn test_plaintext_addition() { diff --git a/mp-spdz-rs/src/fhe/ciphertext.rs b/mp-spdz-rs/src/fhe/ciphertext.rs index 7cd169c..7ff6995 100644 --- a/mp-spdz-rs/src/fhe/ciphertext.rs +++ b/mp-spdz-rs/src/fhe/ciphertext.rs @@ -71,7 +71,8 @@ impl Mul<&Plaintext> for &Ciphertext { #[cfg(test)] mod test { - use rand::{thread_rng, Rng, RngCore}; + use ark_mpc::algebra::Scalar; + use rand::{thread_rng, RngCore}; use crate::fhe::{keys::BGVKeypair, params::BGVParams, plaintext::Plaintext}; use crate::TestCurve; @@ -87,7 +88,10 @@ mod test { } /// Get a plaintext with the given value in the first slot - fn plaintext_int(val: u32, params: &BGVParams) -> Plaintext { + fn plaintext_int( + val: Scalar, + params: &BGVParams, + ) -> Plaintext { let mut plaintext = Plaintext::new(params); plaintext.set_element(0, val); @@ -96,7 +100,7 @@ mod test { /// Get the ciphertext with the given value in the first slot fn encrypt_int( - value: u32, + value: Scalar, keypair: &BGVKeypair, params: &BGVParams, ) -> Ciphertext { @@ -111,8 +115,8 @@ mod test { let (params, mut keypair) = setup_fhe(); // Add a ciphertext with a plaintext - let val1 = rng.next_u32() / 2; - let val2 = rng.next_u32() / 2; + let val1 = rng.next_u64().into(); + let val2 = rng.next_u64().into(); let plaintext = plaintext_int(val2, ¶ms); let ciphertext = encrypt_int(val1, &keypair, ¶ms); @@ -134,9 +138,8 @@ mod test { let (params, mut keypair) = setup_fhe(); // Multiply a ciphertext with a plaintext - let range = 0..(1u32 << 16); - let val1 = rng.gen_range(range.clone()); - let val2 = rng.gen_range(range); + let val1 = rng.next_u64().into(); + let val2 = rng.next_u64().into(); let plaintext = plaintext_int(val2, ¶ms); let ciphertext = encrypt_int(val1, &keypair, ¶ms); @@ -158,8 +161,8 @@ mod test { let (params, mut keypair) = setup_fhe(); // Add two ciphertexts - let val1 = rng.next_u32() / 2; - let val2 = rng.next_u32() / 2; + let val1 = rng.next_u64().into(); + let val2 = rng.next_u64().into(); let ciphertext1 = encrypt_int(val1, &keypair, ¶ms); let ciphertext2 = encrypt_int(val2, &keypair, ¶ms); @@ -181,9 +184,8 @@ mod test { let (params, mut keypair) = setup_fhe(); // Multiply two ciphertexts - let range = 0..(1u32 << 16); - let val1 = rng.gen_range(range.clone()); - let val2 = rng.gen_range(range); + let val1 = rng.next_u64().into(); + let val2 = rng.next_u64().into(); let ciphertext1 = encrypt_int(val1, &keypair, ¶ms); let ciphertext2 = encrypt_int(val2, &keypair, ¶ms); diff --git a/mp-spdz-rs/src/fhe/mod.rs b/mp-spdz-rs/src/fhe/mod.rs index 03b2ae0..7bf3aa2 100644 --- a/mp-spdz-rs/src/fhe/mod.rs +++ b/mp-spdz-rs/src/fhe/mod.rs @@ -2,7 +2,26 @@ //! //! Implements the BGV cryptosystem +use ark_ec::CurveGroup; +use ark_mpc::algebra::Scalar; +use cxx::UniquePtr; + +use crate::ffi::{bigint, bigint_from_be_bytes, bigint_to_be_bytes}; + pub mod ciphertext; pub mod keys; pub mod params; pub mod plaintext; + +/// A helper method to convert a `Scalar` to a `bigint` +pub fn scalar_to_ffi_bigint(x: Scalar) -> UniquePtr { + let mut bytes = x.to_bytes_be(); + unsafe { bigint_from_be_bytes(bytes.as_mut_ptr(), bytes.len()) } +} + +/// A helper method to convert a `bigint` to a `Scalar` +/// +/// Reduces modulo the scalar field's modulus +pub fn ffi_bigint_to_scalar(x: &bigint) -> Scalar { + Scalar::from_be_bytes_mod_order(&bigint_to_be_bytes(x)) +} diff --git a/mp-spdz-rs/src/fhe/plaintext.rs b/mp-spdz-rs/src/fhe/plaintext.rs index 5d2d7aa..e2d1810 100644 --- a/mp-spdz-rs/src/fhe/plaintext.rs +++ b/mp-spdz-rs/src/fhe/plaintext.rs @@ -6,14 +6,15 @@ use std::{ }; use ark_ec::CurveGroup; +use ark_mpc::algebra::Scalar; use cxx::UniquePtr; use crate::ffi::{ - add_plaintexts, get_element_int, mul_plaintexts, new_plaintext, set_element_int, + add_plaintexts, get_element_bigint, mul_plaintexts, new_plaintext, set_element_bigint, sub_plaintexts, Plaintext_mod_prime, }; -use super::params::BGVParams; +use super::{ffi_bigint_to_scalar, params::BGVParams, scalar_to_ffi_bigint}; /// A plaintext in the BGV implementation /// @@ -39,13 +40,15 @@ impl Plaintext { } /// Set the value of an element in the plaintext - pub fn set_element(&mut self, idx: usize, value: u32) { - set_element_int(self.inner.pin_mut(), idx, value) + pub fn set_element(&mut self, idx: usize, value: Scalar) { + let val_bigint = scalar_to_ffi_bigint(value); + set_element_bigint(self.inner.pin_mut(), idx, val_bigint.as_ref().unwrap()); } /// Get the value of an element in the plaintext - pub fn get_element(&self, idx: usize) -> u32 { - get_element_int(self.as_ref(), idx) + pub fn get_element(&self, idx: usize) -> Scalar { + let val_bigint = get_element_bigint(self.as_ref(), idx); + ffi_bigint_to_scalar(val_bigint.as_ref().unwrap()) } } @@ -84,7 +87,7 @@ impl Mul<&Plaintext> for &Plaintext { #[cfg(test)] mod tests { - use rand::{thread_rng, Rng, RngCore}; + use rand::{thread_rng, RngCore}; use super::*; use crate::TestCurve; @@ -98,8 +101,8 @@ mod tests { fn test_add() { let mut rng = thread_rng(); let params = get_params(); - let val1 = rng.next_u32() / 2; - let val2 = rng.next_u32() / 2; + let val1: Scalar = rng.next_u64().into(); + let val2: Scalar = rng.next_u32().into(); let mut plaintext1 = Plaintext::new(¶ms); let mut plaintext2 = Plaintext::new(¶ms); @@ -115,8 +118,8 @@ mod tests { fn test_sub() { let mut rng = thread_rng(); let params = get_params(); - let val1 = rng.next_u32(); - let val2 = rng.gen_range(0..val1); + let val1: Scalar = rng.next_u64().into(); + let val2: Scalar = rng.next_u32().into(); let mut plaintext1 = Plaintext::new(¶ms); let mut plaintext2 = Plaintext::new(¶ms); @@ -132,9 +135,8 @@ mod tests { fn test_mul() { let mut rng = thread_rng(); let params = get_params(); - let range = 0..(1u32 << 16); - let val1 = rng.gen_range(range.clone()); - let val2 = rng.gen_range(range); + let val1: Scalar = rng.next_u64().into(); + let val2: Scalar = rng.next_u64().into(); let mut plaintext1 = Plaintext::new(¶ms); let mut plaintext2 = Plaintext::new(¶ms);