From 25d366640a1c2260551466fac41b7dbd2e25e507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Fri, 27 Sep 2024 17:21:06 +0200 Subject: [PATCH] fix(optimizer): add zero noise and max noise ops --- .../concretelang/Dialect/FHE/IR/FHEOps.td | 8 +- .../Dialect/FHE/Interfaces/FHEInterfaces.td | 8 + .../Dialect/FHELinalg/IR/FHELinalgOps.td | 20 +- .../FHE/Analysis/ConcreteOptimizer.cpp | 75 ++-- .../FHE/Interfaces/FHEInterfacesInstances.cpp | 3 + .../src/concrete-optimizer.rs | 52 ++- .../src/cpp/concrete-optimizer.cpp | 22 +- .../src/cpp/concrete-optimizer.hpp | 4 +- .../src/dag/operator/operator.rs | 36 +- .../src/dag/rewrite/regen.rs | 6 +- .../src/dag/unparametrized.rs | 81 +++- .../dag/multi_parameters/analyze.rs | 363 ++++++++---------- .../dag/multi_parameters/complexity.rs | 165 +++----- .../dag/multi_parameters/feasible.rs | 122 ++---- .../optimization/dag/multi_parameters/mod.rs | 5 +- .../dag/multi_parameters/noise_expression.rs | 249 ++++++++++++ .../dag/multi_parameters/operations_value.rs | 306 --------------- .../dag/multi_parameters/optimize/mod.rs | 219 +++++++---- .../dag/multi_parameters/optimize/tests.rs | 8 +- .../dag/multi_parameters/partition_cut.rs | 8 +- .../dag/multi_parameters/partitionning.rs | 22 +- .../dag/multi_parameters/symbolic.rs | 136 +++++++ .../dag/multi_parameters/symbolic_variance.rs | 264 ------------- .../multi_parameters/variance_constraint.rs | 50 +-- .../src/optimization/dag/solo_key/analyze.rs | 30 +- .../src/optimization/dag/solo_key/optimize.rs | 4 +- .../concrete-optimizer/src/utils/viz.rs | 8 +- docs/explanations/FHEDialect.md | 8 +- docs/explanations/FHELinalgDialect.md | 20 +- 29 files changed, 1133 insertions(+), 1169 deletions(-) create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/noise_expression.rs delete mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic.rs delete mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 1c1391ae71..7f6c4f5329 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -53,7 +53,7 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure, ZeroNoise]> { let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); } -def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, AdditiveNoise, DeclareOpInterfaceMethods]> { +def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, MaxNoise, DeclareOpInterfaceMethods]> { let summary = "Adds an encrypted integer and a clear integer"; let description = [{ @@ -118,7 +118,7 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [Pure, BinaryEint, AdditiveNoise, Declare let hasVerifier = 1; } -def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint, AdditiveNoise]> { +def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint, MaxNoise]> { let summary = "Subtract an encrypted integer from a clear integer"; let description = [{ @@ -150,7 +150,7 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint, AdditiveNois let hasVerifier = 1; } -def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, AdditiveNoise, DeclareOpInterfaceMethods]> { +def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, MaxNoise, DeclareOpInterfaceMethods]> { let summary = "Subtract a clear integer from an encrypted integer"; let description = [{ @@ -215,7 +215,7 @@ def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure, BinaryEint, AdditiveNoise, Declare let hasVerifier = 1; } -def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure, UnaryEint, AdditiveNoise, DeclareOpInterfaceMethods]> { +def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure, UnaryEint, MaxNoise, DeclareOpInterfaceMethods]> { let summary = "Negates an encrypted integer"; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td index e60be26b8b..b47d299f89 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td @@ -53,6 +53,14 @@ def AdditiveNoise : OpInterface<"AdditiveNoise"> { let cppNamespace = "mlir::concretelang::FHE"; } +def MaxNoise : OpInterface<"MaxNoise"> { + let description = [{ + An n-ary operation whose output noise is the max of all input noises. + }]; + + let cppNamespace = "mlir::concretelang::FHE"; +} + def UnaryEint : OpInterface<"UnaryEint"> { let description = [{ A unary operation on scalars, with the operand encrypted. diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 679f7b5677..f66ce511b7 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -28,7 +28,7 @@ def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">; def TensorUnaryEint : NativeOpTrait<"TensorUnaryEint">; -def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, DeclareOpInterfaceMethods]> { +def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, MaxNoise, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers."; let description = [{ @@ -136,7 +136,7 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [Pure, ]; } -def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcastingRules, TensorBinaryIntEint, BinaryIntEint, DeclareOpInterfaceMethods]> { +def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcastingRules, TensorBinaryIntEint, BinaryIntEint, MaxNoise, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers and a tensor of encrypted integers."; let description = [{ @@ -189,7 +189,7 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcast ]; } -def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, DeclareOpInterfaceMethods]> { +def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, MaxNoise, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers from a tensor of encrypted integers."; let description = [{ @@ -297,7 +297,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [Pure, TensorBroadcastingRule ]; } -def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [Pure, TensorUnaryEint, UnaryEint, DeclareOpInterfaceMethods]> { +def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [Pure, TensorUnaryEint, UnaryEint, MaxNoise, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers."; let description = [{ @@ -1128,7 +1128,7 @@ def FHELinalg_SumOp : FHELinalg_Op<"sum", [Pure, TensorUnaryEint]> { let hasVerifier = 1; } -def FHELinalg_ConcatOp : FHELinalg_Op<"concat", [Pure]> { +def FHELinalg_ConcatOp : FHELinalg_Op<"concat", [Pure, MaxNoise]> { let summary = "Concatenates a sequence of tensors along an existing axis."; let description = [{ @@ -1201,7 +1201,7 @@ def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", [UnaryEint, DeclareOpInter let hasVerifier = 1; } -def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { +def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure, UnaryEint, MaxNoise, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the transposition of the input tensor."; let description = [{ @@ -1240,7 +1240,7 @@ def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure, UnaryEint, DeclareO let hasVerifier = 1; } -def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", [Pure]> { +def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", [Pure, MaxNoise]> { let summary = "Creates a tensor with a single element."; let description = [{ @@ -1415,7 +1415,7 @@ def FHELinalg_ReinterpretPrecisionEintOp: FHELinalg_Op<"reinterpret_precision", let hasVerifier = 1; } -def FHELinalg_FancyIndexOp : FHELinalg_Op<"fancy_index", [Pure]> { +def FHELinalg_FancyIndexOp : FHELinalg_Op<"fancy_index", [Pure, MaxNoise]> { let summary = "Index into a tensor using a tensor of indices."; let description = [{ @@ -1462,7 +1462,7 @@ def FHELinalg_FancyIndexOp : FHELinalg_Op<"fancy_index", [Pure]> { let hasVerifier = 1; } -def FHELinalg_FancyAssignOp : FHELinalg_Op<"fancy_assign", [Pure]> { +def FHELinalg_FancyAssignOp : FHELinalg_Op<"fancy_assign", [Pure, MaxNoise]> { let summary = "Assigns a tensor into another tensor at a tensor of indices."; let description = [{ @@ -1517,7 +1517,7 @@ def FHELinalg_FancyAssignOp : FHELinalg_Op<"fancy_assign", [Pure]> { let hasVerifier = 1; } -def FHELinalg_BroadcastOp: FHELinalg_Op<"broadcast", [Pure, ConstantNoise]> { +def FHELinalg_BroadcastOp: FHELinalg_Op<"broadcast", [Pure, ConstantNoise, MaxNoise]> { let summary = "Broadcasts a tensor to a shape."; let description = [{ Broadcasting is used for expanding certain dimensions of a tensor diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 1a7e6799b9..fe18267479 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/PassManager.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Pass.h" @@ -201,10 +202,11 @@ struct FunctionToDag { addEncMatMulTensor(matmulEintEint, encrypted_inputs, precision); return; } else if (auto zero = asZeroNoise(op)) { - // special case as zero are rewritten in several optimizer nodes index = addZeroNoise(zero); } else if (auto additive = asAdditiveNoise(op)) { index = addAdditiveNoise(additive, encrypted_inputs); + } else if (isMaxNoise(op)) { + index = addMaxNoise(op, encrypted_inputs); } else { index = addLevelledOp(op, encrypted_inputs); } @@ -336,22 +338,8 @@ struct FunctionToDag { auto val = op->getOpResult(0); auto outShape = getShape(val); auto loc = loc_to_location(op.getLoc()); - - // Trivial encrypted constants encoding - // There are converted to input + levelledop auto precision = fhe::utils::getEintPrecision(val); - auto opI = dagBuilder.add_input(precision, slice(outShape), *loc); - auto inputs = Inputs{opI}; - - // Default complexity is negligible - double const fixedCost = NEGLIGIBLE_COMPLEXITY; - double const lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; - auto comment = std::string(op->getName().getStringRef()) + " " + - loc_to_string(op.getLoc()); - auto weights = std::vector{1.}; - index[val] = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, - fixedCost, slice(weights), - slice(outShape), comment, *loc); + index[val] = dagBuilder.add_zero_noise(precision, slice(outShape), *loc); return index[val]; } @@ -366,9 +354,9 @@ struct FunctionToDag { loc_to_string(op.getLoc()); auto loc = loc_to_location(op.getLoc()); auto weights = std::vector(inputs.size(), 1.); - index[val] = dagBuilder.add_levelled_op(slice(inputs), lwe_dim_cost_factor, - fixed_cost, slice(weights), - slice(out_shape), comment, *loc); + index[val] = dagBuilder.add_linear_noise(slice(inputs), lwe_dim_cost_factor, + fixed_cost, slice(weights), + slice(out_shape), comment, *loc); return index[val]; } @@ -376,6 +364,19 @@ struct FunctionToDag { loc_to_location(mlir::Location location) { return location_from_string(loc_to_string(location)); } + + concrete_optimizer::dag::OperatorIndex addMaxNoise(mlir::Operation &op, + Inputs &inputs) { + auto val = op.getResult(0); + auto out_shape = getShape(val); + auto loc = loc_to_location(op.getLoc()); + assert(!inputs.empty()); + + index[val] = + dagBuilder.add_max_noise(slice(inputs), slice(out_shape), *loc); + return index[val]; + } + concrete_optimizer::dag::OperatorIndex addLevelledOp(mlir::Operation &op, Inputs &inputs) { auto val = op.getResult(0); @@ -425,9 +426,9 @@ struct FunctionToDag { assert(!std::isnan(weight)); } auto weights = std::vector(n_inputs, weight); - index[val] = dagBuilder.add_levelled_op(slice(inputs), lwe_dim_cost_factor, - fixed_cost, slice(weights), - slice(out_shape), comment, *loc); + index[val] = dagBuilder.add_linear_noise(slice(inputs), lwe_dim_cost_factor, + fixed_cost, slice(weights), + slice(out_shape), comment, *loc); return index[val]; } @@ -497,7 +498,7 @@ struct FunctionToDag { // tlu(x + y) auto addWeights = std::vector{1, 1}; - auto addNode = dagBuilder.add_levelled_op( + auto addNode = dagBuilder.add_linear_noise( slice(inputs), lweDimCostFactor, fixedCost, slice(addWeights), slice(resultShape), comment, *loc); @@ -517,7 +518,7 @@ struct FunctionToDag { // tlu(x - y) auto subWeights = std::vector{1, 1}; - auto subNode = dagBuilder.add_levelled_op( + auto subNode = dagBuilder.add_linear_noise( slice(inputs), lweDimCostFactor, fixedCost, slice(subWeights), slice(resultShape), comment, *loc); @@ -535,7 +536,7 @@ struct FunctionToDag { auto resultWeights = std::vector{1, 1}; const std::vector subInputs = { lhsTluNode, rhsTluNode}; - auto resultNode = dagBuilder.add_levelled_op( + auto resultNode = dagBuilder.add_linear_noise( slice(subInputs), lweDimCostFactor, fixedCost, slice(resultWeights), slice(resultShape), comment, *loc); @@ -661,7 +662,7 @@ struct FunctionToDag { // tlu(x + y) auto addWeights = std::vector{1, 1}; - auto addNode = dagBuilder.add_levelled_op( + auto addNode = dagBuilder.add_linear_noise( slice(inputs), lweDimCostFactor, fixedCost, slice(addWeights), slice(pairMatrixShape), comment, *loc); @@ -681,7 +682,7 @@ struct FunctionToDag { // tlu(x - y) auto subWeights = std::vector{1, 1}; - auto subNode = dagBuilder.add_levelled_op( + auto subNode = dagBuilder.add_linear_noise( slice(inputs), lweDimCostFactor, fixedCost, slice(subWeights), slice(pairMatrixShape), comment, *loc); @@ -699,7 +700,7 @@ struct FunctionToDag { auto resultWeights = std::vector{1, 1}; const std::vector subInputs = { lhsTluNode, rhsTluNode}; - auto resultNode = dagBuilder.add_levelled_op( + auto resultNode = dagBuilder.add_linear_noise( slice(subInputs), lweDimCostFactor, fixedCost, slice(resultWeights), slice(pairMatrixShape), comment, *loc); @@ -723,7 +724,7 @@ struct FunctionToDag { // TODO: use APIFloat.sqrt when it's available double manp = sqrt(smanp_int.getValue().roundToDouble()); auto weights = std::vector(sumOperands.size(), manp / tluSubManp); - index[result] = dagBuilder.add_levelled_op( + index[result] = dagBuilder.add_linear_noise( slice(sumOperands), lwe_dim_cost_factor, fixed_cost, slice(weights), slice(resultShape), comment, *loc); @@ -774,7 +775,7 @@ struct FunctionToDag { loc_to_string(maxOp.getLoc()); auto subWeights = std::vector{1, 1}; - auto subNode = dagBuilder.add_levelled_op( + auto subNode = dagBuilder.add_linear_noise( slice(inputs), lweDimCostFactor, fixedCost, slice(subWeights), slice(resultShape), comment, *loc); @@ -785,7 +786,7 @@ struct FunctionToDag { const std::vector addInputs = { tluNode, inputs[1]}; auto addWeights = std::vector{1, 1}; - auto resultNode = dagBuilder.add_levelled_op( + auto resultNode = dagBuilder.add_linear_noise( slice(addInputs), lweDimCostFactor, fixedCost, slice(addWeights), slice(resultShape), comment, *loc); @@ -837,9 +838,9 @@ struct FunctionToDag { auto subWeights = std::vector( inputs.size(), subManp / sqrt(inputSmanp.roundToDouble())); - auto subNode = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, - fixedCost, slice(subWeights), - slice(fakeShape), comment, *loc); + auto subNode = dagBuilder.add_linear_noise(slice(inputs), lweDimCostFactor, + fixedCost, slice(subWeights), + slice(fakeShape), comment, *loc); const std::vector unknownFunction; auto tluNode = @@ -851,7 +852,7 @@ struct FunctionToDag { auto resultWeights = std::vector( addInputs.size(), addManp / sqrt(inputSmanp.roundToDouble())); - auto resultNode = dagBuilder.add_levelled_op( + auto resultNode = dagBuilder.add_linear_noise( slice(addInputs), lweDimCostFactor, fixedCost, slice(resultWeights), slice(resultShape), comment, *loc); @@ -1002,6 +1003,10 @@ struct FunctionToDag { return value.isa(); } + bool isMaxNoise(mlir::Operation &op) { + return llvm::isa(op); + } + std::optional> resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) { std::vector values; diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfacesInstances.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfacesInstances.cpp index f0c88d285f..3b3927af63 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfacesInstances.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfacesInstances.cpp @@ -17,6 +17,9 @@ using namespace mlir::tensor; void registerFheInterfacesExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { ExtractOp::attachInterface(*ctx); + InsertSliceOp::attachInterface(*ctx); + InsertOp::attachInterface(*ctx); + ParallelInsertSliceOp::attachInterface(*ctx); }); } } // namespace FHE diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index fba2d58b6d..25db907193 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -690,6 +690,21 @@ impl<'dag> DagBuilder<'dag> { .into() } + fn add_zero_noise( + &mut self, + out_precision: Precision, + out_shape: &[u64], + location: &Location, + ) -> ffi::OperatorIndex { + let out_shape = Shape { + dimensions_size: out_shape.to_owned(), + }; + + self.0 + .add_zero_noise(out_precision, out_shape, location.0.clone()) + .into() + } + fn add_lut( &mut self, input: ffi::OperatorIndex, @@ -718,7 +733,7 @@ impl<'dag> DagBuilder<'dag> { self.0.add_dot(inputs, weights.0, location.0.clone()).into() } - fn add_levelled_op( + fn add_linear_noise( &mut self, inputs: &[ffi::OperatorIndex], lwe_dim_cost_factor: f64, @@ -741,7 +756,7 @@ impl<'dag> DagBuilder<'dag> { }; self.0 - .add_levelled_op( + .add_linear_noise( inputs, complexity, weights, @@ -752,6 +767,23 @@ impl<'dag> DagBuilder<'dag> { .into() } + fn add_max_noise( + &mut self, + inputs: &[ffi::OperatorIndex], + out_shape: &[u64], + location: &Location, + ) -> ffi::OperatorIndex { + let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); + + let out_shape = Shape { + dimensions_size: out_shape.to_owned(), + }; + + self.0 + .add_max_noise(inputs, out_shape, location.0.clone()) + .into() + } + fn add_round_op( &mut self, input: ffi::OperatorIndex, @@ -943,6 +975,13 @@ mod ffi { location: &Location, ) -> OperatorIndex; + unsafe fn add_zero_noise( + self: &mut DagBuilder<'_>, + out_precision: u8, + out_shape: &[u64], + location: &Location, + ) -> OperatorIndex; + unsafe fn add_lut( self: &mut DagBuilder<'_>, input: OperatorIndex, @@ -958,7 +997,7 @@ mod ffi { location: &Location, ) -> OperatorIndex; - unsafe fn add_levelled_op( + unsafe fn add_linear_noise( self: &mut DagBuilder<'_>, inputs: &[OperatorIndex], lwe_dim_cost_factor: f64, @@ -969,6 +1008,13 @@ mod ffi { location: &Location, ) -> OperatorIndex; + unsafe fn add_max_noise( + self: &mut DagBuilder<'_>, + inputs: &[OperatorIndex], + out_shape: &[u64], + location: &Location, + ) -> OperatorIndex; + unsafe fn add_round_op( self: &mut DagBuilder<'_>, input: OperatorIndex, diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 38a1871287..4936973333 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -996,9 +996,11 @@ struct Dag final : public ::rust::Opaque { struct DagBuilder final : public ::rust::Opaque { ::rust::String dump() const noexcept; ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_zero_noise(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Box<::concrete_optimizer::Weights> weights, ::concrete_optimizer::Location const &location) noexcept; - ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment, ::concrete_optimizer::Location const &location) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_linear_noise(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment, ::concrete_optimizer::Location const &location) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_max_noise(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_change_partition_with_src(::concrete_optimizer::dag::OperatorIndex input, ::concrete_optimizer::ExternalPartition const &src_partition, ::concrete_optimizer::Location const &location) noexcept; @@ -1353,11 +1355,15 @@ void concrete_optimizer$cxxbridge1$DagBuilder$dump(::concrete_optimizer::DagBuil ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_input(::concrete_optimizer::DagBuilder &self, ::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_zero_noise(::concrete_optimizer::DagBuilder &self, ::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; + ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_lut(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_dot(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::concrete_optimizer::Weights *weights, ::concrete_optimizer::Location const &location) noexcept; -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_levelled_op(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment, ::concrete_optimizer::Location const &location) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_linear_noise(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment, ::concrete_optimizer::Location const &location) noexcept; + +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_max_noise(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_round_op(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision, ::concrete_optimizer::Location const &location) noexcept; @@ -1505,6 +1511,10 @@ ::concrete_optimizer::dag::OperatorIndex DagBuilder::add_input(::std::uint8_t ou return concrete_optimizer$cxxbridge1$DagBuilder$add_input(*this, out_precision, out_shape, location); } +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_zero_noise(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_zero_noise(*this, out_precision, out_shape, location); +} + ::concrete_optimizer::dag::OperatorIndex DagBuilder::add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision, ::concrete_optimizer::Location const &location) noexcept { return concrete_optimizer$cxxbridge1$DagBuilder$add_lut(*this, input, table, out_precision, location); } @@ -1513,8 +1523,12 @@ ::concrete_optimizer::dag::OperatorIndex DagBuilder::add_dot(::rust::Slice<::con return concrete_optimizer$cxxbridge1$DagBuilder$add_dot(*this, inputs, weights.into_raw(), location); } -::concrete_optimizer::dag::OperatorIndex DagBuilder::add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment, ::concrete_optimizer::Location const &location) noexcept { - return concrete_optimizer$cxxbridge1$DagBuilder$add_levelled_op(*this, inputs, lwe_dim_cost_factor, fixed_cost, weights, out_shape, comment, location); +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_linear_noise(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment, ::concrete_optimizer::Location const &location) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_linear_noise(*this, inputs, lwe_dim_cost_factor, fixed_cost, weights, out_shape, comment, location); +} + +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_max_noise(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_max_noise(*this, inputs, out_shape, location); } ::concrete_optimizer::dag::OperatorIndex DagBuilder::add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision, ::concrete_optimizer::Location const &location) noexcept { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 646139ea2d..9f993b9f92 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -977,9 +977,11 @@ struct Dag final : public ::rust::Opaque { struct DagBuilder final : public ::rust::Opaque { ::rust::String dump() const noexcept; ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_zero_noise(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Box<::concrete_optimizer::Weights> weights, ::concrete_optimizer::Location const &location) noexcept; - ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment, ::concrete_optimizer::Location const &location) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_linear_noise(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment, ::concrete_optimizer::Location const &location) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_max_noise(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Slice<::std::uint64_t const> out_shape, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision, ::concrete_optimizer::Location const &location) noexcept; ::concrete_optimizer::dag::OperatorIndex add_change_partition_with_src(::concrete_optimizer::dag::OperatorIndex input, ::concrete_optimizer::ExternalPartition const &src_partition, ::concrete_optimizer::Location const &location) noexcept; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs index 367f2a21ad..c038bdf1b3 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs @@ -77,6 +77,10 @@ pub enum Operator { out_precision: Precision, out_shape: Shape, }, + ZeroNoise { + out_precision: Precision, + out_shape: Shape, + }, Lut { input: OperatorIndex, table: FunctionTable, @@ -87,13 +91,17 @@ pub enum Operator { weights: Weights, kind: DotKind, }, - LevelledOp { + LinearNoise { inputs: Vec, complexity: LevelledComplexity, weights: Vec, out_shape: Shape, comment: String, }, + MaxNoise { + inputs: Vec, + out_shape: Shape, + }, // Used to reduced or increase precision when the ciphertext is compatible with different precision // This is done without any checking UnsafeCast { @@ -116,8 +124,10 @@ impl Operator { // Returns an iterator on the indices of the operator inputs. pub(crate) fn get_inputs_iter(&self) -> Box + '_> { match self { - Self::Input { .. } => Box::new(empty()), - Self::LevelledOp { inputs, .. } | Self::Dot { inputs, .. } => Box::new(inputs.iter()), + Self::Input { .. } | Self::ZeroNoise { .. } => Box::new(empty()), + Self::LinearNoise { inputs, .. } + | Self::Dot { inputs, .. } + | Self::MaxNoise { inputs, .. } => Box::new(inputs.iter()), Self::UnsafeCast { input, .. } | Self::Lut { input, .. } | Self::Round { input, .. } @@ -153,6 +163,12 @@ impl fmt::Display for Operator { } => { write!(f, "Input : u{out_precision} x {out_shape:?}")?; } + Self::ZeroNoise { + out_precision, + out_shape, + } => { + write!(f, "Zero : u{out_precision} x {out_shape:?}")?; + } Self::Dot { inputs, weights, .. } => { @@ -176,7 +192,7 @@ impl fmt::Display for Operator { } => { write!(f, "LUT[%{}] : u{out_precision}", input.0)?; } - Self::LevelledOp { + Self::LinearNoise { inputs, weights, out_shape, @@ -191,6 +207,18 @@ impl fmt::Display for Operator { } write!(f, "] : weights={weights:?}, out_shape={out_shape:?}")?; } + Self::MaxNoise { + inputs, out_shape, .. + } => { + write!(f, "MAX[")?; + for (i, input) in inputs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "%{}", input.0)?; + } + write!(f, "] : out_shape={out_shape:?}")?; + } Self::Round { input, out_precision, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs index 24c38a08c9..ca064fbca0 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs @@ -5,12 +5,14 @@ use crate::dag::unparametrized::{Dag, DagBuilder, DagOperator}; fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator { let mut op = op.clone(); match &mut op { - Operator::Input { .. } => (), + Operator::Input { .. } | Operator::ZeroNoise { .. } => (), Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } | Operator::Round { input, .. } | Operator::ChangePartition { input, .. } => input.0 = old_index_to_new[input.0], - Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { + Operator::Dot { inputs, .. } + | Operator::LinearNoise { inputs, .. } + | Operator::MaxNoise { inputs, .. } => { for input in inputs { input.0 = old_index_to_new[input.0]; } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index a59e0bd4c7..26cd246285 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -73,7 +73,10 @@ pub struct DagOperator<'dag> { impl<'dag> DagOperator<'dag> { /// Returns if the operator is an input. pub fn is_input(&self) -> bool { - matches!(self.operator, Operator::Input { .. }) + matches!( + self.operator, + Operator::Input { .. } | Operator::ZeroNoise { .. } + ) } /// Returns if the operator is an output. @@ -163,6 +166,22 @@ impl<'dag> DagBuilder<'dag> { OperatorIndex(i) } + pub fn add_zero_noise( + &mut self, + out_precision: Precision, + out_shape: impl Into, + location: Location, + ) -> OperatorIndex { + let out_shape = out_shape.into(); + self.add_operator( + Operator::ZeroNoise { + out_precision, + out_shape, + }, + location, + ) + } + pub fn add_input( &mut self, out_precision: Precision, @@ -217,7 +236,7 @@ impl<'dag> DagBuilder<'dag> { ) } - pub fn add_levelled_op( + pub fn add_linear_noise( &mut self, inputs: impl Into>, complexity: LevelledComplexity, @@ -231,7 +250,7 @@ impl<'dag> DagBuilder<'dag> { let comment = comment.into(); let weights = weights.into(); assert_eq!(weights.len(), inputs.len()); - let op = Operator::LevelledOp { + let op = Operator::LinearNoise { inputs, complexity, weights, @@ -241,6 +260,18 @@ impl<'dag> DagBuilder<'dag> { self.add_operator(op, location) } + pub fn add_max_noise( + &mut self, + inputs: impl Into>, + out_shape: impl Into, + location: Location, + ) -> OperatorIndex { + let inputs = inputs.into(); + let out_shape = out_shape.into(); + let op = Operator::MaxNoise { inputs, out_shape }; + self.add_operator(op, location) + } + pub fn add_unsafe_cast( &mut self, input: OperatorIndex, @@ -433,9 +464,10 @@ impl<'dag> DagBuilder<'dag> { fn infer_out_shape(&self, op: &Operator) -> Shape { match op { - Operator::Input { out_shape, .. } | Operator::LevelledOp { out_shape, .. } => { - out_shape.clone() - } + Operator::Input { out_shape, .. } + | Operator::LinearNoise { out_shape, .. } + | Operator::ZeroNoise { out_shape, .. } + | Operator::MaxNoise { out_shape, .. } => out_shape.clone(), Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } | Operator::Round { input, .. } @@ -468,12 +500,13 @@ impl<'dag> DagBuilder<'dag> { fn infer_out_precision(&self, op: &Operator) -> Precision { match op { Operator::Input { out_precision, .. } + | Operator::ZeroNoise { out_precision, .. } | Operator::Lut { out_precision, .. } | Operator::UnsafeCast { out_precision, .. } | Operator::Round { out_precision, .. } => *out_precision, - Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { - self.dag.out_precisions[inputs[0].0] - } + Operator::Dot { inputs, .. } + | Operator::LinearNoise { inputs, .. } + | Operator::MaxNoise { inputs, .. } => self.dag.out_precisions[inputs[0].0], Operator::ChangePartition { input, .. } => self.dag.out_precisions[input.0], } } @@ -591,6 +624,15 @@ impl Dag { .add_input(out_precision, out_shape, Location::Unknown) } + pub fn add_zero_noise( + &mut self, + out_precision: Precision, + out_shape: impl Into, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_zero_noise(out_precision, out_shape, Location::Unknown) + } + pub fn add_lut( &mut self, input: OperatorIndex, @@ -624,7 +666,7 @@ impl Dag { .add_dot(inputs, weights, Location::Unknown) } - pub fn add_levelled_op( + pub fn add_linear_noise( &mut self, inputs: impl Into>, complexity: LevelledComplexity, @@ -632,7 +674,7 @@ impl Dag { out_shape: impl Into, comment: impl Into, ) -> OperatorIndex { - self.builder(DEFAULT_CIRCUIT).add_levelled_op( + self.builder(DEFAULT_CIRCUIT).add_linear_noise( inputs, complexity, weights, @@ -642,6 +684,15 @@ impl Dag { ) } + pub fn add_max_noise( + &mut self, + inputs: impl Into>, + out_shape: impl Into, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_max_noise(inputs, out_shape, Location::Unknown) + } + pub fn add_unsafe_cast( &mut self, input: OperatorIndex, @@ -932,7 +983,7 @@ mod tests { let input2 = builder.add_input(2, Shape::number(), Location::Unknown); let cpx_add = LevelledComplexity::ADDITION; - let sum1 = builder.add_levelled_op( + let sum1 = builder.add_linear_noise( [input1, input2], cpx_add, [1.0, 1.0], @@ -943,7 +994,7 @@ mod tests { let lut1 = builder.add_lut(sum1, FunctionTable::UNKWOWN, 1, Location::Unknown); - let concat = builder.add_levelled_op( + let concat = builder.add_linear_noise( [input1, lut1], cpx_add, [1.0, 1.0], @@ -981,7 +1032,7 @@ mod tests { out_precision: 2, out_shape: Shape::number(), }, - Operator::LevelledOp { + Operator::LinearNoise { inputs: vec![input1, input2], complexity: cpx_add, weights: vec![1.0, 1.0], @@ -993,7 +1044,7 @@ mod tests { table: FunctionTable::UNKWOWN, out_precision: 1, }, - Operator::LevelledOp { + Operator::LinearNoise { inputs: vec![input1, lut1], complexity: cpx_add, weights: vec![1.0, 1.0], diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index 94d271e1ee..40e80b9487 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -1,4 +1,4 @@ -use std::ops::{Deref, Index, IndexMut}; +use std::ops::{Add, Deref, Index, IndexMut}; use crate::dag::operator::{DotKind, LevelledComplexity, Operator, OperatorIndex, Precision}; use crate::dag::rewrite::round::expand_round_and_index_map; @@ -9,14 +9,17 @@ use crate::optimization::dag::multi_parameters::partitionning::partitionning_wit use crate::optimization::dag::multi_parameters::partitions::{ InstructionPartition, PartitionIndex, Transition, }; -use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance; use crate::optimization::dag::solo_key::analyze::safe_noise_bound; use crate::optimization::{Err, Result}; use super::complexity::OperationsCount; use super::keys_spec; -use super::operations_value::OperationsValue; +use super::noise_expression::{ + bootstrap_noise, fast_keyswitch_noise, input_noise, keyswitch_noise, modulus_switching_noise, + NoiseExpression, +}; use super::partitions::Partitions; +use super::symbolic::{bootstrap, fast_keyswitch, keyswitch, SymbolMap}; use super::variance_constraint::VarianceConstraint; use crate::utils::square; @@ -33,16 +36,14 @@ impl PartitionedDag { let vars = self .dag .get_operators_iter() - .map(|op| { - if op.is_input() { - let mut var = OperatorVariance::nan(self.partitions.nb_partitions); - let partition = self.partitions[op.id].instruction_partition; - var[partition] = - SymbolicVariance::input(self.partitions.nb_partitions, partition); - var - } else { - OperatorVariance::nan(self.partitions.nb_partitions) + .map(|op| match op.operator { + Operator::Input { .. } => { + let mut output = OperatorVariance::zero(self.partitions.nb_partitions); + let op_partition = self.partitions[op.id].instruction_partition; + output[op_partition] += 1.0 * input_noise(op_partition); + output } + _ => OperatorVariance::zero(self.partitions.nb_partitions), }) .collect(); Variances { vars } @@ -192,7 +193,6 @@ impl VariancedDag { if let Operator::Input { .. } = op { let partition_index = self.partitions.instrs_partition[i].instruction_partition; if p_cut.is_external_partition(&partition_index) { - let partitions = self.partitions.clone(); let external_partition = &p_cut.external_partitions[p_cut.external_partition_index(partition_index)]; let max_variance = external_partition.max_variance; @@ -200,11 +200,8 @@ impl VariancedDag { let mut input = self.get_operator_mut(OperatorIndex(i)); let mut variances = input.variance().clone(); - variances.vars[partition_index.0] = SymbolicVariance::from_external_partition( - partitions.nb_partitions, - partition_index, - max_variance / variance, - ); + variances.vars[partition_index.0] = NoiseExpression::zero() + + (max_variance / variance) * bootstrap_noise(partition_index); *(input.variance_mut()) = variances; } } @@ -230,17 +227,14 @@ impl VariancedDag { .max_variance; let variances = &self.get_operator(op.id).variance().vars.clone(); - for (i, variance) in variances.iter().enumerate() { - if variance.coeffs.is_nan() { - assert!(i != partition_index.0); - continue; - } + for variance in variances.iter() { let constraint = VarianceConstraint { precision: *out_precision, + nb_partitions: self.partitions.nb_partitions, partition: partition_index, nb_constraints: out_shape.flat_size(), safe_variance_bound: max_variance, - variance: variance.clone(), + noise_expression: variance.clone(), location: op.location.clone(), }; self.external_variance_constraints.push(constraint); @@ -271,17 +265,14 @@ impl VariancedDag { .max_variance; let variances = &self.get_operator(op_index).variance().vars.clone(); - for (i, variance) in variances.iter().enumerate() { - if variance.coeffs.is_nan() { - assert!(i != partition_index.0); - continue; - } + for variance in variances.iter() { let constraint = VarianceConstraint { precision: *out_precision, + nb_partitions: self.partitions.nb_partitions, partition: partition_index, nb_constraints: out_shape.flat_size(), safe_variance_bound: max_variance, - variance: variance.clone(), + noise_expression: variance.clone(), location: dag_op.location.clone(), }; self.external_variance_constraints.push(constraint); @@ -302,52 +293,34 @@ impl VariancedDag { if operator.operator().is_input() { continue; } + // Operator variance will be used to override the noise - let mut operator_variance = OperatorVariance::nan(nb_partitions); + let mut operator_variance = OperatorVariance::zero(nb_partitions); + let operator_partition = operator.partition().instruction_partition; + // We first compute the noise in the partition of the operator operator_variance[operator.partition().instruction_partition] = match operator .operator() .operator { - Operator::Input { .. } => unreachable!(), - Operator::Lut { .. } => SymbolicVariance::after_pbs( - nb_partitions, - operator.partition().instruction_partition, - ), - Operator::LevelledOp { weights, .. } => operator - .get_inputs_iter() - .zip(weights) - .fold(SymbolicVariance::ZERO, |acc, (inp, &weight)| { - acc + inp.variance()[operator.partition().instruction_partition].clone() - * square(weight) - }), - Operator::Dot { - kind: DotKind::CompatibleTensor { .. }, - .. - } => todo!("TODO"), - Operator::Dot { - kind: DotKind::Unsupported { .. }, - .. - } => panic!("Unsupported"), - Operator::Dot { - inputs, - weights, - kind: DotKind::Simple | DotKind::Tensor | DotKind::Broadcast { .. }, - } if inputs.len() == 1 => { - let var = operator + Operator::Lut { .. } => { + NoiseExpression::zero() + 1.0 * bootstrap_noise(operator_partition) + } + Operator::MaxNoise { .. } => { + operator .get_inputs_iter() - .next() - .unwrap() - .variance() - .clone(); - weights - .values - .iter() - .fold(SymbolicVariance::ZERO, |acc, weight| { - acc + var[operator.partition().instruction_partition].clone() - * square(*weight as f64) + .fold(NoiseExpression::zero(), |acc, inp| { + let inp_noise = inp.variance()[operator_partition].clone(); + NoiseExpression::max(&acc, &inp_noise) }) } + Operator::LinearNoise { weights, .. } => operator + .get_inputs_iter() + .zip(weights) + .fold(NoiseExpression::zero(), |acc, (inp, &weight)| { + let inp_noise = inp.variance()[operator_partition].clone(); + acc + inp_noise * square(weight) + }), Operator::Dot { weights, kind: DotKind::Simple | DotKind::Tensor | DotKind::Broadcast { .. }, @@ -356,15 +329,24 @@ impl VariancedDag { .values .iter() .zip(operator.get_inputs_iter().map(|n| n.variance().clone())) - .fold(SymbolicVariance::ZERO, |acc, (weight, var)| { - acc + var[operator.partition().instruction_partition].clone() - * square(*weight as f64) + .fold(NoiseExpression::zero(), |acc, (weight, var)| { + let inp_var = var[operator_partition].clone(); + acc + inp_var * square(*weight as f64) }), Operator::UnsafeCast { .. } | Operator::ChangePartition { .. } => { - operator.get_inputs_iter().next().unwrap().variance() - [operator.partition().instruction_partition] + operator.get_inputs_iter().next().unwrap().variance()[operator_partition] .clone() } + Operator::Input { .. } | Operator::ZeroNoise { .. } => unreachable!(), + + Operator::Dot { + kind: DotKind::CompatibleTensor { .. }, + .. + } => todo!("TODO"), + Operator::Dot { + kind: DotKind::Unsupported { .. }, + .. + } => panic!("Unsupported"), Operator::Round { .. } => { unreachable!("Round should have been either expanded or integrated to a lut") } @@ -375,12 +357,9 @@ impl VariancedDag { .alternative_output_representation .iter() .for_each(|index| { - operator_variance[*index] = operator_variance - [operator.partition().instruction_partition] - .after_partition_keyswitch_to_big( - operator.partition().instruction_partition, - *index, - ); + let noise_in_operator_partition = operator_variance[operator_partition].clone(); + operator_variance[*index] = noise_in_operator_partition + + 1.0 * fast_keyswitch_noise(operator_partition, *index); }); // We override the noise *operator.variance_mut() = operator_variance; @@ -457,7 +436,11 @@ pub fn analyze( let undominated_variance_constraints = VarianceConstraint::remove_dominated(&variance_constraints); let operations_count_per_instrs = collect_operations_count(&varianced_dag); - let operations_count = sum_operations_count(&operations_count_per_instrs); + let operations_count = operations_count_per_instrs + .clone() + .into_iter() + .reduce(Add::add) + .unwrap(); Ok(AnalyzedDag { operators: varianced_dag.dag.operators, instruction_rewrite_index, @@ -548,11 +531,11 @@ pub fn original_instrs_partition( #[derive(PartialEq, Debug, Clone)] pub struct OperatorVariance { - pub(crate) vars: Vec, + pub(crate) vars: Vec, } impl Index for OperatorVariance { - type Output = SymbolicVariance; + type Output = NoiseExpression; fn index(&self, index: PartitionIndex) -> &Self::Output { &self.vars[index.0] @@ -566,7 +549,7 @@ impl IndexMut for OperatorVariance { } impl Deref for OperatorVariance { - type Target = [SymbolicVariance]; + type Target = [NoiseExpression]; fn deref(&self) -> &Self::Target { &self.vars @@ -574,20 +557,24 @@ impl Deref for OperatorVariance { } impl OperatorVariance { - pub fn nan(nb_partitions: usize) -> Self { + pub fn zero(nb_partitions: usize) -> Self { Self { vars: (0..nb_partitions) - .map(|_| SymbolicVariance::nan(nb_partitions)) + .map(|_| NoiseExpression::zero()) .collect(), } } + pub fn nb_partitions(&self) -> usize { + self.vars.len() + } + pub fn partition_wise_max(&self, other: &Self) -> Self { let vars = self .vars .iter() .zip(other.vars.iter()) - .map(|(s, o)| s.max(o)) + .map(|(s, o)| NoiseExpression::max(s, o)) .collect(); Self { vars } } @@ -597,11 +584,11 @@ impl OperatorVariance { .iter() .enumerate() .flat_map(|(var_i, var)| { - PartitionIndex::range(0, var.nb_partitions()) - .map(move |part_i| (var_i, part_i, var.coeff_input(part_i))) + PartitionIndex::range(0, self.nb_partitions()) + .map(move |part_i| (var_i, part_i, var.coeff(input_noise(part_i)))) }) .try_for_each(|(var, partition, coeff)| { - if !coeff.is_nan() && coeff > 1.0 { + if coeff > 1.0 { Result::Err(Err::NotComposable(format!( "The noise of the node {var} is contaminated by noise coming straight from the input (partition: {partition}, coeff: {coeff:.2})" ))) @@ -643,10 +630,11 @@ impl Deref for Variances { fn variance_constraint( dag: &Dag, noise_config: &NoiseBoundConfig, + nb_partitions: usize, partition: PartitionIndex, op_i: usize, precision: Precision, - variance: SymbolicVariance, + noise: NoiseExpression, ) -> VarianceConstraint { let nb_constraints = dag.out_shapes[op_i].flat_size(); let safe_variance_bound = safe_noise_bound(precision, noise_config); @@ -656,7 +644,8 @@ fn variance_constraint( partition, nb_constraints, safe_variance_bound, - variance, + nb_partitions, + noise_expression: noise, location, } } @@ -690,19 +679,19 @@ fn collect_all_variance_constraints( assert!(src_partition != dst_partition); let variance = &variances[*input][dst_partition]; assert!( - variance.coeff_partition_keyswitch_to_big(src_partition, dst_partition) - == 1.0 + variance.coeff(fast_keyswitch_noise(src_partition, dst_partition)) == 1.0 ); dst_partition } }; - let variance = &variances[*input][src_partition].clone(); + let variance = variances[*input][src_partition].clone(); let variance = variance - .after_partition_keyswitch_to_small(src_partition, dst_partition) - .after_modulus_switching(partition); + + 1.0 * keyswitch_noise(src_partition, dst_partition) + + 1.0 * modulus_switching_noise(partition); constraints.push(variance_constraint( dag, noise_config, + partitions.nb_partitions, partition, op.id.0, precision, @@ -715,6 +704,7 @@ fn collect_all_variance_constraints( constraints.push(variance_constraint( dag, noise_config, + partitions.nb_partitions, partition, op.id.0, precision, @@ -733,21 +723,21 @@ fn operations_counts( nb_partitions: usize, instr_partition: &InstructionPartition, ) -> OperationsCount { - let mut counts = OperationsValue::zero(nb_partitions); + let mut counts = SymbolMap::new(); if let Operator::Lut { input, .. } = op { let partition = instr_partition.instruction_partition; - let nb_lut = dag.out_shapes[input.0].flat_size() as f64; + let nb_lut = dag.out_shapes[input.0].flat_size() as usize; let src_partition = match instr_partition.inputs_transition[0] { Some(Transition::Internal { src_partition }) => src_partition, Some(Transition::Additional { .. }) | None => partition, }; - *counts.ks(src_partition, partition) += nb_lut; - *counts.pbs(partition) += nb_lut; + counts.update(keyswitch(src_partition, partition), |a| a + nb_lut); + counts.update(bootstrap(partition), |a| a + nb_lut); for &conv_partition in &instr_partition.alternative_output_representation { - *counts.fks(partition, conv_partition) += nb_lut; + counts.update(fast_keyswitch(partition, conv_partition), |a| a + nb_lut); } } - OperationsCount { counts } + OperationsCount(counts) } #[allow(unused)] @@ -767,15 +757,6 @@ fn collect_operations_count(dag: &VariancedDag) -> Vec { .collect() } -#[allow(unused)] -fn sum_operations_count(all_counts: &[OperationsCount]) -> OperationsCount { - let mut sum_counts = OperationsValue::zero(all_counts[0].counts.nb_partitions()); - for OperationsCount { counts } in all_counts { - sum_counts += counts; - } - OperationsCount { counts: sum_counts } -} - #[cfg(test)] pub mod tests { use super::*; @@ -807,11 +788,7 @@ pub mod tests { ) { for symbolic_variance_partition in [LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION] { let sb = dag.instrs_variances[op_i][partition].clone(); - let coeff = if sb == SymbolicVariance::ZERO { - 0.0 - } else { - sb.coeff_input(symbolic_variance_partition) - }; + let coeff = sb.coeff(input_noise(symbolic_variance_partition)); if symbolic_variance_partition == partition { assert!( coeff == expected_coeff, @@ -845,11 +822,7 @@ pub mod tests { let sb = dag.instrs_variances[op_i][partition].clone(); eprintln!("{:?}", dag.instrs_variances[op_i]); eprintln!("{:?}", dag.instrs_variances[op_i][partition]); - let coeff = if sb == SymbolicVariance::ZERO { - 0.0 - } else { - sb.coeff_pbs(symbolic_variance_partition) - }; + let coeff = sb.coeff(bootstrap_noise(symbolic_variance_partition)); if symbolic_variance_partition == partition { assert!( coeff == expected_coeff, @@ -877,7 +850,7 @@ pub mod tests { fn test_decreasing_panics() { let mut dag = unparametrized::Dag::new(); let inp = dag.add_input(1, Shape::number()); - let oup = dag.add_levelled_op( + let oup = dag.add_linear_noise( [inp], LevelledComplexity::ZERO, [0.5], @@ -893,7 +866,7 @@ pub mod tests { fn test_composition_with_nongrowing_inputs_only() { let mut dag = unparametrized::Dag::new(); let inp = dag.add_input(1, Shape::number()); - let oup = dag.add_levelled_op( + let oup = dag.add_linear_noise( [inp], LevelledComplexity::ZERO, [1.0], @@ -917,7 +890,7 @@ pub mod tests { fn test_composition_with_growing_inputs_panics() { let mut dag = unparametrized::Dag::new(); let inp = dag.add_input(1, Shape::number()); - let oup = dag.add_levelled_op( + let oup = dag.add_linear_noise( [inp], LevelledComplexity::ZERO, [1.1], @@ -969,7 +942,7 @@ pub mod tests { let expected_constraint_strings = vec![ "1σ²Br[0] + 1σ²K[0] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", "1σ²Br[0] + 1σ²K[0→1] + 1σ²M[1] < (2²)**-10 (6bits partition:1 count:1, dom=20)", - "1σ²Br[0] + 1σ²Br[1] + 1σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", + "1σ²Br[0] + 1σ²Br[1] + 1σ²K[0] + 1σ²FK[1→0] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", "1σ²Br[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", ]; assert_eq!(actual_constraint_strings, expected_constraint_strings); @@ -1010,11 +983,11 @@ pub mod tests { .map(ToString::to_string) .collect::>(); let expected_constraint_strings = vec![ - "1σ²Br[0] + 1σ²FK[0→1] + 1σ²Br[2] + 1σ²FK[2→1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", + "1σ²Br[0] + 1σ²Br[2] + 1σ²K[1→0] + 1σ²FK[0→1] + 1σ²FK[2→1] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", "1σ²Br[0] + 1σ²K[0→1] + 1σ²M[1] < (2²)**-10 (6bits partition:1 count:1, dom=20)", - "1σ²Br[0] + 1σ²FK[0→1] + 1σ²Br[1] + 1σ²Br[2] + 1σ²FK[2→1] + 1σ²K[1→2] + 1σ²M[2] < (2²)**-17 (13bits partition:2 count:1, dom=34)", + "1σ²Br[0] + 1σ²Br[1] + 1σ²Br[2] + 1σ²K[1→2] + 1σ²FK[0→1] + 1σ²FK[2→1] + 1σ²M[2] < (2²)**-17 (13bits partition:2 count:1, dom=34)", "1σ²Br[2] < (2²)**-7 (3bits partition:2 count:1, dom=14)", - "1σ²Br[0] + 1σ²FK[0→1] + 1σ²Br[1] + 1σ²Br[2] + 1σ²FK[2→1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", + "1σ²Br[0] + 1σ²Br[1] + 1σ²Br[2] + 1σ²K[1→0] + 1σ²FK[0→1] + 1σ²FK[2→1] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", "1σ²Br[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", ]; assert_eq!(actual_constraint_strings, expected_constraint_strings); @@ -1080,7 +1053,7 @@ pub mod tests { let input1 = dag.add_input(8, Shape::number()); let input2 = dag.add_input(8, Shape::number()); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); - let _levelled = dag.add_levelled_op( + let _levelled = dag.add_linear_noise( [lut1, input2], LevelledComplexity::ZERO, [manp, manp], @@ -1091,10 +1064,6 @@ pub mod tests { assert!(dag.nb_partitions == 1); } - fn nan_symbolic_variance(sb: &SymbolicVariance) -> bool { - sb.coeffs[0].is_nan() - } - #[allow(clippy::float_cmp)] #[test] fn test_rounded_v3_first_layer_and_second_layer() { @@ -1113,22 +1082,19 @@ pub mod tests { for op_i in input1.0..lut1.0 { let p = LOW_PRECISION_PARTITION; let sb = &dag.instrs_variances[op_i][p]; - assert!(sb.coeff_input(p) >= 1.0 || sb.coeff_pbs(p) >= 1.0); - assert!(nan_symbolic_variance( - &dag.instrs_variances[op_i][HIGH_PRECISION_PARTITION] - )); + assert!(sb.coeff(input_noise(p)) >= 1.0 || sb.coeff(bootstrap_noise(p)) >= 1.0); } // First lut is HIGH_PRECISION_PARTITION and immedialtely converted to LOW_PRECISION_PARTITION let p = HIGH_PRECISION_PARTITION; let sb = &dag.instrs_variances[lut1.0][p]; - assert!(sb.coeff_input(p) == 0.0); - assert!(sb.coeff_pbs(p) == 1.0); + assert!(sb.coeff(input_noise(p)) == 0.0); + assert!(sb.coeff(bootstrap_noise(p)) == 1.0); let sb_after_fast_ks = &dag.instrs_variances[lut1.0][LOW_PRECISION_PARTITION]; assert!( - sb_after_fast_ks.coeff_partition_keyswitch_to_big( + sb_after_fast_ks.coeff(fast_keyswitch_noise( HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION - ) == 1.0 + )) == 1.0 ); // The next rounded is on LOW_PRECISION_PARTITION but base noise can comes from HIGH_PRECISION_PARTITION + FKS for op_i in (lut1.0 + 1)..lut2.0 { @@ -1136,31 +1102,28 @@ pub mod tests { let p = LOW_PRECISION_PARTITION; let sb = &dag.instrs_variances[op_i][p]; // The base noise is either from the other partition and shifted or from the current partition and 1 - assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); - assert!(sb.coeff_input(HIGH_PRECISION_PARTITION) == 0.0); - if sb.coeff_pbs(HIGH_PRECISION_PARTITION) >= 1.0 { + assert!(sb.coeff(input_noise(LOW_PRECISION_PARTITION)) == 0.0); + assert!(sb.coeff(input_noise(HIGH_PRECISION_PARTITION)) == 0.0); + if sb.coeff(bootstrap_noise(HIGH_PRECISION_PARTITION)) >= 1.0 { assert!( - sb.coeff_pbs(HIGH_PRECISION_PARTITION) - == sb.coeff_partition_keyswitch_to_big( + sb.coeff(bootstrap_noise(HIGH_PRECISION_PARTITION)) + == sb.coeff(fast_keyswitch_noise( HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION - ) + )) ); } else { - assert!(sb.coeff_pbs(LOW_PRECISION_PARTITION) == 1.0); + assert!(sb.coeff(bootstrap_noise(LOW_PRECISION_PARTITION)) == 1.0); assert!( - sb.coeff_partition_keyswitch_to_big( + sb.coeff(fast_keyswitch_noise( HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION - ) == 0.0 + )) == 0.0 ); } } - assert!(nan_symbolic_variance( - &dag.instrs_variances[lut2.0][LOW_PRECISION_PARTITION] - )); let sb = &dag.instrs_variances[lut2.0][HIGH_PRECISION_PARTITION]; - assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) >= 1.0); + assert!(sb.coeff(bootstrap_noise(HIGH_PRECISION_PARTITION)) >= 1.0); } #[allow(clippy::float_cmp, clippy::cognitive_complexity)] @@ -1179,24 +1142,28 @@ pub mod tests { // First layer is fully HIGH_PRECISION_PARTITION assert!( dag.instrs_variances[free_input1.0][HIGH_PRECISION_PARTITION] - .coeff_input(HIGH_PRECISION_PARTITION) + .coeff(input_noise(HIGH_PRECISION_PARTITION)) == 1.0 ); // First layer tlu let sb = &dag.instrs_variances[input1.0][HIGH_PRECISION_PARTITION]; - assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); - assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0); + assert!(sb.coeff(input_noise(LOW_PRECISION_PARTITION)) == 0.0); + assert!(sb.coeff(bootstrap_noise(HIGH_PRECISION_PARTITION)) == 1.0); assert!( - sb.coeff_partition_keyswitch_to_big(HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION) - == 0.0 + sb.coeff(fast_keyswitch_noise( + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION + )) == 0.0 ); // The same cyphertext exists in another partition with additional noise due to fast keyswitch let sb = &dag.instrs_variances[input1.0][LOW_PRECISION_PARTITION]; - assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); - assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0); + assert!(sb.coeff(input_noise(LOW_PRECISION_PARTITION)) == 0.0); + assert!(sb.coeff(bootstrap_noise(HIGH_PRECISION_PARTITION)) == 1.0); assert!( - sb.coeff_partition_keyswitch_to_big(HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION) - == 1.0 + sb.coeff(fast_keyswitch_noise( + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION + )) == 1.0 ); // Second layer @@ -1212,12 +1179,14 @@ pub mod tests { let bit_erase = weights.values == [1, -1]; let first_bit_erase = bit_erase && !first_bit_erase_verified; let input0_sb = &dag.instrs_variances[inputs[0].0][LOW_PRECISION_PARTITION]; - let input0_coeff_pbs_high = input0_sb.coeff_pbs(HIGH_PRECISION_PARTITION); - let input0_coeff_pbs_low = input0_sb.coeff_pbs(LOW_PRECISION_PARTITION); - let input0_coeff_fks = input0_sb.coeff_partition_keyswitch_to_big( + let input0_coeff_pbs_high = + input0_sb.coeff(bootstrap_noise(HIGH_PRECISION_PARTITION)); + let input0_coeff_pbs_low = + input0_sb.coeff(bootstrap_noise(LOW_PRECISION_PARTITION)); + let input0_coeff_fks = input0_sb.coeff(fast_keyswitch_noise( HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION, - ); + )); if bit_extract { first_bit_extract_verified |= first_bit_extract; assert!(input0_coeff_pbs_high >= 1.0); @@ -1230,12 +1199,14 @@ pub mod tests { } else if bit_erase { first_bit_erase_verified |= first_bit_erase; let input1_sb = &dag.instrs_variances[inputs[1].0][LOW_PRECISION_PARTITION]; - let input1_coeff_pbs_high = input1_sb.coeff_pbs(HIGH_PRECISION_PARTITION); - let input1_coeff_pbs_low = input1_sb.coeff_pbs(LOW_PRECISION_PARTITION); - let input1_coeff_fks = input1_sb.coeff_partition_keyswitch_to_big( + let input1_coeff_pbs_high = + input1_sb.coeff(bootstrap_noise(HIGH_PRECISION_PARTITION)); + let input1_coeff_pbs_low = + input1_sb.coeff(bootstrap_noise(LOW_PRECISION_PARTITION)); + let input1_coeff_fks = input1_sb.coeff(fast_keyswitch_noise( HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION, - ); + )); if first_bit_erase { assert!(input0_coeff_pbs_low == 0.0); } else { @@ -1274,13 +1245,13 @@ pub mod tests { // First lut to force partition HIGH_PRECISION_PARTITION "1σ²In[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)", // 16384(shift) = (2**7)², for Br[1] - "16384σ²Br[1] + 16384σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)", + "16384σ²Br[1] + 1σ²K[0] + 16384σ²FK[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)", // 4096(shift) = (2**6)², 1(due to 1 erase bit) for Br[0] and 1 for Br[1] - "4096σ²Br[0] + 4096σ²Br[1] + 4096σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)", + "4096σ²Br[0] + 4096σ²Br[1] + 1σ²K[0] + 4096σ²FK[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)", // 1024(shift) = (2**5)², 2(due to 2 erase bit for Br[0] and 1 for Br[1] - "2048σ²Br[0] + 1024σ²Br[1] + 1024σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)", + "2048σ²Br[0] + 1024σ²Br[1] + 1σ²K[0] + 1024σ²FK[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)", // 3(erase bit) Br[0] and 1 initial Br[1] - "3σ²Br[0] + 1σ²Br[1] + 1σ²FK[1→0] + 1σ²K[0→1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)", + "3σ²Br[0] + 1σ²Br[1] + 1σ²K[0→1] + 1σ²FK[1→0] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)", // Last lut to close the cycle "1σ²Br[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)", ]; @@ -1337,10 +1308,10 @@ pub mod tests { // 16384(shift) = (2**7)², for Br[1] "16384σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)", // 4096(shift) = (2**6)², 1(due to 1 erase bit) for Br[0] and 1 for Br[1] - "4096σ²Br[0] + 4096σ²FK[0→1] + 4096σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)", + "4096σ²Br[0] + 4096σ²Br[1] + 1σ²K[1→0] + 4096σ²FK[0→1] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)", // 1024(shift) = (2**5)², 2(due to 2 erase bit for Br[0] and 1 for Br[1] - "2048σ²Br[0] + 2048σ²FK[0→1] + 1024σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)", - "3σ²Br[0] + 3σ²FK[0→1] + 1σ²Br[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)", + "2048σ²Br[0] + 1024σ²Br[1] + 1σ²K[1→0] + 2048σ²FK[0→1] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)", + "3σ²Br[0] + 1σ²Br[1] + 1σ²K[1] + 3σ²FK[0→1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)", ]; for (c, ec) in constraints.iter().zip(expected_constraints) { assert!( @@ -1390,24 +1361,24 @@ pub mod tests { .collect(); #[rustfmt::skip] // nighlty and stable are inconsitent here let expected_counts = [ - "ZERO x ¢", // free_input1 - "1¢K[1] + 1¢Br[1] + 1¢FK[1→0]", // input1 - "ZERO x ¢", // shift - "ZERO x ¢", // cast - "1¢K[0] + 1¢Br[0]", // extract (lut) - "ZERO x ¢", // erase (dot) - "ZERO x ¢", // cast - "ZERO x ¢", // shift - "ZERO x ¢", // cast - "1¢K[0] + 1¢Br[0]", // extract (lut) - "ZERO x ¢", // erase (dot) - "ZERO x ¢", // cast - "ZERO x ¢", // shift - "ZERO x ¢", // cast - "1¢K[0] + 1¢Br[0]", // extract (lut) - "ZERO x ¢", // erase (dot) - "ZERO x ¢", // cast - "1¢K[0→1] + 1¢Br[1]", // _lut1 + "∅", // free_input1 + "1¢Br[1] + 1¢K[1] + 1¢FK[1→0]", // input1 + "∅", // shift + "∅", // cast + "1¢Br[0] + 1¢K[0]", // extract (lut) + "∅", // erase (dot) + "∅", // cast + "∅", // shift + "∅", // cast + "1¢Br[0] + 1¢K[0]", // extract (lut) + "∅", // erase (dot) + "∅", // cast + "∅", // shift + "∅", // cast + "1¢Br[0] + 1¢K[0]", // extract (lut) + "∅", // erase (dot) + "∅", // cast + "1¢Br[1] + 1¢K[0→1]", // _lut1 ]; for ((c, ec), op) in instrs_counts.iter().zip(expected_counts).zip(dag.operators) { assert!( @@ -1418,7 +1389,7 @@ pub mod tests { eprintln!("{}", dag.operations_count); assert!( format!("{}", dag.operations_count) - == "3¢K[0] + 1¢K[0→1] + 1¢K[1] + 3¢Br[0] + 2¢Br[1] + 1¢FK[1→0]" + == "3¢Br[0] + 2¢Br[1] + 3¢K[0] + 1¢K[0→1] + 1¢K[1] + 1¢FK[1→0]" ); } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs index c8a787712a..a050e4817a 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs @@ -1,134 +1,93 @@ -use std::fmt; +use std::{fmt, ops::Add}; -use crate::utils::f64::f64_dot; - -use super::{operations_value::OperationsValue, partitions::PartitionIndex}; +use super::{ + partitions::PartitionIndex, + symbolic::{fast_keyswitch, keyswitch, Symbol, SymbolMap}, +}; +/// A structure storing the number of times an fhe operation gets executed in a circuit. #[derive(Clone, Debug)] -pub struct OperationsCount { - pub counts: OperationsValue, -} +pub struct OperationsCount(pub(super) SymbolMap); -#[derive(Clone, Debug)] -#[allow(dead_code)] -pub struct OperationsCost { - pub costs: OperationsValue, -} +impl Add for OperationsCount { + type Output = OperationsCount; -#[derive(Clone, Debug)] -pub struct Complexity { - pub counts: OperationsValue, + fn add(self, rhs: OperationsCount) -> Self::Output { + let mut output = self; + for (s, v) in rhs.0.into_iter() { + output.0.update(s, |a| a + v); + } + output + } } impl fmt::Display for OperationsCount { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut add_plus = ""; - let counts = &self.counts; - let nb_partitions = counts.nb_partitions(); - let index = &counts.index; - for src_partition in PartitionIndex::range(0, nb_partitions) { - for dst_partition in PartitionIndex::range(0, nb_partitions) { - let coeff = counts.values[index.keyswitch_to_small(src_partition, dst_partition)]; - if coeff != 0.0 { - if src_partition == dst_partition { - write!(f, "{add_plus}{coeff}¢K[{src_partition}]")?; - } else { - write!(f, "{add_plus}{coeff}¢K[{src_partition}→{dst_partition}]")?; - } - add_plus = " + "; - } - } - } - for src_partition in PartitionIndex::range(0, nb_partitions) { - assert!(counts.values[index.input(src_partition)] == 0.0); - let coeff = counts.values[index.pbs(src_partition)]; - if coeff != 0.0 { - write!(f, "{add_plus}{coeff}¢Br[{src_partition}]")?; - add_plus = " + "; - } - for dst_partition in PartitionIndex::range(0, nb_partitions) { - let coeff = counts.values[index.keyswitch_to_big(src_partition, dst_partition)]; - if coeff != 0.0 { - write!(f, "{add_plus}{coeff}¢FK[{src_partition}→{dst_partition}]")?; - add_plus = " + "; - } - } - } + self.0.fmt_with(f, "+", "¢") + } +} - for partition in PartitionIndex::range(0, nb_partitions) { - assert!(counts.values[index.modulus_switching(partition)] == 0.0); - } - if add_plus.is_empty() { - write!(f, "ZERO x ¢")?; - } - Ok(()) +/// An ensemble of costs associated with fhe operation symbols. +#[derive(Clone, Debug)] +pub struct ComplexityValues(SymbolMap); + +impl ComplexityValues { + /// Returns an empty set of cost values. + pub fn new() -> Self { + ComplexityValues(SymbolMap::new()) + } + + /// Sets the cost associated with an fhe operation symbol. + pub fn set_cost(&mut self, source: Symbol, value: f64) { + self.0.set(source, value); } } -impl Complexity { - pub fn of(counts: &OperationsCount) -> Self { - Self { - counts: counts.counts.clone(), - } +/// A complexity expression is a sum of complexity terms associating operation +/// symbols with the number of time they gets executed in the circuit. +#[derive(Clone, Debug)] +pub struct ComplexityExpression(SymbolMap); + +impl ComplexityExpression { + /// Creates a complexity expression from a set of operation counts. + pub fn from(counts: &OperationsCount) -> Self { + Self(counts.0.clone()) } - pub fn complexity(&self, costs: &OperationsValue) -> f64 { - f64_dot(&self.counts, costs) + /// Evaluates the total cost expression on a set of cost values. + pub fn evaluate_total_cost(&self, costs: &ComplexityValues) -> f64 { + self.0.iter().fold(0.0, |acc, (symbol, n_ops)| { + acc + (n_ops as f64) * costs.0.get(symbol) + }) } - pub fn ks_max_cost( + /// Evaluates the max ks cost expression on a set of cost values. + pub fn evaluate_ks_max_cost( &self, complexity_cut: f64, - costs: &OperationsValue, + costs: &ComplexityValues, src_partition: PartitionIndex, dst_partition: PartitionIndex, ) -> f64 { - let ks_index = costs.index.keyswitch_to_small(src_partition, dst_partition); - let actual_ks_cost = costs.values[ks_index]; - let ks_coeff = self.counts[self - .counts - .index - .keyswitch_to_small(src_partition, dst_partition)]; - let actual_complexity = self.complexity(costs) - ks_coeff * actual_ks_cost; - - (complexity_cut - actual_complexity) / ks_coeff + let actual_ks_cost = costs.0.get(keyswitch(src_partition, dst_partition)); + let ks_coeff = self.0.get(keyswitch(src_partition, dst_partition)); + let actual_complexity = + self.evaluate_total_cost(costs) - (ks_coeff as f64) * actual_ks_cost; + (complexity_cut - actual_complexity) / (ks_coeff as f64) } - pub fn fks_max_cost( + /// Evaluates the max fks cost expression on a set of cost values. + pub fn evaluate_fks_max_cost( &self, complexity_cut: f64, - costs: &OperationsValue, + costs: &ComplexityValues, src_partition: PartitionIndex, dst_partition: PartitionIndex, ) -> f64 { - let fks_index = costs.index.keyswitch_to_big(src_partition, dst_partition); - let actual_fks_cost = costs.values[fks_index]; - let fks_coeff = self.counts[self - .counts - .index - .keyswitch_to_big(src_partition, dst_partition)]; - let actual_complexity = self.complexity(costs) - fks_coeff * actual_fks_cost; - - (complexity_cut - actual_complexity) / fks_coeff - } - - pub fn compressed(self) -> Self { - let mut detect_used: Vec = vec![false; self.counts.len()]; - for (i, &count) in self.counts.iter().enumerate() { - if count > 0.0 { - detect_used[i] = true; - } - } - Self { - counts: self.counts.compress(&detect_used), - } - } - - pub fn zero_cost(&self) -> OperationsValue { - if self.counts.index.is_compressed() { - OperationsValue::zero_compressed(&self.counts.index) - } else { - OperationsValue::zero(self.counts.nb_partitions()) - } + let actual_fks_cost = costs.0.get(fast_keyswitch(src_partition, dst_partition)); + let fks_coeff = self.0.get(fast_keyswitch(src_partition, dst_partition)); + let actual_complexity = + self.evaluate_total_cost(costs) - (fks_coeff as f64) * actual_fks_cost; + (complexity_cut - actual_complexity) / (fks_coeff as f64) } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs index 8dc8a81a8e..7f93c883c3 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs @@ -1,9 +1,10 @@ use crate::noise_estimator::p_error::{combine_errors, repeat_p_error}; use crate::optimization::dag::multi_parameters::variance_constraint::VarianceConstraint; use crate::optimization::dag::solo_key::analyze::p_error_from_relative_variance; -use crate::utils::f64::f64_dot; -use super::operations_value::OperationsValue; +use super::noise_expression::{ + bootstrap_noise, fast_keyswitch_noise, keyswitch_noise, NoiseValues, +}; use super::partitions::PartitionIndex; #[derive(Debug, Clone)] @@ -21,7 +22,6 @@ impl Feasibility { #[derive(Clone)] pub struct Feasible { - // TODO: move kappa here pub constraints: Vec, pub undominated_constraints: Vec, pub kappa: f64, // to convert variance to local probabilities @@ -41,20 +41,20 @@ impl Feasible { pub fn pbs_max_feasible_variance( &self, - operations_variance: &OperationsValue, + operations_variance: &NoiseValues, partition: PartitionIndex, ) -> f64 { - let pbs_index = operations_variance.index.pbs(partition); - let actual_pbs_variance = operations_variance.values[pbs_index]; - + let actual_pbs_variance = operations_variance.variance(bootstrap_noise(partition)); let mut smallest_pbs_max_variance = f64::MAX; for constraint in &self.undominated_constraints { - let pbs_coeff = constraint.variance.coeff_pbs(partition); + let pbs_coeff = constraint + .noise_expression + .coeff(bootstrap_noise(partition)); if pbs_coeff == 0.0 { continue; } - let actual_variance = f64_dot(operations_variance, &constraint.variance.coeffs) + let actual_variance = constraint.noise_expression.evaluate(operations_variance) - pbs_coeff * actual_pbs_variance; let pbs_max_variance = (constraint.safe_variance_bound - actual_variance) / pbs_coeff; smallest_pbs_max_variance = smallest_pbs_max_variance.min(pbs_max_variance); @@ -64,25 +64,23 @@ impl Feasible { pub fn ks_max_feasible_variance( &self, - operations_variance: &OperationsValue, + operations_variance: &NoiseValues, src_partition: PartitionIndex, dst_partition: PartitionIndex, ) -> f64 { - let ks_index = operations_variance - .index - .keyswitch_to_small(src_partition, dst_partition); - let actual_ks_variance = operations_variance.values[ks_index]; + let actual_ks_variance = + operations_variance.variance(keyswitch_noise(src_partition, dst_partition)); let mut smallest_ks_max_variance = f64::MAX; for constraint in &self.undominated_constraints { let ks_coeff = constraint - .variance - .coeff_keyswitch_to_small(src_partition, dst_partition); + .noise_expression + .coeff(keyswitch_noise(src_partition, dst_partition)); if ks_coeff == 0.0 { continue; } - let actual_variance = f64_dot(operations_variance, &constraint.variance.coeffs) + let actual_variance = constraint.noise_expression.evaluate(operations_variance) - ks_coeff * actual_ks_variance; let ks_max_variance = (constraint.safe_variance_bound - actual_variance) / ks_coeff; smallest_ks_max_variance = smallest_ks_max_variance.min(ks_max_variance); @@ -93,25 +91,23 @@ impl Feasible { pub fn fks_max_feasible_variance( &self, - operations_variance: &OperationsValue, + operations_variance: &NoiseValues, src_partition: PartitionIndex, dst_partition: PartitionIndex, ) -> f64 { - let fks_index = operations_variance - .index - .keyswitch_to_big(src_partition, dst_partition); - let actual_fks_variance = operations_variance.values[fks_index]; + let actual_fks_variance = + operations_variance.variance(fast_keyswitch_noise(src_partition, dst_partition)); let mut smallest_fks_max_variance = f64::MAX; for constraint in &self.undominated_constraints { let fks_coeff = constraint - .variance - .coeff_partition_keyswitch_to_big(src_partition, dst_partition); + .noise_expression + .coeff(fast_keyswitch_noise(src_partition, dst_partition)); if fks_coeff == 0.0 { continue; } - let actual_variance = f64_dot(operations_variance, &constraint.variance.coeffs) + let actual_variance = constraint.noise_expression.evaluate(operations_variance) - fks_coeff * actual_fks_variance; let fks_max_variance = (constraint.safe_variance_bound - actual_variance) / fks_coeff; smallest_fks_max_variance = smallest_fks_max_variance.min(fks_max_variance); @@ -120,7 +116,7 @@ impl Feasible { smallest_fks_max_variance } - pub fn feasible(&self, operations_variance: &OperationsValue) -> bool { + pub fn feasible(&self, operations_variance: &NoiseValues) -> bool { if self.global_p_error.is_none() { self.local_feasible(operations_variance) } else { @@ -128,9 +124,9 @@ impl Feasible { } } - fn local_feasible(&self, operations_variance: &OperationsValue) -> bool { + fn local_feasible(&self, operations_variance: &NoiseValues) -> bool { for constraint in &self.undominated_constraints { - if f64_dot(operations_variance, &constraint.variance.coeffs) + if constraint.noise_expression.evaluate(operations_variance) > constraint.safe_variance_bound { return false; @@ -139,20 +135,20 @@ impl Feasible { true } - fn global_feasible(&self, operations_variance: &OperationsValue) -> bool { + fn global_feasible(&self, operations_variance: &NoiseValues) -> bool { self.global_p_error_with_cut(operations_variance, self.global_p_error.unwrap_or(1.0)) .is_some() } pub fn worst_constraint( &self, - operations_variance: &OperationsValue, + operations_variance: &NoiseValues, ) -> (f64, f64, &VarianceConstraint) { let mut worst_constraint = &self.undominated_constraints[0]; let mut worst_relative_variance = 0.0; let mut worst_variance = 0.0; for constraint in &self.undominated_constraints { - let variance = f64_dot(operations_variance, &constraint.variance.coeffs); + let variance = constraint.noise_expression.evaluate(operations_variance); let relative_variance = variance / constraint.safe_variance_bound; if relative_variance > worst_relative_variance { worst_relative_variance = relative_variance; @@ -163,19 +159,15 @@ impl Feasible { (worst_variance, worst_relative_variance, worst_constraint) } - pub fn p_error(&self, operations_variance: &OperationsValue) -> f64 { + pub fn p_error(&self, operations_variance: &NoiseValues) -> f64 { let (_, relative_variance, _) = self.worst_constraint(operations_variance); p_error_from_relative_variance(relative_variance, self.kappa) } - fn global_p_error_with_cut( - &self, - operations_variance: &OperationsValue, - cut: f64, - ) -> Option { + fn global_p_error_with_cut(&self, operations_variance: &NoiseValues, cut: f64) -> Option { let mut global_p_error = 0.0; for constraint in &self.constraints { - let variance = f64_dot(operations_variance, &constraint.variance.coeffs); + let variance = constraint.noise_expression.evaluate(operations_variance); let relative_variance = variance / constraint.safe_variance_bound; let p_error = p_error_from_relative_variance(relative_variance, self.kappa); global_p_error = combine_errors( @@ -189,19 +181,19 @@ impl Feasible { Some(global_p_error) } - pub fn global_p_error(&self, operations_variance: &OperationsValue) -> f64 { + pub fn global_p_error(&self, operations_variance: &NoiseValues) -> f64 { self.global_p_error_with_cut(operations_variance, 1.0) .unwrap_or(1.0) } pub fn filter_constraints(&self, partition: PartitionIndex) -> Self { - let nb_partitions = self.constraints[0].variance.nb_partitions(); + let nb_partitions = self.constraints[0].nb_partitions; let touch_any_ks = |constraint: &VarianceConstraint, i| { - let variance = &constraint.variance; - variance.coeff_keyswitch_to_small(partition, i) > 0.0 - || variance.coeff_keyswitch_to_small(i, partition) > 0.0 - || variance.coeff_partition_keyswitch_to_big(partition, i) > 0.0 - || variance.coeff_partition_keyswitch_to_big(i, partition) > 0.0 + let variance = &constraint.noise_expression; + variance.coeff(keyswitch_noise(partition, i)) > 0.0 + || variance.coeff(keyswitch_noise(i, partition)) > 0.0 + || variance.coeff(fast_keyswitch_noise(partition, i)) > 0.0 + || variance.coeff(fast_keyswitch_noise(i, partition)) > 0.0 }; let partition_constraints: Vec<_> = self .constraints @@ -214,44 +206,4 @@ impl Feasible { .collect(); Self::of(&partition_constraints, self.kappa, self.global_p_error) } - - pub fn compressed(self) -> Self { - let mut detect_used: Vec = vec![false; self.constraints[0].variance.coeffs.len()]; - for constraint in &self.constraints { - for (i, &coeff) in constraint.variance.coeffs.iter().enumerate() { - if coeff > 0.0 { - detect_used[i] = true; - } - } - } - let compress = |c: &VarianceConstraint| VarianceConstraint { - variance: c.variance.compress(&detect_used), - ..(c.to_owned()) - }; - let constraints = self.constraints.iter().map(compress).collect(); - let undominated_constraints = self.undominated_constraints.iter().map(compress).collect(); - Self { - constraints, - undominated_constraints, - ..self - } - } - - pub fn zero_variance(&self) -> OperationsValue { - if self.undominated_constraints[0] - .variance - .coeffs - .index - .is_compressed() - { - OperationsValue::zero_compressed(&self.undominated_constraints[0].variance.coeffs.index) - } else { - OperationsValue::zero( - self.undominated_constraints[0] - .variance - .coeffs - .nb_partitions(), - ) - } - } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 169472cb64..024e12e949 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -3,12 +3,13 @@ mod complexity; mod fast_keyswitch; mod feasible; pub mod keys_spec; -mod operations_value; pub mod optimize; pub mod optimize_generic; pub mod partition_cut; mod partitionning; mod partitions; -mod symbolic_variance; mod union_find; pub(crate) mod variance_constraint; + +mod noise_expression; +mod symbolic; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/noise_expression.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/noise_expression.rs new file mode 100644 index 0000000000..ef0c2e40a8 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/noise_expression.rs @@ -0,0 +1,249 @@ +use std::{ + fmt::Display, + ops::{Add, AddAssign, Mul, MulAssign}, +}; + +use super::{ + partitions::PartitionIndex, + symbolic::{Symbol, SymbolMap}, +}; + +/// An ensemble of noise values for fhe operations. +#[derive(Debug, Clone, PartialEq)] +pub struct NoiseValues(SymbolMap); + +impl NoiseValues { + /// Returns an empty set of noise values. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + NoiseValues(SymbolMap::new()) + } + + /// Sets the noise variance associated with a noise source. + pub fn set_variance(&mut self, source: NoiseSource, value: f64) { + self.0.set(source.0, value); + } + + /// Returns the variance associated with a noise source + pub fn variance(&self, source: NoiseSource) -> f64 { + self.0.get(source.0) + } +} + +impl Display for NoiseValues { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt_with(f, ";", ":=") + } +} + +/// A noise expression, i.e. a sum of noise terms associating a noise source, +/// with a multiplicative coefficient. +#[derive(Debug, Clone, PartialEq)] +pub struct NoiseExpression(SymbolMap); + +impl NoiseExpression { + /// Returns a zero noise expression + pub fn zero() -> Self { + NoiseExpression(SymbolMap::new()) + } + + /// Returns an iterator over noise terms. + pub fn terms_iter(&self) -> impl Iterator + '_ { + self.0.iter().map(|(s, c)| NoiseTerm { + source: NoiseSource(s), + coefficient: c, + }) + } + + /// Returns the coefficient associated with a noise source. + pub fn coeff(&self, source: NoiseSource) -> f64 { + self.0.get(source.0) + } + + /// Builds a noise expression with the largest coefficients of the two expressions. + pub fn max(lhs: &Self, rhs: &Self) -> Self { + let mut lhs = lhs.to_owned(); + for (k, v) in rhs.0.iter() { + let coef = f64::max(lhs.0.get(k), v); + lhs.0.set(k, coef); + } + lhs + } + + /// Evaluate the noise expression on a set of noise values. + pub fn evaluate(&self, values: &NoiseValues) -> f64 { + self.terms_iter().fold(0.0, |acc, term| { + acc + term.coefficient * values.variance(term.source) + }) + } +} + +impl Display for NoiseExpression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt_with(f, "+", "σ²") + } +} + +impl From for NoiseExpression { + fn from(v: NoiseTerm) -> NoiseExpression { + NoiseExpression::zero() + v + } +} + +impl MulAssign for NoiseExpression { + fn mul_assign(&mut self, rhs: f64) { + if rhs == 0. { + self.0.clear(); + } + self.0 + .iter() + .for_each(|(sym, coef)| self.0.set(sym, coef * rhs)); + } +} + +impl Mul for NoiseExpression { + type Output = NoiseExpression; + + fn mul(mut self, rhs: f64) -> Self::Output { + self *= rhs; + self + } +} + +impl Mul for f64 { + type Output = NoiseExpression; + + fn mul(self, mut rhs: NoiseExpression) -> Self::Output { + rhs *= self; + rhs + } +} + +impl AddAssign for NoiseExpression { + fn add_assign(&mut self, rhs: NoiseTerm) { + self.0.update(rhs.source.0, |a| a + rhs.coefficient); + } +} + +impl Add for NoiseExpression { + type Output = NoiseExpression; + + fn add(mut self, rhs: NoiseTerm) -> Self::Output { + self += rhs; + self + } +} + +impl Add for NoiseTerm { + type Output = NoiseExpression; + + fn add(self, mut rhs: NoiseExpression) -> Self::Output { + rhs += self; + rhs + } +} + +impl Add for NoiseExpression { + type Output = NoiseExpression; + + fn add(mut self, rhs: NoiseExpression) -> Self::Output { + for term in rhs.terms_iter() { + self += term; + } + self + } +} + +impl Add for NoiseTerm { + type Output = NoiseExpression; + + fn add(self, rhs: NoiseTerm) -> Self::Output { + let mut output = NoiseExpression::zero(); + output += self; + output += rhs; + output + } +} + +/// A symbolic noise term, or a multiplicative coefficient associated with a noise source. +#[derive(Debug)] +pub struct NoiseTerm { + pub source: NoiseSource, + pub coefficient: f64, +} + +impl Mul for NoiseSource { + type Output = NoiseTerm; + + fn mul(self, rhs: f64) -> Self::Output { + NoiseTerm { + source: self, + coefficient: rhs, + } + } +} + +impl Mul for f64 { + type Output = NoiseTerm; + + fn mul(self, rhs: NoiseSource) -> Self::Output { + NoiseTerm { + source: rhs, + coefficient: self, + } + } +} + +/// A symbolic source of noise, or a noise source variable. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub struct NoiseSource(Symbol); + +/// Returns an input noise source symbol. +pub fn input_noise(partition: PartitionIndex) -> NoiseSource { + NoiseSource(Symbol::Input(partition)) +} + +/// Returns a keyswitch noise source symbol. +pub fn keyswitch_noise(from: PartitionIndex, to: PartitionIndex) -> NoiseSource { + NoiseSource(Symbol::Keyswitch(from, to)) +} + +/// Returns a fast keyswitch noise source symbol. +pub fn fast_keyswitch_noise(from: PartitionIndex, to: PartitionIndex) -> NoiseSource { + NoiseSource(Symbol::FastKeyswitch(from, to)) +} + +/// Returns a pbs noise source symbol. +pub fn bootstrap_noise(partition: PartitionIndex) -> NoiseSource { + NoiseSource(Symbol::Bootstrap(partition)) +} + +/// Returns a modulus switching noise source symbol. +pub fn modulus_switching_noise(partition: PartitionIndex) -> NoiseSource { + NoiseSource(Symbol::ModulusSwitch(partition)) +} + +#[cfg(test)] +mod test { + + use super::*; + + #[test] + fn test_noise_expression() { + let mut expr = 1. * input_noise(PartitionIndex(0)) + + bootstrap_noise(PartitionIndex(0)) * 2.0 + + 5. * (3. * input_noise(PartitionIndex(1)) + bootstrap_noise(PartitionIndex(1)) * 4.0); + assert_eq!(expr.coeff(input_noise(PartitionIndex(0))), 1.0); + assert_eq!(expr.coeff(bootstrap_noise(PartitionIndex(0))), 2.0); + assert_eq!(expr.coeff(input_noise(PartitionIndex(1))), 15.0); + assert_eq!(expr.coeff(bootstrap_noise(PartitionIndex(1))), 20.0); + expr *= 4.; + println!("{expr}"); + assert_eq!(expr.coeff(input_noise(PartitionIndex(0))), 4.0); + assert_eq!(expr.coeff(bootstrap_noise(PartitionIndex(0))), 8.0); + assert_eq!(expr.coeff(input_noise(PartitionIndex(1))), 60.0); + assert_eq!(expr.coeff(bootstrap_noise(PartitionIndex(1))), 80.0); + expr *= 0.; + assert_eq!(expr.0.len(), 0); + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs deleted file mode 100644 index 7b0b90b268..0000000000 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs +++ /dev/null @@ -1,306 +0,0 @@ -use std::ops::{Deref, DerefMut}; - -use super::partitions::PartitionIndex; - -/** - * Index actual operations (input, ks, pbs, fks, modulus switching, etc). - */ -#[derive(Clone, Debug, Eq, PartialEq, PartialOrd)] -pub struct Indexing { - /* Values order - [ - // Partition 1 - // related only to the partition - fresh, pbs, modulus, - // Keyswitchs to small, from any partition to 1 - ks from 1, ks from 2, ... - // Keyswitch to big, from any partition to 1 - ks from 1, ks from 2, ... - - // Partition 2 - // same - ] - */ - pub nb_partitions: usize, - pub compressed_index: Vec, -} - -const VALUE_INDEX_FRESH: usize = 0; -const VALUE_INDEX_PBS: usize = 1; -const VALUE_INDEX_MODULUS: usize = 2; -// number of value always present for a partition -const STABLE_NB_VALUES_BY_PARTITION: usize = 3; - -pub const COMPRESSED_0_INDEX: usize = 0; // all 0.0 value are indexed here -pub const COMPRESSED_FIRST_FREE_INDEX: usize = 1; - -impl Indexing { - fn uncompressed(nb_partitions: usize) -> Self { - Self { - nb_partitions, - compressed_index: vec![], - } - } - - fn compress(&self, used: &[bool]) -> Self { - assert!(!self.is_compressed()); - let mut compressed_index = vec![COMPRESSED_0_INDEX; self.nb_coeff()]; - let mut index = COMPRESSED_FIRST_FREE_INDEX; - for (i, &is_used) in used.iter().enumerate() { - if is_used { - compressed_index[i] = index; - index += 1; - } - } - Self { - compressed_index, - ..(*self) - } - } - - pub fn is_compressed(&self) -> bool { - !self.compressed_index.is_empty() - } - - fn nb_keyswitchs_per_partition(&self) -> usize { - self.nb_partitions - } - - pub fn maybe_compressed(&self, i: usize) -> usize { - if self.is_compressed() { - self.compressed_index[i] - } else { - i - } - } - - pub fn nb_coeff_per_partition(&self) -> usize { - STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions - } - - pub fn nb_coeff(&self) -> usize { - self.nb_partitions * (STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions) - } - - pub fn input(&self, partition: PartitionIndex) -> usize { - assert!(partition.0 < self.nb_partitions); - self.maybe_compressed(partition.0 * self.nb_coeff_per_partition() + VALUE_INDEX_FRESH) - } - - pub fn pbs(&self, partition: PartitionIndex) -> usize { - assert!(partition.0 < self.nb_partitions); - self.maybe_compressed(partition.0 * self.nb_coeff_per_partition() + VALUE_INDEX_PBS) - } - - pub fn modulus_switching(&self, partition: PartitionIndex) -> usize { - assert!(partition.0 < self.nb_partitions); - self.maybe_compressed(partition.0 * self.nb_coeff_per_partition() + VALUE_INDEX_MODULUS) - } - - pub fn keyswitch_to_small( - &self, - src_partition: PartitionIndex, - dst_partition: PartitionIndex, - ) -> usize { - assert!(src_partition.0 < self.nb_partitions); - assert!(dst_partition.0 < self.nb_partitions); - self.maybe_compressed( - // Skip other partition - dst_partition.0 * self.nb_coeff_per_partition() - // Skip non keyswitchs - + STABLE_NB_VALUES_BY_PARTITION - // Select the right keyswicth to small - + src_partition.0, - ) - } - - pub fn keyswitch_to_big( - &self, - src_partition: PartitionIndex, - dst_partition: PartitionIndex, - ) -> usize { - assert!(src_partition.0 < self.nb_partitions); - assert!(dst_partition.0 < self.nb_partitions); - self.maybe_compressed( - // Skip other partition - dst_partition.0 * self.nb_coeff_per_partition() - // Skip non keyswitchs - + STABLE_NB_VALUES_BY_PARTITION - // Skip keyswitch to small - + self.nb_keyswitchs_per_partition() - // Select the right keyswicth to big - + src_partition.0, - ) - } - - pub fn compressed_size(&self) -> usize { - self.compressed_index.iter().copied().max().unwrap_or(0) + 1 - } -} - -/** - * Represent any values indexed by actual operations (input, pbs, modulus switching, ks, fks, , etc) variance, - */ -#[derive(Clone, Debug, PartialOrd)] -pub struct OperationsValue { - pub index: Indexing, - pub values: Vec, -} - -impl PartialEq for OperationsValue { - fn eq(&self, other: &Self) -> bool { - self.index == other.index - && self - .values - .iter() - .zip(other.values.iter()) - .all(|(a, b)| a.is_nan() && b.is_nan() || *a == *b) - } -} - -impl OperationsValue { - pub const ZERO: Self = Self { - index: Indexing { - nb_partitions: 0, - compressed_index: vec![], - }, - values: vec![], - }; - - pub fn zero(nb_partitions: usize) -> Self { - let index = Indexing::uncompressed(nb_partitions); - let nb_coeff = index.nb_coeff(); - Self { - index, - values: vec![0.0; nb_coeff], - } - } - - pub fn zero_compressed(index: &Indexing) -> Self { - assert!(index.is_compressed()); - Self { - index: index.clone(), - values: vec![0.0; index.compressed_size()], - } - } - - pub fn nan(nb_partitions: usize) -> Self { - let index = Indexing::uncompressed(nb_partitions); - let nb_coeff = index.nb_coeff(); - Self { - index, - values: vec![f64::NAN; nb_coeff], - } - } - - pub fn is_nan(&self) -> bool { - for val in self.values.iter() { - if !val.is_nan() { - return false; - } - } - true - } - - pub fn input(&mut self, partition: PartitionIndex) -> &mut f64 { - &mut self.values[self.index.input(partition)] - } - - pub fn pbs(&mut self, partition: PartitionIndex) -> &mut f64 { - &mut self.values[self.index.pbs(partition)] - } - - pub fn ks(&mut self, src_partition: PartitionIndex, dst_partition: PartitionIndex) -> &mut f64 { - &mut self.values[self.index.keyswitch_to_small(src_partition, dst_partition)] - } - - pub fn fks( - &mut self, - src_partition: PartitionIndex, - dst_partition: PartitionIndex, - ) -> &mut f64 { - &mut self.values[self.index.keyswitch_to_big(src_partition, dst_partition)] - } - - pub fn modulus_switching(&mut self, partition: PartitionIndex) -> &mut f64 { - &mut self.values[self.index.modulus_switching(partition)] - } - - pub fn nb_partitions(&self) -> usize { - self.index.nb_partitions - } - - pub fn compress(&self, used: &[bool]) -> Self { - self.compress_with(self.index.compress(used)) - } - - pub fn compress_like(&self, other: Self) -> Self { - self.compress_with(other.index) - } - - fn compress_with(&self, index: Indexing) -> Self { - assert!(!index.compressed_index.is_empty()); - assert!(self.index.compressed_index.is_empty()); - let mut values = vec![0.0; index.compressed_size()]; - for (i, &value) in self.values.iter().enumerate() { - #[allow(clippy::option_if_let_else)] - let j = index.compressed_index[i]; - if j == COMPRESSED_0_INDEX { - assert!(value == 0.0, "Cannot compress non null value"); - } else { - values[j] = value; - } - } - assert!(values[COMPRESSED_0_INDEX] == 0.0); - assert!(!index.compressed_index.is_empty()); - Self { index, values } - } -} - -impl Deref for OperationsValue { - type Target = [f64]; - - fn deref(&self) -> &Self::Target { - &self.values - } -} - -impl DerefMut for OperationsValue { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.values - } -} - -impl std::ops::AddAssign for OperationsValue { - fn add_assign(&mut self, rhs: Self) { - if self.values.is_empty() { - *self = rhs; - } else { - for i in 0..self.values.len() { - self.values[i] += rhs.values[i]; - } - } - } -} - -impl std::ops::AddAssign<&Self> for OperationsValue { - fn add_assign(&mut self, rhs: &Self) { - if self.values.is_empty() { - *self = rhs.clone(); - } else { - for i in 0..self.values.len() { - self.values[i] += rhs.values[i]; - } - } - } -} - -impl std::ops::Mul for OperationsValue { - type Output = Self; - fn mul(self, sq_weight: f64) -> Self { - Self { - values: self.values.iter().map(|v| v * sq_weight).collect(), - ..self - } - } -} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs index 2217fbfc5e..413e2a8447 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs @@ -8,7 +8,6 @@ use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace}; use crate::optimization::dag::multi_parameters::analyze::{analyze, AnalyzedDag}; use crate::optimization::dag::multi_parameters::fast_keyswitch; use crate::optimization::dag::multi_parameters::fast_keyswitch::FksComplexityNoise; -use crate::optimization::dag::multi_parameters::operations_value::OperationsValue; use crate::optimization::dag::solo_key::analyze::lut_count_from_dag; use crate::optimization::dag::solo_key::optimize::optimize as optimize_mono; use crate::optimization::decomposition::cmux::CmuxComplexityNoise; @@ -16,14 +15,20 @@ use crate::optimization::decomposition::keyswitch::KsComplexityNoise; use crate::optimization::decomposition::{cmux, keyswitch, DecompCaches, PersistDecompCaches}; use crate::parameters::GlweParameters; -use crate::optimization::dag::multi_parameters::complexity::Complexity; +use crate::optimization::dag::multi_parameters::complexity::ComplexityExpression; use crate::optimization::dag::multi_parameters::feasible::Feasible; use crate::optimization::dag::multi_parameters::partition_cut::PartitionCut; use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; use crate::optimization::dag::multi_parameters::{analyze, keys_spec}; +use super::complexity::ComplexityValues; use super::feasible::Feasibility; use super::keys_spec::InstructionKeys; +use super::noise_expression::{ + bootstrap_noise, fast_keyswitch_noise, input_noise, keyswitch_noise, modulus_switching_noise, + NoiseValues, +}; +use super::symbolic::{bootstrap, fast_keyswitch, keyswitch}; const DEBUG: bool = false; @@ -64,8 +69,8 @@ pub struct Parameters { #[derive(Debug, Clone)] struct OperationsCV { - variance: OperationsValue, - cost: OperationsValue, + variance: NoiseValues, + cost: ComplexityValues, } type KsSrc = PartitionIndex; @@ -80,12 +85,13 @@ fn optimize_1_ks( ks_pareto: &[KsComplexityNoise], operations: &mut OperationsCV, feasible: &Feasible, - complexity: &Complexity, + complexity: &ComplexityExpression, cut_complexity: f64, ) -> Option { // find the first feasible (and less complex) let ks_max_variance = feasible.ks_max_feasible_variance(&operations.variance, ks_src, ks_dst); - let ks_max_cost = complexity.ks_max_cost(cut_complexity, &operations.cost, ks_src, ks_dst); + let ks_max_cost = + complexity.evaluate_ks_max_cost(cut_complexity, &operations.cost, ks_src, ks_dst); for &ks_quantity in ks_pareto { // variance is decreasing, complexity is increasing let ks_cost = ks_quantity.complexity(ks_input_lwe_dim); @@ -94,8 +100,10 @@ fn optimize_1_ks( return None; } if ks_variance <= ks_max_variance { - *operations.variance.ks(ks_src, ks_dst) = ks_variance; - *operations.cost.ks(ks_src, ks_dst) = ks_cost; + operations + .variance + .set_variance(keyswitch_noise(ks_src, ks_dst), ks_variance); + operations.cost.set_cost(keyswitch(ks_src, ks_dst), ks_cost); return Some(ks_quantity); } } @@ -109,7 +117,7 @@ fn optimize_many_independant_ks( ks_used: &[Vec], operations: &OperationsCV, feasible: &Feasible, - complexity: &Complexity, + complexity: &ComplexityExpression, caches: &mut keyswitch::Cache, cut_complexity: f64, ) -> Option<(Vec<(KsDst, KsComplexityNoise)>, OperationsCV)> { @@ -120,7 +128,7 @@ fn optimize_many_independant_ks( // we know there a feasible solution and a better complexity solution // we just need to check if both properties at the same time occur debug_assert!(feasible.feasible(&operations.variance)); - debug_assert!(complexity.complexity(&operations.cost) <= cut_complexity); + debug_assert!(complexity.evaluate_total_cost(&operations.cost) <= cut_complexity); let mut operations = operations.clone(); let mut ks_bests = Vec::with_capacity(macro_parameters.len()); for (ks_dst, macro_dst) in macro_parameters.iter().enumerate() { @@ -158,7 +166,7 @@ fn optimize_1_fks_and_all_compatible_ks( fks_dst: PartitionIndex, operations: &OperationsCV, feasible: &Feasible, - complexity: &Complexity, + complexity: &ComplexityExpression, caches: &mut keyswitch::Cache, cut_complexity: f64, ciphertext_modulus_log: u32, @@ -184,7 +192,7 @@ fn optimize_1_fks_and_all_compatible_ks( let fks_max_variance = feasible.fks_max_feasible_variance(&operations.variance, fks_src, fks_dst); let mut fks_max_cost = - complexity.fks_max_cost(cut_complexity, &operations.cost, fks_src, fks_dst); + complexity.evaluate_fks_max_cost(cut_complexity, &operations.cost, fks_src, fks_dst); for &ks_quantity in &ks_pareto { // OPT: add a pareto cache for fks let fks_quantity = if same_dim { @@ -234,8 +242,12 @@ fn optimize_1_fks_and_all_compatible_ks( continue; } - *operations.cost.fks(fks_src, fks_dst) = fks_quantity.complexity; - *operations.variance.fks(fks_src, fks_dst) = fks_quantity.noise; + operations + .cost + .set_cost(fast_keyswitch(fks_src, fks_dst), fks_quantity.complexity); + operations + .variance + .set_variance(fast_keyswitch_noise(fks_src, fks_dst), fks_quantity.noise); let sol = optimize_many_independant_ks( macro_parameters, @@ -252,12 +264,13 @@ fn optimize_1_fks_and_all_compatible_ks( continue; } let (best_many_ks, operations) = sol.unwrap(); - let cost = complexity.complexity(&operations.cost); + let cost = complexity.evaluate_total_cost(&operations.cost); if cost > cut_complexity { continue; } cut_complexity = cost; - fks_max_cost = complexity.fks_max_cost(cut_complexity, &operations.cost, fks_src, fks_dst); + fks_max_cost = + complexity.evaluate_fks_max_cost(cut_complexity, &operations.cost, fks_src, fks_dst); // COULD: handle complexity tie let bests = Best1FksAndManyKs { fks: Some((fks_src, fks_quantity)), @@ -277,7 +290,7 @@ fn optimize_dst_exclusive_fks_subset_and_all_ks( ks_used: &[Vec], operations: &OperationsCV, feasible: &Feasible, - complexity: &Complexity, + complexity: &ComplexityExpression, caches: &mut keyswitch::Cache, cut_complexity: f64, ciphertext_modulus_log: u32, @@ -337,7 +350,7 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( ks_used: &[Vec], operations: &OperationsCV, feasible: &Feasible, - complexity: &Complexity, + complexity: &ComplexityExpression, caches: &mut keyswitch::Cache, cut_complexity: f64, best_p_error: f64, @@ -357,8 +370,8 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( // Lower bounds cuts let pbs_cost = cmux_quantity.complexity_br(internal_dim); - *operations.cost.pbs(partition) = pbs_cost; - let lower_cost = complexity.complexity(&operations.cost); + operations.cost.set_cost(bootstrap(partition), pbs_cost); + let lower_cost = complexity.evaluate_total_cost(&operations.cost); if lower_cost > best_sol_complexity { continue; } @@ -368,7 +381,9 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( continue; } - *operations.variance.pbs(partition) = pbs_variance; + operations + .variance + .set_variance(bootstrap_noise(partition), pbs_variance); let sol = optimize_dst_exclusive_fks_subset_and_all_ks( macro_parameters, fks_paretos, @@ -386,7 +401,7 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( } let (best_fks_ks, operations) = sol.unwrap(); - let cost = complexity.complexity(&operations.cost); + let cost = complexity.evaluate_total_cost(&operations.cost); if cost > best_sol_complexity { continue; }; @@ -440,8 +455,14 @@ fn apply_all_ks_lower_bound( let out_internal_dim = macro_parameters[dst.0].internal_dim; let ks_pareto = caches.pareto_quantities(out_internal_dim); let in_lwe_dim = in_glwe_params.sample_extract_lwe_dimension(); - *operations.variance.ks(src, dst) = keyswitch::lowest_noise_ks(ks_pareto, in_lwe_dim); - *operations.cost.ks(src, dst) = keyswitch::lowest_complexity_ks(ks_pareto, in_lwe_dim); + operations.variance.set_variance( + keyswitch_noise(src, dst), + keyswitch::lowest_noise_ks(ks_pareto, in_lwe_dim), + ); + operations.cost.set_cost( + keyswitch(src, dst), + keyswitch::lowest_complexity_ks(ks_pareto, in_lwe_dim), + ); } } @@ -463,18 +484,25 @@ fn apply_fks_variance_and_cost_or_lower_bound( let input_glwe = ¯o_parameters[src.0].glwe_params; let output_glwe = ¯o_parameters[dst.0].glwe_params; if input_glwe == output_glwe { - *operations.variance.fks(src, dst) = 0.0; - *operations.cost.fks(src, dst) = 0.0; + operations + .variance + .set_variance(fast_keyswitch_noise(src, dst), 0.0); + operations.cost.set_cost(fast_keyswitch(src, dst), 0.0); continue; } // if an optimized fks is applicable and is not to be optimized // we use the already optimized fks instead of a lower bound - if let Some(fks) = initial_fks[src.0][dst.0] { + if let Some(this_fks) = initial_fks[src.0][dst.0] { let to_be_optimized = fks_to_optimize[src.0].map_or(false, |fdst| dst == fdst); if !to_be_optimized { - if input_glwe == &fks.src_glwe_param && output_glwe == &fks.dst_glwe_param { - *operations.variance.fks(src, dst) = fks.noise; - *operations.cost.fks(src, dst) = fks.complexity; + if input_glwe == &this_fks.src_glwe_param && output_glwe == &this_fks.dst_glwe_param + { + operations + .variance + .set_variance(fast_keyswitch_noise(src, dst), this_fks.noise); + operations + .cost + .set_cost(fast_keyswitch(src, dst), this_fks.complexity); } continue; } @@ -488,7 +516,7 @@ fn apply_fks_variance_and_cost_or_lower_bound( } else { keyswitch::lowest_complexity_ks(ks_pareto, input_glwe.sample_extract_lwe_dimension()) }; - *operations.cost.fks(src, dst) = cost; + operations.cost.set_cost(fast_keyswitch(src, dst), cost); let mut variance_min = f64::INFINITY; // TODO: use a pareto front to avoid that loop if use_fast_ks { @@ -506,7 +534,9 @@ fn apply_fks_variance_and_cost_or_lower_bound( variance_min = keyswitch::lowest_noise_ks(ks_pareto, input_glwe.sample_extract_lwe_dimension()); } - *operations.variance.fks(src, dst) = variance_min; + operations + .variance + .set_variance(fast_keyswitch_noise(src, dst), variance_min); } } @@ -535,8 +565,12 @@ fn apply_partitions_input_and_modulus_variance_and_cost( ); (input_variance, variance_modulus_switching) }; - *operations.variance.input(i) = input_variance; - *operations.variance.modulus_switching(i) = variance_modulus_switching; + operations + .variance + .set_variance(input_noise(i), input_variance); + operations + .variance + .set_variance(modulus_switching_noise(i), variance_modulus_switching); } } @@ -548,21 +582,27 @@ fn apply_pbs_variance_and_cost_or_lower_bounds( operations: &mut OperationsCV, ) { // setting already chosen pbs and lower bounds - for (i, pbs) in initial_pbs.iter().enumerate() { + for (i, this_pbs) in initial_pbs.iter().enumerate() { let i = PartitionIndex(i); - let pbs = if i == partition { &None } else { pbs }; - if let Some(pbs) = pbs { + let this_pbs = if i == partition { &None } else { this_pbs }; + if let Some(this_pbs) = this_pbs { let internal_dim = macro_parameters[i.0].internal_dim; - *operations.variance.pbs(i) = pbs.noise_br(internal_dim); - *operations.cost.pbs(i) = pbs.complexity_br(internal_dim); + operations + .variance + .set_variance(bootstrap_noise(i), this_pbs.noise_br(internal_dim)); + operations + .cost + .set_cost(bootstrap(i), this_pbs.complexity_br(internal_dim)); } else { // OPT: Most values could be shared on first optimize_macro let in_internal_dim = macro_parameters[i.0].internal_dim; let out_glwe_params = macro_parameters[i.0].glwe_params; let variance_min = cmux::lowest_noise_br(caches.pareto_quantities(out_glwe_params), in_internal_dim); - *operations.variance.pbs(i) = variance_min; - *operations.cost.pbs(i) = 0.0; + operations + .variance + .set_variance(bootstrap_noise(i), variance_min); + operations.cost.set_cost(bootstrap(i), 0.0); } } } @@ -627,7 +667,7 @@ fn optimize_macro( used_tlu_keyswitch: &[Vec], used_conversion_keyswitch: &[Vec], feasible: &Feasible, - complexity: &Complexity, + complexity: &ComplexityExpression, caches: &mut DecompCaches, init_parameters: &Parameters, best_complexity: f64, @@ -651,8 +691,8 @@ fn optimize_macro( let fks_to_optimize = fks_to_optimize(nb_partitions, used_conversion_keyswitch, partition); let operations = OperationsCV { - variance: feasible.zero_variance(), - cost: complexity.zero_cost(), + variance: NoiseValues::new(), + cost: ComplexityValues::new(), }; let partition_feasible = feasible.filter_constraints(partition); @@ -717,7 +757,7 @@ fn optimize_macro( break; } - if complexity.complexity(&operations.cost) > best_complexity { + if complexity.evaluate_total_cost(&operations.cost) > best_complexity { continue; } @@ -757,7 +797,7 @@ fn optimize_macro( continue; } - if complexity.complexity(&operations.cost) > best_complexity { + if complexity.evaluate_total_cost(&operations.cost) > best_complexity { continue; } @@ -803,7 +843,7 @@ fn optimize_macro( continue; } - if complexity.complexity(&operations.cost) > best_complexity { + if complexity.evaluate_total_cost(&operations.cost) > best_complexity { continue; } @@ -933,8 +973,8 @@ pub fn optimize( let mut caches = persistent_caches.caches(); - let feasible = Feasible::of(&dag.variance_constraints, kappa, None).compressed(); - let complexity = Complexity::of(&dag.operations_count).compressed(); + let feasible = Feasible::of(&dag.variance_constraints, kappa, None); + let complexity = ComplexityExpression::from(&dag.operations_count); let used_tlu_keyswitch = used_tlu_keyswitch(&dag); let used_conversion_keyswitch = used_conversion_keyswitch(&dag); @@ -1076,8 +1116,8 @@ fn used_tlu_keyswitch(dag: &AnalyzedDag) -> Vec> { for (src_partition, dst_partition) in cross_partition(dag.nb_partitions) { for constraint in &dag.variance_constraints { if constraint - .variance - .coeff_keyswitch_to_small(src_partition, dst_partition) + .noise_expression + .coeff(keyswitch_noise(src_partition, dst_partition)) != 0.0 { result[src_partition.0][dst_partition.0] = true; @@ -1093,8 +1133,8 @@ fn used_conversion_keyswitch(dag: &AnalyzedDag) -> Vec> { for (src_partition, dst_partition) in cross_partition(dag.nb_partitions) { for constraint in &dag.variance_constraints { if constraint - .variance - .coeff_partition_keyswitch_to_big(src_partition, dst_partition) + .noise_expression + .coeff(fast_keyswitch_noise(src_partition, dst_partition)) != 0.0 { result[src_partition.0][dst_partition.0] = true; @@ -1113,7 +1153,7 @@ fn sanity_check( ciphertext_modulus_log: u32, security_level: u64, feasible: &Feasible, - complexity: &Complexity, + complexity: &ComplexityExpression, ) { assert!(params.is_feasible.is_feasible()); assert!( @@ -1122,8 +1162,8 @@ fn sanity_check( ); let nb_partitions = params.macro_params.len(); let mut operations = OperationsCV { - variance: feasible.zero_variance(), - cost: complexity.zero_cost(), + variance: NoiseValues::new(), + cost: ComplexityValues::new(), }; let micro_params = ¶ms.micro_params; for partition in PartitionIndex::range(0, nb_partitions) { @@ -1136,48 +1176,79 @@ fn sanity_check( glwe_param.log2_polynomial_size, ciphertext_modulus_log, ); - *operations.variance.input(partition) = input_variance; - *operations.variance.modulus_switching(partition) = variance_modulus_switching; - if let Some(pbs) = micro_params.pbs[partition.0] { - *operations.variance.pbs(partition) = pbs.noise_br(internal_dim); - *operations.cost.pbs(partition) = pbs.complexity_br(internal_dim); + operations + .variance + .set_variance(input_noise(partition), input_variance); + operations.variance.set_variance( + modulus_switching_noise(partition), + variance_modulus_switching, + ); + if let Some(this_pbs) = micro_params.pbs[partition.0] { + operations + .variance + .set_variance(bootstrap_noise(partition), this_pbs.noise_br(internal_dim)); + operations + .cost + .set_cost(bootstrap(partition), this_pbs.complexity_br(internal_dim)); } else { - *operations.variance.pbs(partition) = f64::MAX; - *operations.cost.pbs(partition) = f64::MAX; + operations + .variance + .set_variance(bootstrap_noise(partition), f64::MAX); + operations.cost.set_cost(bootstrap(partition), f64::MAX); } for src_partition in PartitionIndex::range(0, nb_partitions) { let src_partition_macro = params.macro_params[src_partition.0].unwrap(); let src_glwe_param = src_partition_macro.glwe_params; let src_lwe_dim = src_glwe_param.sample_extract_lwe_dimension(); - if let Some(ks) = micro_params.ks[src_partition.0][partition.0] { + if let Some(this_ks) = micro_params.ks[src_partition.0][partition.0] { assert!( used_tlu_keyswitch[src_partition.0][partition.0], "Superflous ks[{src_partition}->{partition}]" ); - *operations.variance.ks(src_partition, partition) = ks.noise(src_lwe_dim); - *operations.cost.ks(src_partition, partition) = ks.complexity(src_lwe_dim); + operations.variance.set_variance( + keyswitch_noise(src_partition, partition), + this_ks.noise(src_lwe_dim), + ); + operations.cost.set_cost( + keyswitch(src_partition, partition), + this_ks.complexity(src_lwe_dim), + ); } else { assert!( !used_tlu_keyswitch[src_partition.0][partition.0], "Missing ks[{src_partition}->{partition}]" ); - *operations.variance.ks(src_partition, partition) = f64::MAX; - *operations.cost.ks(src_partition, partition) = f64::MAX; + operations + .variance + .set_variance(keyswitch_noise(src_partition, partition), f64::MAX); + operations + .cost + .set_cost(keyswitch(src_partition, partition), f64::MAX); } - if let Some(fks) = micro_params.fks[src_partition.0][partition.0] { + if let Some(this_fks) = micro_params.fks[src_partition.0][partition.0] { assert!( used_conversion_keyswitch[src_partition.0][partition.0], "Superflous fks[{src_partition}->{partition}]" ); - *operations.variance.fks(src_partition, partition) = fks.noise; - *operations.cost.fks(src_partition, partition) = fks.complexity; + operations.variance.set_variance( + fast_keyswitch_noise(src_partition, partition), + this_fks.noise, + ); + operations.cost.set_cost( + fast_keyswitch(src_partition, partition), + this_fks.complexity, + ); } else { assert!( !used_conversion_keyswitch[src_partition.0][partition.0], "Missing fks[{src_partition}->{partition}]" ); - *operations.variance.fks(src_partition, partition) = f64::MAX; - *operations.cost.fks(src_partition, partition) = f64::MAX; + operations + .variance + .set_variance(fast_keyswitch_noise(src_partition, partition), f64::MAX); + operations + .cost + .set_cost(fast_keyswitch(src_partition, partition), f64::MAX); } } } @@ -1185,7 +1256,7 @@ fn sanity_check( { assert!(feasible.feasible(&operations.variance)); assert!(params.p_error == feasible.p_error(&operations.variance)); - assert!(params.complexity == complexity.complexity(&operations.cost)); + assert!(params.complexity == complexity.evaluate_total_cost(&operations.cost)); assert!(params.global_p_error == feasible.global_p_error(&operations.variance)); } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index ab2c4d8f43..2d357affb8 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -420,7 +420,7 @@ fn optimize_sign_extract() { let free_small_input1 = dag.add_input(precision, Shape::number()); let small_input1 = dag.add_lut(free_small_input1, FunctionTable::UNKWOWN, precision); let small_input1 = dag.add_lut(small_input1, FunctionTable::UNKWOWN, high_precision); - let input1 = dag.add_levelled_op( + let input1 = dag.add_linear_noise( [small_input1], complexity, [1.0], @@ -571,7 +571,7 @@ fn test_chained_partitions_non_feasible_single_params() { let mut lut_input = dag.add_input(precisions[0], Shape::number()); for out_precision in precisions { let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64; - lut_input = dag.add_levelled_op( + lut_input = dag.add_linear_noise( [lut_input], LevelledComplexity::ZERO, [noise_factor], @@ -831,8 +831,8 @@ fn test_bug_with_zero_noise() { let out_shape = Shape::number(); let mut dag = unparametrized::Dag::new(); let v0 = dag.add_input(2, &out_shape); - let v1 = dag.add_levelled_op([v0], complexity, [0.0], &out_shape, "comment"); - let v2 = dag.add_levelled_op([v1], complexity, [1.0], &out_shape, "comment"); + let v1 = dag.add_linear_noise([v0], complexity, [0.0], &out_shape, "comment"); + let v2 = dag.add_linear_noise([v1], complexity, [1.0], &out_shape, "comment"); let v3 = dag.add_unsafe_cast(v2, 1); let _ = dag.add_lut(v3, FunctionTable { values: vec![] }, 1); let sol = optimize(&dag, &None, PartitionIndex(0)); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs index dbd5edc22f..6640487cb8 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs @@ -232,7 +232,9 @@ impl PartitionCut { for (op_i, op) in dag.operators.iter().enumerate() { match op { // propagate - Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { + Operator::Dot { inputs, .. } + | Operator::LinearNoise { inputs, .. } + | Operator::MaxNoise { inputs, .. } => { let mut origins = HashSet::default(); for input in inputs { origins.extend(&noise_origins[input.0]); @@ -251,6 +253,10 @@ impl PartitionCut { max_output_norm2[op_i] = 1.0; // initial value that can be maxed noise_origins[op_i] = std::iter::once(op_i).collect(); } + Operator::ZeroNoise { .. } => { + max_output_norm2[op_i] = 0.0; // initial value that can be maxed + noise_origins[op_i] = std::iter::once(op_i).collect(); + } Operator::ChangePartition { src_partition: Some(partition), dst_partition: None, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index e01648db12..4cc5c8b801 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -51,14 +51,16 @@ fn extract_levelled_block(dag: &unparametrized::Dag) -> Blocks { for (op_i, op) in dag.operators.iter().enumerate() { match op { // Block entry point - Operator::Input { .. } => (), + Operator::Input { .. } | Operator::ZeroNoise { .. } => (), // Block entry point and pre-exit point Op::Lut { .. } => (), // Connectors Op::UnsafeCast { input, .. } | Op::ChangePartition { input, .. } => { uf.union(input.0, op_i); } - Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => { + Op::LinearNoise { inputs, .. } + | Op::Dot { inputs, .. } + | Op::MaxNoise { inputs, .. } => { for input in inputs { uf.union(input.0, op_i); } @@ -119,13 +121,15 @@ fn only_1_partition(dag: &unparametrized::Dag) -> Partitions { vec![InstructionPartition::new(PartitionIndex::FIRST); dag.operators.len()]; for (op_i, op) in dag.operators.iter().enumerate() { match op { - Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => { + Op::Dot { inputs, .. } + | Op::LinearNoise { inputs, .. } + | Op::MaxNoise { inputs, .. } => { instrs_partition[op_i].inputs_transition = vec![None; inputs.len()]; } Op::Lut { .. } | Op::UnsafeCast { .. } | Operator::ChangePartition { .. } => { instrs_partition[op_i].inputs_transition = vec![None]; } - Op::Input { .. } => (), + Op::Input { .. } | Op::ZeroNoise { .. } => (), Op::Round { .. } => unreachable!(), } } @@ -216,7 +220,9 @@ fn resolve_by_levelled_block( HashSet::from([group_partition]); } } - Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => { + Op::LinearNoise { inputs, .. } + | Op::Dot { inputs, .. } + | Op::MaxNoise { inputs, .. } => { instrs_p[op_i].instruction_partition = group_partition; instrs_p[op_i].inputs_transition = vec![None; inputs.len()]; for (i, input) in inputs.iter().enumerate() { @@ -239,7 +245,9 @@ fn resolve_by_levelled_block( })] } } - Operator::Input { .. } => instrs_p[op_i].instruction_partition = group_partition, + Operator::Input { .. } | Operator::ZeroNoise { .. } => { + instrs_p[op_i].instruction_partition = group_partition + } Op::Round { .. } => unreachable!("Round should have been expanded"), } } @@ -479,7 +487,7 @@ pub mod tests { let external_input = dag.add_input(16, Shape::number()); let other_input = dag.add_input(16, Shape::number()); let other_lut = dag.add_lut(other_input, FunctionTable::UNKWOWN, 16); - let mix_add = dag.add_levelled_op( + let mix_add = dag.add_linear_noise( [external_input, other_lut], LevelledComplexity::ADDITION, [1.0, 1.0], diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic.rs new file mode 100644 index 0000000000..c5ae790571 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic.rs @@ -0,0 +1,136 @@ +use super::partitions::PartitionIndex; +use std::{collections::HashMap, fmt::Display}; + +/// A map associating symbols with values. +/// +/// By default all symbols are assumed to be associated with the default value +/// of the type T. In practice, only associations with non-default values are +/// stored in the map. +#[derive(Clone, Debug, PartialEq)] +pub struct SymbolMap(HashMap); + +impl SymbolMap { + /// Returns an empty symbol map. + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Update a symbol's value. + pub fn update T>(&mut self, sym: Symbol, f: F) { + let val = f(self.get(sym)); + if val != T::default() { + let _ = self.0.insert(sym, val); + } else { + let _ = self.0.remove(&sym); + } + } + + /// Sets a symbol's value. + pub fn set(&mut self, sym: Symbol, val: T) { + self.update(sym, |_| val) + } + + /// Returns the value associated with the symbol. + pub fn get(&self, sym: Symbol) -> T { + self.0.get(&sym).cloned().unwrap_or_default() + } + + /// Returns an iterator over associations with non-default values. + pub fn iter(&self) -> impl Iterator { + self.0 + .iter() + .map(|(k, v)| (*k, v.clone())) + .collect::>() + .into_iter() + } + + /// Consumes the symbol map and return an iterator. + pub fn into_iter(self) -> impl Iterator { + self.0.into_iter() + } + + /// Reset all associations to the default value. + pub fn clear(&mut self) { + self.0.clear(); + } + + #[allow(unused)] + pub(crate) fn len(&self) -> usize { + self.0.len() + } +} + +impl SymbolMap { + /// Formats the symbol map with a given separator and symbol prefix. + pub fn fmt_with( + &self, + f: &mut std::fmt::Formatter<'_>, + separator: &str, + sym_prefix: &str, + ) -> std::fmt::Result { + let mut terms = self.iter().collect::>(); + terms.sort_by_key(|t| t.0); + let mut terms = terms.into_iter(); + match terms.next() { + Some((sym, val)) => write!(f, "{val}{sym_prefix}{sym}")?, + None => return write!(f, "∅"), + } + for (sym, val) in terms { + write!(f, " {separator} {val}{sym_prefix}{sym}")?; + } + Ok(()) + } +} + +/// A symbol related to an fhe operation. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum Symbol { + Input(PartitionIndex), + Bootstrap(PartitionIndex), + Keyswitch(PartitionIndex, PartitionIndex), + FastKeyswitch(PartitionIndex, PartitionIndex), + ModulusSwitch(PartitionIndex), +} + +impl Display for Symbol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Symbol::Keyswitch(from, to) if from == to => write!(f, "K[{from}]"), + Symbol::Keyswitch(from, to) => write!(f, "K[{from}→{to}]"), + Symbol::FastKeyswitch(from, to) => write!(f, "FK[{from}→{to}]"), + Symbol::Bootstrap(p) => write!(f, "Br[{p}]"), + Symbol::Input(p) => write!(f, "In[{p}]"), + Symbol::ModulusSwitch(p) => write!(f, "M[{p}]"), + } + } +} + +/// Returns an input symbol. +#[allow(unused)] +pub fn input(partition: PartitionIndex) -> Symbol { + Symbol::Input(partition) +} + +/// Returns an keyswitch symbol. +#[allow(unused)] +pub fn keyswitch(from: PartitionIndex, to: PartitionIndex) -> Symbol { + Symbol::Keyswitch(from, to) +} + +/// Returns a fast keyswitch symbol. +#[allow(unused)] +pub fn fast_keyswitch(from: PartitionIndex, to: PartitionIndex) -> Symbol { + Symbol::FastKeyswitch(from, to) +} + +/// Returns a pbs symbol. +#[allow(unused)] +pub fn bootstrap(partition: PartitionIndex) -> Symbol { + Symbol::Bootstrap(partition) +} + +/// Returns a modulus switch symbol. +#[allow(unused)] +pub fn modulus_switching(partition: PartitionIndex) -> Symbol { + Symbol::Bootstrap(partition) +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs deleted file mode 100644 index 7b120a48a6..0000000000 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs +++ /dev/null @@ -1,264 +0,0 @@ -use std::fmt; - -use crate::optimization::dag::multi_parameters::operations_value::OperationsValue; - -use super::partitions::PartitionIndex; - -/** - * A variance that is represented as a linear combination of base variances. - * Only the linear coefficient are known. - * The base variances are unknown. - * - * Possible base variances: - * - fresh, - * - lut output, - * - keyswitch, - * - partition keyswitch, - * - modulus switching - * - * We only known that the fresh <= lut output in the same partition. - * Each linear coefficient is a variance factor. - * There are homogeneous to squared weight (or summed square weights or squared norm2). - */ -#[derive(Clone, Debug, PartialEq, PartialOrd)] -pub struct SymbolicVariance { - pub partition: PartitionIndex, - pub coeffs: OperationsValue, -} - -impl SymbolicVariance { - // To be used as a initial accumulator - pub const ZERO: Self = Self { - partition: PartitionIndex::FIRST, - coeffs: OperationsValue::ZERO, - }; - - pub fn nb_partitions(&self) -> usize { - self.coeffs.nb_partitions() - } - - pub fn nan(nb_partitions: usize) -> Self { - Self { - partition: PartitionIndex::INVALID, - coeffs: OperationsValue::nan(nb_partitions), - } - } - - pub fn input(nb_partitions: usize, partition: PartitionIndex) -> Self { - let mut r = Self { - partition, - coeffs: OperationsValue::zero(nb_partitions), - }; - // rust ..., offset cannot be inlined - *r.coeffs.input(partition) = 1.0; - r - } - - pub fn from_external_partition( - nb_partitions: usize, - partition: PartitionIndex, - max_variance: f64, - ) -> Self { - let mut r = Self { - partition, - coeffs: OperationsValue::zero(nb_partitions), - }; - // rust ..., offset cannot be inlined - *r.coeffs.pbs(partition) = max_variance; - r - } - - pub fn coeff_input(&self, partition: PartitionIndex) -> f64 { - self.coeffs[self.coeffs.index.input(partition)] - } - - pub fn after_pbs(nb_partitions: usize, partition: PartitionIndex) -> Self { - let mut r = Self { - partition, - coeffs: OperationsValue::zero(nb_partitions), - }; - *r.coeffs.pbs(partition) = 1.0; - r - } - - pub fn coeff_pbs(&self, partition: PartitionIndex) -> f64 { - self.coeffs[self.coeffs.index.pbs(partition)] - } - - pub fn coeff_modulus_switching(&self, partition: PartitionIndex) -> f64 { - self.coeffs[self.coeffs.index.modulus_switching(partition)] - } - - pub fn after_modulus_switching(&self, partition: PartitionIndex) -> Self { - let mut new = self.clone(); - let index = self.coeffs.index.modulus_switching(partition); - assert!(new.coeffs[index] == 0.0); - new.coeffs[index] = 1.0; - new - } - - pub fn coeff_keyswitch_to_small( - &self, - src_partition: PartitionIndex, - dst_partition: PartitionIndex, - ) -> f64 { - self.coeffs[self - .coeffs - .index - .keyswitch_to_small(src_partition, dst_partition)] - } - - pub fn after_partition_keyswitch_to_small( - &self, - src_partition: PartitionIndex, - dst_partition: PartitionIndex, - ) -> Self { - let index = self - .coeffs - .index - .keyswitch_to_small(src_partition, dst_partition); - self.after_partition_keyswitch(src_partition, dst_partition, index) - } - - pub fn coeff_partition_keyswitch_to_big( - &self, - src_partition: PartitionIndex, - dst_partition: PartitionIndex, - ) -> f64 { - self.coeffs[self - .coeffs - .index - .keyswitch_to_big(src_partition, dst_partition)] - } - - pub fn after_partition_keyswitch_to_big( - &self, - src_partition: PartitionIndex, - dst_partition: PartitionIndex, - ) -> Self { - let index = self - .coeffs - .index - .keyswitch_to_big(src_partition, dst_partition); - self.after_partition_keyswitch(src_partition, dst_partition, index) - } - - pub fn after_partition_keyswitch( - &self, - src_partition: PartitionIndex, - dst_partition: PartitionIndex, - index: usize, - ) -> Self { - assert!(src_partition.0 < self.nb_partitions()); - assert!(dst_partition.0 < self.nb_partitions()); - assert!(src_partition == self.partition); - let mut new = self.clone(); - new.partition = dst_partition; - new.coeffs[index] = 1.0; - new - } - - pub fn max(&self, other: &Self) -> Self { - let mut coeffs = self.coeffs.clone(); - for (i, coeff) in coeffs.iter_mut().enumerate() { - *coeff = coeff.max(other.coeffs[i]); - } - Self { coeffs, ..*self } - } - - pub fn compress(&self, detect_used: &[bool]) -> Self { - Self { - coeffs: self.coeffs.compress(detect_used), - ..(*self) - } - } -} - -impl fmt::Display for SymbolicVariance { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self == &Self::ZERO { - return write!(f, "ZERO x σ²"); - } - if self.coeffs[0].is_nan() { - return write!(f, "NAN x σ²"); - } - let mut add_plus = ""; - for src_partition in PartitionIndex::range(0, self.nb_partitions()) { - let coeff = self.coeff_input(src_partition); - if coeff != 0.0 { - write!(f, "{add_plus}{coeff}σ²In[{src_partition}]")?; - add_plus = " + "; - } - let coeff = self.coeff_pbs(src_partition); - if coeff != 0.0 { - write!(f, "{add_plus}{coeff}σ²Br[{src_partition}]")?; - add_plus = " + "; - } - for dst_partition in PartitionIndex::range(0, self.nb_partitions()) { - let coeff = self.coeff_partition_keyswitch_to_big(src_partition, dst_partition); - if coeff != 0.0 { - write!(f, "{add_plus}{coeff}σ²FK[{src_partition}→{dst_partition}]")?; - add_plus = " + "; - } - } - } - for src_partition in PartitionIndex::range(0, self.nb_partitions()) { - for dst_partition in PartitionIndex::range(0, self.nb_partitions()) { - let coeff = self.coeff_keyswitch_to_small(src_partition, dst_partition); - if coeff != 0.0 { - if src_partition == dst_partition { - write!(f, "{add_plus}{coeff}σ²K[{src_partition}]")?; - } else { - write!(f, "{add_plus}{coeff}σ²K[{src_partition}→{dst_partition}]")?; - } - add_plus = " + "; - } - } - } - for partition in PartitionIndex::range(0, self.nb_partitions()) { - let coeff = self.coeff_modulus_switching(partition); - if coeff != 0.0 { - write!(f, "{add_plus}{coeff}σ²M[{partition}]")?; - add_plus = " + "; - } - } - Ok(()) - } -} - -impl std::ops::Add for SymbolicVariance { - type Output = Self; - - fn add(mut self, rhs: Self) -> Self::Output { - if self.coeffs.is_empty() { - self = rhs; - } else { - for i in 0..self.coeffs.len() { - self.coeffs[i] += rhs.coeffs[i]; - } - }; - self - } -} - -impl std::ops::AddAssign for SymbolicVariance { - fn add_assign(&mut self, rhs: Self) { - if self.coeffs.is_empty() { - *self = rhs; - } else { - for i in 0..self.coeffs.len() { - self.coeffs[i] += rhs.coeffs[i]; - } - } - } -} - -impl std::ops::Mul for SymbolicVariance { - type Output = Self; - fn mul(self, sq_weight: f64) -> Self { - Self { - coeffs: self.coeffs * sq_weight, - ..self - } - } -} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs index c03b0f063f..67e9c40bc3 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs @@ -1,16 +1,20 @@ -use std::fmt; - use crate::dag::operator::{Location, Precision}; use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; -use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance; +use std::fmt; + +use super::noise_expression::{ + bootstrap_noise, fast_keyswitch_noise, input_noise, keyswitch_noise, modulus_switching_noise, + NoiseExpression, +}; #[derive(Clone, Debug, PartialEq)] pub struct VarianceConstraint { pub precision: Precision, pub partition: PartitionIndex, + pub nb_partitions: usize, pub nb_constraints: u64, pub safe_variance_bound: f64, - pub variance: SymbolicVariance, + pub noise_expression: NoiseExpression, pub location: Location, } @@ -19,7 +23,7 @@ impl fmt::Display for VarianceConstraint { write!( f, "{} < (2²)**{} ({}bits partition:{} count:{}, dom={})", - self.variance, + self.noise_expression, self.safe_variance_bound.log2().round() / 2.0, self.precision, self.partition, @@ -34,27 +38,27 @@ impl VarianceConstraint { #[allow(clippy::cast_sign_loss)] fn dominance_index(&self) -> u64 { let max_coeff = self - .variance - .coeffs - .iter() - .copied() - .reduce(f64::max) - .unwrap(); + .noise_expression + .terms_iter() + .map(|t| t.coefficient) + .fold(0.0, f64::max); (max_coeff / self.safe_variance_bound).log2().ceil() as u64 } fn dominate_or_equal(&self, other: &Self) -> bool { // With BR > Fresh - let self_var = &self.variance; - let other_var = &other.variance; + let self_var = &self.noise_expression; + let other_var = &other.noise_expression; let self_renorm = other.safe_variance_bound / self.safe_variance_bound; let rel_diff = - |f: &dyn Fn(&SymbolicVariance) -> f64| self_renorm * f(self_var) - f(other_var); - for partition in PartitionIndex::range(0, self.variance.nb_partitions()) { + |f: &dyn Fn(&NoiseExpression) -> f64| self_renorm * f(self_var) - f(other_var); + for partition in PartitionIndex::range(0, self.nb_partitions) { let diffs = [ - rel_diff(&|var| var.coeff_pbs(partition)), - rel_diff(&|var| var.coeff_pbs(partition) + var.coeff_input(partition)), - rel_diff(&|var| var.coeff_modulus_switching(partition)), + rel_diff(&|expr| expr.coeff(bootstrap_noise(partition))), + rel_diff(&|expr| { + expr.coeff(bootstrap_noise(partition)) + expr.coeff(input_noise(partition)) + }), + rel_diff(&|expr| expr.coeff(modulus_switching_noise(partition))), ]; for diff in diffs { if diff < 0.0 { @@ -62,12 +66,12 @@ impl VarianceConstraint { } } } - for src_partition in PartitionIndex::range(0, self.variance.nb_partitions()) { - for dst_partition in PartitionIndex::range(0, self.variance.nb_partitions()) { + for src_partition in PartitionIndex::range(0, self.nb_partitions) { + for dst_partition in PartitionIndex::range(0, self.nb_partitions) { let diffs = [ - rel_diff(&|var| var.coeff_keyswitch_to_small(src_partition, dst_partition)), - rel_diff(&|var| { - var.coeff_partition_keyswitch_to_big(src_partition, dst_partition) + rel_diff(&|expr| expr.coeff(keyswitch_noise(src_partition, dst_partition))), + rel_diff(&|expr| { + expr.coeff(fast_keyswitch_noise(src_partition, dst_partition)) }), ]; for diff in diffs { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 484338b123..96b04008a6 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -25,7 +25,7 @@ fn assert_all_same( } fn assert_inputs_uniform_precisions(op: &Operator, out_precisions: &[Precision]) { - if let Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } = op { + if let Operator::Dot { inputs, .. } | Operator::LinearNoise { inputs, .. } = op { assert_all_same(inputs, out_precisions); } } @@ -37,21 +37,21 @@ fn assert_dot_uniform_inputs_shape(op: &Operator, out_shapes: &[Shape]) { } fn assert_non_empty_inputs(op: &Operator) { - if let Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } = op { + if let Operator::Dot { inputs, .. } | Operator::LinearNoise { inputs, .. } = op { assert!(!inputs.is_empty()); } } fn assert_inputs_index(op: &Operator, first_bad_index: usize) { let valid = match op { - Operator::Input { .. } => true, + Operator::Input { .. } | Operator::ZeroNoise { .. } => true, Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } | Operator::Round { input, .. } | Operator::ChangePartition { input, .. } => input.0 < first_bad_index, - Operator::LevelledOp { inputs, .. } | Operator::Dot { inputs, .. } => { - inputs.iter().all(|input| input.0 < first_bad_index) - } + Operator::LinearNoise { inputs, .. } + | Operator::Dot { inputs, .. } + | Operator::MaxNoise { inputs, .. } => inputs.iter().all(|input| input.0 < first_bad_index), }; assert!(valid, "Invalid dag, bad index in op: {op:?}"); } @@ -148,8 +148,9 @@ fn out_variance( // TODO: track each elements instead of container match op { Operator::Input { .. } => SymbolicVariance::INPUT, + Operator::ZeroNoise { .. } => SymbolicVariance::ZERO, Operator::Lut { .. } => SymbolicVariance::LUT, - Operator::LevelledOp { + Operator::LinearNoise { inputs, weights, .. } => inputs .iter() @@ -158,6 +159,15 @@ fn out_variance( .fold(SymbolicVariance::ZERO, |acc, (var, &weight)| { acc + var * square(weight) }), + Operator::MaxNoise { inputs, .. } => { + inputs + .iter() + .map(|i| out_variances[i.0]) + .fold(SymbolicVariance::ZERO, |acc, var| SymbolicVariance { + lut_coeff: acc.lut_coeff.max(var.lut_coeff), + input_coeff: acc.input_coeff.max(var.input_coeff), + }) + } Operator::Dot { kind: DotKind::CompatibleTensor { .. }, .. @@ -251,8 +261,10 @@ fn op_levelled_complexity(op: &Operator, out_shapes: &[Shape]) -> LevelledComple * out_shapes[inputs[0].0].flat_size() } - Operator::LevelledOp { complexity, .. } => *complexity, + Operator::LinearNoise { complexity, .. } => *complexity, Operator::Input { .. } + | Operator::ZeroNoise { .. } + | Operator::MaxNoise { .. } | Operator::Lut { .. } | Operator::UnsafeCast { .. } | Operator::ChangePartition { .. } => LevelledComplexity::ZERO, @@ -710,7 +722,7 @@ pub mod tests { let weights = Weights::vector([1, 2]); #[allow(clippy::imprecise_flops)] let dot = - graph.add_levelled_op([input1, input1], cpx_dot, [1., 2.], Shape::number(), "dot"); + graph.add_linear_noise([input1, input1], cpx_dot, [1., 2.], Shape::number(), "dot"); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 01bcb4e53d..265c45b688 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -390,9 +390,9 @@ pub fn add_v0_dag(dag: &mut unparametrized::Dag, sum_size: u64, precision: u64, let comment = "dot"; let precision = precision as crate::dag::operator::Precision; let input1 = dag.add_input(precision, out_shape); - let dot1 = dag.add_levelled_op([input1], complexity, [1.0], out_shape, comment); + let dot1 = dag.add_linear_noise([input1], complexity, [1.0], out_shape, comment); let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); - let dot2 = dag.add_levelled_op([lut1], complexity, [manp], out_shape, comment); + let dot2 = dag.add_linear_noise([lut1], complexity, [manp], out_shape, comment); let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs index a54068ede6..c42f9ad303 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs @@ -92,15 +92,21 @@ impl<'dag> Viz for crate::dag::unparametrized::DagOperator<'dag> { Operator::Input { out_precision, .. } => { format!("{index} [label =\"{{%{index} = Input({input_string}) |{{out_precision:|{out_precision:?}}} | {{loc:|{location}}}}}\" fillcolor={color}];") } + Operator::ZeroNoise { out_precision, .. } => { + format!("{index} [label =\"{{%{index} = Zero({input_string}) |{{out_precision:|{out_precision:?}}} | {{loc:|{location}}}}}\" fillcolor={color}];") + } Operator::Lut { out_precision, .. } => { format!("{index} [label = \"{{%{index} = Lut({input_string}) |{{out_precision:|{out_precision:?}}}| {{loc:|{location}}}}}\" fillcolor={color}];") } Operator::Dot { .. } => { format!("{index} [label = \"{{%{index} = Dot({input_string})| {{loc:|{location}}}}}\" fillcolor={color}];") } - Operator::LevelledOp { weights, .. } => { + Operator::LinearNoise { weights, .. } => { format!("{index} [label = \"{{%{index} = LevelledOp({input_string}) |{{weights:|{weights:?}}}| {{loc:|{location}}}}}\" fillcolor={color}];") } + Operator::MaxNoise { .. } => { + format!("{index} [label = \"{{%{index} = Max({input_string}) | {{loc:|{location}}}}}\" fillcolor={color}];") + } Operator::UnsafeCast { out_precision, .. } => format!( "{index} [label = \"{{%{index} = UnsafeCast({input_string}) |{{out_precision:|{out_precision:?}}}| {{loc:|{location}}}}}\" fillcolor={color}];" ), diff --git a/docs/explanations/FHEDialect.md b/docs/explanations/FHEDialect.md index fcb657181e..d9316cdc55 100644 --- a/docs/explanations/FHEDialect.md +++ b/docs/explanations/FHEDialect.md @@ -29,7 +29,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: AdditiveNoise, Binary, BinaryEintInt, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: Binary, BinaryEintInt, ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -581,7 +581,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: AdditiveNoise, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), UnaryEint +Interfaces: ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface), UnaryEint Effects: MemoryEffects::Effect{} @@ -698,7 +698,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: AdditiveNoise, Binary, BinaryEintInt, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: Binary, BinaryEintInt, ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -774,7 +774,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: AdditiveNoise, Binary, BinaryIntEint, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: Binary, BinaryIntEint, ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} diff --git a/docs/explanations/FHELinalgDialect.md b/docs/explanations/FHELinalgDialect.md index 59c6f4dcea..51e332d938 100644 --- a/docs/explanations/FHELinalgDialect.md +++ b/docs/explanations/FHELinalgDialect.md @@ -48,7 +48,7 @@ Examples: Traits: AlwaysSpeculatableImplTrait, TensorBinaryEintInt, TensorBroadcastingRules -Interfaces: Binary, BinaryEintInt, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: Binary, BinaryEintInt, ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -350,7 +350,7 @@ Examples: Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, ConstantNoise, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: ConditionallySpeculatable, ConstantNoise, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -397,7 +397,7 @@ Examples: Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -560,7 +560,7 @@ Notes: Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -614,7 +614,7 @@ Examples: Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -643,7 +643,7 @@ Creates a tensor with a single element. Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -1300,7 +1300,7 @@ Examples: Traits: AlwaysSpeculatableImplTrait, TensorUnaryEint -Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), UnaryEint +Interfaces: ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface), UnaryEint Effects: MemoryEffects::Effect{} @@ -1435,7 +1435,7 @@ Examples: Traits: AlwaysSpeculatableImplTrait, TensorBinaryEintInt, TensorBroadcastingRules -Interfaces: Binary, BinaryEintInt, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: Binary, BinaryEintInt, ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -1548,7 +1548,7 @@ Examples: Traits: AlwaysSpeculatableImplTrait, TensorBinaryIntEint, TensorBroadcastingRules -Interfaces: Binary, BinaryIntEint, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: Binary, BinaryIntEint, ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -1761,7 +1761,7 @@ Examples: Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), UnaryEint +Interfaces: ConditionallySpeculatable, MaxNoise, NoMemoryEffect (MemoryEffectOpInterface), UnaryEint Effects: MemoryEffects::Effect{}