Skip to content

Commit

Permalink
chore(gpu): refactor tree_add_chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Aug 6, 2024
1 parent a9bb6ea commit 62d2e30
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 46 deletions.
44 changes: 24 additions & 20 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ template <typename Torus>
__global__ void radix_blocks_rotate_right(Torus *dst, Torus *src,
uint32_t value, uint32_t blocks_count,
uint32_t lwe_size) {
value %= blocks_count;

size_t tid = threadIdx.x;
size_t src_block_id = blockIdx.x;
size_t dst_block_id = (src_block_id + value) % blocks_count;
size_t stride = blockDim.x;
if (tid < lwe_size) {
value %= blocks_count;
size_t src_block_id = blockIdx.x;
size_t dst_block_id = (src_block_id + value) % blocks_count;
size_t stride = blockDim.x;

auto cur_src_block = &src[src_block_id * lwe_size];
auto cur_dst_block = &dst[dst_block_id * lwe_size];
auto cur_src_block = &src[src_block_id * lwe_size];
auto cur_dst_block = &dst[dst_block_id * lwe_size];

for (size_t i = tid; i < lwe_size; i += stride) {
cur_dst_block[i] = cur_src_block[i];
for (size_t i = tid; i < lwe_size; i += stride) {
cur_dst_block[i] = cur_src_block[i];
}
}
}

Expand All @@ -44,25 +45,28 @@ template <typename Torus>
__global__ void radix_blocks_rotate_left(Torus *dst, Torus *src, uint32_t value,
uint32_t blocks_count,
uint32_t lwe_size) {
value %= blocks_count;
size_t src_block_id = blockIdx.x;

size_t tid = threadIdx.x;
size_t dst_block_id = (src_block_id >= value)
? src_block_id - value
: src_block_id - value + blocks_count;
size_t stride = blockDim.x;
if (tid < lwe_size) {
value %= blocks_count;
size_t src_block_id = blockIdx.x;

auto cur_src_block = &src[src_block_id * lwe_size];
auto cur_dst_block = &dst[dst_block_id * lwe_size];
size_t dst_block_id = (src_block_id >= value)
? src_block_id - value
: src_block_id - value + blocks_count;
size_t stride = blockDim.x;

for (size_t i = tid; i < lwe_size; i += stride) {
cur_dst_block[i] = cur_src_block[i];
auto cur_src_block = &src[src_block_id * lwe_size];
auto cur_dst_block = &dst[dst_block_id * lwe_size];

for (size_t i = tid; i < lwe_size; i += stride) {
cur_dst_block[i] = cur_src_block[i];
}
}
}

// rotate radix ciphertext right with specific value
// calculation is not inplace, so `dst` and `src` must not be the same
// one block is responsible to process single lwe ciphertext
template <typename Torus>
__host__ void
host_radix_blocks_rotate_right(cudaStream_t *streams, uint32_t *gpu_indexes,
Expand Down
26 changes: 5 additions & 21 deletions backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,19 @@ all_shifted_lhs_rhs(Torus *radix_lwe_left, Torus *lsb_ciphertext,
}
}

template <typename Torus, sharedMemDegree SMD>
template <typename Torus>
__global__ void tree_add_chunks(Torus *result_blocks, Torus *input_blocks,
uint32_t chunk_size, uint32_t block_size,
uint32_t num_blocks) {

extern __shared__ int8_t sharedmem[];

Torus *result = (Torus *)sharedmem;

size_t stride = blockDim.x;
size_t chunk_id = blockIdx.x;
size_t chunk_elem_size = chunk_size * num_blocks * block_size;
size_t radix_elem_size = num_blocks * block_size;
auto src_chunk = &input_blocks[chunk_id * chunk_elem_size];
auto dst_radix = &result_blocks[chunk_id * radix_elem_size];
size_t block_stride = blockIdx.y * block_size;
auto dst_block = &dst_radix[block_stride];

if constexpr (SMD == NOSM)
result = dst_block;
auto result = &dst_radix[block_stride];

// init shared mem with first radix of chunk
size_t tid = threadIdx.x;
Expand All @@ -127,11 +120,6 @@ __global__ void tree_add_chunks(Torus *result_blocks, Torus *input_blocks,
result[i] += cur_src_radix[block_stride + i];
}
}

// put result from shared mem to global mem
if constexpr (SMD == FULLSM)
for (int i = tid; i < block_size; i += stride)
dst_block[i] = result[i];
}

template <typename Torus, class params>
Expand Down Expand Up @@ -281,12 +269,8 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
size_t sm_size = big_lwe_size * sizeof(Torus);

cudaSetDevice(gpu_indexes[0]);
if (sm_size < max_shared_memory)
tree_add_chunks<Torus, FULLSM><<<add_grid, 512, sm_size, streams[0]>>>(
new_blocks, old_blocks, min(r, chunk_size), big_lwe_size, num_blocks);
else
tree_add_chunks<Torus, NOSM><<<add_grid, 512, 0, streams[0]>>>(
new_blocks, old_blocks, min(r, chunk_size), big_lwe_size, num_blocks);
tree_add_chunks<Torus><<<add_grid, 512, 0, streams[0]>>>(
new_blocks, old_blocks, min(r, chunk_size), big_lwe_size, num_blocks);

check_cuda_error(cudaGetLastError());

Expand All @@ -299,7 +283,7 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
terms_degree, h_lwe_idx_in, h_lwe_idx_out, h_smart_copy_in,
h_smart_copy_out, ch_amount, r, num_blocks, chunk_size, message_max,
total_count, message_count, carry_count, sm_copy_count);

cuda_synchronize_stream(streams[0], gpu_indexes[0]);
auto lwe_indexes_in = luts_message_carry->lwe_indexes_in;
auto lwe_indexes_out = luts_message_carry->lwe_indexes_out;
luts_message_carry->set_lwe_indexes(streams[0], gpu_indexes[0],
Expand Down
5 changes: 0 additions & 5 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(
Torus *full_rotated_buffer = mem->tmp_rotated;
Torus *rotated_buffer = &full_rotated_buffer[big_lwe_size];

// rotate right all the blocks in radix ciphertext
// copy result in new buffer
// 1024 threads are used in every block
// block_count blocks will be used in the grid
// one block is responsible to process single lwe ciphertext
if (mem->shift_type == LEFT_SHIFT) {
// rotate right as the blocks are from LSB to MSB
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
Expand Down

0 comments on commit 62d2e30

Please sign in to comment.