Skip to content

Commit

Permalink
[TailRecursionElimination] Update the set of a return blocks before w…
Browse files Browse the repository at this point in the history
…hich to insert a sync if TRE eliminates that block.
  • Loading branch information
neboat committed Jan 23, 2024
1 parent 14f9bbe commit ef5a15f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
15 changes: 12 additions & 3 deletions llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class TailRecursionEliminator {
Instruction *AccumulatorRecursionInstr = nullptr;

// Map from sync region to return blocks to sync for that sync region.
DenseMap<Value *, SmallPtrSet<BasicBlock *, 4>> ReturnBlocksToSync;
DenseMap<Value *, SmallPtrSet<BasicBlock *, 2>> ReturnBlocksToSync;

TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
Expand All @@ -442,6 +442,8 @@ class TailRecursionEliminator {

bool eliminateCall(CallInst *CI);

void RemoveReturnBlockToSync(BasicBlock *RetBlock);

void InsertSyncsIntoReturnBlocks();

void cleanupAndFinalize();
Expand Down Expand Up @@ -847,6 +849,11 @@ getReturnBlocksToSync(BasicBlock *Entry, SyncInst *Sync,
}
}

void TailRecursionEliminator::RemoveReturnBlockToSync(BasicBlock *RetBlock) {
for (auto &ReturnsToSync : ReturnBlocksToSync)
ReturnsToSync.second.erase(RetBlock);
}

static bool hasPrecedingSync(SyncInst *SI) {
// TODO: Save the results from previous calls to hasPrecedingSync, in order to
// speed up multiple calls to this routine for different sync instructions.
Expand Down Expand Up @@ -941,8 +948,10 @@ bool TailRecursionEliminator::processBlock(BasicBlock &BB) {
// because the ret instruction in there is still using a value which
// eliminateCall will attempt to remove. This block can only contain
// instructions that can't have uses, therefore it is safe to remove.
if (pred_empty(Succ))
if (pred_empty(Succ)) {
RemoveReturnBlockToSync(Succ);
DTU.deleteBB(Succ);
}

eliminateCall(CI);
return true;
Expand Down Expand Up @@ -1065,7 +1074,7 @@ bool TailRecursionEliminator::processBlock(BasicBlock &BB) {
// We defer the restoration of syncs at relevant return blocks until after
// all blocks are processed. This approach simplifies the logic for
// eliminating multiple tail calls that are only separated from the return
// by a sync, since the CFG won't be perturbed unnecessarily.
// by a sync, since the CFG won't be changed unnecessarily.
} else {
// Restore the sync that was eliminated.
BasicBlock *RetBlock = Ret->getParent();
Expand Down
53 changes: 53 additions & 0 deletions llvm/test/Transforms/Tapir/tre-remove-return.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
; Check that tail-call elimination handles deletion of return blocks
; when it attempts to insert sync instructions before returns.
;
; RUN: opt < %s -passes="cgscc(devirt<4>(function<eager-inv;no-rerun>(tailcallelim)))" -S | FileCheck %s
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
target triple = "arm64-apple-macosx14.0.0"

define void @_Z3dacPfPKfS1_xxxxxx() personality ptr null {
entry:
%syncreg = call token @llvm.syncregion.start()
%0 = call token @llvm.tapir.runtime.start()
br i1 false, label %if.end10, label %if.then7

if.then7: ; preds = %entry
call void @_Z3dacPfPKfS1_xxxxxx()
sync within %syncreg, label %cleanup

if.end10: ; preds = %entry
call void @_Z3dacPfPKfS1_xxxxxx()
br label %cleanup

cleanup: ; preds = %if.end10, %if.then7
ret void
}

; CHECK: define void @_Z3dacPfPKfS1_xxxxxx()

; CHECK: entry:
; CHECK-NEXT: %syncreg = {{.*}}call token @llvm.syncregion.start()
; CHECK-NEXT: br label %[[TAILRECURSE:.+]]

; CHECK: [[TAILRECURSE]]:
; CHECK: br i1 false, label %if.end10, label %if.then7

; CHECK: if.then7:
; CHECK-NOT: sync
; CHECK-NEXT: br label %[[TAILRECURSE]]

; CHECK: if.end10:
; CHECK-NEXT: br label %[[TAILRECURSE]]

; CHECK-NOT: ret void

; Function Attrs: nounwind willreturn memory(argmem: readwrite)
declare token @llvm.tapir.runtime.start() #0

; Function Attrs: nounwind willreturn memory(argmem: readwrite)
declare token @llvm.syncregion.start() #0

; uselistorder directives
uselistorder ptr null, { 1, 2, 0 }

attributes #0 = { nounwind willreturn memory(argmem: readwrite) }

0 comments on commit ef5a15f

Please sign in to comment.