Skip to content

Commit

Permalink
mp-spdz-rs: fhe: Plaintext and params interface
Browse files Browse the repository at this point in the history
Defines a high level interface for interacting with plaintexts and
params, rather than interacting with the FFI bindings directly
  • Loading branch information
joeykraut committed Apr 4, 2024
1 parent 11a5db9 commit 7b108f1
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 1 deletion.
10 changes: 10 additions & 0 deletions mp-spdz-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@ name = "mp-spdz-rs"
version = "0.1.0"
edition = "2021"

[features]
test-helpers = []

[dependencies]
# === Arithmetic + Crypto === #
ark-ec = { version = "0.4", features = ["parallel"] }
ark-ff = { version = "0.4", features = ["parallel"] }
ark-mpc = { path = "../online-phase" }

# === Bindings === #
cxx = "1.0"

[build-dependencies]
Expand All @@ -12,4 +21,5 @@ itertools = "0.12.0"
pkg-config = "0.3"

[dev-dependencies]
ark-bn254 = "0.4"
rand = "0.8.4"
1 change: 1 addition & 0 deletions mp-spdz-rs/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ fn main() {

// Build cache flags
println!("cargo:rerun-if-changed=src/include/MP-SPDZ");
println!("cargo:rerun-if-changed=src/ffi.rs");
}

/// Get the vendor of the current host
Expand Down
15 changes: 14 additions & 1 deletion mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ mod ffi_inner {
fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32);
fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32;

fn add_plaintexts(
x: &Plaintext_mod_prime,
y: &Plaintext_mod_prime,
) -> UniquePtr<Plaintext_mod_prime>;
fn sub_plaintexts(
x: &Plaintext_mod_prime,
y: &Plaintext_mod_prime,
) -> UniquePtr<Plaintext_mod_prime>;
fn mul_plaintexts(
x: &Plaintext_mod_prime,
y: &Plaintext_mod_prime,
) -> UniquePtr<Plaintext_mod_prime>;

// `Ciphertext`
type Ciphertext;
fn add_plaintext(c0: &Ciphertext, p1: &Plaintext_mod_prime) -> UniquePtr<Ciphertext>;
Expand Down Expand Up @@ -81,7 +94,7 @@ mod test {
keypair: &FHE_KeyPair,
params: &FHE_Params,
) -> UniquePtr<Ciphertext> {
let mut plaintext = plaintext_int(value, params);
let plaintext = plaintext_int(value, params);
encrypt(keypair, &plaintext)
}

Expand Down
6 changes: 6 additions & 0 deletions mp-spdz-rs/src/fhe/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//! FHE primitives exported from MP-SPDZ
//!
//! Implements the BGV cryptosystem
pub mod params;
pub mod plaintext;
44 changes: 44 additions & 0 deletions mp-spdz-rs/src/fhe/params.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//! FHE setup parameters
use ark_ec::CurveGroup;
use ark_mpc::algebra::Scalar;
use std::marker::PhantomData;

use cxx::UniquePtr;

use crate::ffi::{new_fhe_params, FHE_Params};

/// The default drowning security parameter
const DEFAULT_DROWN_SEC: i32 = 128;

/// A wrapper around the MP-SPDZ `FHE_Params` struct
pub struct BGVParams<C: CurveGroup> {
/// The wrapped MP-SPDZ `FHE_Params`
pub(crate) inner: UniquePtr<FHE_Params>,
/// Phantom
_phantom: PhantomData<C>,
}

impl<C: CurveGroup> AsRef<FHE_Params> for BGVParams<C> {
fn as_ref(&self) -> &FHE_Params {
self.inner.as_ref().unwrap()
}
}

impl<C: CurveGroup> BGVParams<C> {
/// Create a new set of FHE parameters
pub fn new(n_mults: u32) -> Self {
let mut inner = new_fhe_params(n_mults as i32, DEFAULT_DROWN_SEC);

// Generate the parameters
let bits = Scalar::<C>::bit_length() as i32;
inner.pin_mut().basic_generation_mod_prime(bits);

Self { inner, _phantom: PhantomData }
}

/// Create a new set of FHE parameters that supports zero multiplications
pub fn new_no_mults() -> Self {
Self::new(0)
}
}
148 changes: 148 additions & 0 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
//! Wrapper around an MP-SPDZ plaintext that exports a rust-friendly interface
use std::{
marker::PhantomData,
ops::{Add, Mul, Sub},
};

use ark_ec::CurveGroup;
use cxx::UniquePtr;

use crate::ffi::{
add_plaintexts, get_element_int, mul_plaintexts, new_plaintext, set_element_int,
sub_plaintexts, Plaintext_mod_prime,
};

use super::params::BGVParams;

/// A plaintext in the BGV implementation
///
/// The plaintext is defined over the Scalar field of the curve group
pub struct Plaintext<C: CurveGroup> {
/// The wrapped MP-SPDZ `Plaintext_mod_prime`
inner: UniquePtr<Plaintext_mod_prime>,
/// Phantom
_phantom: PhantomData<C>,
}

impl<C: CurveGroup> AsRef<Plaintext_mod_prime> for Plaintext<C> {
fn as_ref(&self) -> &Plaintext_mod_prime {
self.inner.as_ref().unwrap()
}
}

impl<C: CurveGroup> Plaintext<C> {
/// Create a new plaintext
pub fn new(params: &BGVParams<C>) -> Self {
let inner = new_plaintext(params.as_ref());
Self { inner, _phantom: PhantomData }
}

/// Set the value of an element in the plaintext
pub fn set_element(&mut self, idx: usize, value: u32) {
set_element_int(self.inner.pin_mut(), idx, value)
}

/// Get the value of an element in the plaintext
pub fn get_element(&self, idx: usize) -> u32 {
get_element_int(self.as_ref(), idx)
}
}

impl<C: CurveGroup> From<UniquePtr<Plaintext_mod_prime>> for Plaintext<C> {
fn from(inner: UniquePtr<Plaintext_mod_prime>) -> Self {
Self { inner, _phantom: PhantomData }
}
}

// --------------
// | Arithmetic |
// --------------

impl<C: CurveGroup> Add for &Plaintext<C> {
type Output = Plaintext<C>;

fn add(self, rhs: Self) -> Self::Output {
add_plaintexts(self.as_ref(), rhs.as_ref()).into()
}
}
impl<C: CurveGroup> Sub for &Plaintext<C> {
type Output = Plaintext<C>;

fn sub(self, rhs: Self) -> Self::Output {
sub_plaintexts(self.as_ref(), rhs.as_ref()).into()
}
}

impl<C: CurveGroup> Mul<&Plaintext<C>> for &Plaintext<C> {
type Output = Plaintext<C>;

fn mul(self, rhs: &Plaintext<C>) -> Self::Output {
mul_plaintexts(self.as_ref(), rhs.as_ref()).into()
}
}

#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng, RngCore};

use super::*;
use crate::TestCurve;

/// A helper to get parameters for the tests
fn get_params() -> BGVParams<TestCurve> {
BGVParams::new(1 /* n_mults */)
}

#[test]
fn test_add() {
let mut rng = thread_rng();
let params = get_params();
let val1 = rng.next_u32() / 2;
let val2 = rng.next_u32() / 2;

let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);
plaintext1.set_element(0, val1);
plaintext2.set_element(0, val2);

let expected = val1 + val2;
let result = &plaintext1 + &plaintext2;
assert_eq!(result.get_element(0), expected);
}

#[test]
fn test_sub() {
let mut rng = thread_rng();
let params = get_params();
let val1 = rng.next_u32();
let val2 = rng.gen_range(0..val1);

let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);
plaintext1.set_element(0, val1);
plaintext2.set_element(0, val2);

let expected = val1 - val2;
let result = &plaintext1 - &plaintext2;
assert_eq!(result.get_element(0), expected);
}

#[test]
fn test_mul() {
let mut rng = thread_rng();
let params = get_params();
let range = 0..(1u32 << 16);
let val1 = rng.gen_range(range.clone());
let val2 = rng.gen_range(range);

let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);
plaintext1.set_element(0, val1);
plaintext2.set_element(0, val2);

let expected = val1 * val2;
let result = &plaintext1 * &plaintext2;
assert_eq!(result.get_element(0), expected);
}
}
9 changes: 9 additions & 0 deletions mp-spdz-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,12 @@
//! and to internalize build and link procedure with the foreign ABI
pub mod ffi;
pub mod fhe;

#[cfg(test)]
mod test_helpers {
/// The curve group to use for testing
pub type TestCurve = ark_bn254::G1Projective;
}
#[cfg(test)]
pub(crate) use test_helpers::*;
5 changes: 5 additions & 0 deletions online-phase/src/algebra/scalar/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ impl<C: CurveGroup> Scalar<C> {
self.0
}

/// Get the bit length of the scalar
pub fn bit_length() -> usize {
C::ScalarField::MODULUS_BIT_SIZE as usize
}

/// Sample a random field element
pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
Self(C::ScalarField::rand(rng))
Expand Down

0 comments on commit 7b108f1

Please sign in to comment.