diff --git a/go/ql/lib/semmle/go/Decls.qll b/go/ql/lib/semmle/go/Decls.qll index 8e3df22cc800..6c66b085575b 100644 --- a/go/ql/lib/semmle/go/Decls.qll +++ b/go/ql/lib/semmle/go/Decls.qll @@ -212,10 +212,7 @@ class MethodDecl extends FuncDecl { * * is `Rectangle`. */ - NamedType getReceiverBaseType() { - result = this.getReceiverType() or - result = this.getReceiverType().(PointerType).getBaseType() - } + NamedType getReceiverBaseType() { result = lookThroughPointerType(this.getReceiverType()) } /** * Gets the receiver variable of this method. diff --git a/go/ql/lib/semmle/go/Scopes.qll b/go/ql/lib/semmle/go/Scopes.qll index 191534759ea6..f9b9e3a26b9a 100644 --- a/go/ql/lib/semmle/go/Scopes.qll +++ b/go/ql/lib/semmle/go/Scopes.qll @@ -519,13 +519,7 @@ class Method extends Function { * Gets the receiver base type of this method, that is, either the base type of the receiver type * if it is a pointer type, or the receiver type itself if it is not a pointer type. */ - Type getReceiverBaseType() { - exists(Type recv | recv = this.getReceiverType() | - if recv instanceof PointerType - then result = recv.(PointerType).getBaseType() - else result = recv - ) - } + Type getReceiverBaseType() { result = lookThroughPointerType(this.getReceiverType()) } /** Holds if this method has name `m` and belongs to the method set of type `tp` or `*tp`. */ private predicate isIn(NamedType tp, string m) { diff --git a/go/ql/lib/semmle/go/Types.qll b/go/ql/lib/semmle/go/Types.qll index 4818db2f774d..1b09ea466cc4 100644 --- a/go/ql/lib/semmle/go/Types.qll +++ b/go/ql/lib/semmle/go/Types.qll @@ -446,11 +446,7 @@ class StructType extends @structtype, CompositeType { if n = "" then ( isEmbedded = true and - ( - name = tp.(NamedType).getName() - or - name = tp.(PointerType).getBaseType().(NamedType).getName() - ) + name = lookThroughPointerType(tp).(NamedType).getName() ) else ( isEmbedded = false and name = n @@ -518,9 +514,7 @@ class StructType extends @structtype, CompositeType { this.hasFieldCand(_, embeddedParent, depth - 1, true) and result.getName() = name and ( - result.getReceiverBaseType() = embeddedParent.getType() - or - result.getReceiverBaseType() = embeddedParent.getType().(PointerType).getBaseType() + result.getReceiverBaseType() = lookThroughPointerType(embeddedParent.getType()) or methodhosts(result, embeddedParent.getType()) ) @@ -644,6 +638,16 @@ class PointerType extends @pointertype, CompositeType { override string toString() { result = "pointer type" } } +/** + * Gets the base type if `t` is a pointer type, otherwise `t` itself. + */ +Type lookThroughPointerType(Type t) { + not t instanceof PointerType and + result = t + or + result = t.(PointerType).getBaseType() +} + private newtype TTypeSetTerm = MkTypeSetTerm(TypeSetLiteralType tslit, int index) { component_types(tslit, index, _, _) } diff --git a/go/ql/lib/semmle/go/controlflow/IR.qll b/go/ql/lib/semmle/go/controlflow/IR.qll index b036ddf6d0f5..addd85a36c4c 100644 --- a/go/ql/lib/semmle/go/controlflow/IR.qll +++ b/go/ql/lib/semmle/go/controlflow/IR.qll @@ -358,11 +358,7 @@ module IR { override predicate reads(ValueEntity v) { v = field } - override Type getResultType() { - if field.getType() instanceof PointerType - then result = field.getType().(PointerType).getBaseType() - else result = field.getType() - } + override Type getResultType() { result = lookThroughPointerType(field.getType()) } override ControlFlow::Root getRoot() { result.isRootOf(e) } diff --git a/go/ql/src/InconsistentCode/LengthComparisonOffByOne.ql b/go/ql/src/InconsistentCode/LengthComparisonOffByOne.ql index 39c951f0150b..fbb1965c5f62 100644 --- a/go/ql/src/InconsistentCode/LengthComparisonOffByOne.ql +++ b/go/ql/src/InconsistentCode/LengthComparisonOffByOne.ql @@ -73,7 +73,7 @@ predicate isRegexpMethodCall(DataFlow::MethodCallNode c) { exists(NamedType regexp, Type recvtp | regexp.getName() = "Regexp" and recvtp = c.getReceiver().getType() | - recvtp = regexp or recvtp.(PointerType).getBaseType() = regexp + lookThroughPointerType(recvtp) = regexp ) }