From 2b74d5f3c61654e7b354cb7956a84ba7e330f7c2 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Thu, 19 Dec 2024 10:41:02 +0100 Subject: [PATCH] chore(gpu): add inputs to erc20 throughput bench with multiple GPUs --- tfhe/benches/high_level_api/erc20.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tfhe/benches/high_level_api/erc20.rs b/tfhe/benches/high_level_api/erc20.rs index a53d325f11..295a40c6ea 100644 --- a/tfhe/benches/high_level_api/erc20.rs +++ b/tfhe/benches/high_level_api/erc20.rs @@ -221,8 +221,12 @@ fn bench_transfer_throughput( F: for<'a> Fn(&'a FheType, &'a FheType, &'a FheType) -> (FheType, FheType) + Sync, { let mut rng = thread_rng(); + #[cfg(not(feature = "gpu"))] + let num_gpus = 1u64; + #[cfg(feature = "gpu")] + let num_gpus = unsafe { cuda_get_number_of_gpus() } as u64; - for num_elems in [10, 100, 500] { + for num_elems in [10 * num_gpus, 100 * num_gpus, 500 * num_gpus] { group.throughput(Throughput::Elements(num_elems)); let bench_id = format!("{bench_name}::{fn_name}::{type_name}::{num_elems}_elems"); group.bench_with_input(&bench_id, &num_elems, |b, &num_elems| { @@ -262,6 +266,8 @@ fn bench_transfer_throughput( #[cfg(feature = "pbs-stats")] use pbs_stats::print_transfer_pbs_counts; +#[cfg(feature = "gpu")] +use tfhe_cuda_backend::cuda_bind::cuda_get_number_of_gpus; fn main() { #[cfg(not(feature = "gpu"))]