Skip to content

Commit

Permalink
Merge pull request #1104 from zama-ai/alex/optimizer_max_2
Browse files Browse the repository at this point in the history
fix(optimizer): add zero noise and max noise ops
  • Loading branch information
aPere3 authored Oct 21, 2024
2 parents cd8208e + 5205a2f commit 7f72c2b
Show file tree
Hide file tree
Showing 35 changed files with 1,207 additions and 1,261 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure, ZeroNoise]> {
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$tensor);
}

def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, AdditiveNoise, DeclareOpInterfaceMethods<Binary>]> {
def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, MaxNoise, DeclareOpInterfaceMethods<Binary>]> {
let summary = "Adds an encrypted integer and a clear integer";

let description = [{
Expand Down Expand Up @@ -118,7 +118,7 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [Pure, BinaryEint, AdditiveNoise, Declare
let hasVerifier = 1;
}

def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint, AdditiveNoise]> {
def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint, MaxNoise]> {
let summary = "Subtract an encrypted integer from a clear integer";

let description = [{
Expand Down Expand Up @@ -150,7 +150,7 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint, AdditiveNois
let hasVerifier = 1;
}

def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, AdditiveNoise, DeclareOpInterfaceMethods<Binary>]> {
def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, MaxNoise, DeclareOpInterfaceMethods<Binary>]> {
let summary = "Subtract a clear integer from an encrypted integer";

let description = [{
Expand Down Expand Up @@ -215,7 +215,7 @@ def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure, BinaryEint, AdditiveNoise, Declare
let hasVerifier = 1;
}

def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure, UnaryEint, AdditiveNoise, DeclareOpInterfaceMethods<UnaryEint>]> {
def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure, UnaryEint, MaxNoise, DeclareOpInterfaceMethods<UnaryEint>]> {

let summary = "Negates an encrypted integer";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def AdditiveNoise : OpInterface<"AdditiveNoise"> {
let cppNamespace = "mlir::concretelang::FHE";
}

def MaxNoise : OpInterface<"MaxNoise"> {
let description = [{
An n-ary operation whose output noise is the max of all input noises.
}];

let cppNamespace = "mlir::concretelang::FHE";
}

def UnaryEint : OpInterface<"UnaryEint"> {
let description = [{
A unary operation on scalars, with the operand encrypted.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">;
def TensorUnaryEint : NativeOpTrait<"TensorUnaryEint">;


def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, DeclareOpInterfaceMethods<Binary>]> {
def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, MaxNoise, DeclareOpInterfaceMethods<Binary>]> {
let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers.";

let description = [{
Expand Down Expand Up @@ -136,7 +136,7 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [Pure,
];
}

def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcastingRules, TensorBinaryIntEint, BinaryIntEint, DeclareOpInterfaceMethods<Binary>]> {
def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcastingRules, TensorBinaryIntEint, BinaryIntEint, MaxNoise, DeclareOpInterfaceMethods<Binary>]> {
let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers and a tensor of encrypted integers.";

let description = [{
Expand Down Expand Up @@ -189,7 +189,7 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcast
];
}

def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, DeclareOpInterfaceMethods<Binary>]> {
def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, MaxNoise, DeclareOpInterfaceMethods<Binary>]> {
let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers from a tensor of encrypted integers.";

let description = [{
Expand Down Expand Up @@ -297,7 +297,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [Pure, TensorBroadcastingRule
];
}

def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [Pure, TensorUnaryEint, UnaryEint, DeclareOpInterfaceMethods<UnaryEint>]> {
def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [Pure, TensorUnaryEint, UnaryEint, MaxNoise, DeclareOpInterfaceMethods<UnaryEint>]> {
let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers.";

let description = [{
Expand Down Expand Up @@ -1128,7 +1128,7 @@ def FHELinalg_SumOp : FHELinalg_Op<"sum", [Pure, TensorUnaryEint]> {
let hasVerifier = 1;
}

def FHELinalg_ConcatOp : FHELinalg_Op<"concat", [Pure]> {
def FHELinalg_ConcatOp : FHELinalg_Op<"concat", [Pure, MaxNoise]> {
let summary = "Concatenates a sequence of tensors along an existing axis.";

let description = [{
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", [UnaryEint, DeclareOpInter
let hasVerifier = 1;
}

def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint>]> {
def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure, UnaryEint, MaxNoise, DeclareOpInterfaceMethods<UnaryEint>]> {
let summary = "Returns a tensor that contains the transposition of the input tensor.";

let description = [{
Expand Down Expand Up @@ -1240,7 +1240,7 @@ def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure, UnaryEint, DeclareO
let hasVerifier = 1;
}

def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", [Pure]> {
def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", [Pure, MaxNoise]> {
let summary = "Creates a tensor with a single element.";

let description = [{
Expand Down Expand Up @@ -1415,7 +1415,7 @@ def FHELinalg_ReinterpretPrecisionEintOp: FHELinalg_Op<"reinterpret_precision",
let hasVerifier = 1;
}

def FHELinalg_FancyIndexOp : FHELinalg_Op<"fancy_index", [Pure]> {
def FHELinalg_FancyIndexOp : FHELinalg_Op<"fancy_index", [Pure, MaxNoise]> {
let summary = "Index into a tensor using a tensor of indices.";

let description = [{
Expand Down Expand Up @@ -1462,7 +1462,7 @@ def FHELinalg_FancyIndexOp : FHELinalg_Op<"fancy_index", [Pure]> {
let hasVerifier = 1;
}

def FHELinalg_FancyAssignOp : FHELinalg_Op<"fancy_assign", [Pure]> {
def FHELinalg_FancyAssignOp : FHELinalg_Op<"fancy_assign", [Pure, MaxNoise]> {
let summary = "Assigns a tensor into another tensor at a tensor of indices.";

let description = [{
Expand Down Expand Up @@ -1517,7 +1517,7 @@ def FHELinalg_FancyAssignOp : FHELinalg_Op<"fancy_assign", [Pure]> {
let hasVerifier = 1;
}

def FHELinalg_BroadcastOp: FHELinalg_Op<"broadcast", [Pure, ConstantNoise]> {
def FHELinalg_BroadcastOp: FHELinalg_Op<"broadcast", [Pure, ConstantNoise, MaxNoise]> {
let summary = "Broadcasts a tensor to a shape.";
let description = [{
Broadcasting is used for expanding certain dimensions of a tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Pass.h"
Expand Down Expand Up @@ -201,10 +202,11 @@ struct FunctionToDag {
addEncMatMulTensor(matmulEintEint, encrypted_inputs, precision);
return;
} else if (auto zero = asZeroNoise(op)) {
// special case as zero are rewritten in several optimizer nodes
index = addZeroNoise(zero);
} else if (auto additive = asAdditiveNoise(op)) {
index = addAdditiveNoise(additive, encrypted_inputs);
} else if (isMaxNoise(op)) {
index = addMaxNoise(op, encrypted_inputs);
} else {
index = addLevelledOp(op, encrypted_inputs);
}
Expand Down Expand Up @@ -336,22 +338,8 @@ struct FunctionToDag {
auto val = op->getOpResult(0);
auto outShape = getShape(val);
auto loc = loc_to_location(op.getLoc());

// Trivial encrypted constants encoding
// There are converted to input + levelledop
auto precision = fhe::utils::getEintPrecision(val);
auto opI = dagBuilder.add_input(precision, slice(outShape), *loc);
auto inputs = Inputs{opI};

// Default complexity is negligible
double const fixedCost = NEGLIGIBLE_COMPLEXITY;
double const lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
auto comment = std::string(op->getName().getStringRef()) + " " +
loc_to_string(op.getLoc());
auto weights = std::vector<double>{1.};
index[val] = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor,
fixedCost, slice(weights),
slice(outShape), comment, *loc);
index[val] = dagBuilder.add_zero_noise(precision, slice(outShape), *loc);
return index[val];
}

Expand All @@ -366,16 +354,29 @@ struct FunctionToDag {
loc_to_string(op.getLoc());
auto loc = loc_to_location(op.getLoc());
auto weights = std::vector<double>(inputs.size(), 1.);
index[val] = dagBuilder.add_levelled_op(slice(inputs), lwe_dim_cost_factor,
fixed_cost, slice(weights),
slice(out_shape), comment, *loc);
index[val] = dagBuilder.add_linear_noise(slice(inputs), lwe_dim_cost_factor,
fixed_cost, slice(weights),
slice(out_shape), comment, *loc);
return index[val];
}

rust::Box<concrete_optimizer::Location>
loc_to_location(mlir::Location location) {
return location_from_string(loc_to_string(location));
}

concrete_optimizer::dag::OperatorIndex addMaxNoise(mlir::Operation &op,
Inputs &inputs) {
auto val = op.getResult(0);
auto out_shape = getShape(val);
auto loc = loc_to_location(op.getLoc());
assert(!inputs.empty());

index[val] =
dagBuilder.add_max_noise(slice(inputs), slice(out_shape), *loc);
return index[val];
}

concrete_optimizer::dag::OperatorIndex addLevelledOp(mlir::Operation &op,
Inputs &inputs) {
auto val = op.getResult(0);
Expand Down Expand Up @@ -425,9 +426,9 @@ struct FunctionToDag {
assert(!std::isnan(weight));
}
auto weights = std::vector<double>(n_inputs, weight);
index[val] = dagBuilder.add_levelled_op(slice(inputs), lwe_dim_cost_factor,
fixed_cost, slice(weights),
slice(out_shape), comment, *loc);
index[val] = dagBuilder.add_linear_noise(slice(inputs), lwe_dim_cost_factor,
fixed_cost, slice(weights),
slice(out_shape), comment, *loc);
return index[val];
}

Expand Down Expand Up @@ -497,7 +498,7 @@ struct FunctionToDag {
// tlu(x + y)

auto addWeights = std::vector<double>{1, 1};
auto addNode = dagBuilder.add_levelled_op(
auto addNode = dagBuilder.add_linear_noise(
slice(inputs), lweDimCostFactor, fixedCost, slice(addWeights),
slice(resultShape), comment, *loc);

Expand All @@ -517,7 +518,7 @@ struct FunctionToDag {

// tlu(x - y)
auto subWeights = std::vector<double>{1, 1};
auto subNode = dagBuilder.add_levelled_op(
auto subNode = dagBuilder.add_linear_noise(
slice(inputs), lweDimCostFactor, fixedCost, slice(subWeights),
slice(resultShape), comment, *loc);

Expand All @@ -535,7 +536,7 @@ struct FunctionToDag {
auto resultWeights = std::vector<double>{1, 1};
const std::vector<concrete_optimizer::dag::OperatorIndex> subInputs = {
lhsTluNode, rhsTluNode};
auto resultNode = dagBuilder.add_levelled_op(
auto resultNode = dagBuilder.add_linear_noise(
slice(subInputs), lweDimCostFactor, fixedCost, slice(resultWeights),
slice(resultShape), comment, *loc);

Expand Down Expand Up @@ -661,7 +662,7 @@ struct FunctionToDag {

// tlu(x + y)
auto addWeights = std::vector<double>{1, 1};
auto addNode = dagBuilder.add_levelled_op(
auto addNode = dagBuilder.add_linear_noise(
slice(inputs), lweDimCostFactor, fixedCost, slice(addWeights),
slice(pairMatrixShape), comment, *loc);

Expand All @@ -681,7 +682,7 @@ struct FunctionToDag {

// tlu(x - y)
auto subWeights = std::vector<double>{1, 1};
auto subNode = dagBuilder.add_levelled_op(
auto subNode = dagBuilder.add_linear_noise(
slice(inputs), lweDimCostFactor, fixedCost, slice(subWeights),
slice(pairMatrixShape), comment, *loc);

Expand All @@ -699,7 +700,7 @@ struct FunctionToDag {
auto resultWeights = std::vector<double>{1, 1};
const std::vector<concrete_optimizer::dag::OperatorIndex> subInputs = {
lhsTluNode, rhsTluNode};
auto resultNode = dagBuilder.add_levelled_op(
auto resultNode = dagBuilder.add_linear_noise(
slice(subInputs), lweDimCostFactor, fixedCost, slice(resultWeights),
slice(pairMatrixShape), comment, *loc);

Expand All @@ -723,7 +724,7 @@ struct FunctionToDag {
// TODO: use APIFloat.sqrt when it's available
double manp = sqrt(smanp_int.getValue().roundToDouble());
auto weights = std::vector<double>(sumOperands.size(), manp / tluSubManp);
index[result] = dagBuilder.add_levelled_op(
index[result] = dagBuilder.add_linear_noise(
slice(sumOperands), lwe_dim_cost_factor, fixed_cost, slice(weights),
slice(resultShape), comment, *loc);

Expand Down Expand Up @@ -774,7 +775,7 @@ struct FunctionToDag {
loc_to_string(maxOp.getLoc());

auto subWeights = std::vector<double>{1, 1};
auto subNode = dagBuilder.add_levelled_op(
auto subNode = dagBuilder.add_linear_noise(
slice(inputs), lweDimCostFactor, fixedCost, slice(subWeights),
slice(resultShape), comment, *loc);

Expand All @@ -785,7 +786,7 @@ struct FunctionToDag {
const std::vector<concrete_optimizer::dag::OperatorIndex> addInputs = {
tluNode, inputs[1]};
auto addWeights = std::vector<double>{1, 1};
auto resultNode = dagBuilder.add_levelled_op(
auto resultNode = dagBuilder.add_linear_noise(
slice(addInputs), lweDimCostFactor, fixedCost, slice(addWeights),
slice(resultShape), comment, *loc);

Expand Down Expand Up @@ -837,9 +838,9 @@ struct FunctionToDag {

auto subWeights = std::vector<double>(
inputs.size(), subManp / sqrt(inputSmanp.roundToDouble()));
auto subNode = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor,
fixedCost, slice(subWeights),
slice(fakeShape), comment, *loc);
auto subNode = dagBuilder.add_linear_noise(slice(inputs), lweDimCostFactor,
fixedCost, slice(subWeights),
slice(fakeShape), comment, *loc);

const std::vector<std::uint64_t> unknownFunction;
auto tluNode =
Expand All @@ -851,7 +852,7 @@ struct FunctionToDag {

auto resultWeights = std::vector<double>(
addInputs.size(), addManp / sqrt(inputSmanp.roundToDouble()));
auto resultNode = dagBuilder.add_levelled_op(
auto resultNode = dagBuilder.add_linear_noise(
slice(addInputs), lweDimCostFactor, fixedCost, slice(resultWeights),
slice(resultShape), comment, *loc);

Expand Down Expand Up @@ -1002,6 +1003,10 @@ struct FunctionToDag {
return value.isa<mlir::BlockArgument>();
}

bool isMaxNoise(mlir::Operation &op) {
return llvm::isa<mlir::concretelang::FHE::MaxNoise>(op);
}

std::optional<std::vector<std::int64_t>>
resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) {
std::vector<std::int64_t> values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ using namespace mlir::tensor;
void registerFheInterfacesExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
ExtractOp::attachInterface<UnaryEint>(*ctx);
InsertSliceOp::attachInterface<MaxNoise>(*ctx);
InsertOp::attachInterface<MaxNoise>(*ctx);
ParallelInsertSliceOp::attachInterface<MaxNoise>(*ctx);
});
}
} // namespace FHE
Expand Down
33 changes: 15 additions & 18 deletions compilers/concrete-optimizer/brute-force-optimizer/src/cggi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,24 +342,21 @@ pub fn write_to_file(
intem,
} in res.iter()
{
match intem {
Some((solution, cost)) => {
writeln!(
writer,
" {:2}, {:2}, {:2}, {:2}, {:4}, {:2}, {:2}, {:2}, {:2}, {:6}",
precision,
log_norm,
solution.glwe_dim,
solution.log_poly_size,
solution.small_lwe_dim,
solution.level_pbs,
solution.base_log_pbs,
solution.level_ks,
solution.base_log_ks,
cost
)?;
}
None => {}
if let Some((solution, cost)) = intem {
writeln!(
writer,
" {:2}, {:2}, {:2}, {:2}, {:4}, {:2}, {:2}, {:2}, {:2}, {:6}",
precision,
log_norm,
solution.glwe_dim,
solution.log_poly_size,
solution.small_lwe_dim,
solution.level_pbs,
solution.base_log_pbs,
solution.level_ks,
solution.base_log_ks,
cost
)?;
}
}
Ok(())
Expand Down
Loading

0 comments on commit 7f72c2b

Please sign in to comment.