Skip to content

Commit

Permalink
mp-spdz-rs: benches: Add benchmarks for FHE operations
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Apr 4, 2024
1 parent 09cddd6 commit 7c96b7f
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 3 deletions.
12 changes: 12 additions & 0 deletions mp-spdz-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,19 @@ edition = "2021"
[features]
test-helpers = []

[[bench]]
name = "plaintext_ops"
harness = false
required-features = ["test-helpers"]

[[bench]]
name = "ciphertext_ops"
harness = false
required-features = ["test-helpers"]

[dependencies]
# === Arithmetic + Crypto === #
ark-bn254 = "0.4"
ark-ec = { version = "0.4", features = ["parallel"] }
ark-ff = { version = "0.4", features = ["parallel"] }
ark-mpc = { path = "../online-phase" }
Expand All @@ -22,4 +33,5 @@ pkg-config = "0.3"

[dev-dependencies]
ark-bn254 = "0.4"
criterion = { version = "0.5", features = ["async", "async_tokio"] }
rand = "0.8.4"
198 changes: 198 additions & 0 deletions mp-spdz-rs/benches/ciphertext_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
//! Benchmarks for ciphertext operations
use ark_mpc::algebra::Scalar;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use mp_spdz_rs::fhe::{keys::BGVKeypair, params::BGVParams, plaintext::Plaintext};
use mp_spdz_rs::TestCurve;

/// Benchmark the time to encrypt and decrypt a plaintext
fn bench_ciphertext_encrypt_decrypt(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let mut keypair = BGVKeypair::gen(&params);

let plaintext = Plaintext::new(&params);
let slots = plaintext.num_slots();

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("encrypt-decrypt", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext = Plaintext::new(&params);

for i in 0..(slots as usize) {
plaintext.set_element(i, Scalar::random(&mut rng));
}

let start = std::time::Instant::now();
let ciphertext = keypair.encrypt(&plaintext);
let _ = keypair.decrypt(&ciphertext);
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark addition between a ciphertext and a plaintext
///
/// This includes only the time to add the two values together
fn bench_ciphertext_plaintext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let keypair = BGVKeypair::gen(&params);

let plaintext = Plaintext::new(&params);
let slots = plaintext.num_slots();

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("plaintext-add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);

for i in 0..(slots as usize) {
plaintext1.set_element(i, Scalar::random(&mut rng));
plaintext2.set_element(i, Scalar::random(&mut rng));
}

let ciphertext = keypair.encrypt(&plaintext1);

let start = std::time::Instant::now();
let _ = &ciphertext + &plaintext2;
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark multiplying a ciphertext by a plaintext
fn bench_ciphertext_plaintext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let keypair = BGVKeypair::gen(&params);

let plaintext = Plaintext::new(&params);
let slots = plaintext.num_slots();

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("plaintext-mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);

for i in 0..(slots as usize) {
plaintext1.set_element(i, Scalar::random(&mut rng));
plaintext2.set_element(i, Scalar::random(&mut rng));
}

let ciphertext = keypair.encrypt(&plaintext1);

let start = std::time::Instant::now();
let _ = &ciphertext * &plaintext2;
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark adding a ciphertext to another ciphertext
fn bench_ciphertext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let keypair = BGVKeypair::gen(&params);

let plaintext = Plaintext::new(&params);
let slots = plaintext.num_slots();

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("ciphertext-add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);

for i in 0..(slots as usize) {
plaintext1.set_element(i, Scalar::random(&mut rng));
plaintext2.set_element(i, Scalar::random(&mut rng));
}

let ciphertext1 = keypair.encrypt(&plaintext1);
let ciphertext2 = keypair.encrypt(&plaintext2);

let start = std::time::Instant::now();
let _ = &ciphertext1 + &ciphertext2;
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark multiplying a ciphertext by another ciphertext
fn bench_ciphertext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new(1 /* n_mults */);
let keypair = BGVKeypair::gen(&params);

let plaintext = Plaintext::new(&params);
let slots = plaintext.num_slots();

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("ciphertext-mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);

for i in 0..(slots as usize) {
plaintext1.set_element(i, Scalar::random(&mut rng));
plaintext2.set_element(i, Scalar::random(&mut rng));
}

let ciphertext1 = keypair.encrypt(&plaintext1);
let ciphertext2 = keypair.encrypt(&plaintext2);

let start = std::time::Instant::now();
let _ = &ciphertext1.mul_ciphertext(&ciphertext2, &keypair.public_key);
total_time += start.elapsed();
}

total_time
})
});
}

criterion_group! {
name = ciphertext_ops;
config = Criterion::default().sample_size(10);
targets = bench_ciphertext_encrypt_decrypt,
bench_ciphertext_plaintext_addition,
bench_ciphertext_plaintext_multiplication,
bench_ciphertext_addition,
bench_ciphertext_multiplication,
}
criterion_main!(ciphertext_ops);
80 changes: 80 additions & 0 deletions mp-spdz-rs/benches/plaintext_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::time::{Duration, Instant};

use ark_mpc::algebra::Scalar;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use mp_spdz_rs::fhe::{params::BGVParams, plaintext::Plaintext};
use mp_spdz_rs::TestCurve;
use rand::thread_rng;

/// Benchmark plaintext addition
fn benchmark_plaintext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("plaintext-ops");

let params = BGVParams::<TestCurve>::new_no_mults();
let plaintext = Plaintext::new(&params);
let slots = plaintext.num_slots();

group.throughput(criterion::Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = Duration::default();
let mut rng = thread_rng();

for _ in 0..n_iters {
let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);

for i in 0..(slots as usize) {
plaintext1.set_element(i, Scalar::random(&mut rng));
plaintext2.set_element(i, Scalar::random(&mut rng));
}

let start = Instant::now();
let _ = black_box(&plaintext1 + &plaintext2);
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark plaintext multiplication
fn benchmark_plaintext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("plaintext-ops");

let params = BGVParams::<TestCurve>::new_no_mults();
let plaintext = Plaintext::new(&params);
let slots = plaintext.num_slots();

group.throughput(criterion::Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = Duration::default();
let mut rng = thread_rng();

for _ in 0..n_iters {
let mut plaintext1 = Plaintext::new(&params);
let mut plaintext2 = Plaintext::new(&params);

for i in 0..(slots as usize) {
plaintext1.set_element(i, Scalar::random(&mut rng));
plaintext2.set_element(i, Scalar::random(&mut rng));
}

let start = Instant::now();
let _ = black_box(&plaintext1 * &plaintext2);
total_time += start.elapsed();
}

total_time
})
});
}

criterion_group! {
name = plaintext_ops;
config = Criterion::default();
targets = benchmark_plaintext_addition, benchmark_plaintext_multiplication
}
criterion_main!(plaintext_ops);
1 change: 1 addition & 0 deletions mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ mod ffi_inner {
// `Plaintext`
type Plaintext_mod_prime;
fn new_plaintext(params: &FHE_Params) -> UniquePtr<Plaintext_mod_prime>;
fn num_slots(self: &Plaintext_mod_prime) -> u32;
fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32;
fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32);
fn get_element_bigint(plaintext: &Plaintext_mod_prime, idx: usize) -> UniquePtr<bigint>;
Expand Down
5 changes: 5 additions & 0 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ impl<C: CurveGroup> Plaintext<C> {
Self { inner, _phantom: PhantomData }
}

/// Get the number of slots in the plaintext
pub fn num_slots(&self) -> u32 {
self.inner.num_slots()
}

/// Set the value of an element in the plaintext
pub fn set_element(&mut self, idx: usize, value: Scalar<C>) {
let val_bigint = scalar_to_ffi_bigint(value);
Expand Down
6 changes: 3 additions & 3 deletions mp-spdz-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ pub mod ffi;
pub mod fhe;

#[allow(clippy::items_after_test_module)]
#[cfg(test)]
#[cfg(any(test, feature = "test-helpers"))]
mod test_helpers {
/// The curve group to use for testing
pub type TestCurve = ark_bn254::G1Projective;
}
#[cfg(test)]
pub(crate) use test_helpers::*;
#[cfg(any(test, feature = "test-helpers"))]
pub use test_helpers::*;

0 comments on commit 7c96b7f

Please sign in to comment.