diff --git a/mp-spdz-rs/Cargo.toml b/mp-spdz-rs/Cargo.toml index 32a4faa..73489f2 100644 --- a/mp-spdz-rs/Cargo.toml +++ b/mp-spdz-rs/Cargo.toml @@ -6,8 +6,19 @@ edition = "2021" [features] test-helpers = [] +[[bench]] +name = "plaintext_ops" +harness = false +required-features = ["test-helpers"] + +[[bench]] +name = "ciphertext_ops" +harness = false +required-features = ["test-helpers"] + [dependencies] # === Arithmetic + Crypto === # +ark-bn254 = "0.4" ark-ec = { version = "0.4", features = ["parallel"] } ark-ff = { version = "0.4", features = ["parallel"] } ark-mpc = { path = "../online-phase" } @@ -22,4 +33,5 @@ pkg-config = "0.3" [dev-dependencies] ark-bn254 = "0.4" +criterion = { version = "0.5", features = ["async", "async_tokio"] } rand = "0.8.4" diff --git a/mp-spdz-rs/benches/ciphertext_ops.rs b/mp-spdz-rs/benches/ciphertext_ops.rs new file mode 100644 index 0000000..c79d252 --- /dev/null +++ b/mp-spdz-rs/benches/ciphertext_ops.rs @@ -0,0 +1,198 @@ +//! Benchmarks for ciphertext operations + +use ark_mpc::algebra::Scalar; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use mp_spdz_rs::fhe::{keys::BGVKeypair, params::BGVParams, plaintext::Plaintext}; +use mp_spdz_rs::TestCurve; + +/// Benchmark the time to encrypt and decrypt a plaintext +fn bench_ciphertext_encrypt_decrypt(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new_no_mults(); + let mut keypair = BGVKeypair::gen(¶ms); + + let plaintext = Plaintext::new(¶ms); + let slots = plaintext.num_slots(); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("encrypt-decrypt", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let mut plaintext = Plaintext::new(¶ms); + + for i in 0..(slots as usize) { + plaintext.set_element(i, Scalar::random(&mut rng)); + } + + let start = std::time::Instant::now(); + let ciphertext = keypair.encrypt(&plaintext); + let _ = keypair.decrypt(&ciphertext); + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark addition between a ciphertext and a plaintext +/// +/// This includes only the time to add the two values together +fn bench_ciphertext_plaintext_addition(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new_no_mults(); + let keypair = BGVKeypair::gen(¶ms); + + let plaintext = Plaintext::new(¶ms); + let slots = plaintext.num_slots(); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("plaintext-add", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + + for i in 0..(slots as usize) { + plaintext1.set_element(i, Scalar::random(&mut rng)); + plaintext2.set_element(i, Scalar::random(&mut rng)); + } + + let ciphertext = keypair.encrypt(&plaintext1); + + let start = std::time::Instant::now(); + let _ = &ciphertext + &plaintext2; + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark multiplying a ciphertext by a plaintext +fn bench_ciphertext_plaintext_multiplication(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new_no_mults(); + let keypair = BGVKeypair::gen(¶ms); + + let plaintext = Plaintext::new(¶ms); + let slots = plaintext.num_slots(); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("plaintext-mul", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + + for i in 0..(slots as usize) { + plaintext1.set_element(i, Scalar::random(&mut rng)); + plaintext2.set_element(i, Scalar::random(&mut rng)); + } + + let ciphertext = keypair.encrypt(&plaintext1); + + let start = std::time::Instant::now(); + let _ = &ciphertext * &plaintext2; + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark adding a ciphertext to another ciphertext +fn bench_ciphertext_addition(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new_no_mults(); + let keypair = BGVKeypair::gen(¶ms); + + let plaintext = Plaintext::new(¶ms); + let slots = plaintext.num_slots(); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("ciphertext-add", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + + for i in 0..(slots as usize) { + plaintext1.set_element(i, Scalar::random(&mut rng)); + plaintext2.set_element(i, Scalar::random(&mut rng)); + } + + let ciphertext1 = keypair.encrypt(&plaintext1); + let ciphertext2 = keypair.encrypt(&plaintext2); + + let start = std::time::Instant::now(); + let _ = &ciphertext1 + &ciphertext2; + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark multiplying a ciphertext by another ciphertext +fn bench_ciphertext_multiplication(c: &mut Criterion) { + let mut group = c.benchmark_group("ciphertext-ops"); + let params = BGVParams::::new(1 /* n_mults */); + let keypair = BGVKeypair::gen(¶ms); + + let plaintext = Plaintext::new(¶ms); + let slots = plaintext.num_slots(); + + group.throughput(Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("ciphertext-mul", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = std::time::Duration::default(); + let mut rng = rand::thread_rng(); + + for _ in 0..n_iters { + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + + for i in 0..(slots as usize) { + plaintext1.set_element(i, Scalar::random(&mut rng)); + plaintext2.set_element(i, Scalar::random(&mut rng)); + } + + let ciphertext1 = keypair.encrypt(&plaintext1); + let ciphertext2 = keypair.encrypt(&plaintext2); + + let start = std::time::Instant::now(); + let _ = &ciphertext1.mul_ciphertext(&ciphertext2, &keypair.public_key); + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +criterion_group! { + name = ciphertext_ops; + config = Criterion::default().sample_size(10); + targets = bench_ciphertext_encrypt_decrypt, + bench_ciphertext_plaintext_addition, + bench_ciphertext_plaintext_multiplication, + bench_ciphertext_addition, + bench_ciphertext_multiplication, +} +criterion_main!(ciphertext_ops); diff --git a/mp-spdz-rs/benches/plaintext_ops.rs b/mp-spdz-rs/benches/plaintext_ops.rs new file mode 100644 index 0000000..a756d73 --- /dev/null +++ b/mp-spdz-rs/benches/plaintext_ops.rs @@ -0,0 +1,80 @@ +use std::time::{Duration, Instant}; + +use ark_mpc::algebra::Scalar; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use mp_spdz_rs::fhe::{params::BGVParams, plaintext::Plaintext}; +use mp_spdz_rs::TestCurve; +use rand::thread_rng; + +/// Benchmark plaintext addition +fn benchmark_plaintext_addition(c: &mut Criterion) { + let mut group = c.benchmark_group("plaintext-ops"); + + let params = BGVParams::::new_no_mults(); + let plaintext = Plaintext::new(¶ms); + let slots = plaintext.num_slots(); + + group.throughput(criterion::Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("add", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = Duration::default(); + let mut rng = thread_rng(); + + for _ in 0..n_iters { + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + + for i in 0..(slots as usize) { + plaintext1.set_element(i, Scalar::random(&mut rng)); + plaintext2.set_element(i, Scalar::random(&mut rng)); + } + + let start = Instant::now(); + let _ = black_box(&plaintext1 + &plaintext2); + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +/// Benchmark plaintext multiplication +fn benchmark_plaintext_multiplication(c: &mut Criterion) { + let mut group = c.benchmark_group("plaintext-ops"); + + let params = BGVParams::::new_no_mults(); + let plaintext = Plaintext::new(¶ms); + let slots = plaintext.num_slots(); + + group.throughput(criterion::Throughput::Elements(slots as u64)); + group.bench_function(BenchmarkId::new("mul", ""), |b| { + b.iter_custom(|n_iters| { + let mut total_time = Duration::default(); + let mut rng = thread_rng(); + + for _ in 0..n_iters { + let mut plaintext1 = Plaintext::new(¶ms); + let mut plaintext2 = Plaintext::new(¶ms); + + for i in 0..(slots as usize) { + plaintext1.set_element(i, Scalar::random(&mut rng)); + plaintext2.set_element(i, Scalar::random(&mut rng)); + } + + let start = Instant::now(); + let _ = black_box(&plaintext1 * &plaintext2); + total_time += start.elapsed(); + } + + total_time + }) + }); +} + +criterion_group! { + name = plaintext_ops; + config = Criterion::default(); + targets = benchmark_plaintext_addition, benchmark_plaintext_multiplication +} +criterion_main!(plaintext_ops); diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs index 8828749..2b70ee0 100644 --- a/mp-spdz-rs/src/ffi.rs +++ b/mp-spdz-rs/src/ffi.rs @@ -36,6 +36,7 @@ mod ffi_inner { // `Plaintext` type Plaintext_mod_prime; fn new_plaintext(params: &FHE_Params) -> UniquePtr; + fn num_slots(self: &Plaintext_mod_prime) -> u32; fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32; fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32); fn get_element_bigint(plaintext: &Plaintext_mod_prime, idx: usize) -> UniquePtr; diff --git a/mp-spdz-rs/src/fhe/plaintext.rs b/mp-spdz-rs/src/fhe/plaintext.rs index 52173cb..9f0e668 100644 --- a/mp-spdz-rs/src/fhe/plaintext.rs +++ b/mp-spdz-rs/src/fhe/plaintext.rs @@ -39,6 +39,11 @@ impl Plaintext { Self { inner, _phantom: PhantomData } } + /// Get the number of slots in the plaintext + pub fn num_slots(&self) -> u32 { + self.inner.num_slots() + } + /// Set the value of an element in the plaintext pub fn set_element(&mut self, idx: usize, value: Scalar) { let val_bigint = scalar_to_ffi_bigint(value); diff --git a/mp-spdz-rs/src/lib.rs b/mp-spdz-rs/src/lib.rs index 52a6e3a..a31c8cc 100644 --- a/mp-spdz-rs/src/lib.rs +++ b/mp-spdz-rs/src/lib.rs @@ -8,10 +8,10 @@ pub mod ffi; pub mod fhe; #[allow(clippy::items_after_test_module)] -#[cfg(test)] +#[cfg(any(test, feature = "test-helpers"))] mod test_helpers { /// The curve group to use for testing pub type TestCurve = ark_bn254::G1Projective; } -#[cfg(test)] -pub(crate) use test_helpers::*; +#[cfg(any(test, feature = "test-helpers"))] +pub use test_helpers::*;