Skip to content

Commit

Permalink
Merge pull request #15958 from MathiasVP/ir-guards-from-switch-statem…
Browse files Browse the repository at this point in the history
…ents-2

C++: Implement guards logic for switch statements
  • Loading branch information
MathiasVP authored Mar 19, 2024
2 parents df18453 + d7afd7b commit 597f008
Show file tree
Hide file tree
Showing 8 changed files with 411 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
category: feature
---
* Added a predicate `GuardCondition.comparesEq/4` to query whether an expression is compared to a constant.
* Added a predicate `GuardCondition.ensuresEq/4` to query whether a basic block is guarded by an expression being equal to a constant.
229 changes: 199 additions & 30 deletions cpp/ql/lib/semmle/code/cpp/controlflow/IRGuards.qll
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ class GuardCondition extends Expr {
*/
cached
predicate ensuresEq(Expr left, Expr right, int k, BasicBlock block, boolean areEqual) { none() }

/** Holds if (determined by this guard) `e == k` evaluates to `areEqual` if this expression evaluates to `testIsTrue`. */
cached
predicate comparesEq(Expr e, int k, boolean areEqual, boolean testIsTrue) { none() }

/**
* Holds if (determined by this guard) `e == k` must be `areEqual` in `block`.
* If `areEqual = false` then this implies `e != k`.
*/
cached
predicate ensuresEq(Expr e, int k, BasicBlock block, boolean areEqual) { none() }
}

/**
Expand Down Expand Up @@ -184,6 +195,20 @@ private class GuardConditionFromBinaryLogicalOperator extends GuardCondition {
this.comparesEq(left, right, k, areEqual, testIsTrue) and this.controls(block, testIsTrue)
)
}

override predicate comparesEq(Expr e, int k, boolean areEqual, boolean testIsTrue) {
exists(boolean partIsTrue, GuardCondition part |
this.(BinaryLogicalOperation).impliesValue(part, partIsTrue, testIsTrue)
|
part.comparesEq(e, k, areEqual, partIsTrue)
)
}

override predicate ensuresEq(Expr e, int k, BasicBlock block, boolean areEqual) {
exists(boolean testIsTrue |
this.comparesEq(e, k, areEqual, testIsTrue) and this.controls(block, testIsTrue)
)
}
}

/**
Expand Down Expand Up @@ -245,6 +270,21 @@ private class GuardConditionFromIR extends GuardCondition {
)
}

override predicate comparesEq(Expr e, int k, boolean areEqual, boolean testIsTrue) {
exists(Instruction i |
i.getUnconvertedResultExpression() = e and
ir.comparesEq(i.getAUse(), k, areEqual, testIsTrue)
)
}

override predicate ensuresEq(Expr e, int k, BasicBlock block, boolean areEqual) {
exists(Instruction i, boolean testIsTrue |
i.getUnconvertedResultExpression() = e and
ir.comparesEq(i.getAUse(), k, areEqual, testIsTrue) and
this.controls(block, testIsTrue)
)
}

/**
* Holds if this condition controls `block`, meaning that `block` is only
* entered if the value of this condition is `v`. This helper
Expand Down Expand Up @@ -446,7 +486,25 @@ class IRGuardCondition extends Instruction {
/** Holds if (determined by this guard) `left == right + k` evaluates to `areEqual` if this expression evaluates to `testIsTrue`. */
cached
predicate comparesEq(Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue) {
compares_eq(this, left, right, k, areEqual, testIsTrue)
exists(BooleanValue value |
compares_eq(this, left, right, k, areEqual, value) and
value.getValue() = testIsTrue
)
}

/** Holds if (determined by this guard) `op == k` evaluates to `areEqual` if this expression evaluates to `testIsTrue`. */
cached
predicate comparesEq(Operand op, int k, boolean areEqual, boolean testIsTrue) {
exists(MatchValue mv |
compares_eq(this, op, k, areEqual, mv) and
// A match value cannot be dualized, so `testIsTrue` is always true
testIsTrue = true
)
or
exists(BooleanValue bv |
compares_eq(this, op, k, areEqual, bv) and
bv.getValue() = testIsTrue
)
}

/**
Expand All @@ -455,8 +513,19 @@ class IRGuardCondition extends Instruction {
*/
cached
predicate ensuresEq(Operand left, Operand right, int k, IRBlock block, boolean areEqual) {
exists(boolean testIsTrue |
compares_eq(this, left, right, k, areEqual, testIsTrue) and this.controls(block, testIsTrue)
exists(AbstractValue value |
compares_eq(this, left, right, k, areEqual, value) and this.valueControls(block, value)
)
}

/**
* Holds if (determined by this guard) `op == k` must be `areEqual` in `block`.
* If `areEqual = false` then this implies `op != k`.
*/
cached
predicate ensuresEq(Operand op, int k, IRBlock block, boolean areEqual) {
exists(AbstractValue value |
compares_eq(this, op, k, areEqual, value) and this.valueControls(block, value)
)
}

Expand All @@ -468,9 +537,21 @@ class IRGuardCondition extends Instruction {
predicate ensuresEqEdge(
Operand left, Operand right, int k, IRBlock pred, IRBlock succ, boolean areEqual
) {
exists(boolean testIsTrue |
compares_eq(this, left, right, k, areEqual, testIsTrue) and
this.controlsEdge(pred, succ, testIsTrue)
exists(AbstractValue value |
compares_eq(this, left, right, k, areEqual, value) and
this.valueControlsEdge(pred, succ, value)
)
}

/**
* Holds if (determined by this guard) `op == k` must be `areEqual` on the edge from
* `pred` to `succ`. If `areEqual = false` then this implies `op != k`.
*/
cached
predicate ensuresEqEdge(Operand op, int k, IRBlock pred, IRBlock succ, boolean areEqual) {
exists(AbstractValue value |
compares_eq(this, op, k, areEqual, value) and
this.valueControlsEdge(pred, succ, value)
)
}

Expand Down Expand Up @@ -572,52 +653,98 @@ private Instruction getBranchForCondition(Instruction guard) {
* Beware making mistaken logical implications here relating `areEqual` and `testIsTrue`.
*/
private predicate compares_eq(
Instruction test, Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue
Instruction test, Operand left, Operand right, int k, boolean areEqual, AbstractValue value
) {
/* The simple case where the test *is* the comparison so areEqual = testIsTrue xor eq. */
exists(boolean eq | simple_comparison_eq(test, left, right, k, eq) |
areEqual = true and testIsTrue = eq
exists(AbstractValue v | simple_comparison_eq(test, left, right, k, v) |
areEqual = true and value = v
or
areEqual = false and testIsTrue = eq.booleanNot()
areEqual = false and value = v.getDualValue()
)
or
// I think this is handled by forwarding in controlsBlock.
//or
//logical_comparison_eq(test, left, right, k, areEqual, testIsTrue)
/* a == b + k => b == a - k */
exists(int mk | k = -mk | compares_eq(test, right, left, mk, areEqual, testIsTrue))
exists(int mk | k = -mk | compares_eq(test, right, left, mk, areEqual, value))
or
complex_eq(test, left, right, k, areEqual, testIsTrue)
complex_eq(test, left, right, k, areEqual, value)
or
/* (x is true => (left == right + k)) => (!x is false => (left == right + k)) */
exists(boolean isFalse | testIsTrue = isFalse.booleanNot() |
compares_eq(test.(LogicalNotInstruction).getUnary(), left, right, k, areEqual, isFalse)
exists(AbstractValue dual | value = dual.getDualValue() |
compares_eq(test.(LogicalNotInstruction).getUnary(), left, right, k, areEqual, dual)
)
}

/** Holds if `op == k` is `areEqual` given that `test` is equal to `value`. */
private predicate compares_eq(
Instruction test, Operand op, int k, boolean areEqual, AbstractValue value
) {
/* The simple case where the test *is* the comparison so areEqual = testIsTrue xor eq. */
exists(AbstractValue v | simple_comparison_eq(test, op, k, v) |
areEqual = true and value = v
or
areEqual = false and value = v.getDualValue()
)
or
complex_eq(test, op, k, areEqual, value)
or
/* (x is true => (op == k)) => (!x is false => (op == k)) */
exists(AbstractValue dual | value = dual.getDualValue() |
compares_eq(test.(LogicalNotInstruction).getUnary(), op, k, areEqual, dual)
)
or
// ((test is `areEqual` => op == const + k2) and const == `k1`) =>
// test is `areEqual` => op == k1 + k2
exists(int k1, int k2, ConstantInstruction const |
compares_eq(test, op, const.getAUse(), k2, areEqual, value) and
int_value(const) = k1 and
k = k1 + k2
)
}

/** Rearrange various simple comparisons into `left == right + k` form. */
private predicate simple_comparison_eq(
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual
CompareInstruction cmp, Operand left, Operand right, int k, AbstractValue value
) {
left = cmp.getLeftOperand() and
cmp instanceof CompareEQInstruction and
right = cmp.getRightOperand() and
k = 0 and
areEqual = true
value.(BooleanValue).getValue() = true
or
left = cmp.getLeftOperand() and
cmp instanceof CompareNEInstruction and
right = cmp.getRightOperand() and
k = 0 and
areEqual = false
value.(BooleanValue).getValue() = false
}

/** Rearrange various simple comparisons into `op == k` form. */
private predicate simple_comparison_eq(Instruction test, Operand op, int k, AbstractValue value) {
exists(SwitchInstruction switch, CaseEdge case |
test = switch.getExpression() and
op.getDef() = test and
case = value.(MatchValue).getCase() and
exists(switch.getSuccessor(case)) and
case.getValue().toInt() = k
)
}

private predicate complex_eq(
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, AbstractValue value
) {
sub_eq(cmp, left, right, k, areEqual, value)
or
add_eq(cmp, left, right, k, areEqual, value)
}

private predicate complex_eq(
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue
Instruction test, Operand op, int k, boolean areEqual, AbstractValue value
) {
sub_eq(cmp, left, right, k, areEqual, testIsTrue)
sub_eq(test, op, k, areEqual, value)
or
add_eq(cmp, left, right, k, areEqual, testIsTrue)
add_eq(test, op, k, areEqual, value)
}

/*
Expand Down Expand Up @@ -768,44 +895,61 @@ private predicate add_lt(
// left - x == right + c => left == right + (c+x)
// left == (right - x) + c => left == right + (c-x)
private predicate sub_eq(
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, AbstractValue value
) {
exists(SubInstruction lhs, int c, int x |
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, testIsTrue) and
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, value) and
left = lhs.getLeftOperand() and
x = int_value(lhs.getRight()) and
k = c + x
)
or
exists(SubInstruction rhs, int c, int x |
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, testIsTrue) and
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, value) and
right = rhs.getLeftOperand() and
x = int_value(rhs.getRight()) and
k = c - x
)
or
exists(PointerSubInstruction lhs, int c, int x |
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, testIsTrue) and
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, value) and
left = lhs.getLeftOperand() and
x = int_value(lhs.getRight()) and
k = c + x
)
or
exists(PointerSubInstruction rhs, int c, int x |
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, testIsTrue) and
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, value) and
right = rhs.getLeftOperand() and
x = int_value(rhs.getRight()) and
k = c - x
)
}

// op - x == c => op == (c+x)
private predicate sub_eq(Instruction test, Operand op, int k, boolean areEqual, AbstractValue value) {
exists(SubInstruction sub, int c, int x |
compares_eq(test, sub.getAUse(), c, areEqual, value) and
op = sub.getLeftOperand() and
x = int_value(sub.getRight()) and
k = c + x
)
or
exists(PointerSubInstruction sub, int c, int x |
compares_eq(test, sub.getAUse(), c, areEqual, value) and
op = sub.getLeftOperand() and
x = int_value(sub.getRight()) and
k = c + x
)
}

// left + x == right + c => left == right + (c-x)
// left == (right + x) + c => left == right + (c+x)
private predicate add_eq(
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, AbstractValue value
) {
exists(AddInstruction lhs, int c, int x |
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, testIsTrue) and
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, value) and
(
left = lhs.getLeftOperand() and x = int_value(lhs.getRight())
or
Expand All @@ -815,7 +959,7 @@ private predicate add_eq(
)
or
exists(AddInstruction rhs, int c, int x |
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, testIsTrue) and
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, value) and
(
right = rhs.getLeftOperand() and x = int_value(rhs.getRight())
or
Expand All @@ -825,7 +969,7 @@ private predicate add_eq(
)
or
exists(PointerAddInstruction lhs, int c, int x |
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, testIsTrue) and
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, value) and
(
left = lhs.getLeftOperand() and x = int_value(lhs.getRight())
or
Expand All @@ -835,7 +979,7 @@ private predicate add_eq(
)
or
exists(PointerAddInstruction rhs, int c, int x |
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, testIsTrue) and
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, value) and
(
right = rhs.getLeftOperand() and x = int_value(rhs.getRight())
or
Expand All @@ -845,5 +989,30 @@ private predicate add_eq(
)
}

// left + x == right + c => left == right + (c-x)
private predicate add_eq(
Instruction test, Operand left, int k, boolean areEqual, AbstractValue value
) {
exists(AddInstruction lhs, int c, int x |
compares_eq(test, lhs.getAUse(), c, areEqual, value) and
(
left = lhs.getLeftOperand() and x = int_value(lhs.getRight())
or
left = lhs.getRightOperand() and x = int_value(lhs.getLeft())
) and
k = c - x
)
or
exists(PointerAddInstruction lhs, int c, int x |
compares_eq(test, lhs.getAUse(), c, areEqual, value) and
(
left = lhs.getLeftOperand() and x = int_value(lhs.getRight())
or
left = lhs.getRightOperand() and x = int_value(lhs.getLeft())
) and
k = c - x
)
}

/** The int value of integer constant expression. */
private int int_value(Instruction i) { result = i.(IntegerConstantInstruction).getValue().toInt() }
9 changes: 9 additions & 0 deletions cpp/ql/lib/semmle/code/cpp/ir/implementation/EdgeKind.qll
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ class CaseEdge extends EdgeKind, TCaseEdge {
* Gets the largest value of the switch expression for which control will flow along this edge.
*/
final string getMaxValue() { result = maxValue }

/**
* Gets the unique value of the switch expression for which control will
* flow along this edge, if any.
*/
final string getValue() {
minValue = maxValue and
result = minValue
}
}

/**
Expand Down
Loading

0 comments on commit 597f008

Please sign in to comment.