Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(shortint): refactor the shortint keyswitching code #1423

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions tfhe/src/integer/key_switching_key/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ use crate::integer::{
IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey,
};
use crate::shortint::parameters::compact_public_key_only::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::key_switching::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::key_switching::{
PARAM_KEYSWITCH_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_KEYSWITCH_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
use crate::shortint::parameters::{
ShortintKeySwitchingParameters, PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
ClassicPBSParameters, CompactPublicKeyEncryptionParameters, ShortintKeySwitchingParameters,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
use crate::shortint::prelude::{PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS};

Expand Down Expand Up @@ -160,12 +164,11 @@ fn gen_multi_keys_test_integer_to_integer_ci_run_filter() {
assert_eq!(clear, 228);
}

#[test]
fn test_cpk_encrypt_cast_compute_ci_run_filter() {
let param_pke_only = PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
let param_fhe = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
let param_ksk = PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;

fn test_case_cpk_encrypt_cast_compute(
param_pke_only: CompactPublicKeyEncryptionParameters,
param_fhe: ClassicPBSParameters,
param_ksk: ShortintKeySwitchingParameters,
) {
let num_block = 4usize;

assert_eq!(param_pke_only.message_modulus, param_fhe.message_modulus);
Expand Down Expand Up @@ -225,3 +228,21 @@ fn test_cpk_encrypt_cast_compute_ci_run_filter() {
let clear = cks_fhe.decrypt_radix::<u64>(&ct1_extracted_and_cast) % modulus;
assert_eq!(clear, (input_msg * multiplier) % modulus);
}

#[test]
fn test_cpk_encrypt_cast_to_small_compute_big_ci_run_filter() {
test_case_cpk_encrypt_cast_compute(
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_KEYSWITCH_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
)
}

#[test]
fn test_cpk_encrypt_cast_to_big_compute_big_ci_run_filter() {
test_case_cpk_encrypt_cast_compute(
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_KEYSWITCH_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
)
}
138 changes: 81 additions & 57 deletions tfhe/src/shortint/key_switching_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,36 +500,17 @@ impl<'keys> KeySwitchingKeyView<'keys> {
.dest_server_key
.unchecked_create_trivial_with_lwe_size(0, output_lwe_size);

// TODO: We are outside the standard AP, if we chain keyswitches, we will refresh, which is
// safer for now. We can likely add an additional flag in shortint to indicate if we
// want to refresh or not, for now refresh anyways.
keyswitched.set_noise_level(NoiseLevel::UNKNOWN);

let cast_rshift = self.key_switching_key_material.cast_rshift;

match cast_rshift.cmp(&0) {
// Same bit size: only key switch
Ordering::Equal => {
keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&input_ct.ct,
&mut keyswitched.ct,
);
keyswitched.degree = input_ct.degree;
// We don't really know where we stand in terms of noise here
keyswitched.set_noise_level(NoiseLevel::UNKNOWN);
}
// Cast to bigger bit length: keyswitch, then right shift
Ordering::Greater => {
keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&input_ct.ct,
&mut keyswitched.ct,
);
// First pre process
let tmp_preprocessed: Ciphertext;

let acc = self
.dest_server_key
.generate_lookup_table(|n| n >> cast_rshift);
self.dest_server_key
.apply_lookup_table_assign(&mut keyswitched, &acc);
// degree updated by the apply lookup table
keyswitched.set_noise_level(NoiseLevel::NOMINAL);
}
let pre_processed = match cast_rshift.cmp(&0) {
// Cast to smaller bit length: left shift, then keyswitch
Ordering::Less => {
let src_server_key = self.src_server_key.as_ref().expect(
Expand All @@ -541,38 +522,32 @@ impl<'keys> KeySwitchingKeyView<'keys> {
(n << -cast_rshift)
% (input_ct.carry_modulus.0 * input_ct.message_modulus.0) as u64
});
let shifted_cipher = src_server_key.apply_lookup_table(input_ct, &acc);

keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&shifted_cipher.ct,
&mut keyswitched.ct,
);
// The degree is high in the source plaintext modulus, but smaller in the arriving
// one.
//
// src 4 bits:
// 0 | XX | 11
// shifted:
// 0 | 11 | 00 -> Applied lut will have max degree 1100 = 12
// dst 2 bits :
// 0 | 11 -> 11 = 3
keyswitched.degree = Degree::new(shifted_cipher.degree.get() >> -cast_rshift);
// We don't really know where we stand in terms of noise here
keyswitched.set_noise_level(NoiseLevel::UNKNOWN);
tmp_preprocessed = src_server_key.apply_lookup_table(input_ct, &acc);
&tmp_preprocessed
}
}
// No pre-processing
Ordering::Equal | Ordering::Greater => input_ct,
};

// The keyswitch
keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&pre_processed.ct,
&mut keyswitched.ct,
);
keyswitched.degree = pre_processed.degree;

let ret = {
// Manage the destination key adjustment
let mut res = {
let destination_pbs_order: PBSOrder =
self.key_switching_key_material.destination_key.into();
if destination_pbs_order == self.dest_server_key.pbs_order {
keyswitched
} else {
let wrong_key_ct = keyswitched;
let mut output = self.dest_server_key.create_trivial(0);
output.degree = wrong_key_ct.degree;
output.set_noise_level(wrong_key_ct.noise_level());
let mut correct_key_ct = self.dest_server_key.create_trivial(0);
correct_key_ct.degree = wrong_key_ct.degree;
correct_key_ct.set_noise_level(wrong_key_ct.noise_level());

// We are arriving under the wrong key for the dest_server_key
match self.key_switching_key_material.destination_key {
Expand All @@ -581,9 +556,8 @@ impl<'keys> KeySwitchingKeyView<'keys> {
keyswitch_lwe_ciphertext(
&self.dest_server_key.key_switching_key,
&wrong_key_ct.ct,
&mut output.ct,
&mut correct_key_ct.ct,
);
// TODO refresh ?
}
// Small to Big == PBS
EncryptionKeyChoice::Small => {
Expand All @@ -593,20 +567,70 @@ impl<'keys> KeySwitchingKeyView<'keys> {
apply_programmable_bootstrap(
&self.dest_server_key.bootstrapping_key,
&wrong_key_ct.ct,
&mut output.ct,
&mut correct_key_ct.ct,
&acc.acc,
buffers,
);
});
output.set_noise_level(NoiseLevel::NOMINAL);
// Degree does not need to be updated as we apply an Identity LUT and we
// apply only the bootstrap directly on the underlying ciphertext, we have
// to update the noise however.
correct_key_ct.set_noise_level(NoiseLevel::NOMINAL);
}
}

output
correct_key_ct
}
};

ret
let degree_after_keyswitch = res.degree;
match cast_rshift.cmp(&0) {
// Same bit size: only key switch
Ordering::Equal => {
// Refresh if we haven't applied a PBS yet
if res.noise_level() == NoiseLevel::UNKNOWN {
let acc = self.dest_server_key.generate_lookup_table(|x| x);
self.dest_server_key
.apply_lookup_table_assign(&mut res, &acc);
// We apply an Identity LUT so we know a tighter bound than the worst case LUT
// value
res.degree = degree_after_keyswitch;
}
}
// Cast to bigger bit length: keyswitch, then right shift
Ordering::Greater => {
let acc = self
.dest_server_key
.generate_lookup_table(|n| n >> cast_rshift);
self.dest_server_key
.apply_lookup_table_assign(&mut res, &acc);
// degree and noise are updated by the apply lookup table
}
// Cast to smaller bit length: left shift, then keyswitch
Ordering::Less => {
// The degree is high in the source plaintext modulus, but smaller in the arriving
// one.
//
// src 4 bits:
// 0 | XX | 11
// shifted:
// 0 | 11 | 00 -> Applied lut will have max degree 1100 = 12
// dst 2 bits :
// 0 | 11 -> 11 = 3
let new_degree = Degree::new(degree_after_keyswitch.get() >> -cast_rshift);
// Refresh if we haven't applied a PBS yet
if res.noise_level() == NoiseLevel::UNKNOWN {
let acc = self.dest_server_key.generate_lookup_table(|x| x);
self.dest_server_key
.apply_lookup_table_assign(&mut res, &acc);
}
// Apply the degree correction, even if we bootstrapped as the Identity LUT would
// not change this correction
res.degree = new_degree;
}
}

res
}
}

Expand Down
5 changes: 5 additions & 0 deletions tfhe/src/shortint/key_switching_key/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
// Message 0 Carry 0
let cipher = ck1.unchecked_encrypt(0);
let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 0);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand All @@ -190,6 +191,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
// Message 1 Carry 0
let cipher = ck1.unchecked_encrypt(1);
let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 1);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand All @@ -199,6 +201,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
// Message 0 Carry 1
let cipher = ck1.unchecked_encrypt(2);
let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 0);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand All @@ -208,6 +211,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
// Message 1 Carry 1
let cipher = ck1.unchecked_encrypt(3);
let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 1);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand All @@ -222,6 +226,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
assert_eq!((clear, carry), (0, 3));

let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 0);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand Down
22 changes: 22 additions & 0 deletions tfhe/src/shortint/parameters/key_switching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,31 @@ pub const PARAM_KEYSWITCH_1_1_KS_PBS_TO_2_2_KS_PBS: ShortintKeySwitchingParamete
destination_key: EncryptionKeyChoice::Big,
};

// The level and base log correspond to the level and base log of the 2_2 TUniform parameters, so
// these parameters allow to keyswitch from one set of keys of the 2_2 TUniform parameters to
// another set of keys. The ciphertext will be under the small key and a PBS with the destination
// keys will be applied to finish the keyswitch.
pub const PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64: ShortintKeySwitchingParameters =
ShortintKeySwitchingParameters {
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(3),
destination_key: EncryptionKeyChoice::Small,
};

// Parameters to keyswitch from input PKE 2_2 TUniform parameters to 2_2 KS_PBS compute parameters
// arriving under the small key, requires a PBS to get to the big key
pub const PARAM_KEYSWITCH_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64:
ShortintKeySwitchingParameters = ShortintKeySwitchingParameters {
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(3),
destination_key: EncryptionKeyChoice::Small,
};

// Parameters to keyswitch from input PKE 2_2 TUniform parameters to 2_2 KS_PBS compute parameters
// arriving under the big key, requires a PBS to get to the big key
pub const PARAM_KEYSWITCH_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64:
ShortintKeySwitchingParameters = ShortintKeySwitchingParameters {
ks_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(27),
destination_key: EncryptionKeyChoice::Big,
};
Loading