From 8b261844d80f5b56b90d23b2dc50476d1ed48863 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Tue, 7 May 2024 10:24:32 +0200 Subject: [PATCH] Improve mapping and pickling of annotated types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Annotation.mapWith` maps an `Annotation` with a type map `tm`. As an optimization, this function first checks if `tm` would result in any change (by traversing the annotation’s argument trees with a `TreeAccumulator`) before applying `tm` to the whole annotation tree. This optimization had two problems: 1. it didn’t include type parameters, and 2. it used `frozen_=:=` to compare types, which didn’t work as expected with `NoType`. This commit fixes these issues. Additionally, positions of trees that appear only inside `AnnotatedType` were not pickled. This commit also fixes this. [Cherry-picked ac76938c1e596efcaffa0f5d64e224bdc8be72f0] --- .../src/dotty/tools/dotc/ast/TreeInfo.scala | 10 +++++- .../dotty/tools/dotc/core/Annotations.scala | 11 ++++--- .../dotc/core/tasty/PositionPickler.scala | 4 +++ .../tools/dotc/core/tasty/TreePickler.scala | 7 ++++ .../tools/dotc/quoted/PickledQuotes.scala | 2 +- .../dotty/tools/dotc/transform/Pickler.scala | 2 +- tests/pos/annot-17939b.scala | 10 ++++++ tests/pos/annot-18064.scala | 9 +++++ tests/pos/annot-5789.scala | 10 ++++++ tests/printing/annot-18064.check | 16 +++++++++ tests/printing/annot-18064.scala | 9 +++++ tests/printing/annot-19846b.check | 33 +++++++++++++++++++ tests/printing/annot-19846b.scala | 7 ++++ 13 files changed, 123 insertions(+), 7 deletions(-) create mode 100644 tests/pos/annot-17939b.scala create mode 100644 tests/pos/annot-18064.scala create mode 100644 tests/pos/annot-5789.scala create mode 100644 tests/printing/annot-18064.check create mode 100644 tests/printing/annot-18064.scala create mode 100644 tests/printing/annot-19846b.check create mode 100644 tests/printing/annot-19846b.scala diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index bf5b83507401..465fa2b765f8 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -135,9 +135,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] => loop(tree, Nil) /** All term arguments of an application in a single flattened list */ + def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match { + case Apply(fn, args) => allArguments(fn) ::: args + case TypeApply(fn, args) => allArguments(fn) + case Block(_, expr) => allArguments(expr) + case _ => Nil + } + + /** All type and term arguments of an application in a single flattened list */ def allArguments(tree: Tree): List[Tree] = unsplice(tree) match { case Apply(fn, args) => allArguments(fn) ::: args - case TypeApply(fn, _) => allArguments(fn) + case TypeApply(fn, args) => allArguments(fn) ::: args case Block(_, expr) => allArguments(expr) case _ => Nil } diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index 137c084d03a6..8f9f37530dcb 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -30,8 +30,8 @@ object Annotations { def derivedAnnotation(tree: Tree)(using Context): Annotation = if (tree eq this.tree) this else Annotation(tree) - /** All arguments to this annotation in a single flat list */ - def arguments(using Context): List[Tree] = tpd.allArguments(tree) + /** All term arguments of this annotation in a single flat list */ + def arguments(using Context): List[Tree] = tpd.allTermArguments(tree) def argument(i: Int)(using Context): Option[Tree] = { val args = arguments @@ -54,15 +54,18 @@ object Annotations { * type, since ranges cannot be types of trees. */ def mapWith(tm: TypeMap)(using Context) = - val args = arguments + val args = tpd.allArguments(tree) if args.isEmpty then this else + // Checks if `tm` would result in any change by applying it to types + // inside the annotations' arguments and checking if the resulting types + // are different. val findDiff = new TreeAccumulator[Type]: def apply(x: Type, tree: Tree)(using Context): Type = if tm.isRange(x) then x else val tp1 = tm(tree.tpe) - foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree) + foldOver(if !tp1.exists || (tp1 frozen_=:= tree.tpe) then x else tp1, tree) val diff = findDiff(NoType, args) if tm.isRange(diff) then EmptyAnnotation else if diff.exists then derivedAnnotation(tm.mapOver(tree)) diff --git a/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala index 86076517021a..3d8080e72a29 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala @@ -33,6 +33,7 @@ object PositionPickler: pickler: TastyPickler, addrOfTree: TreeToAddr, treeAnnots: untpd.MemberDef => List[tpd.Tree], + typeAnnots: List[tpd.Tree], relativePathReference: String, source: SourceFile, roots: List[Tree], @@ -136,6 +137,9 @@ object PositionPickler: } for (root <- roots) traverse(root, NoSource) + + for annotTree <- typeAnnots do + traverse(annotTree, NoSource) end picklePositions end PositionPickler diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index 4a6904c22f8a..75fcaaf544a3 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -39,6 +39,10 @@ class TreePickler(pickler: TastyPickler) { */ private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]() + /** A set of annotation trees appearing in annotated types. + */ + private val annotatedTypeTrees = mutable.ListBuffer[Tree]() + /** A map from member definitions to their doc comments, so that later * parallel comment pickling does not need to access symbols of trees (which * would involve accessing symbols of named types and possibly changing phases @@ -52,6 +56,8 @@ class TreePickler(pickler: TastyPickler) { val ts = annotTrees.lookup(tree) if ts == null then Nil else ts.toList + def typeAnnots: List[Tree] = annotatedTypeTrees.toList + def docString(tree: untpd.MemberDef): Option[Comment] = Option(docStrings.lookup(tree)) @@ -262,6 +268,7 @@ class TreePickler(pickler: TastyPickler) { case tpe: AnnotatedType => writeByte(ANNOTATEDtype) withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) } + annotatedTypeTrees += tpe.annot.tree case tpe: AndType => writeByte(ANDtype) withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) } diff --git a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala index a9b66fc056e2..45712908ce0e 100644 --- a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala +++ b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala @@ -223,7 +223,7 @@ object PickledQuotes { if tree.span.exists then val positionWarnings = new mutable.ListBuffer[Message]() val reference = ctx.settings.sourceroot.value - PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference, + PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference, ctx.compilationUnit.source, tree :: Nil, positionWarnings) positionWarnings.foreach(report.warning(_)) diff --git a/compiler/src/dotty/tools/dotc/transform/Pickler.scala b/compiler/src/dotty/tools/dotc/transform/Pickler.scala index 5ab289e510aa..61781dde13db 100644 --- a/compiler/src/dotty/tools/dotc/transform/Pickler.scala +++ b/compiler/src/dotty/tools/dotc/transform/Pickler.scala @@ -99,7 +99,7 @@ class Pickler extends Phase { if tree.span.exists then val reference = ctx.settings.sourceroot.value PositionPickler.picklePositions( - pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference, + pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference, unit.source, tree :: Nil, positionWarnings, scratch.positionBuffer, scratch.pickledIndices) diff --git a/tests/pos/annot-17939b.scala b/tests/pos/annot-17939b.scala new file mode 100644 index 000000000000..a48f4690d0b2 --- /dev/null +++ b/tests/pos/annot-17939b.scala @@ -0,0 +1,10 @@ +import scala.annotation.Annotation +class myRefined(f: ? => Boolean) extends Annotation + +def test(axes: Int) = true + +trait Tensor: + def mean(axes: Int): Int @myRefined(_ => test(axes)) + +class TensorImpl() extends Tensor: + def mean(axes: Int) = ??? diff --git a/tests/pos/annot-18064.scala b/tests/pos/annot-18064.scala new file mode 100644 index 000000000000..b6a67ea9ebe7 --- /dev/null +++ b/tests/pos/annot-18064.scala @@ -0,0 +1,9 @@ +//> using options "-Xprint:typer" + +class myAnnot[T]() extends annotation.Annotation + +trait Tensor[T]: + def add: Tensor[T] @myAnnot[T]() + +class TensorImpl[A]() extends Tensor[A]: + def add /* : Tensor[A] @myAnnot[A] */ = this diff --git a/tests/pos/annot-5789.scala b/tests/pos/annot-5789.scala new file mode 100644 index 000000000000..bdf4438c9d5d --- /dev/null +++ b/tests/pos/annot-5789.scala @@ -0,0 +1,10 @@ +class Annot[T] extends scala.annotation.Annotation + +class D[T](val f: Int@Annot[T]) + +object A{ + def main(a:Array[String]) = { + val c = new D[Int](1) + c.f + } +} diff --git a/tests/printing/annot-18064.check b/tests/printing/annot-18064.check new file mode 100644 index 000000000000..d93ddb95afee --- /dev/null +++ b/tests/printing/annot-18064.check @@ -0,0 +1,16 @@ +[[syntax trees at end of typer]] // tests/printing/annot-18064.scala +package { + class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() { + T + } + trait Tensor[T >: Nothing <: Any]() extends Object { + T + def add: Tensor[Tensor.this.T] @myAnnot[T] + } + class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[ + TensorImpl.this.A] { + A + def add: Tensor[A] @myAnnot[A] = this + } +} + diff --git a/tests/printing/annot-18064.scala b/tests/printing/annot-18064.scala new file mode 100644 index 000000000000..b6a67ea9ebe7 --- /dev/null +++ b/tests/printing/annot-18064.scala @@ -0,0 +1,9 @@ +//> using options "-Xprint:typer" + +class myAnnot[T]() extends annotation.Annotation + +trait Tensor[T]: + def add: Tensor[T] @myAnnot[T]() + +class TensorImpl[A]() extends Tensor[A]: + def add /* : Tensor[A] @myAnnot[A] */ = this diff --git a/tests/printing/annot-19846b.check b/tests/printing/annot-19846b.check new file mode 100644 index 000000000000..3f63a46c4286 --- /dev/null +++ b/tests/printing/annot-19846b.check @@ -0,0 +1,33 @@ +[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala +package { + class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(), + annotation.StaticAnnotation { + private[this] val g: () => Int + } + final lazy module val Test: Test = new Test() + final module class Test() extends Object() { this: Test.type => + val y: Int = ??? + val z: + Int @lambdaAnnot( + { + def $anonfun(): Int = Test.y + closure($anonfun) + } + ) + = f(Test.y) + } + final lazy module val annot-19846b$package: annot-19846b$package = + new annot-19846b$package() + final module class annot-19846b$package() extends Object() { + this: annot-19846b$package.type => + def f(x: Int): + Int @lambdaAnnot( + { + def $anonfun(): Int = x + closure($anonfun) + } + ) + = x + } +} + diff --git a/tests/printing/annot-19846b.scala b/tests/printing/annot-19846b.scala new file mode 100644 index 000000000000..951a3c8116ff --- /dev/null +++ b/tests/printing/annot-19846b.scala @@ -0,0 +1,7 @@ +class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation + +def f(x: Int): Int @lambdaAnnot(() => x) = x + +object Test: + val y: Int = ??? + val z /* : Int @lambdaAnnot(() => y) */ = f(y)