Skip to content

Commit

Permalink
refactor(gpu): refactor sample extract and modulus switch to match CP…
Browse files Browse the repository at this point in the history
…U's version
  • Loading branch information
pdroalves authored and agnesLeroy committed Jul 25, 2024
1 parent 95d5036 commit 19dc0f0
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 76 deletions.
37 changes: 10 additions & 27 deletions backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,36 +39,19 @@ __device__ inline T round_to_closest_multiple(T x, uint32_t base_log,
}

template <typename T>
__device__ __forceinline__ void rescale_torus_element(T element, T &output,
uint32_t log_shift) {
output =
round((double)element / (double(std::numeric_limits<T>::max()) + 1.0) *
(double)log_shift);
}
__device__ __forceinline__ void modulus_switch(T input, T &output,
uint32_t log_modulus) {
constexpr uint32_t BITS = sizeof(T) * 8;

template <typename T>
__device__ __forceinline__ T rescale_torus_element(T element,
uint32_t log_shift) {
return round((double)element / (double(std::numeric_limits<T>::max()) + 1.0) *
(double)log_shift);
output = input + (((T)1) << (BITS - log_modulus - 1));
output >>= (BITS - log_modulus);
}

template <>
__device__ __forceinline__ void
rescale_torus_element<uint32_t>(uint32_t element, uint32_t &output,
uint32_t log_shift) {
output =
round(__uint2double_rn(element) /
(__uint2double_rn(std::numeric_limits<uint32_t>::max()) + 1.0) *
__uint2double_rn(log_shift));
template <typename T>
__device__ __forceinline__ T modulus_switch(T input, uint32_t log_modulus) {
T output;
modulus_switch(input, output, log_modulus);
return output;
}

template <>
__device__ __forceinline__ void
rescale_torus_element<uint64_t>(uint64_t element, uint64_t &output,
uint32_t log_shift) {
output = round(__ull2double_rn(element) /
(__ull2double_rn(std::numeric_limits<uint64_t>::max()) + 1.0) *
__uint2double_rn(log_shift));
}
#endif // CNCRT_TORUS_H
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ __global__ void device_programmable_bootstrap_amortized(

// Put "b", the body, in [0, 2N[
Torus b_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
2 * params::degree); // 2 * params::log2_degree + 1);
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
params::log2_degree + 1);

divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
Expand All @@ -105,8 +105,8 @@ __global__ void device_programmable_bootstrap_amortized(

// Put "a" in [0, 2N[ instead of Zq
Torus a_hat = 0;
rescale_torus_element(block_lwe_array_in[iteration], a_hat,
2 * params::degree); // 2 * params::log2_degree + 1);
modulus_switch(block_lwe_array_in[iteration], a_hat,
params::log2_degree + 1);

// Perform ACC * (X^ä - 1)
multiply_by_monomial_negacyclic_and_sub_polynomial<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ __global__ void device_programmable_bootstrap_cg(

// Put "b" in [0, 2N[
Torus b_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
2 * params::degree);
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
params::log2_degree + 1);

divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
Expand All @@ -106,8 +106,7 @@ __global__ void device_programmable_bootstrap_cg(

// Put "a" in [0, 2N[
Torus a_hat = 0;
rescale_torus_element(block_lwe_array_in[i], a_hat,
2 * params::degree); // 2 * params::log2_degree + 1);
modulus_switch(block_lwe_array_in[i], a_hat, params::log2_degree + 1);

// Perform ACC * (X^ä - 1)
multiply_by_monomial_negacyclic_and_sub_polynomial<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ __global__ void device_multi_bit_programmable_bootstrap_cg_accumulate(
if (lwe_offset == 0) {
// Put "b" in [0, 2N[
Torus b_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
2 * params::degree);
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
params::log2_degree + 1);

divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ __global__ void device_programmable_bootstrap_step_one(
// First iteration
// Put "b" in [0, 2N[
Torus b_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
2 * params::degree);
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
params::log2_degree + 1);
// The y-dimension is used to select the element of the GLWE this block will
// compute
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
Expand All @@ -93,8 +93,8 @@ __global__ void device_programmable_bootstrap_step_one(

// Put "a" in [0, 2N[
Torus a_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_iteration], a_hat,
2 * params::degree); // 2 * params::log2_degree + 1);
modulus_switch(block_lwe_array_in[lwe_iteration], a_hat,
params::log2_degree + 1); // 2 * params::log2_degree + 1);

synchronize_threads_in_block();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ __device__ Torus calculates_monomial_degree(const Torus *lwe_array_group,
x += selection_bit * lwe_array_group[i];
}

return rescale_torus_element(
x, 2 * params::degree); // 2 * params::log2_degree + 1);
return modulus_switch(x, params::log2_degree + 1);
}

template <typename Torus, class params, sharedMemDegree SMD>
Expand Down Expand Up @@ -204,8 +203,8 @@ __global__ void device_multi_bit_programmable_bootstrap_accumulate_step_one(
// Initializes the accumulator with the body of LWE
// Put "b" in [0, 2N[
Torus b_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
2 * params::degree);
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
params::log2_degree + 1);

divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ __global__ void device_programmable_bootstrap_tbc(

// Put "b" in [0, 2N[
Torus b_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
2 * params::degree);
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
params::log2_degree + 1);

divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
Expand All @@ -109,8 +109,8 @@ __global__ void device_programmable_bootstrap_tbc(

// Put "a" in [0, 2N[
Torus a_hat = 0;
rescale_torus_element(block_lwe_array_in[i], a_hat,
2 * params::degree); // 2 * params::log2_degree + 1);
modulus_switch(block_lwe_array_in[i], a_hat,
params::log2_degree + 1); // 2 * params::log2_degree + 1);

// Perform ACC * (X^ä - 1)
multiply_by_monomial_negacyclic_and_sub_polynomial<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ __global__ void device_multi_bit_programmable_bootstrap_tbc_accumulate(
if (lwe_offset == 0) {
// Put "b" in [0, 2N[
Torus b_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
2 * params::degree);
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
params::log2_degree + 1);

divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
Expand Down
38 changes: 13 additions & 25 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,48 +189,42 @@ __device__ void add_to_torus(double2 *m_values, Torus *result,
}
}

// Extracts the body of a GLWE.
// k is the offset to find the body element / polynomial in the lwe_array_out /
// accumulator
// Extracts the body of the nth-LWE in a GLWE.
template <typename Torus, class params>
__device__ void sample_extract_body(Torus *lwe_array_out, Torus *accumulator,
uint32_t k) {
uint32_t glwe_dimension, uint32_t nth = 0) {
// Set first coefficient of the accumulator as the body of the LWE sample
lwe_array_out[k * params::degree] = accumulator[k * params::degree];
lwe_array_out[glwe_dimension * params::degree] =
accumulator[glwe_dimension * params::degree + nth];
}

// Extracts the mask from num_poly polynomials individually
// Extracts the mask from the nth-LWE in a GLWE.
template <typename Torus, class params>
__device__ void sample_extract_mask(Torus *lwe_array_out, Torus *accumulator,
uint32_t num_poly = 1) {
uint32_t num_poly = 1, uint32_t nth = 0) {
for (int z = 0; z < num_poly; z++) {
Torus *lwe_array_out_slice =
(Torus *)lwe_array_out + (ptrdiff_t)(z * params::degree);
Torus *accumulator_slice =
(Torus *)accumulator + (ptrdiff_t)(z * params::degree);

// Set ACC = -ACC
int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
accumulator_slice[tid] = -accumulator_slice[tid];
tid = tid + params::degree / params::opt;
}
synchronize_threads_in_block();

// Reverse the accumulator
tid = threadIdx.x;
int tid = threadIdx.x;
Torus result[params::opt];
#pragma unroll
for (int i = 0; i < params::opt; i++) {
result[i] = accumulator_slice[params::degree - tid - 1];
tid = tid + params::degree / params::opt;
}
synchronize_threads_in_block();

// Set ACC = -ACC
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
accumulator_slice[tid] = result[i];
accumulator_slice[tid] =
SEL(-result[i], result[i], tid >= params::degree - nth);
tid = tid + params::degree / params::opt;
}
synchronize_threads_in_block();
Expand All @@ -244,23 +238,17 @@ __device__ void sample_extract_mask(Torus *lwe_array_out, Torus *accumulator,
// result[i] = -accumulator_slice[tid - 1 + params::degree];
// else
// result[i] = accumulator_slice[tid - 1];
int x = tid - 1 + SEL(0, params::degree, tid < 1);
int x = tid - 1 + SEL(0, params::degree - nth, tid < 1);
result[i] = SEL(1, -1, tid < 1) * accumulator_slice[x];
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
tid = threadIdx.x;
for (int i = 0; i < params::opt; i++) {
accumulator_slice[tid] = result[i];
tid += params::degree / params::opt;
}
synchronize_threads_in_block();

// Copy to the mask of the LWE sample
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
lwe_array_out_slice[tid] = accumulator_slice[tid];
lwe_array_out_slice[tid] = result[i];
tid = tid + params::degree / params::opt;
}
}
Expand Down

0 comments on commit 19dc0f0

Please sign in to comment.