Skip to content

Commit

Permalink
feat(gpu): implement unchecked scalar eq
Browse files Browse the repository at this point in the history
  • Loading branch information
pdroalves committed Feb 1, 2024
1 parent 92a2817 commit 0992716
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 48 deletions.
81 changes: 50 additions & 31 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,12 +484,12 @@ template <typename Torus> struct int_sc_prop_memory {
};

// create lut objects
luts_array = new int_radix_lut<Torus>(
stream, params, 2, num_radix_blocks, allocate_gpu_memory);
luts_array = new int_radix_lut<Torus>(stream, params, 2, num_radix_blocks,
allocate_gpu_memory);
luts_carry_propagation_sum = new int_radix_lut<Torus>(
stream, params, 1, num_radix_blocks, luts_array);
message_acc = new int_radix_lut<Torus>(
stream, params, 1, num_radix_blocks, luts_array);
message_acc = new int_radix_lut<Torus>(stream, params, 1, num_radix_blocks,
luts_array);

auto lut_does_block_generate_carry = luts_array->get_lut(0);
auto lut_does_block_generate_or_propagate = luts_array->get_lut(1);
Expand Down Expand Up @@ -945,6 +945,9 @@ template <typename Torus> struct int_are_all_block_true_buffer {
}
};

template <typename Torus>
using scalar_eq_lut_pair_t = std::pair<Torus, int_radix_lut<Torus> *>;

template <typename Torus> struct int_comparison_eq_buffer {
int_radix_params params;
COMPARISON_TYPE op;
Expand All @@ -954,6 +957,10 @@ template <typename Torus> struct int_comparison_eq_buffer {

int_are_all_block_true_buffer<Torus> *are_all_block_true_buffer;

std::vector<scalar_eq_lut_pair_t<Torus>> *scalar_comparison_luts;

Torus *h_scalar_blocks;

int_comparison_eq_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
int_radix_params params, uint32_t num_radix_blocks,
bool allocate_gpu_memory) {
Expand All @@ -962,6 +969,8 @@ template <typename Torus> struct int_comparison_eq_buffer {

if (allocate_gpu_memory) {

h_scalar_blocks = (Torus *)malloc(num_radix_blocks * sizeof(Torus));

are_all_block_true_buffer = new int_are_all_block_true_buffer<Torus>(
stream, op, params, num_radix_blocks, allocate_gpu_memory);

Expand Down Expand Up @@ -996,6 +1005,8 @@ template <typename Torus> struct int_comparison_eq_buffer {
stream, is_non_zero_lut->lut, params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
is_non_zero_lut_f);

scalar_comparison_luts = new std::vector<scalar_eq_lut_pair_t<Torus>>();
}
}

Expand All @@ -1005,8 +1016,16 @@ template <typename Torus> struct int_comparison_eq_buffer {
is_non_zero_lut->release(stream);
delete is_non_zero_lut;

free(h_scalar_blocks);

are_all_block_true_buffer->release(stream);
delete are_all_block_true_buffer;

for (auto &pair : *scalar_comparison_luts) {
int_radix_lut<Torus> *lut = pair.second;
lut->release(stream);
}
scalar_comparison_luts->clear();
}
};

Expand Down Expand Up @@ -1083,14 +1102,8 @@ template <typename Torus> struct int_comparison_diff_buffer {

std::function<Torus(Torus)> operator_f;

int_radix_lut<Torus> *is_zero_lut;

int_tree_sign_reduction_buffer<Torus> *tree_buffer;

// Used for scalar comparisons
cuda_stream_t *lsb_stream;
cuda_stream_t *msb_stream;

int_comparison_diff_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
int_radix_params params, uint32_t num_radix_blocks,
bool allocate_gpu_memory) {
Expand All @@ -1114,8 +1127,6 @@ template <typename Torus> struct int_comparison_diff_buffer {
};

if (allocate_gpu_memory) {
lsb_stream = cuda_create_stream(stream->gpu_index);
msb_stream = cuda_create_stream(stream->gpu_index);

Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus);

Expand All @@ -1125,36 +1136,17 @@ template <typename Torus> struct int_comparison_diff_buffer {
tmp_packed_right =
(Torus *)cuda_malloc_async(big_size * (num_radix_blocks / 2), stream);

// LUTs
uint32_t total_modulus = params.message_modulus * params.carry_modulus;
auto is_zero_f = [total_modulus](Torus x) -> Torus {
return (x % total_modulus) == 0;
};

is_zero_lut = new int_radix_lut<Torus>(
stream, params, 1, num_radix_blocks, allocate_gpu_memory);

generate_device_accumulator<Torus>(
stream, is_zero_lut->lut, params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
is_zero_f);

tree_buffer = new int_tree_sign_reduction_buffer<Torus>(
stream, operator_f, params, num_radix_blocks, allocate_gpu_memory);
}
}

void release(cuda_stream_t *stream) {
is_zero_lut->release(stream);
delete is_zero_lut;
tree_buffer->release(stream);
delete tree_buffer;

cuda_drop_async(tmp_packed_left, stream);
cuda_drop_async(tmp_packed_right, stream);

cuda_destroy_stream(lsb_stream);
cuda_destroy_stream(msb_stream);
}
};

Expand All @@ -1167,6 +1159,8 @@ template <typename Torus> struct int_comparison_buffer {
int_radix_lut<Torus> *cleaning_lut;
std::function<Torus(Torus)> cleaning_lut_f;

int_radix_lut<Torus> *is_zero_lut;

int_comparison_eq_buffer<Torus> *eq_buffer;
int_comparison_diff_buffer<Torus> *diff_buffer;

Expand All @@ -1176,6 +1170,10 @@ template <typename Torus> struct int_comparison_buffer {
Torus *tmp_lwe_array_out;
int_cmux_buffer<Torus> *cmux_buffer;

// Used for scalar comparisons
cuda_stream_t *lsb_stream;
cuda_stream_t *msb_stream;

int_comparison_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
int_radix_params params, uint32_t num_radix_blocks,
bool allocate_gpu_memory) {
Expand All @@ -1185,6 +1183,9 @@ template <typename Torus> struct int_comparison_buffer {
cleaning_lut_f = [](Torus x) -> Torus { return x; };

if (allocate_gpu_memory) {
lsb_stream = cuda_create_stream(stream->gpu_index);
msb_stream = cuda_create_stream(stream->gpu_index);

tmp_lwe_array_out = (Torus *)cuda_malloc_async(
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
stream);
Expand All @@ -1203,6 +1204,19 @@ template <typename Torus> struct int_comparison_buffer {
params.polynomial_size, params.message_modulus, params.carry_modulus,
cleaning_lut_f);

uint32_t total_modulus = params.message_modulus * params.carry_modulus;
auto is_zero_f = [total_modulus](Torus x) -> Torus {
return (x % total_modulus) == 0;
};

is_zero_lut = new int_radix_lut<Torus>(
stream, params, 1, num_radix_blocks, allocate_gpu_memory);

generate_device_accumulator<Torus>(
stream, is_zero_lut->lut, params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
is_zero_f);

switch (op) {
case COMPARISON_TYPE::MAX:
case COMPARISON_TYPE::MIN:
Expand Down Expand Up @@ -1246,8 +1260,13 @@ template <typename Torus> struct int_comparison_buffer {
break;
}
cleaning_lut->release(stream);
is_zero_lut->release(stream);
delete is_zero_lut;
cuda_drop_async(tmp_lwe_array_out, stream);
cuda_drop_async(tmp_block_comparisons, stream);

cuda_destroy_stream(lsb_stream);
cuda_destroy_stream(msb_stream);
}
};

Expand Down
3 changes: 1 addition & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ __host__ void host_compare_with_zero_equality(
num_sum_blocks = 1;
} else {
uint32_t remainder_blocks = num_radix_blocks;

auto sum_i = sum;
auto chunk = lwe_array_in;
while (remainder_blocks > 1) {
Expand All @@ -194,7 +193,7 @@ __host__ void host_compare_with_zero_equality(
}
}

auto is_equal_to_zero_lut = mem_ptr->diff_buffer->is_zero_lut;
auto is_equal_to_zero_lut = mem_ptr->is_zero_lut;
integer_radix_apply_univariate_lookup_table_kb<Torus>(
stream, sum, sum, bsk, ksk, num_sum_blocks, is_equal_to_zero_lut);
are_all_comparisons_block_true(stream, lwe_array_out, sum, mem_ptr, bsk, ksk,
Expand Down
49 changes: 49 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -671,4 +671,53 @@ create_trivial_radix(cuda_stream_t *stream, Torus *lwe_array_out,
check_cuda_error(cudaGetLastError());
}

// Create new luts for scalar eq/ne if needed and update the related vector
template <typename Torus>
void update_scalar_comparison_luts(
cuda_stream_t *stream, std::vector<scalar_eq_lut_pair_t<Torus>> *luts,
Torus *h_scalar_blocks, uint32_t num_radix_blocks,
uint32_t num_scalar_blocks, int_radix_params params,
std::function<Torus(Torus, Torus)> operator_f) {
// One lut per scalar block
// And only generate a lut for scalar block
// actually present
for (int i = 0; i < num_scalar_blocks; i++) {
Torus scalar = h_scalar_blocks[i];

auto it = std::find_if(luts->begin(), luts->end(),
[scalar](const scalar_eq_lut_pair_t<Torus> &pair) {
return pair.first == scalar;
});

// Check if the element was found
if (it != luts->end()) {
// The LUT for this scalar has already been generated
continue;
} else {
int_radix_lut<Torus> *new_lut;

// Generate a new lut
if (!luts->empty()) {
// Use the first as base to reduce memory consumption
int_radix_lut<Torus> *base_lut = luts->front().second;
new_lut = new int_radix_lut<Torus>(stream, params, 1, num_radix_blocks,
base_lut);
} else {
new_lut =
new int_radix_lut<Torus>(stream, params, 1, num_radix_blocks, true);
}

auto lut_f = [scalar, operator_f](Torus x) -> Torus {
return operator_f(scalar, x);
};

generate_device_accumulator<Torus>(
stream, new_lut->lut, params.glwe_dimension, params.polynomial_size,
params.message_modulus, params.carry_modulus, lut_f);

luts->push_back({scalar, new_lut});
}
}
}

#endif // TFHE_RS_INTERNAL_INTEGER_CUH
19 changes: 8 additions & 11 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
int_comparison_buffer<uint64_t> *buffer =
(int_comparison_buffer<uint64_t> *)mem_ptr;
switch (buffer->op) {
// case EQ:
// case NE:
// host_integer_radix_equality_check_kb<uint64_t>(
// stream, static_cast<uint64_t *>(lwe_array_out),
// static_cast<uint64_t *>(lwe_array_1),
// static_cast<uint64_t *>(lwe_array_2), buffer, bsk,
// static_cast<uint64_t *>(ksk), glwe_dimension, polynomial_size,
// big_lwe_dimension, small_lwe_dimension, ks_level, ks_base_log,
// pbs_level, pbs_base_log, grouping_factor, lwe_ciphertext_count,
// message_modulus, carry_modulus);
// break;
case EQ:
case NE:
host_integer_radix_scalar_equality_check_kb<uint64_t>(
stream, static_cast<uint64_t *>(lwe_array_out),
static_cast<uint64_t *>(lwe_array_in),
static_cast<uint64_t *>(scalar_blocks), buffer, bsk,
static_cast<uint64_t *>(ksk), lwe_ciphertext_count, num_scalar_blocks);
break;
case GT:
case GE:
case LT:
Expand Down
Loading

0 comments on commit 0992716

Please sign in to comment.