diff --git a/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index c442ee175c1d..ae4b2be59820 100644 --- a/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -29,6 +29,11 @@ include "mlir/SPIRV/SPIRVBase.td" #endif // SPIRV_BASE +#ifdef MLIR_CALLINTERFACES +#else +include "mlir/Analysis/CallInterfaces.td" +#endif // MLIR_CALLINTERFACES + // ----- def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> { @@ -151,7 +156,8 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> { // ----- -def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [InFunctionScope]> { +def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [ + InFunctionScope, DeclareOpInterfaceMethods]> { let summary = "Call a function."; let description = [{ diff --git a/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index ba7c61b3556d..d9e37879db2f 100644 --- a/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -264,7 +264,8 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> { } def SPV_ModuleOp : SPV_Op<"module", - [SingleBlockImplicitTerminator<"ModuleEndOp">, + [IsolatedFromAbove, + SingleBlockImplicitTerminator<"ModuleEndOp">, NativeOpTrait<"SymbolTable">]> { let summary = "The top-level op that defines a SPIR-V module"; diff --git a/lib/Dialect/SPIRV/SPIRVDialect.cpp b/lib/Dialect/SPIRV/SPIRVDialect.cpp index af50c8e4b1a7..96777b18cc90 100644 --- a/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Parser.h" #include "mlir/Support/StringExtras.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" @@ -34,6 +35,67 @@ namespace spirv { using namespace mlir; using namespace mlir::spirv; +//===----------------------------------------------------------------------===// +// InlinerInterface +//===----------------------------------------------------------------------===// + +/// Returns true if the given region contains spv.Return or spv.ReturnValue ops. +static inline bool containsReturn(Region ®ion) { + return llvm::any_of(region, [](Block &block) { + Operation *terminator = block.getTerminator(); + return isa(terminator) || + isa(terminator); + }); +} + +namespace { +/// This class defines the interface for inlining within the SPIR-V dialect. +struct SPIRVInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + /// Returns true if the given region 'src' can be inlined into the region + /// 'dest' that is attached to an operation registered to the current dialect. + bool isLegalToInline(Operation *op, Region *dest, + BlockAndValueMapping &) const final { + // TODO(antiagainst): Enable inlining structured control flows with return. + if ((isa(op) || isa(op)) && + containsReturn(op->getRegion(0))) + return false; + // TODO(antiagainst): we need to filter OpKill here to avoid inlining it to + // a loop continue construct: + // https://github.com/KhronosGroup/SPIRV-Headers/issues/86 + // However OpKill is fragment shader specific and we don't support it yet. + return true; + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + if (auto returnOp = dyn_cast(op)) { + OpBuilder(op).create(op->getLoc(), newDest); + op->erase(); + } else if (auto retValOp = dyn_cast(op)) { + llvm_unreachable("unimplemented spv.ReturnValue in inliner"); + } + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only spv.ReturnValue needs to be handled here. + auto retValOp = dyn_cast(op); + if (!retValOp) + return; + + // Replace the values directly with the return operands. + assert(valuesToRepl.size() == 1 && + "spv.ReturnValue expected to only handle one result"); + valuesToRepl.front()->replaceAllUsesWith(retValOp.value()); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // SPIR-V Dialect //===----------------------------------------------------------------------===// @@ -48,6 +110,8 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context) #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc" >(); + addInterfaces(); + // Allow unknown operations because SPIR-V is extensible. allowUnknownOperations(); } diff --git a/lib/Dialect/SPIRV/SPIRVOps.cpp b/lib/Dialect/SPIRV/SPIRVOps.cpp index 9662e056cc8e..7e51cc0bbd7c 100644 --- a/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Analysis/CallInterfaces.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" @@ -1199,6 +1200,14 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { return success(); } +CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() { + return getAttrOfType(kCallee); +} + +Operation::operand_range spirv::FunctionCallOp::getArgOperands() { + return arguments(); +} + //===----------------------------------------------------------------------===// // spv.globalVariable //===----------------------------------------------------------------------===// diff --git a/test/Dialect/SPIRV/Transforms/inlining.mlir b/test/Dialect/SPIRV/Transforms/inlining.mlir new file mode 100644 index 000000000000..9837d7babb7b --- /dev/null +++ b/test/Dialect/SPIRV/Transforms/inlining.mlir @@ -0,0 +1,182 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline)' -mlir-disable-inline-simplify | FileCheck %s + +spv.module "Logical" "GLSL450" { + func @callee() { + spv.Return + } + + // CHECK-LABEL: func @calling_single_block_ret_func + func @calling_single_block_ret_func() { + // CHECK-NEXT: spv.Return + spv.FunctionCall @callee() : () -> () + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @callee() -> i32 { + %0 = spv.constant 42 : i32 + spv.ReturnValue %0 : i32 + } + + // CHECK-LABEL: func @calling_single_block_retval_func + func @calling_single_block_retval_func() -> i32 { + // CHECK-NEXT: %[[CST:.*]] = spv.constant 42 + %0 = spv.FunctionCall @callee() : () -> (i32) + // CHECK-NEXT: spv.ReturnValue %[[CST]] + spv.ReturnValue %0 : i32 + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + spv.globalVariable @data bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + func @callee() { + %0 = spv._address_of @data : !spv.ptr [0]>, StorageBuffer> + %1 = spv.constant 0: i32 + %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0]>, StorageBuffer> + spv.Branch ^next + + ^next: + %3 = spv.constant 42: i32 + spv.Store "StorageBuffer" %2, %3 : i32 + spv.Return + } + + // CHECK-LABEL: func @calling_multi_block_ret_func + func @calling_multi_block_ret_func() { + // CHECK-NEXT: spv._address_of + // CHECK-NEXT: spv.constant 0 + // CHECK-NEXT: spv.AccessChain + // CHECK-NEXT: spv.Branch ^bb1 + // CHECK-NEXT: ^bb1: + // CHECK-NEXT: spv.constant + // CHECK-NEXT: spv.Store + // CHECK-NEXT: spv.Branch ^bb2 + spv.FunctionCall @callee() : () -> () + // CHECK-NEXT: ^bb2: + // CHECK-NEXT: spv.Return + spv.Return + } +} + +// TODO: calling_multi_block_retval_func + +// ----- + +spv.module "Logical" "GLSL450" { + func @callee(%cond : i1) -> () { + spv.selection { + spv.BranchConditional %cond, ^then, ^merge + ^then: + spv.Return + ^merge: + spv._merge + } + spv.Return + } + + // CHECK-LABEL: calling_selection_ret_func + func @calling_selection_ret_func() { + %0 = spv.constant true + // CHECK: spv.FunctionCall + spv.FunctionCall @callee(%0) : (i1) -> () + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @callee(%cond : i1) -> () { + spv.selection { + spv.BranchConditional %cond, ^then, ^merge + ^then: + spv.Branch ^merge + ^merge: + spv._merge + } + spv.Return + } + + // CHECK-LABEL: calling_selection_no_ret_func + func @calling_selection_no_ret_func() { + // CHECK-NEXT: %[[TRUE:.*]] = spv.constant true + %0 = spv.constant true + // CHECK-NEXT: spv.selection + // CHECK-NEXT: spv.BranchConditional %[[TRUE]], ^bb1, ^bb2 + // CHECK-NEXT: ^bb1: + // CHECK-NEXT: spv.Branch ^bb2 + // CHECK-NEXT: ^bb2: + // CHECK-NEXT: spv._merge + spv.FunctionCall @callee(%0) : (i1) -> () + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @callee(%cond : i1) -> () { + spv.loop { + spv.Branch ^header + ^header: + spv.BranchConditional %cond, ^body, ^merge + ^body: + spv.Return + ^continue: + spv.Branch ^header + ^merge: + spv._merge + } + spv.Return + } + + // CHECK-LABEL: calling_loop_ret_func + func @calling_loop_ret_func() { + %0 = spv.constant true + // CHECK: spv.FunctionCall + spv.FunctionCall @callee(%0) : (i1) -> () + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @callee(%cond : i1) -> () { + spv.loop { + spv.Branch ^header + ^header: + spv.BranchConditional %cond, ^body, ^merge + ^body: + spv.Branch ^continue + ^continue: + spv.Branch ^header + ^merge: + spv._merge + } + spv.Return + } + + // CHECK-LABEL: calling_loop_no_ret_func + func @calling_loop_no_ret_func() { + // CHECK-NEXT: %[[TRUE:.*]] = spv.constant true + %0 = spv.constant true + // CHECK-NEXT: spv.loop + // CHECK-NEXT: spv.Branch ^bb1 + // CHECK-NEXT: ^bb1: + // CHECK-NEXT: spv.BranchConditional %[[TRUE]], ^bb2, ^bb4 + // CHECK-NEXT: ^bb2: + // CHECK-NEXT: spv.Branch ^bb3 + // CHECK-NEXT: ^bb3: + // CHECK-NEXT: spv.Branch ^bb1 + // CHECK-NEXT: ^bb4: + // CHECK-NEXT: spv._merge + spv.FunctionCall @callee(%0) : (i1) -> () + spv.Return + } +}