diff --git a/rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll b/rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll index 808b410c8d00..79b18b9cd0db 100644 --- a/rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll +++ b/rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll @@ -386,7 +386,7 @@ private predicate resolveExtendedCanonicalPath(Resolvable r, CrateOriginOption c } /** - * A reference contained in an object. For example a field in a struct. + * A path to a value contained in an object. For example a field name of a struct. */ abstract class Content extends TContent { /** Gets a textual representation of this content. */ @@ -416,34 +416,34 @@ private class VariantCanonicalPath extends MkVariantCanonicalPath { abstract class VariantContent extends Content { } /** A tuple variant. */ -private class TupleVariantContent extends VariantContent, TTupleVariantContent { +private class VariantPositionContent extends VariantContent, TVariantPositionContent { private VariantCanonicalPath v; private int pos_; - TupleVariantContent() { this = TTupleVariantContent(v, pos_) } + VariantPositionContent() { this = TVariantPositionContent(v, pos_) } VariantCanonicalPath getVariantCanonicalPath(int pos) { result = v and pos = pos_ } final override string toString() { // only print indices when the arity is > 1 - if exists(TTupleVariantContent(v, 1)) + if exists(TVariantPositionContent(v, 1)) then result = v.toString() + "(" + pos_ + ")" else result = v.toString() } } /** A record variant. */ -private class RecordVariantContent extends VariantContent, TRecordVariantContent { +private class VariantFieldContent extends VariantContent, TVariantFieldContent { private VariantCanonicalPath v; private string field_; - RecordVariantContent() { this = TRecordVariantContent(v, field_) } + VariantFieldContent() { this = TVariantFieldContent(v, field_) } VariantCanonicalPath getVariantCanonicalPath(string field) { result = v and field = field_ } final override string toString() { // only print field when the arity is > 1 - if strictcount(string f | exists(TRecordVariantContent(v, f))) > 1 + if strictcount(string f | exists(TVariantFieldContent(v, f))) > 1 then result = v.toString() + "{" + field_ + "}" else result = v.toString() } @@ -461,7 +461,7 @@ abstract class ContentSet extends TContentSet { abstract Content getAReadContent(); } -private class SingletonContentSet extends ContentSet, TSingletonContentSet { +final private class SingletonContentSet extends ContentSet, TSingletonContentSet { private Content c; SingletonContentSet() { this = TSingletonContentSet(c) } @@ -539,21 +539,18 @@ module RustDataFlow implements InputSig { final class ReturnKind = ReturnKindAlias; pragma[nomagic] - private predicate callResolveExtendedCanonicalPath( - CallExprBase call, CrateOriginOption crate, string path - ) { - exists(Resolvable r | resolveExtendedCanonicalPath(r, crate, path) | - r = call.(MethodCallExpr) - or - r = call.(CallExpr).getExpr().(PathExpr).getPath() - ) + private Resolvable getCallResolvable(CallExprBase call) { + result = call.(MethodCallExpr) + or + result = call.(CallExpr).getExpr().(PathExpr).getPath() } /** Gets a viable implementation of the target of the given `Call`. */ DataFlowCallable viableCallable(DataFlowCall call) { - exists(string path, CrateOriginOption crate | + exists(Resolvable r, string path, CrateOriginOption crate | hasExtendedCanonicalPath(result.asCfgScope(), crate, path) and - callResolveExtendedCanonicalPath(call.asCallBaseExprCfgNode().getExpr(), crate, path) + r = getCallResolvable(call.asCallBaseExprCfgNode().getExpr()) and + resolveExtendedCanonicalPath(r, crate, path) ) } @@ -581,7 +578,7 @@ module RustDataFlow implements InputSig { predicate forceHighPrecision(Content c) { none() } - final class ContentApprox = Content; // todo + final class ContentApprox = Content; // TODO: Implement if needed ContentApprox getContentApprox(Content c) { result = c } @@ -621,6 +618,10 @@ module RustDataFlow implements InputSig { // TODO: Remove once library types are extracted not p.hasQualifier() and v = MkVariantCanonicalPath(_, "crate::std::option::Option", p.getPart().getNameRef().getText()) + or + // TODO: Remove once library types are extracted + not p.hasQualifier() and + v = MkVariantCanonicalPath(_, "crate::std::result::Result", p.getPart().getNameRef().getText()) } /** Holds if `p` destructs an enum variant `v`. */ @@ -642,22 +643,19 @@ module RustDataFlow implements InputSig { */ predicate readStep(Node node1, ContentSet cs, Node node2) { exists(Content c | c = cs.(SingletonContentSet).getContent() | - node1.asPat() = - any(TupleStructPatCfgNode pat, int pos | - tupleVariantDestruction(pat.getPat(), c.(TupleVariantContent).getVariantCanonicalPath(pos)) and - node2.asPat() = pat.getField(pos) - | - pat - ) + exists(TupleStructPatCfgNode pat, int pos | + pat = node1.asPat() and + tupleVariantDestruction(pat.getPat(), + c.(VariantPositionContent).getVariantCanonicalPath(pos)) and + node2.asPat() = pat.getField(pos) + ) or - node1.asPat() = - any(RecordPatCfgNode pat, string field | - recordVariantDestruction(pat.getPat(), - c.(RecordVariantContent).getVariantCanonicalPath(field)) and - node2.asPat() = pat.getFieldPat(field) - | - pat - ) + exists(RecordPatCfgNode pat, string field | + pat = node1.asPat() and + recordVariantDestruction(pat.getPat(), + c.(VariantFieldContent).getVariantCanonicalPath(field)) and + node2.asPat() = pat.getFieldPat(field) + ) ) } @@ -683,7 +681,7 @@ module RustDataFlow implements InputSig { node2.asExpr() = any(CallExprCfgNode call, int pos | tupleVariantConstruction(call.getCallExpr(), - c.(TupleVariantContent).getVariantCanonicalPath(pos)) and + c.(VariantPositionContent).getVariantCanonicalPath(pos)) and node1.asExpr() = call.getArgument(pos) | call @@ -692,7 +690,7 @@ module RustDataFlow implements InputSig { node2.asExpr() = any(RecordExprCfgNode re, string field | recordVariantConstruction(re.getRecordExpr(), - c.(RecordVariantContent).getVariantCanonicalPath(field)) and + c.(VariantFieldContent).getVariantCanonicalPath(field)) and node1.asExpr() = re.getFieldExpr(field) | re @@ -806,18 +804,27 @@ private module Cached { crate.isNone() and path = "crate::std::option::Option" and name = "Some" + or + // TODO: Remove once library types are extracted + crate.isNone() and + path = "crate::std::result::Result" and + name = ["Ok", "Err"] } cached newtype TContent = - TTupleVariantContent(VariantCanonicalPath v, int pos) { + TVariantPositionContent(VariantCanonicalPath v, int pos) { pos in [0 .. v.getVariant().getFieldList().(TupleFieldList).getNumberOfFields() - 1] or // TODO: Remove once library types are extracted v = MkVariantCanonicalPath(_, "crate::std::option::Option", "Some") and pos = 0 + or + // TODO: Remove once library types are extracted + v = MkVariantCanonicalPath(_, "crate::std::result::Result", ["Ok", "Err"]) and + pos = 0 } or - TRecordVariantContent(VariantCanonicalPath v, string field) { + TVariantFieldContent(VariantCanonicalPath v, string field) { field = v.getVariant().getFieldList().(RecordFieldList).getAField().getName().getText() }