Skip to content

Commit

Permalink
BGV: Count Add/Keyswitch in one level for OpenFHE
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Jan 11, 2025
1 parent a8f17da commit 6a581df
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#include "lib/Analysis/AddAndKeySwitchCountAnalysis/AddAndKeySwitchCountAnalysis.h"

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {

LogicalResult CountAnalysis::visitOperation(
Operation *op, ArrayRef<const CountLattice *> operands,
ArrayRef<CountLattice *> results) {
auto propagate = [&](Value value, const CountState &state) {
auto *lattice = getLatticeElement(value);
ChangeResult changed = lattice->join(state);
propagateIfChanged(lattice, changed);
};

llvm::TypeSwitch<Operation &>(*op)
.Case<secret::GenericOp>([&](auto genericOp) {
Block *body = genericOp.getBody();
for (auto i = 0; i != body->getNumArguments(); ++i) {
auto blockArg = body->getArgument(i);
// one Vfresh
propagate(blockArg, CountState(1, 0));
}
})
.Case<arith::AddIOp, arith::SubIOp>([&](auto &op) {
// condition on result secretness
SmallVector<OpResult> secretResults;
getSecretResults(op, secretResults);
if (secretResults.empty()) {
return;
}

CountState zeroState(0, 0);
SmallVector<OpOperand *> secretOperands;
getSecretOperands(op, secretOperands);
for (auto *operand : secretOperands) {
auto countState = operands[operand->getOperandNumber()]->getValue();
zeroState = zeroState + countState;
}

for (auto result : secretResults) {
propagate(result, zeroState);
}
})
.Case<arith::MulIOp>([&](auto &op) {
SmallVector<OpResult> secretResults;
getSecretResults(op, secretResults);
if (secretResults.empty()) {
return;
}

// now noise is Vmult
// TODO(#1168): we can actually do a more fine grained analysis here
// distinguishing ct-ct and ct-pt
propagate(op.getResult(), CountState(1, 0));
})
.Case<mgmt::RelinearizeOp, tensor_ext::RotateOp>([&](auto &op) {
auto secretness = ensureSecretness(op, op->getOperand(0));
if (!secretness) {
return;
}

auto state = operands[0]->getValue();
if (!state.isInitialized()) {
return;
}

propagate(op.getResult(), state.keySwitch());
})
// TODO(#1174): in BGV tensor::ExtractOp is assumed to be always
// mul+const
.Case<tensor::ExtractOp>([&](auto &op) {
auto secretness = ensureSecretness(op, op.getResult());
if (!secretness) {
return;
}

// now noise is Vmult + one Vks
propagate(op.getResult(), CountState(1, 1));
})
.Case<mgmt::ModReduceOp>([&](auto modReduceOp) {
// implicitly ensure that the operand is secret

propagate(modReduceOp.getResult(), CountState(0, 0));
});
// should not propagate through mgmt::ModReduceOp
return success();
}

void annotateCount(Operation *top, DataFlowSolver *solver) {
auto getIntegerAttr = [&](int level) {
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), level);
};

auto maxAddCount = 0;
auto maxKeySwitchCount = 0;

auto getCount = [&](Value value) {
auto state = solver->lookupState<CountLattice>(value)->getValue();
// update the max
maxAddCount = std::max(maxAddCount, state.getAddCount());
maxKeySwitchCount = std::max(maxKeySwitchCount, state.getKeySwitchCount());
return std::make_tuple(state.getAddCount(), state.getKeySwitchCount());
};

top->walk<WalkOrder::PreOrder>([&](secret::GenericOp genericOp) {
for (auto i = 0; i != genericOp.getBody()->getNumArguments(); ++i) {
auto blockArg = genericOp.getBody()->getArgument(i);
auto [addCount, keySwitchCount] = getCount(blockArg);
if (addCount != 0) {
genericOp.setArgAttr(i, "addCount", getIntegerAttr(addCount));
}
if (keySwitchCount != 0) {
genericOp.setArgAttr(i, "keySwitchCount",
getIntegerAttr(keySwitchCount));
}
}

genericOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op->getNumResults() == 0) {
return;
}
auto [addCount, keySwitchCount] = getCount(op->getResult(0));
if (addCount != 0) {
op->setAttr("addCount", getIntegerAttr(addCount));
}
if (keySwitchCount != 0) {
op->setAttr("keySwitchCount", getIntegerAttr(keySwitchCount));
}
});

// annotate mgmt::OpenfheParamsAttr to func::FuncOp containing the genericOp
auto *funcOp = genericOp->getParentOp();
auto openfheParamAttr = mgmt::OpenfheParamsAttr::get(
funcOp->getContext(), maxAddCount, maxKeySwitchCount);
funcOp->setAttr(mgmt::MgmtDialect::kArgOpenfheParamsAttrName,
openfheParamAttr);
});
}

} // namespace heir
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#ifndef LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_
#define LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project

namespace mlir {
namespace heir {

// This analysis should be used after --secret-with-mgmt-bgv
// but before --secret-distribute-generic
// where a whole secret::GenericOp is assumed
//
// this follows strictly the strategy of modulus switch
// right before multiplication including the first.
// namely FLEXIBLEAUTOEXT in OpenFHE
//
// OpenFHE only supports setting EvalAddCount/EvalKeySwitchCount
// for BGV/BFV, so HEIR only supports BGV here
//
// For not include-the-first, the addCount for the L-th level
// might be overestimated, as we can not distinguish between
// Vmult and Vfresh

class CountState {
public:
CountState() : initialized(false), addCount(0), keySwitchCount(0) {}
explicit CountState(int addCount, int keySwitchCount)
: initialized(true), addCount(addCount), keySwitchCount(keySwitchCount) {}
~CountState() = default;

int getAddCount() const {
assert(isInitialized());
return addCount;
}

int getKeySwitchCount() const {
assert(isInitialized());
return keySwitchCount;
}

bool operator==(const CountState &rhs) const {
return initialized == rhs.initialized && addCount == rhs.addCount &&
keySwitchCount == rhs.keySwitchCount;
}

bool isInitialized() const { return initialized; }

CountState operator+(const CountState &rhs) const {
assert(isInitialized() && rhs.isInitialized());
return CountState{addCount + rhs.addCount,
keySwitchCount + rhs.keySwitchCount};
}

CountState keySwitch() const {
assert(isInitialized());
return CountState{addCount, keySwitchCount + 1};
}

CountState max(const CountState &rhs) const {
assert(isInitialized() && rhs.isInitialized());
return CountState{std::max(addCount, rhs.addCount),
std::max(keySwitchCount, rhs.keySwitchCount)};
}

static CountState join(const CountState &lhs, const CountState &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;

return lhs.max(rhs);
}

void print(llvm::raw_ostream &os) const {
if (isInitialized()) {
os << "CountState(" << addCount << ", " << keySwitchCount << ")";
} else {
os << "CountState(uninitialized)";
}
}

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const CountState &state) {
state.print(os);
return os;
}

private:
bool initialized;
int addCount; // how many Vmult or Vfresh (before first multiplication)
// encountered
int keySwitchCount;
};

class CountLattice : public dataflow::Lattice<CountState> {
public:
using Lattice::Lattice;
};

class CountAnalysis
: public dataflow::SparseForwardDataFlowAnalysis<CountLattice>,
public SecretnessAnalysisDependent<CountAnalysis> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
friend class SecretnessAnalysisDependent<CountAnalysis>;

void setToEntryState(CountLattice *lattice) override {
propagateIfChanged(lattice, lattice->join(CountState()));
}

LogicalResult visitOperation(Operation *op,
ArrayRef<const CountLattice *> operands,
ArrayRef<CountLattice *> results) override;
};

void annotateCount(Operation *top, DataFlowSolver *solver);

} // namespace heir
} // namespace mlir

#endif // LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_
20 changes: 20 additions & 0 deletions lib/Analysis/AddAndKeySwitchCountAnalysis/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "AddAndKeySwitchCountAnalysis",
srcs = ["AddAndKeySwitchCountAnalysis.cpp"],
hdrs = ["AddAndKeySwitchCountAnalysis.h"],
deps = [
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
28 changes: 28 additions & 0 deletions lib/Dialect/Mgmt/IR/MgmtAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,33 @@ def Mgmt_MgmtAttr : Mgmt_Attr<"Mgmt", "mgmt"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def Mgmt_OpenfheParamsAttr : Mgmt_Attr<"OpenfheParams", "openfhe_params"> {
let summary = "Container attribute for some OpenFHE-specific parameters";
let description = [{
This attribute is used to store some OpenFHE-specific parameters.

The attribute is a struct with the following fields:
- `evalAddCount` : param for OpenFHE SetEvalAddCount
- `keySwitchCount` : param for OpenFHE SetKeySwitchCount

When this attribute presents, the lowering of openfhe pass
will use these parameters to set the corresponding OpenFHE
parameters.

It should be populated by --secret-with-mgmt-bgv before
going through the secret-to-bgv bgv-to-openfhe pass.

Example:
```
#openfhe_params = #mgmt.openfhe_params<evalAddCount = 1, keySwitchCount = 1>
```
}];

let parameters = (ins
"int": $evalAddCount,
"int": $keySwitchCount
);
let assemblyFormat = "`<` struct(params) `>`";
}

#endif // LIB_DIALECT_MGMT_IR_MGMTATTRIBUTES_TD_
4 changes: 4 additions & 0 deletions lib/Dialect/Mgmt/IR/MgmtDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def Mgmt_Dialect : Dialect {
/// Name of the attribute holding MgmtAttr
constexpr const static ::llvm::StringLiteral
kArgMgmtAttrName = "mgmt.mgmt";

/// Name of the attribute holding OpenfheParamsAttr
constexpr const static ::llvm::StringLiteral
kArgOpenfheParamsAttrName = "mgmt.openfhe_params";
}];


Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def GenParamsOp : Openfhe_Op<"gen_params"> {
let arguments = (ins
I64Attr:$mulDepth,
I64Attr:$plainMod,
BoolAttr:$insecure
BoolAttr:$insecure,
I64Attr:$evalAddCount,
I64Attr:$KeySwitchCount
);
let results = (outs Openfhe_CCParams:$params);
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Openfhe/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cc_library(
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/Openfhe/IR:Dialect",
"@heir//lib/Dialect/RNS/IR:Dialect",
Expand Down
15 changes: 14 additions & 1 deletion lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>

#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"
Expand Down Expand Up @@ -102,9 +103,21 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
}
}

// get evalAddCount/KeySwitchCount from func attribute, if present
int64_t evalAddCount = 0;
int64_t keySwitchCount = 0;
if (auto openfheParamAttr = mlir::dyn_cast<mgmt::OpenfheParamsAttr>(
op->getAttr(mgmt::MgmtDialect::kArgOpenfheParamsAttrName))) {
evalAddCount = openfheParamAttr.getEvalAddCount();
keySwitchCount = openfheParamAttr.getKeySwitchCount();
// remove the attribute after reading
op->removeAttr(mgmt::MgmtDialect::kArgOpenfheParamsAttrName);
}

Type openfheParamsType = openfhe::CCParamsType::get(builder.getContext());
Value ccParams = builder.create<openfhe::GenParamsOp>(
openfheParamsType, mulDepth, plainMod, insecure);
openfheParamsType, mulDepth, plainMod, insecure, evalAddCount,
keySwitchCount);
Value cryptoContext = builder.create<openfhe::GenContextOp>(
openfheContextType, ccParams,
BoolAttr::get(builder.getContext(), hasBootstrapOp));
Expand Down
Loading

0 comments on commit 6a581df

Please sign in to comment.