Skip to content

Commit

Permalink
fix(gpu): fix equal
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Dec 12, 2024
1 parent c1f05cb commit 25f4e5f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,7 @@ template <typename Torus> struct int_are_all_block_true_buffer {
// This map store LUTs that checks the equality between some input and values
// of interest in are_all_block_true(), as with max_value (the maximum message
// value).
std::unordered_map<int, int_radix_lut<Torus> *> is_equal_to_lut_map;
int_radix_lut<Torus> *is_max_value;

int_are_all_block_true_buffer(cudaStream_t const *streams,
uint32_t const *gpu_indexes, uint32_t gpu_count,
Expand All @@ -3084,16 +3084,26 @@ template <typename Torus> struct int_are_all_block_true_buffer {
tmp_out = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) *
num_radix_blocks * sizeof(Torus),
streams[0], gpu_indexes[0]);
is_max_value =
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
num_radix_blocks, allocate_gpu_memory);
auto is_max_value_f = [max_value](Torus x) -> Torus {
return x == max_value;
};

generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], is_max_value->get_lut(0, 0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, is_max_value_f);

is_max_value->broadcast_lut(streams, gpu_indexes, 0);
}
}

void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count) {
for (auto &lut : is_equal_to_lut_map) {
lut.second->release(streams, gpu_indexes, gpu_count);
delete (lut.second);
}
is_equal_to_lut_map.clear();
is_max_value->release(streams, gpu_indexes, gpu_count);
delete (is_max_value);

cuda_drop_async(tmp_block_accumulated, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_out, streams[0], gpu_indexes[0]);
Expand Down
43 changes: 24 additions & 19 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,19 @@ __host__ void are_all_comparisons_block_true(

while (remaining_blocks > 0) {
// Split in max_value chunks
uint32_t chunk_length = std::min(max_value, remaining_blocks);
int num_chunks = remaining_blocks / chunk_length;
int num_chunks = (remaining_blocks + max_value - 1) / max_value;

// Since all blocks encrypt either 0 or 1, we can sum max_value of them
// as in the worst case we will be adding `max_value` ones
auto input_blocks = tmp_out;
auto accumulator = are_all_block_true_buffer->tmp_block_accumulated;
auto is_equal_to_num_blocks_map =
&are_all_block_true_buffer->is_equal_to_lut_map;
auto is_max_value_lut = are_all_block_true_buffer->is_max_value;
uint32_t chunk_lengths[num_chunks];
auto begin_remaining_blocks = remaining_blocks;
for (int i = 0; i < num_chunks; i++) {
uint32_t chunk_length =
std::min(max_value, begin_remaining_blocks - i * max_value);
chunk_lengths[i] = chunk_length;
accumulate_all_blocks<Torus>(streams[0], gpu_indexes[0], accumulator,
input_blocks, big_lwe_dimension,
chunk_length);
Expand All @@ -111,29 +114,31 @@ __host__ void are_all_comparisons_block_true(
// is_non_zero_lut_buffer LUT
lut = mem_ptr->eq_buffer->is_non_zero_lut;
} else {
if ((*is_equal_to_num_blocks_map).find(chunk_length) !=
(*is_equal_to_num_blocks_map).end()) {
// The LUT is already computed
lut = (*is_equal_to_num_blocks_map)[chunk_length];
} else {
if (chunk_lengths[num_chunks - 1] != max_value) {
// LUT needs to be computed
auto new_lut =
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params,
max_value, num_radix_blocks, true);

uint32_t chunk_length = chunk_lengths[num_chunks - 1];
auto is_equal_to_num_blocks_lut_f = [chunk_length](Torus x) -> Torus {
return x == chunk_length;
};
generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], new_lut->get_lut(0, 0), glwe_dimension,
polynomial_size, message_modulus, carry_modulus,
streams[0], gpu_indexes[0], is_max_value_lut->get_lut(0, 1),
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
is_equal_to_num_blocks_lut_f);

new_lut->broadcast_lut(streams, gpu_indexes, 0);

(*is_equal_to_num_blocks_map)[chunk_length] = new_lut;
lut = new_lut;
Torus *h_lut_indexes = (Torus *)malloc(num_chunks * sizeof(Torus));
for (int index = 0; index < num_chunks; index++) {
if (index == num_chunks - 1) {
h_lut_indexes[index] = 1;
} else {
h_lut_indexes[index] = 0;
}
}
cuda_memcpy_async_to_gpu(is_max_value_lut->get_lut_indexes(0, 0),
h_lut_indexes, num_chunks * sizeof(Torus),
streams[0], gpu_indexes[0]);
is_max_value_lut->broadcast_lut(streams, gpu_indexes, 0);
}
lut = is_max_value_lut;
}

// Applies the LUT
Expand Down

0 comments on commit 25f4e5f

Please sign in to comment.