Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BGV: Count Add/Keyswitch in one level for OpenFHE #1254

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#include "lib/Analysis/AddAndKeySwitchCountAnalysis/AddAndKeySwitchCountAnalysis.h"

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {

LogicalResult CountAnalysis::visitOperation(
Operation *op, ArrayRef<const CountLattice *> operands,
ArrayRef<CountLattice *> results) {
auto propagate = [&](Value value, const CountState &state) {
auto *lattice = getLatticeElement(value);
ChangeResult changed = lattice->join(state);
propagateIfChanged(lattice, changed);
};

llvm::TypeSwitch<Operation &>(*op)
.Case<secret::GenericOp>([&](auto genericOp) {
Block *body = genericOp.getBody();
for (auto i = 0; i != body->getNumArguments(); ++i) {
auto blockArg = body->getArgument(i);
// one Vfresh
propagate(blockArg, CountState(1, 0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to my other comment, naively I would expect zero adds here, not 1, and below (line 53) it appears you're adding the countState once per operand, rather than per op.

}
})
.Case<arith::AddIOp, arith::SubIOp, arith::AddFOp, arith::SubFOp>(
[&](auto &op) {
// condition on result secretness
SmallVector<OpResult> secretResults;
getSecretResults(op, secretResults);
if (secretResults.empty()) {
return;
}

CountState zeroState(0, 0);
SmallVector<OpOperand *> secretOperands;
getSecretOperands(op, secretOperands);
for (auto *operand : secretOperands) {
auto countState =
operands[operand->getOperandNumber()]->getValue();
zeroState = zeroState + countState;
}

for (auto result : secretResults) {
propagate(result, zeroState);
}
})
.Case<arith::MulIOp, arith::MulFOp>([&](auto &op) {
SmallVector<OpResult> secretResults;
getSecretResults(op, secretResults);
if (secretResults.empty()) {
return;
}

// now noise is Vmult
// TODO(#1168): we can actually do a more fine grained analysis here
// distinguishing ct-ct and ct-pt
propagate(op.getResult(), CountState(1, 0));
})
.Case<mgmt::RelinearizeOp, tensor_ext::RotateOp>([&](auto &op) {
auto secretness = ensureSecretness(op, op->getOperand(0));
if (!secretness) {
return;
}

auto state = operands[0]->getValue();
if (!state.isInitialized()) {
return;
}

propagate(op.getResult(), state.keySwitch());
})
// TODO(#1174): in BGV tensor::ExtractOp is assumed to be always
// mul+const
.Case<tensor::ExtractOp>([&](auto &op) {
auto secretness = ensureSecretness(op, op.getResult());
if (!secretness) {
return;
}

// now noise is Vmult + one Vks
propagate(op.getResult(), CountState(1, 1));
})
.Case<mgmt::ModReduceOp>([&](auto modReduceOp) {
// implicitly ensure that the operand is secret

propagate(modReduceOp.getResult(), CountState(0, 0));
});
// should not propagate through mgmt::ModReduceOp
return success();
}

void annotateCount(Operation *top, DataFlowSolver *solver) {
auto getIntegerAttr = [&](int level) {
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), level);
};

auto maxAddCount = 0;
auto maxKeySwitchCount = 0;

auto getCount = [&](Value value) {
auto state = solver->lookupState<CountLattice>(value)->getValue();
// update the max
maxAddCount = std::max(maxAddCount, state.getAddCount());
maxKeySwitchCount = std::max(maxKeySwitchCount, state.getKeySwitchCount());
return std::make_tuple(state.getAddCount(), state.getKeySwitchCount());
};

top->walk<WalkOrder::PreOrder>([&](secret::GenericOp genericOp) {
for (auto i = 0; i != genericOp.getBody()->getNumArguments(); ++i) {
auto blockArg = genericOp.getBody()->getArgument(i);
auto [addCount, keySwitchCount] = getCount(blockArg);
if (addCount != 0) {
genericOp.setArgAttr(i, "addCount", getIntegerAttr(addCount));
}
if (keySwitchCount != 0) {
genericOp.setArgAttr(i, "keySwitchCount",
getIntegerAttr(keySwitchCount));
}
}

genericOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op->getNumResults() == 0) {
return;
}
auto [addCount, keySwitchCount] = getCount(op->getResult(0));
if (addCount != 0) {
op->setAttr("addCount", getIntegerAttr(addCount));
}
if (keySwitchCount != 0) {
op->setAttr("keySwitchCount", getIntegerAttr(keySwitchCount));
}
});

// annotate mgmt::OpenfheParamsAttr to func::FuncOp containing the genericOp
auto *funcOp = genericOp->getParentOp();
auto openfheParamAttr = mgmt::OpenfheParamsAttr::get(
funcOp->getContext(), maxAddCount, maxKeySwitchCount);
funcOp->setAttr(mgmt::MgmtDialect::kArgOpenfheParamsAttrName,
openfheParamAttr);
});
}

} // namespace heir
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#ifndef LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_
#define LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project

namespace mlir {
namespace heir {

// This analysis should be used after --secret-with-mgmt-bgv
// but before --secret-distribute-generic
// where a whole secret::GenericOp is assumed
//
// this follows strictly the strategy of modulus switch
// right before multiplication including the first.
// namely FLEXIBLEAUTOEXT in OpenFHE
//
// OpenFHE only supports setting EvalAddCount/EvalKeySwitchCount
// for BGV/BFV, so HEIR only supports BGV here
//
// For not include-the-first, the addCount for the L-th level
// might be overestimated, as we can not distinguish between
// Vmult and Vfresh

class CountState {
public:
CountState() : initialized(false), addCount(0), keySwitchCount(0) {}
explicit CountState(int addCount, int keySwitchCount)
: initialized(true), addCount(addCount), keySwitchCount(keySwitchCount) {}
~CountState() = default;

int getAddCount() const {
assert(isInitialized());
return addCount;
}

int getKeySwitchCount() const {
assert(isInitialized());
return keySwitchCount;
}

bool operator==(const CountState &rhs) const {
return initialized == rhs.initialized && addCount == rhs.addCount &&
keySwitchCount == rhs.keySwitchCount;
}

bool isInitialized() const { return initialized; }

CountState operator+(const CountState &rhs) const {
assert(isInitialized() && rhs.isInitialized());
return CountState{addCount + rhs.addCount,
keySwitchCount + rhs.keySwitchCount};
}

CountState keySwitch() const {
assert(isInitialized());
return CountState{addCount, keySwitchCount + 1};
}

CountState max(const CountState &rhs) const {
assert(isInitialized() && rhs.isInitialized());
return CountState{std::max(addCount, rhs.addCount),
std::max(keySwitchCount, rhs.keySwitchCount)};
}

static CountState join(const CountState &lhs, const CountState &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;

return lhs.max(rhs);
}

void print(llvm::raw_ostream &os) const {
if (isInitialized()) {
os << "CountState(" << addCount << ", " << keySwitchCount << ")";
} else {
os << "CountState(uninitialized)";
}
}

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const CountState &state) {
state.print(os);
return os;
}

private:
bool initialized;
int addCount; // how many Vmult or Vfresh (before first multiplication)
// encountered
int keySwitchCount;
};

class CountLattice : public dataflow::Lattice<CountState> {
public:
using Lattice::Lattice;
};

class CountAnalysis
: public dataflow::SparseForwardDataFlowAnalysis<CountLattice>,
public SecretnessAnalysisDependent<CountAnalysis> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
friend class SecretnessAnalysisDependent<CountAnalysis>;

void setToEntryState(CountLattice *lattice) override {
propagateIfChanged(lattice, lattice->join(CountState()));
}

LogicalResult visitOperation(Operation *op,
ArrayRef<const CountLattice *> operands,
ArrayRef<CountLattice *> results) override;
};

void annotateCount(Operation *top, DataFlowSolver *solver);

} // namespace heir
} // namespace mlir

#endif // LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_
20 changes: 20 additions & 0 deletions lib/Analysis/AddAndKeySwitchCountAnalysis/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "AddAndKeySwitchCountAnalysis",
srcs = ["AddAndKeySwitchCountAnalysis.cpp"],
hdrs = ["AddAndKeySwitchCountAnalysis.h"],
deps = [
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
28 changes: 28 additions & 0 deletions lib/Dialect/Mgmt/IR/MgmtAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,33 @@ def Mgmt_MgmtAttr : Mgmt_Attr<"Mgmt", "mgmt"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def Mgmt_OpenfheParamsAttr : Mgmt_Attr<"OpenfheParams", "openfhe_params"> {
let summary = "Container attribute for some OpenFHE-specific parameters";
let description = [{
This attribute is used to store some OpenFHE-specific parameters.

The attribute is a struct with the following fields:
- `evalAddCount` : param for OpenFHE SetEvalAddCount
- `keySwitchCount` : param for OpenFHE SetKeySwitchCount

When this attribute presents, the lowering of openfhe pass
will use these parameters to set the corresponding OpenFHE
parameters.

It should be populated by --secret-with-mgmt-bgv before
going through the secret-to-bgv bgv-to-openfhe pass.

Example:
```
#openfhe_params = #mgmt.openfhe_params<evalAddCount = 1, keySwitchCount = 1>
```
}];

let parameters = (ins
"int": $evalAddCount,
"int": $keySwitchCount
);
let assemblyFormat = "`<` struct(params) `>`";
}

#endif // LIB_DIALECT_MGMT_IR_MGMTATTRIBUTES_TD_
4 changes: 4 additions & 0 deletions lib/Dialect/Mgmt/IR/MgmtDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def Mgmt_Dialect : Dialect {
/// Name of the attribute holding MgmtAttr
constexpr const static ::llvm::StringLiteral
kArgMgmtAttrName = "mgmt.mgmt";

/// Name of the attribute holding OpenfheParamsAttr
constexpr const static ::llvm::StringLiteral
kArgOpenfheParamsAttrName = "mgmt.openfhe_params";
}];


Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def GenParamsOp : Openfhe_Op<"gen_params"> {
let arguments = (ins
I64Attr:$mulDepth,
I64Attr:$plainMod,
BoolAttr:$insecure
BoolAttr:$insecure,
I64Attr:$evalAddCount,
I64Attr:$keySwitchCount
);
let results = (outs Openfhe_CCParams:$params);
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Openfhe/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cc_library(
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/Openfhe/IR:Dialect",
"@heir//lib/Dialect/RNS/IR:Dialect",
Expand Down
15 changes: 14 additions & 1 deletion lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>

#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"
Expand Down Expand Up @@ -102,9 +103,21 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
}
}

// get evalAddCount/KeySwitchCount from func attribute, if present
int64_t evalAddCount = 0;
int64_t keySwitchCount = 0;
if (auto openfheParamsAttr = op->getAttrOfType<mgmt::OpenfheParamsAttr>(
mgmt::MgmtDialect::kArgOpenfheParamsAttrName)) {
evalAddCount = openfheParamsAttr.getEvalAddCount();
keySwitchCount = openfheParamsAttr.getKeySwitchCount();
// remove the attribute after reading
op->removeAttr(mgmt::MgmtDialect::kArgOpenfheParamsAttrName);
}

Type openfheParamsType = openfhe::CCParamsType::get(builder.getContext());
Value ccParams = builder.create<openfhe::GenParamsOp>(
openfheParamsType, mulDepth, plainMod, insecure);
openfheParamsType, mulDepth, plainMod, insecure, evalAddCount,
keySwitchCount);
Value cryptoContext = builder.create<openfhe::GenContextOp>(
openfheContextType, ccParams,
BoolAttr::get(builder.getContext(), hasBootstrapOp));
Expand Down
Loading
Loading