Skip to content

Commit

Permalink
LICM pass zero trip count loop handling; zero trip count loop removal in
Browse files Browse the repository at this point in the history
-simplify-affine-structures

- addresses tensorflow#194

- change name of tablegen auto-generated internal helper for op
  interfaces to avoid potential conflicts with methods of same
  name in dialect namespaces. (addresses tensorflow#197)

Signed-off-by: Uday Bondhugula <[email protected]>
  • Loading branch information
bondhugula committed Oct 26, 2019
1 parent 705a743 commit 95526b1
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 39 deletions.
2 changes: 2 additions & 0 deletions include/mlir/Dialect/LoopOps/LoopOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ void ensureLoopTerminator(Region &region, Builder &builder, Location loc);
/// not an induction variable, then return nullptr.
ForOp getForInductionVarOwner(Value *val);

/// Returns the trip count of the loop if it's a constant, None otherwise.
Optional<uint64_t> getConstantTripCount(ForOp forOp);
} // end namespace loop
} // end namespace mlir
#endif // MLIR_LOOPOPS_OPS_H_
1 change: 1 addition & 0 deletions include/mlir/Transforms/LoopLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#ifndef MLIR_TRANSFORMS_LOOPLIKEINTERFACE_H_
#define MLIR_TRANSFORMS_LOOPLIKEINTERFACE_H_

#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
Expand Down
4 changes: 4 additions & 0 deletions include/mlir/Transforms/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
}],
"LogicalResult", "moveOutOfLoop", (ins "ArrayRef<Operation *>":$ops)
>,
InterfaceMethod<"Get the trip count if it is a constant.",
"llvm::Optional<uint64_t>", "getConstantTripCount", (ins), [{
return getConstantTripCount(op);
}]>,
];
}

Expand Down
27 changes: 26 additions & 1 deletion lib/Dialect/LoopOps/LoopOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ bool ForOp::isDefinedOutsideOfLoop(Value *value) {

LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
for (auto *op : ops)
op->moveBefore(this->getOperation());
op->moveBefore(*this);
return success();
}

Expand All @@ -153,6 +153,31 @@ ForOp mlir::loop::getForInductionVarOwner(Value *val) {
return dyn_cast_or_null<ForOp>(containingInst);
}

Optional<uint64_t> mlir::loop::getConstantTripCount(ForOp forOp) {
Value *lb = forOp.lowerBound();
Value *ub = forOp.upperBound();

if (lb == ub)
return 0;

IntegerAttr lbCst, ubCst, step;
if (!matchPattern(lb, m_Constant(&lbCst)) ||
!matchPattern(ub, m_Constant(&ubCst)))
return llvm::None;

int64_t lbConst = lbCst.getValue().getSExtValue();
int64_t ubConst = ubCst.getValue().getSExtValue();
if (ubConst - lbConst <= 0)
return 0;

if (!matchPattern(forOp.step(), m_Constant(&step)))
return llvm::None;

// Step is guaranteed to be positive.
int64_t stepConst = step.getValue().getSExtValue();
return llvm::divideCeil(ubConst - lbConst, stepConst);
}

//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 5 additions & 5 deletions lib/Transforms/AffineLoopInvariantCodeMotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "licm"
#define DEBUG_TYPE "affine-licm"

using namespace mlir;

namespace {

/// Loop invariant code motion (LICM) pass.
/// Affine loop invariant code motion (LICM) pass.
/// TODO: This pass should be removed once the new LICM pass can handle its
/// uses.
/// TODO(asabne) : The pass is missing zero-trip tests.
/// TODO(asabne) : Check for the presence of side effects before hoisting.
/// TODO: This code should be removed once the new LICM pass can handle its
/// uses.
struct LoopInvariantCodeMotion : public FunctionPass<LoopInvariantCodeMotion> {
void runOnFunction() override;
void runOnAffineForOp(AffineForOp forOp);
Expand Down Expand Up @@ -245,4 +245,4 @@ mlir::createAffineLoopInvariantCodeMotionPass() {

static PassRegistration<LoopInvariantCodeMotion>
pass("affine-loop-invariant-code-motion",
"Hoist loop invariant instructions outside of the loop");
"Hoist loop invariant operations outside of the loop");
26 changes: 14 additions & 12 deletions lib/Transforms/LoopInvariantCodeMotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,15 @@ static bool canBeHoisted(Operation *op,
auto thisOpIsSideEffecting = sideEffecting;
if (thisOpIsSideEffecting != SideEffecting::Never) {
thisOpIsSideEffecting = interface.isSideEffecting(op);
// If the op always has sideeffects, we cannot hoist.
// If the op always has side effects, we cannot hoist.
if (thisOpIsSideEffecting == SideEffecting::Always)
return false;
}
// Recurse into the regions for this op and check whether the contained ops
// can be hoisted.
for (auto &region : op->getRegions()) {
for (auto &block : region.getBlocks()) {
for (auto &innerOp : block) {
if (innerOp.isKnownTerminator())
continue;
for (auto &innerOp : block.without_terminator()) {
if (!canBeHoisted(&innerOp, definedOutside, thisOpIsSideEffecting,
interface))
return false;
Expand Down Expand Up @@ -112,7 +110,7 @@ static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike,
}
}

// For all instructions that we found to be invariant, move outside of the
// For all operations that we found to be invariant, move outside of the
// loop.
auto result = looplike.moveOutOfLoop(opsToMove);
LLVM_DEBUG(looplike.print(llvm::dbgs() << "Modified loop\n"));
Expand All @@ -126,12 +124,16 @@ void LoopInvariantCodeMotion::runOnOperation() {
// Walk through all loops in a function in innermost-loop-first order. This
// way, we first LICM from the inner loop, and place the ops in
// the outer loop, which in turn can be further LICM'ed.
getOperation()->walk([&](Operation *op) {
if (auto looplike = dyn_cast<LoopLikeOpInterface>(op)) {
LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n"));
if (failed(moveLoopInvariantCode(looplike, interface)))
signalPassFailure();
}
getOperation()->walk([&](LoopLikeOpInterface loopLikeOp) {
// Skip zero trip count loops. For unknown trip counts, we still move
// invariant code since it is side-effect free, and in general profitable.
// TODO: when necessary, we could only move when the trip count is
// guaranteed to be at least one.
if (loopLikeOp.getConstantTripCount() == uint64_t(0))
return;
LLVM_DEBUG(loopLikeOp.print(llvm::dbgs() << "\nOriginal loop\n"));
if (failed(moveLoopInvariantCode(loopLikeOp, interface)))
signalPassFailure();
});
}

Expand All @@ -146,4 +148,4 @@ std::unique_ptr<Pass> mlir::createLoopInvariantCodeMotionPass() {

static PassRegistration<LoopInvariantCodeMotion>
pass("loop-invariant-code-motion",
"Hoist loop invariant instructions outside of the loop");
"Hoist loop invariant operations outside of the loop");
21 changes: 15 additions & 6 deletions lib/Transforms/SimplifyAffineStructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopLikeInterface.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"

Expand Down Expand Up @@ -93,12 +94,12 @@ std::unique_ptr<OpPassBase<FuncOp>> mlir::createSimplifyAffineStructuresPass() {
void SimplifyAffineStructures::runOnFunction() {
auto func = getFunction();
simplifiedAttributes.clear();
func.walk([&](Operation *opInst) {
for (auto attr : opInst->getAttrs()) {
func.walk([&](Operation *op) {
for (auto attr : op->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);
simplifyAndUpdateAttribute(op, attr.first, mapAttr);
else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>())
simplifyAndUpdateAttribute(opInst, attr.first, setAttr);
simplifyAndUpdateAttribute(op, attr.first, setAttr);
}
});

Expand All @@ -110,8 +111,16 @@ void SimplifyAffineStructures::runOnFunction() {
for (auto allocOp : allocOps) {
normalizeMemRef(allocOp);
}

// Remove zero trip count loops.
// TODO: this could be moved to a more appropriate place.
func.walk([&](LoopLikeOpInterface loopOp) {
if (loopOp.getConstantTripCount() == uint64_t(0))
loopOp.erase();
});
}

static PassRegistration<SimplifyAffineStructures>
pass("simplify-affine-structures",
"Simplify affine expressions in maps/sets and normalize memrefs");
pass("simplify-affine-structures", "Simplify expressions in afine map/set "
"attributes, normalize memrefs, remove "
"zero trip-count loops");
14 changes: 5 additions & 9 deletions test/Transforms/affine-loop-invariant-code-motion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,23 @@ func @nested_loops_both_having_invariant_code() {
// CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32
// CHECK-NEXT: %1 = addf %cst, %cst_0 : f32
// CHECK-NEXT: affine.for %arg0 = 0 to 10 {
// CHECK-NEXT: affine.store %1, %0[%arg0] : memref<10xf32>
// CHECK-NEXT: affine.store %1, %0[%arg0] : memref<10xf32>

return
}

// The store-load forwarding can see through affine apply's since it relies on
// dependence information.
// CHECK-LABEL: func @store_affine_apply
func @store_affine_apply() -> memref<10xf32> {
// CHECK-LABEL: func @store_affine_for
func @store_affine_for() -> memref<10xf32> {
%cf7 = constant 7.0 : f32
%m = alloc() : memref<10xf32>
affine.for %arg0 = 0 to 10 {
%t0 = affine.apply (d1) -> (d1 + 1)(%arg0)
affine.store %cf7, %m[%t0] : memref<10xf32>
affine.store %cf7, %m[%arg0 + 1] : memref<10xf32>
}
return %m : memref<10xf32>
// CHECK: %cst = constant 7.000000e+00 : f32
// CHECK-NEXT: %0 = alloc() : memref<10xf32>
// CHECK-NEXT: affine.for %arg0 = 0 to 10 {
// CHECK-NEXT: %1 = affine.apply #map3(%arg0)
// CHECK-NEXT: affine.store %cst, %0[%1] : memref<10xf32>
// CHECK-NEXT: affine.store %cst, %0[%arg0 + 1] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %0 : memref<10xf32>
}
Expand Down
50 changes: 47 additions & 3 deletions test/Transforms/loop-invariant-code-motion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,54 @@ func @invariant_affine_nested_if_else() {
// CHECK-NEXT: }
// CHECK-NEXT: }

return
}

// CHECK-LABEL: func @zero_trip_count_affine
func @zero_trip_count_affine() {
%m = alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
%N = constant 0 : index

affine.for %arg0 = 0 to %N {
affine.for %arg1 = 0 to 10 {
%v0 = addf %cf7, %cf7 : f32
}
}
// CHECK: alloc() : memref<10xf32>
// CHECK-NEXT: %cst = constant 7.000000e+00 : f32
// CHECK-NEXT: %c0 = constant 0 : index
// CHECK-NEXT: affine.for
// CHECK-NEXT: addf
// CHECK-NEXT: affine.for

return
}

// CHECK-LABEL: func @zero_trip_count_loop
func @zero_trip_count_loop(%N : index) {
%m = alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
%c1 = constant 1 : index
%c5 = constant 5 : index

loop.for %i = %N to %N step %c1 {
loop.for %j = %c5 to %c5 step %c1 {
addf %cf7, %cf7 : f32
}
}
// CHECK: alloc() : memref<10xf32>
// CHECK-NEXT: %cst = constant 7.000000e+00 : f32
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %c5 = constant 5 : index
// CHECK-NEXT: loop.for
// CHECK-NEXT: loop.for
// CHECK-NEXT: addf

return
}

// CHECK-LABEL: func @invariant_loop_dialect
func @invariant_loop_dialect() {
%ci0 = constant 0 : index
%ci10 = constant 10 : index
Expand All @@ -211,7 +255,7 @@ func @invariant_loop_dialect() {
%cf7 = constant 7.0 : f32
%cf8 = constant 8.0 : f32
loop.for %arg0 = %ci0 to %ci10 step %ci1 {
loop.for %arg1 = %ci0 to %ci10 step %ci1 {
loop.for %arg1 = %ci0 to %ci1 step %ci10 {
%v0 = addf %cf7, %cf8 : f32
}
}
Expand All @@ -237,8 +281,8 @@ func @variant_loop_dialect() {

// CHECK: %0 = alloc() : memref<10xf32>
// CHECK-NEXT: loop.for
// CHECK-NEXT: loop.for
// CHECK-NEXT: addi
// CHECK-NEXT: loop.for
// CHECK-NEXT: addi

return
}
19 changes: 19 additions & 0 deletions test/Transforms/simplify-affine-structures.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,22 @@ func @test_empty_set(%N : index) {

return
}

// CHECK-LABEL: func @zero_trip_count_loops
func @zero_trip_count_loops(%N : index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c-1 = constant -1 : index
%M = affine.apply (d0) -> ((2*d0 + 4) mod 2)(%N)
affine.for %i = 0 to %M {
}
affine.for %i = 0 to -1 {
}
loop.for %i = %M to %M step %c1 {
}
loop.for %i = %c0 to %c-1 step %N {
}
// All loops above should disappear.
// CHECK-NOT: loop.for
return
}
11 changes: 8 additions & 3 deletions tools/mlir-tblgen/OpInterfacesGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ using mlir::tblgen::OpInterfaceMethod;
// beginning of the argument list.
static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
raw_ostream &os, bool addOperationArg) {
os << method.getName() << '(';
// Whenever an operation argument is added, suffix helper method name with an
// underscore to avoid conflicts with free functions of same name on the
// concrete ops using this interface.
os << method.getName() << (addOperationArg ? "_(" : "(");
if (addOperationArg)
os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
interleaveComma(method.getArguments(), os,
Expand All @@ -64,9 +67,11 @@ static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) {
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);

// Forward to the method on the concrete operation type.
os << " {\n return getImpl()->" << method.getName() << '(';
os << " {\n return getImpl()->" << method.getName();
if (!method.isStatic())
os << "getOperation()" << (method.arg_empty() ? "" : ", ");
os << "_(getOperation()" << (method.arg_empty() ? "" : ", ");
else
os << "(";
interleaveComma(
method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
Expand Down

0 comments on commit 95526b1

Please sign in to comment.