Skip to content

Commit

Permalink
feat(compiler): Allow concat with only one operand
Browse files Browse the repository at this point in the history
  • Loading branch information
BourgerieQuentin committed Apr 25, 2024
1 parent e238067 commit dcf9181
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,7 @@ def FHELinalg_ConcatOp : FHELinalg_Op<"concat", [Pure]> {
);

let hasVerifier = 1;
let hasFolder = 1;
}

def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", [Pure, BinaryEintInt, DeclareOpInterfaceMethods<Binary, ["sqMANP"]>]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,18 @@ static bool sameShapeExceptAxis(llvm::ArrayRef<int64_t> shape1,
return true;
}

/// Avoid multiplication with constant tensor of 1s
OpFoldResult ConcatOp::fold(FoldAdaptor operands) {
if (this->getNumOperands() == 1) {
return this->getOperand(0);
}
return nullptr;
}

mlir::LogicalResult ConcatOp::verify() {
unsigned numOperands = this->getNumOperands();
if (numOperands < 2) {
this->emitOpError() << "should have at least 2 inputs";
if (numOperands < 1) {
this->emitOpError() << "should have at least 1 inputs";
return mlir::failure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,13 @@
// -----

func.func @main() -> tensor<0x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}}
// expected-error @+1 {{'FHELinalg.concat' op should have at least 1 inputs}}
%0 = "FHELinalg.concat"() : () -> tensor<0x!FHE.eint<7>>
return %0 : tensor<0x!FHE.eint<7>>
}

// -----

func.func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}}
%0 = "FHELinalg.concat"(%x) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>>
return %0 : tensor<4x!FHE.eint<7>>
}

// -----

func.func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equal}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,10 @@ func.func @to_unsigned_zero() -> tensor<4x!FHE.eint<7>> {
%1 = "FHELinalg.to_unsigned"(%0) : (tensor<4x!FHE.esint<7>>) -> tensor<4x!FHE.eint<7>>
return %1 : tensor<4x!FHE.eint<7>>
}

// CHECK: func.func @concat_1_operand(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> {
// CHECK-NEXT: return %[[a0]]
func.func @concat_1_operand(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>>
return %0 : tensor<4x!FHE.eint<7>>
}

0 comments on commit dcf9181

Please sign in to comment.