Skip to content

Commit

Permalink
refactor(gpu): add restrict keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun authored and agnesLeroy committed Jul 19, 2024
1 parent ffb8b4f commit c1fcd95
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 121 deletions.
10 changes: 8 additions & 2 deletions backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,12 @@ __device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
int glwe_dimension,
uint32_t level_count);

template <typename T>
__device__ const T *get_ith_mask_kth_block(const T *ptr, int i, int k,
int level, uint32_t polynomial_size,
int glwe_dimension,
uint32_t level_count);

template <typename T>
__device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
uint32_t polynomial_size,
Expand All @@ -422,8 +428,8 @@ __device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
int glwe_dimension, uint32_t level_count);

template <typename T>
__device__ T *get_multi_bit_ith_lwe_gth_group_kth_block(
T *ptr, int g, int i, int k, int level, uint32_t grouping_factor,
__device__ const T *get_multi_bit_ith_lwe_gth_group_kth_block(
const T *ptr, int g, int i, int k, int level, uint32_t grouping_factor,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t level_count);

#endif
Expand Down
12 changes: 7 additions & 5 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
// threads in y are used to paralelize the lwe_dimension_in loop.
// shared memory is used to store intermediate results of the reduction.
template <typename Torus>
__global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_output_indexes,
Torus *lwe_array_in, Torus *lwe_input_indexes,
Torus *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, int gpu_offset) {
__global__ void
keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const Torus *__restrict__ ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
int gpu_offset) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x;

Expand Down
27 changes: 21 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/pbs/bootstraping_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ void cuda_convert_lwe_multi_bit_programmable_bootstrap_key_64(
}

// We need these lines so the compiler knows how to specialize these functions
template __device__ const uint64_t *
get_ith_mask_kth_block(const uint64_t *ptr, int i, int k, int level,
uint32_t polynomial_size, int glwe_dimension,
uint32_t level_count);
template __device__ const uint32_t *
get_ith_mask_kth_block(const uint32_t *ptr, int i, int k, int level,
uint32_t polynomial_size, int glwe_dimension,
uint32_t level_count);
template __device__ const double2 *
get_ith_mask_kth_block(const double2 *ptr, int i, int k, int level,
uint32_t polynomial_size, int glwe_dimension,
uint32_t level_count);
template __device__ uint64_t *get_ith_mask_kth_block(uint64_t *ptr, int i,
int k, int level,
uint32_t polynomial_size,
Expand All @@ -51,6 +63,7 @@ template __device__ double2 *get_ith_mask_kth_block(double2 *ptr, int i, int k,
uint32_t polynomial_size,
int glwe_dimension,
uint32_t level_count);

template __device__ uint64_t *get_ith_body_kth_block(uint64_t *ptr, int i,
int k, int level,
uint32_t polynomial_size,
Expand All @@ -67,10 +80,12 @@ template __device__ double2 *get_ith_body_kth_block(double2 *ptr, int i, int k,
int glwe_dimension,
uint32_t level_count);

template __device__ uint64_t *get_multi_bit_ith_lwe_gth_group_kth_block(
uint64_t *ptr, int g, int i, int k, int level, uint32_t grouping_factor,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t level_count);
template __device__ const uint64_t *get_multi_bit_ith_lwe_gth_group_kth_block(
const uint64_t *ptr, int g, int i, int k, int level,
uint32_t grouping_factor, uint32_t polynomial_size, uint32_t glwe_dimension,
uint32_t level_count);

template __device__ double2 *get_multi_bit_ith_lwe_gth_group_kth_block(
double2 *ptr, int g, int i, int k, int level, uint32_t grouping_factor,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t level_count);
template __device__ const double2 *get_multi_bit_ith_lwe_gth_group_kth_block(
const double2 *ptr, int g, int i, int k, int level,
uint32_t grouping_factor, uint32_t polynomial_size, uint32_t glwe_dimension,
uint32_t level_count);
23 changes: 18 additions & 5 deletions backends/tfhe-cuda-backend/cuda/src/pbs/bootstrapping_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ __device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
}

////////////////////////////////////////////////
template <typename T>
__device__ const T *get_ith_mask_kth_block(const T *ptr, int i, int k,
int level, uint32_t polynomial_size,
int glwe_dimension,
uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1)];
}

template <typename T>
__device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
uint32_t polynomial_size,
Expand All @@ -27,7 +39,6 @@ __device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
(glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1)];
}

template <typename T>
__device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
uint32_t polynomial_size,
Expand All @@ -50,14 +61,16 @@ __device__ inline int get_start_ith_lwe(uint32_t i, uint32_t grouping_factor,
}

template <typename T>
__device__ T *get_multi_bit_ith_lwe_gth_group_kth_block(
T *ptr, int g, int i, int k, int level, uint32_t grouping_factor,
__device__ const T *get_multi_bit_ith_lwe_gth_group_kth_block(
const T *ptr, int g, int i, int k, int level, uint32_t grouping_factor,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t level_count) {
T *ptr_group = ptr + get_start_ith_lwe(i, grouping_factor, polynomial_size,
glwe_dimension, level_count);
const T *ptr_group =
ptr + get_start_ith_lwe(i, grouping_factor, polynomial_size,
glwe_dimension, level_count);
return get_ith_mask_kth_block(ptr_group, g, k, level, polynomial_size,
glwe_dimension, level_count);
}

////////////////////////////////////////////////
template <typename T, typename ST>
void cuda_convert_lwe_programmable_bootstrap_key(cudaStream_t stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ get_join_buffer_element(int level_id, int glwe_id, G &group,
uint32_t glwe_dimension, bool support_dsm);

template <typename Torus, typename G, class params>
__device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,
double2 *join_buffer, double2 *bootstrapping_key,
int polynomial_size, uint32_t glwe_dimension,
int level_count, int iteration, G &group,
bool support_dsm = false) {
__device__ void
mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *join_buffer,
const double2 *__restrict__ bootstrapping_key,
int polynomial_size, uint32_t glwe_dimension, int level_count,
int iteration, G &group, bool support_dsm = false) {

// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(fft);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ template <typename Torus, class params, sharedMemDegree SMD>
* is not FULLSM
*/
__global__ void device_programmable_bootstrap_amortized(
Torus *lwe_array_out, Torus *lwe_output_indexes, Torus *lut_vector,
Torus *lut_vector_indexes, Torus *lwe_array_in, Torus *lwe_input_indexes,
double2 *bootstrapping_key, int8_t *device_mem, uint32_t glwe_dimension,
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t lwe_idx,
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lut_vector,
const Torus *__restrict__ lut_vector_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const double2 *__restrict__ bootstrapping_key, int8_t *device_mem,
uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t lwe_idx,
size_t device_memory_size_per_sample, uint32_t gpu_offset) {
// We use shared memory for the polynomials that are used often during the
// bootstrap, since shared memory is kept in L1 cache and accessing it is
Expand Down Expand Up @@ -81,7 +84,7 @@ __global__ void device_programmable_bootstrap_amortized(
auto block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[blockIdx.x + gpu_offset] *
(lwe_dimension + 1)];
Torus *block_lut_vector =
const Torus *block_lut_vector =
&lut_vector[lut_vector_indexes[lwe_idx + blockIdx.x] * params::degree *
(glwe_dimension + 1)];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ namespace cg = cooperative_groups;
*/
template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_programmable_bootstrap_cg(
Torus *lwe_array_out, Torus *lwe_output_indexes, Torus *lut_vector,
Torus *lut_vector_indexes, Torus *lwe_array_in, Torus *lwe_input_indexes,
double2 *bootstrapping_key, double2 *join_buffer, uint32_t lwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
int8_t *device_mem, uint64_t device_memory_size_per_block,
uint32_t gpu_offset) {
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lut_vector,
const Torus *__restrict__ lut_vector_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const double2 *__restrict__ bootstrapping_key, double2 *join_buffer,
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, int8_t *device_mem,
uint64_t device_memory_size_per_block, uint32_t gpu_offset) {

grid_group grid = this_grid();

Expand Down Expand Up @@ -74,12 +77,13 @@ __global__ void device_programmable_bootstrap_cg(

// The third dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
Torus *block_lwe_array_in =
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[blockIdx.z + gpu_offset] *
(lwe_dimension + 1)];

Torus *block_lut_vector = &lut_vector[lut_vector_indexes[blockIdx.z] *
params::degree * (glwe_dimension + 1)];
const Torus *block_lut_vector =
&lut_vector[lut_vector_indexes[blockIdx.z] * params::degree *
(glwe_dimension + 1)];

double2 *block_join_buffer =
&join_buffer[blockIdx.z * level_count * (glwe_dimension + 1) *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_multi_bit_programmable_bootstrap_cg_accumulate(
Torus *lwe_array_out, Torus *lwe_output_indexes, Torus *lut_vector,
Torus *lut_vector_indexes, Torus *lwe_array_in, Torus *lwe_input_indexes,
double2 *keybundle_array, double2 *join_buffer, Torus *global_accumulator,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t grouping_factor,
uint32_t lwe_offset, uint32_t lwe_chunk_size,
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lut_vector,
const Torus *__restrict__ lut_vector_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const double2 *__restrict__ keybundle_array, double2 *join_buffer,
Torus *global_accumulator, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t grouping_factor, uint32_t lwe_offset, uint32_t lwe_chunk_size,
uint32_t keybundle_size_per_input, int8_t *device_mem,
uint64_t device_memory_size_per_block, uint32_t gpu_offset) {

Expand Down Expand Up @@ -54,12 +57,13 @@ __global__ void device_multi_bit_programmable_bootstrap_cg_accumulate(

// The third dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
Torus *block_lwe_array_in =
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[blockIdx.z + gpu_offset] *
(lwe_dimension + 1)];

Torus *block_lut_vector = &lut_vector[lut_vector_indexes[blockIdx.z] *
params::degree * (glwe_dimension + 1)];
const Torus *block_lut_vector =
&lut_vector[lut_vector_indexes[blockIdx.z] * params::degree *
(glwe_dimension + 1)];

double2 *block_join_buffer =
&join_buffer[blockIdx.z * level_count * (glwe_dimension + 1) *
Expand All @@ -69,9 +73,9 @@ __global__ void device_multi_bit_programmable_bootstrap_cg_accumulate(
global_accumulator +
(blockIdx.y + blockIdx.z * (glwe_dimension + 1)) * params::degree;

double2 *keybundle = keybundle_array +
// select the input
blockIdx.z * keybundle_size_per_input;
const double2 *keybundle = keybundle_array +
// select the input
blockIdx.z * keybundle_size_per_input;

if (lwe_offset == 0) {
// Put "b" in [0, 2N[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@

template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_programmable_bootstrap_step_one(
Torus *lut_vector, Torus *lut_vector_indexes, Torus *lwe_array_in,
Torus *lwe_input_indexes, double2 *bootstrapping_key,
Torus *global_accumulator, double2 *global_accumulator_fft,
uint32_t lwe_iteration, uint32_t lwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, int8_t *device_mem,
const Torus *__restrict__ lut_vector,
const Torus *__restrict__ lut_vector_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const double2 *__restrict__ bootstrapping_key, Torus *global_accumulator,
double2 *global_accumulator_fft, uint32_t lwe_iteration,
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, int8_t *device_mem,
uint64_t device_memory_size_per_block, uint32_t gpu_offset) {

// We use shared memory for the polynomials that are used often during the
Expand Down Expand Up @@ -50,12 +53,13 @@ __global__ void device_programmable_bootstrap_step_one(

// The third dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
Torus *block_lwe_array_in =
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[blockIdx.z + gpu_offset] *
(lwe_dimension + 1)];

Torus *block_lut_vector = &lut_vector[lut_vector_indexes[blockIdx.z] *
params::degree * (glwe_dimension + 1)];
const Torus *block_lut_vector =
&lut_vector[lut_vector_indexes[blockIdx.z] * params::degree *
(glwe_dimension + 1)];

Torus *global_slice =
global_accumulator +
Expand Down Expand Up @@ -129,11 +133,13 @@ __global__ void device_programmable_bootstrap_step_one(

template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_programmable_bootstrap_step_two(
Torus *lwe_array_out, Torus *lwe_output_indexes, Torus *lut_vector,
Torus *lut_vector_indexes, double2 *bootstrapping_key,
Torus *global_accumulator, double2 *global_accumulator_fft,
uint32_t lwe_iteration, uint32_t lwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, int8_t *device_mem,
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lut_vector,
const Torus *__restrict__ lut_vector_indexes,
const double2 *__restrict__ bootstrapping_key, Torus *global_accumulator,
double2 *global_accumulator_fft, uint32_t lwe_iteration,
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, int8_t *device_mem,
uint64_t device_memory_size_per_block, uint32_t gpu_offset) {

// We use shared memory for the polynomials that are used often during the
Expand Down
Loading

0 comments on commit c1fcd95

Please sign in to comment.