Skip to content

Commit

Permalink
Merge pull request #1244 from ZenithalHourlyRate:openfhe-plaintext-mo…
Browse files Browse the repository at this point in the history
…dulus

PiperOrigin-RevId: 714288970
  • Loading branch information
copybara-github committed Jan 11, 2025
2 parents 337468d + 3389eb5 commit a8f17da
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 8 deletions.
13 changes: 13 additions & 0 deletions lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ class Openfhe_BinaryOp<string mnemonic, list<Trait> traits = []>
}

def GenParamsOp : Openfhe_Op<"gen_params"> {
let description = [{
Generates the parameters for the OpenFHE scheme.

`mulDepth` is the depth of the multiplication circuit,
including the bootstrapping depth.

`plainMod` is the modulus of the plaintext space. If we
are using CKKS, this is 0.

`insecure` is a flag that determines whether the parameters
are generated securely or not. This is mainly used for
testing purposes.
}];
let arguments = (ins
I64Attr:$mulDepth,
I64Attr:$plainMod,
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Openfhe/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ cc_library(
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/Openfhe/IR:Dialect",
"@heir//lib/Dialect/RNS/IR:Dialect",
"@llvm-project//llvm:Support",
Expand Down
20 changes: 18 additions & 2 deletions lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <set>
#include <string>

#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"
#include "lib/Dialect/RNS/IR/RNSTypes.h"
Expand All @@ -18,6 +20,7 @@
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
Expand Down Expand Up @@ -84,8 +87,21 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
auto genFuncOp = builder.create<func::FuncOp>(genFuncName, genFuncType);
builder.setInsertionPointToEnd(genFuncOp.addEntryBlock());

// TODO(#661) : Calculate the appropriate values by analyzing the function
int64_t plainMod = 4295294977;
// get plaintext modulus from function argument ciphertext type
// for CKKS, plainMod is 0
int64_t plainMod = 0;
for (auto arg : op.getArguments()) {
if (auto argType = dyn_cast<lwe::NewLWECiphertextType>(
getElementTypeOrSelf(arg.getType()))) {
if (auto modArithType = dyn_cast<mod_arith::ModArithType>(
argType.getPlaintextSpace().getRing().getCoefficientType())) {
plainMod = modArithType.getModulus().getInt();
// implicitly assume arguments have the same plaintext modulus
break;
}
}
}

Type openfheParamsType = openfhe::CCParamsType::get(builder.getContext());
Value ccParams = builder.create<openfhe::GenParamsOp>(
openfheParamsType, mulDepth, plainMod, insecure);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,11 @@ class SecretToBGVTypeConverter : public TypeWithAttrTypeConverter {
auto dimension = mgmtAttr.getDimension();

auto *ctx = type.getContext();
// TODO(#661) : Calculate the appropriate values by analyzing the function
auto plaintextRing = ::mlir::heir::polynomial::RingAttr::get(
type.getContext(),
mod_arith::ModArithType::get(
ctx, IntegerAttr::get(IntegerType::get(ctx, 64), 65537)),
ctx, IntegerAttr::get(IntegerType::get(ctx, 64), 4295294977)),
ring.getPolynomialModulus());

SmallVector<IntegerAttr, 6> moduliChain;
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ class SecretToCKKSTypeConverter : public TypeWithAttrTypeConverter {

Type valueTy = type.getValueType();

// Note that slot number for CKKS is always half of the ring dimension.
// so ring_.getPolynomialModulus() is not useful here
// TODO(#1191): use packing information to get the correct slot number
auto plaintextRing = ::mlir::heir::polynomial::RingAttr::get(
type.getContext(),
mod_arith::ModArithType::get(
ctx, IntegerAttr::get(IntegerType::get(ctx, 64), 65537)),
ring_.getPolynomialModulus());
type.getContext(), Float64Type::get(ctx), ring_.getPolynomialModulus());

SmallVector<IntegerAttr, 6> moduliChain;
for (auto modArithType :
Expand Down
4 changes: 3 additions & 1 deletion lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,9 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenParamsOp op) {

os << "CCParamsT " << paramsName << ";\n";
os << paramsName << ".SetMultiplicativeDepth(" << mulDepth << ");\n";
os << paramsName << ".SetPlaintextModulus(" << plainMod << ");\n";
if (plainMod != 0) {
os << paramsName << ".SetPlaintextModulus(" << plainMod << ");\n";
}
if (op.getInsecure()) {
os << paramsName << ".SetSecurityLevel(lbcrypto::HEStd_NotSet);\n";
os << paramsName << ".SetRingDim(128);\n";
Expand Down
10 changes: 10 additions & 0 deletions tests/Dialect/Openfhe/Emitters/emit_openfhe_pke.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,13 @@ func.func @test_constant() -> tensor<2xf32> {
%cst_2d = arith.constant dense<[[1.5, 2.5]]> : tensor<1x2xf64>
return %splat : tensor<2xf32>
}

// -----

// CHECK-LABEL: test_ckks_no_plaintext_modulus
// CHECK-NOT: SetPlaintextModulus
func.func @test_ckks_no_plaintext_modulus() -> !openfhe.crypto_context {
%0 = openfhe.gen_params {insecure = false, mulDepth = 2 : i64, plainMod = 0 : i64} : () -> !openfhe.cc_params
%1 = openfhe.gen_context %0 {supportFHE = false} : (!openfhe.cc_params) -> !openfhe.crypto_context
return %1 : !openfhe.crypto_context
}

0 comments on commit a8f17da

Please sign in to comment.