Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tosa compatibility to TFHE-rs HL emitter #1272

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ struct ConvertExtUIOp : public OpConversionPattern<mlir::arith::ExtUIOp> {
}
};

struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
ConvertExtSIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ExtSIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ExtSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto outType = convertArithToCGGIType(
cast<IntegerType>(op.getResult().getType()), op->getContext());
auto castOp = b.create<cggi::CastOp>(op.getLoc(), outType, adaptor.getIn());

rewriter.replaceOp(op, castOp);
return success();
}
};

struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
ConvertShRUIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}
Expand All @@ -139,12 +159,10 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
.getSExtValue();

auto inputValue =
mlir::IntegerAttr::get(rewriter.getI8Type(), (int8_t)shiftAmount);
auto cteOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI8Type(), inputValue);
mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount);

auto shiftOp =
b.create<cggi::ShiftRightOp>(outputType, adaptor.getLhs(), cteOp);
auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);

return success();
Expand All @@ -157,14 +175,12 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
auto shiftAmount =
cast<IntegerAttr>(cteShiftSizeOp.getValue()).getValue().getSExtValue();

auto inputValue = mlir::IntegerAttr::get(rewriter.getI8Type(), shiftAmount);
auto cteOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI8Type(), inputValue);
auto inputValue =
mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount);

auto shiftOp =
b.create<cggi::ShiftRightOp>(outputType, adaptor.getLhs(), cteOp);
auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);
rewriter.replaceOp(op.getLhs().getDefiningOp(), cteOp);

return success();
}
Expand All @@ -184,10 +200,7 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
target.addDynamicallyLegalOp<mlir::arith::ConstantOp>(
[](mlir::arith::ConstantOp op) {
// Allow use of constant if it is used to denote the size of a shift
bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) {
return isa<cggi::ShiftRightOp>(user);
});
return (isa<IndexType>(op.getValue().getType()) || (usedByShift));
return (isa<IndexType>(op.getValue().getType()));
});

target.addDynamicallyLegalOp<
Expand All @@ -199,8 +212,8 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

patterns.add<
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertShRUIOp,
ConvertBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp,
ConvertShRUIOp, ConvertBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
Expand Down
221 changes: 174 additions & 47 deletions lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h"

#include <mlir/IR/MLIRContext.h>

#include <cstdint>

#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
Expand All @@ -15,7 +11,9 @@
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir::heir::arith {

Expand Down Expand Up @@ -94,7 +92,7 @@ class ArithToCGGIQuartTypeConverter : public TypeConverter {
};

static Value createTrivialOpMaxWidth(ImplicitLocOpBuilder b, int value) {
auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth >> 1);
auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth);
auto intAttr = b.getIntegerAttr(maxWideIntType, value);

auto encoding =
Expand Down Expand Up @@ -153,19 +151,16 @@ static SmallVector<Value> extractLastDimHalves(

static Value createScalarOrSplatConstant(OpBuilder &builder, Location loc,
Type type, int64_t value) {
unsigned elementBitWidth = 0;
if (auto lweTy = dyn_cast<lwe::LWECiphertextType>(type))
elementBitWidth =
cast<lwe::UnspecifiedBitFieldEncodingAttr>(lweTy.getEncoding())
.getCleartextBitwidth();
else
elementBitWidth = maxIntWidth;
// unsigned elementBitWidth = 0;
// if (auto lweTy = dyn_cast<lwe::LWECiphertextType>(type))
// elementBitWidth =
// cast<lwe::UnspecifiedBitFieldEncodingAttr>(lweTy.getEncoding())
// .getCleartextBitwidth();
// else
// elementBitWidth = maxIntWidth;

auto apValue = APInt(elementBitWidth, value);

auto maxWideIntType =
IntegerType::get(builder.getContext(), maxIntWidth >> 1);
auto intAttr = builder.getIntegerAttr(maxWideIntType, value);
auto intAttr = builder.getIntegerAttr(
IntegerType::get(builder.getContext(), maxIntWidth), value);

return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
}
Expand Down Expand Up @@ -249,6 +244,40 @@ struct ConvertQuartConstantOp
}
};

struct ConvertQuartTruncIOp
: public OpConversionPattern<mlir::arith::TruncIOp> {
ConvertQuartTruncIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::TruncIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto newResultTy = getTypeConverter()->convertType<RankedTensorType>(
op.getResult().getType());
auto newInTy =
getTypeConverter()->convertType<RankedTensorType>(op.getIn().getType());

SmallVector<OpFoldResult> offsets(newResultTy.getShape().size(),
rewriter.getIndexAttr(0));
offsets.back() = rewriter.getIndexAttr(newInTy.getShape().back() -
newResultTy.getShape().back());
SmallVector<OpFoldResult> sizes(newResultTy.getShape().size());
sizes.back() = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> strides(newResultTy.getShape().size(),
rewriter.getIndexAttr(1));

auto resOp = rewriter.create<tensor::ExtractSliceOp>(
op->getLoc(), adaptor.getIn(), offsets, sizes, strides);
rewriter.replaceOp(op, resOp);

return success();
}
};

template <typename ArithExtOp>
struct ConvertQuartExt final : OpConversionPattern<ArithExtOp> {
using OpConversionPattern<ArithExtOp>::OpConversionPattern;
Expand All @@ -274,23 +303,21 @@ struct ConvertQuartExt final : OpConversionPattern<ArithExtOp> {
auto resultChunks = newResultTy.getShape().back();
auto inChunks = newInTy.getShape().back();

if (resultChunks > inChunks) {
auto paddingFactor = resultChunks - inChunks;
// Through definition of ExtOp, paddingFactor is always positive
auto paddingFactor = resultChunks - inChunks;

SmallVector<OpFoldResult, 1> low, high;
low.push_back(rewriter.getIndexAttr(0));
high.push_back(rewriter.getIndexAttr(paddingFactor));
SmallVector<OpFoldResult, 1> low, high;
low.push_back(rewriter.getIndexAttr(0));
high.push_back(rewriter.getIndexAttr(paddingFactor));

auto padValue = createTrivialOpMaxWidth(b, 0);
auto padValue = createTrivialOpMaxWidth(b, 0);

auto resultVec = b.create<tensor::PadOp>(newResultTy, adaptor.getIn(),
low, high, padValue,
/*nofold=*/true);
auto resultVec = b.create<tensor::PadOp>(newResultTy, adaptor.getIn(), low,
high, padValue,
/*nofold=*/true);

rewriter.replaceOp(op, resultVec);
return success();
}
return failure();
rewriter.replaceOp(op, resultVec);
return success();
}
};

Expand Down Expand Up @@ -318,14 +345,15 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {

// Actual type of the underlying elements; we use half the width.
// Create Constant
auto intAttr = IntegerAttr::get(rewriter.getI8Type(), maxIntWidth >> 1);
auto shiftAttr =
IntegerAttr::get(rewriter.getIndexType(), maxIntWidth >> 1);

auto elemType = convertArithToCGGIType(
IntegerType::get(op->getContext(), maxIntWidth), op->getContext());
auto realTy = convertArithToCGGIType(
IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext());

auto constantOp = b.create<mlir::arith::ConstantOp>(intAttr);
// auto constantOp = b.create<mlir::arith::ConstantOp>(intAttr);

SmallVector<Value> carries;
SmallVector<Value> outputs;
Expand All @@ -338,7 +366,8 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {

// Now all the outputs are 16b elements, wants presentation of 4x8b
if (i != splitLhs.size() - 1) {
auto carry = b.create<cggi::ShiftRightOp>(elemType, lowSum, constantOp);
auto carry =
b.create<cggi::ScalarShiftRightOp>(elemType, lowSum, shiftAttr);
carries.push_back(carry);
}

Expand All @@ -356,6 +385,103 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {
}
};

// Implemented using the Karatsuba algorithm
// https://en.wikipedia.org/wiki/Karatsuba_algorithm#Algorithm
struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::MulIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(loc, rewriter);

auto newTy =
getTypeConverter()->convertType<RankedTensorType>(op.getType());
if (!newTy)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", op.getType()));
if (newTy.getShape().back() != 4)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("Mul only support 4 split elements. Shape: {0}",
newTy));

auto elemTy = convertArithToCGGIType(
IntegerType::get(op->getContext(), maxIntWidth), op->getContext());
auto realTy = convertArithToCGGIType(
IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext());

// Create Constant
auto shiftAttr =
rewriter.getIntegerAttr(b.getIndexType(), maxIntWidth >> 1);

SmallVector<Value> splitLhs =
extractLastDimHalves(rewriter, loc, adaptor.getLhs());
SmallVector<Value> splitRhs =
extractLastDimHalves(rewriter, loc, adaptor.getRhs());

// TODO: Implement the real Karatsuba algorithm for 4x4 multiplication.
// First part of Karatsuba algorithm
auto z00 = b.create<cggi::MulOp>(splitLhs[0], splitRhs[0]);
auto z02 = b.create<cggi::MulOp>(splitLhs[1], splitRhs[1]);
auto z01_p1 = b.create<cggi::AddOp>(splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(z01_p1, z01_p2);
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
auto z01 = b.create<cggi::SubOp>(z01_s, z02);

// Second part I of Karatsuba algorithm
auto z1a0 = b.create<cggi::MulOp>(splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(z1a1_p1, z1a1_p2);
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);

// Second part II of Karatsuba algorithm
auto z1b0 = b.create<cggi::MulOp>(splitLhs[2], splitRhs[0]);
auto z1b2 = b.create<cggi::MulOp>(splitLhs[3], splitRhs[1]);
auto z1b1_p1 = b.create<cggi::AddOp>(splitLhs[2], splitLhs[3]);
auto z1b1_p2 = b.create<cggi::AddOp>(splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(z1b1_m, z1b0);
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);

auto out2Kara = b.create<cggi::AddOp>(z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(out2Kara, z02);
auto out3Carry = b.create<cggi::AddOp>(z1a1, z1b1);

// Output are now all 16b elements, wants presentation of 4x8b
auto output0Lsb = b.create<cggi::CastOp>(realTy, z00);
auto output0LsbHigh = b.create<cggi::CastOp>(elemTy, output0Lsb);
auto output0Msb =
b.create<cggi::ScalarShiftRightOp>(elemTy, z00, shiftAttr);

auto output1Lsb = b.create<cggi::CastOp>(realTy, z01);
auto output1LsbHigh = b.create<cggi::CastOp>(elemTy, output1Lsb);
auto output1Msb =
b.create<cggi::ScalarShiftRightOp>(elemTy, z01, shiftAttr);

auto output2Lsb = b.create<cggi::CastOp>(realTy, out2Carry);
auto output2LsbHigh = b.create<cggi::CastOp>(elemTy, output2Lsb);
auto output2Msb =
b.create<cggi::ScalarShiftRightOp>(elemTy, out2Carry, shiftAttr);

auto output3Lsb = b.create<cggi::CastOp>(realTy, out3Carry);
auto output3LsbHigh = b.create<cggi::CastOp>(elemTy, output3Lsb);

auto output1 = b.create<cggi::AddOp>(output1LsbHigh, output0Msb);
auto output2 = b.create<cggi::AddOp>(output2LsbHigh, output1Msb);
auto output3 = b.create<cggi::AddOp>(output3LsbHigh, output2Msb);

Value resultVec = constructResultTensor(
rewriter, loc, newTy, {output0LsbHigh, output1, output2, output3});
rewriter.replaceOp(op, resultVec);
return success();
}
};

struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase<ArithToCGGIQuart> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand Down Expand Up @@ -386,28 +512,29 @@ struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase<ArithToCGGIQuart> {

target.addDynamicallyLegalOp<mlir::arith::ConstantOp>(
[](mlir::arith::ConstantOp op) {
// Allow use of constant if it is used to denote the size of a shift
bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) {
return isa<cggi::ShiftRightOp>(user);
});
return (isa<IndexType>(op.getValue().getType()) || (usedByShift));
return isa<IndexType>(op.getValue().getType());
});

patterns.add<
ConvertQuartConstantOp, ConvertQuartExt<mlir::arith::ExtUIOp>,
ConvertQuartExt<mlir::arith::ExtSIOp>, ConvertQuartAddI,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::StoreOp>,
ConvertAny<memref::SubViewOp>, ConvertAny<memref::CopyOp>,
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp>,
ConvertAny<affine::AffineStoreOp>, ConvertAny<affine::AffineLoadOp>>(
typeConverter, context);
patterns
.add<ConvertQuartConstantOp, ConvertQuartExt<mlir::arith::ExtUIOp>,
ConvertQuartExt<mlir::arith::ExtSIOp>, ConvertQuartAddI,
ConvertQuartMulI, ConvertAny<memref::LoadOp>,
ConvertAny<memref::AllocOp>, ConvertAny<memref::DeallocOp>,
ConvertAny<memref::StoreOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<tensor::FromElementsOp>,
ConvertAny<tensor::ExtractOp>, ConvertAny<affine::AffineStoreOp>,
ConvertAny<affine::AffineLoadOp>>(typeConverter, context);

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}

// Remove the uncessary tensor ops between each converted arith operation.
OpPassManager pipeline("builtin.module");
pipeline.addPass(createCSEPass());
(void)runPipeline(pipeline, getOperation());
}
};

Expand Down
Loading
Loading