From 2e33183fdb22bf8db8918c86d7a4f18fff62829f Mon Sep 17 00:00:00 2001 From: TB Schardl Date: Sun, 3 Dec 2023 11:20:29 -0500 Subject: [PATCH] [SemaStmt] Simplify synthesized loop-bound computation for simple cilk_for loops with unit stride. --- clang/lib/Sema/SemaStmt.cpp | 62 ++++++++----------- clang/test/Cilk/cilkfor-bounds.cpp | 56 +++++------------ .../Cilk/cilkfor-detach-unwind-rewrite.cpp | 29 ++++++++- 3 files changed, 70 insertions(+), 77 deletions(-) diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp index e5c548bf0bdd..6bbf3476e740 100644 --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -3568,29 +3568,27 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc, Expr *LimitExpr = nullptr; if (DeclUseInLHS) LimitExpr = Cond->getRHS(); - else // if (DeclUseInRHS) + else // DeclUseInRHS LimitExpr = Cond->getLHS(); if (!LimitExpr) return StmtEmpty(); // Get the loop stride. + if (!Increment) + return StmtEmpty(); Expr *Stride = nullptr; + bool StrideIsUnit = false; bool StrideIsNegative = false; if (const UnaryOperator *UO = dyn_cast_or_null(Increment)) { - if (UO->isIncrementOp()) - Stride = ActOnIntegerConstant(Increment->getExprLoc(), 1).get(); - else if (UO->isDecrementOp()) { - Stride = ActOnIntegerConstant(Increment->getExprLoc(), 1).get(); + StrideIsUnit = true; + if (UO->isDecrementOp()) StrideIsNegative = true; - } } else { auto StrideWithSign = GetCilkForStride(*this, Decls, Increment); StrideIsNegative = StrideWithSign.second; Stride = StrideWithSign.first; } - if (!Stride) - return StmtEmpty(); // Determine the type of comparison. // @@ -3624,8 +3622,7 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc, // evaluated just once. SourceLocation InitLoc = LoopVarInit->getBeginLoc(); // Add declaration to store the old loop var initialization. - VarDecl *InitVar = BuildForRangeVarDecl(*this, InitLoc, - LoopVarTy, "__init"); + VarDecl *InitVar = BuildForRangeVarDecl(*this, InitLoc, LoopVarTy, "__init"); AddInitializerToDecl(InitVar, LoopVarInit, /*DirectInit=*/false); FinalizeDeclaration(InitVar); CurContext->addHiddenDecl(InitVar); @@ -3633,8 +3630,8 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc, // Create a declaration for the limit of this loop, to ensure its evaluated // just once. SourceLocation LimitLoc = LimitExpr->getBeginLoc(); - VarDecl *LimitVar = BuildForRangeVarDecl(*this, LimitLoc, - LoopVarTy, "__limit"); + VarDecl *LimitVar = + BuildForRangeVarDecl(*this, LimitLoc, LoopVarTy, "__limit"); AddInitializerToDecl(LimitVar, LimitExpr, /*DirectInit=*/false); FinalizeDeclaration(LimitVar); CurContext->addHiddenDecl(LimitVar); @@ -3651,17 +3648,11 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc, if (LimitDecl.isInvalid()) return StmtError(); - ExprResult InitRef = BuildDeclRefExpr(InitVar, LoopVarTy, VK_LValue, - InitLoc); - ExprResult LimitRef = BuildDeclRefExpr(LimitVar, LimitVar->getType(), - VK_LValue, LimitLoc); - - // LimitVar should have the correct type, because it's derived from the - // original condition. Hence we only need to cast InitRef. - ExprResult CastInit = ImplicitCastExpr::Create( - Context, LimitVar->getType(), CK_IntegralCast, InitRef.get(), nullptr, - VK_XValue, FPOptionsOverride()); + ExprResult InitRef = BuildDeclRefExpr(InitVar, LoopVarTy, VK_LValue, InitLoc); + ExprResult LimitRef = + BuildDeclRefExpr(LimitVar, LoopVarTy, VK_LValue, LimitLoc); + ExprResult CastInit = InitRef; // Compute a check that this _Cilk_for loop executes at all. SourceLocation CondLoc = Cond->getExprLoc(); ExprResult InitCond; @@ -3691,15 +3682,17 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc, // Now rewrite the loop control. // If the comparison is not inclusive, reduce the Range by 1. - if (!CompareInclusive) + if (!CompareInclusive && !StrideIsUnit) Range = BuildBinOp(S, CondLoc, BO_Sub, Range.get(), ActOnIntegerConstant(CilkForLoc, 1).get()); - // Build Range/Stride. - ExprResult NewLimit = BuildBinOp(S, CondLoc, BO_Div, Range.get(), Stride); + ExprResult NewLimit = Range; + if (!StrideIsUnit) + // Build Range/Stride. + NewLimit = BuildBinOp(S, CondLoc, BO_Div, Range.get(), Stride); // If the comparison is not an equality, build Range/Stride + 1 - if (!CompareInclusive) + if (!CompareInclusive && !StrideIsUnit) NewLimit = BuildBinOp(S, CondLoc, BO_Add, NewLimit.get(), ActOnIntegerConstant(CilkForLoc, 1).get()); @@ -3709,8 +3702,7 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc, // Create new declarations for replacement loop control variables. // Declaration for new beginning loop control variable. - VarDecl *BeginVar = BuildForRangeVarDecl(*this, CondLoc, CountTy, - "__begin"); + VarDecl *BeginVar = BuildForRangeVarDecl(*this, CondLoc, CountTy, "__begin"); AddInitializerToDecl(BeginVar, ActOnIntegerConstant(CondLoc, 0).get(), /*DirectInit=*/false); FinalizeDeclaration(BeginVar); @@ -3767,20 +3759,18 @@ StmtResult Sema::HandleSimpleCilkForStmt(SourceLocation CilkForLoc, } } else if (const UnaryOperator *UO = dyn_cast_or_null(Increment)) { - if (UO->isIncrementOp()) - NewInc = BuildUnaryOp(S, IncLoc, UO_PreInc, BeginRef.get()); - else if (UO->isDecrementOp()) - NewInc = BuildUnaryOp(S, IncLoc, UO_PreInc, BeginRef.get()); + NewInc = BuildUnaryOp(S, IncLoc, UO_PreInc, BeginRef.get()); } if (NewInc.isInvalid()) return StmtError(); // Return a new statement for initializing the old loop variable. SourceLocation LoopVarLoc = LoopVar->getBeginLoc(); - ExprResult NewLoopVarInit = - BuildBinOp(S, LoopVarLoc, StrideIsNegative ? BO_Sub : BO_Add, InitRef.get(), - BuildBinOp(S, LoopVarLoc, BO_Mul, - BeginRef.get(), Stride).get()); + ExprResult NewLoopVarInit = BuildBinOp( + S, LoopVarLoc, StrideIsNegative ? BO_Sub : BO_Add, InitRef.get(), + StrideIsUnit + ? BeginRef.get() + : BuildBinOp(S, LoopVarLoc, BO_Mul, BeginRef.get(), Stride).get()); if (!NewLoopVarInit.isInvalid()) AddInitializerToDecl(LoopVar, NewLoopVarInit.get(), /*DirectInit=*/false); diff --git a/clang/test/Cilk/cilkfor-bounds.cpp b/clang/test/Cilk/cilkfor-bounds.cpp index 99c832d5a302..43eaa0121d99 100644 --- a/clang/test/Cilk/cilkfor-bounds.cpp +++ b/clang/test/Cilk/cilkfor-bounds.cpp @@ -26,15 +26,11 @@ void up(size_t start, size_t end) { // CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]] // CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDLIMIT]], %[[ENDINIT]] -// CHECK-NEXT: %[[ENDSUB1:.+]] = sub i64 %[[ENDSUB]], 1 -// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB1]], 1 -// CHECK-NEXT: %[[ENDADD:.+]] = add i64 %[[ENDDIV]], 1 -// CHECK-NEXT: store i64 %[[ENDADD]], ptr %[[END:.+]], align 8 +// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8 // CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]] -// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1 -// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[ITERMUL]] +// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[BEGINITER]] // CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]] // CHECK: [[DETACHED]]: @@ -74,13 +70,11 @@ void up_leq(size_t start, size_t end) { // CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]] // CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDLIMIT]], %[[ENDINIT]] -// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB]], 1 -// CHECK-NEXT: store i64 %[[ENDDIV]], ptr %[[END:.+]], align 8 +// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8 // CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]] -// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1 -// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[ITERMUL]] +// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[BEGINITER]] // CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]] // CHECK: [[DETACHED]]: @@ -120,15 +114,11 @@ void up_flip(size_t start, size_t end) { // CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]] // CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDLIMIT]], %[[ENDINIT]] -// CHECK-NEXT: %[[ENDSUB1:.+]] = sub i64 %[[ENDSUB]], 1 -// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB1]], 1 -// CHECK-NEXT: %[[ENDADD:.+]] = add i64 %[[ENDDIV]], 1 -// CHECK-NEXT: store i64 %[[ENDADD]], ptr %[[END:.+]], align 8 +// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8 // CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]] -// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1 -// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[ITERMUL]] +// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[BEGINITER]] // CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]] // CHECK: [[DETACHED]]: @@ -168,13 +158,11 @@ void up_flip_geq(size_t start, size_t end) { // CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]] // CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDLIMIT]], %[[ENDINIT]] -// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB]], 1 -// CHECK-NEXT: store i64 %[[ENDDIV]], ptr %[[END:.+]], align 8 +// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8 // CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]] -// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1 -// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[ITERMUL]] +// CHECK-NEXT: %[[ITERADD:.+]] = add i64 %[[INITITER]], %[[BEGINITER]] // CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]] // CHECK: [[DETACHED]]: @@ -506,15 +494,11 @@ void down(size_t start, size_t end) { // CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]] // CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDINIT]], %[[ENDLIMIT]] -// CHECK-NEXT: %[[ENDSUB1:.+]] = sub i64 %[[ENDSUB]], 1 -// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB1]], 1 -// CHECK-NEXT: %[[ENDADD:.+]] = add i64 %[[ENDDIV]], 1 -// CHECK-NEXT: store i64 %[[ENDADD]], ptr %[[END:.+]], align 8 +// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8 // CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]] -// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1 -// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[ITERMUL]] +// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[BEGINITER]] // CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]] // CHECK: [[DETACHED]]: @@ -554,13 +538,11 @@ void down_geq(size_t start, size_t end) { // CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]] // CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDINIT]], %[[ENDLIMIT]] -// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB]], 1 -// CHECK-NEXT: store i64 %[[ENDDIV]], ptr %[[END:.+]], align 8 +// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8 // CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]] -// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1 -// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[ITERMUL]] +// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[BEGINITER]] // CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]] // CHECK: [[DETACHED]]: @@ -600,15 +582,11 @@ void down_flip(size_t start, size_t end) { // CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]] // CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDINIT]], %[[ENDLIMIT]] -// CHECK-NEXT: %[[ENDSUB1:.+]] = sub i64 %[[ENDSUB]], 1 -// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB1]], 1 -// CHECK-NEXT: %[[ENDADD:.+]] = add i64 %[[ENDDIV]], 1 -// CHECK-NEXT: store i64 %[[ENDADD]], ptr %[[END:.+]], align 8 +// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8 // CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]] -// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1 -// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[ITERMUL]] +// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[BEGINITER]] // CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]] // CHECK: [[DETACHED]]: @@ -648,13 +626,11 @@ void down_flip_leq(size_t start, size_t end) { // CHECK-NEXT: %[[ENDINIT:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[ENDLIMIT:.+]] = load i64, ptr %[[LIMIT]] // CHECK-NEXT: %[[ENDSUB:.+]] = sub i64 %[[ENDINIT]], %[[ENDLIMIT]] -// CHECK-NEXT: %[[ENDDIV:.+]] = udiv i64 %[[ENDSUB]], 1 -// CHECK-NEXT: store i64 %[[ENDDIV]], ptr %[[END:.+]], align 8 +// CHECK-NEXT: store i64 %[[ENDSUB]], ptr %[[END:.+]], align 8 // CHECK: %[[INITITER:.+]] = load i64, ptr %[[INIT]] // CHECK-NEXT: %[[BEGINITER:.+]] = load i64, ptr %[[BEGIN]] -// CHECK-NEXT: %[[ITERMUL:.+]] = mul i64 %[[BEGINITER]], 1 -// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[ITERMUL]] +// CHECK-NEXT: %[[ITERADD:.+]] = sub i64 %[[INITITER]], %[[BEGINITER]] // CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[DETACHED:.+]], label %[[PFORINC:.+]] // CHECK: [[DETACHED]]: diff --git a/clang/test/Cilk/cilkfor-detach-unwind-rewrite.cpp b/clang/test/Cilk/cilkfor-detach-unwind-rewrite.cpp index f2d71f0bc196..1857da64958f 100644 --- a/clang/test/Cilk/cilkfor-detach-unwind-rewrite.cpp +++ b/clang/test/Cilk/cilkfor-detach-unwind-rewrite.cpp @@ -19,7 +19,34 @@ int main() { return total_red.get(); } -// CHECK: define {{.*}}i32 @main() +int stride_test() { + Reducer_sum total_red(__cilkrts_get_nworkers()); + _Cilk_for(long i = 0; i < 100; i += 5) { total_red.add(5); } + return total_red.get(); +} + +// CHECK-LABEL: define {{.*}}i32 @main() +// CHECK: br i1 %{{.*}}, label %[[PFOR_PH:.+]], label %[[PFOR_END:[a-z0-9._]+]] + +// CHECK: [[PFOR_PH]]: +// CHECK: call void @__ubsan_handle_sub_overflow + +// Check contents of the detach block +// CHECK: load i64, ptr %[[INIT:.+]] +// CHECK: load i64, ptr %[[BEGIN:.+]] +// CHECK: call { i64, i1 } @llvm.sadd.with.overflow.i64( +// CHECK: br i1 %{{.*}}, label %[[CONT5:.+]], label %[[HANDLE_ADD_OVERFLOW:[a-z0-9._]+]], + +// CHECK: [[HANDLE_ADD_OVERFLOW]]: +// CHECK: call void @__ubsan_handle_add_overflow + +// Check that the detach ends up after the loop-variable init expression. + +// CHECK: [[CONT5]]: +// CHECK-NEXT: detach within %[[SYNCREG:.+]], label %[[PFOR_BODY_ENTRY:.+]], label %[[PFOR_INC:.+]] unwind label %[[LPAD9:.+]] + + +// CHECK-LABEL: define {{.*}}i32 @_Z11stride_testv() // CHECK: br i1 %{{.*}}, label %[[PFOR_PH:.+]], label %[[PFOR_END:[a-z0-9._]+]] // CHECK: [[PFOR_PH]]: