Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mp-spdz-rs: benches: Add benchmarks for FHE operations #57

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion mp-spdz-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,37 @@ version = "0.1.0"
edition = "2021"

[features]
test-helpers = []
test-helpers = ["dep:rand"]

[[bench]]
name = "plaintext_ops"
harness = false
required-features = ["test-helpers"]

[[bench]]
name = "ciphertext_ops"
harness = false
required-features = ["test-helpers"]

[dependencies]
# === Arithmetic + Crypto === #
ark-bn254 = "0.4"
ark-ec = { version = "0.4", features = ["parallel"] }
ark-ff = { version = "0.4", features = ["parallel"] }
ark-mpc = { path = "../online-phase" }

# === Bindings === #
cxx = "1.0"

# === Misc === #
rand = { version = "0.8.4", optional = true }

[build-dependencies]
cxx-build = "1.0"
itertools = "0.12.0"
pkg-config = "0.3"

[dev-dependencies]
ark-bn254 = "0.4"
criterion = { version = "0.5", features = ["async", "async_tokio"] }
rand = "0.8.4"
155 changes: 155 additions & 0 deletions mp-spdz-rs/benches/ciphertext_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//! Benchmarks for ciphertext operations

use ark_mpc::algebra::Scalar;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use mp_spdz_rs::benchmark_helpers::random_plaintext;
use mp_spdz_rs::fhe::{keys::BGVKeypair, params::BGVParams, plaintext::Plaintext};
use mp_spdz_rs::TestCurve;

/// Benchmark the time to encrypt and decrypt a plaintext
fn bench_ciphertext_encrypt_decrypt(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();
let mut keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("encrypt-decrypt", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let plaintext = random_plaintext(&params);

let start = std::time::Instant::now();
let ciphertext = keypair.encrypt(&plaintext);
let _ = keypair.decrypt(&ciphertext);
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark addition between a ciphertext and a plaintext
///
/// This includes only the time to add the two values together
fn bench_ciphertext_plaintext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();
let keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("plaintext-add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext = random_plaintext(&params);
let ciphertext = keypair.encrypt(&random_plaintext(&params));

let start = std::time::Instant::now();
let _ = &ciphertext + &plaintext;
joeykraut marked this conversation as resolved.
Show resolved Hide resolved
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark multiplying a ciphertext by a plaintext
fn bench_ciphertext_plaintext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();
let keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("plaintext-mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let mut plaintext = random_plaintext(&params);
let ciphertext = keypair.encrypt(&random_plaintext(&params));

let start = std::time::Instant::now();
let _ = &ciphertext * &plaintext;
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark adding a ciphertext to another ciphertext
fn bench_ciphertext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();
let keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("ciphertext-add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let ciphertext1 = keypair.encrypt(&random_plaintext(&params));
let ciphertext2 = keypair.encrypt(&random_plaintext(&params));

let start = std::time::Instant::now();
let _ = &ciphertext1 + &ciphertext2;
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark multiplying a ciphertext by another ciphertext
fn bench_ciphertext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("ciphertext-ops");
let params = BGVParams::<TestCurve>::new(1 /* n_mults */);
let slots = params.plaintext_slots();
let keypair = BGVKeypair::gen(&params);

group.throughput(Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("ciphertext-mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = std::time::Duration::default();
let mut rng = rand::thread_rng();

for _ in 0..n_iters {
let ciphertext1 = keypair.encrypt(&random_plaintext(&params));
let ciphertext2 = keypair.encrypt(&random_plaintext(&params));

let start = std::time::Instant::now();
let _ = &ciphertext1.mul_ciphertext(&ciphertext2, &keypair.public_key);
total_time += start.elapsed();
}

total_time
})
});
}

criterion_group! {
name = ciphertext_ops;
config = Criterion::default().sample_size(10);
targets = bench_ciphertext_encrypt_decrypt,
bench_ciphertext_plaintext_addition,
bench_ciphertext_plaintext_multiplication,
bench_ciphertext_addition,
bench_ciphertext_multiplication,
}
criterion_main!(ciphertext_ops);
69 changes: 69 additions & 0 deletions mp-spdz-rs/benches/plaintext_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::time::{Duration, Instant};

use ark_mpc::algebra::Scalar;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use mp_spdz_rs::benchmark_helpers::random_plaintext;
use mp_spdz_rs::fhe::{params::BGVParams, plaintext::Plaintext};
use mp_spdz_rs::TestCurve;
use rand::thread_rng;

/// Benchmark plaintext addition
fn benchmark_plaintext_addition(c: &mut Criterion) {
let mut group = c.benchmark_group("plaintext-ops");

let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();

group.throughput(criterion::Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("add", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = Duration::default();
let mut rng = thread_rng();

for _ in 0..n_iters {
let mut plaintext1 = random_plaintext(&params);
let mut plaintext2 = random_plaintext(&params);

let start = Instant::now();
let _ = black_box(&plaintext1 + &plaintext2);
total_time += start.elapsed();
}

total_time
})
});
}

/// Benchmark plaintext multiplication
fn benchmark_plaintext_multiplication(c: &mut Criterion) {
let mut group = c.benchmark_group("plaintext-ops");

let params = BGVParams::<TestCurve>::new_no_mults();
let slots = params.plaintext_slots();

group.throughput(criterion::Throughput::Elements(slots as u64));
group.bench_function(BenchmarkId::new("mul", ""), |b| {
b.iter_custom(|n_iters| {
let mut total_time = Duration::default();
let mut rng = thread_rng();

for _ in 0..n_iters {
let plaintext1 = random_plaintext(&params);
let plaintext2 = random_plaintext(&params);

let start = Instant::now();
let _ = black_box(&plaintext1 * &plaintext2);
total_time += start.elapsed();
}

total_time
})
});
}

criterion_group! {
name = plaintext_ops;
config = Criterion::default();
targets = benchmark_plaintext_addition, benchmark_plaintext_multiplication
}
criterion_main!(plaintext_ops);
2 changes: 2 additions & 0 deletions mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod ffi_inner {
// `FHE_Params`
type FHE_Params;
fn new_fhe_params(n_mults: i32, drown_sec: i32) -> UniquePtr<FHE_Params>;
fn n_plaintext_slots(self: &FHE_Params) -> u32;
fn basic_generation_mod_prime(self: Pin<&mut FHE_Params>, plaintext_length: i32);
fn param_generation_with_modulus(self: Pin<&mut FHE_Params>, plaintext_modulus: &bigint);
fn get_plaintext_mod(params: &FHE_Params) -> UniquePtr<bigint>;
Expand All @@ -36,6 +37,7 @@ mod ffi_inner {
// `Plaintext`
type Plaintext_mod_prime;
fn new_plaintext(params: &FHE_Params) -> UniquePtr<Plaintext_mod_prime>;
fn num_slots(self: &Plaintext_mod_prime) -> u32;
fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32;
fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32);
fn get_element_bigint(plaintext: &Plaintext_mod_prime, idx: usize) -> UniquePtr<bigint>;
Expand Down
5 changes: 5 additions & 0 deletions mp-spdz-rs/src/fhe/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,9 @@ impl<C: CurveGroup> BGVParams<C> {
pub fn new_no_mults() -> Self {
Self::new(0)
}

/// Get the number of plaintext slots the given parameters support
pub fn plaintext_slots(&self) -> u32 {
self.as_ref().n_plaintext_slots() as u32
}
}
5 changes: 5 additions & 0 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ impl<C: CurveGroup> Plaintext<C> {
Self { inner, _phantom: PhantomData }
}

/// Get the number of slots in the plaintext
pub fn num_slots(&self) -> u32 {
self.inner.num_slots()
}

/// Set the value of an element in the plaintext
pub fn set_element(&mut self, idx: usize, value: Scalar<C>) {
let val_bigint = scalar_to_ffi_bigint(value);
Expand Down
28 changes: 25 additions & 3 deletions mp-spdz-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,32 @@ pub mod ffi;
pub mod fhe;

#[allow(clippy::items_after_test_module)]
#[cfg(test)]
#[cfg(any(test, feature = "test-helpers"))]
mod test_helpers {
/// The curve group to use for testing
pub type TestCurve = ark_bn254::G1Projective;
}
#[cfg(test)]
pub(crate) use test_helpers::*;
#[cfg(any(test, feature = "test-helpers"))]
pub use test_helpers::*;

#[cfg(feature = "test-helpers")]
pub mod benchmark_helpers {
use ark_ec::CurveGroup;
use ark_ff::UniformRand;
use ark_mpc::algebra::Scalar;
use rand::thread_rng;

use crate::fhe::{params::BGVParams, plaintext::Plaintext};

/// Get a random plaintext filled with random values
pub fn random_plaintext<C: CurveGroup>(params: &BGVParams<C>) -> Plaintext<C> {
let mut rng = thread_rng();
let mut pt = Plaintext::new(params);

for i in 0..pt.num_slots() as usize {
pt.set_element(i, Scalar::random(&mut rng));
}

pt
}
}
Loading