Skip to content

Commit

Permalink
feat(gpu): signed and unsigned scalar mul
Browse files Browse the repository at this point in the history
+ remove small scalar mul
+ move around signed tests_cases
  • Loading branch information
agnesLeroy committed Mar 29, 2024
1 parent 80836c5 commit cb1110f
Show file tree
Hide file tree
Showing 50 changed files with 2,966 additions and 2,783 deletions.
177 changes: 167 additions & 10 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,6 @@ void cuda_scalar_addition_integer_radix_ciphertext_64_inplace(
uint32_t lwe_dimension, uint32_t lwe_ciphertext_count,
uint32_t message_modulus, uint32_t carry_modulus);

void cuda_small_scalar_multiplication_integer_radix_ciphertext_64(
cuda_stream_t *stream, void *output_lwe_array, void *input_lwe_array,
uint64_t scalar, uint32_t lwe_dimension, uint32_t lwe_ciphertext_count);

void cuda_small_scalar_multiplication_integer_radix_ciphertext_64_inplace(
cuda_stream_t *stream, void *lwe_array, uint64_t scalar,
uint32_t lwe_dimension, uint32_t lwe_ciphertext_count);

void scratch_cuda_integer_radix_logical_scalar_shift_kb_64(
cuda_stream_t *stream, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
Expand Down Expand Up @@ -269,8 +261,36 @@ void cuda_integer_radix_overflowing_sub_kb_64(

void cleanup_cuda_integer_radix_overflowing_sub(cuda_stream_t *stream,
int8_t **mem_ptr_void);
} // extern "C"

void scratch_cuda_integer_scalar_mul_kb_64(
cuda_stream_t *stream, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory);

void cuda_scalar_multiplication_integer_radix_ciphertext_64_inplace(
cuda_stream_t *stream, void *lwe_array, uint64_t *decomposed_scalar,
uint64_t *has_at_least_one_set, int8_t *mem_ptr, void *bsk, void *ksk,
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t message_modulus,
uint32_t num_blocks, uint32_t num_scalars);

void cleanup_cuda_integer_radix_scalar_mul(cuda_stream_t *stream,
int8_t **mem_ptr_void);
}

template <typename Torus>
__global__ void radix_blocks_rotate_right(Torus *dst, Torus *src,
uint32_t value, uint32_t blocks_count,
uint32_t lwe_size);
void generate_ids_update_degrees(int *terms_degree, size_t *h_lwe_idx_in,
size_t *h_lwe_idx_out,
int32_t *h_smart_copy_in,
int32_t *h_smart_copy_out, size_t ch_amount,
uint32_t num_radix, uint32_t num_blocks,
size_t chunk_size, size_t message_max,
size_t &total_count, size_t &message_count,
size_t &carry_count, size_t &sm_copy_count);
/*
* generate bivariate accumulator (lut) for device pointer
* v_stream - cuda stream
Expand Down Expand Up @@ -1225,6 +1245,8 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {

Torus *tmp_rotated;

bool reuse_memory = false;

int_logical_scalar_shift_buffer(cuda_stream_t *stream,
SHIFT_OR_ROTATE_TYPE shift_type,
int_radix_params params,
Expand Down Expand Up @@ -1310,14 +1332,98 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
}
}

int_logical_scalar_shift_buffer(cuda_stream_t *stream,
SHIFT_OR_ROTATE_TYPE shift_type,
int_radix_params params,
uint32_t num_radix_blocks,
bool allocate_gpu_memory,
Torus *pre_allocated_buffer) {
this->shift_type = shift_type;
this->params = params;
tmp_rotated = pre_allocated_buffer;
reuse_memory = true;

uint32_t max_amount_of_pbs = num_radix_blocks;
uint32_t big_lwe_size = params.big_lwe_dimension + 1;
uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);
cuda_memset_async(tmp_rotated, 0,
(max_amount_of_pbs + 2) * big_lwe_size_bytes, stream);
if (allocate_gpu_memory) {

uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus);

// LUT
// pregenerate lut vector and indexes
// lut for left shift
// here we generate 'num_bits_in_block' times lut
// one for each 'shift_within_block' = 'shift' % 'num_bits_in_block'
// even though lut_left contains 'num_bits_in_block' lut
// lut_indexes will have indexes for single lut only and those indexes
// will be 0 it means for pbs corresponding lut should be selected and
// pass along lut_indexes filled with zeros

// calculate bivariate lut for each 'shift_within_block'
// so that in case an application calls scratches only once for a whole
// circuit it can reuse memory for different shift values
for (int s_w_b = 1; s_w_b < num_bits_in_block; s_w_b++) {
auto cur_lut_bivariate = new int_radix_lut<Torus>(
stream, params, 1, num_radix_blocks, allocate_gpu_memory);

uint32_t shift_within_block = s_w_b;

std::function<Torus(Torus, Torus)> shift_lut_f;

if (shift_type == LEFT_SHIFT) {
shift_lut_f = [shift_within_block,
params](Torus current_block,
Torus previous_block) -> Torus {
current_block = current_block << shift_within_block;
previous_block = previous_block << shift_within_block;

Torus message_of_current_block =
current_block % params.message_modulus;
Torus carry_of_previous_block =
previous_block / params.message_modulus;
return message_of_current_block + carry_of_previous_block;
};
} else {
shift_lut_f = [num_bits_in_block, shift_within_block, params](
Torus current_block, Torus next_block) -> Torus {
// left shift so as not to lose
// bits when shifting right afterwards
next_block <<= num_bits_in_block;
next_block >>= shift_within_block;

// The way of getting carry / message is reversed compared
// to the usual way but its normal:
// The message is in the upper bits, the carry in lower bits
Torus message_of_current_block =
current_block >> shift_within_block;
Torus carry_of_previous_block = next_block % params.message_modulus;

return message_of_current_block + carry_of_previous_block;
};
}

// right shift
generate_device_accumulator_bivariate<Torus>(
stream, cur_lut_bivariate->lut, params.glwe_dimension,
params.polynomial_size, params.message_modulus,
params.carry_modulus, shift_lut_f);

lut_buffers_bivariate.push_back(cur_lut_bivariate);
}
}
}
void release(cuda_stream_t *stream) {
for (auto &buffer : lut_buffers_bivariate) {
buffer->release(stream);
delete buffer;
}
lut_buffers_bivariate.clear();

cuda_drop_async(tmp_rotated, stream);
if (!reuse_memory)
cuda_drop_async(tmp_rotated, stream);
}
};

Expand Down Expand Up @@ -2048,4 +2154,55 @@ template <typename Torus> struct int_bitop_buffer {
}
};

template <typename Torus> struct int_scalar_mul_buffer {
int_radix_params params;
int_logical_scalar_shift_buffer<Torus> *logical_scalar_shift_buffer;
int_sum_ciphertexts_vec_memory<Torus> *sum_ciphertexts_vec_mem;
Torus *preshifted_buffer;
Torus *all_shifted_buffer;

int_scalar_mul_buffer(cuda_stream_t *stream, int_radix_params params,
uint32_t num_radix_blocks, bool allocate_gpu_memory) {
this->params = params;

if (allocate_gpu_memory) {
uint32_t msg_bits = (uint32_t)std::log2(params.message_modulus);
uint32_t lwe_size = params.big_lwe_dimension + 1;
uint32_t lwe_size_bytes = lwe_size * sizeof(Torus);
size_t num_ciphertext_bits = msg_bits * num_radix_blocks;

//// Contains all shifted values of lhs for shift in range (0..msg_bits)
//// The idea is that with these we can create all other shift that are in
//// range (0..total_bits) for free (block rotation)
preshifted_buffer = (Torus *)cuda_malloc_async(
num_ciphertext_bits * lwe_size_bytes, stream);

all_shifted_buffer = (Torus *)cuda_malloc_async(
num_ciphertext_bits * num_radix_blocks * lwe_size_bytes, stream);

cuda_memset_async(preshifted_buffer, 0,
num_ciphertext_bits * lwe_size_bytes, stream);

cuda_memset_async(all_shifted_buffer, 0,
num_ciphertext_bits * num_radix_blocks * lwe_size_bytes,
stream);

logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer<Torus>(
stream, LEFT_SHIFT, params, num_radix_blocks, allocate_gpu_memory,
all_shifted_buffer);

sum_ciphertexts_vec_mem = new int_sum_ciphertexts_vec_memory<Torus>(
stream, params, num_radix_blocks, num_ciphertext_bits,
allocate_gpu_memory);
}
}

void release(cuda_stream_t *stream) {
logical_scalar_shift_buffer->release(stream);
sum_ciphertexts_vec_mem->release(stream);
cuda_drop_async(preshifted_buffer, stream);
cuda_drop_async(all_shifted_buffer, stream);
}
};

#endif // CUDA_INTEGER_H
142 changes: 61 additions & 81 deletions backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,66 @@
#include "integer/multiplication.cuh"

/*
* when adding chunk_size times terms together, there might be some blocks
* where addition have not happened or degree is zero, in that case we don't
* need to apply lookup table, so we find the indexes of the blocks where
* addition happened and store them inside h_lwe_idx_in, from same block
* might be extracted message and carry(if it is not the last block), so
* one block id might have two output id and we store them in h_lwe_idx_out
* blocks that do not require applying lookup table might be copied on both
* message and carry side or be replaced with zero ciphertexts, indexes of such
* blocks are stored inside h_smart_copy_in as input ids and h_smart_copy_out
* as output ids, -1 value as an input id means that zero ciphertext will be
* copied on output index.
*/
void generate_ids_update_degrees(int *terms_degree, size_t *h_lwe_idx_in,
size_t *h_lwe_idx_out,
int32_t *h_smart_copy_in,
int32_t *h_smart_copy_out, size_t ch_amount,
uint32_t num_radix, uint32_t num_blocks,
size_t chunk_size, size_t message_max,
size_t &total_count, size_t &message_count,
size_t &carry_count, size_t &sm_copy_count) {
for (size_t c_id = 0; c_id < ch_amount; c_id++) {
auto cur_chunk = &terms_degree[c_id * chunk_size * num_blocks];
for (size_t r_id = 0; r_id < num_blocks; r_id++) {
size_t new_degree = 0;
for (size_t chunk_id = 0; chunk_id < chunk_size; chunk_id++) {
new_degree += cur_chunk[chunk_id * num_blocks + r_id];
}

if (new_degree > message_max) {
h_lwe_idx_in[message_count] = c_id * num_blocks + r_id;
h_lwe_idx_out[message_count] = c_id * num_blocks + r_id;
message_count++;
} else {
h_smart_copy_in[sm_copy_count] = c_id * num_blocks + r_id;
h_smart_copy_out[sm_copy_count] = c_id * num_blocks + r_id;
sm_copy_count++;
}
}
}
for (size_t i = 0; i < sm_copy_count; i++) {
h_smart_copy_in[i] = -1;
h_smart_copy_out[i] = h_smart_copy_out[i] + ch_amount * num_blocks + 1;
}

for (size_t i = 0; i < message_count; i++) {
if (h_lwe_idx_in[i] % num_blocks != num_blocks - 1) {
h_lwe_idx_in[message_count + carry_count] = h_lwe_idx_in[i];
h_lwe_idx_out[message_count + carry_count] =
ch_amount * num_blocks + h_lwe_idx_in[i] + 1;
carry_count++;
} else {
h_smart_copy_in[sm_copy_count] = -1;
h_smart_copy_out[sm_copy_count] =
h_lwe_idx_in[i] - (num_blocks - 1) + ch_amount * num_blocks;
sm_copy_count++;
}
}

total_count = message_count + carry_count;
}
/*
* This scratch function allocates the necessary amount of data on the GPU for
* the integer radix multiplication in keyswitch->bootstrap order.
Expand Down Expand Up @@ -89,25 +150,6 @@ void cleanup_cuda_integer_mult(cuda_stream_t *stream, int8_t **mem_ptr_void) {
mem_ptr->release(stream);
}

void cuda_small_scalar_multiplication_integer_radix_ciphertext_64_inplace(
cuda_stream_t *stream, void *lwe_array, uint64_t scalar,
uint32_t lwe_dimension, uint32_t lwe_ciphertext_count) {

cuda_small_scalar_multiplication_integer_radix_ciphertext_64(
stream, lwe_array, lwe_array, scalar, lwe_dimension,
lwe_ciphertext_count);
}

void cuda_small_scalar_multiplication_integer_radix_ciphertext_64(
cuda_stream_t *stream, void *output_lwe_array, void *input_lwe_array,
uint64_t scalar, uint32_t lwe_dimension, uint32_t lwe_ciphertext_count) {

host_integer_small_scalar_mult_radix(
stream, static_cast<uint64_t *>(output_lwe_array),
static_cast<uint64_t *>(input_lwe_array), scalar, lwe_dimension,
lwe_ciphertext_count);
}

void scratch_cuda_integer_radix_sum_ciphertexts_vec_kb_64(
cuda_stream_t *stream, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
Expand Down Expand Up @@ -197,65 +239,3 @@ void cleanup_cuda_integer_radix_sum_ciphertexts_vec(cuda_stream_t *stream,

mem_ptr->release(stream);
}

/*
* when adding chunk_size times terms together, there might be some blocks
* where addition have not happened or degree is zero, in that case we don't
* need to apply lookup table, so we find the indexes of the blocks where
* addition happened and store them inside h_lwe_idx_in, from same block
* might be extracted message and carry(if it is not the last block), so
* one block id might have two output id and we store them in h_lwe_idx_out
* blocks that do not require applying lookup table might be copied on both
* message and carry side or be replaced with zero ciphertexts, indexes of such
* blocks are stored inside h_smart_copy_in as input ids and h_smart_copy_out
* as output ids, -1 value as an input id means that zero ciphertext will be
* copied on output index.
*/
void generate_ids_update_degrees(int *terms_degree, size_t *h_lwe_idx_in,
size_t *h_lwe_idx_out,
int32_t *h_smart_copy_in,
int32_t *h_smart_copy_out, size_t ch_amount,
uint32_t num_radix, uint32_t num_blocks,
size_t chunk_size, size_t message_max,
size_t &total_count, size_t &message_count,
size_t &carry_count, size_t &sm_copy_count) {
for (size_t c_id = 0; c_id < ch_amount; c_id++) {
auto cur_chunk = &terms_degree[c_id * chunk_size * num_blocks];
for (size_t r_id = 0; r_id < num_blocks; r_id++) {
size_t new_degree = 0;
for (size_t chunk_id = 0; chunk_id < chunk_size; chunk_id++) {
new_degree += cur_chunk[chunk_id * num_blocks + r_id];
}

if (new_degree > message_max) {
h_lwe_idx_in[message_count] = c_id * num_blocks + r_id;
h_lwe_idx_out[message_count] = c_id * num_blocks + r_id;
message_count++;
} else {
h_smart_copy_in[sm_copy_count] = c_id * num_blocks + r_id;
h_smart_copy_out[sm_copy_count] = c_id * num_blocks + r_id;
sm_copy_count++;
}
}
}
for (size_t i = 0; i < sm_copy_count; i++) {
h_smart_copy_in[i] = -1;
h_smart_copy_out[i] = h_smart_copy_out[i] + ch_amount * num_blocks + 1;
}

for (size_t i = 0; i < message_count; i++) {
if (h_lwe_idx_in[i] % num_blocks != num_blocks - 1) {
h_lwe_idx_in[message_count + carry_count] = h_lwe_idx_in[i];
h_lwe_idx_out[message_count + carry_count] =
ch_amount * num_blocks + h_lwe_idx_in[i] + 1;
carry_count++;
} else {
h_smart_copy_in[sm_copy_count] = -1;
h_smart_copy_out[sm_copy_count] =
h_lwe_idx_in[i] - (num_blocks - 1) + ch_amount * num_blocks;
sm_copy_count++;
}
}

total_count = message_count + carry_count;
}
Loading

0 comments on commit cb1110f

Please sign in to comment.