Skip to content

Commit

Permalink
feat(gpu): Implements a classical PBS variant that uses thread block …
Browse files Browse the repository at this point in the history
…cluster and distributed shared memory
  • Loading branch information
pdroalves authored and agnesLeroy committed Apr 26, 2024
1 parent 05527c9 commit 20e11ea
Show file tree
Hide file tree
Showing 12 changed files with 1,101 additions and 121 deletions.
1 change: 0 additions & 1 deletion backends/tfhe-cuda-backend/cuda/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <cuda_runtime.h>

#define synchronize_threads_in_block() __syncthreads()

extern "C" {

#define check_cuda_error(ans) \
Expand Down
105 changes: 100 additions & 5 deletions backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,28 @@ get_buffer_size_partial_sm_programmable_bootstrap(uint32_t polynomial_size) {
return sizeof(double2) * polynomial_size / 2; // accumulator fft
}

template <typename Torus>
__host__ __device__ uint64_t
get_buffer_size_full_sm_programmable_bootstrap_tbc(uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size + // accumulator_rotated
sizeof(Torus) * polynomial_size + // accumulator
sizeof(double2) * polynomial_size / 2; // accumulator fft
}

template <typename Torus>
__host__ __device__ uint64_t
get_buffer_size_partial_sm_programmable_bootstrap_tbc(
uint32_t polynomial_size) {
return sizeof(double2) * polynomial_size / 2; // accumulator fft mask & body
}

template <typename Torus>
__host__ __device__ uint64_t
get_buffer_size_sm_dsm_plus_tbc_classic_programmable_bootstrap(
uint32_t polynomial_size) {
return sizeof(double2) * polynomial_size / 2; // tbc
}

template <typename Torus>
__host__ __device__ uint64_t
get_buffer_size_full_sm_programmable_bootstrap_cg(uint32_t polynomial_size) {
Expand All @@ -125,6 +147,11 @@ get_buffer_size_partial_sm_programmable_bootstrap_cg(uint32_t polynomial_size) {
return sizeof(double2) * polynomial_size / 2; // accumulator fft mask & body
}

template <typename Torus>
__host__ bool
supports_distributed_shared_memory_on_classic_programmable_bootstrap(
uint32_t polynomial_size, uint32_t max_shared_memory);

template <typename Torus, PBS_TYPE pbs_type> struct pbs_buffer;

template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {
Expand Down Expand Up @@ -213,6 +240,54 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {
polynomial_size / 2 * sizeof(double2),
stream);
} break;
#if CUDA_ARCH >= 900
case PBS_VARIANT::TBC: {

bool supports_dsm =
supports_distributed_shared_memory_on_classic_programmable_bootstrap<
Torus>(polynomial_size, max_shared_memory);

uint64_t full_sm =
get_buffer_size_full_sm_programmable_bootstrap_tbc<Torus>(
polynomial_size);
uint64_t partial_sm =
get_buffer_size_partial_sm_programmable_bootstrap_tbc<Torus>(
polynomial_size);
uint64_t minimum_sm_tbc = 0;
if (supports_dsm)
minimum_sm_tbc =
get_buffer_size_sm_dsm_plus_tbc_classic_programmable_bootstrap<
Torus>(polynomial_size);

uint64_t partial_dm = full_sm - partial_sm;
uint64_t full_dm = full_sm;
uint64_t device_mem = 0;

// There is a minimum amount of memory we need to run the TBC PBS, which
// is minimum_sm_tbc. We know that minimum_sm_tbc bytes are available
// because otherwise the previous check would have redirected
// computation to some other variant. If over that we don't have more
// partial_sm bytes, TBC PBS will run on NOSM. If we have partial_sm but
// not full_sm bytes, it will run on PARTIALSM. Otherwise, FULLSM.
//
// NOSM mode actually requires minimum_sm_tbc shared memory bytes.
if (max_shared_memory < partial_sm + minimum_sm_tbc) {
device_mem = full_dm * input_lwe_ciphertext_count * level_count *
(glwe_dimension + 1);
} else if (max_shared_memory < full_sm + minimum_sm_tbc) {
device_mem = partial_dm * input_lwe_ciphertext_count * level_count *
(glwe_dimension + 1);
}

// Otherwise, both kernels run all in shared memory
d_mem = (int8_t *)cuda_malloc_async(device_mem, stream);

global_accumulator_fft = (double2 *)cuda_malloc_async(
(glwe_dimension + 1) * level_count * input_lwe_ciphertext_count *
polynomial_size / 2 * sizeof(double2),
stream);
} break;
#endif
default:
PANIC("Cuda error (PBS): unsupported implementation variant.")
}
Expand Down Expand Up @@ -281,6 +356,25 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector(
uint32_t level_count, uint32_t num_samples, uint32_t num_luts,
uint32_t lwe_idx, uint32_t max_shared_memory);

#if (CUDA_ARCH >= 900)
template <typename Torus>
void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector(
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_output_indexes,
Torus *lut_vector, Torus *lut_vector_indexes, Torus *lwe_array_in,
Torus *lwe_input_indexes, double2 *bootstrapping_key,
pbs_buffer<Torus, CLASSICAL> *buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, uint32_t num_luts,
uint32_t lwe_idx, uint32_t max_shared_memory);

template <typename Torus, typename STorus>
void scratch_cuda_programmable_bootstrap_tbc(
cuda_stream_t *stream, pbs_buffer<Torus, CLASSICAL> **pbs_buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, uint32_t max_shared_memory,
bool allocate_gpu_memory);
#endif

template <typename Torus, typename STorus>
void scratch_cuda_programmable_bootstrap_cg(
cuda_stream_t *stream, pbs_buffer<Torus, CLASSICAL> **pbs_buffer,
Expand All @@ -295,11 +389,12 @@ void scratch_cuda_programmable_bootstrap(
uint32_t input_lwe_ciphertext_count, uint32_t max_shared_memory,
bool allocate_gpu_memory);

template <typename G>
__device__ int get_this_block_rank(G &group, bool support_dsm);
template <typename G, class params>
__device__ double2 *get_join_buffer_element(int i, G &group, bool support_dsm,
double2 *global_memory_buffer);
template <typename Torus>
bool has_support_to_cuda_programmable_bootstrap_tbc(uint32_t num_samples,
uint32_t glwe_dimension,
uint32_t polynomial_size,
uint32_t level_count,
uint32_t max_shared_memory);

#ifdef __CUDACC__
__device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ supports_distributed_shared_memory_on_multibit_programmable_bootstrap(

template <typename Torus>
bool has_support_to_cuda_programmable_bootstrap_tbc_multi_bit(
uint32_t polynomial_size, uint32_t max_shared_memory);
uint32_t num_samples, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t max_shared_memory);

#if CUDA_ARCH >= 900
template <typename Torus, typename STorus>
void scratch_cuda_tbc_multi_bit_programmable_bootstrap(
cuda_stream_t *stream, pbs_buffer<Torus, MULTI_BIT> **buffer,
Expand All @@ -78,6 +80,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory,
uint32_t lwe_chunk_size);
#endif

template <typename Torus, typename STorus>
void scratch_cuda_cg_multi_bit_programmable_bootstrap(
Expand Down Expand Up @@ -205,19 +208,6 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
auto num_blocks_acc_cg =
level_count * (glwe_dimension + 1) * input_lwe_ciphertext_count;

#if CUDA_ARCH >= 900
uint64_t full_sm_tbc_accumulate =
get_buffer_size_full_sm_tbc_multibit_programmable_bootstrap<Torus>(
polynomial_size);
uint64_t partial_sm_tbc_accumulate =
get_buffer_size_partial_sm_tbc_multibit_programmable_bootstrap<Torus>(
polynomial_size);
uint64_t minimum_sm_tbc =
get_buffer_size_sm_dsm_plus_tbc_multibit_programmable_bootstrap<Torus>(
polynomial_size);
auto num_blocks_acc_tbc = num_blocks_acc_cg;
#endif

if (allocate_gpu_memory) {
// Keybundle
if (max_shared_memory < full_sm_keybundle)
Expand Down Expand Up @@ -250,6 +240,28 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
break;
#if CUDA_ARCH >= 900
case TBC:

uint64_t full_sm_tbc_accumulate =
get_buffer_size_full_sm_tbc_multibit_programmable_bootstrap<Torus>(
polynomial_size);
uint64_t partial_sm_tbc_accumulate =
get_buffer_size_partial_sm_tbc_multibit_programmable_bootstrap<
Torus>(polynomial_size);
uint64_t minimum_sm_tbc =
get_buffer_size_sm_dsm_plus_tbc_multibit_programmable_bootstrap<
Torus>(polynomial_size);
auto num_blocks_acc_tbc = num_blocks_acc_cg;

// There is a minimum amount of memory we need to run the TBC PBS, which
// is minimum_sm_tbc. We know that minimum_sm_tbc bytes are available
// because otherwise the previous check would have redirected
// computation to some other variant. If over that we don't have more
// partial_sm_tbc_accumulate bytes, TBC PBS will run on NOSM. If we have
// partial_sm_tbc_accumulate but not full_sm_tbc_accumulate bytes, it
// will run on PARTIALSM. Otherwise, FULLSM.
//
// NOSM mode actually requires minimum_sm_tbc shared memory bytes.

// Accumulator TBC
if (max_shared_memory < partial_sm_tbc_accumulate + minimum_sm_tbc)
d_mem_acc_tbc = (int8_t *)cuda_malloc_async(
Expand Down Expand Up @@ -306,7 +318,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
};

template <typename Torus, class params>
__host__ uint32_t get_lwe_chunk_size(int gpu_index, uint32_t max_num_pbs,
__host__ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
uint32_t polynomial_size,
uint32_t max_shared_memory);

Expand Down
52 changes: 28 additions & 24 deletions backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cu
Original file line number Diff line number Diff line change
@@ -1,40 +1,44 @@
#include "programmable_bootstrap.cuh"


template <> __device__ int get_this_block_rank(grid_group &group, bool support_dsm) {
template <>
__device__ int get_this_block_rank(grid_group &group, bool support_dsm) {
return blockIdx.y;
}

template <> __device__ int get_this_block_rank(cluster_group &cluster, bool support_dsm) {
template <>
__device__ double2 *
get_join_buffer_element(int level_id, int glwe_id, grid_group &group,
double2 *global_memory_buffer, uint32_t polynomial_size,
uint32_t glwe_dimension, bool support_dsm) {
double2 *buffer_slice =
global_memory_buffer +
(glwe_id + level_id * (glwe_dimension + 1)) * polynomial_size / 2;
return buffer_slice;
}

#if CUDA_ARCH >= 900
template <>
__device__ int get_this_block_rank(cluster_group &cluster, bool support_dsm) {
if (support_dsm)
return cluster.block_rank();
else
return blockIdx.y;
}

template<> __device__ double2 *get_join_buffer_element(int i, grid_group &group,
bool support_dsm,
double2 *global_memory_buffer, uint32_t
polynomial_size) {
double2 *buffer_slice = global_memory_buffer + i * polynomial_size / 2;
return buffer_slice;
}

template<> __device__ double2 *get_join_buffer_element(int i, cluster_group &cluster,
bool support_dsm,
double2 *global_memory_buffer, uint32_t
polynomial_size) {
#if CUDA_ARCH < 900
double2 *buffer_slice =
global_memory_buffer + blockIdx.y * polynomial_size / 2;
#else
template <>
__device__ double2 *
get_join_buffer_element(int level_id, int glwe_id, cluster_group &cluster,
double2 *global_memory_buffer, uint32_t polynomial_size,
uint32_t glwe_dimension, bool support_dsm) {
double2 *buffer_slice;
if (support_dsm) {
extern __shared__ double2 smem[];
buffer_slice = cluster.map_shared_rank(smem, i);
buffer_slice = cluster.map_shared_rank(
smem, glwe_id + level_id * (glwe_dimension + 1));
} else {
buffer_slice = global_memory_buffer + i * polynomial_size / 2;
buffer_slice =
global_memory_buffer +
(glwe_id + level_id * (glwe_dimension + 1)) * polynomial_size / 2;
}
#endif
return buffer_slice;
}
}
#endif
31 changes: 17 additions & 14 deletions backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ template <typename G>
__device__ int get_this_block_rank(G &group, bool support_dsm);

template <typename G>
__device__ double2 *get_join_buffer_element(int i, G &group, bool support_dsm,
double2 *global_memory_buffer, uint32_t polynomial_size);
__device__ double2 *
get_join_buffer_element(int level_id, int glwe_id, G &group,
double2 *global_memory_buffer, uint32_t polynomial_size,
uint32_t glwe_dimension, bool support_dsm);

template <typename Torus, typename G, class params>
__device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,
Expand All @@ -38,18 +40,15 @@ __device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,
bootstrapping_key, iteration, blockIdx.y, blockIdx.x, polynomial_size,
glwe_dimension, level_count);

// Selects all GLWEs in a particular decomposition level
auto level_join_buffer =
join_buffer + blockIdx.x * (glwe_dimension + 1) * params::degree / 2;

// Perform the matrix multiplication between the GGSW and the GLWE,
// each block operating on a single level for mask and body

// The first product is used to initialize level_join_buffer
auto bsk_poly = bsk_slice + blockIdx.y * params::degree / 2;
auto this_block_rank = get_this_block_rank<G>(group, support_dsm);
auto buffer_slice = get_join_buffer_element<G>(
this_block_rank, group, support_dsm, level_join_buffer, polynomial_size);
auto buffer_slice =
get_join_buffer_element<G>(blockIdx.x, blockIdx.y, group, join_buffer,
polynomial_size, glwe_dimension, support_dsm);

int tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
Expand All @@ -65,8 +64,9 @@ __device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,
int idx = (j + this_block_rank) % (glwe_dimension + 1);

auto bsk_poly = bsk_slice + idx * params::degree / 2;
auto buffer_slice = get_join_buffer_element<G>(
idx, group, support_dsm, level_join_buffer, polynomial_size);
auto buffer_slice = get_join_buffer_element<G>(blockIdx.x, idx, group,
join_buffer, polynomial_size,
glwe_dimension, support_dsm);

int tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
Expand All @@ -80,8 +80,9 @@ __device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,
// All blocks are synchronized here; after this sync, level_join_buffer has
// the values needed from every other block

auto src_acc = get_join_buffer_element<G>(blockIdx.y, group,
support_dsm, join_buffer, polynomial_size);
auto src_acc =
get_join_buffer_element<G>(0, blockIdx.y, group, join_buffer,
polynomial_size, glwe_dimension, support_dsm);

// copy first product into fft buffer
tid = threadIdx.x;
Expand All @@ -93,7 +94,9 @@ __device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,

// accumulate rest of the products into fft buffer
for (int l = 1; l < gridDim.x; l++) {
auto cur_src_acc = &src_acc[l * (glwe_dimension + 1) * params::degree / 2];
auto cur_src_acc = get_join_buffer_element<G>(l, blockIdx.y, group,
join_buffer, polynomial_size,
glwe_dimension, support_dsm);
tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] += cur_src_acc[tid];
Expand Down Expand Up @@ -222,4 +225,4 @@ void execute_scratch_pbs(cuda_stream_t *stream, int8_t **pbs_buffer,
}
}

#endif
#endif
Loading

0 comments on commit 20e11ea

Please sign in to comment.