diff --git a/mp-spdz-rs/Cargo.toml b/mp-spdz-rs/Cargo.toml index 93dd495..32a4faa 100644 --- a/mp-spdz-rs/Cargo.toml +++ b/mp-spdz-rs/Cargo.toml @@ -3,7 +3,16 @@ name = "mp-spdz-rs" version = "0.1.0" edition = "2021" +[features] +test-helpers = [] + [dependencies] +# === Arithmetic + Crypto === # +ark-ec = { version = "0.4", features = ["parallel"] } +ark-ff = { version = "0.4", features = ["parallel"] } +ark-mpc = { path = "../online-phase" } + +# === Bindings === # cxx = "1.0" [build-dependencies] @@ -12,4 +21,5 @@ itertools = "0.12.0" pkg-config = "0.3" [dev-dependencies] +ark-bn254 = "0.4" rand = "0.8.4" diff --git a/mp-spdz-rs/build.rs b/mp-spdz-rs/build.rs index 999bbfa..83cf34d 100644 --- a/mp-spdz-rs/build.rs +++ b/mp-spdz-rs/build.rs @@ -64,6 +64,7 @@ fn main() { // Build cache flags println!("cargo:rerun-if-changed=src/include/MP-SPDZ"); + println!("cargo:rerun-if-changed=src/ffi.rs"); } /// Get the vendor of the current host diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs index b5caef1..3120727 100644 --- a/mp-spdz-rs/src/ffi.rs +++ b/mp-spdz-rs/src/ffi.rs @@ -38,6 +38,19 @@ mod ffi_inner { fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32); fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32; + fn add_plaintexts( + x: &Plaintext_mod_prime, + y: &Plaintext_mod_prime, + ) -> UniquePtr; + fn sub_plaintexts( + x: &Plaintext_mod_prime, + y: &Plaintext_mod_prime, + ) -> UniquePtr; + fn mul_plaintexts( + x: &Plaintext_mod_prime, + y: &Plaintext_mod_prime, + ) -> UniquePtr; + // `Ciphertext` type Ciphertext; fn add_plaintext(c0: &Ciphertext, p1: &Plaintext_mod_prime) -> UniquePtr; @@ -81,7 +94,7 @@ mod test { keypair: &FHE_KeyPair, params: &FHE_Params, ) -> UniquePtr { - let mut plaintext = plaintext_int(value, params); + let plaintext = plaintext_int(value, params); encrypt(keypair, &plaintext) } diff --git a/mp-spdz-rs/src/fhe/mod.rs b/mp-spdz-rs/src/fhe/mod.rs new file mode 100644 index 0000000..6b74b20 --- /dev/null +++ b/mp-spdz-rs/src/fhe/mod.rs @@ -0,0 +1,6 @@ +//! FHE primitives exported from MP-SPDZ +//! +//! Implements the BGV cryptosystem + +pub mod params; +pub mod plaintext; diff --git a/mp-spdz-rs/src/fhe/params.rs b/mp-spdz-rs/src/fhe/params.rs new file mode 100644 index 0000000..a3afe68 --- /dev/null +++ b/mp-spdz-rs/src/fhe/params.rs @@ -0,0 +1,44 @@ +//! FHE setup parameters + +use ark_ec::CurveGroup; +use ark_mpc::algebra::Scalar; +use std::marker::PhantomData; + +use cxx::UniquePtr; + +use crate::ffi::{new_fhe_params, FHE_Params}; + +/// The default drowning security parameter +const DEFAULT_DROWN_SEC: i32 = 128; + +/// A wrapper around the MP-SPDZ `FHE_Params` struct +pub struct BGVParams { + /// The wrapped MP-SPDZ `FHE_Params` + pub(crate) inner: UniquePtr, + /// Phantom + _phantom: PhantomData, +} + +impl AsRef for BGVParams { + fn as_ref(&self) -> &FHE_Params { + self.inner.as_ref().unwrap() + } +} + +impl BGVParams { + /// Create a new set of FHE parameters + pub fn new(n_mults: u32) -> Self { + let mut inner = new_fhe_params(n_mults as i32, DEFAULT_DROWN_SEC); + + // Generate the parameters + let bits = Scalar::::bit_length() as i32; + inner.pin_mut().basic_generation_mod_prime(bits); + + Self { inner, _phantom: PhantomData } + } + + /// Create a new set of FHE parameters that supports zero multiplications + pub fn new_no_mults() -> Self { + Self::new(0) + } +} diff --git a/mp-spdz-rs/src/fhe/plaintext.rs b/mp-spdz-rs/src/fhe/plaintext.rs new file mode 100644 index 0000000..5d2d7aa --- /dev/null +++ b/mp-spdz-rs/src/fhe/plaintext.rs @@ -0,0 +1,148 @@ +//! Wrapper around an MP-SPDZ plaintext that exports a rust-friendly interface + +use std::{ + marker::PhantomData, + ops::{Add, Mul, Sub}, +}; + +use ark_ec::CurveGroup; +use cxx::UniquePtr; + +use crate::ffi::{ + add_plaintexts, get_element_int, mul_plaintexts, new_plaintext, set_element_int, + sub_plaintexts, Plaintext_mod_prime, +}; + +use super::params::BGVParams; + +/// A plaintext in the BGV implementation +/// +/// The plaintext is defined over the Scalar field of the curve group +pub struct Plaintext { + /// The wrapped MP-SPDZ `Plaintext_mod_prime` + inner: UniquePtr, + /// Phantom + _phantom: PhantomData, +} + +impl AsRef for Plaintext { + fn as_ref(&self) -> &Plaintext_mod_prime { + self.inner.as_ref().unwrap() + } +} + +impl Plaintext { + /// Create a new plaintext + pub fn new(params: &BGVParams) -> Self { + let inner = new_plaintext(params.as_ref()); + Self { inner, _phantom: PhantomData } + } + + /// Set the value of an element in the plaintext + pub fn set_element(&mut self, idx: usize, value: u32) { + set_element_int(self.inner.pin_mut(), idx, value) + } + + /// Get the value of an element in the plaintext + pub fn get_element(&self, idx: usize) -> u32 { + get_element_int(self.as_ref(), idx) + } +} + +impl From> for Plaintext { + fn from(inner: UniquePtr) -> Self { + Self { inner, _phantom: PhantomData } + } +} + +// -------------- +// | Arithmetic | +// -------------- + +impl Add for &Plaintext { + type Output = Plaintext; + + fn add(self, rhs: Self) -> Self::Output { + add_plaintexts(self.as_ref(), rhs.as_ref()).into() + } +} +impl Sub for &Plaintext { + type Output = Plaintext; + + fn sub(self, rhs: Self) -> Self::Output { + sub_plaintexts(self.as_ref(), rhs.as_ref()).into() + } +} + +impl Mul<&Plaintext> for &Plaintext { + type Output = Plaintext; + + fn mul(self, rhs: &Plaintext) -> Self::Output { + mul_plaintexts(self.as_ref(), rhs.as_ref()).into() + } +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, Rng, RngCore}; + + use super::*; + use crate::TestCurve; + + /// A helper to get parameters for the tests + fn get_params() -> BGVParams { + BGVParams::new(1 /* n_mults */) + } + + #[test] + fn test_add() { + let mut rng = thread_rng(); + let params = get_params(); + let val1 = rng.next_u32() / 2; + let val2 = rng.next_u32() / 2; + + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + plaintext1.set_element(0, val1); + plaintext2.set_element(0, val2); + + let expected = val1 + val2; + let result = &plaintext1 + &plaintext2; + assert_eq!(result.get_element(0), expected); + } + + #[test] + fn test_sub() { + let mut rng = thread_rng(); + let params = get_params(); + let val1 = rng.next_u32(); + let val2 = rng.gen_range(0..val1); + + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + plaintext1.set_element(0, val1); + plaintext2.set_element(0, val2); + + let expected = val1 - val2; + let result = &plaintext1 - &plaintext2; + assert_eq!(result.get_element(0), expected); + } + + #[test] + fn test_mul() { + let mut rng = thread_rng(); + let params = get_params(); + let range = 0..(1u32 << 16); + let val1 = rng.gen_range(range.clone()); + let val2 = rng.gen_range(range); + + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + plaintext1.set_element(0, val1); + plaintext2.set_element(0, val2); + + let expected = val1 * val2; + let result = &plaintext1 * &plaintext2; + assert_eq!(result.get_element(0), expected); + } +} diff --git a/mp-spdz-rs/src/lib.rs b/mp-spdz-rs/src/lib.rs index 8ebb0c3..bd29b85 100644 --- a/mp-spdz-rs/src/lib.rs +++ b/mp-spdz-rs/src/lib.rs @@ -5,3 +5,12 @@ //! and to internalize build and link procedure with the foreign ABI pub mod ffi; +pub mod fhe; + +#[cfg(test)] +mod test_helpers { + /// The curve group to use for testing + pub type TestCurve = ark_bn254::G1Projective; +} +#[cfg(test)] +pub(crate) use test_helpers::*; diff --git a/online-phase/src/algebra/scalar/scalar.rs b/online-phase/src/algebra/scalar/scalar.rs index ccc53c5..e390c70 100644 --- a/online-phase/src/algebra/scalar/scalar.rs +++ b/online-phase/src/algebra/scalar/scalar.rs @@ -68,6 +68,11 @@ impl Scalar { self.0 } + /// Get the bit length of the scalar + pub fn bit_length() -> usize { + C::ScalarField::MODULUS_BIT_SIZE as usize + } + /// Sample a random field element pub fn random(rng: &mut R) -> Self { Self(C::ScalarField::rand(rng))