Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: use strong types for outputs of DispersionParameters trait fns #1845

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn lwe_encrypt_decrypt_noise_distribution_custom_mod<Scalar: UnsignedTorus + Cas
let message_modulus_log = params.message_modulus_log;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);

let expected_variance = Variance(lwe_noise_distribution.gaussian_std_dev().get_variance());
let expected_variance = lwe_noise_distribution.gaussian_std_dev().get_variance();

let mut rsc = TestResources::new();

Expand Down Expand Up @@ -93,7 +93,7 @@ fn lwe_compact_public_key_encryption_expected_variance(
lwe_dimension: LweDimension,
) -> Variance {
let input_variance = input_noise.get_variance();
Variance(input_variance * (lwe_dimension.to_lwe_size().0 as f64))
Variance(input_variance.0 * (lwe_dimension.to_lwe_size().0 as f64))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a Mul<Rhs = usize> for 'Variance'

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not want to update the usability, just that we return strong types, this can be done in a follow up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

#[test]
Expand All @@ -104,7 +104,8 @@ fn test_variance_increase_cpk_formula() {
);

assert!(
(predicted_variance.get_standard_dev().log2() - 44.000704097196405f64).abs() < f64::EPSILON
(predicted_variance.get_standard_dev().0.log2() - 44.000704097196405f64).abs()
< f64::EPSILON
);
}

Expand All @@ -119,7 +120,7 @@ fn lwe_compact_public_encrypt_noise_distribution_custom_mod<
let message_modulus_log = params.message_modulus_log;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);

let glwe_variance = Variance(glwe_noise_distribution.gaussian_std_dev().get_variance());
let glwe_variance = glwe_noise_distribution.gaussian_std_dev().get_variance();

let expected_variance =
lwe_compact_public_key_encryption_expected_variance(glwe_variance, lwe_dimension);
Expand Down Expand Up @@ -208,7 +209,7 @@ fn random_noise_roundtrip<Scalar: UnsignedTorus + CastInto<usize>>(

assert!(matches!(noise, DynamicDistribution::Gaussian(_)));

let expected_variance = Variance(noise.gaussian_std_dev().get_variance());
let expected_variance = noise.gaussian_std_dev().get_variance();

let num_outputs = 100_000;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn lwe_encrypt_ks_decrypt_noise_distribution_custom_mod<Scalar: UnsignedTorus +
ciphertext_modulus.get_custom_modulus() as f64
};

let encryption_variance = Variance(glwe_noise_distribution.gaussian_std_dev().get_variance());
let encryption_variance = glwe_noise_distribution.gaussian_std_dev().get_variance();
let expected_variance = Variance(
encryption_variance.0
+ keyswitch_additive_variance_132_bits_security_gaussian(
Expand Down
180 changes: 119 additions & 61 deletions tfhe/src/core_crypto/commons/dispersion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,24 @@ use crate::core_crypto::backward_compatibility::commons::dispersion::StandardDev
// Clone because f64 is itself Copy and stored in register.
pub trait DispersionParameter: Copy {
/// Return the standard deviation of the distribution, i.e. $\sigma = 2^p$.
fn get_standard_dev(&self) -> f64;
fn get_standard_dev(&self) -> StandardDev;
/// Return the variance of the distribution, i.e. $\sigma^2 = 2^{2p}$.
fn get_variance(&self) -> f64;
fn get_variance(&self) -> Variance;
/// Return base 2 logarithm of the standard deviation of the distribution, i.e.
/// $\log\_2(\sigma)=p$
fn get_log_standard_dev(&self) -> f64;
fn get_log_standard_dev(&self) -> LogStandardDev;
/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$.
fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64;
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev;

/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$.
fn get_modular_variance(&self, log2_modulus: u32) -> f64;
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance;

/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$.
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64;
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev;
}

fn log2_modulus_to_modulus(log2_modulus: u32) -> f64 {
2.0f64.powi(log2_modulus as i32)
}

/// A distribution parameter that uses the base-2 logarithm of the standard deviation as
Expand All @@ -49,22 +53,31 @@ pub trait DispersionParameter: Copy {
/// ```rust
/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, LogStandardDev};
/// let params = LogStandardDev::from_log_standard_dev(-25.);
/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev(), -25.);
/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.));
/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.);
/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_variance(32),
/// params.get_modular_standard_dev(32).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
///
/// let modular_params = LogStandardDev::from_modular_log_standard_dev(22., 32);
/// assert_eq!(modular_params.get_standard_dev(), 2_f64.powf(-10.));
/// assert_eq!(modular_params.get_standard_dev().0, 2_f64.powf(-10.));
/// ```
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct LogStandardDev(pub f64);

#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularLogStandardDev {
pub value: f64,
pub modulus: f64,
}

impl LogStandardDev {
pub fn from_log_standard_dev(log_std: f64) -> Self {
Self(log_std)
Expand All @@ -76,23 +89,32 @@ impl LogStandardDev {
}

impl DispersionParameter for LogStandardDev {
fn get_standard_dev(&self) -> f64 {
f64::powf(2., self.0)
fn get_standard_dev(&self) -> StandardDev {
StandardDev(f64::powf(2., self.0))
}
fn get_variance(&self) -> f64 {
f64::powf(2., self.0 * 2.)
fn get_variance(&self) -> Variance {
Variance(f64::powf(2., self.0 * 2.))
}
fn get_log_standard_dev(&self) -> f64 {
self.0
fn get_log_standard_dev(&self) -> Self {
Self(self.0)
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 {
f64::powf(2., log2_modulus as f64 + self.0)
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
ModularStandardDev {
value: f64::powf(2., log2_modulus as f64 + self.0),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> f64 {
f64::powf(2., (log2_modulus as f64 + self.0) * 2.)
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
ModularVariance {
value: f64::powf(2., (log2_modulus as f64 + self.0) * 2.),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 {
log2_modulus as f64 + self.0
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0,
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
}

Expand All @@ -103,20 +125,29 @@ impl DispersionParameter for LogStandardDev {
/// ```rust
/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, StandardDev};
/// let params = StandardDev::from_standard_dev(2_f64.powf(-25.));
/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev(), -25.);
/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.));
/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.);
/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_standard_dev(32).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32),
/// params.get_modular_variance(32).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
/// ```
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
#[versionize(StandardDevVersions)]
pub struct StandardDev(pub f64);

#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularStandardDev {
pub value: f64,
pub modulus: f64,
}

impl StandardDev {
pub fn from_standard_dev(std: f64) -> Self {
Self(std)
Expand All @@ -128,23 +159,32 @@ impl StandardDev {
}

impl DispersionParameter for StandardDev {
fn get_standard_dev(&self) -> f64 {
self.0
fn get_standard_dev(&self) -> Self {
Self(self.0)
}
fn get_variance(&self) -> f64 {
self.0.powi(2)
fn get_variance(&self) -> Variance {
Variance(self.0.powi(2))
}
fn get_log_standard_dev(&self) -> f64 {
self.0.log2()
fn get_log_standard_dev(&self) -> LogStandardDev {
LogStandardDev(self.0.log2())
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 {
2_f64.powf(log2_modulus as f64 + self.0.log2())
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
ModularStandardDev {
value: 2_f64.powf(log2_modulus as f64 + self.0.log2()),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> f64 {
2_f64.powf(2. * (log2_modulus as f64 + self.0.log2()))
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
ModularVariance {
value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.log2())),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 {
log2_modulus as f64 + self.0.log2()
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0.log2(),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
}

Expand All @@ -155,19 +195,28 @@ impl DispersionParameter for StandardDev {
/// ```rust
/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, Variance};
/// let params = Variance::from_variance(2_f64.powi(-50));
/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev(), -25.);
/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.));
/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.);
/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_standard_dev(32).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32),
/// params.get_modular_variance(32).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
/// ```
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct Variance(pub f64);

#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularVariance {
pub value: f64,
pub modulus: f64,
}

impl Variance {
pub fn from_variance(var: f64) -> Self {
Self(var)
Expand All @@ -179,22 +228,31 @@ impl Variance {
}

impl DispersionParameter for Variance {
fn get_standard_dev(&self) -> f64 {
self.0.sqrt()
fn get_standard_dev(&self) -> StandardDev {
StandardDev(self.0.sqrt())
}
fn get_variance(&self) -> f64 {
self.0
fn get_variance(&self) -> Self {
Self(self.0)
}
fn get_log_standard_dev(&self) -> f64 {
self.0.sqrt().log2()
fn get_log_standard_dev(&self) -> LogStandardDev {
LogStandardDev(self.0.sqrt().log2())
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 {
2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2())
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
ModularStandardDev {
value: 2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2()),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> f64 {
2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2()))
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
ModularVariance {
value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2())),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 {
log2_modulus as f64 + self.0.sqrt().log2()
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0.sqrt().log2(),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
}
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/commons/math/random/gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl Gaussian<f64> {

pub fn from_dispersion_parameter(dispersion: impl DispersionParameter, mean: f64) -> Self {
Self {
std: dispersion.get_standard_dev(),
std: dispersion.get_standard_dev().0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could remove the generic and make std a StandardDev

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really as it would mean making StandardDev a generic, so I did not touch the design, just updated the functions to be more type safe

mean,
}
}
Expand Down
4 changes: 1 addition & 3 deletions tfhe/src/core_crypto/commons/math/random/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,7 @@ impl<T: UnsignedInteger> DynamicDistribution<T> {
#[track_caller]
pub fn gaussian_variance(&self) -> Variance {
match self {
Self::Gaussian(gaussian) => {
Variance(StandardDev::from_standard_dev(gaussian.std).get_variance())
}
Self::Gaussian(gaussian) => StandardDev::from_standard_dev(gaussian.std).get_variance(),
Self::TUniform(_) => {
panic!("Tried to get gaussian variance from a non gaussian distribution")
}
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/core_crypto/commons/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ pub mod test_tools {
{
for (x, y) in first.as_ref().iter().zip(second.as_ref().iter()) {
println!("{:?}, {:?}", *x, *y);
println!("{}", dist.get_standard_dev());
println!("{:?}", dist.get_standard_dev());
let distance: f64 = modular_distance(*x, *y).cast_into();
let torus_distance = distance / 2_f64.powi(Element::BITS as i32);
assert!(
torus_distance <= 5. * dist.get_standard_dev(),
torus_distance <= 5. * dist.get_standard_dev().0,
"{x} != {y} "
);
}
Expand Down
Loading