Skip to content

Commit

Permalink
[TapirUtils] Perform basic updates to TaskInfo analysis when serializ…
Browse files Browse the repository at this point in the history
…ing detaches, to support serializing nested tasks.
  • Loading branch information
neboat committed Nov 19, 2024
1 parent ddca2cf commit 8074d9d
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 28 deletions.
8 changes: 5 additions & 3 deletions llvm/include/llvm/Transforms/Utils/TapirUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ bool MoveStaticAllocasInBlock(BasicBlock *Entry, BasicBlock *Block,

/// Inline any taskframe.resume markers associated with the given taskframe. If
/// \p DT is provided, then it will be updated to reflect the CFG changes.
void InlineTaskFrameResumes(Value *TaskFrame, DominatorTree *DT = nullptr);
void InlineTaskFrameResumes(Value *TaskFrame, DominatorTree *DT = nullptr,
TaskInfo *TI = nullptr);

/// Clone exception-handling blocks EHBlocksToClone, with predecessors
/// EHBlockPreds in a given task. Updates EHBlockPreds to point at the cloned
Expand Down Expand Up @@ -131,7 +132,8 @@ void SerializeDetach(DetachInst *DI, BasicBlock *ParentEntry,
SmallPtrSetImpl<LandingPadInst *> *InlinedLPads,
SmallVectorImpl<Instruction *> *DetachedRethrows,
bool ReplaceWithTaskFrame = false,
DominatorTree *DT = nullptr, LoopInfo *LI = nullptr);
DominatorTree *DT = nullptr, TaskInfo *TI = nullptr,
LoopInfo *LI = nullptr);

/// Analyze a task T for serialization. Gets the reattaches, landing pads, and
/// detached rethrows that need special handling during serialization.
Expand All @@ -145,7 +147,7 @@ void AnalyzeTaskForSerialization(
/// Serialize the detach DI that spawns task T. If \p DT is provided, then it
/// will be updated to reflect the CFG changes.
void SerializeDetach(DetachInst *DI, Task *T, bool ReplaceWithTaskFrame = false,
DominatorTree *DT = nullptr);
DominatorTree *DT = nullptr, TaskInfo *TI = nullptr);

/// Get the entry basic block to the detached context that contains
/// the specified block.
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/Transforms/Utils/TaskSimplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TaskSimplifyPass : public PassInfoMixin<TaskSimplifyPass> {
bool simplifySyncs(Task *T, MaybeParallelTasks &MPTasks);

/// Simplify the specified task T.
bool simplifyTask(Task *T);
bool simplifyTask(Task *T, TaskInfo &TI, DominatorTree &DT);

/// Simplify the taskframes analyzed by TapirTaskInfo TI.
bool simplifyTaskFrames(TaskInfo &TI, DominatorTree &DT);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Tapir/LoopStripMine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ Loop *llvm::StripMineLoop(Loop *L, unsigned Count, bool AllowExpensiveTripCount,
SerializeDetach(ClonedDI, ParentEntry, EHCont, EHContLPadVal,
ClonedReattaches, &ClonedEHBlocks, &ClonedEHBlockPreds,
&ClonedInlinedLPads, &ClonedDetachedRethrows,
NeedToInsertTaskFrame, DT, LI);
NeedToInsertTaskFrame, DT, nullptr, LI);
}

// Detach the stripmined loop.
Expand Down
51 changes: 34 additions & 17 deletions llvm/lib/Transforms/Utils/TapirUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,16 @@ class LandingPadInliningInfo {

/// Dominator tree to update.
DominatorTree *DT = nullptr;

/// TaskInfo to update.
TaskInfo *TI = nullptr;

public:
LandingPadInliningInfo(DetachInst *DI, BasicBlock *EHContinue,
Value *LPadValInEHContinue,
DominatorTree *DT = nullptr)
: OuterResumeDest(EHContinue), SpawnerLPad(LPadValInEHContinue), DT(DT) {
DominatorTree *DT = nullptr, TaskInfo *TI = nullptr)
: OuterResumeDest(EHContinue), SpawnerLPad(LPadValInEHContinue), DT(DT),
TI(TI) {
// Find the predecessor block of OuterResumeDest.
BasicBlock *DetachBB = DI->getParent();
BasicBlock *DetachUnwind = DI->getUnwindDest();
Expand All @@ -414,9 +419,9 @@ class LandingPadInliningInfo {
}

LandingPadInliningInfo(InvokeInst *TaskFrameResume,
DominatorTree *DT = nullptr)
DominatorTree *DT = nullptr, TaskInfo *TI = nullptr)
: OuterResumeDest(TaskFrameResume->getUnwindDest()),
SpawnerLPad(TaskFrameResume->getLandingPadInst()), DT(DT) {
SpawnerLPad(TaskFrameResume->getLandingPadInst()), DT(DT), TI(TI) {
// If there are PHI nodes in the unwind destination block, we need to keep
// track of which values came into them from the detach before removing the
// edge from this block.
Expand Down Expand Up @@ -484,6 +489,8 @@ BasicBlock *LandingPadInliningInfo::getInnerResumeDest() {
for (DomTreeNode *I : Children)
DT->changeImmediateDominator(I, NewNode);
}
if (TI)
TI->addBlockToSpindle(*InnerResumeDest, TI->getSpindleFor(OuterResumeDest));

// The number of incoming edges we expect to the inner landing pad.
const unsigned PHICapacity = 2;
Expand Down Expand Up @@ -571,11 +578,15 @@ void LandingPadInliningInfo::forwardTaskResume(InvokeInst *TR) {
if (NormalDest) {
for (BasicBlock *Succ : successors(NormalDest))
maybeRemovePredecessor(Succ, NormalDest);
if (TI)
TI->removeBlock(*NormalDest);
NormalDest->eraseFromParent();
}
if (UnwindDest) {
for (BasicBlock *Succ : successors(UnwindDest))
maybeRemovePredecessor(Succ, UnwindDest);
if (TI)
TI->removeBlock(*UnwindDest);
UnwindDest->eraseFromParent();
}
}
Expand All @@ -584,8 +595,8 @@ static void handleDetachedLandingPads(
DetachInst *DI, BasicBlock *EHContinue, Value *LPadValInEHContinue,
SmallPtrSetImpl<LandingPadInst *> &InlinedLPads,
SmallVectorImpl<Instruction *> &DetachedRethrows,
DominatorTree *DT = nullptr) {
LandingPadInliningInfo DetUnwind(DI, EHContinue, LPadValInEHContinue, DT);
DominatorTree *DT = nullptr, TaskInfo *TI = nullptr) {
LandingPadInliningInfo DetUnwind(DI, EHContinue, LPadValInEHContinue, DT, TI);

// Append the clauses from the outer landing pad instruction into the inlined
// landing pad instructions.
Expand Down Expand Up @@ -815,13 +826,14 @@ static void getTaskFrameLandingPads(
// Helper method to handle a given taskframe.resume.
static void handleTaskFrameResume(Value *TaskFrame,
Instruction *TaskFrameResume,
DominatorTree *DT = nullptr) {
DominatorTree *DT = nullptr,
TaskInfo *TI = nullptr) {
// Get landingpads to inline.
SmallPtrSet<LandingPadInst *, 1> InlinedLPads;
getTaskFrameLandingPads(TaskFrame, TaskFrameResume, InlinedLPads);

InvokeInst *TFR = cast<InvokeInst>(TaskFrameResume);
LandingPadInliningInfo TFResumeDest(TFR, DT);
LandingPadInliningInfo TFResumeDest(TFR, DT, TI);

// Append the clauses from the outer landing pad instruction into the inlined
// landing pad instructions.
Expand All @@ -839,7 +851,8 @@ static void handleTaskFrameResume(Value *TaskFrame,
TFResumeDest.forwardTaskResume(TFR);
}

void llvm::InlineTaskFrameResumes(Value *TaskFrame, DominatorTree *DT) {
void llvm::InlineTaskFrameResumes(Value *TaskFrame, DominatorTree *DT,
TaskInfo *TI) {
SmallVector<Instruction *, 1> TaskFrameResumes;
// Record all taskframe.resume markers that use TaskFrame.
for (User *U : TaskFrame->users())
Expand All @@ -849,20 +862,20 @@ void llvm::InlineTaskFrameResumes(Value *TaskFrame, DominatorTree *DT) {

// Handle all taskframe.resume markers.
for (Instruction *TFR : TaskFrameResumes)
handleTaskFrameResume(TaskFrame, TFR, DT);
handleTaskFrameResume(TaskFrame, TFR, DT, TI);
}

static void startSerializingTaskFrame(Value *TaskFrame,
SmallVectorImpl<Instruction *> &ToErase,
DominatorTree *DT,
DominatorTree *DT, TaskInfo *TI,
bool PreserveTaskFrame) {
for (User *U : TaskFrame->users())
if (Instruction *UI = dyn_cast<Instruction>(U))
if (isTapirIntrinsic(Intrinsic::taskframe_use, UI))
ToErase.push_back(UI);

if (!PreserveTaskFrame)
InlineTaskFrameResumes(TaskFrame, DT);
InlineTaskFrameResumes(TaskFrame, DT, TI);
}

void llvm::SerializeDetach(DetachInst *DI, BasicBlock *ParentEntry,
Expand All @@ -873,7 +886,9 @@ void llvm::SerializeDetach(DetachInst *DI, BasicBlock *ParentEntry,
SmallPtrSetImpl<LandingPadInst *> *InlinedLPads,
SmallVectorImpl<Instruction *> *DetachedRethrows,
bool ReplaceWithTaskFrame, DominatorTree *DT,
LoopInfo *LI) {
TaskInfo *TI, LoopInfo *LI) {
LLVM_DEBUG(dbgs() << "Serializing detach " << *DI << "\n");

BasicBlock *Spawner = DI->getParent();
BasicBlock *TaskEntry = DI->getDetached();
BasicBlock *Continue = DI->getContinue();
Expand All @@ -885,7 +900,7 @@ void llvm::SerializeDetach(DetachInst *DI, BasicBlock *ParentEntry,
SmallVector<Instruction *, 8> ToErase;
Value *TaskFrame = getTaskFrameUsed(TaskEntry);
if (TaskFrame)
startSerializingTaskFrame(TaskFrame, ToErase, DT, ReplaceWithTaskFrame);
startSerializingTaskFrame(TaskFrame, ToErase, DT, TI, ReplaceWithTaskFrame);

// Clone any EH blocks that need cloning.
if (EHBlocksToClone) {
Expand Down Expand Up @@ -952,7 +967,7 @@ void llvm::SerializeDetach(DetachInst *DI, BasicBlock *ParentEntry,
} else {
// Otherwise, "inline" the detached landingpads.
handleDetachedLandingPads(DI, EHContinue, LPadValInEHContinue,
*InlinedLPads, *DetachedRethrows, DT);
*InlinedLPads, *DetachedRethrows, DT, TI);
}
}

Expand Down Expand Up @@ -1059,7 +1074,7 @@ void llvm::AnalyzeTaskForSerialization(
/// Serialize the detach DI that spawns task T. If provided, the dominator tree
/// DT will be updated to reflect the serialization.
void llvm::SerializeDetach(DetachInst *DI, Task *T, bool ReplaceWithTaskFrame,
DominatorTree *DT) {
DominatorTree *DT, TaskInfo *TI) {
assert(DI && "SerializeDetach given nullptr for detach.");
assert(DI == T->getDetach() && "Task and detach arguments do not match.");
SmallVector<BasicBlock *, 4> EHBlocksToClone;
Expand All @@ -1078,7 +1093,9 @@ void llvm::SerializeDetach(DetachInst *DI, Task *T, bool ReplaceWithTaskFrame,
}
SerializeDetach(DI, T->getParentTask()->getEntry(), EHContinue, LPadVal,
Reattaches, &EHBlocksToClone, &EHBlockPreds, &InlinedLPads,
&DetachedRethrows, ReplaceWithTaskFrame, DT);
&DetachedRethrows, ReplaceWithTaskFrame, DT, TI);
if (TI)
TI->moveSpindlesToParent(T);
}

static bool isCanonicalTaskFrameEnd(const Instruction *TFEnd) {
Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/Transforms/Utils/TaskSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,11 @@ static bool detachImmediatelySyncs(DetachInst *DI) {
return isa<SyncInst>(I);
}

bool llvm::simplifyTask(Task *T) {
bool llvm::simplifyTask(Task *T, TaskInfo &TI, DominatorTree &DT) {
if (T->isRootTask())
return false;

DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
LLVM_DEBUG(dbgs() << "Simplifying task @ " << T->getEntry()->getName()
<< "\n");

Expand All @@ -254,7 +255,8 @@ bool llvm::simplifyTask(Task *T) {
// destination from T's detach.
if (DI->hasUnwindDest()) {
if (!taskCanThrow(T)) {
removeUnwindEdge(DI->getParent());
LLVM_DEBUG(dbgs() << "Removing unwind edge of " << *DI << "\n");
removeUnwindEdge(DI->getParent(), &DTU);
// removeUnwindEdge will invalidate the DI pointer. Get the new DI
// pointer.
DI = T->getDetach();
Expand All @@ -263,13 +265,17 @@ bool llvm::simplifyTask(Task *T) {
}

if (!taskCanReachContinuation(T)) {
LLVM_DEBUG(dbgs() << "Task cannot reach continuation. Serializing " << *DI
<< "\n");
// This optimization assumes that if a task cannot reach its continuation
// then we shouldn't bother spawning it. The task might perform code that
// can reach the unwind destination, however.
SerializeDetach(DI, T, NestedSync);
SerializeDetach(DI, T, NestedSync, &DT, &TI);
Changed = true;
} else if (!PreserveAllSpawns && detachImmediatelySyncs(DI)) {
SerializeDetach(DI, T, NestedSync);
LLVM_DEBUG(dbgs() << "Detach immediately syncs. Serializing " << *DI
<< "\n");
SerializeDetach(DI, T, NestedSync, &DT, &TI);
Changed = true;
}

Expand Down Expand Up @@ -651,7 +657,7 @@ bool TaskSimplify::runOnFunction(Function &F) {

// Simplify each task in the function.
for (Task *T : post_order(TI.getRootTask()))
Changed |= simplifyTask(T);
Changed |= simplifyTask(T, TI, DT);

if (PostCleanupCFG && (Changed | SplitBlocks))
Changed |= simplifyFunctionCFG(F, TTI, nullptr, Options);
Expand Down Expand Up @@ -729,7 +735,7 @@ PreservedAnalyses TaskSimplifyPass::run(Function &F,

// Simplify each task in the function.
for (Task *T : post_order(TI.getRootTask()))
Changed |= simplifyTask(T);
Changed |= simplifyTask(T, TI, DT);

if (PostCleanupCFG && (Changed | SplitBlocks))
Changed |= simplifyFunctionCFG(F, TTI, nullptr, Options);
Expand Down
98 changes: 98 additions & 0 deletions llvm/test/Transforms/Tapir/nested-serialize-detach.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
; Check that nested detaches can be serialized.
;
; RUN: opt < %s -passes="function<eager-inv>(task-simplify)" -S | FileCheck %s
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
target triple = "arm64-apple-macosx15.0.0"

; Function Attrs: nounwind willreturn memory(argmem: readwrite)
declare token @llvm.syncregion.start() #0

; Function Attrs: willreturn memory(argmem: readwrite)
declare void @llvm.sync.unwind(token) #1

; Function Attrs: willreturn memory(argmem: readwrite)
declare void @llvm.detached.rethrow.sl_p0i32s(token, { ptr, i32 }) #1

; Function Attrs: nounwind willreturn memory(argmem: readwrite)
declare token @llvm.taskframe.create() #0

; CHECK: define void @_ZNK5Graph17pbfs_walk_PennantEP7PennantIiERH3BagIiEjPj()
; CHECK-NEXT: entry:
; CHECK-NOT: detach within
; CHECK: unreachable

define void @_ZNK5Graph17pbfs_walk_PennantEP7PennantIiERH3BagIiEjPj() personality ptr null {
entry:
%syncreg = tail call token @llvm.syncregion.start()
%syncreg45 = tail call token @llvm.syncregion.start()
%0 = tail call token @llvm.tapir.runtime.start()
detach within %syncreg45, label %pfor.body.entry.tf, label %pfor.inc unwind label %lpad59.loopexit

pfor.body.entry.tf: ; preds = %entry
%tf.i = tail call token @llvm.taskframe.create()
%syncreg.i = tail call token @llvm.syncregion.start()
detach within %syncreg.i, label %pfor.cond.i.strpm.detachloop.entry, label %pfor.cond.cleanup.i unwind label %lpad4924.loopexit.split-lp

pfor.cond.i.strpm.detachloop.entry: ; preds = %pfor.body.entry.tf
%syncreg.i.strpm.detachloop = tail call token @llvm.syncregion.start()
detach within none, label %pfor.body.entry.i.strpm.outer.1, label %pfor.inc.i.strpm.outer.1 unwind label %lpad4924.loopexit.strpm

pfor.body.entry.i.strpm.outer.1: ; preds = %pfor.cond.i.strpm.detachloop.entry
invoke void @llvm.detached.rethrow.sl_p0i32s(token none, { ptr, i32 } zeroinitializer)
to label %lpad4924.unreachable unwind label %lpad4924.loopexit.strpm

pfor.inc.i.strpm.outer.1: ; preds = %pfor.cond.i.strpm.detachloop.entry
sync within none, label %pfor.cond.i.strpm.detachloop.reattach.split

pfor.cond.i.strpm.detachloop.reattach.split: ; preds = %pfor.inc.i.strpm.outer.1
reattach within %syncreg.i, label %pfor.cond.cleanup.i

pfor.cond.cleanup.i: ; preds = %pfor.cond.i.strpm.detachloop.reattach.split, %pfor.body.entry.tf
sync within %syncreg.i, label %sync.continue.i

sync.continue.i: ; preds = %pfor.cond.cleanup.i
invoke void @llvm.sync.unwind(token none)
to label %pfor.preattach unwind label %lpad4924.loopexit.split-lp

lpad4924.loopexit.strpm: ; preds = %pfor.body.entry.i.strpm.outer.1, %pfor.cond.i.strpm.detachloop.entry
%lpad.strpm = landingpad { ptr, i32 }
cleanup
invoke void @llvm.detached.rethrow.sl_p0i32s(token %syncreg.i, { ptr, i32 } zeroinitializer)
to label %lpad4924.loopexit.strpm.unreachable unwind label %lpad4924.loopexit.split-lp

lpad4924.loopexit.strpm.unreachable: ; preds = %lpad4924.loopexit.strpm
unreachable

lpad4924.loopexit.split-lp: ; preds = %lpad4924.loopexit.strpm, %sync.continue.i, %pfor.body.entry.tf
%lpad.loopexit.split-lp = landingpad { ptr, i32 }
cleanup
call void @llvm.detached.rethrow.sl_p0i32s(token none, { ptr, i32 } zeroinitializer)
unreachable

lpad4924.unreachable: ; preds = %pfor.body.entry.i.strpm.outer.1
unreachable

pfor.preattach: ; preds = %sync.continue.i
reattach within %syncreg45, label %pfor.inc

pfor.inc: ; preds = %pfor.preattach, %entry
ret void

lpad59.loopexit: ; preds = %entry
%lpad.loopexit28 = landingpad { ptr, i32 }
cleanup
tail call void @llvm.tapir.runtime.end(token %0)
resume { ptr, i32 } zeroinitializer
}

; Function Attrs: nounwind willreturn memory(argmem: readwrite)
declare token @llvm.tapir.runtime.start() #0

; Function Attrs: nounwind willreturn memory(argmem: readwrite)
declare void @llvm.tapir.runtime.end(token) #0

; uselistorder directives
uselistorder ptr null, { 1, 2, 0 }

attributes #0 = { nounwind willreturn memory(argmem: readwrite) }
attributes #1 = { willreturn memory(argmem: readwrite) }

0 comments on commit 8074d9d

Please sign in to comment.