Skip to content

Commit

Permalink
[Sema] Fix semantic analysis of simple cilk_for loops to get rid of e…
Browse files Browse the repository at this point in the history
…rroneous unused-comparison warnings.
  • Loading branch information
neboat committed Jun 14, 2024
1 parent 99ab793 commit 5dfee39
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 27 deletions.
42 changes: 23 additions & 19 deletions clang/include/clang/AST/StmtCilk.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ class CilkForStmt : public Stmt {
public:
CilkForStmt(Stmt *Init, DeclStmt *Limit, Expr *InitCond, DeclStmt *Begin,
DeclStmt *End, Expr *Cond, Expr *Inc, DeclStmt *LoopVar,
Stmt *Body, Stmt *OgCond, Stmt *OgInc, SourceLocation CFL,
Stmt *Body, Expr *OgCond, Expr *OgInc, SourceLocation CFL,
SourceLocation LP, SourceLocation RP);

/// Build an empty _Cilk_for statement.
explicit CilkForStmt(EmptyShell Empty) : Stmt(CilkForStmtClass, Empty) { }
explicit CilkForStmt(EmptyShell Empty) : Stmt(CilkForStmtClass, Empty) {}

Stmt *getInit() { return SubExprs[INIT]; }

Expand All @@ -130,24 +130,22 @@ class CilkForStmt : public Stmt {
// return reinterpret_cast<DeclStmt*>(SubExprs[CONDVAR]);
// }

DeclStmt *getLimitStmt() {
return cast_or_null<DeclStmt>(SubExprs[LIMIT]);
}
DeclStmt *getLimitStmt() { return cast_or_null<DeclStmt>(SubExprs[LIMIT]); }
Expr *getInitCond() { return cast_or_null<Expr>(SubExprs[INITCOND]); }
DeclStmt *getBeginStmt() {
return cast_or_null<DeclStmt>(SubExprs[BEGINSTMT]);
}
DeclStmt *getEndStmt() { return cast_or_null<DeclStmt>(SubExprs[ENDSTMT]); }
Expr *getCond() { return reinterpret_cast<Expr*>(SubExprs[COND]); }
Expr *getInc() { return reinterpret_cast<Expr*>(SubExprs[INC]); }
Expr *getCond() { return reinterpret_cast<Expr *>(SubExprs[COND]); }
Expr *getInc() { return reinterpret_cast<Expr *>(SubExprs[INC]); }
DeclStmt *getLoopVarStmt() {
return cast_or_null<DeclStmt>(SubExprs[LOOPVAR]);
}
Stmt *getBody() { return SubExprs[BODY]; }

Stmt *getOriginalInit();
Stmt *getOriginalCond() { return SubExprs[OGCOND]; }
Stmt *getOriginalInc() { return SubExprs[OGINC]; }
Expr *getOriginalCond() { return reinterpret_cast<Expr *>(SubExprs[OGCOND]); }
Expr *getOriginalInc() { return reinterpret_cast<Expr *>(SubExprs[OGINC]); }

const Stmt *getInit() const { return SubExprs[INIT]; }
const VarDecl *getLoopVariable() const;
Expand All @@ -163,24 +161,32 @@ class CilkForStmt : public Stmt {
const DeclStmt *getEndStmt() const {
return cast_or_null<DeclStmt>(SubExprs[ENDSTMT]);
}
const Expr *getCond() const { return reinterpret_cast<Expr*>(SubExprs[COND]);}
const Expr *getInc() const { return reinterpret_cast<Expr*>(SubExprs[INC]); }
const Expr *getCond() const {
return reinterpret_cast<Expr *>(SubExprs[COND]);
}
const Expr *getInc() const { return reinterpret_cast<Expr *>(SubExprs[INC]); }
const DeclStmt *getLoopVarStmt() const {
return cast_or_null<DeclStmt>(SubExprs[LOOPVAR]);
}
const Stmt *getBody() const { return SubExprs[BODY]; }

const Stmt *getOriginalInit() const;
const Stmt *getOriginalCond() const { return SubExprs[OGCOND]; }
const Stmt *getOriginalInc() const { return SubExprs[OGINC]; }
const Expr *getOriginalCond() const {
return reinterpret_cast<Expr *>(SubExprs[OGCOND]);
}
const Expr *getOriginalInc() const {
return reinterpret_cast<Expr *>(SubExprs[OGINC]);
}

void setInit(Stmt *S) { SubExprs[INIT] = S; }
void setLimitStmt(Stmt *S) { SubExprs[LIMIT] = S; }
void setInitCond(Expr *E) { SubExprs[INITCOND] = reinterpret_cast<Stmt*>(E); }
void setInitCond(Expr *E) {
SubExprs[INITCOND] = reinterpret_cast<Stmt *>(E);
}
void setBeginStmt(Stmt *S) { SubExprs[BEGINSTMT] = S; }
void setEndStmt(Stmt *S) { SubExprs[ENDSTMT] = S; }
void setCond(Expr *E) { SubExprs[COND] = reinterpret_cast<Stmt*>(E); }
void setInc(Expr *E) { SubExprs[INC] = reinterpret_cast<Stmt*>(E); }
void setCond(Expr *E) { SubExprs[COND] = reinterpret_cast<Stmt *>(E); }
void setInc(Expr *E) { SubExprs[INC] = reinterpret_cast<Stmt *>(E); }
void setLoopVarStmt(Stmt *S) { SubExprs[LOOPVAR] = S; }
void setBody(Stmt *S) { SubExprs[BODY] = S; }

Expand All @@ -204,9 +210,7 @@ class CilkForStmt : public Stmt {
}

// Iterators
child_range children() {
return child_range(&SubExprs[0], &SubExprs[END]);
}
child_range children() { return child_range(&SubExprs[0], &SubExprs[END]); }

const_child_range children() const {
return const_child_range(&SubExprs[0], &SubExprs[END]);
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -5264,7 +5264,7 @@ class Sema final {
ConditionResult second, FullExprArg third,
SourceLocation RParenLoc, Stmt *Body,
DeclStmt *LoopVar = nullptr,
Stmt *OgCond = nullptr, Stmt *OgInc = nullptr);
Expr *OgCond = nullptr, Expr *OgInc = nullptr);

StmtResult BuildCilkForStmt(SourceLocation CilkForLoc,
SourceLocation LParenLoc,
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1443,8 +1443,8 @@ bool CapturedStmt::capturesVariable(const VarDecl *Var) const {
// CilkForStmt
CilkForStmt::CilkForStmt(Stmt *Init, DeclStmt *Limit, Expr *InitCond,
DeclStmt *BeginStmt, DeclStmt *EndStmt, Expr *Cond,
Expr *Inc, DeclStmt *LoopVar, Stmt *Body, Stmt *OgCond,
Stmt *OgInc, SourceLocation CFL, SourceLocation LP,
Expr *Inc, DeclStmt *LoopVar, Stmt *Body, Expr *OgCond,
Expr *OgInc, SourceLocation CFL, SourceLocation LP,
SourceLocation RP)
: Stmt(CilkForStmtClass), CilkForLoc(CFL), LParenLoc(LP), RParenLoc(RP) {
SubExprs[INIT] = Init;
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/Sema/SemaStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3931,7 +3931,7 @@ Sema::ActOnCilkForStmt(SourceLocation CilkForLoc, SourceLocation LParenLoc,
Stmt *First, DeclStmt *Limit, ConditionResult InitCond,
DeclStmt *Begin, DeclStmt *End, ConditionResult Second,
FullExprArg Third, SourceLocation RParenLoc, Stmt *Body,
DeclStmt *LoopVar, Stmt *OgCond, Stmt *OgInc) {
DeclStmt *LoopVar, Expr *OgCond, Expr *OgInc) {
if (CheckCilkForInit(*this, CilkForLoc, First))
return StmtResult();

Expand Down
8 changes: 4 additions & 4 deletions clang/lib/Sema/TreeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1464,8 +1464,8 @@ class TreeTransform {
Sema::ConditionResult InitCond, Stmt *Begin,
Stmt *End, Sema::ConditionResult Cond,
Sema::FullExprArg Inc, SourceLocation RParenLoc,
Stmt *LoopVar, Stmt *Body, Stmt *OgCond,
Stmt *OgInc) {
Stmt *LoopVar, Stmt *Body, Expr *OgCond,
Expr *OgInc) {
return getSema().ActOnCilkForStmt(
ForLoc, LParenLoc, Init, cast_or_null<DeclStmt>(Limit), InitCond,
cast_or_null<DeclStmt>(Begin), cast_or_null<DeclStmt>(End), Cond, Inc,
Expand Down Expand Up @@ -15537,10 +15537,10 @@ TreeTransform<Derived>::TransformCilkForStmt(CilkForStmt *S) {
}

// Transform the original init, condition, and increment statements
StmtResult OgCond = getDerived().TransformStmt(S->getOriginalCond());
ExprResult OgCond = getDerived().TransformExpr(S->getOriginalCond());
if (OgCond.isInvalid())
return StmtError();
StmtResult OgInc = getDerived().TransformStmt(S->getOriginalInc());
ExprResult OgInc = getDerived().TransformExpr(S->getOriginalInc());
if (OgInc.isInvalid())
return StmtError();

Expand Down
32 changes: 32 additions & 0 deletions clang/test/Cilk/cilkfor-unused-expr-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: %clang_cc1 %s -fopencilk -verify -fsyntax-only
// expected-no-diagnostics

typedef long long int64_t;

template <bool transpose>
__attribute__((always_inline)) static int64_t a_index(int64_t ii, int64_t m,
int64_t jj, int64_t n) {
return transpose ? ((jj * m) + ii) : ((ii * n) + jj);
}

#define ARG_INDEX(arg, ii, m, jj, n, transpose) \
(arg[a_index<transpose>(ii, m, jj, n)])

template <typename F>
void matmul_ploops(F *__restrict__ out, const F *__restrict__ lhs,
const F *__restrict__ rhs, int64_t m, int64_t n,
int64_t k) {
_Cilk_for(int64_t i = 0; i < m; ++i) {
_Cilk_for(int64_t j = 0; j < n; ++j) {
out[j * m + i] = 0.0;
for (int64_t l = 0; l < k; ++l)
out[j * m + i] += ARG_INDEX(lhs, l, k, i, m, true) *
ARG_INDEX(rhs, j, n, l, k, true);
}
}
}

template void matmul_ploops<float>(float *__restrict__ out,
const float *__restrict__ lhs,
const float *__restrict__ rhs, int64_t m,
int64_t n, int64_t k);

0 comments on commit 5dfee39

Please sign in to comment.