Skip to content

Commit

Permalink
feat(hlapi): bind cuda's trailing/leading_ones/zeros, ilog2
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Sep 2, 2024
1 parent aa2b274 commit 7503b30
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 36 deletions.
82 changes: 64 additions & 18 deletions tfhe/src/high_level_api/integers/signed/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,17 @@ where
crate::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support leading_zeros yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key
.key
.leading_zeros(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
crate::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -296,9 +304,17 @@ where
crate::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support leading_ones yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key
.key
.leading_ones(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
crate::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -332,9 +348,17 @@ where
crate::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support trailing_zeros yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key
.key
.trailing_zeros(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
crate::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -368,9 +392,17 @@ where
crate::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support trailing_ones yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key
.key
.trailing_ones(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
crate::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -406,9 +438,15 @@ where
crate::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support ilog2 yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key.key.ilog2(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
crate::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -448,9 +486,17 @@ where
(crate::FheUint32::new(result), FheBool::new(is_ok))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support checked_ilog2 yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let (result, is_ok) = cuda_key
.key
.checked_ilog2(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
(crate::FheUint32::new(result), FheBool::new(is_ok))
}),
})
}

Expand Down
82 changes: 64 additions & 18 deletions tfhe/src/high_level_api/integers/unsigned/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,17 @@ where
super::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support leading_zeros yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key
.key
.leading_zeros(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
super::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
super::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -398,9 +406,17 @@ where
super::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support leading_ones yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key
.key
.leading_ones(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
super::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
super::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -434,9 +450,17 @@ where
super::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support trailing_zeros yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key
.key
.trailing_zeros(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
super::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
super::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -470,9 +494,17 @@ where
super::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support trailing_ones yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key
.key
.trailing_ones(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
super::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
super::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -508,9 +540,15 @@ where
super::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support ilog2 yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let result = cuda_key.key.ilog2(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
super::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
super::FheUint32::new(result)
}),
})
}

Expand Down Expand Up @@ -550,9 +588,17 @@ where
(super::FheUint32::new(result), FheBool::new(is_ok))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support checked_ilog2 yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let (result, is_ok) = cuda_key
.key
.checked_ilog2(&*self.ciphertext.on_gpu(), streams);
let result = cuda_key.key.cast_to_unsigned(
result,
super::FheUint32Id::num_blocks(cuda_key.key.message_modulus),
streams,
);
(super::FheUint32::new(result), FheBool::new(is_ok))
}),
})
}

Expand Down
24 changes: 24 additions & 0 deletions tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,27 @@ fn test_is_even_is_odd_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
super::test_case_is_even_is_odd(&client_key);
}

#[test]
fn test_leading_trailing_zeros_ones_gpu() {
let client_key = setup_default_gpu();
super::test_case_leading_trailing_zeros_ones(&client_key);
}

#[test]
fn test_leading_trailing_zeros_ones_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
super::test_case_leading_trailing_zeros_ones(&client_key);
}

#[test]
fn test_ilog2_gpu() {
let client_key = setup_default_gpu();
super::test_case_ilog2(&client_key);
}

#[test]
fn test_ilog2_multibit() {
let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
super::test_case_ilog2(&client_key);
}

0 comments on commit 7503b30

Please sign in to comment.