Skip to content

Commit

Permalink
mp-spdz-rs: benches: Add benchmarks for FHE operations
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Apr 5, 2024
1 parent 09cddd6 commit e7981cc
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 4 deletions.
17 changes: 16 additions & 1 deletion mp-spdz-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,37 @@ version = "0.1.0"
edition = "2021"

[features]
test-helpers = []
test-helpers = ["dep:rand"]

[[bench]]
name = "plaintext_ops"
harness = false
required-features = ["test-helpers"]

[[bench]]
name = "ciphertext_ops"
harness = false
required-features = ["test-helpers"]

[dependencies]
# === Arithmetic + Crypto === #
ark-bn254 = "0.4"
ark-ec = { version = "0.4", features = ["parallel"] }
ark-ff = { version = "0.4", features = ["parallel"] }
ark-mpc = { path = "../online-phase" }

# === Bindings === #
cxx = "1.0"

# === Misc === #
rand = { version = "0.8.4", optional = true }

[build-dependencies]
cxx-build = "1.0"
itertools = "0.12.0"
pkg-config = "0.3"

[dev-dependencies]
ark-bn254 = "0.4"
criterion = { version = "0.5", features = ["async", "async_tokio"] }
rand = "0.8.4"
155 changes: 155 additions & 0 deletions mp-spdz-rs/benches/ciphertext_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//! Benchmarks for ciphertext operations
use ark_mpc::algebra::Scalar;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use mp_spdz_rs::benchmark_helpers::random_plaintext;
use mp_spdz_rs::fhe::{keys::BGVKeypair, params::BGVParams, plaintext::Plaintext};
use mp_spdz_rs::TestCurve;

/// Benchmark the time to encrypt and decrypt a plaintext
fn bench_ciphertext_encrypt_decrypt(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();
let mut keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("encrypt-decrypt", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let plaintext = random_plaintext(&params);

let start = std::time::Instant::now();
let ciphertext = keypair.encrypt(&plaintext);
let _ = keypair.decrypt(&ciphertext);
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark addition between a ciphertext and a plaintext
///
/// This includes only the time to add the two values together
fn bench_ciphertext_plaintext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();
let keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("plaintext-add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext = random_plaintext(&params);
let ciphertext = keypair.encrypt(&random_plaintext(&params));

let start = std::time::Instant::now();
let _ = &ciphertext + &plaintext;
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark multiplying a ciphertext by a plaintext
fn bench_ciphertext_plaintext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();
let keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("plaintext-mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext = random_plaintext(&params);
let ciphertext = keypair.encrypt(&random_plaintext(&params));

let start = std::time::Instant::now();
let _ = &ciphertext * &plaintext;
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark adding a ciphertext to another ciphertext
fn bench_ciphertext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();
let keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("ciphertext-add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let ciphertext1 = keypair.encrypt(&random_plaintext(&params));
let ciphertext2 = keypair.encrypt(&random_plaintext(&params));

let start = std::time::Instant::now();
let _ = &ciphertext1 + &ciphertext2;
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark multiplying a ciphertext by another ciphertext
fn bench_ciphertext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new(1 /* n_mults */);
let slots = params.plaintext_slots();
let keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("ciphertext-mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let ciphertext1 = keypair.encrypt(&random_plaintext(&params));
let ciphertext2 = keypair.encrypt(&random_plaintext(&params));

let start = std::time::Instant::now();
let _ = &ciphertext1.mul_ciphertext(&ciphertext2, &keypair.public_key);
total_time += start.elapsed();
}

total_time
})
});
}

criterion_group! {
name = ciphertext_ops;
config = Criterion::default().sample_size(10);
targets = bench_ciphertext_encrypt_decrypt,
bench_ciphertext_plaintext_addition,
bench_ciphertext_plaintext_multiplication,
bench_ciphertext_addition,
bench_ciphertext_multiplication,
}
criterion_main!(ciphertext_ops);
69 changes: 69 additions & 0 deletions mp-spdz-rs/benches/plaintext_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::time::{Duration, Instant};

use ark_mpc::algebra::Scalar;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use mp_spdz_rs::benchmark_helpers::random_plaintext;
use mp_spdz_rs::fhe::{params::BGVParams, plaintext::Plaintext};
use mp_spdz_rs::TestCurve;
use rand::thread_rng;

/// Benchmark plaintext addition
fn benchmark_plaintext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("plaintext-ops");

let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();

group.throughput(criterion::Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = Duration::default();
let mut rng = thread_rng();

for _ in 0..n_iters {
let mut plaintext1 = random_plaintext(&params);
let mut plaintext2 = random_plaintext(&params);

let start = Instant::now();
let _ = black_box(&plaintext1 + &plaintext2);
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark plaintext multiplication
fn benchmark_plaintext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("plaintext-ops");

let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();

group.throughput(criterion::Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = Duration::default();
let mut rng = thread_rng();

for _ in 0..n_iters {
let plaintext1 = random_plaintext(&params);
let plaintext2 = random_plaintext(&params);

let start = Instant::now();
let _ = black_box(&plaintext1 * &plaintext2);
total_time += start.elapsed();
}

total_time
})
});
}

criterion_group! {
name = plaintext_ops;
config = Criterion::default();
targets = benchmark_plaintext_addition, benchmark_plaintext_multiplication
}
criterion_main!(plaintext_ops);
2 changes: 2 additions & 0 deletions mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod ffi_inner {
// `FHE_Params`
type FHE_Params;
fn new_fhe_params(n_mults: i32, drown_sec: i32) -> UniquePtr<FHE_Params>;
fn n_plaintext_slots(self: &FHE_Params) -> u32;
fn basic_generation_mod_prime(self: Pin<&mut FHE_Params>, plaintext_length: i32);
fn param_generation_with_modulus(self: Pin<&mut FHE_Params>, plaintext_modulus: &bigint);
fn get_plaintext_mod(params: &FHE_Params) -> UniquePtr<bigint>;
Expand All @@ -36,6 +37,7 @@ mod ffi_inner {
// `Plaintext`
type Plaintext_mod_prime;
fn new_plaintext(params: &FHE_Params) -> UniquePtr<Plaintext_mod_prime>;
fn num_slots(self: &Plaintext_mod_prime) -> u32;
fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32;
fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32);
fn get_element_bigint(plaintext: &Plaintext_mod_prime, idx: usize) -> UniquePtr<bigint>;
Expand Down
5 changes: 5 additions & 0 deletions mp-spdz-rs/src/fhe/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,9 @@ impl<C: CurveGroup> BGVParams<C> {
pub fn new_no_mults() -> Self {
Self::new(0)
}

/// Get the number of plaintext slots the given parameters support
pub fn plaintext_slots(&self) -> u32 {
self.as_ref().n_plaintext_slots() as u32
}
}
5 changes: 5 additions & 0 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ impl<C: CurveGroup> Plaintext<C> {
Self { inner, _phantom: PhantomData }
}

/// Get the number of slots in the plaintext
pub fn num_slots(&self) -> u32 {
self.inner.num_slots()
}

/// Set the value of an element in the plaintext
pub fn set_element(&mut self, idx: usize, value: Scalar<C>) {
let val_bigint = scalar_to_ffi_bigint(value);
Expand Down
28 changes: 25 additions & 3 deletions mp-spdz-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,32 @@ pub mod ffi;
pub mod fhe;

#[allow(clippy::items_after_test_module)]
#[cfg(test)]
#[cfg(any(test, feature = "test-helpers"))]
mod test_helpers {
/// The curve group to use for testing
pub type TestCurve = ark_bn254::G1Projective;
}
#[cfg(test)]
pub(crate) use test_helpers::*;
#[cfg(any(test, feature = "test-helpers"))]
pub use test_helpers::*;

#[cfg(feature = "test-helpers")]
pub mod benchmark_helpers {
use ark_ec::CurveGroup;
use ark_ff::UniformRand;
use ark_mpc::algebra::Scalar;
use rand::thread_rng;

use crate::fhe::{params::BGVParams, plaintext::Plaintext};

/// Get a random plaintext filled with random values
pub fn random_plaintext<C: CurveGroup>(params: &BGVParams<C>) -> Plaintext<C> {
let mut rng = thread_rng();
let mut pt = Plaintext::new(params);

for i in 0..pt.num_slots() as usize {
pt.set_element(i, Scalar::random(&mut rng));
}

pt
}
}

0 comments on commit e7981cc

Please sign in to comment.