Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hvitved committed Nov 26, 2024
1 parent 5e7cd46 commit 8c11138
Showing 1 changed file with 45 additions and 38 deletions.
83 changes: 45 additions & 38 deletions rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ private predicate resolveExtendedCanonicalPath(Resolvable r, CrateOriginOption c
}

/**
* A reference contained in an object. For example a field in a struct.
* A path to a value contained in an object. For example a field name of a struct.
*/
abstract class Content extends TContent {
/** Gets a textual representation of this content. */
Expand Down Expand Up @@ -416,34 +416,34 @@ private class VariantCanonicalPath extends MkVariantCanonicalPath {
abstract class VariantContent extends Content { }

/** A tuple variant. */
private class TupleVariantContent extends VariantContent, TTupleVariantContent {
private class VariantPositionContent extends VariantContent, TVariantPositionContent {
private VariantCanonicalPath v;
private int pos_;

TupleVariantContent() { this = TTupleVariantContent(v, pos_) }
VariantPositionContent() { this = TVariantPositionContent(v, pos_) }

VariantCanonicalPath getVariantCanonicalPath(int pos) { result = v and pos = pos_ }

final override string toString() {
// only print indices when the arity is > 1
if exists(TTupleVariantContent(v, 1))
if exists(TVariantPositionContent(v, 1))
then result = v.toString() + "(" + pos_ + ")"
else result = v.toString()
}
}

/** A record variant. */
private class RecordVariantContent extends VariantContent, TRecordVariantContent {
private class VariantFieldContent extends VariantContent, TVariantFieldContent {
private VariantCanonicalPath v;
private string field_;

RecordVariantContent() { this = TRecordVariantContent(v, field_) }
VariantFieldContent() { this = TVariantFieldContent(v, field_) }

VariantCanonicalPath getVariantCanonicalPath(string field) { result = v and field = field_ }

final override string toString() {
// only print field when the arity is > 1
if strictcount(string f | exists(TRecordVariantContent(v, f))) > 1
if strictcount(string f | exists(TVariantFieldContent(v, f))) > 1
then result = v.toString() + "{" + field_ + "}"
else result = v.toString()
}
Expand All @@ -461,7 +461,7 @@ abstract class ContentSet extends TContentSet {
abstract Content getAReadContent();
}

private class SingletonContentSet extends ContentSet, TSingletonContentSet {
final private class SingletonContentSet extends ContentSet, TSingletonContentSet {
private Content c;

SingletonContentSet() { this = TSingletonContentSet(c) }
Expand Down Expand Up @@ -539,21 +539,18 @@ module RustDataFlow implements InputSig<Location> {
final class ReturnKind = ReturnKindAlias;

pragma[nomagic]
private predicate callResolveExtendedCanonicalPath(
CallExprBase call, CrateOriginOption crate, string path
) {
exists(Resolvable r | resolveExtendedCanonicalPath(r, crate, path) |
r = call.(MethodCallExpr)
or
r = call.(CallExpr).getExpr().(PathExpr).getPath()
)
private Resolvable getCallResolvable(CallExprBase call) {
result = call.(MethodCallExpr)
or
result = call.(CallExpr).getExpr().(PathExpr).getPath()
}

/** Gets a viable implementation of the target of the given `Call`. */
DataFlowCallable viableCallable(DataFlowCall call) {
exists(string path, CrateOriginOption crate |
exists(Resolvable r, string path, CrateOriginOption crate |
hasExtendedCanonicalPath(result.asCfgScope(), crate, path) and
callResolveExtendedCanonicalPath(call.asCallBaseExprCfgNode().getExpr(), crate, path)
r = getCallResolvable(call.asCallBaseExprCfgNode().getExpr()) and
resolveExtendedCanonicalPath(r, crate, path)
)
}

Expand Down Expand Up @@ -581,7 +578,7 @@ module RustDataFlow implements InputSig<Location> {

predicate forceHighPrecision(Content c) { none() }

final class ContentApprox = Content; // todo
final class ContentApprox = Content; // TODO: Implement if needed

ContentApprox getContentApprox(Content c) { result = c }

Expand Down Expand Up @@ -621,6 +618,10 @@ module RustDataFlow implements InputSig<Location> {
// TODO: Remove once library types are extracted
not p.hasQualifier() and
v = MkVariantCanonicalPath(_, "crate::std::option::Option", p.getPart().getNameRef().getText())
or
// TODO: Remove once library types are extracted
not p.hasQualifier() and
v = MkVariantCanonicalPath(_, "crate::std::result::Result", p.getPart().getNameRef().getText())
}

/** Holds if `p` destructs an enum variant `v`. */
Expand All @@ -642,22 +643,19 @@ module RustDataFlow implements InputSig<Location> {
*/
predicate readStep(Node node1, ContentSet cs, Node node2) {
exists(Content c | c = cs.(SingletonContentSet).getContent() |
node1.asPat() =
any(TupleStructPatCfgNode pat, int pos |
tupleVariantDestruction(pat.getPat(), c.(TupleVariantContent).getVariantCanonicalPath(pos)) and
node2.asPat() = pat.getField(pos)
|
pat
)
exists(TupleStructPatCfgNode pat, int pos |
pat = node1.asPat() and
tupleVariantDestruction(pat.getPat(),
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
node2.asPat() = pat.getField(pos)
)
or
node1.asPat() =
any(RecordPatCfgNode pat, string field |
recordVariantDestruction(pat.getPat(),
c.(RecordVariantContent).getVariantCanonicalPath(field)) and
node2.asPat() = pat.getFieldPat(field)
|
pat
)
exists(RecordPatCfgNode pat, string field |
pat = node1.asPat() and
recordVariantDestruction(pat.getPat(),
c.(VariantFieldContent).getVariantCanonicalPath(field)) and
node2.asPat() = pat.getFieldPat(field)
)
)
}

Expand All @@ -683,7 +681,7 @@ module RustDataFlow implements InputSig<Location> {
node2.asExpr() =
any(CallExprCfgNode call, int pos |
tupleVariantConstruction(call.getCallExpr(),
c.(TupleVariantContent).getVariantCanonicalPath(pos)) and
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
node1.asExpr() = call.getArgument(pos)
|
call
Expand All @@ -692,7 +690,7 @@ module RustDataFlow implements InputSig<Location> {
node2.asExpr() =
any(RecordExprCfgNode re, string field |
recordVariantConstruction(re.getRecordExpr(),
c.(RecordVariantContent).getVariantCanonicalPath(field)) and
c.(VariantFieldContent).getVariantCanonicalPath(field)) and
node1.asExpr() = re.getFieldExpr(field)
|
re
Expand Down Expand Up @@ -806,18 +804,27 @@ private module Cached {
crate.isNone() and
path = "crate::std::option::Option" and
name = "Some"
or
// TODO: Remove once library types are extracted
crate.isNone() and
path = "crate::std::result::Result" and
name = ["Ok", "Err"]
}

cached
newtype TContent =
TTupleVariantContent(VariantCanonicalPath v, int pos) {
TVariantPositionContent(VariantCanonicalPath v, int pos) {
pos in [0 .. v.getVariant().getFieldList().(TupleFieldList).getNumberOfFields() - 1]
or
// TODO: Remove once library types are extracted
v = MkVariantCanonicalPath(_, "crate::std::option::Option", "Some") and
pos = 0
or
// TODO: Remove once library types are extracted
v = MkVariantCanonicalPath(_, "crate::std::result::Result", ["Ok", "Err"]) and
pos = 0
} or
TRecordVariantContent(VariantCanonicalPath v, string field) {
TVariantFieldContent(VariantCanonicalPath v, string field) {
field = v.getVariant().getFieldList().(RecordFieldList).getAField().getName().getText()
}

Expand Down

0 comments on commit 8c11138

Please sign in to comment.