-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
BGV: Count Add/Keyswitch in one level for OpenFHE
- Loading branch information
1 parent
a8f17da
commit 6a581df
Showing
11 changed files
with
367 additions
and
2 deletions.
There are no files selected for viewing
155 changes: 155 additions & 0 deletions
155
lib/Analysis/AddAndKeySwitchCountAnalysis/AddAndKeySwitchCountAnalysis.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
#include "lib/Analysis/AddAndKeySwitchCountAnalysis/AddAndKeySwitchCountAnalysis.h" | ||
|
||
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" | ||
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h" | ||
#include "lib/Dialect/Mgmt/IR/MgmtOps.h" | ||
#include "lib/Dialect/Secret/IR/SecretOps.h" | ||
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h" | ||
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project | ||
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
LogicalResult CountAnalysis::visitOperation( | ||
Operation *op, ArrayRef<const CountLattice *> operands, | ||
ArrayRef<CountLattice *> results) { | ||
auto propagate = [&](Value value, const CountState &state) { | ||
auto *lattice = getLatticeElement(value); | ||
ChangeResult changed = lattice->join(state); | ||
propagateIfChanged(lattice, changed); | ||
}; | ||
|
||
llvm::TypeSwitch<Operation &>(*op) | ||
.Case<secret::GenericOp>([&](auto genericOp) { | ||
Block *body = genericOp.getBody(); | ||
for (auto i = 0; i != body->getNumArguments(); ++i) { | ||
auto blockArg = body->getArgument(i); | ||
// one Vfresh | ||
propagate(blockArg, CountState(1, 0)); | ||
} | ||
}) | ||
.Case<arith::AddIOp, arith::SubIOp>([&](auto &op) { | ||
// condition on result secretness | ||
SmallVector<OpResult> secretResults; | ||
getSecretResults(op, secretResults); | ||
if (secretResults.empty()) { | ||
return; | ||
} | ||
|
||
CountState zeroState(0, 0); | ||
SmallVector<OpOperand *> secretOperands; | ||
getSecretOperands(op, secretOperands); | ||
for (auto *operand : secretOperands) { | ||
auto countState = operands[operand->getOperandNumber()]->getValue(); | ||
zeroState = zeroState + countState; | ||
} | ||
|
||
for (auto result : secretResults) { | ||
propagate(result, zeroState); | ||
} | ||
}) | ||
.Case<arith::MulIOp>([&](auto &op) { | ||
SmallVector<OpResult> secretResults; | ||
getSecretResults(op, secretResults); | ||
if (secretResults.empty()) { | ||
return; | ||
} | ||
|
||
// now noise is Vmult | ||
// TODO(#1168): we can actually do a more fine grained analysis here | ||
// distinguishing ct-ct and ct-pt | ||
propagate(op.getResult(), CountState(1, 0)); | ||
}) | ||
.Case<mgmt::RelinearizeOp, tensor_ext::RotateOp>([&](auto &op) { | ||
auto secretness = ensureSecretness(op, op->getOperand(0)); | ||
if (!secretness) { | ||
return; | ||
} | ||
|
||
auto state = operands[0]->getValue(); | ||
if (!state.isInitialized()) { | ||
return; | ||
} | ||
|
||
propagate(op.getResult(), state.keySwitch()); | ||
}) | ||
// TODO(#1174): in BGV tensor::ExtractOp is assumed to be always | ||
// mul+const | ||
.Case<tensor::ExtractOp>([&](auto &op) { | ||
auto secretness = ensureSecretness(op, op.getResult()); | ||
if (!secretness) { | ||
return; | ||
} | ||
|
||
// now noise is Vmult + one Vks | ||
propagate(op.getResult(), CountState(1, 1)); | ||
}) | ||
.Case<mgmt::ModReduceOp>([&](auto modReduceOp) { | ||
// implicitly ensure that the operand is secret | ||
|
||
propagate(modReduceOp.getResult(), CountState(0, 0)); | ||
}); | ||
// should not propagate through mgmt::ModReduceOp | ||
return success(); | ||
} | ||
|
||
void annotateCount(Operation *top, DataFlowSolver *solver) { | ||
auto getIntegerAttr = [&](int level) { | ||
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), level); | ||
}; | ||
|
||
auto maxAddCount = 0; | ||
auto maxKeySwitchCount = 0; | ||
|
||
auto getCount = [&](Value value) { | ||
auto state = solver->lookupState<CountLattice>(value)->getValue(); | ||
// update the max | ||
maxAddCount = std::max(maxAddCount, state.getAddCount()); | ||
maxKeySwitchCount = std::max(maxKeySwitchCount, state.getKeySwitchCount()); | ||
return std::make_tuple(state.getAddCount(), state.getKeySwitchCount()); | ||
}; | ||
|
||
top->walk<WalkOrder::PreOrder>([&](secret::GenericOp genericOp) { | ||
for (auto i = 0; i != genericOp.getBody()->getNumArguments(); ++i) { | ||
auto blockArg = genericOp.getBody()->getArgument(i); | ||
auto [addCount, keySwitchCount] = getCount(blockArg); | ||
if (addCount != 0) { | ||
genericOp.setArgAttr(i, "addCount", getIntegerAttr(addCount)); | ||
} | ||
if (keySwitchCount != 0) { | ||
genericOp.setArgAttr(i, "keySwitchCount", | ||
getIntegerAttr(keySwitchCount)); | ||
} | ||
} | ||
|
||
genericOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) { | ||
if (op->getNumResults() == 0) { | ||
return; | ||
} | ||
auto [addCount, keySwitchCount] = getCount(op->getResult(0)); | ||
if (addCount != 0) { | ||
op->setAttr("addCount", getIntegerAttr(addCount)); | ||
} | ||
if (keySwitchCount != 0) { | ||
op->setAttr("keySwitchCount", getIntegerAttr(keySwitchCount)); | ||
} | ||
}); | ||
|
||
// annotate mgmt::OpenfheParamsAttr to func::FuncOp containing the genericOp | ||
auto *funcOp = genericOp->getParentOp(); | ||
auto openfheParamAttr = mgmt::OpenfheParamsAttr::get( | ||
funcOp->getContext(), maxAddCount, maxKeySwitchCount); | ||
funcOp->setAttr(mgmt::MgmtDialect::kArgOpenfheParamsAttrName, | ||
openfheParamAttr); | ||
}); | ||
} | ||
|
||
} // namespace heir | ||
} // namespace mlir |
123 changes: 123 additions & 0 deletions
123
lib/Analysis/AddAndKeySwitchCountAnalysis/AddAndKeySwitchCountAnalysis.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#ifndef LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_ | ||
#define LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_ | ||
|
||
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" | ||
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
// This analysis should be used after --secret-with-mgmt-bgv | ||
// but before --secret-distribute-generic | ||
// where a whole secret::GenericOp is assumed | ||
// | ||
// this follows strictly the strategy of modulus switch | ||
// right before multiplication including the first. | ||
// namely FLEXIBLEAUTOEXT in OpenFHE | ||
// | ||
// OpenFHE only supports setting EvalAddCount/EvalKeySwitchCount | ||
// for BGV/BFV, so HEIR only supports BGV here | ||
// | ||
// For not include-the-first, the addCount for the L-th level | ||
// might be overestimated, as we can not distinguish between | ||
// Vmult and Vfresh | ||
|
||
class CountState { | ||
public: | ||
CountState() : initialized(false), addCount(0), keySwitchCount(0) {} | ||
explicit CountState(int addCount, int keySwitchCount) | ||
: initialized(true), addCount(addCount), keySwitchCount(keySwitchCount) {} | ||
~CountState() = default; | ||
|
||
int getAddCount() const { | ||
assert(isInitialized()); | ||
return addCount; | ||
} | ||
|
||
int getKeySwitchCount() const { | ||
assert(isInitialized()); | ||
return keySwitchCount; | ||
} | ||
|
||
bool operator==(const CountState &rhs) const { | ||
return initialized == rhs.initialized && addCount == rhs.addCount && | ||
keySwitchCount == rhs.keySwitchCount; | ||
} | ||
|
||
bool isInitialized() const { return initialized; } | ||
|
||
CountState operator+(const CountState &rhs) const { | ||
assert(isInitialized() && rhs.isInitialized()); | ||
return CountState{addCount + rhs.addCount, | ||
keySwitchCount + rhs.keySwitchCount}; | ||
} | ||
|
||
CountState keySwitch() const { | ||
assert(isInitialized()); | ||
return CountState{addCount, keySwitchCount + 1}; | ||
} | ||
|
||
CountState max(const CountState &rhs) const { | ||
assert(isInitialized() && rhs.isInitialized()); | ||
return CountState{std::max(addCount, rhs.addCount), | ||
std::max(keySwitchCount, rhs.keySwitchCount)}; | ||
} | ||
|
||
static CountState join(const CountState &lhs, const CountState &rhs) { | ||
if (!lhs.isInitialized()) return rhs; | ||
if (!rhs.isInitialized()) return lhs; | ||
|
||
return lhs.max(rhs); | ||
} | ||
|
||
void print(llvm::raw_ostream &os) const { | ||
if (isInitialized()) { | ||
os << "CountState(" << addCount << ", " << keySwitchCount << ")"; | ||
} else { | ||
os << "CountState(uninitialized)"; | ||
} | ||
} | ||
|
||
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, | ||
const CountState &state) { | ||
state.print(os); | ||
return os; | ||
} | ||
|
||
private: | ||
bool initialized; | ||
int addCount; // how many Vmult or Vfresh (before first multiplication) | ||
// encountered | ||
int keySwitchCount; | ||
}; | ||
|
||
class CountLattice : public dataflow::Lattice<CountState> { | ||
public: | ||
using Lattice::Lattice; | ||
}; | ||
|
||
class CountAnalysis | ||
: public dataflow::SparseForwardDataFlowAnalysis<CountLattice>, | ||
public SecretnessAnalysisDependent<CountAnalysis> { | ||
public: | ||
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; | ||
friend class SecretnessAnalysisDependent<CountAnalysis>; | ||
|
||
void setToEntryState(CountLattice *lattice) override { | ||
propagateIfChanged(lattice, lattice->join(CountState())); | ||
} | ||
|
||
LogicalResult visitOperation(Operation *op, | ||
ArrayRef<const CountLattice *> operands, | ||
ArrayRef<CountLattice *> results) override; | ||
}; | ||
|
||
void annotateCount(Operation *top, DataFlowSolver *solver); | ||
|
||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_ANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_ADDANDKEYSWITCHCOUNTANALYSISANALYSIS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "AddAndKeySwitchCountAnalysis", | ||
srcs = ["AddAndKeySwitchCountAnalysis.cpp"], | ||
hdrs = ["AddAndKeySwitchCountAnalysis.h"], | ||
deps = [ | ||
"@heir//lib/Analysis/SecretnessAnalysis", | ||
"@heir//lib/Dialect/Mgmt/IR:Dialect", | ||
"@heir//lib/Dialect/Secret/IR:Dialect", | ||
"@heir//lib/Dialect/TensorExt/IR:Dialect", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:Analysis", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Support", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.