diff --git a/mp-spdz-rs/Cargo.toml b/mp-spdz-rs/Cargo.toml index 32a4faa..22aa803 100644 --- a/mp-spdz-rs/Cargo.toml +++ b/mp-spdz-rs/Cargo.toml @@ -4,10 +4,21 @@ version = "0.1.0" edition = "2021" [features] -test-helpers = [] +test-helpers = ["dep:rand"] + +[[bench]] +name = "plaintext_ops" +harness = false +required-features = ["test-helpers"] + +[[bench]] +name = "ciphertext_ops" +harness = false +required-features = ["test-helpers"] [dependencies] # === Arithmetic + Crypto === # +ark-bn254 = "0.4" ark-ec = { version = "0.4", features = ["parallel"] } ark-ff = { version = "0.4", features = ["parallel"] } ark-mpc = { path = "../online-phase" } @@ -15,6 +26,9 @@ ark-mpc = { path = "../online-phase" } # === Bindings === # cxx = "1.0" +# === Misc === # +rand = { version = "0.8.4", optional = true } + [build-dependencies] cxx-build = "1.0" itertools = "0.12.0" @@ -22,4 +36,5 @@ pkg-config = "0.3" [dev-dependencies] ark-bn254 = "0.4" +criterion = { version = "0.5", features = ["async", "async_tokio"] } rand = "0.8.4" diff --git a/mp-spdz-rs/benches/ciphertext_ops.rs b/mp-spdz-rs/benches/ciphertext_ops.rs new file mode 100644 index 0000000..325e152 --- /dev/null +++ b/mp-spdz-rs/benches/ciphertext_ops.rs @@ -0,0 +1,155 @@ +//! Benchmarks for ciphertext operations + +use ark_mpc::algebra::Scalar; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use mp_spdz_rs::benchmark_helpers::random_plaintext; +use mp_spdz_rs::fhe::{keys::BGVKeypair, params::BGVParams, plaintext::Plaintext}; +use mp_spdz_rs::TestCurve; + +/// Benchmark the time to encrypt and decrypt a plaintext +fn bench_ciphertext_encrypt_decrypt(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new_no_mults(); + let slots = params.plaintext_slots(); + let mut keypair = BGVKeypair::gen(¶ms); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("encrypt-decrypt", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let plaintext = random_plaintext(¶ms); + + let start = std::time::Instant::now(); + let ciphertext = keypair.encrypt(&plaintext); + let _ = keypair.decrypt(&ciphertext); + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark addition between a ciphertext and a plaintext +/// +/// This includes only the time to add the two values together +fn bench_ciphertext_plaintext_addition(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new_no_mults(); + let slots = params.plaintext_slots(); + let keypair = BGVKeypair::gen(¶ms); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("plaintext-add", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let mut plaintext = random_plaintext(¶ms); + let ciphertext = keypair.encrypt(&random_plaintext(¶ms)); + + let start = std::time::Instant::now(); + let _ = &ciphertext + &plaintext; + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark multiplying a ciphertext by a plaintext +fn bench_ciphertext_plaintext_multiplication(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new_no_mults(); + let slots = params.plaintext_slots(); + let keypair = BGVKeypair::gen(¶ms); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("plaintext-mul", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let mut plaintext = random_plaintext(¶ms); + let ciphertext = keypair.encrypt(&random_plaintext(¶ms)); + + let start = std::time::Instant::now(); + let _ = &ciphertext * &plaintext; + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark adding a ciphertext to another ciphertext +fn bench_ciphertext_addition(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new_no_mults(); + let slots = params.plaintext_slots(); + let keypair = BGVKeypair::gen(¶ms); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("ciphertext-add", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let ciphertext1 = keypair.encrypt(&random_plaintext(¶ms)); + let ciphertext2 = keypair.encrypt(&random_plaintext(¶ms)); + + let start = std::time::Instant::now(); + let _ = &ciphertext1 + &ciphertext2; + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark multiplying a ciphertext by another ciphertext +fn bench_ciphertext_multiplication(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new(1 /* n_mults */); + let slots = params.plaintext_slots(); + let keypair = BGVKeypair::gen(¶ms); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("ciphertext-mul", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let ciphertext1 = keypair.encrypt(&random_plaintext(¶ms)); + let ciphertext2 = keypair.encrypt(&random_plaintext(¶ms)); + + let start = std::time::Instant::now(); + let _ = &ciphertext1.mul_ciphertext(&ciphertext2, &keypair.public_key); + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +criterion_group! { + name = ciphertext_ops; + config = Criterion::default().sample_size(10); + targets = bench_ciphertext_encrypt_decrypt, + bench_ciphertext_plaintext_addition, + bench_ciphertext_plaintext_multiplication, + bench_ciphertext_addition, + bench_ciphertext_multiplication, +} +criterion_main!(ciphertext_ops); diff --git a/mp-spdz-rs/benches/plaintext_ops.rs b/mp-spdz-rs/benches/plaintext_ops.rs new file mode 100644 index 0000000..d6d62f9 --- /dev/null +++ b/mp-spdz-rs/benches/plaintext_ops.rs @@ -0,0 +1,69 @@ +use std::time::{Duration, Instant}; + +use ark_mpc::algebra::Scalar; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use mp_spdz_rs::benchmark_helpers::random_plaintext; +use mp_spdz_rs::fhe::{params::BGVParams, plaintext::Plaintext}; +use mp_spdz_rs::TestCurve; +use rand::thread_rng; + +/// Benchmark plaintext addition +fn benchmark_plaintext_addition(c: &mut Criterion) { + let mut group = c.benchmark_group("plaintext-ops"); + + let params = BGVParams::::new_no_mults(); + let slots = params.plaintext_slots(); + + group.throughput(criterion::Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("add", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = Duration::default(); + let mut rng = thread_rng(); + + for _ in 0..n_iters { + let mut plaintext1 = random_plaintext(¶ms); + let mut plaintext2 = random_plaintext(¶ms); + + let start = Instant::now(); + let _ = black_box(&plaintext1 + &plaintext2); + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark plaintext multiplication +fn benchmark_plaintext_multiplication(c: &mut Criterion) { + let mut group = c.benchmark_group("plaintext-ops"); + + let params = BGVParams::::new_no_mults(); + let slots = params.plaintext_slots(); + + group.throughput(criterion::Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("mul", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = Duration::default(); + let mut rng = thread_rng(); + + for _ in 0..n_iters { + let plaintext1 = random_plaintext(¶ms); + let plaintext2 = random_plaintext(¶ms); + + let start = Instant::now(); + let _ = black_box(&plaintext1 * &plaintext2); + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +criterion_group! { + name = plaintext_ops; + config = Criterion::default(); + targets = benchmark_plaintext_addition, benchmark_plaintext_multiplication +} +criterion_main!(plaintext_ops); diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs index 8828749..3963426 100644 --- a/mp-spdz-rs/src/ffi.rs +++ b/mp-spdz-rs/src/ffi.rs @@ -18,6 +18,7 @@ mod ffi_inner { // `FHE_Params` type FHE_Params; fn new_fhe_params(n_mults: i32, drown_sec: i32) -> UniquePtr; + fn n_plaintext_slots(self: &FHE_Params) -> u32; fn basic_generation_mod_prime(self: Pin<&mut FHE_Params>, plaintext_length: i32); fn param_generation_with_modulus(self: Pin<&mut FHE_Params>, plaintext_modulus: &bigint); fn get_plaintext_mod(params: &FHE_Params) -> UniquePtr; @@ -36,6 +37,7 @@ mod ffi_inner { // `Plaintext` type Plaintext_mod_prime; fn new_plaintext(params: &FHE_Params) -> UniquePtr; + fn num_slots(self: &Plaintext_mod_prime) -> u32; fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32; fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32); fn get_element_bigint(plaintext: &Plaintext_mod_prime, idx: usize) -> UniquePtr; diff --git a/mp-spdz-rs/src/fhe/params.rs b/mp-spdz-rs/src/fhe/params.rs index 2a70a05..5c0d261 100644 --- a/mp-spdz-rs/src/fhe/params.rs +++ b/mp-spdz-rs/src/fhe/params.rs @@ -42,4 +42,9 @@ impl BGVParams { pub fn new_no_mults() -> Self { Self::new(0) } + + /// Get the number of plaintext slots the given parameters support + pub fn plaintext_slots(&self) -> u32 { + self.as_ref().n_plaintext_slots() as u32 + } } diff --git a/mp-spdz-rs/src/fhe/plaintext.rs b/mp-spdz-rs/src/fhe/plaintext.rs index 52173cb..9f0e668 100644 --- a/mp-spdz-rs/src/fhe/plaintext.rs +++ b/mp-spdz-rs/src/fhe/plaintext.rs @@ -39,6 +39,11 @@ impl Plaintext { Self { inner, _phantom: PhantomData } } + /// Get the number of slots in the plaintext + pub fn num_slots(&self) -> u32 { + self.inner.num_slots() + } + /// Set the value of an element in the plaintext pub fn set_element(&mut self, idx: usize, value: Scalar) { let val_bigint = scalar_to_ffi_bigint(value); diff --git a/mp-spdz-rs/src/lib.rs b/mp-spdz-rs/src/lib.rs index 52a6e3a..a2f4703 100644 --- a/mp-spdz-rs/src/lib.rs +++ b/mp-spdz-rs/src/lib.rs @@ -8,10 +8,32 @@ pub mod ffi; pub mod fhe; #[allow(clippy::items_after_test_module)] -#[cfg(test)] +#[cfg(any(test, feature = "test-helpers"))] mod test_helpers { /// The curve group to use for testing pub type TestCurve = ark_bn254::G1Projective; } -#[cfg(test)] -pub(crate) use test_helpers::*; +#[cfg(any(test, feature = "test-helpers"))] +pub use test_helpers::*; + +#[cfg(feature = "test-helpers")] +pub mod benchmark_helpers { + use ark_ec::CurveGroup; + use ark_ff::UniformRand; + use ark_mpc::algebra::Scalar; + use rand::thread_rng; + + use crate::fhe::{params::BGVParams, plaintext::Plaintext}; + + /// Get a random plaintext filled with random values + pub fn random_plaintext(params: &BGVParams) -> Plaintext { + let mut rng = thread_rng(); + let mut pt = Plaintext::new(params); + + for i in 0..pt.num_slots() as usize { + pt.set_element(i, Scalar::random(&mut rng)); + } + + pt + } +}