Skip to content

Commit

Permalink
refactor(gpu): avoid synchronizations in the keybundle
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Sep 5, 2024
1 parent 4c707e7 commit d651f68
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
#include <vector>

template <typename Torus, class params>
__device__ Torus calculates_monomial_degree(const Torus *lwe_array_group,
uint32_t ggsw_idx,
uint32_t grouping_factor) {
__device__ uint32_t calculates_monomial_degree(const Torus *lwe_array_group,
uint32_t ggsw_idx,
uint32_t grouping_factor) {
Torus x = 0;
for (int i = 0; i < grouping_factor; i++) {
uint32_t mask_position = grouping_factor - (i + 1);
Expand All @@ -31,6 +31,13 @@ __device__ Torus calculates_monomial_degree(const Torus *lwe_array_group,
return modulus_switch(x, params::log2_degree + 1);
}

__device__ __forceinline__ int
get_start_ith_ggsw_offset(uint32_t polynomial_size, int glwe_dimension,
uint32_t level_count) {
return polynomial_size * (glwe_dimension + 1) * (glwe_dimension + 1) *
level_count;
}

template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_multi_bit_programmable_bootstrap_keybundle(
const Torus *__restrict__ lwe_array_in,
Expand Down Expand Up @@ -60,8 +67,6 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
uint32_t input_idx = blockIdx.x / lwe_chunk_size;

if (lwe_iteration < (lwe_dimension / grouping_factor)) {
//
Torus *accumulator = (Torus *)selected_memory;

const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[input_idx] * (lwe_dimension + 1)];
Expand All @@ -81,57 +86,52 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
bootstrapping_key, 0, rev_lwe_iteration, glwe_id, level_id,
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
const Torus *bsk_poly = bsk_slice + poly_id * params::degree;
const Torus *bsk_poly_ini = bsk_slice + poly_id * params::degree;

copy_polynomial<Torus, params::opt, params::degree / params::opt>(
bsk_poly, accumulator);
Torus reg_acc[params::opt];

// Accumulate the other terms
for (int g = 1; g < (1 << grouping_factor); g++) {
copy_polynomial_in_regs<Torus, params::opt, params::degree / params::opt>(
bsk_poly_ini, reg_acc);

const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
bootstrapping_key, g, rev_lwe_iteration, glwe_id, level_id,
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
const Torus *bsk_poly = bsk_slice + poly_id * params::degree;
int offset =
get_start_ith_ggsw_offset(polynomial_size, glwe_dimension, level_count);

// Calculates the monomial degree
// Precalculate the monomial degrees and store them in shared memory
uint32_t *monomial_degrees = (uint32_t *)selected_memory;
if (threadIdx.x < (1 << grouping_factor)) {
const Torus *lwe_array_group =
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
uint32_t monomial_degree = calculates_monomial_degree<Torus, params>(
lwe_array_group, g, grouping_factor);

synchronize_threads_in_block();
// Multiply by the bsk element
polynomial_accumulate_monic_monomial_mul<Torus>(
accumulator, bsk_poly, monomial_degree, threadIdx.x, params::degree,
params::opt, false);
monomial_degrees[threadIdx.x] = calculates_monomial_degree<Torus, params>(
lwe_array_group, threadIdx.x, grouping_factor);
}

synchronize_threads_in_block();

// Move accumulator to local memory
double2 temp[params::opt / 2];
int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
temp[i].x = __ll2double_rn((int64_t)accumulator[tid]);
temp[i].y =
__ll2double_rn((int64_t)accumulator[tid + params::degree / 2]);
temp[i].x /= (double)std::numeric_limits<Torus>::max();
temp[i].y /= (double)std::numeric_limits<Torus>::max();
tid += params::degree / params::opt;
// Accumulate the other terms
for (int g = 1; g < (1 << grouping_factor); g++) {

uint32_t monomial_degree = monomial_degrees[g];

const Torus *bsk_poly = bsk_poly_ini + g * offset;
// Multiply by the bsk element
polynomial_product_accumulate_by_monomial_nosync<Torus, params>(
reg_acc, bsk_poly, monomial_degree);
}
synchronize_threads_in_block(); // needed because we are going to reuse the
// shared memory for the fft

synchronize_threads_in_block();
// Move from local memory back to shared memory but as complex
tid = threadIdx.x;
int tid = threadIdx.x;
double2 *fft = (double2 *)selected_memory;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] = temp[i];
fft[tid] =
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
(double)std::numeric_limits<Torus>::max());
tid += params::degree / params::opt;
}
synchronize_threads_in_block();

NSMFFT_direct<HalfDegree<params>>(fft);

// lwe iteration
Expand Down
7 changes: 7 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ __device__ void copy_polynomial(const T *__restrict__ source, T *dst) {
tid = tid + block_size;
}
}
template <typename T, int elems_per_thread, int block_size>
__device__ void copy_polynomial_in_regs(const T *__restrict__ source, T *dst) {
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
dst[i] = source[threadIdx.x + i * block_size];
}
}

/*
* Receives num_poly concatenated polynomials of type T. For each:
Expand Down
25 changes: 25 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/polynomial_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,29 @@ __device__ void polynomial_accumulate_monic_monomial_mul(
}
}

template <typename T, class params>
__device__ void polynomial_product_accumulate_by_monomial_nosync(
T *result, const T *__restrict__ poly, uint32_t monomial_degree) {
// monomial_degree \in [0, 2 * params::degree)
int full_cycles_count = monomial_degree / params::degree;
int remainder_degrees = monomial_degree % params::degree;

// Every thread has a fixed position to track instead of "chasing" the
// position
#pragma unroll
for (int i = 0; i < params::opt; i++) {
int pos =
(threadIdx.x + i * (params::degree / params::opt) - monomial_degree) &
(params::degree - 1);

T element = poly[pos];
T x = SEL(element, -element, full_cycles_count % 2);
x = SEL(-x, x,
threadIdx.x + i * (params::degree / params::opt) >=
remainder_degrees);

result[i] += x;
}
}

#endif // CNCRT_POLYNOMIAL_MATH_H

0 comments on commit d651f68

Please sign in to comment.