Skip to content

Commit

Permalink
feat(integer): add count_ones/zeros
Browse files Browse the repository at this point in the history
The non naive version made for 2_2 parameters
only bring slight (10-15%) for some small sizes like (64, 128, 256 bits)
but reduces number of PBS. The place where it brings the best
improvements it for very large numbers (e.g 6400 blocks 1.8s for naive,
1.1 sec for non-naive)
  • Loading branch information
tmontaigu committed Sep 2, 2024
1 parent aa2b274 commit 1e2d237
Show file tree
Hide file tree
Showing 11 changed files with 1,047 additions and 7 deletions.
4 changes: 4 additions & 0 deletions tfhe/benches/integer/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,8 @@ define_server_key_bench_unary_default_fn!(method_name: leading_ones_parallelized
define_server_key_bench_unary_default_fn!(method_name: trailing_zeros_parallelized, display_name: trailing_zeros);
define_server_key_bench_unary_default_fn!(method_name: trailing_ones_parallelized, display_name: trailing_ones);
define_server_key_bench_unary_default_fn!(method_name: ilog2_parallelized, display_name: ilog2);
define_server_key_bench_unary_default_fn!(method_name: count_ones_parallelized, display_name: count_ones);
define_server_key_bench_unary_default_fn!(method_name: count_zeros_parallelized, display_name: count_zeros);
define_server_key_bench_unary_default_fn!(method_name: checked_ilog2_parallelized, display_name: checked_ilog2);

define_server_key_bench_unary_default_fn!(method_name: unchecked_abs_parallelized, display_name: abs);
Expand Down Expand Up @@ -2227,6 +2229,8 @@ criterion_group!(
trailing_ones_parallelized,
ilog2_parallelized,
checked_ilog2_parallelized,
count_zeros_parallelized,
count_ones_parallelized,
);

criterion_group!(
Expand Down
4 changes: 4 additions & 0 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ define_server_key_bench_unary_signed_clean_input_fn!(method_name: leading_ones_p
define_server_key_bench_unary_signed_clean_input_fn!(method_name: trailing_zeros_parallelized, display_name: trailing_zeros);
define_server_key_bench_unary_signed_clean_input_fn!(method_name: trailing_ones_parallelized, display_name: trailing_ones);
define_server_key_bench_unary_signed_clean_input_fn!(method_name: ilog2_parallelized, display_name: ilog2);
define_server_key_bench_unary_signed_clean_input_fn!(method_name: count_zeros_parallelized, display_name: count_zeros);
define_server_key_bench_unary_signed_clean_input_fn!(method_name: count_ones_parallelized, display_name: count_ones);
define_server_key_bench_unary_signed_clean_input_fn!(method_name: checked_ilog2_parallelized, display_name: checked_ilog2);

define_server_key_bench_binary_signed_clean_inputs_fn!(
Expand Down Expand Up @@ -448,6 +450,8 @@ criterion_group!(
trailing_ones_parallelized,
ilog2_parallelized,
checked_ilog2_parallelized,
count_ones_parallelized,
count_zeros_parallelized,
);

criterion_group!(
Expand Down
74 changes: 74 additions & 0 deletions tfhe/src/high_level_api/integers/signed/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,80 @@ where
})
}

/// Returns the number of ones in the binary representation of self.
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool, FheInt16};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let clear_a = 0b0000000_0110111i16;
/// let a = FheInt16::encrypt(clear_a, &client_key);
///
/// let result = a.count_ones();
/// let decrypted: u32 = result.decrypt(&client_key);
/// assert_eq!(decrypted, clear_a.count_ones());
/// ```
pub fn count_ones(&self) -> crate::FheUint32 {
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let result = cpu_key
.pbs_key()
.count_ones_parallelized(&*self.ciphertext.on_cpu());
let result = cpu_key.pbs_key().cast_to_unsigned(
result,
crate::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()),
);
crate::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support count_ones yet");
}
})
}

/// Returns the number of zeros in the binary representation of self.
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool, FheInt16};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let clear_a = 0b0000000_0110111i16;
/// let a = FheInt16::encrypt(clear_a, &client_key);
///
/// let result = a.count_zeros();
/// let decrypted: u32 = result.decrypt(&client_key);
/// assert_eq!(decrypted, clear_a.count_zeros());
/// ```
pub fn count_zeros(&self) -> crate::FheUint32 {
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let result = cpu_key
.pbs_key()
.count_zeros_parallelized(&*self.ciphertext.on_cpu());
let result = cpu_key.pbs_key().cast_to_unsigned(
result,
crate::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()),
);
crate::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support count_zeros yet");
}
})
}

/// Returns the base 2 logarithm of the number, rounded down.
///
/// Result has no meaning if self encrypts a value <= 0. See [Self::checked_ilog2]
Expand Down
74 changes: 74 additions & 0 deletions tfhe/src/high_level_api/integers/unsigned/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,80 @@ where
})
}

/// Returns the number of ones in the binary representation of self.
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool, FheUint16};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let clear_a = 0b0000000_0110111u16;
/// let a = FheUint16::encrypt(clear_a, &client_key);
///
/// let result = a.count_ones();
/// let decrypted: u32 = result.decrypt(&client_key);
/// assert_eq!(decrypted, clear_a.count_ones());
/// ```
pub fn count_ones(&self) -> super::FheUint32 {
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let result = cpu_key
.pbs_key()
.count_ones_parallelized(&*self.ciphertext.on_cpu());
let result = cpu_key.pbs_key().cast_to_unsigned(
result,
super::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()),
);
super::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support count_ones yet");
}
})
}

/// Returns the number of zeros in the binary representation of self.
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool, FheUint16};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let clear_a = 0b0000000_0110111u16;
/// let a = FheUint16::encrypt(clear_a, &client_key);
///
/// let result = a.count_zeros();
/// let decrypted: u32 = result.decrypt(&client_key);
/// assert_eq!(decrypted, clear_a.count_zeros());
/// ```
pub fn count_zeros(&self) -> super::FheUint32 {
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let result = cpu_key
.pbs_key()
.count_zeros_parallelized(&*self.ciphertext.on_cpu());
let result = cpu_key.pbs_key().cast_to_unsigned(
result,
super::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()),
);
super::FheUint32::new(result)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support count_zeros yet");
}
})
}

/// Returns the base 2 logarithm of the number, rounded down.
///
/// Result has no meaning if self encrypts 0. See [Self::checked_ilog2]
Expand Down
20 changes: 13 additions & 7 deletions tfhe/src/integer/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use super::backward_compatibility::server_key::{CompressedServerKeyVersions, Ser
#[derive(Serialize, Deserialize, Clone, Versionize)]
#[versionize(ServerKeyVersions)]
pub struct ServerKey {
pub(crate) key: crate::shortint::ServerKey,
pub key: crate::shortint::ServerKey,
}

impl From<ServerKey> for crate::shortint::ServerKey {
Expand Down Expand Up @@ -216,19 +216,25 @@ impl ServerKey {
self.key.carry_modulus
}

/// Returns how many blocks a radix ciphertext should have to
/// be able to represent the given unsigned integer
pub fn num_blocks_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
pub fn num_bits_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
let num_bits_in_message = self.message_modulus().0.ilog2();
let num_bits_to_represent_output_value = if clear == Clear::MAX {
if clear == Clear::MAX {
Clear::BITS
} else {
(clear + Clear::ONE).ceil_ilog2() as usize
};
}
}

/// Returns how many blocks a radix ciphertext should have to
/// be able to represent the given unsigned integer
pub fn num_blocks_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
let num_bits_to_represent_output_value = self.num_bits_to_represent_unsigned_value(clear);
let num_bits_in_message = self.message_modulus().0.ilog2();
num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
}
}
Expand Down
Loading

0 comments on commit 1e2d237

Please sign in to comment.