Skip to content

Commit

Permalink
[SemaStmt] Simplify synthesized loop-bound computation for simple cil…
Browse files Browse the repository at this point in the history
…k_for loops with unit stride.
  • Loading branch information
neboat committed Dec 6, 2023
1 parent 262b7b2 commit 2e33183
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 77 deletions.
62 changes: 26 additions & 36 deletions clang/lib/Sema/SemaStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3568,29 +3568,27 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc,
Expr *LimitExpr = nullptr;
if (DeclUseInLHS)
LimitExpr = Cond->getRHS();
else // if (DeclUseInRHS)
else // DeclUseInRHS
LimitExpr = Cond->getLHS();
if (!LimitExpr)
return StmtEmpty();

// Get the loop stride.
if (!Increment)
return StmtEmpty();
Expr *Stride = nullptr;
bool StrideIsUnit = false;
bool StrideIsNegative = false;
if (const UnaryOperator *UO =
dyn_cast_or_null<UnaryOperator>(Increment)) {
if (UO->isIncrementOp())
Stride = ActOnIntegerConstant(Increment->getExprLoc(), 1).get();
else if (UO->isDecrementOp()) {
Stride = ActOnIntegerConstant(Increment->getExprLoc(), 1).get();
StrideIsUnit = true;
if (UO->isDecrementOp())
StrideIsNegative = true;
}
} else {
auto StrideWithSign = GetCilkForStride(*this, Decls, Increment);
StrideIsNegative = StrideWithSign.second;
Stride = StrideWithSign.first;
}
if (!Stride)
return StmtEmpty();

// Determine the type of comparison.
//
Expand Down Expand Up @@ -3624,17 +3622,16 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc,
// evaluated just once.
SourceLocation InitLoc = LoopVarInit->getBeginLoc();
// Add declaration to store the old loop var initialization.
VarDecl *InitVar = BuildForRangeVarDecl(*this, InitLoc,
LoopVarTy, "__init");
VarDecl *InitVar = BuildForRangeVarDecl(*this, InitLoc, LoopVarTy, "__init");
AddInitializerToDecl(InitVar, LoopVarInit, /*DirectInit=*/false);
FinalizeDeclaration(InitVar);
CurContext->addHiddenDecl(InitVar);

// Create a declaration for the limit of this loop, to ensure its evaluated
// just once.
SourceLocation LimitLoc = LimitExpr->getBeginLoc();
VarDecl *LimitVar = BuildForRangeVarDecl(*this, LimitLoc,
LoopVarTy, "__limit");
VarDecl *LimitVar =
BuildForRangeVarDecl(*this, LimitLoc, LoopVarTy, "__limit");
AddInitializerToDecl(LimitVar, LimitExpr, /*DirectInit=*/false);
FinalizeDeclaration(LimitVar);
CurContext->addHiddenDecl(LimitVar);
Expand All @@ -3651,17 +3648,11 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc,
if (LimitDecl.isInvalid())
return StmtError();

ExprResult InitRef = BuildDeclRefExpr(InitVar, LoopVarTy, VK_LValue,
InitLoc);
ExprResult LimitRef = BuildDeclRefExpr(LimitVar, LimitVar->getType(),
VK_LValue, LimitLoc);

// LimitVar should have the correct type, because it's derived from the
// original condition. Hence we only need to cast InitRef.
ExprResult CastInit = ImplicitCastExpr::Create(
Context, LimitVar->getType(), CK_IntegralCast, InitRef.get(), nullptr,
VK_XValue, FPOptionsOverride());
ExprResult InitRef = BuildDeclRefExpr(InitVar, LoopVarTy, VK_LValue, InitLoc);
ExprResult LimitRef =
BuildDeclRefExpr(LimitVar, LoopVarTy, VK_LValue, LimitLoc);

ExprResult CastInit = InitRef;
// Compute a check that this _Cilk_for loop executes at all.
SourceLocation CondLoc = Cond->getExprLoc();
ExprResult InitCond;
Expand Down Expand Up @@ -3691,15 +3682,17 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc,
// Now rewrite the loop control.

// If the comparison is not inclusive, reduce the Range by 1.
if (!CompareInclusive)
if (!CompareInclusive && !StrideIsUnit)
Range = BuildBinOp(S, CondLoc, BO_Sub, Range.get(),
ActOnIntegerConstant(CilkForLoc, 1).get());

// Build Range/Stride.
ExprResult NewLimit = BuildBinOp(S, CondLoc, BO_Div, Range.get(), Stride);
ExprResult NewLimit = Range;
if (!StrideIsUnit)
// Build Range/Stride.
NewLimit = BuildBinOp(S, CondLoc, BO_Div, Range.get(), Stride);

// If the comparison is not an equality, build Range/Stride + 1
if (!CompareInclusive)
if (!CompareInclusive && !StrideIsUnit)
NewLimit = BuildBinOp(S, CondLoc, BO_Add, NewLimit.get(),
ActOnIntegerConstant(CilkForLoc, 1).get());

Expand All @@ -3709,8 +3702,7 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc,

// Create new declarations for replacement loop control variables.
// Declaration for new beginning loop control variable.
VarDecl *BeginVar = BuildForRangeVarDecl(*this, CondLoc, CountTy,
"__begin");
VarDecl *BeginVar = BuildForRangeVarDecl(*this, CondLoc, CountTy, "__begin");
AddInitializerToDecl(BeginVar, ActOnIntegerConstant(CondLoc, 0).get(),
/*DirectInit=*/false);
FinalizeDeclaration(BeginVar);
Expand Down Expand Up @@ -3767,20 +3759,18 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc,
}
} else if (const UnaryOperator *UO =
dyn_cast_or_null<UnaryOperator>(Increment)) {
if (UO->isIncrementOp())
NewInc = BuildUnaryOp(S, IncLoc, UO_PreInc, BeginRef.get());
else if (UO->isDecrementOp())
NewInc = BuildUnaryOp(S, IncLoc, UO_PreInc, BeginRef.get());
NewInc = BuildUnaryOp(S, IncLoc, UO_PreInc, BeginRef.get());
}
if (NewInc.isInvalid())
return StmtError();

// Return a new statement for initializing the old loop variable.
SourceLocation LoopVarLoc = LoopVar->getBeginLoc();
ExprResult NewLoopVarInit =
BuildBinOp(S, LoopVarLoc, StrideIsNegative ? BO_Sub : BO_Add, InitRef.get(),
BuildBinOp(S, LoopVarLoc, BO_Mul,
BeginRef.get(), Stride).get());
ExprResult NewLoopVarInit = BuildBinOp(
S, LoopVarLoc, StrideIsNegative ? BO_Sub : BO_Add, InitRef.get(),
StrideIsUnit
? BeginRef.get()
: BuildBinOp(S, LoopVarLoc, BO_Mul, BeginRef.get(), Stride).get());
if (!NewLoopVarInit.isInvalid())
AddInitializerToDecl(LoopVar, NewLoopVarInit.get(), /*DirectInit=*/false);

Expand Down
56 changes: 16 additions & 40 deletions clang/test/Cilk/cilkfor-bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@ void up(size_t start, size_t end) {
// CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]]
// CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDLIMIT]], %[[ENDINIT]]
// CHECK-NEXT: %[[ENDSUB1:.+]] = sub i64 %[[ENDSUB]], 1
// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB1]], 1
// CHECK-NEXT: %[[ENDADD:.+]] = add i64 %[[ENDDIV]], 1
// CHECK-NEXT: store i64 %[[ENDADD]], ptr %[[END:.+]], align 8
// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8

// CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]]
// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1
// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[ITERMUL]]
// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[BEGINITER]]
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]]

// CHECK: [[DETACHED]]:
Expand Down Expand Up @@ -74,13 +70,11 @@ void up_leq(size_t start, size_t end) {
// CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]]
// CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDLIMIT]], %[[ENDINIT]]
// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB]], 1
// CHECK-NEXT: store i64 %[[ENDDIV]], ptr %[[END:.+]], align 8
// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8

// CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]]
// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1
// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[ITERMUL]]
// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[BEGINITER]]
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]]

// CHECK: [[DETACHED]]:
Expand Down Expand Up @@ -120,15 +114,11 @@ void up_flip(size_t start, size_t end) {
// CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]]
// CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDLIMIT]], %[[ENDINIT]]
// CHECK-NEXT: %[[ENDSUB1:.+]] = sub i64 %[[ENDSUB]], 1
// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB1]], 1
// CHECK-NEXT: %[[ENDADD:.+]] = add i64 %[[ENDDIV]], 1
// CHECK-NEXT: store i64 %[[ENDADD]], ptr %[[END:.+]], align 8
// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8

// CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]]
// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1
// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[ITERMUL]]
// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[BEGINITER]]
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]]

// CHECK: [[DETACHED]]:
Expand Down Expand Up @@ -168,13 +158,11 @@ void up_flip_geq(size_t start, size_t end) {
// CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]]
// CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDLIMIT]], %[[ENDINIT]]
// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB]], 1
// CHECK-NEXT: store i64 %[[ENDDIV]], ptr %[[END:.+]], align 8
// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8

// CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]]
// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1
// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[ITERMUL]]
// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[BEGINITER]]
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]]

// CHECK: [[DETACHED]]:
Expand Down Expand Up @@ -506,15 +494,11 @@ void down(size_t start, size_t end) {
// CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]]
// CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDINIT]], %[[ENDLIMIT]]
// CHECK-NEXT: %[[ENDSUB1:.+]] = sub i64 %[[ENDSUB]], 1
// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB1]], 1
// CHECK-NEXT: %[[ENDADD:.+]] = add i64 %[[ENDDIV]], 1
// CHECK-NEXT: store i64 %[[ENDADD]], ptr %[[END:.+]], align 8
// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8

// CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]]
// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1
// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[ITERMUL]]
// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[BEGINITER]]
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]]

// CHECK: [[DETACHED]]:
Expand Down Expand Up @@ -554,13 +538,11 @@ void down_geq(size_t start, size_t end) {
// CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]]
// CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDINIT]], %[[ENDLIMIT]]
// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB]], 1
// CHECK-NEXT: store i64 %[[ENDDIV]], ptr %[[END:.+]], align 8
// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8

// CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]]
// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1
// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[ITERMUL]]
// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[BEGINITER]]
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]]

// CHECK: [[DETACHED]]:
Expand Down Expand Up @@ -600,15 +582,11 @@ void down_flip(size_t start, size_t end) {
// CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]]
// CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDINIT]], %[[ENDLIMIT]]
// CHECK-NEXT: %[[ENDSUB1:.+]] = sub i64 %[[ENDSUB]], 1
// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB1]], 1
// CHECK-NEXT: %[[ENDADD:.+]] = add i64 %[[ENDDIV]], 1
// CHECK-NEXT: store i64 %[[ENDADD]], ptr %[[END:.+]], align 8
// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8

// CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]]
// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1
// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[ITERMUL]]
// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[BEGINITER]]
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]]

// CHECK: [[DETACHED]]:
Expand Down Expand Up @@ -648,13 +626,11 @@ void down_flip_leq(size_t start, size_t end) {
// CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]]
// CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDINIT]], %[[ENDLIMIT]]
// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB]], 1
// CHECK-NEXT: store i64 %[[ENDDIV]], ptr %[[END:.+]], align 8
// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8

// CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]]
// CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]]
// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1
// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[ITERMUL]]
// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[BEGINITER]]
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]]

// CHECK: [[DETACHED]]:
Expand Down
29 changes: 28 additions & 1 deletion clang/test/Cilk/cilkfor-detach-unwind-rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,34 @@ int main() {
return total_red.get();
}

// CHECK: define {{.*}}i32 @main()
int stride_test() {
Reducer_sum<long> total_red(__cilkrts_get_nworkers());
_Cilk_for(long i = 0; i < 100; i += 5) { total_red.add(5); }
return total_red.get();
}

// CHECK-LABEL: define {{.*}}i32 @main()
// CHECK: br i1 %{{.*}}, label %[[PFOR_PH:.+]], label %[[PFOR_END:[a-z0-9._]+]]

// CHECK: [[PFOR_PH]]:
// CHECK: call void @__ubsan_handle_sub_overflow

// Check contents of the detach block
// CHECK: load i64, ptr %[[INIT:.+]]
// CHECK: load i64, ptr %[[BEGIN:.+]]
// CHECK: call { i64, i1 } @llvm.sadd.with.overflow.i64(
// CHECK: br i1 %{{.*}}, label %[[CONT5:.+]], label %[[HANDLE_ADD_OVERFLOW:[a-z0-9._]+]],

// CHECK: [[HANDLE_ADD_OVERFLOW]]:
// CHECK: call void @__ubsan_handle_add_overflow

// Check that the detach ends up after the loop-variable init expression.

// CHECK: [[CONT5]]:
// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[PFOR_BODY_ENTRY:.+]], label %[[PFOR_INC:.+]] unwind label %[[LPAD9:.+]]


// CHECK-LABEL: define {{.*}}i32 @_Z11stride_testv()
// CHECK: br i1 %{{.*}}, label %[[PFOR_PH:.+]], label %[[PFOR_END:[a-z0-9._]+]]

// CHECK: [[PFOR_PH]]:
Expand Down

0 comments on commit 2e33183

Please sign in to comment.