Skip to content

Commit

Permalink
Merge pull request #1145 from zama-ai/alex/optimizer_keyset_generation
Browse files Browse the repository at this point in the history
feat(optimizer): add generic keyset info generation
  • Loading branch information
aPere3 authored Dec 2, 2024
2 parents fce0550 + af46cf4 commit 819e6da
Show file tree
Hide file tree
Showing 15 changed files with 738 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef CONCRETELANG_COMMON_KEYSETS_H
#define CONCRETELANG_COMMON_KEYSETS_H

#include "concrete-optimizer.hpp"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Error.h"
Expand Down Expand Up @@ -92,6 +93,10 @@ class KeysetCache {
KeysetCache() = default;
};

Message<concreteprotocol::KeysetInfo> keysetInfoFromVirtualCircuit(
std::vector<concrete_optimizer::utils::PartitionDefinition> partitions,
bool generate_fks, std::optional<concrete_optimizer::Options> options);

} // namespace keysets
} // namespace concretelang

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@
// Exceptions. See
// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_SECURITY_H
#define CONCRETELANG_COMMON_SECURITY_H

#ifndef CONCRETELANG_SUPPORT_V0CURVES_H_
#define CONCRETELANG_SUPPORT_V0CURVES_H_

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <vector>

namespace concrete {
namespace concretelang {
namespace security {

enum KeyFormat {
BINARY,
Expand Down Expand Up @@ -42,31 +37,16 @@ struct SecurityCurve {
/// @param polynomialSize The size of the polynom of the glwe
/// @param logQ The log of q
/// @return The secure encryption variances
double getVariance(int glweDimension, int polynomialSize, int logQ) {
auto size = glweDimension * polynomialSize;
if (size < minimalLweDimension) {
return NAN;
}
auto a = std::pow(2, (slope * size + bias) * 2);
auto b = std::pow(2, -2 * (logQ - 2));
return a > b ? a : b;
}
double getVariance(int glweDimension, int polynomialSize, int logQ);
};

#include "curves.gen.h"

/// @brief Return the security curve for a given level and a key format.
/// @param bitsOfSecurity The number of bits of security
/// @param keyFormat The format of the key
/// @return The security curve or nullptr if the curve is not found.
SecurityCurve *getSecurityCurve(int bitsOfSecurity, KeyFormat keyFormat) {
for (size_t i = 0; i < curvesLen; i++) {
if (curves[i].bits == bitsOfSecurity && curves[i].keyFormat == keyFormat)
return &curves[i];
}
return nullptr;
}
SecurityCurve *getSecurityCurve(int bitsOfSecurity, KeyFormat keyFormat);

} // namespace concrete
} // namespace security
} // namespace concretelang

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,122 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.doc() = "Allow to restrict the optimizer search space to be compatible "
"with a keyset.";

// ------------------------------------------------------------------------------//
// OPTIMIZER OPTIONS //
// ------------------------------------------------------------------------------//
pybind11::class_<concrete_optimizer::Options>(m, "OptimizerOptions")
.def(
"set_security_level",
[](concrete_optimizer::Options &options, uint64_t security_level) {
options.security_level = security_level;
},
"Set option for security level.", arg("security_level"))
.def(
"set_maximum_acceptable_error_probability",
[](concrete_optimizer::Options &options,
double maximum_acceptable_error_probability) {
options.maximum_acceptable_error_probability =
maximum_acceptable_error_probability;
},
"Set option for maximum acceptable error probability.",
arg("maximum_acceptable_error_probability"))
.def(
"set_key_sharing",
[](concrete_optimizer::Options &options, bool key_sharing) {
options.key_sharing = key_sharing;
},
"Set option for key sharing.", arg("key_sharing"))
.def(
"set_multi_param_strategy_to_by_precision",
[](concrete_optimizer::Options &options) {
options.multi_param_strategy =
concrete_optimizer::MultiParamStrategy::ByPrecision;
},
"Set option for multi param strategy to by-precision.")
.def(
"set_multi_param_strategy_to_by_precision_and_norm_2",
[](concrete_optimizer::Options &options) {
options.multi_param_strategy =
concrete_optimizer::MultiParamStrategy::ByPrecisionAndNorm2;
},
"Set option for multi param strategy to by-precision-and-norm2.")
.def(
"set_default_log_norm2_woppbs",
[](concrete_optimizer::Options &options,
double default_log_norm2_woppbs) {
options.default_log_norm2_woppbs = default_log_norm2_woppbs;
},
"Set option for default log norm2 woppbs.",
arg("default_log_norm2_woppbs"))
.def(
"set_use_gpu_constraints",
[](concrete_optimizer::Options &options, bool use_gpu_constraints) {
options.use_gpu_constraints = use_gpu_constraints;
},
"Set option for use gpu constrints.", arg("use_gpu_constraints"))
.def(
"set_encoding_to_auto",
[](concrete_optimizer::Options &options) {
options.encoding = concrete_optimizer::Encoding::Auto;
},
"Set option for encoding to auto.")
.def(
"set_encoding_to_crt",
[](concrete_optimizer::Options &options) {
options.encoding = concrete_optimizer::Encoding::Crt;
},
"Set option for encoding to crt.")
.def(
"set_encoding_to_native",
[](concrete_optimizer::Options &options) {
options.encoding = concrete_optimizer::Encoding::Native;
},
"Set option for encoding to native.")
.def(
"set_cache_on_disk",
[](concrete_optimizer::Options &options, bool cache_on_disk) {
options.cache_on_disk = cache_on_disk;
},
"Set option for cache on disk.", arg("cache_on_disk"))
.def(
"set_ciphertext_modulus_log",
[](concrete_optimizer::Options &options,
uint32_t ciphertext_modulus_log) {
options.ciphertext_modulus_log = ciphertext_modulus_log;
},
"Set option for ciphertext modulus log.",
arg("ciphertext_modulus_log"))
.def(
"set_fft_precision",
[](concrete_optimizer::Options &options, uint32_t fft_precision) {
options.fft_precision = fft_precision;
},
"Set option for fft precision.", arg("fft_precision"))
.def(
"set_fft_precision",
[](concrete_optimizer::Options &options, uint32_t fft_precision) {
options.fft_precision = fft_precision;
},
"Set option for fft precision.", arg("fft_precision"))
.def(
"set_range_restriction",
[](concrete_optimizer::Options &options,
concrete_optimizer::restriction::RangeRestriction restriction) {
options.range_restriction = std::make_shared<
concrete_optimizer::restriction::RangeRestriction>(restriction);
},
"Set option for range restriction", arg("restriction"))
.def(
"set_keyset_restriction",
[](concrete_optimizer::Options &options,
concrete_optimizer::restriction::KeysetRestriction restriction) {
options.keyset_restriction = std::make_shared<
concrete_optimizer::restriction::KeysetRestriction>(
restriction);
},
"Set option for keyset restriction", arg("restriction"))
.doc() = "Options for the optimizer.";

// ------------------------------------------------------------------------------//
// COMPILATION OPTIONS //
// ------------------------------------------------------------------------------//
Expand Down Expand Up @@ -645,6 +761,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(LweSecretKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(LweSecretKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<LweSecretKeyParam>(m, "LweSecretKeyParam")
.def(
Expand All @@ -659,6 +787,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of an LWE Secret Key.";

// ------------------------------------------------------------------------------//
Expand Down Expand Up @@ -689,6 +819,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(BootstrapKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(BootstrapKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<BootstrapKeyParam>(m, "BootstrapKeyParam")
.def(
Expand Down Expand Up @@ -745,6 +887,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a Bootstrap key.";

// ------------------------------------------------------------------------------//
Expand All @@ -766,6 +910,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(KeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(KeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<KeyswitchKeyParam>(m, "KeyswitchKeyParam")
.def(
Expand Down Expand Up @@ -804,6 +960,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a keyswitch key.";

// ------------------------------------------------------------------------------//
Expand Down Expand Up @@ -834,6 +992,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(PackingKeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(PackingKeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<PackingKeyswitchKeyParam>(m, "PackingKeyswitchKeyParam")
.def(
Expand Down Expand Up @@ -892,13 +1062,46 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a packing keyswitch key.";

// ------------------------------------------------------------------------------//
// PARTITION DEFINITION //
// ------------------------------------------------------------------------------//
//
pybind11::class_<concrete_optimizer::utils::PartitionDefinition>(
m, "PartitionDefinition")
.def(init([](uint8_t precision, double norm2)
-> concrete_optimizer::utils::PartitionDefinition {
return concrete_optimizer::utils::PartitionDefinition{precision,
norm2};
}),
arg("precision"), arg("norm2"))
.doc() = "Definition of a partition (in terms of precision in bits and "
"norm2 in value).";

// ------------------------------------------------------------------------------//
// KEYSET INFO //
// ------------------------------------------------------------------------------//
typedef Message<concreteprotocol::KeysetInfo> KeysetInfo;
pybind11::class_<KeysetInfo>(m, "KeysetInfo")
.def_static(
"generate_virtual",
[](std::vector<concrete_optimizer::utils::PartitionDefinition>
partitions,
bool generateFks,
std::optional<concrete_optimizer::Options> options) -> KeysetInfo {
if (partitions.size() < 2) {
throw std::runtime_error("Need at least two partition defs to "
"generate a virtual keyset info.");
}
return ::concretelang::keysets::keysetInfoFromVirtualCircuit(
partitions, generateFks, options);
},
arg("partition_defs"), arg("generate_fks"),
arg("options") = std::nullopt,
"Generate a generic keyset info for a set of partition definitions")
.def(
"secret_keys",
[](KeysetInfo &keysetInfo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_library(
Keys.cpp
Keysets.cpp
Transformers.cpp
Security.cpp
Values.cpp
DEPENDS
concrete-protocol
Expand Down
Loading

0 comments on commit 819e6da

Please sign in to comment.