Skip to content

Commit

Permalink
Merge pull request #1164 from zama-ai/fix_optimizer_bug_2
Browse files Browse the repository at this point in the history
fix(optimizer): fix performance regression
  • Loading branch information
BourgerieQuentin authored Nov 29, 2024
2 parents 0aac13f + 831e847 commit de14712
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ impl VariancedDag {
nb_constraints: out_shape.flat_size(),
safe_variance_bound: max_variance,
noise_expression: variance.clone(),
noise_evaluator: None,
location: op.location.clone(),
};
self.external_variance_constraints.push(constraint);
Expand Down Expand Up @@ -273,6 +274,7 @@ impl VariancedDag {
nb_constraints: out_shape.flat_size(),
safe_variance_bound: max_variance,
noise_expression: variance.clone(),
noise_evaluator: None,
location: dag_op.location.clone(),
};
self.external_variance_constraints.push(constraint);
Expand Down Expand Up @@ -646,6 +648,7 @@ fn variance_constraint(
safe_variance_bound,
nb_partitions,
noise_expression: noise,
noise_evaluator: None,
location,
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{fmt, ops::Add};

use super::{
partitions::PartitionIndex,
symbolic::{fast_keyswitch, keyswitch, Symbol, SymbolMap},
symbolic::{fast_keyswitch, keyswitch, Symbol, SymbolArray, SymbolMap, SymbolScheme},
};

/// A structure storing the number of times an fhe operation gets executed in a circuit.
Expand All @@ -29,36 +29,44 @@ impl fmt::Display for OperationsCount {

/// An ensemble of costs associated with fhe operation symbols.
#[derive(Clone, Debug)]
pub struct ComplexityValues(SymbolMap<f64>);
pub struct ComplexityValues(SymbolArray<f64>);

impl ComplexityValues {
/// Returns an empty set of cost values.
pub fn new() -> Self {
ComplexityValues(SymbolMap::new())
pub fn from_scheme(scheme: &SymbolScheme) -> ComplexityValues {
ComplexityValues(SymbolArray::from_scheme(scheme))
}

/// Sets the cost associated with an fhe operation symbol.
pub fn set_cost(&mut self, source: Symbol, value: f64) {
self.0.set(source, value);
self.0.set(&source, value);
}
}

/// A complexity expression is a sum of complexity terms associating operation
/// symbols with the number of time they gets executed in the circuit.
#[derive(Clone, Debug)]
pub struct ComplexityExpression(SymbolMap<usize>);
pub struct ComplexityEvaluator(SymbolArray<usize>);

impl ComplexityExpression {
impl ComplexityEvaluator {
/// Creates a complexity expression from a set of operation counts.
pub fn from(counts: &OperationsCount) -> Self {
Self(counts.0.clone())
pub fn from_scheme_and_counts(
scheme: &SymbolScheme,
counts: &OperationsCount,
) -> ComplexityEvaluator {
Self(SymbolArray::from_scheme_and_map(scheme, &counts.0))
}

pub fn scheme(&self) -> &SymbolScheme {
self.0.scheme()
}

/// Evaluates the total cost expression on a set of cost values.
pub fn evaluate_total_cost(&self, costs: &ComplexityValues) -> f64 {
self.0.iter().fold(0.0, |acc, (symbol, n_ops)| {
acc + (n_ops as f64) * costs.0.get(symbol)
})
self.0
.iter()
.zip(costs.0.iter())
.fold(0.0, |acc, (n_ops, cost)| acc + (*n_ops as f64) * *cost)
}

/// Evaluates the max ks cost expression on a set of cost values.
Expand All @@ -69,11 +77,11 @@ impl ComplexityExpression {
src_partition: PartitionIndex,
dst_partition: PartitionIndex,
) -> f64 {
let actual_ks_cost = costs.0.get(keyswitch(src_partition, dst_partition));
let ks_coeff = self.0.get(keyswitch(src_partition, dst_partition));
let actual_ks_cost = costs.0.get(&keyswitch(src_partition, dst_partition));
let ks_coeff = self.0.get(&keyswitch(src_partition, dst_partition));
let actual_complexity =
self.evaluate_total_cost(costs) - (ks_coeff as f64) * actual_ks_cost;
(complexity_cut - actual_complexity) / (ks_coeff as f64)
self.evaluate_total_cost(costs) - (*ks_coeff as f64) * actual_ks_cost;
(complexity_cut - actual_complexity) / (*ks_coeff as f64)
}

/// Evaluates the max fks cost expression on a set of cost values.
Expand All @@ -84,10 +92,10 @@ impl ComplexityExpression {
src_partition: PartitionIndex,
dst_partition: PartitionIndex,
) -> f64 {
let actual_fks_cost = costs.0.get(fast_keyswitch(src_partition, dst_partition));
let fks_coeff = self.0.get(fast_keyswitch(src_partition, dst_partition));
let actual_fks_cost = costs.0.get(&fast_keyswitch(src_partition, dst_partition));
let fks_coeff = self.0.get(&fast_keyswitch(src_partition, dst_partition));
let actual_complexity =
self.evaluate_total_cost(costs) - (fks_coeff as f64) * actual_fks_cost;
(complexity_cut - actual_complexity) / (fks_coeff as f64)
self.evaluate_total_cost(costs) - (*fks_coeff as f64) * actual_fks_cost;
(complexity_cut - actual_complexity) / (*fks_coeff as f64)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,18 @@ impl Feasible {

for constraint in &self.undominated_constraints {
let pbs_coeff = constraint
.noise_expression
.noise_evaluator
.as_ref()
.unwrap()
.coeff(bootstrap_noise(partition));
if pbs_coeff == 0.0 {
continue;
}
let actual_variance = constraint.noise_expression.evaluate(operations_variance)
let actual_variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance)
- pbs_coeff * actual_pbs_variance;
let pbs_max_variance = (constraint.safe_variance_bound - actual_variance) / pbs_coeff;
smallest_pbs_max_variance = smallest_pbs_max_variance.min(pbs_max_variance);
Expand All @@ -75,12 +81,18 @@ impl Feasible {

for constraint in &self.undominated_constraints {
let ks_coeff = constraint
.noise_expression
.noise_evaluator
.as_ref()
.unwrap()
.coeff(keyswitch_noise(src_partition, dst_partition));
if ks_coeff == 0.0 {
continue;
}
let actual_variance = constraint.noise_expression.evaluate(operations_variance)
let actual_variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance)
- ks_coeff * actual_ks_variance;
let ks_max_variance = (constraint.safe_variance_bound - actual_variance) / ks_coeff;
smallest_ks_max_variance = smallest_ks_max_variance.min(ks_max_variance);
Expand All @@ -102,12 +114,18 @@ impl Feasible {

for constraint in &self.undominated_constraints {
let fks_coeff = constraint
.noise_expression
.noise_evaluator
.as_ref()
.unwrap()
.coeff(fast_keyswitch_noise(src_partition, dst_partition));
if fks_coeff == 0.0 {
continue;
}
let actual_variance = constraint.noise_expression.evaluate(operations_variance)
let actual_variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance)
- fks_coeff * actual_fks_variance;
let fks_max_variance = (constraint.safe_variance_bound - actual_variance) / fks_coeff;
smallest_fks_max_variance = smallest_fks_max_variance.min(fks_max_variance);
Expand All @@ -126,7 +144,11 @@ impl Feasible {

fn local_feasible(&self, operations_variance: &NoiseValues) -> bool {
for constraint in &self.undominated_constraints {
if constraint.noise_expression.evaluate(operations_variance)
if constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance)
> constraint.safe_variance_bound
{
return false;
Expand All @@ -148,7 +170,11 @@ impl Feasible {
let mut worst_relative_variance = 0.0;
let mut worst_variance = 0.0;
for constraint in &self.undominated_constraints {
let variance = constraint.noise_expression.evaluate(operations_variance);
let variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance);
let relative_variance = variance / constraint.safe_variance_bound;
if relative_variance > worst_relative_variance {
worst_relative_variance = relative_variance;
Expand All @@ -167,7 +193,11 @@ impl Feasible {
fn global_p_error_with_cut(&self, operations_variance: &NoiseValues, cut: f64) -> Option<f64> {
let mut global_p_error = 0.0;
for constraint in &self.constraints {
let variance = constraint.noise_expression.evaluate(operations_variance);
let variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance);
let relative_variance = variance / constraint.safe_variance_bound;
let p_error = p_error_from_relative_variance(relative_variance, self.kappa);
global_p_error = combine_errors(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,27 @@ use std::{

use super::{
partitions::PartitionIndex,
symbolic::{Symbol, SymbolMap},
symbolic::{Symbol, SymbolArray, SymbolMap, SymbolScheme},
};

/// An ensemble of noise values for fhe operations.
#[derive(Debug, Clone, PartialEq)]
pub struct NoiseValues(SymbolMap<f64>);
pub struct NoiseValues(SymbolArray<f64>);

impl NoiseValues {
/// Returns an empty set of noise values.
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
NoiseValues(SymbolMap::new())
pub fn from_scheme(scheme: &SymbolScheme) -> NoiseValues {
NoiseValues(SymbolArray::from_scheme(scheme))
}

/// Sets the noise variance associated with a noise source.
pub fn set_variance(&mut self, source: NoiseSource, value: f64) {
self.0.set(source.0, value);
self.0.set(&source.0, value);
}

/// Returns the variance associated with a noise source
pub fn variance(&self, source: NoiseSource) -> f64 {
self.0.get(source.0)
*self.0.get(&source.0)
}
}

Expand All @@ -36,10 +35,36 @@ impl Display for NoiseValues {
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct NoiseEvaluator(SymbolArray<f64>);

impl NoiseEvaluator {
/// Returns a zero noise expression
pub fn from_scheme_and_expression(
scheme: &SymbolScheme,
expr: &NoiseExpression,
) -> NoiseEvaluator {
NoiseEvaluator(SymbolArray::from_scheme_and_map(scheme, &expr.0))
}

/// Returns the coefficient associated with a noise source.
pub fn coeff(&self, source: NoiseSource) -> f64 {
*self.0.get(&source.0)
}

/// Evaluate the noise expression on a set of noise values.
pub fn evaluate(&self, values: &NoiseValues) -> f64 {
self.0
.iter()
.zip(values.0.iter())
.fold(0.0, |acc, (coef, var)| acc + coef * var)
}
}

/// A noise expression, i.e. a sum of noise terms associating a noise source,
/// with a multiplicative coefficient.
#[derive(Debug, Clone, PartialEq)]
pub struct NoiseExpression(SymbolMap<f64>);
pub struct NoiseExpression(pub SymbolMap<f64>);

impl NoiseExpression {
/// Returns a zero noise expression
Expand Down Expand Up @@ -70,12 +95,12 @@ impl NoiseExpression {
lhs
}

/// Evaluate the noise expression on a set of noise values.
pub fn evaluate(&self, values: &NoiseValues) -> f64 {
self.terms_iter().fold(0.0, |acc, term| {
acc + term.coefficient * values.variance(term.source)
})
}
// /// Evaluate the noise expression on a set of noise values.
// pub fn evaluate(&self, values: &NoiseValues) -> f64 {
// self.terms_iter().fold(0.0, |acc, term| {
// acc + term.coefficient * values.variance(term.source)
// })
// }
}

impl Display for NoiseExpression {
Expand Down Expand Up @@ -196,7 +221,7 @@ impl Mul<NoiseSource> for f64 {

/// A symbolic source of noise, or a noise source variable.
#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub struct NoiseSource(Symbol);
pub struct NoiseSource(pub Symbol);

/// Returns an input noise source symbol.
pub fn input_noise(partition: PartitionIndex) -> NoiseSource {
Expand Down
Loading

0 comments on commit de14712

Please sign in to comment.