From 856009f142b60c929f06afbb36aa797f44ab8e54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Fri, 2 Dec 2022 10:27:28 +0100 Subject: [PATCH 1/2] Provide some convenience methods on Annotations. --- .../main/scala/tastyquery/Annotations.scala | 99 +++++++++++++++++++ .../src/main/scala/tastyquery/Symbols.scala | 9 ++ .../test/scala/tastyquery/ReadTreeSuite.scala | 14 ++- .../src/test/scala/tastyquery/TypeSuite.scala | 41 ++++++++ .../main/scala/simple_trees/Annotations.scala | 2 +- 5 files changed, 163 insertions(+), 2 deletions(-) diff --git a/tasty-query/shared/src/main/scala/tastyquery/Annotations.scala b/tasty-query/shared/src/main/scala/tastyquery/Annotations.scala index a40d5c7e..1a351374 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Annotations.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Annotations.scala @@ -1,9 +1,108 @@ package tastyquery +import scala.annotation.tailrec + +import tastyquery.Constants.* +import tastyquery.Contexts.* +import tastyquery.Exceptions.* +import tastyquery.Names.* +import tastyquery.Symbols.* import tastyquery.Trees.* +import tastyquery.Types.* object Annotations: final class Annotation(val tree: TermTree): + private var mySymbol: ClassSymbol | Null = null + private var myArguments: List[TermTree] | Null = null + + /** The annotation class symbol. */ + def symbol(using Context): ClassSymbol = + val local = mySymbol + if local != null then local + else + val computed = computeAnnotSymbol(tree) + mySymbol = computed + computed + end symbol + + /** The symbol of the constructor used in the annotation. */ + def annotConstructor(using Context): TermSymbol = + computeAnnotConstructor(tree) + + /** All the term arguments to the annotation's constructor. + * + * If the constructor has several parameter lists, the arguments are + * flattened in a single list. + * + * `NamedArg`s are not visible with this method. They are replaced by + * their right-hand-side. + */ + def arguments(using Context): List[TermTree] = + val local = myArguments + if local != null then local + else + val computed = computeAnnotArguments(tree) + myArguments = computed + computed + end arguments + + def argCount(using Context): Int = arguments.size + + def argIfConstant(idx: Int)(using Context): Option[Constant] = + arguments(idx) match + case Literal(constant) => Some(constant) + case _ => None + override def toString(): String = s"Annotation($tree)" end Annotation + + private def computeAnnotSymbol(tree: TermTree)(using Context): ClassSymbol = + def invalid(): Nothing = + throw InvalidProgramStructureException(s"Cannot find annotation class in $tree") + + @tailrec + def loop(tree: TermTree): ClassSymbol = tree match + case Apply(fun, _) => loop(fun) + case New(tpt) => tpt.toType.classSymbol.getOrElse(invalid()) + case Select(qual, _) => loop(qual) + case TypeApply(fun, _) => loop(fun) + case Block(_, expr) => loop(expr) + case _ => invalid() + + loop(tree) + end computeAnnotSymbol + + private def computeAnnotConstructor(tree: TermTree)(using Context): TermSymbol = + def invalid(): Nothing = + throw InvalidProgramStructureException(s"Cannot find annotation constructor in $tree") + + @tailrec + def loop(tree: TermTree): TermSymbol = tree match + case Apply(fun, _) => loop(fun) + case tree @ Select(New(tpt), _) => tree.tpe.asInstanceOf[TermRef].symbol + case TypeApply(fun, _) => loop(fun) + case Block(_, expr) => loop(expr) + case _ => invalid() + + loop(tree) + end computeAnnotConstructor + + private def computeAnnotArguments(tree: TermTree)(using Context): List[TermTree] = + def invalid(): Nothing = + throw InvalidProgramStructureException(s"Cannot find annotation arguments in $tree") + + @tailrec + def loop(tree: TermTree, tail: List[TermTree]): List[TermTree] = tree match + case Apply(fun, args) => loop(fun, args ::: tail) + case Select(New(tpt), _) => tail + case TypeApply(fun, _) => loop(fun, tail) + case Block(_, expr) => loop(expr, tail) + case New(tpt) => tail // for some ancient TASTy with raw New's + case _ => invalid() + + loop(tree, Nil).map { + case NamedArg(_, arg) => arg + case arg => arg + } + end computeAnnotArguments end Annotations diff --git a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala index c134e905..d8ca4eaa 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala @@ -208,6 +208,15 @@ object Symbols { case scope: ClassSymbol => scope.hasOverloads(name) case _ => false + final def hasAnnotation(annotClass: ClassSymbol)(using Context): Boolean = + annotations.exists(_.symbol == annotClass) + + final def getAnnotations(annotClass: ClassSymbol)(using Context): List[Annotation] = + annotations.filter(_.symbol == annotClass) + + final def getAnnotation(annotClass: ClassSymbol)(using Context): Option[Annotation] = + annotations.find(_.symbol == annotClass) + override def toString: String = { val kind = this match case _: PackageSymbol => "package " diff --git a/tasty-query/shared/src/test/scala/tastyquery/ReadTreeSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/ReadTreeSuite.scala index 7776f401..3d2c345e 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/ReadTreeSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/ReadTreeSuite.scala @@ -2025,6 +2025,16 @@ class ReadTreeSuite extends RestrictedUnpicklingSuite { ) => } + def deprecatedAnnotBothNamedCheck(msg: String, since: String): StructureCheck = { + case Apply( + SimpleAnnotCtorNamed("deprecated"), + List( + NamedArg(SimpleName("message"), Literal(Constant(`msg`))), + NamedArg(SimpleName("since"), Literal(Constant(`since`))) + ) + ) => + } + def implicitNotFoundAnnotCheck(msg: String): StructureCheck = { case Apply(SimpleAnnotCtorNamed("implicitNotFound"), List(Literal(Constant(`msg`)))) => } @@ -2053,7 +2063,9 @@ class ReadTreeSuite extends RestrictedUnpicklingSuite { sym } assert(clue(deprecatedValSym.annotations).sizeIs == 1) - assert(containsSubtree(deprecatedAnnotNamedCheck("reason", "forever"))(clue(deprecatedValSym.annotations(0).tree))) + assert( + containsSubtree(deprecatedAnnotBothNamedCheck("reason", "forever"))(clue(deprecatedValSym.annotations(0).tree)) + ) val myTypeClassSym = findTree(tree) { case ClassDef(TypeName(SimpleName("MyTypeClass")), _, sym) => sym diff --git a/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala index a5926196..49cddbf8 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala @@ -2,6 +2,7 @@ package tastyquery import scala.collection.mutable +import tastyquery.Constants.* import tastyquery.Contexts.* import tastyquery.Flags.* import tastyquery.Names.* @@ -1530,4 +1531,44 @@ class TypeSuite extends UnrestrictedUnpicklingSuite { assert(innerIdentSym.is(ParamAccessor)) } + testWithContext("annotations") { + val AnnotationsClass = ctx.findTopLevelClass("simple_trees.Annotations") + val inlineClass = ctx.findTopLevelClass("scala.inline") + val deprecatedClass = ctx.findTopLevelClass("scala.deprecated") + + locally { + val inlineMethodSym = AnnotationsClass.findNonOverloadedDecl(termName("inlineMethod")) + val List(inlineAnnot) = inlineMethodSym.annotations + assert(clue(inlineAnnot.symbol) == inlineClass) + assert(clue(inlineAnnot.arguments).isEmpty) + + assert(inlineMethodSym.hasAnnotation(inlineClass)) + assert(!inlineMethodSym.hasAnnotation(deprecatedClass)) + + assert(inlineMethodSym.getAnnotations(inlineClass) == List(inlineAnnot)) + assert(inlineMethodSym.getAnnotations(deprecatedClass) == Nil) + + assert(inlineMethodSym.getAnnotation(inlineClass) == Some(inlineAnnot)) + assert(inlineMethodSym.getAnnotation(deprecatedClass) == None) + } + + locally { + val deprecatedValSym = AnnotationsClass.findNonOverloadedDecl(termName("deprecatedVal")) + val List(deprecatedAnnot) = deprecatedValSym.annotations + + assert(clue(deprecatedAnnot.symbol) == deprecatedClass) + assert(clue(deprecatedAnnot.annotConstructor) == deprecatedClass.findNonOverloadedDecl(nme.Constructor)) + assert(clue(deprecatedAnnot.argCount) == 2) + + deprecatedAnnot.arguments match + case List(Literal(Constant("reason")), Literal(Constant("forever"))) => + () // OK + case args => + fail("unexpected arguments", clues(args)) + + assert(clue(deprecatedAnnot.argIfConstant(0)) == Some(Constant("reason"))) + assert(clue(deprecatedAnnot.argIfConstant(1)) == Some(Constant("forever"))) + } + } + } diff --git a/test-sources/src/main/scala/simple_trees/Annotations.scala b/test-sources/src/main/scala/simple_trees/Annotations.scala index ee8a872b..14da048c 100644 --- a/test-sources/src/main/scala/simple_trees/Annotations.scala +++ b/test-sources/src/main/scala/simple_trees/Annotations.scala @@ -10,7 +10,7 @@ class Annotations: @deprecated("some reason", since = "1.0") def inlineDeprecatedMethod(): Unit = () - @deprecated("reason", since = "forever") + @deprecated(since = "forever", message = "reason") val deprecatedVal: Int = 5 @implicitNotFound("Cannot find implicit for MyTypeClass[${T}]") From 44635cdc33ca6e73727a27cda3fd6f95cd8744d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Fri, 2 Dec 2022 11:33:07 +0100 Subject: [PATCH 2/2] Fix #128: Support `@targetName`. --- .../main/scala/tastyquery/Definitions.scala | 5 +++ .../src/main/scala/tastyquery/Symbols.scala | 29 ++++++++++++-- .../scala/tastyquery/SignatureSuite.scala | 13 +++++++ .../test/scala/tastyquery/SymbolSuite.scala | 3 +- .../src/test/scala/tastyquery/TypeSuite.scala | 38 ++++++++++++------- .../scala/simple_trees/GenericMethod.scala | 5 +++ .../scala/simple_trees/OverloadedApply.scala | 9 ++++- 7 files changed, 82 insertions(+), 20 deletions(-) diff --git a/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala b/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala index f8656780..751a02bf 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala @@ -18,6 +18,8 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS private val javaPackage = RootPackage.getPackageDeclOrCreate(nme.javaPackageName) val javaLangPackage = javaPackage.getPackageDeclOrCreate(nme.langPackageName) + private val scalaAnnotationPackage = + scalaPackage.getPackageDeclOrCreate(termName("annotation")) private val scalaCollectionPackage = scalaPackage.getPackageDeclOrCreate(termName("collection")) private val scalaCollectionImmutablePackage = @@ -199,6 +201,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS extension (pkg: PackageSymbol) private def requiredClass(name: String): ClassSymbol = pkg.getDecl(typeName(name)).get.asClass + private def optionalClass(name: String): Option[ClassSymbol] = pkg.getDecl(typeName(name)).map(_.asClass) lazy val ObjectClass = javaLangPackage.requiredClass("Object") @@ -219,6 +222,8 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS lazy val StringClass = javaLangPackage.requiredClass("String") + private[tastyquery] lazy val targetNameAnnotClass = scalaAnnotationPackage.optionalClass("targetName") + def isPrimitiveValueClass(sym: ClassSymbol): Boolean = sym == IntClass || sym == LongClass || sym == FloatClass || sym == DoubleClass || sym == BooleanClass || sym == ByteClass || sym == ShortClass || sym == CharClass || diff --git a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala index d8ca4eaa..ed0b4d90 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala @@ -306,6 +306,7 @@ object Symbols { // Cache fields private var mySignature: Option[Signature] | Null = null + private var myTargetName: TermName | Null = null private var myParamRefss: List[Either[List[TermParamRef], List[TypeParamRef]]] | Null = null protected override def doCheckCompleted(): Unit = @@ -356,11 +357,29 @@ object Symbols { mySignature = sig sig + private[tastyquery] final def targetName(using Context): TermName = + val local = myTargetName + if local != null then local + else + val computed = computeTargetName() + myTargetName = computed + computed + end targetName + + private def computeTargetName()(using Context): TermName = + if annotations.isEmpty then name + else + defn.targetNameAnnotClass match + case None => name + case Some(targetNameAnnotClass) => + getAnnotation(targetNameAnnotClass) match + case None => name + case Some(annot) => termName(annot.argIfConstant(0).get.stringValue) + end computeTargetName + /** If this symbol has a `MethodicType`, returns a `SignedName`, otherwise a `Name`. */ final def signedName(using Context): Name = signature.fold(name) { sig => - val name = this.name.asSimpleName - val targetName = name // TODO We may have to take `@targetName` into account here, one day SignedName(name, sig, targetName) } @@ -679,8 +698,10 @@ object Symbols { myDeclarations.get(overloaded.underlying) match case Some(overloads) => overloads.find { - case decl: TermSymbol => decl.signature.exists(_ == overloaded.sig) - case _ => false + case decl: TermSymbol => + decl.signature.exists(_ == overloaded.sig) && decl.targetName == overloaded.target + case _ => + false } case None => None end distinguishOverloaded diff --git a/tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala index 7230025f..fa2753d9 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala @@ -14,10 +14,16 @@ import TestUtils.* class SignatureSuite extends UnrestrictedUnpicklingSuite: def assertIsSignedName(actual: Name, simpleName: String, signature: String)(using Location): Unit = + assertIsSignedName(actual, simpleName, signature, simpleName) + + def assertIsSignedName(actual: Name, simpleName: String, signature: String, targetName: String)( + using Location + ): Unit = actual match case name: SignedName => assert(clue(name.underlying) == clue(termName(simpleName))) assert(clue(name.sig.toString) == clue(signature)) + assert(clue(name.target) == clue(termName(targetName))) case _ => fail("not a Signed name", clues(actual)) end assertIsSignedName @@ -58,6 +64,13 @@ class SignatureSuite extends UnrestrictedUnpicklingSuite: assertIsSignedName(identity.signedName, "identity", "(1,java.lang.Object):java.lang.Object") } + testWithContext("targetName") { + val GenericMethod = ctx.findTopLevelClass("simple_trees.GenericMethod") + + val identity = GenericMethod.findNonOverloadedDecl(name"otherIdentity") + assertIsSignedName(identity.signedName, "otherIdentity", "(1,java.lang.Object):java.lang.Object", "otherName") + } + testWithContext("JavaInnerClass") { val TreeMap = ctx.findTopLevelClass("java.util.TreeMap") diff --git a/tasty-query/shared/src/test/scala/tastyquery/SymbolSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/SymbolSuite.scala index 1ee4c851..9346f4af 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/SymbolSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/SymbolSuite.scala @@ -15,7 +15,8 @@ import TestUtils.* class SymbolSuite extends RestrictedUnpicklingSuite { /** Needed for correct resolving of ctor signatures */ - val fundamentalClasses: Seq[String] = Seq("java.lang.Object", "scala.Unit", "scala.AnyVal") + val fundamentalClasses: Seq[String] = + Seq("java.lang.Object", "scala.Unit", "scala.AnyVal", "scala.annotation.targetName") def testWithContext(name: String, rootSymbolPath: String, extraRootSymbolPaths: String*)(using munit.Location)( body: Context ?=> Unit diff --git a/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala index 49cddbf8..05b1daae 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala @@ -86,12 +86,11 @@ class TypeSuite extends UnrestrictedUnpicklingSuite { assert(clue(parentClasses) == List(defn.ObjectClass, ProductClass, SerializableClass)) } - def applyOverloadedTest(name: String)(callMethod: String, paramCls: Context ?=> Symbol)(using munit.Location): Unit = + def applyOverloadedTest(name: String)(callMethod: String, checkParamType: Context ?=> Type => Boolean): Unit = testWithContext(name) { val OverloadedApplyClass = ctx.findTopLevelClass("simple_trees.OverloadedApply") val callSym = OverloadedApplyClass.findDecl(termName(callMethod)) - val Acls = paramCls val Some(callTree @ _: DefDef) = callSym.tree: @unchecked @@ -103,31 +102,44 @@ class TypeSuite extends UnrestrictedUnpicklingSuite { callCount += 1 assert(app.tpe.isRef(defn.UnitClass), clue(app)) val fooSym = fooRef.tpe.asInstanceOf[TermRef].symbol - val List(Left(List(aRef)), _*) = fooSym.paramRefss: @unchecked - assert(aRef.isRef(Acls), clues(Acls.fullName, aRef)) + val mt = fooSym.declaredType.asInstanceOf[MethodType] + assert(clue(mt.resultType).isRef(defn.UnitClass)) + assert(checkParamType(clue(mt.paramTypes.head))) case _ => () } assert(callCount == 1) } - applyOverloadedTest("apply-overloaded-int")("callA", defn.IntClass) + applyOverloadedTest("apply-overloaded-int")("callA", _.isRef(defn.IntClass)) applyOverloadedTest("apply-overloaded-gen")( "callB", - ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Box") + _.isApplied( + _.isRef(ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Box")), + List(_.isRef(ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Num"))) + ) ) applyOverloadedTest("apply-overloaded-nestedObj")( "callC", - ctx - .findTopLevelClass("simple_trees.OverloadedApply") - .findDecl(moduleClassName("Foo")) - .asClass - .findDecl(termName("Bar")) + _.isRef( + ctx + .findTopLevelClass("simple_trees.OverloadedApply") + .findDecl(moduleClassName("Foo")) + .asClass + .findDecl(termName("Bar")) + ) ) - applyOverloadedTest("apply-overloaded-arrayObj")("callD", defn.ArrayClass) + applyOverloadedTest("apply-overloaded-arrayObj")("callD", _.isRef(defn.ArrayClass)) applyOverloadedTest("apply-overloaded-byName")( "callE", - ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Num") + _.isRef(ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Num")) + ) + applyOverloadedTest("apply-overloaded-gen-target-name")( + "callG", + _.isApplied( + _.isRef(ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Box")), + List(_.isRef(defn.IntClass)) + ) ) testWithContext("apply-overloaded-not-method") { diff --git a/test-sources/src/main/scala/simple_trees/GenericMethod.scala b/test-sources/src/main/scala/simple_trees/GenericMethod.scala index c246fd25..91e96e59 100644 --- a/test-sources/src/main/scala/simple_trees/GenericMethod.scala +++ b/test-sources/src/main/scala/simple_trees/GenericMethod.scala @@ -1,8 +1,13 @@ package simple_trees +import scala.annotation.targetName + class GenericMethod { def usesTypeParam[T](): Option[T] = None def usesTermParam(i: Int): Option[Int] = None def identity[T](x: T): T = x + + @targetName("otherName") + def otherIdentity[T](x: T): T = x } diff --git a/test-sources/src/main/scala/simple_trees/OverloadedApply.scala b/test-sources/src/main/scala/simple_trees/OverloadedApply.scala index 76e5f5bb..44cfdb31 100644 --- a/test-sources/src/main/scala/simple_trees/OverloadedApply.scala +++ b/test-sources/src/main/scala/simple_trees/OverloadedApply.scala @@ -1,12 +1,14 @@ package simple_trees +import scala.annotation.targetName + class OverloadedApply { object Foo { object Bar } - class Box[T] + class Box[T](val x: T) class Num @@ -17,11 +19,14 @@ class OverloadedApply { def foo(a: => Num): Unit = () def foo: String = "foo" + @targetName("bar") def foo(a: Box[Int]): Unit = () + def callA = foo(1) - def callB = foo(Box()) + def callB = foo(Box(new Num)) def callC = foo(Foo.Bar) def callD = foo(Array(Foo)) def callE = foo(Num()) def callF = foo + def callG = foo(Box(3)) }