Skip to content

Commit

Permalink
Translate Seq[T] @Repeated to <repeated>[T] in Scala 3 `MethodTyp…
Browse files Browse the repository at this point in the history
…e`s.

This aligns how repeated parameters are represented in `MethodType`s
for Scala 3, Scala 2 and Java. It also aligns those types with the
call-site type of the `Typed` nodes.

The `ValDef` for the parameter is still typed as `Seq[T] @Repeated`,
since inside the method it must be treated as a `Seq`.
  • Loading branch information
sjrd committed Mar 31, 2023
1 parent c78ef0e commit 09ef210
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 58 deletions.
32 changes: 29 additions & 3 deletions tasty-query/shared/src/main/scala/tastyquery/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ object Annotations:
case Literal(constant) => Some(constant)
case _ => None

/** Tests whether this annotation points to `defn.internalRepeatedAnnotClass` without resolving anything. */
private[tastyquery] def safeIsInternalRepeatedAnnot(using Context): Boolean =
defn.internalRepeatedAnnotClass match
case None =>
false
case Some(repeatedAnnotClass) =>
val tpt = findNewAnnotTypeTree(tree)
tpt match
// It is compiler-synthetic by definition, so it can only be a TypeWrapper
case TypeWrapper(tpe: TypeRef) =>
if tpe.name != tpnme.internalRepeatedAnnot then false
else
tpe.prefix match
case pkg: PackageRef => pkg.symbol == defn.scalaAnnotationInternalPackage
case _ => false
case _ =>
false
end safeIsInternalRepeatedAnnot

override def toString(): String = s"Annotation($tree)"
end Annotation

Expand Down Expand Up @@ -88,20 +107,27 @@ object Annotations:
end Annotation

private def computeAnnotSymbol(tree: TermTree)(using Context): ClassSymbol =
val tpt = findNewAnnotTypeTree(tree)
tpt.toType.classSymbol.getOrElse {
throw InvalidProgramStructureException(s"Illegal annotation class type $tpt in $tree")
}
end computeAnnotSymbol

private def findNewAnnotTypeTree(tree: TermTree): TypeTree =
def invalid(): Nothing =
throw InvalidProgramStructureException(s"Cannot find annotation class in $tree")

@tailrec
def loop(tree: TermTree): ClassSymbol = tree match
def loop(tree: TermTree): TypeTree = tree match
case Apply(fun, _) => loop(fun)
case New(tpt) => tpt.toType.classSymbol.getOrElse(invalid())
case New(tpt) => tpt
case Select(qual, _) => loop(qual)
case TypeApply(fun, _) => loop(fun)
case Block(_, expr) => loop(expr)
case _ => invalid()

loop(tree)
end computeAnnotSymbol
end findNewAnnotTypeTree

private def computeAnnotConstructor(tree: TermTree)(using Context): TermSymbol =
def invalid(): Nothing =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS

private val scalaAnnotationPackage =
scalaPackage.getPackageDeclOrCreate(termName("annotation"))
private val scalaAnnotationInternalPackage =
private[tastyquery] val scalaAnnotationInternalPackage =
scalaAnnotationPackage.getPackageDeclOrCreate(termName("internal"))
private val scalaCollectionPackage =
scalaPackage.getPackageDeclOrCreate(termName("collection"))
Expand Down
2 changes: 2 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ object Names {
val RepeatedParamClassMagic: TypeName = typeName("<repeated>")

val scala2PackageObjectClass: TypeName = termName("package").withObjectSuffix.toTypeName

private[tastyquery] val internalRepeatedAnnot: TypeName = typeName("Repeated")
}

/** Create a type name from a string */
Expand Down
16 changes: 14 additions & 2 deletions tasty-query/shared/src/main/scala/tastyquery/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1289,12 +1289,24 @@ object Types {
* - add @inlineParam to inline parameters
*/
private[tastyquery] def fromSymbols(params: List[TermSymbol], resultType: Type)(using Context): MethodType = {
def annotatedToRepeated(tpe: Type): Type = tpe match
case tpe: AnnotatedType if tpe.annotation.safeIsInternalRepeatedAnnot =>
tpe.typ match
case applied: AppliedType if applied.args.sizeIs == 1 =>
// We're going to assume that `tycon` is indeed `Seq`, here, because we cannot afford to resolve it
defn.RepeatedTypeOf(applied.args.head)
case _ =>
throw TastyFormatException(s"in $params, $tpe is declared repeated but is not a Seq type")
case _ =>
tpe
end annotatedToRepeated

// def translateInline(tp: Type): Type =
// AnnotatedType(tp, Annotation(defn.InlineParamAnnot))
// def translateErased(tp: Type): Type =
// AnnotatedType(tp, Annotation(defn.ErasedParamAnnot))
def paramInfo(param: TermSymbol) = {
var paramType = param.declaredType //.annotatedToRepeated
def paramInfo(param: TermSymbol): Type = {
var paramType = annotatedToRepeated(param.declaredType)
// if (param.is(Inline)) paramType = translateInline(paramType)
// if (param.is(Erased)) paramType = translateErased(paramType)
paramType
Expand Down
69 changes: 17 additions & 52 deletions tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1149,77 +1149,42 @@ class TypeSuite extends UnrestrictedUnpicklingSuite {
def assertSeqOfInt(tpe: Type): Unit =
assert(clue(tpe).isApplied(t => t.isRef(defn.SeqClass) || t.isRef(scalaSeq), List(_.isRef(defn.IntClass))))

def assertAnnotated(tpe: Type)(assertInner: Type => Unit): Unit = tpe match
def assertAnnotatedSeqOfInt(tpe: Type): Unit = tpe match
case tpe: AnnotatedType =>
assertInner(tpe.typ)
assertSeqOfInt(tpe.typ)
assert(clue(tpe.annotation.symbol) == defn.internalRepeatedAnnotClass.get)
case _ =>
fail("unexpected parameter type", clues(tpe))
end assertAnnotated
end assertAnnotatedSeqOfInt

def assertRepeatedOfInt(tpe: Type): Unit =
assert(clue(tpe).isApplied(_.isRef(defn.RepeatedParamClass), List(_.isRef(defn.IntClass))))

locally {
val dd = getDefOf("takesVarargs")
val List(Left(List(paramValDef))) = dd.paramLists: @unchecked
val (paramType, resultType) = extractParamAndResultType(dd.symbol.declaredType)

assertAnnotated(paramType)(assertSeqOfInt(_))
assertAnnotatedSeqOfInt(paramValDef.symbol.declaredType)
assertRepeatedOfInt(paramType)
}

locally {
val dd = getDefOf("givesVarargs")
val (formal, typed, actual) = extractFormalTypedActualParamTypes(dd.rhs.get)

assertAnnotated(formal)(assertSeqOfInt(_))
assertRepeatedOfInt(typed)
assertSeqOfInt(actual.widen)
}

locally {
val dd = getDefOf("givesSeqLiteral")
val (formal, typed, actual) = extractFormalTypedActualParamTypes(dd.rhs.get)

assertAnnotated(formal)(assertSeqOfInt(_))
assertRepeatedOfInt(typed)
assertSeqOfInt(actual.widen)
}

locally {
val dd = getDefOf("givesSeqToJava")
val (formal, typed, actual) = extractFormalTypedActualParamTypes(dd.rhs.get)

assertRepeatedOfInt(formal)
assertRepeatedOfInt(typed)
assertSeqOfInt(actual.widen)
}

locally {
val dd = getDefOf("givesSeqLiteralToJava")
val (formal, typed, actual) = extractFormalTypedActualParamTypes(dd.rhs.get)

assertRepeatedOfInt(formal)
assertRepeatedOfInt(typed)
assertSeqOfInt(actual.widen)
}

locally {
val dd = getDefOf("givesSeqToScala2")
val (formal, typed, actual) = extractFormalTypedActualParamTypes(dd.rhs.get)

assertRepeatedOfInt(formal)
assertRepeatedOfInt(typed)
assertSeqOfInt(actual.widen)
}

locally {
val dd = getDefOf("givesSeqLiteralToScala2")
val testMethodNames = List(
"givesVarargs",
"givesSeqLiteral",
"givesSeqToJava",
"givesSeqLiteralToJava",
"givesSeqToScala2",
"givesSeqLiteralToScala2"
)
for testMethodName <- testMethodNames do
val dd = getDefOf(testMethodName)
val (formal, typed, actual) = extractFormalTypedActualParamTypes(dd.rhs.get)

assertRepeatedOfInt(formal)
assertRepeatedOfInt(typed)
assertSeqOfInt(actual.widen)
}
end for
}

testWithContext("scala2-class-type-param-ref") {
Expand Down

0 comments on commit 09ef210

Please sign in to comment.