Skip to content

Commit

Permalink
chore(gpu): simplify 4090 bench workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Jul 26, 2024
1 parent d3f2ecd commit 3ea3e2a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Run all benchmarks on an RTX 4090 machine and return parsed results to Slab CI bot.
name: TFHE Cuda Backend - 4090 full benchmarks
# Run benchmarks on an RTX 4090 machine and return parsed results to Slab CI bot.
name: TFHE Cuda Backend - 4090 benchmarks

env:
CARGO_TERM_COLOR: always
Expand All @@ -11,6 +11,7 @@ env:
SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png
SLACK_USERNAME: ${{ secrets.BOT_USERNAME }}
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
FAST_BENCH: TRUE

on:
# Allows you to run this workflow manually from the Actions tab as an alternative.
Expand All @@ -23,7 +24,7 @@ on:

jobs:
cuda-integer-benchmarks:
name: Cuda integer benchmarks for all operations flavor (RTX 4090)
name: Cuda integer benchmarks (RTX 4090)
if: ${{ github.event_name == 'workflow_dispatch' ||
github.event_name == 'schedule' && github.repository == 'zama-ai/tfhe-rs' ||
contains(github.event.label.name, '4090_bench') }}
Expand All @@ -35,9 +36,6 @@ jobs:
strategy:
fail-fast: false
max-parallel: 1
matrix:
command: [integer, integer_multi_bit]
op_flavor: [default, unchecked]

steps:
- name: Checkout tfhe-rs
Expand All @@ -52,6 +50,7 @@ jobs:
echo "COMMIT_DATE=$(git --no-pager show -s --format=%cd --date=iso8601-strict ${{ github.sha }})";
echo "COMMIT_HASH=$(git describe --tags --dirty)";
} >> "${GITHUB_ENV}"
echo "FAST_BENCH=TRUE" >> "${GITHUB_ENV}"
- name: Install rust
uses: dtolnay/rust-toolchain@21dc36fb71dd22e3317045c0c31a3f4249868b17
Expand All @@ -67,7 +66,7 @@ jobs:

- name: Run integer benchmarks
run: |
make BENCH_OP_FLAVOR=${{ matrix.op_flavor }} bench_${{ matrix.command }}_gpu
make BENCH_OP_FLAVOR=default bench_integer_multi_bit_gpu
- name: Parse results
run: |
Expand All @@ -85,7 +84,7 @@ jobs:
- name: Upload parsed results artifact
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b
with:
name: ${{ github.sha }}_${{ matrix.command }}_${{ matrix.op_flavor }}
name: ${{ github.sha }}_integer_multi_bit_gpu_default
path: ${{ env.RESULTS_FILENAME }}

- name: Send data to Slab
Expand Down Expand Up @@ -146,7 +145,7 @@ jobs:
path: slab
token: ${{ secrets.FHE_ACTIONS_TOKEN }}

- name: Run integer benchmarks
- name: Run core crypto benchmarks
run: |
make bench_pbs_gpu
make bench_ks_gpu
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ bench_pbs128: install_rs_check_toolchain

.PHONY: bench_pbs_gpu # Run benchmarks for PBS on GPU backend
bench_pbs_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_FAST_BENCH=$(FAST_BENCH) cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench pbs-bench \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,gpu,internal-keycache,nightly-avx512 -p $(TFHE_SPEC)

Expand Down
24 changes: 14 additions & 10 deletions tfhe/benches/core_crypto/pbs_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ fn pbs_throughput<Scalar: UnsignedTorus + CastInto<usize> + Sync + Send + Serial
#[cfg(feature = "gpu")]
mod cuda {
use super::{multi_bit_benchmark_parameters_64bits, throughput_benchmark_parameters_64bits};
use crate::utilities::{write_to_json, CryptoParametersRecord, OperatorType};
use crate::utilities::{write_to_json, CryptoParametersRecord, EnvConfig, OperatorType};
use criterion::{black_box, Criterion};
use serde::Serialize;
use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
Expand Down Expand Up @@ -1181,13 +1181,17 @@ mod cuda {
&stream,
);

const NUM_CTS: usize = 8192;
let mut num_cts: usize = 8192;
let env_config = EnvConfig::new();
if env_config.is_fast_bench {
num_cts = 1024;
}

let plaintext_list = PlaintextList::new(Scalar::ZERO, PlaintextCount(NUM_CTS));
let plaintext_list = PlaintextList::new(Scalar::ZERO, PlaintextCount(num_cts));
let mut lwe_list = LweCiphertextList::new(
Scalar::ZERO,
params.lwe_dimension.unwrap().to_lwe_size(),
LweCiphertextCount(NUM_CTS),
LweCiphertextCount(num_cts),
params.ciphertext_modulus.unwrap(),
);
encrypt_lwe_ciphertext_list(
Expand All @@ -1208,7 +1212,7 @@ mod cuda {
let output_lwe_list = LweCiphertextList::new(
Scalar::ZERO,
big_lwe_dimension.to_lwe_size(),
LweCiphertextCount(NUM_CTS),
LweCiphertextCount(num_cts),
params.ciphertext_modulus.unwrap(),
);
let lwe_ciphertext_in_gpu =
Expand All @@ -1225,8 +1229,8 @@ mod cuda {

let mut out_pbs_ct_gpu =
CudaLweCiphertextList::from_lwe_ciphertext_list(&output_lwe_list, &stream);
let mut h_indexes: [Scalar; NUM_CTS] = [Scalar::ZERO; NUM_CTS];
let mut d_lut_indexes = unsafe { CudaVec::<Scalar>::new_async(NUM_CTS, &stream, 0) };
let mut h_indexes: Vec<Scalar> = vec![Scalar::ZERO; num_cts];
let mut d_lut_indexes = unsafe { CudaVec::<Scalar>::new_async(num_cts, &stream, 0) };
unsafe {
d_lut_indexes.copy_from_cpu_async(h_indexes.as_ref(), &stream, 0);
}
Expand All @@ -1235,15 +1239,15 @@ mod cuda {
*index = Scalar::cast_from(i);
}
stream.synchronize();
let mut d_input_indexes = unsafe { CudaVec::<Scalar>::new_async(NUM_CTS, &stream, 0) };
let mut d_output_indexes = unsafe { CudaVec::<Scalar>::new_async(NUM_CTS, &stream, 0) };
let mut d_input_indexes = unsafe { CudaVec::<Scalar>::new_async(num_cts, &stream, 0) };
let mut d_output_indexes = unsafe { CudaVec::<Scalar>::new_async(num_cts, &stream, 0) };
unsafe {
d_input_indexes.copy_from_cpu_async(h_indexes.as_ref(), &stream, 0);
d_output_indexes.copy_from_cpu_async(h_indexes.as_ref(), &stream, 0);
}
stream.synchronize();

let id = format!("{bench_name}::{name}::{NUM_CTS}chunk");
let id = format!("{bench_name}::{name}::{num_cts}chunk");
bench_group.bench_function(&id, |b| {
b.iter(|| {
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
Expand Down

0 comments on commit 3ea3e2a

Please sign in to comment.