Skip to content

Commit

Permalink
feat(compiler): fancy assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed May 3, 2024
1 parent c162dcf commit a74ec9f
Show file tree
Hide file tree
Showing 10 changed files with 953 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1462,5 +1462,60 @@ def FHELinalg_FancyIndexOp : FHELinalg_Op<"fancy_index", [Pure]> {
let hasVerifier = 1;
}

def FHELinalg_FancyAssignOp : FHELinalg_Op<"fancy_assign", [Pure]> {
let summary = "Assigns a tensor into another tensor at a tensor of indices.";

let description = [{
Examples:

```mlir
"FHELinalg.fancy_assign"(%t, %i, %a) : (tensor<5x!FHE.eint<16>>, tensor<3xindex>, tensor<3x!FHE.eint<16>>) -> tensor<5x!FHE.eint<16>>
//
// fancy_assign([10, 20, 30, 40, 50], [3, 1, 2], [1000, 2000, 3000]) = [10, 2000, 3000, 1000, 50]
//
```

```mlir
"FHELinalg.fancy_assign"(%t, %i, %a) : (tensor<5x!FHE.eint<16>>, tensor<2x2xindex>, tensor<2x2x!FHE.eint<16>>) -> tensor<5x!FHE.eint<16>>
//
// fancy_assign([10, 20, 30, 40, 50], [[3, 1], [2, 0]], [[1000, 2000], [3000, 4000]]) = [4000, 2000, 3000, 1000, 50]
//
```

```mlir
"FHELinalg.fancy_assign"(%t, %i, %a) : (tensor<2x3x!FHE.eint<16>>, tensor<3x2xindex>, tensor<3x!FHE.eint<16>>) -> tensor<2x3x!FHE.eint<16>>
//
// fancy_assign([[11, 12, 13], [21, 22, 23]], [[1, 0], [0, 2], [0, 0]], [1000, 2000, 3000]) = [[3000, 2000, 13], [1000, 22, 23]]
//
```

```mlir
"FHELinalg.fancy_assign"(%t, %i, %a) : (tensor<3x3x!FHE.eint<16>>, tensor<2x3x2xindex>, tensor<2x3x!FHE.eint<16>>) -> tensor<3x3x!FHE.eint<16>>
//
// fancy_assign(
// [[11, 12, 13], [21, 22, 23], [31, 32, 33]],
// [[[1, 0], [0, 2], [0, 0]], [[2, 0], [1, 1], [2, 1]]],
// [[1000, 2000, 3000], [4000, 5000, 6000]]
// ) = [[3000, 2000, 13], [1000, 5000, 23], [4000, 6000, 33]]
//
```

Notes:
- Assigning to the same output position results in undefined behavior.
}];

let arguments = (ins
Type<And<[TensorOf<[AnyType]>.predicate, HasStaticShapePred]>>:$input,
Type<And<[TensorOf<[Index]>.predicate, HasStaticShapePred]>>:$indices,
Type<And<[TensorOf<[AnyType]>.predicate, HasStaticShapePred]>>:$values
);

let results = (outs
Type<And<[TensorOf<[AnyType]>.predicate, HasStaticShapePred]>>:$output
);

let hasVerifier = 1;
}


#endif
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,124 @@ struct FancyIndexToTensorGenerate
};
};

/// This rewrite pattern transforms any instance of operators
/// `FHELinalg.fancy_assign` to an instance of `scf.forall`.
///
/// Example:
///
/// %output = "FHELinalg.fancy_assign"(%input, %indices, %values) :
/// (tensor<5x!FHE.eint<6>>, tensor<3xindex>, tensor<3x!FHE.eint<6>>) ->
/// tensor<5x!FHE.eint<6>>
///
/// becomes:
///
/// %0 = scf.forall (%i) in (3) shared_outs(%output = %input)
/// -> (tensor<5x!FHE.eint<6>>) {
/// %index = tensor.extract %indices[%i] : tensor<3xindex>
/// %value = tensor.extract %values[%i] : tensor<3x!FHE.eint<6>>
/// %value_slice = tensor.from_elements %value : tensor<1x!FHE.eint<6>>
/// scf.forall.in_parallel {
/// tensor.parallel_insert_slice
/// %value_slice into %output[%index][1][1]
/// : tensor<1x!FHE.eint<6>> into tensor<5x!FHE.eint<6>>
/// }
/// }
///
struct FancyAssignToSfcForall
: public mlir::OpRewritePattern<FHELinalg::FancyAssignOp> {
FancyAssignToSfcForall(mlir::MLIRContext *context)
: mlir::OpRewritePattern<FHELinalg::FancyAssignOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}

mlir::LogicalResult
matchAndRewrite(FHELinalg::FancyAssignOp fancyAssignOp,
mlir::PatternRewriter &rewriter) const override {

auto input = fancyAssignOp.getInput();
auto indices = fancyAssignOp.getIndices();
auto values = fancyAssignOp.getValues();

auto inputType = input.getType().dyn_cast<mlir::RankedTensorType>();
auto valuesType =
fancyAssignOp.getValues().getType().cast<mlir::RankedTensorType>();

auto inputShape = inputType.getShape();
auto inputDimensions = inputShape.size();
auto inputIsVector = inputDimensions == 1;
auto inputElementType = inputType.getElementType();

auto upperBounds = llvm::SmallVector<mlir::OpFoldResult>();
for (auto dimension : valuesType.getShape()) {
upperBounds.push_back(
mlir::OpFoldResult(rewriter.getIndexAttr(dimension)));
}

auto body = [=](mlir::OpBuilder &builder, mlir::Location location,
mlir::ValueRange args) {
auto output = args[args.size() - 1];
auto loopArgs = args.take_front(args.size() - 1);

std::vector<mlir::Value> index;
mlir::Value element;

if (inputIsVector) {
index.push_back(builder.create<tensor::ExtractOp>(
location, builder.getIndexType(), indices, loopArgs));

element = builder
.create<tensor::ExtractOp>(location, inputElementType,
values, loopArgs)
.getResult();
} else {
auto baseArgs =
std::vector<mlir::Value>(loopArgs.begin(), loopArgs.end());

for (size_t i = 0; i < inputShape.size(); i++) {
baseArgs.push_back(builder.create<arith::ConstantOp>(
location, builder.getIndexType(), builder.getIndexAttr(i)));
index.push_back(builder.create<tensor::ExtractOp>(
location, builder.getIndexType(), indices, baseArgs));
baseArgs.pop_back();
}

element = builder
.create<tensor::ExtractOp>(location, inputElementType,
values, loopArgs)
.getResult();
}

if (!element.getType().isa<mlir::TensorType>()) {
element =
builder.create<mlir::tensor::FromElementsOp>(location, element)
.getResult();
}

auto offsets = std::vector<mlir::OpFoldResult>();
auto sizes = std::vector<mlir::OpFoldResult>();
auto strides = std::vector<mlir::OpFoldResult>();

for (size_t i = 0; i < index.size(); i++) {
offsets.push_back(mlir::OpFoldResult(index[i]));
sizes.push_back(mlir::OpFoldResult(builder.getIndexAttr(1)));
strides.push_back(mlir::OpFoldResult(builder.getIndexAttr(1)));
}

auto inParallelOp = builder.create<mlir::scf::InParallelOp>(location);
builder.setInsertionPointToStart(inParallelOp.getBody());

builder.create<mlir::tensor::ParallelInsertSliceOp>(
location, element, output, offsets, sizes, strides);
};

auto forallOp = rewriter.create<mlir::scf::ForallOp>(
fancyAssignOp.getLoc(), upperBounds,
mlir::ValueRange{fancyAssignOp.getInput()}, std::nullopt, body);

rewriter.replaceOp(fancyAssignOp, forallOp->getResult(0));
return mlir::success();
};
};

namespace {
struct FHETensorOpsToLinalg
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
Expand All @@ -2130,6 +2248,9 @@ void FHETensorOpsToLinalg::runOnOperation() {
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();

target.addLegalOp<mlir::scf::ForallOp>();
target.addLegalOp<mlir::scf::InParallelOp>();

target.addDynamicallyLegalOp<
mlir::concretelang::Optimizer::PartitionFrontierOp>(
[&](mlir::concretelang::Optimizer::PartitionFrontierOp op) {
Expand Down Expand Up @@ -2292,6 +2413,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
patterns.insert<FromElementToTensorFromElements>(&getContext());
patterns.insert<TensorPartitionFrontierOpToLinalgGeneric>(&getContext());
patterns.insert<FancyIndexToTensorGenerate>(&getContext());
patterns.insert<FancyAssignToSfcForall>(&getContext());

if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,16 @@ getSqMANP(mlir::concretelang::FHELinalg::FancyIndexOp op,
return operandMANPs[0]->getValue().getMANP().value();
}

static std::optional<llvm::APInt>
getSqMANP(mlir::concretelang::FHELinalg::FancyAssignOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
auto inputManp = operandMANPs[0]->getValue().getMANP().value();
auto valuesManp = operandMANPs[2]->getValue().getMANP().value();
return inputManp.getLimitedValue() >= valuesManp.getLimitedValue()
? inputManp
: valuesManp;
}

static llvm::APInt
sqMANP_conv2d(llvm::APInt inputNorm, mlir::RankedTensorType weightTy,
std::optional<mlir::detail::ElementsAttrRange<
Expand Down Expand Up @@ -828,6 +838,10 @@ class MANPAnalysis
llvm::dyn_cast<mlir::concretelang::FHELinalg::FancyIndexOp>(
op)) {
norm2SqEquiv = getSqMANP(fancyIndexOp, operands);
} else if (auto fancyAssignOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::FancyAssignOp>(
op)) {
norm2SqEquiv = getSqMANP(fancyAssignOp, operands);
}
// Tensor Operators
// ExtractOp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,101 @@ mlir::LogicalResult FancyIndexOp::verify() {
return mlir::success();
}

mlir::LogicalResult FancyAssignOp::verify() {
auto inputType =
this->getInput().getType().dyn_cast_or_null<mlir::RankedTensorType>();
auto indicesType =
this->getIndices().getType().dyn_cast_or_null<mlir::RankedTensorType>();
auto valuesType =
this->getValues().getType().dyn_cast_or_null<mlir::RankedTensorType>();
auto outputType =
this->getOutput().getType().dyn_cast_or_null<mlir::RankedTensorType>();

auto inputElementType = inputType.getElementType();
auto valuesElementType = valuesType.getElementType();
auto outputElementType = outputType.getElementType();

if (valuesElementType != inputElementType) {
this->emitOpError() << "values element type " << valuesElementType
<< " doesn't match input element type "
<< inputElementType;
return mlir::failure();
}
if (outputElementType != inputElementType) {
this->emitOpError() << "output element type " << outputElementType
<< " doesn't match input element type "
<< inputElementType;
return mlir::failure();
}

auto inputShape = inputType.getShape();
auto indicesShape = indicesType.getShape();
auto valuesShape = valuesType.getShape();
auto outputShape = outputType.getShape();

auto inputIsVector = inputShape.size() == 1;
if (!inputIsVector) {
if (indicesShape[indicesShape.size() - 1] != (int64_t)inputShape.size()) {
this->emitOpError()
<< "size of the last dimension of indices '"
<< indicesShape[indicesShape.size() - 1]
<< "' doesn't match the number of dimensions of input '"
<< inputShape.size() << "'";
return mlir::failure();
}
}

auto expectedValuesShape =
inputIsVector ? indicesShape
: indicesShape.slice(0, indicesShape.size() - 1);
if (valuesShape != expectedValuesShape) {
auto stream = this->emitOpError();

stream << "values shape '<";
if (!valuesShape.empty()) {
stream << valuesShape[0];
for (size_t i = 1; i < valuesShape.size(); i++) {
stream << "x" << valuesShape[i];
}
}
stream << ">' doesn't match the expected values shape '<";
if (!expectedValuesShape.empty()) {
stream << expectedValuesShape[0];
for (size_t i = 1; i < expectedValuesShape.size(); i++) {
stream << "x" << expectedValuesShape[i];
}
}
stream << ">'";

return mlir::failure();
}

auto expectedOutputShape = inputShape;
if (outputShape != expectedOutputShape) {
auto stream = this->emitOpError();

stream << "output shape '<";
if (!outputShape.empty()) {
stream << outputShape[0];
for (size_t i = 1; i < outputShape.size(); i++) {
stream << "x" << outputShape[i];
}
}
stream << ">' doesn't match the expected output shape '<";
if (!expectedOutputShape.empty()) {
stream << expectedOutputShape[0];
for (size_t i = 1; i < expectedOutputShape.size(); i++) {
stream << "x" << expectedOutputShape[i];
}
}
stream << ">'";

return mlir::failure();
}

return mlir::success();
}

/// Avoid addition with constant tensor of 0s
OpFoldResult AddEintIntOp::fold(FoldAdaptor operands) {
auto toAdd = operands.getRhs().dyn_cast_or_null<mlir::DenseIntElementsAttr>();
Expand Down
Loading

0 comments on commit a74ec9f

Please sign in to comment.