Skip to content

Commit

Permalink
fix: working integer test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Dec 17, 2024
1 parent 6121771 commit 7867800
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ __host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension_in,
uint32_t polynomial_size,
uint32_t level_count,
uint32_t glwe_dimension) {
return level_count == 1; // &&
// glwe_dimension == 1;

/// lwe_dimension_in % BLOCK_SIZE_GEMM == 0 &&
// num_lwe % BLOCK_SIZE_GEMM == 0 &&
// TODO: Generalize to level_count > 1 by transposing the KSK
return level_count == 1;
}

template <typename Torus, typename TorusVec>
Expand Down Expand Up @@ -252,20 +249,26 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {

printf("FAST PATH PKS\n");
// Optimization of packing keyswitch when packing many LWEs

if (level_count > 1) {
PANIC("Fast path PKS only supports level_count==1");
}

cudaSetDevice(gpu_index);
check_cuda_error(cudaGetLastError());

int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;
int memory_unit =
glwe_accumulator_size; // > lwe_dimension_in ? glwe_accumulator_size :
// lwe_dimension_in;

if (lwe_dimension_in > glwe_accumulator_size) {
printf("PKS with lwe_dimension_in > glwe_accumulator_size\n");
}
// The fast path of PKS uses the scratch buffer (d_mem) differently than the
// old path: it needs to store the decomposed masks in the first half of this
// buffer and the keyswitched GLWEs in the second half of the buffer. Thus the
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension_in
? glwe_accumulator_size
: lwe_dimension_in;

// ping pong the buffer between successive calls
// split the buffer in two parts of this size
Expand Down Expand Up @@ -298,27 +301,29 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
CEIL_DIV(num_lwes, BLOCK_SIZE_GEMM));
dim3 threads_gemm(BLOCK_SIZE_GEMM * THREADS_GEMM);

printf("GEMM BLOCKS (%d, %d)\n", grid_gemm.x, grid_gemm.y);

auto stride_KSK_buffer = level_count * glwe_accumulator_size;
auto stride_KSK_buffer = glwe_accumulator_size;

uint32_t sharedMemSize = BLOCK_SIZE_GEMM * THREADS_GEMM * 2 * sizeof(Torus);
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, sharedMemSize, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension_in, d_mem_0, fp_ksk_array,
stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());

for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension_in, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
/*
TODO: transpose key to generalize to level_count > 1
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, sharedMemSize, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension_in, d_mem_0,
fp_ksk_array + li * glwe_accumulator_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension_in, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, sharedMemSize,
stream>>>( num_lwes, glwe_accumulator_size, lwe_dimension_in, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}
*/

// should we include the mask in the rotation ??
dim3 grid_rotate(CEIL_DIV(num_lwes, BLOCK_SIZE_DECOMP),
Expand Down
19 changes: 15 additions & 4 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ __host__ void scratch_packing_keyswitch_lwe_list_to_glwe(

int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;

int memory_unit =
glwe_accumulator_size; // > lwe_dimension ? glwe_accumulator_size :
// lwe_dimension;
int memory_unit = glwe_accumulator_size > lwe_dimension
? glwe_accumulator_size
: lwe_dimension;

if (allocate_gpu_memory) {
*fp_ks_buffer = (int8_t *)cuda_malloc_async(
Expand Down Expand Up @@ -245,6 +245,7 @@ __global__ void packing_keyswitch_lwe_list_to_glwe(
auto lwe_in = lwe_array_in + input_id * lwe_size;
auto ks_glwe_out = d_mem + input_id * glwe_accumulator_size;
auto glwe_out = glwe_array_out + input_id * glwe_accumulator_size;

// KS LWE to GLWE
packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext<Torus>(
ks_glwe_out, lwe_in, fp_ksk, lwe_dimension_in, glwe_dimension,
Expand Down Expand Up @@ -297,8 +298,18 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
dim3 grid(num_blocks, num_lwes);
dim3 threads(num_threads);

// The fast path of PKS uses the scratch buffer (d_mem) differently:
// it needs to store the decomposed masks in the first half of this buffer
// and the keyswitched GLWEs in the second half of the buffer. Thus the
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension_in
? glwe_accumulator_size
: lwe_dimension_in;

auto d_mem = (Torus *)fp_ks_buffer;
auto d_tmp_glwe_array_out = d_mem + num_lwes * glwe_accumulator_size;
auto d_tmp_glwe_array_out = d_mem + num_lwes * memory_unit;

// individually keyswitch each lwe
packing_keyswitch_lwe_list_to_glwe<Torus><<<grid, threads, 0, stream>>>(
Expand Down
20 changes: 13 additions & 7 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -730,25 +730,29 @@ mod tests {

#[test]
fn test_gpu_ciphertext_compression_fast_path() {
/// Implement a test only for the storage of ciphertexts
/// using a custom parameter set which is supported by a fast-path
/// packing keyswitch (only for level_count==1)
const COMP_PARAM_CUSTOM_FAST_PATH: CompressionParameters = CompressionParameters {
br_level: DecompositionLevelCount(1),
br_base_log: DecompositionBaseLog(23),
br_base_log: DecompositionBaseLog(21),
packing_ks_level: DecompositionLevelCount(1),
packing_ks_base_log: DecompositionBaseLog(21),
packing_ks_base_log: DecompositionBaseLog(19),
packing_ks_polynomial_size: PolynomialSize(2048),
packing_ks_glwe_dimension: GlweDimension(1),
lwe_per_glwe: LweCiphertextCount(2048),
storage_log_modulus: CiphertextModulusLog(19),
storage_log_modulus: CiphertextModulusLog(55),
packing_ks_key_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(
StandardDev(0.0),
StandardDev(2.845267479601915e-15),
),
};

const NUM_BLOCKS: usize = 32;

let streams = CudaStreams::new_multi_gpu();

let (radix_cks, _sks) = gen_keys_radix_gpu(
let (radix_cks, sks) = gen_keys_radix_gpu(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
NUM_BLOCKS,
&streams,
Expand Down Expand Up @@ -787,7 +791,8 @@ mod tests {
let ct = radix_cks.encrypt(message);
let d_ct =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams);
builder.push(d_ct, &streams);
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
builder.push(d_and_ct, &streams);
messages.push(MessageType::Unsigned(message));
}
1 => {
Expand All @@ -797,7 +802,8 @@ mod tests {
let ct = radix_cks.encrypt_signed(message);
let d_ct =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct, &streams);
builder.push(d_ct, &streams);
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
builder.push(d_and_ct, &streams);
messages.push(MessageType::Signed(message));
}
_ => {
Expand Down

0 comments on commit 7867800

Please sign in to comment.