Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
skatrak committed Oct 25, 2024
1 parent 0443d95 commit 36f00b9
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 119 deletions.
38 changes: 31 additions & 7 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,13 +717,12 @@ class OpenMPIRBuilder {
Value *TripCount,
const Twine &Name = "loop");

/// Generator for the control flow structure of an OpenMP canonical loop.
/// Calculate the trip count of a canonical loop.
///
/// Instead of a logical iteration space, this allows specifying user-defined
/// loop counter values using increment, upper- and lower bounds. To
/// disambiguate the terminology when counting downwards, instead of lower
/// bounds we use \p Start for the loop counter value in the first body
/// iteration.
/// This allows specifying user-defined loop counter values using increment,
/// upper- and lower bounds. To disambiguate the terminology when counting
/// downwards, instead of lower bounds we use \p Start for the loop counter
/// value in the first body iteration.
///
/// Consider the following limitations:
///
Expand All @@ -747,7 +746,32 @@ class OpenMPIRBuilder {
///
/// for (int i = 0; i < 42; i -= 1u)
///
//
/// \param Loc The insert and source location description.
/// \param Start Value of the loop counter for the first iterations.
/// \param Stop Loop counter values past this will stop the loop.
/// \param Step Loop counter increment after each iteration; negative
/// means counting down.
/// \param IsSigned Whether Start, Stop and Step are signed integers.
/// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
/// counter.
/// \param Name Base name used to derive instruction names.
///
/// \returns The value holding the calculated trip count.
Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc,
Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop,
const Twine &Name = "loop");

/// Generator for the control flow structure of an OpenMP canonical loop.
///
/// Instead of a logical iteration space, this allows specifying user-defined
/// loop counter values using increment, upper- and lower bounds. To
/// disambiguate the terminology when counting downwards, instead of lower
/// bounds we use \p Start for the loop counter value in the first body
///
/// It calls \see calculateCanonicalLoopTripCount for trip count calculations,
/// so limitations of that method apply here as well.
///
/// \param Loc The insert and source location description.
/// \param BodyGenCB Callback that will generate the loop body code.
/// \param Start Value of the loop counter for the first iterations.
Expand Down
29 changes: 18 additions & 11 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3968,26 +3968,21 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
return CL;
}

CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
InsertPointTy ComputeIP, const Twine &Name) {

Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop, const Twine &Name) {
// Consider the following difficulties (assuming 8-bit signed integers):
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
// DO I = 1, 100, 50
/// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
// DO I = 100, 0, -128
updateToLocation(Loc);

// Start, Stop and Step must be of the same integer type.
auto *IndVarTy = cast<IntegerType>(Start->getType());
assert(IndVarTy == Stop->getType() && "Stop type mismatch");
assert(IndVarTy == Step->getType() && "Step type mismatch");

LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
updateToLocation(ComputeLoc);

ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
ConstantInt *One = ConstantInt::get(IndVarTy, 1);

Expand Down Expand Up @@ -4026,8 +4021,20 @@ CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
}
Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
"omp_" + Name + ".tripcount");

return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
"omp_" + Name + ".tripcount");
}

CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
InsertPointTy ComputeIP, const Twine &Name) {
LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;

Value *TripCount = calculateCanonicalLoopTripCount(
ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);

auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
Builder.restoreIP(CodeGenIP);
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,10 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
Dialect *ompDialect = (*this)->getDialect();
Operation *capturedOp = nullptr;

// Process in pre-order to check operations from outermost to innermost,
// ensuring we only enter the region of an operation if it meets the criteria
// for being captured. We stop the exploration of nested operations as soon as
// we process a region with no operation to be captured.
walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op == *this)
return WalkResult::advance();
Expand All @@ -1792,6 +1796,7 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
return WalkResult::interrupt();

// Don't continue capturing nested operations if we reach an omp.loop_nest.
// Otherwise, process the contents of this operation.
capturedOp = op;
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
: WalkResult::advance();
Expand Down
152 changes: 51 additions & 101 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3636,31 +3636,40 @@ static uint64_t getTeamsReductionDataSize(mlir::omp::TeamsOp &teamsOp) {
return getReductionDataSize<mlir::omp::TeamsOp>(teamsOp);
}

/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
/// operation and populate output variables with their corresponding host value
/// (i.e. operand evaluated outside of the target region), based on their uses
/// inside of the target region.
///
/// Loop bounds and steps are only optionally populated, if output vectors are
/// provided.
static void
extractHostEvalClauses(OperandRange ops, llvm::ArrayRef<BlockArgument> args,
Value &numThreads, Value &numTeamsLower,
Value &numTeamsUpper, Value &threadLimit,
extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
Value &numTeamsLower, Value &numTeamsUpper,
Value &threadLimit,
llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
llvm::SmallVectorImpl<Value> *steps = nullptr) {
for (auto item : llvm::zip_equal(ops, args)) {
Value op = std::get<0>(item), arg = std::get<1>(item);
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
blockArgIface.getHostEvalBlockArgs())) {
Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);

for (Operation *user : arg.getUsers()) {
for (Operation *user : blockArg.getUsers()) {
llvm::TypeSwitch<Operation *>(user)
.Case([&](omp::TeamsOp teamsOp) {
if (teamsOp.getNumTeamsLower() == arg)
numTeamsLower = op;
else if (teamsOp.getNumTeamsUpper() == arg)
numTeamsUpper = op;
else if (teamsOp.getThreadLimit() == arg)
threadLimit = op;
if (teamsOp.getNumTeamsLower() == blockArg)
numTeamsLower = hostEvalVar;
else if (teamsOp.getNumTeamsUpper() == blockArg)
numTeamsUpper = hostEvalVar;
else if (teamsOp.getThreadLimit() == blockArg)
threadLimit = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
if (parallelOp.getNumThreads() == arg)
numThreads = op;
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
})
Expand All @@ -3670,10 +3679,10 @@ extractHostEvalClauses(OperandRange ops, llvm::ArrayRef<BlockArgument> args,
llvm::SmallVectorImpl<Value> *outBounds) -> bool {
bool found = false;
for (auto [i, lb] : llvm::enumerate(opBounds)) {
if (lb == arg) {
if (lb == blockArg) {
found = true;
if (outBounds)
(*outBounds)[i] = op;
(*outBounds)[i] = hostEvalVar;
}
}
return found;
Expand Down Expand Up @@ -3704,11 +3713,8 @@ static void initTargetDefaultBounds(
llvm::OpenMPIRBuilder::TargetKernelDefaultBounds &bounds,
bool isTargetDevice, bool isGPU) {
Value hostNumThreads, hostNumTeamsLower, hostNumTeamsUpper, hostThreadLimit;
extractHostEvalClauses(targetOp.getHostEvalVars(),
llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp)
.getHostEvalBlockArgs(),
hostNumThreads, hostNumTeamsLower, hostNumTeamsUpper,
hostThreadLimit);
extractHostEvalClauses(targetOp, hostNumThreads, hostNumTeamsLower,
hostNumTeamsUpper, hostThreadLimit);

// TODO Handle constant IF clauses
Operation *innermostCapturedOmpOp = targetOp.getInnermostCapturedOmpOp();
Expand Down Expand Up @@ -3829,10 +3835,7 @@ static void initTargetRuntimeBounds(
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
steps(numLoops);
extractHostEvalClauses(targetOp.getHostEvalVars(),
llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp)
.getHostEvalBlockArgs(),
numThreads, numTeamsLower, numTeamsUpper,
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);

// TODO Handle IF clauses.
Expand All @@ -3856,86 +3859,33 @@ static void initTargetRuntimeBounds(
if (numThreads)
bounds.MaxThreads = moduleTranslation.lookupValue(numThreads);

// To calculate the trip count, we first create a placeholder set of canonical
// loops which we then collapse. Then, we skip execution of the collapsed loop
// and remove the basic blocks that originally defined it. The trip count
// remains available in the entry block.
// TODO: Improve implementation by extracting the logic to calculate the trip
// count of the collapsed loop nest based on the bounds and steps, rather than
// creating and then removing the loop itself.
if (targetOp.isTargetSPMDLoop()) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);

SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
llvm::OpenMPIRBuilder::InsertPointTy bodyLastInsertPoint;
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
llvm::Value *iv) {
bodyLastInsertPoint = ip;
builder.restoreIP(ip);
};

for (unsigned i = 0, e = numLoops; i < e; ++i) {
llvm::Value *lowerBound = moduleTranslation.lookupValue(lowerBounds[i]);
llvm::Value *upperBound = moduleTranslation.lookupValue(upperBounds[i]);
llvm::Value *step = moduleTranslation.lookupValue(steps[i]);

llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
if (i != 0) {
loc = llvm::OpenMPIRBuilder::LocationDescription(bodyLastInsertPoint);
computeIP = loopInfos.front()->getPreheaderIP();
}
loopInfos.push_back(ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP));
}

// Collapse loops. Store the insertion point because LoopInfos may get
// invalidated.
llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
llvm::CanonicalLoopInfo *loopInfo =
ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});

bounds.LoopTripCount = loopInfo->getTripCount();

auto removeLoop = [](llvm::BasicBlock &preheader) {
assert(preheader.hasNPredecessors(0) &&
"loop entry block expected to be unreachable");

// Collect blocks reachable from the loop entry.
llvm::df_iterator_default_set<llvm::BasicBlock *> reachable;
for (llvm::BasicBlock *block :
llvm::depth_first_ext(&preheader, reachable))
(void)block;

// Mark as dead all blocks that are only reachable from the loop entry.
std::vector<llvm::BasicBlock *> deadBlocks;
for (llvm::BasicBlock *block : reachable) {
auto predecessors = llvm::predecessors(block);
bool dead =
llvm::find_if(predecessors, [&reachable](llvm::BasicBlock *pred) {
return !reachable.count(pred);
}) == predecessors.end();

if (dead)
deadBlocks.push_back(block);
bounds.LoopTripCount = nullptr;

// To calculate the trip count, we multiply together the trip counts of
// every collapsed canonical loop. We don't need to create the loop nests
// here, since we're only interested in the trip count.
for (auto [loopLower, loopUpper, loopStep] :
llvm::zip_equal(lowerBounds, upperBounds, steps)) {
llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
llvm::Value *step = moduleTranslation.lookupValue(loopStep);

llvm::OpenMPIRBuilder::LocationDescription loc(builder);
llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
loc, lowerBound, upperBound, step, /*IsSigned=*/true,
loopOp.getLoopInclusive());

if (!bounds.LoopTripCount) {
bounds.LoopTripCount = tripCount;
continue;
}

// Delete the dead blocks.
llvm::DeleteDeadBlocks(deadBlocks);
};

// Skip execution of the canonical loop.
llvm::Instruction *terminator =
loopInfo->getPreheader()->getSinglePredecessor()->getTerminator();
builder.SetInsertPoint(terminator);
builder.CreateBr(afterIP.getBlock());
terminator->eraseFromParent();
builder.restoreIP(afterIP);

// Delete blocks associated to the loop.
removeLoop(*loopInfo->getPreheader());
// TODO: Enable UndefinedSanitizer to diagnose an overflow here.
bounds.LoopTripCount = builder.CreateMul(bounds.LoopTripCount, tripCount,
{}, /*HasNUW=*/true);
}
}
}

Expand Down

0 comments on commit 36f00b9

Please sign in to comment.