Skip to content

Commit

Permalink
fix(gpu): general fixes and improvements to PBS
Browse files Browse the repository at this point in the history
- update pbs test parameters to match tfhe-rs' integer tests
- refactor mul_ggsw_glwe to make it easier to read
- fix the way we accumulate the external product result on multi-bit PBS
  • Loading branch information
pdroalves authored and agnesLeroy committed Nov 13, 2024
1 parent eac3002 commit b041608
Show file tree
Hide file tree
Showing 15 changed files with 413 additions and 568 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
uint32_t lwe_chunk_size;
double2 *keybundle_fft;
Torus *global_accumulator;
double2 *global_accumulator_fft;
double2 *global_join_buffer;

PBS_VARIANT pbs_variant;

Expand Down Expand Up @@ -225,10 +225,12 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
num_blocks_keybundle * (polynomial_size / 2) * sizeof(double2),
stream, gpu_index);
global_accumulator = (Torus *)cuda_malloc_async(
num_blocks_acc_step_one * polynomial_size * sizeof(Torus), stream,
gpu_index);
global_accumulator_fft = (double2 *)cuda_malloc_async(
num_blocks_acc_step_one * (polynomial_size / 2) * sizeof(double2),
input_lwe_ciphertext_count * (glwe_dimension + 1) * polynomial_size *
sizeof(Torus),
stream, gpu_index);
global_join_buffer = (double2 *)cuda_malloc_async(
level_count * (glwe_dimension + 1) * input_lwe_ciphertext_count *
(polynomial_size / 2) * sizeof(double2),
stream, gpu_index);
}
}
Expand Down Expand Up @@ -260,7 +262,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {

cuda_drop_async(keybundle_fft, stream, gpu_index);
cuda_drop_async(global_accumulator, stream, gpu_index);
cuda_drop_async(global_accumulator_fft, stream, gpu_index);
cuda_drop_async(global_join_buffer, stream, gpu_index);
}
};

Expand Down
10 changes: 5 additions & 5 deletions backends/tfhe-cuda-backend/cuda/include/pbs/pbs_utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {
int8_t *d_mem;

Torus *global_accumulator;
double2 *global_accumulator_fft;
double2 *global_join_buffer;

PBS_VARIANT pbs_variant;

Expand Down Expand Up @@ -114,7 +114,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {
// Otherwise, both kernels run all in shared memory
d_mem = (int8_t *)cuda_malloc_async(device_mem, stream, gpu_index);

global_accumulator_fft = (double2 *)cuda_malloc_async(
global_join_buffer = (double2 *)cuda_malloc_async(
(glwe_dimension + 1) * level_count * input_lwe_ciphertext_count *
(polynomial_size / 2) * sizeof(double2),
stream, gpu_index);
Expand Down Expand Up @@ -147,7 +147,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {
// Otherwise, both kernels run all in shared memory
d_mem = (int8_t *)cuda_malloc_async(device_mem, stream, gpu_index);

global_accumulator_fft = (double2 *)cuda_malloc_async(
global_join_buffer = (double2 *)cuda_malloc_async(
(glwe_dimension + 1) * level_count * input_lwe_ciphertext_count *
polynomial_size / 2 * sizeof(double2),
stream, gpu_index);
Expand Down Expand Up @@ -194,7 +194,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {
// Otherwise, both kernels run all in shared memory
d_mem = (int8_t *)cuda_malloc_async(device_mem, stream, gpu_index);

global_accumulator_fft = (double2 *)cuda_malloc_async(
global_join_buffer = (double2 *)cuda_malloc_async(
(glwe_dimension + 1) * level_count * input_lwe_ciphertext_count *
polynomial_size / 2 * sizeof(double2),
stream, gpu_index);
Expand All @@ -208,7 +208,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {

void release(cudaStream_t stream, uint32_t gpu_index) {
cuda_drop_async(d_mem, stream, gpu_index);
cuda_drop_async(global_accumulator_fft, stream, gpu_index);
cuda_drop_async(global_join_buffer, stream, gpu_index);

if (pbs_variant == DEFAULT)
cuda_drop_async(global_accumulator, stream, gpu_index);
Expand Down
9 changes: 3 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/crypto/gadget.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CNCRT_CRYPTO_CUH
#define CNCRT_CRPYTO_CUH

#include "crypto/torus.cuh"
#include "device.h"
#include <cstdint>

Expand All @@ -21,7 +22,6 @@ private:
uint32_t base_log;
uint32_t mask;
uint32_t num_poly;
int current_level;
T mask_mod_b;
T *state;

Expand All @@ -32,7 +32,6 @@ public:
state(state) {

mask_mod_b = (1ll << base_log) - 1ll;
current_level = level_count;
int tid = threadIdx.x;
for (int i = 0; i < num_poly * params::opt; i++) {
state[tid] >>= (sizeof(T) * 8 - base_log * level_count);
Expand All @@ -52,8 +51,6 @@ public:
// Decomposes a single polynomial
__device__ void decompose_and_compress_next_polynomial(double2 *result,
int j) {
if (j == 0)
current_level -= 1;

int tid = threadIdx.x;
auto state_slice = state + j * params::degree;
Expand All @@ -72,8 +69,8 @@ public:
res_re -= carry_re << base_log;
res_im -= carry_im << base_log;

result[tid].x = (int32_t)res_re;
result[tid].y = (int32_t)res_im;
typecast_torus_to_double(res_re, result[tid].x);
typecast_torus_to_double(res_im, result[tid].y);

tid += params::degree / params::opt;
}
Expand Down
16 changes: 16 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CNCRT_TORUS_CUH
#define CNCRT_TORUS_CUH

#include "device.h"
#include "polynomial/parameters.cuh"
#include "types/int128.cuh"
#include "utils/kernel_dimensions.cuh"
Expand Down Expand Up @@ -43,6 +44,21 @@ __device__ inline void typecast_double_round_to_torus(double x, T &r) {
typecast_double_to_torus(round(frac), r);
}

template <typename T>
__device__ inline void typecast_torus_to_double(T x, double &r);

template <>
__device__ inline void typecast_torus_to_double<uint32_t>(uint32_t x,
double &r) {
r = __int2double_rn(x);
}

template <>
__device__ inline void typecast_torus_to_double<uint64_t>(uint64_t x,
double &r) {
r = __ll2double_rn(x);
}

template <typename T>
__device__ inline T round_to_closest_multiple(T x, uint32_t base_log,
uint32_t level_count) {
Expand Down
95 changes: 28 additions & 67 deletions backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "fft/bnsmfft.cuh"
#include "helper_multi_gpu.h"
#include "pbs/programmable_bootstrap_multibit.h"
#include "polynomial/polynomial_math.cuh"

using namespace cooperative_groups;
namespace cg = cooperative_groups;
Expand All @@ -20,100 +21,60 @@ get_join_buffer_element(int level_id, int glwe_id, G &group,
double2 *global_memory_buffer, uint32_t polynomial_size,
uint32_t glwe_dimension, bool support_dsm);

template <typename Torus, typename G, class params>
/** Perform the matrix multiplication between the GGSW and the GLWE,
* each block operating on a single level for mask and body.
* Both operands should be at fourier domain
*
* This function assumes:
* - Thread blocks at dimension x relates to the decomposition level.
* - Thread blocks at dimension y relates to the glwe dimension.
* - polynomial_size / params::opt threads are available per block
*/
template <typename G, class params>
__device__ void
mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *join_buffer,
const double2 *__restrict__ bootstrapping_key,
int polynomial_size, uint32_t glwe_dimension, int level_count,
int iteration, G &group, bool support_dsm = false) {

// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(fft);
synchronize_threads_in_block();

// Get the pieces of the bootstrapping key that will be needed for the
// external product; blockIdx.x is the ID of the block that's executing
// this function, so we end up getting the lines of the bootstrapping key
// needed to perform the external product in this block (corresponding to
// the same decomposition level)
auto bsk_slice = get_ith_mask_kth_block(
bootstrapping_key, iteration, blockIdx.y, blockIdx.x, polynomial_size,
glwe_dimension, level_count);

// Perform the matrix multiplication between the GGSW and the GLWE,
// each block operating on a single level for mask and body
mul_ggsw_glwe_in_fourier_domain(double2 *fft, double2 *join_buffer,
const double2 *__restrict__ bootstrapping_key,
int iteration, G &group,
bool support_dsm = false) {
const uint32_t polynomial_size = params::degree;
const uint32_t glwe_dimension = gridDim.y - 1;
const uint32_t level_count = gridDim.x;

// The first product is used to initialize level_join_buffer
auto bsk_poly = bsk_slice + blockIdx.y * params::degree / 2;
auto this_block_rank = get_this_block_rank<G>(group, support_dsm);
auto buffer_slice =
get_join_buffer_element<G>(blockIdx.x, blockIdx.y, group, join_buffer,
polynomial_size, glwe_dimension, support_dsm);

int tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
buffer_slice[tid] = fft[tid] * bsk_poly[tid];
tid += params::degree / params::opt;
}

group.sync();

// Continues multiplying fft by every polynomial in that particular bsk level
// Each y-block accumulates in a different polynomial at each iteration
for (int j = 1; j < (glwe_dimension + 1); j++) {
auto bsk_slice = get_ith_mask_kth_block(
bootstrapping_key, iteration, blockIdx.y, blockIdx.x, polynomial_size,
glwe_dimension, level_count);
for (int j = 0; j < glwe_dimension + 1; j++) {
int idx = (j + this_block_rank) % (glwe_dimension + 1);

auto bsk_poly = bsk_slice + idx * params::degree / 2;
auto bsk_poly = bsk_slice + idx * polynomial_size / 2;
auto buffer_slice = get_join_buffer_element<G>(blockIdx.x, idx, group,
join_buffer, polynomial_size,
glwe_dimension, support_dsm);

int tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
buffer_slice[tid] += fft[tid] * bsk_poly[tid];
tid += params::degree / params::opt;
}
polynomial_product_accumulate_in_fourier_domain<params, double2>(
buffer_slice, fft, bsk_poly, j == 0);
group.sync();
}

// -----------------------------------------------------------------
// All blocks are synchronized here; after this sync, level_join_buffer has
// the values needed from every other block

auto src_acc =
get_join_buffer_element<G>(0, blockIdx.y, group, join_buffer,
polynomial_size, glwe_dimension, support_dsm);

// copy first product into fft buffer
tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] = src_acc[tid];
tid += params::degree / params::opt;
}
synchronize_threads_in_block();

// accumulate rest of the products into fft buffer
for (int l = 1; l < gridDim.x; l++) {
for (int l = 0; l < level_count; l++) {
auto cur_src_acc = get_join_buffer_element<G>(l, blockIdx.y, group,
join_buffer, polynomial_size,
glwe_dimension, support_dsm);
tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] += cur_src_acc[tid];
tid += params::degree / params::opt;
}
}

synchronize_threads_in_block();
polynomial_accumulate_in_fourier_domain<params>(fft, cur_src_acc, l == 0);
}

// Perform the inverse FFT on the result of the GGSW x GLWE and add to the
// accumulator
NSMFFT_inverse<HalfDegree<params>>(fft);
synchronize_threads_in_block();

add_to_torus<Torus, params>(fft, accumulator);

__syncthreads();
}

template <typename Torus>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,59 +129,59 @@ __global__ void device_programmable_bootstrap_cg(
GadgetMatrix<Torus, params> gadget_acc(base_log, level_count,
accumulator_rotated);
gadget_acc.decompose_and_compress_level(accumulator_fft, blockIdx.x);

// We are using the same memory space for accumulator_fft and
// accumulator_rotated, so we need to synchronize here to make sure they
// don't modify the same memory space at the same time
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();

// Perform G^-1(ACC) * GGSW -> GLWE
mul_ggsw_glwe<Torus, grid_group, params>(
accumulator, accumulator_fft, block_join_buffer, bootstrapping_key,
polynomial_size, glwe_dimension, level_count, i, grid);

mul_ggsw_glwe_in_fourier_domain<grid_group, params>(
accumulator_fft, block_join_buffer, bootstrapping_key, i, grid);
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();

add_to_torus<Torus, params>(accumulator_fft, accumulator);
}

auto block_lwe_array_out =
&lwe_array_out[lwe_output_indexes[blockIdx.z] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

if (blockIdx.x == 0 && blockIdx.y < glwe_dimension) {
// Perform a sample extract. At this point, all blocks have the result, but
// we do the computation at block 0 to avoid waiting for extra blocks, in
// case they're not synchronized
sample_extract_mask<Torus, params>(block_lwe_array_out, accumulator);
if (lut_count > 1) {
for (int i = 1; i < lut_count; i++) {
auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

sample_extract_mask<Torus, params>(next_block_lwe_array_out,
accumulator, 1, i * lut_stride);
if (blockIdx.x == 0) {
if (blockIdx.y < glwe_dimension) {
// Perform a sample extract. At this point, all blocks have the result,
// but we do the computation at block 0 to avoid waiting for extra blocks,
// in case they're not synchronized
sample_extract_mask<Torus, params>(block_lwe_array_out, accumulator);
if (lut_count > 1) {
for (int i = 1; i < lut_count; i++) {
auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

sample_extract_mask<Torus, params>(next_block_lwe_array_out,
accumulator, 1, i * lut_stride);
}
}
}
} else if (blockIdx.x == 0 && blockIdx.y == glwe_dimension) {
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator, 0);
if (lut_count > 1) {
for (int i = 1; i < lut_count; i++) {

auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

sample_extract_body<Torus, params>(next_block_lwe_array_out,
accumulator, 0, i * lut_stride);
} else if (blockIdx.y == glwe_dimension) {
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator, 0);
if (lut_count > 1) {
for (int i = 1; i < lut_count; i++) {

auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

sample_extract_body<Torus, params>(next_block_lwe_array_out,
accumulator, 0, i * lut_stride);
}
}
}
}
Expand Down Expand Up @@ -254,7 +254,7 @@ __host__ void host_programmable_bootstrap_cg(
uint64_t partial_dm = full_dm - partial_sm;

int8_t *d_mem = buffer->d_mem;
double2 *buffer_fft = buffer->global_accumulator_fft;
double2 *buffer_fft = buffer->global_join_buffer;

int thds = polynomial_size / params::opt;
dim3 grid(level_count, glwe_dimension + 1, input_lwe_ciphertext_count);
Expand Down
Loading

0 comments on commit b041608

Please sign in to comment.