From 87b9f0e5dbce99c254e618d032a7bd17f3fa1c34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Thu, 14 Nov 2024 11:24:26 +0100 Subject: [PATCH] feat(optimizer): add generic keyset info generation --- .../include/concretelang/Common/Keysets.h | 5 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 86 +++++++++ .../compiler/lib/Common/Keysets.cpp | 94 ++++++++++ .../src/concrete-optimizer.rs | 30 ++++ .../src/cpp/concrete-optimizer.cpp | 72 ++++++++ .../src/cpp/concrete-optimizer.hpp | 17 ++ .../multi_parameters/generic_generation.rs | 169 ++++++++++++++++++ .../optimization/dag/multi_parameters/mod.rs | 1 + .../dag/multi_parameters/partitions.rs | 2 +- .../tests/compilation/test_restrictions.py | 32 +++- 10 files changed, 506 insertions(+), 2 deletions(-) create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h index 465ee76fd5..b6d47a20b6 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h @@ -6,6 +6,7 @@ #ifndef CONCRETELANG_COMMON_KEYSETS_H #define CONCRETELANG_COMMON_KEYSETS_H +#include "concrete-optimizer.hpp" #include "concrete-protocol.capnp.h" #include "concretelang/Common/Csprng.h" #include "concretelang/Common/Error.h" @@ -92,6 +93,10 @@ class KeysetCache { KeysetCache() = default; }; +Message generate_generic_keyset_info( + std::vector partitions, + bool generate_fks); + } // namespace keysets } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index c1183a8128..86ea598963 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -18,6 +18,7 @@ #include "concretelang/Support/Error.h" #include "concretelang/Support/V0Parameters.h" #include "concretelang/Support/logging.h" +#include #include #include #include @@ -645,6 +646,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( output.append(")"); return output; } + + bool operator==(LweSecretKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left == right; + } + + bool operator!=(LweSecretKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left != right; + } }; pybind11::class_(m, "LweSecretKeyParam") .def( @@ -659,6 +672,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](pybind11::object key) { return pybind11::hash(pybind11::repr(key)); }) + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) .doc() = "Parameters of an LWE Secret Key."; // ------------------------------------------------------------------------------// @@ -689,6 +704,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( output.append(")"); return output; } + + bool operator==(BootstrapKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left == right; + } + + bool operator!=(BootstrapKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left != right; + } }; pybind11::class_(m, "BootstrapKeyParam") .def( @@ -745,6 +772,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](pybind11::object key) { return pybind11::hash(pybind11::repr(key)); }) + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) .doc() = "Parameters of a Bootstrap key."; // ------------------------------------------------------------------------------// @@ -766,6 +795,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( output.append(")"); return output; } + + bool operator==(KeyswitchKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left == right; + } + + bool operator!=(KeyswitchKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left != right; + } }; pybind11::class_(m, "KeyswitchKeyParam") .def( @@ -804,6 +845,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](pybind11::object key) { return pybind11::hash(pybind11::repr(key)); }) + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) .doc() = "Parameters of a keyswitch key."; // ------------------------------------------------------------------------------// @@ -834,6 +877,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( output.append(")"); return output; } + + bool operator==(PackingKeyswitchKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left == right; + } + + bool operator!=(PackingKeyswitchKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left != right; + } }; pybind11::class_(m, "PackingKeyswitchKeyParam") .def( @@ -892,13 +947,44 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](pybind11::object key) { return pybind11::hash(pybind11::repr(key)); }) + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) .doc() = "Parameters of a packing keyswitch key."; + // ------------------------------------------------------------------------------// + // PARTITION DEFINITION // + // ------------------------------------------------------------------------------// + // + pybind11::class_( + m, "PartitionDefinition") + .def(init([](uint8_t precision, double norm2) + -> concrete_optimizer::utils::PartitionDefinition { + return concrete_optimizer::utils::PartitionDefinition{precision, + norm2}; + }), + arg("precision"), arg("norm2")) + .doc() = "Definition of a partition (in terms of precision in bits and " + "norm2 in value)."; + // ------------------------------------------------------------------------------// // KEYSET INFO // // ------------------------------------------------------------------------------// typedef Message KeysetInfo; pybind11::class_(m, "KeysetInfo") + .def_static( + "generate_generic", + [](std::vector + partitions, + bool generateFks) -> KeysetInfo { + if (partitions.size() < 2) { + throw std::runtime_error("Need at least two partition defs to " + "generate a generic keyset info."); + } + return ::concretelang::keysets::generate_generic_keyset_info( + partitions, generateFks); + }, + arg("partition_defs"), arg("generate_fks"), + "Generate a generic keyset info for a set of partition definitions") .def( "secret_keys", [](KeysetInfo &keysetInfo) { diff --git a/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp b/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp index 90d7f1d432..dabdcd052e 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp @@ -6,6 +6,7 @@ #include "concretelang/Common/Keysets.h" #include "capnp/message.h" #include "concrete-cpu.h" +#include "concrete-optimizer.hpp" #include "concrete-protocol.capnp.h" #include "concretelang/Common/Csprng.h" #include "concretelang/Common/Error.h" @@ -417,5 +418,98 @@ KeysetCache::getKeyset(const Message &keysetInfo, return std::move(keyset); } +Message generate_generic_keyset_info( + std::vector partitionDefs, + bool generateFks) { + auto output = Message{}; + rust::Vec rustPartitionDefs{}; + for (auto def : partitionDefs) { + rustPartitionDefs.push_back(def); + } + auto parameters = concrete_optimizer::utils::generate_generic_keyset_info( + rustPartitionDefs, generateFks); + + auto skLen = (int)parameters.secret_keys.size(); + auto skBuilder = output.asBuilder().initLweSecretKeys(skLen); + for (int i = 0; i < skLen; i++) { + auto output = Message(); + auto sk = parameters.secret_keys[i]; + output.asBuilder().setId(sk.identifier); + output.asBuilder().getParams().setIntegerPrecision(64); + output.asBuilder().getParams().setLweDimension(sk.polynomial_size * + sk.glwe_dimension); + output.asBuilder().getParams().setKeyType( + ::concreteprotocol::KeyType::BINARY); + skBuilder.setWithCaveats(i, output.asReader()); + } + + auto bskLen = (int)parameters.bootstrap_keys.size(); + auto bskBuilder = output.asBuilder().initLweBootstrapKeys(bskLen); + for (int i = 0; i < bskLen; i++) { + auto output = Message(); + auto bsk = parameters.bootstrap_keys[i]; + output.asBuilder().setId(bsk.identifier); + output.asBuilder().setInputId(bsk.input_key.identifier); + output.asBuilder().setOutputId(bsk.output_key.identifier); + output.asBuilder().getParams().setLevelCount( + bsk.br_decomposition_parameter.level); + output.asBuilder().getParams().setBaseLog( + bsk.br_decomposition_parameter.log2_base); + output.asBuilder().getParams().setGlweDimension( + bsk.output_key.glwe_dimension); + output.asBuilder().getParams().setPolynomialSize( + bsk.output_key.polynomial_size); + output.asBuilder().getParams().setInputLweDimension( + bsk.input_key.polynomial_size); + output.asBuilder().getParams().setIntegerPrecision(64); + output.asBuilder().getParams().setKeyType( + concreteprotocol::KeyType::BINARY); + bskBuilder.setWithCaveats(i, output.asReader()); + } + + auto kskLen = (int)parameters.keyswitch_keys.size(); + auto ckskLen = (int)parameters.conversion_keyswitch_keys.size(); + auto kskBuilder = output.asBuilder().initLweKeyswitchKeys(kskLen + ckskLen); + for (int i = 0; i < kskLen; i++) { + auto output = Message(); + auto ksk = parameters.keyswitch_keys[i]; + output.asBuilder().setId(ksk.identifier); + output.asBuilder().setInputId(ksk.input_key.identifier); + output.asBuilder().setOutputId(ksk.output_key.identifier); + output.asBuilder().getParams().setLevelCount( + ksk.ks_decomposition_parameter.level); + output.asBuilder().getParams().setBaseLog( + ksk.ks_decomposition_parameter.log2_base); + output.asBuilder().getParams().setIntegerPrecision(64); + output.asBuilder().getParams().setInputLweDimension( + ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size); + output.asBuilder().getParams().setOutputLweDimension( + ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size); + output.asBuilder().getParams().setKeyType( + concreteprotocol::KeyType::BINARY); + kskBuilder.setWithCaveats(i, output.asReader()); + } + for (int i = 0; i < ckskLen; i++) { + auto output = Message(); + auto ksk = parameters.conversion_keyswitch_keys[i]; + output.asBuilder().setId(ksk.identifier); + output.asBuilder().setInputId(ksk.input_key.identifier); + output.asBuilder().setOutputId(ksk.output_key.identifier); + output.asBuilder().getParams().setLevelCount( + ksk.ks_decomposition_parameter.level); + output.asBuilder().getParams().setBaseLog( + ksk.ks_decomposition_parameter.log2_base); + output.asBuilder().getParams().setIntegerPrecision(64); + output.asBuilder().getParams().setInputLweDimension( + ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size); + output.asBuilder().getParams().setOutputLweDimension( + ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size); + output.asBuilder().getParams().setKeyType( + concreteprotocol::KeyType::BINARY); + kskBuilder.setWithCaveats(i + kskLen, output.asReader()); + } + return output; +} + } // namespace keysets } // namespace concretelang diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index b195d12424..c76d31eab4 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -11,6 +11,7 @@ use concrete_optimizer::dag::operator::{ }; use concrete_optimizer::dag::unparametrized; use concrete_optimizer::optimization::config::{Config, SearchSpace}; +use concrete_optimizer::optimization::dag::multi_parameters::generic_generation::generate_generic_parameters; use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::CircuitSolution; use concrete_optimizer::optimization::dag::multi_parameters::optimize::{ KeysetRestriction, MacroParameters, NoSearchSpaceRestriction, RangeRestriction, @@ -913,6 +914,22 @@ fn location_from_string(string: &str) -> Box { } } +fn generate_generic_keyset_info( + inputs: Vec, + generate_fks: bool, +) -> ffi::CircuitKeys { + generate_generic_parameters( + inputs + .into_iter() + .map( + |ffi::PartitionDefinition { precision, norm2 }| concrete_optimizer::optimization::dag::multi_parameters::generic_generation::PartitionDefinition { precision, norm2 }, + ) + .collect(), + generate_fks, + ) + .into() +} + pub struct Weights(operator::Weights); fn vector(weights: &[i64]) -> Box { @@ -981,6 +998,12 @@ mod ffi { #[namespace = "concrete_optimizer::utils"] fn location_from_string(string: &str) -> Box; + #[namespace = "concrete_optimizer::utils"] + fn generate_generic_keyset_info( + partitions: Vec, + generate_fks: bool, + ) -> CircuitKeys; + #[namespace = "concrete_optimizer::utils"] fn get_external_partition( name: String, @@ -1359,6 +1382,13 @@ mod ffi { pub struct KeysetRestriction { pub info: KeysetInfo, } + + #[namespace = "concrete_optimizer::utils"] + #[derive(Debug, Clone)] + pub struct PartitionDefinition { + pub precision: u8, + pub norm2: f64, + } } fn processing_unit(options: &ffi::Options) -> ProcessingUnit { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 6c6c8075cb..fe34c2444c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -929,6 +929,13 @@ struct operator_new { }; } // namespace detail +template +union ManuallyDrop { + T value; + ManuallyDrop(T &&value) : value(::std::move(value)) {} + ~ManuallyDrop() {} +}; + template union MaybeUninit { T value; @@ -974,6 +981,9 @@ namespace concrete_optimizer { struct KeysetInfo; struct KeysetRestriction; } + namespace utils { + struct PartitionDefinition; + } } namespace concrete_optimizer { @@ -1387,6 +1397,18 @@ struct KeysetRestriction final { #endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction } // namespace restriction +namespace utils { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +#define CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +struct PartitionDefinition final { + ::std::uint8_t precision; + double norm2; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +} // namespace utils + namespace v0 { extern "C" { ::concrete_optimizer::v0::Solution concrete_optimizer$v0$cxxbridge1$optimize_bootstrap(::std::uint64_t precision, double noise_factor, ::concrete_optimizer::Options const &options) noexcept; @@ -1418,6 +1440,8 @@ ::concrete_optimizer::Location *concrete_optimizer$utils$cxxbridge1$location_unk ::concrete_optimizer::Location *concrete_optimizer$utils$cxxbridge1$location_from_string(::rust::Str string) noexcept; +void concrete_optimizer$utils$cxxbridge1$generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *partitions, bool generate_fks, ::CircuitKeys *return$) noexcept; + ::concrete_optimizer::ExternalPartition *concrete_optimizer$utils$cxxbridge1$get_external_partition(::rust::String *name, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t internal_dim, double max_variance, double variance) noexcept; double concrete_optimizer$utils$cxxbridge1$get_noise_br(::concrete_optimizer::Options const &options, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t lwe_dim, ::std::uint64_t pbs_level, ::std::uint64_t pbs_log2_base) noexcept; @@ -1560,6 +1584,13 @@ ::rust::Box<::concrete_optimizer::Location> location_from_string(::rust::Str str return ::rust::Box<::concrete_optimizer::Location>::from_raw(concrete_optimizer$utils$cxxbridge1$location_from_string(string)); } +::CircuitKeys generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> partitions, bool generate_fks) noexcept { + ::rust::ManuallyDrop<::rust::Vec<::concrete_optimizer::utils::PartitionDefinition>> partitions$(::std::move(partitions)); + ::rust::MaybeUninit<::CircuitKeys> return$; + concrete_optimizer$utils$cxxbridge1$generate_generic_keyset_info(&partitions$.value, generate_fks, &return$.value); + return ::std::move(return$.value); +} + ::rust::Box<::concrete_optimizer::ExternalPartition> get_external_partition(::rust::String name, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t internal_dim, double max_variance, double variance) noexcept { return ::rust::Box<::concrete_optimizer::ExternalPartition>::from_raw(concrete_optimizer$utils$cxxbridge1$get_external_partition(&name, log2_polynomial_size, glwe_dimension, internal_dim, max_variance, variance)); } @@ -1713,6 +1744,15 @@ ::concrete_optimizer::Location *cxxbridge1$box$concrete_optimizer$Location$alloc void cxxbridge1$box$concrete_optimizer$Location$dealloc(::concrete_optimizer::Location *) noexcept; void cxxbridge1$box$concrete_optimizer$Location$drop(::rust::Box<::concrete_optimizer::Location> *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$new(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$drop(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$len(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> const *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$capacity(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> const *ptr) noexcept; +::concrete_optimizer::utils::PartitionDefinition const *cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$data(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$reserve_total(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *ptr, ::std::size_t new_cap) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$set_len(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *ptr, ::std::size_t len) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$truncate(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *ptr, ::std::size_t len) noexcept; + ::concrete_optimizer::ExternalPartition *cxxbridge1$box$concrete_optimizer$ExternalPartition$alloc() noexcept; void cxxbridge1$box$concrete_optimizer$ExternalPartition$dealloc(::concrete_optimizer::ExternalPartition *) noexcept; void cxxbridge1$box$concrete_optimizer$ExternalPartition$drop(::rust::Box<::concrete_optimizer::ExternalPartition> *ptr) noexcept; @@ -1884,6 +1924,38 @@ void Box<::concrete_optimizer::Location>::drop() noexcept { cxxbridge1$box$concrete_optimizer$Location$drop(this); } template <> +Vec<::concrete_optimizer::utils::PartitionDefinition>::Vec() noexcept { + cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$new(this); +} +template <> +void Vec<::concrete_optimizer::utils::PartitionDefinition>::drop() noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$drop(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::utils::PartitionDefinition>::size() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$len(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::utils::PartitionDefinition>::capacity() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$capacity(this); +} +template <> +::concrete_optimizer::utils::PartitionDefinition const *Vec<::concrete_optimizer::utils::PartitionDefinition>::data() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$data(this); +} +template <> +void Vec<::concrete_optimizer::utils::PartitionDefinition>::reserve_total(::std::size_t new_cap) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$reserve_total(this, new_cap); +} +template <> +void Vec<::concrete_optimizer::utils::PartitionDefinition>::set_len(::std::size_t len) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$set_len(this, len); +} +template <> +void Vec<::concrete_optimizer::utils::PartitionDefinition>::truncate(::std::size_t len) { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$truncate(this, len); +} +template <> ::concrete_optimizer::ExternalPartition *Box<::concrete_optimizer::ExternalPartition>::allocation::alloc() noexcept { return cxxbridge1$box$concrete_optimizer$ExternalPartition$alloc(); } diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 336faa20ad..662493a2c4 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -955,6 +955,9 @@ namespace concrete_optimizer { struct KeysetInfo; struct KeysetRestriction; } + namespace utils { + struct PartitionDefinition; + } } namespace concrete_optimizer { @@ -1368,6 +1371,18 @@ struct KeysetRestriction final { #endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction } // namespace restriction +namespace utils { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +#define CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +struct PartitionDefinition final { + ::std::uint8_t precision; + double norm2; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +} // namespace utils + namespace v0 { ::concrete_optimizer::v0::Solution optimize_bootstrap(::std::uint64_t precision, double noise_factor, ::concrete_optimizer::Options const &options) noexcept; } // namespace v0 @@ -1381,6 +1396,8 @@ ::rust::Box<::concrete_optimizer::Location> location_unknown() noexcept; ::rust::Box<::concrete_optimizer::Location> location_from_string(::rust::Str string) noexcept; +::CircuitKeys generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> partitions, bool generate_fks) noexcept; + ::rust::Box<::concrete_optimizer::ExternalPartition> get_external_partition(::rust::String name, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t internal_dim, double max_variance, double variance) noexcept; double get_noise_br(::concrete_optimizer::Options const &options, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t lwe_dim, ::std::uint64_t pbs_level, ::std::uint64_t pbs_log2_base) noexcept; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs new file mode 100644 index 0000000000..0e0c4fc087 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs @@ -0,0 +1,169 @@ +use crate::{ + computing_cost::cpu::CpuComplexity, + config::ProcessingUnit, + dag::{ + operator::{FunctionTable, LevelledComplexity, Precision, Shape}, + unparametrized, + }, + optimization::{ + config::{Config, SearchSpace}, + decomposition::{self}, + }, +}; + +use super::{ + keys_spec::{CircuitKeys, ExpandedCircuitKeys}, + optimize::{optimize, NoSearchSpaceRestriction}, + partition_cut::PartitionCut, + PartitionIndex, +}; + +const _4_SIGMA: f64 = 0.000_063_342_483_999_973; + +#[derive(Debug, Clone, PartialEq)] +pub struct PartitionDefinition { + pub precision: Precision, + pub norm2: f64, +} + +impl PartialOrd for PartitionDefinition { + fn partial_cmp(&self, other: &Self) -> Option { + match self.precision.cmp(&other.precision) { + std::cmp::Ordering::Equal => self.norm2.partial_cmp(&other.norm2), + ordering => Some(ordering), + } + } +} + +pub fn generate_generic_parameters( + partitions: Vec, + generate_fks: bool, +) -> CircuitKeys { + let mut dag = unparametrized::Dag::new(); + + for def_a in partitions.iter() { + for def_b in partitions.iter() { + if def_a == def_b { + continue; + } + let inp_a = dag.add_input(def_a.precision, Shape::number()); + let lut_a = dag.add_lut(inp_a, FunctionTable::UNKWOWN, def_a.precision); + let _weighted_a = dag.add_linear_noise( + [lut_a], + LevelledComplexity::ZERO, + [def_a.norm2.sqrt()], + Shape::number(), + "", + ); + + let inp_b = dag.add_input(def_b.precision, Shape::number()); + let lut_b = dag.add_lut(inp_b, FunctionTable::UNKWOWN, def_b.precision); + let weighted_b = dag.add_linear_noise( + [lut_b], + LevelledComplexity::ZERO, + [def_b.norm2.sqrt()], + Shape::number(), + "", + ); + + dag.add_composition(weighted_b, inp_a); + + if generate_fks && def_a > def_b { + let inp_a = dag.add_input(def_a.precision, Shape::number()); + let lut_a = dag.add_lut(inp_a, FunctionTable::UNKWOWN, def_a.precision); + let _weighted_a = dag.add_linear_noise( + [lut_a], + LevelledComplexity::ZERO, + [def_a.norm2.sqrt()], + Shape::number(), + "", + ); + + let inp_b = dag.add_input(def_b.precision, Shape::number()); + let lut_b = dag.add_lut(inp_b, FunctionTable::UNKWOWN, def_b.precision); + let _weighted_b = dag.add_linear_noise( + [lut_b], + LevelledComplexity::ZERO, + [def_b.norm2.sqrt()], + Shape::number(), + "", + ); + + let _ = dag.add_linear_noise( + [lut_a, lut_b], + LevelledComplexity::ZERO, + [0., 0.], + Shape::number(), + "", + ); + } + } + } + + let precisions: Vec<_> = partitions.iter().map(|def| def.precision).collect(); + let n_partitions = precisions.len(); + let p_cut = PartitionCut::maximal_partitionning(&dag); + let config = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: true, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + }; + let search_space = SearchSpace::default_cpu(); + let cache = decomposition::cache(128, ProcessingUnit::Cpu, None, true, 64, 53); + let parameters = optimize( + &dag, + config, + &search_space, + &NoSearchSpaceRestriction, + &cache, + &Some(p_cut), + PartitionIndex(0), + ) + .map_or(None, |v| Some(v.1)) + .unwrap(); + + for i in 0..n_partitions { + for j in 0..n_partitions { + assert!( + parameters.micro_params.ks[i][j].is_some(), + "Ksk[{i},{j}] missing." + ); + if i > j { + assert!( + parameters.micro_params.fks[i][j].is_some(), + "Fksk[{i},{j}] missing." + ); + } + } + } + ExpandedCircuitKeys::of(¶meters).compacted() +} + +#[cfg(test)] +mod test { + use super::{generate_generic_parameters, PartitionDefinition}; + + #[test] + fn test_generate_generic_parameters() { + let _ = generate_generic_parameters( + vec![ + PartitionDefinition { + precision: 3, + norm2: 1., + }, + PartitionDefinition { + precision: 3, + norm2: 100., + }, + PartitionDefinition { + precision: 3, + norm2: 1000., + }, + ], + true, + ); + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index e0b66a59e7..7e88ea7bae 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod analyze; mod complexity; mod fast_keyswitch; mod feasible; +pub mod generic_generation; pub mod keys_spec; pub mod optimize; pub mod optimize_generic; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs index 18dd8050da..72a7e14b8d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs @@ -9,7 +9,7 @@ use crate::dag::operator::OperatorIndex; use super::partition_cut::PartitionCut; #[derive(Clone, Debug, PartialEq, Eq, Default, PartialOrd, Ord, Hash, Copy)] -pub struct PartitionIndex(pub(crate) usize); +pub struct PartitionIndex(pub usize); impl PartitionIndex { pub const FIRST: Self = Self(0); diff --git a/frontends/concrete-python/tests/compilation/test_restrictions.py b/frontends/concrete-python/tests/compilation/test_restrictions.py index de96cb85b1..ab42c5a520 100644 --- a/frontends/concrete-python/tests/compilation/test_restrictions.py +++ b/frontends/concrete-python/tests/compilation/test_restrictions.py @@ -4,7 +4,12 @@ import numpy as np import pytest -from mlir._mlir_libs._concretelang._compiler import KeysetRestriction, RangeRestriction +from mlir._mlir_libs._concretelang._compiler import ( + KeysetInfo, + KeysetRestriction, + PartitionDefinition, + RangeRestriction, +) from concrete import fhe @@ -96,3 +101,28 @@ def inc(x): restricted_keyset_info = restricted_module.keys.specs.program_info.get_keyset_info() assert big_keyset_info == restricted_keyset_info assert small_keyset_info != restricted_keyset_info + + +def test_generic_restriction(): + """ + Test that compiling a module works. + """ + + generic_keyset_info = KeysetInfo.generate_generic( + [PartitionDefinition(8, 10.0), PartitionDefinition(10, 10000.0)], True + ) + + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 200 + + inputset = [np.random.randint(1, 200, size=()) for _ in range(100)] + restricted_module = Module.compile( + {"inc": inputset}, + enable_unsafe_features=True, + keyset_restriction=generic_keyset_info.get_restriction(), + ) + compiled_keyset_info = restricted_module.keys.specs.program_info.get_keyset_info() + assert all([k in generic_keyset_info.secret_keys() for k in compiled_keyset_info.secret_keys()])