Skip to content

Commit

Permalink
feat(optimizer): add support for circuits and composition rules
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed May 15, 2024
1 parent b96fece commit ac848e9
Show file tree
Hide file tree
Showing 36 changed files with 1,499 additions and 719 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ namespace optimizer {
std::unique_ptr<mlir::Pass> createDagPass(optimizer::Config config,
concrete_optimizer::Dag &dag);

void applyCompositionRules(optimizer::Config config,
concrete_optimizer::Dag &dag);

} // namespace optimizer
} // namespace concretelang
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,21 @@ std::unique_ptr<mlir::Pass> createDagPass(optimizer::Config config,
return std::make_unique<optimizer::DagPass>(config, dag);
}

// Adds the composition rules to the
void applyCompositionRules(optimizer::Config config,
concrete_optimizer::Dag &dag) {

if (config.composable) {
auto inputs = dag.get_input_indices();
auto outputs = dag.get_output_indices();
dag.add_compositions(
rust::Slice<const concrete_optimizer::dag::OperatorIndex>(
outputs.data(), outputs.size()),
rust::Slice<const concrete_optimizer::dag::OperatorIndex>(
inputs.data(), inputs.size()));
}
}

} // namespace optimizer
} // namespace concretelang
} // namespace mlir
3 changes: 3 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
if (pm.run(module.getOperation()).failed()) {
return StreamStringError() << "Failed to create concrete-optimizer dag\n";
}
optimizer::applyCompositionRules(config, *dag);

std::optional<optimizer::Description> description;

if (!constraint) {
description = std::nullopt;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ concrete_optimizer::Options options_from_config(optimizer::Config config) {
/* .encoding = */ config.encoding,
/* .cache_on_disk = */ config.cache_on_disk,
/* .ciphertext_modulus_log = */ config.ciphertext_modulus_log,
/* .fft_precision = */ config.fft_precision,
/* .composable = */ config.composable};
/* .fft_precision = */ config.fft_precision};
return options;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ func.func @main(%arg0: tensor<5x!FHE.eint<5>>) -> !FHE.eint<5> {
%weights = arith.constant dense<[-1, -1, -1, -1, -1]> : tensor<5xi6>
%tlu = arith.constant dense<[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64>
%0 = "FHELinalg.apply_lookup_table"(%arg0, %tlu) : (tensor<5x!FHE.eint<5>>, tensor<32xi64>) -> tensor<5x!FHE.eint<5>>
// CHECK: Dot { [[a:.*]], weights: ClearTensor { shape: Shape { dimensions_size: [5] }, values: [-1, -1, -1, -1, -1] } }
// CHECK: Dot { [[a:.*]], weights: ClearTensor { shape: Shape { dimensions_size: [5] }, values: [-1, -1, -1, -1, -1] }, kind: Tensor }
%1 = "FHELinalg.dot_eint_int"(%0, %weights) : (tensor<5x!FHE.eint<5>>, tensor<5xi6>) -> !FHE.eint<5>
return %1 : !FHE.eint<5>
}
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,16 @@ TEST(CompileNotComposable, not_composable_2) {
TestProgram circuit(options);
auto err = circuit.compile(R"XXX(
func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) {
%cst_1 = arith.constant 1 : i4
%cst_1 = arith.constant 2 : i4
%cst_2 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
%1 = "FHE.add_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
%1 = "FHE.mul_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
%2 = "FHE.apply_lookup_table"(%1, %cst_2): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>)
return %1, %2: !FHE.eint<3>, !FHE.eint<3>
}
)XXX");
ASSERT_OUTCOME_HAS_FAILURE_WITH_ERRORMSG(
err, "Program can not be composed: Output 1 has variance 1σ²In[0].");
err, "Program can not be composed: Dag is not composable, because of "
"output 1: Partition 0 has input coefficient 4");
}

TEST(CompileComposable, composable_supported_dag_mono) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
ciphertext_modulus_log,
fft_precision,
complexity_model: &CpuComplexity::default(),
composable: false,
};

let cache = decomposition::cache(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
ciphertext_modulus_log,
fft_precision,
complexity_model: &CpuComplexity::default(),
composable: false,
};

let cache = decomposition::cache(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ fn optimize_bootstrap(precision: u64, noise_factor: f64, options: ffi::Options)
ciphertext_modulus_log: options.ciphertext_modulus_log,
fft_precision: options.fft_precision,
complexity_model: &CpuComplexity::default(),
composable: options.composable,
};

let sum_size = 1;
Expand Down Expand Up @@ -489,6 +488,20 @@ impl Dag {
self.0.viz_string()
}

fn get_input_indices(&self) -> Vec<ffi::OperatorIndex> {
self.0
.get_input_operators_iter()
.map(|n| ffi::OperatorIndex { index: n.id.0 })
.collect()
}

fn get_output_indices(&self) -> Vec<ffi::OperatorIndex> {
self.0
.get_output_operators_iter()
.map(|n| ffi::OperatorIndex { index: n.id.0 })
.collect()
}

fn optimize(&self, options: ffi::Options) -> ffi::DagSolution {
let processing_unit = processing_unit(options);
let config = Config {
Expand All @@ -498,13 +511,12 @@ impl Dag {
ciphertext_modulus_log: options.ciphertext_modulus_log,
fft_precision: options.fft_precision,
complexity_model: &CpuComplexity::default(),
composable: options.composable,
};

let search_space = SearchSpace::default(processing_unit);

let encoding = options.encoding.into();
if options.composable {
if self.0.is_composed() {
let circuit_sol =
concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize(
&self.0,
Expand Down Expand Up @@ -535,6 +547,18 @@ impl Dag {
self.0.get_circuit_count()
}

fn add_compositions(&mut self, froms: &[ffi::OperatorIndex], tos: &[ffi::OperatorIndex]) {
self.0.add_compositions(
froms
.iter()
.map(|a| OperatorIndex(a.index))
.collect::<Vec<_>>(),
tos.iter()
.map(|a| OperatorIndex(a.index))
.collect::<Vec<_>>(),
);
}

fn optimize_multi(&self, options: ffi::Options) -> ffi::CircuitSolution {
let processing_unit = processing_unit(options);
let config = Config {
Expand All @@ -544,7 +568,6 @@ impl Dag {
ciphertext_modulus_log: options.ciphertext_modulus_log,
fft_precision: options.fft_precision,
complexity_model: &CpuComplexity::default(),
composable: options.composable,
};
let search_space = SearchSpace::default(processing_unit);

Expand Down Expand Up @@ -763,6 +786,8 @@ mod ffi {

fn optimize(self: &Dag, options: Options) -> DagSolution;

fn add_compositions(self: &mut Dag, froms: &[OperatorIndex], tos: &[OperatorIndex]);

#[namespace = "concrete_optimizer::dag"]
fn dump(self: &CircuitSolution) -> String;

Expand All @@ -781,6 +806,10 @@ mod ffi {

fn optimize_multi(self: &Dag, options: Options) -> CircuitSolution;

fn get_input_indices(self: &Dag) -> Vec<OperatorIndex>;

fn get_output_indices(self: &Dag) -> Vec<OperatorIndex>;

fn NO_KEY_ID() -> u64;
}

Expand Down Expand Up @@ -857,7 +886,6 @@ mod ffi {
pub cache_on_disk: bool,
pub ciphertext_modulus_log: u32,
pub fft_precision: u32,
pub composable: bool,
}

#[namespace = "concrete_optimizer::dag"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -972,8 +972,11 @@ struct Dag final : public ::rust::Opaque {
::rust::Box<::concrete_optimizer::DagBuilder> builder(::rust::String circuit) noexcept;
::rust::String dump() const noexcept;
::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept;
void add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept;
::std::size_t get_circuit_count() const noexcept;
::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept;
::rust::Vec<::concrete_optimizer::dag::OperatorIndex> get_input_indices() const noexcept;
::rust::Vec<::concrete_optimizer::dag::OperatorIndex> get_output_indices() const noexcept;
~Dag() = delete;

private:
Expand Down Expand Up @@ -1111,7 +1114,6 @@ struct Options final {
bool cache_on_disk;
::std::uint32_t ciphertext_modulus_log;
::std::uint32_t fft_precision;
bool composable;

using IsRelocatable = ::std::true_type;
};
Expand Down Expand Up @@ -1315,6 +1317,8 @@ ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilde
void concrete_optimizer$cxxbridge1$DagBuilder$tag_operator_as_output(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex op) noexcept;

void concrete_optimizer$cxxbridge1$Dag$optimize(::concrete_optimizer::Dag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::DagSolution *return$) noexcept;

void concrete_optimizer$cxxbridge1$Dag$add_compositions(::concrete_optimizer::Dag &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept;
} // extern "C"

namespace dag {
Expand Down Expand Up @@ -1343,6 +1347,10 @@ ::std::size_t concrete_optimizer$cxxbridge1$Dag$get_circuit_count(::concrete_opt

void concrete_optimizer$cxxbridge1$Dag$optimize_multi(::concrete_optimizer::Dag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept;

void concrete_optimizer$cxxbridge1$Dag$get_input_indices(::concrete_optimizer::Dag const &self, ::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *return$) noexcept;

void concrete_optimizer$cxxbridge1$Dag$get_output_indices(::concrete_optimizer::Dag const &self, ::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *return$) noexcept;

::std::uint64_t concrete_optimizer$cxxbridge1$NO_KEY_ID() noexcept;
} // extern "C"

Expand Down Expand Up @@ -1438,6 +1446,10 @@ ::concrete_optimizer::dag::DagSolution Dag::optimize(::concrete_optimizer::Optio
return ::std::move(return$.value);
}

void Dag::add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept {
concrete_optimizer$cxxbridge1$Dag$add_compositions(*this, froms, tos);
}

namespace dag {
::rust::String CircuitSolution::dump() const noexcept {
::rust::MaybeUninit<::rust::String> return$;
Expand Down Expand Up @@ -1480,6 +1492,18 @@ ::concrete_optimizer::dag::CircuitSolution Dag::optimize_multi(::concrete_optimi
return ::std::move(return$.value);
}

::rust::Vec<::concrete_optimizer::dag::OperatorIndex> Dag::get_input_indices() const noexcept {
::rust::MaybeUninit<::rust::Vec<::concrete_optimizer::dag::OperatorIndex>> return$;
concrete_optimizer$cxxbridge1$Dag$get_input_indices(*this, &return$.value);
return ::std::move(return$.value);
}

::rust::Vec<::concrete_optimizer::dag::OperatorIndex> Dag::get_output_indices() const noexcept {
::rust::MaybeUninit<::rust::Vec<::concrete_optimizer::dag::OperatorIndex>> return$;
concrete_optimizer$cxxbridge1$Dag$get_output_indices(*this, &return$.value);
return ::std::move(return$.value);
}

::std::uint64_t NO_KEY_ID() noexcept {
return concrete_optimizer$cxxbridge1$NO_KEY_ID();
}
Expand All @@ -1498,6 +1522,15 @@ ::concrete_optimizer::Weights *cxxbridge1$box$concrete_optimizer$Weights$alloc()
void cxxbridge1$box$concrete_optimizer$Weights$dealloc(::concrete_optimizer::Weights *) noexcept;
void cxxbridge1$box$concrete_optimizer$Weights$drop(::rust::Box<::concrete_optimizer::Weights> *ptr) noexcept;

void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$new(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> const *ptr) noexcept;
void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$drop(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr) noexcept;
::std::size_t cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$len(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> const *ptr) noexcept;
::std::size_t cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$capacity(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> const *ptr) noexcept;
::concrete_optimizer::dag::OperatorIndex const *cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$data(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> const *ptr) noexcept;
void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$reserve_total(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr, ::std::size_t new_cap) noexcept;
void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$set_len(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr, ::std::size_t len) noexcept;
void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$truncate(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr, ::std::size_t len) noexcept;

void cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$new(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> const *ptr) noexcept;
void cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$drop(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> *ptr) noexcept;
::std::size_t cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$len(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> const *ptr) noexcept;
Expand Down Expand Up @@ -1601,6 +1634,38 @@ void Box<::concrete_optimizer::Weights>::drop() noexcept {
cxxbridge1$box$concrete_optimizer$Weights$drop(this);
}
template <>
Vec<::concrete_optimizer::dag::OperatorIndex>::Vec() noexcept {
cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$new(this);
}
template <>
void Vec<::concrete_optimizer::dag::OperatorIndex>::drop() noexcept {
return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$drop(this);
}
template <>
::std::size_t Vec<::concrete_optimizer::dag::OperatorIndex>::size() const noexcept {
return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$len(this);
}
template <>
::std::size_t Vec<::concrete_optimizer::dag::OperatorIndex>::capacity() const noexcept {
return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$capacity(this);
}
template <>
::concrete_optimizer::dag::OperatorIndex const *Vec<::concrete_optimizer::dag::OperatorIndex>::data() const noexcept {
return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$data(this);
}
template <>
void Vec<::concrete_optimizer::dag::OperatorIndex>::reserve_total(::std::size_t new_cap) noexcept {
return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$reserve_total(this, new_cap);
}
template <>
void Vec<::concrete_optimizer::dag::OperatorIndex>::set_len(::std::size_t len) noexcept {
return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$set_len(this, len);
}
template <>
void Vec<::concrete_optimizer::dag::OperatorIndex>::truncate(::std::size_t len) {
return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$truncate(this, len);
}
template <>
Vec<::concrete_optimizer::dag::SecretLweKey>::Vec() noexcept {
cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$new(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -953,8 +953,11 @@ struct Dag final : public ::rust::Opaque {
::rust::Box<::concrete_optimizer::DagBuilder> builder(::rust::String circuit) noexcept;
::rust::String dump() const noexcept;
::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept;
void add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept;
::std::size_t get_circuit_count() const noexcept;
::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept;
::rust::Vec<::concrete_optimizer::dag::OperatorIndex> get_input_indices() const noexcept;
::rust::Vec<::concrete_optimizer::dag::OperatorIndex> get_output_indices() const noexcept;
~Dag() = delete;

private:
Expand Down Expand Up @@ -1092,7 +1095,6 @@ struct Options final {
bool cache_on_disk;
::std::uint32_t ciphertext_modulus_log;
::std::uint32_t fft_precision;
bool composable;

using IsRelocatable = ::std::true_type;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ concrete_optimizer::Options default_options() {
.cache_on_disk = true,
.ciphertext_modulus_log = CIPHERTEXT_MODULUS_LOG,
.fft_precision = 53,
.composable = false
};
}

#define TEST static void

TEST test_v0() {
auto options = default_options();
options.composable = true;
concrete_optimizer::v0::Solution solution =
concrete_optimizer::v0::optimize_bootstrap(
PRECISION_1B, NOISE_DEVIATION_COEFF, options);
Expand Down Expand Up @@ -261,7 +259,8 @@ TEST test_composable_dag_mono_fallback_on_dag_multi() {
assert(!solution1.use_wop_pbs);
assert(solution1.p_error < options.maximum_acceptable_error_probability);

options.composable = true;
std::vector<concrete_optimizer::dag::OperatorIndex> froms{id};
dag->add_compositions(slice(froms), slice(inputs));
auto solution2 = dag->optimize(options);
assert(!solution2.use_wop_pbs);
assert(solution2.p_error < options.maximum_acceptable_error_probability);
Expand Down Expand Up @@ -298,7 +297,8 @@ TEST test_non_composable_dag_mono_fallback_on_woppbs() {
assert(!solution1.use_wop_pbs);
assert(solution1.p_error < options.maximum_acceptable_error_probability);

options.composable = true;
std::vector<concrete_optimizer::dag::OperatorIndex> froms{id};
dag->add_compositions(slice(froms), slice(inputs));
auto solution2 = dag->optimize(options);
assert(solution2.p_error < options.maximum_acceptable_error_probability);
assert(solution1.complexity < solution2.complexity);
Expand Down
Loading

0 comments on commit ac848e9

Please sign in to comment.