Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
pdroalves committed Mar 15, 2024
1 parent 991ec0b commit 8238126
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 25 deletions.
38 changes: 37 additions & 1 deletion backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {
polynomial_size * sizeof(Torus),
stream);
} break;
case PBS_VARIANT::TBC:
case PBS_VARIANT::CG: {
uint64_t full_sm =
get_buffer_size_full_sm_programmable_bootstrap_cg<Torus>(
Expand Down Expand Up @@ -239,6 +238,43 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::CLASSICAL> {
polynomial_size / 2 * sizeof(double2),
stream);
} break;
case PBS_VARIANT::TBC:{

bool supports_dsm =
supports_distributed_shared_memory_on_classic_programmable_bootstrap<
Torus>(polynomial_size, max_shared_memory);

uint64_t full_sm =
get_buffer_size_full_sm_programmable_bootstrap_tbc<Torus>(
polynomial_size);
uint64_t partial_sm =
get_buffer_size_partial_sm_programmable_bootstrap_tbc<Torus>(
polynomial_size);
uint64_t minimum_sm_tbc = 0;
if (supports_dsm)
minimum_sm_tbc = get_buffer_size_sm_dsm_plus_tbc_classic_programmable_bootstrap<Torus>(
polynomial_size);

uint64_t partial_dm = full_sm - partial_sm;
uint64_t full_dm = full_sm;
uint64_t device_mem = 0;

if (max_shared_memory < partial_sm + minimum_sm_tbc) {
device_mem = full_dm * input_lwe_ciphertext_count * level_count *
(glwe_dimension + 1);
} else if (max_shared_memory < full_sm + minimum_sm_tbc) {
device_mem = partial_dm * input_lwe_ciphertext_count * level_count *
(glwe_dimension + 1);
}

// Otherwise, both kernels run all in shared memory
d_mem = (int8_t *)cuda_malloc_async(device_mem, stream);

global_accumulator_fft = (double2 *)cuda_malloc_async(
(glwe_dimension + 1) * level_count * input_lwe_ciphertext_count *
polynomial_size / 2 * sizeof(double2),
stream);
} break;
default:
PANIC("Cuda error (PBS): unsupported implementation variant.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,12 @@ template<> __device__ double2 *get_join_buffer_element(int i, cluster_group &clu
bool support_dsm,
double2 *global_memory_buffer, uint32_t
polynomial_size) {
#if CUDA_ARCH < 900
double2 *buffer_slice =
global_memory_buffer + blockIdx.y * polynomial_size / 2;
#else
double2 *buffer_slice;
if (support_dsm) {
extern __shared__ double2 smem[];
buffer_slice = cluster.map_shared_rank(smem, i);
} else {
buffer_slice = global_memory_buffer + i * polynomial_size / 2;
}
#endif
return buffer_slice;
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ __device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,
// Continues multiplying fft by every polynomial in that particular bsk level
// Each y-block accumulates in a different polynomial at each iteration
for (int j = 1; j < (glwe_dimension + 1); j++) {
int idx = (j + this_block_rank) % (glwe_dimension + 1);
int idx = (j + blockIdx.y) % (glwe_dimension + 1);

auto bsk_poly = bsk_slice + idx * params::degree / 2;
auto buffer_slice = get_join_buffer_element<G>(
idx, group, support_dsm, level_join_buffer, polynomial_size);
idx * level_count, group, support_dsm, join_buffer, polynomial_size);

int tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
Expand All @@ -80,7 +80,7 @@ __device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,
// All blocks are synchronized here; after this sync, level_join_buffer has
// the values needed from every other block

auto src_acc = get_join_buffer_element<G>(blockIdx.y, group,
auto src_acc = get_join_buffer_element<G>(this_block_rank, group,
support_dsm, join_buffer, polynomial_size);

// copy first product into fft buffer
Expand All @@ -93,7 +93,8 @@ __device__ void mul_ggsw_glwe(Torus *accumulator, double2 *fft,

// accumulate rest of the products into fft buffer
for (int l = 1; l < gridDim.x; l++) {
auto cur_src_acc = &src_acc[l * (glwe_dimension + 1) * params::degree / 2];
auto cur_src_acc = get_join_buffer_element<G>(
blockIdx.y * level_count + l, group, support_dsm, join_buffer, polynomial_size);
tid = threadIdx.x;
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] += cur_src_acc[tid];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ __host__ void scratch_programmable_bootstrap_cg(

*buffer = new pbs_buffer<Torus, CLASSICAL>(
stream, glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, PBS_VARIANT::CG, allocate_gpu_memory);
input_lwe_ciphertext_count, PBS_VARIANT::TBC, allocate_gpu_memory);
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,17 @@ device_programmable_bootstrap_tbc(
selected_memory = &device_mem[block_index * device_memory_size_per_block];
}

// We always compute the pointer with most restrictive alignment to avoid
// alignment issues
double2 *accumulator_fft = (double2 *)selected_memory;
Torus *accumulator =
(Torus *)accumulator_fft +
(ptrdiff_t)(sizeof(double2) * polynomial_size / 2 / sizeof(Torus));
Torus *accumulator = (Torus*)selected_memory;
Torus *accumulator_rotated =
(Torus *)accumulator + (ptrdiff_t)polynomial_size;
double2 *accumulator_fft = (double2*)accumulator_rotated +
(ptrdiff_t)(sizeof(Torus) * polynomial_size / sizeof(double2));


if constexpr (SMD == PARTIALSM){
accumulator_fft = (double2 *)sharedmem;
if(support_dsm)
accumulator_fft += sizeof(double2) * (polynomial_size/2);
accumulator_fft += (ptrdiff_t)(polynomial_size/2);
}

// The third dimension of the block is used to determine on which ciphertext
Expand Down Expand Up @@ -175,7 +173,6 @@ __host__ void scratch_programmable_bootstrap_tbc(
uint64_t partial_sm =
get_buffer_size_partial_sm_programmable_bootstrap_tbc<Torus>(
polynomial_size);

uint64_t minimum_sm_tbc = 0;
if (supports_dsm)
minimum_sm_tbc = get_buffer_size_sm_dsm_plus_tbc_classic_programmable_bootstrap<Torus>(
Expand Down Expand Up @@ -263,7 +260,7 @@ __host__ void host_programmable_bootstrap_tbc(

cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeClusterDimension;
attribute[0].val.clusterDim.x = 1; // Cluster size in X-dimension
attribute[0].val.clusterDim.x = level_count; // Cluster size in X-dimension
attribute[0].val.clusterDim.y = (glwe_dimension + 1);
attribute[0].val.clusterDim.z = 1;
config.attrs = attribute;
Expand All @@ -279,10 +276,10 @@ __host__ void host_programmable_bootstrap_tbc(
lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft,lwe_dimension,
polynomial_size, base_log, level_count, d_mem, full_dm, supports_dsm));
} else if (max_shared_memory < full_sm + minimum_sm_tbc) {
config.dynamicSmemBytes = partial_dm +minimum_sm_tbc;
config.dynamicSmemBytes = partial_sm +minimum_sm_tbc;

check_cuda_error(cudaLaunchKernelEx(
&config, device_programmable_bootstrap_tbc<Torus, params, PARTIALSM>,
&config,device_programmable_bootstrap_tbc<Torus, params, PARTIALSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft,lwe_dimension,
polynomial_size, base_log, level_count, d_mem, partial_dm, supports_dsm));
Expand Down Expand Up @@ -352,7 +349,7 @@ supports_distributed_shared_memory_on_classic_programmable_bootstrap(
get_buffer_size_sm_dsm_plus_tbc_classic_programmable_bootstrap<Torus>(
polynomial_size);

if (max_shared_memory <= minimum_sm) {
if (max_shared_memory < minimum_sm) {
// If we cannot store a single polynomial in a block shared memory we
// cannot use TBC
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ TEST_P(ClassicalProgrammableBootstrapTestPrimitives_u64, bootstrap) {
uint64_t decrypted = 0;
core_crypto_lwe_decrypt(&decrypted, result, lwe_sk_out,
glwe_dimension * polynomial_size);
EXPECT_NE(decrypted, plaintext);
ASSERT_NE(decrypted, plaintext);
// let err = (decrypted >= plaintext) ? decrypted - plaintext :
// plaintext
// - decrypted;
Expand All @@ -216,7 +216,7 @@ TEST_P(ClassicalProgrammableBootstrapTestPrimitives_u64, bootstrap) {
// Compute the rounding bit
uint64_t rounding = (decrypted & rounding_bit) << 1;
uint64_t decoded = (decrypted + rounding) / delta;
EXPECT_EQ(decoded, plaintext / delta);
ASSERT_EQ(decoded, plaintext / delta);
}
}
}
Expand Down

0 comments on commit 8238126

Please sign in to comment.