Skip to content

Commit

Permalink
Merge pull request #281 from sjrd/consistent-varargs-representation
Browse files Browse the repository at this point in the history
Align the representation of varargs for all source languages
  • Loading branch information
sjrd authored Mar 31, 2023
2 parents 67fcd15 + 09ef210 commit 1086459
Show file tree
Hide file tree
Showing 17 changed files with 159 additions and 72 deletions.
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ lazy val tastyQuery =
import com.typesafe.tools.mima.core.*
Seq(
// private[tastyquery], so this is fine
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Signatures#Signature.fromType"),
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Symbols#ClassSymbol.createRefinedClassSymbol"),
ProblemFilters.exclude[FinalClassProblem]("tastyquery.TypeOps$AsSeenFromMap"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("tastyquery.TypeOps#AsSeenFromMap.this"),
)
Expand Down
32 changes: 29 additions & 3 deletions tasty-query/shared/src/main/scala/tastyquery/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ object Annotations:
case Literal(constant) => Some(constant)
case _ => None

/** Tests whether this annotation points to `defn.internalRepeatedAnnotClass` without resolving anything. */
private[tastyquery] def safeIsInternalRepeatedAnnot(using Context): Boolean =
defn.internalRepeatedAnnotClass match
case None =>
false
case Some(repeatedAnnotClass) =>
val tpt = findNewAnnotTypeTree(tree)
tpt match
// It is compiler-synthetic by definition, so it can only be a TypeWrapper
case TypeWrapper(tpe: TypeRef) =>
if tpe.name != tpnme.internalRepeatedAnnot then false
else
tpe.prefix match
case pkg: PackageRef => pkg.symbol == defn.scalaAnnotationInternalPackage
case _ => false
case _ =>
false
end safeIsInternalRepeatedAnnot

override def toString(): String = s"Annotation($tree)"
end Annotation

Expand Down Expand Up @@ -88,20 +107,27 @@ object Annotations:
end Annotation

private def computeAnnotSymbol(tree: TermTree)(using Context): ClassSymbol =
val tpt = findNewAnnotTypeTree(tree)
tpt.toType.classSymbol.getOrElse {
throw InvalidProgramStructureException(s"Illegal annotation class type $tpt in $tree")
}
end computeAnnotSymbol

private def findNewAnnotTypeTree(tree: TermTree): TypeTree =
def invalid(): Nothing =
throw InvalidProgramStructureException(s"Cannot find annotation class in $tree")

@tailrec
def loop(tree: TermTree): ClassSymbol = tree match
def loop(tree: TermTree): TypeTree = tree match
case Apply(fun, _) => loop(fun)
case New(tpt) => tpt.toType.classSymbol.getOrElse(invalid())
case New(tpt) => tpt
case Select(qual, _) => loop(qual)
case TypeApply(fun, _) => loop(fun)
case Block(_, expr) => loop(expr)
case _ => invalid()

loop(tree)
end computeAnnotSymbol
end findNewAnnotTypeTree

private def computeAnnotConstructor(tree: TermTree)(using Context): TermSymbol =
def invalid(): Nothing =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS

private val scalaAnnotationPackage =
scalaPackage.getPackageDeclOrCreate(termName("annotation"))
private val scalaAnnotationInternalPackage =
private[tastyquery] val scalaAnnotationInternalPackage =
scalaAnnotationPackage.getPackageDeclOrCreate(termName("internal"))
private val scalaCollectionPackage =
scalaPackage.getPackageDeclOrCreate(termName("collection"))
Expand Down Expand Up @@ -48,6 +48,9 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
val SeqTypeUnapplied: TypeRef = TypeRef(scalaCollectionImmutablePackage.packageRef, typeName("Seq"))
def SeqTypeOf(tpe: Type): AppliedType = AppliedType(SeqTypeUnapplied, List(tpe))

val RepeatedTypeUnapplied: TypeRef = TypeRef(scalaPackage.packageRef, tpnme.RepeatedParamClassMagic)
def RepeatedTypeOf(tpe: Type): AppliedType = AppliedType(RepeatedTypeUnapplied, List(tpe))

val IntType: TypeRef = TypeRef(scalaPackage.packageRef, typeName("Int"))
val LongType: TypeRef = TypeRef(scalaPackage.packageRef, typeName("Long"))
val FloatType: TypeRef = TypeRef(scalaPackage.packageRef, typeName("Float"))
Expand Down
16 changes: 12 additions & 4 deletions tasty-query/shared/src/main/scala/tastyquery/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ private[tastyquery] object Erasure:
// - use correct type erasure algorithm from Scala 3, with specialisations
// for Java types and Scala 2 types (i.e. varargs, value-classes)

@deprecated("use the overload that takes an explicit SourceLanguage", since = "0.7.1")
def erase(tpe: Type)(using Context): ErasedTypeRef =
erase(tpe, SourceLanguage.Scala3)

def erase(tpe: Type, language: SourceLanguage)(using Context): ErasedTypeRef =
given SourceLanguage = language
tpe match
case _: ByNameType => ClassRef(defn.Function0Class)
case _ => finishErase(preErase(tpe))
Expand All @@ -25,7 +30,10 @@ private[tastyquery] object Erasure:
* In particular, `Any` is preserved as `Any`, instead of becoming
* `java.lang.Object`.
*/
private def preErase(tpe: Type)(using Context): ErasedTypeRef =
private def preErase(tpe: Type)(using Context, SourceLanguage): ErasedTypeRef =
def hasArrayErasure(cls: ClassSymbol): Boolean =
cls == defn.ArrayClass || (cls == defn.RepeatedParamClass && summon[SourceLanguage] == SourceLanguage.Java)

def arrayOfBounds(bounds: TypeBounds): ErasedTypeRef =
preErase(bounds.high) match
case ClassRef(cls) if cls == defn.AnyClass || cls == defn.AnyValClass =>
Expand All @@ -37,7 +45,7 @@ private[tastyquery] object Erasure:
case tpe: AppliedType =>
tpe.tycon match
case TypeRef.OfClass(cls) =>
if cls == defn.ArrayClass then
if hasArrayErasure(cls) then
val List(targ) = tpe.args: @unchecked
arrayOf(targ).arrayOf()
else ClassRef(cls).arrayOf()
Expand All @@ -64,7 +72,7 @@ private[tastyquery] object Erasure:
case tpe: AppliedType =>
tpe.tycon match
case TypeRef.OfClass(cls) =>
if cls == defn.ArrayClass then
if hasArrayErasure(cls) then
val List(targ) = tpe.args: @unchecked
arrayOf(targ)
else ClassRef(cls)
Expand Down Expand Up @@ -101,7 +109,7 @@ private[tastyquery] object Erasure:
val ctor = cls.findNonOverloadedDecl(nme.Constructor)
val List(Left(List(paramRef))) = ctor.paramRefss.dropWhile(_.isRight): @unchecked
val paramType = paramRef.underlying
erase(paramType)
erase(paramType, ctor.sourceLanguage)

typeRef match
case ClassRef(cls) =>
Expand Down
2 changes: 2 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ object Flags:
val Infix: Flag = newFlag("Infix")
val Inline: Flag = newFlag("Inline")
val InlineProxy: Flag = newFlag("InlineProxy")
val JavaDefined: Flag = newFlag("JavaDefined")
val Lazy: Flag = newFlag("Lazy")
val Local: Flag = newFlag("Local")
val Macro: Flag = newFlag("Macro")
Expand All @@ -75,6 +76,7 @@ object Flags:
val ParamAccessor: Flag = newFlag("ParamAccessor")
val Private: Flag = newFlag("Private")
val Protected: Flag = newFlag("Protected")
val Scala2Defined: Flag = newFlag("Scala2Defined")
val Sealed: Flag = newFlag("Sealed")
val SuperParamAlias: Flag = newFlag("SuperParamAlias")
val Static: Flag = newFlag("Static")
Expand Down
2 changes: 2 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ object Names {
val RepeatedParamClassMagic: TypeName = typeName("<repeated>")

val scala2PackageObjectClass: TypeName = termName("package").withObjectSuffix.toTypeName

private[tastyquery] val internalRepeatedAnnot: TypeName = typeName("Repeated")
}

/** Create a type name from a string */
Expand Down
8 changes: 5 additions & 3 deletions tasty-query/shared/src/main/scala/tastyquery/Signatures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@ object Signatures:
end Signature

object Signature {
private[tastyquery] def fromType(info: Type, optCtorReturn: Option[ClassSymbol])(using Context): Signature =
private[tastyquery] def fromType(info: Type, language: SourceLanguage, optCtorReturn: Option[ClassSymbol])(
using Context
): Signature =
def rec(info: Type, acc: List[ParamSig]): Signature =
info match {
case info: MethodType =>
val erased = info.paramTypes.map(tpe => ParamSig.Term(ErasedTypeRef.erase(tpe).toSigFullName))
val erased = info.paramTypes.map(tpe => ParamSig.Term(ErasedTypeRef.erase(tpe, language).toSigFullName))
rec(info.resultType, acc ::: erased)
case info: PolyType =>
rec(info.resultType, acc ::: ParamSig.TypeLen(info.paramTypeBounds.length) :: Nil)
case tpe =>
val retType = optCtorReturn.map(_.typeRef).getOrElse(tpe)
Signature(acc, ErasedTypeRef.erase(retType).toSigFullName)
Signature(acc, ErasedTypeRef.erase(retType, language).toSigFullName)
}

rec(info, Nil)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package tastyquery

/** Source language of a symbol. */
enum SourceLanguage:
case Java, Scala2, Scala3
end SourceLanguage
22 changes: 17 additions & 5 deletions tasty-query/shared/src/main/scala/tastyquery/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ object Symbols {
}

sealed abstract class TermOrTypeSymbol(override val owner: Symbol) extends Symbol(owner):
/** The source language in which this symbol was defined.
*
* The source language of a symbol may have an influence on how it is
* erased, and therefore on how its signature is computed.
*/
final def sourceLanguage: SourceLanguage =
if flags.is(JavaDefined) then SourceLanguage.Java
else if flags.is(Scala2Defined) then SourceLanguage.Scala2
else SourceLanguage.Scala3

// Overriding relationships

/** The non-private symbol whose name and type matches the type of this symbol in the given class.
Expand Down Expand Up @@ -387,7 +397,7 @@ object Symbols {
val local = mySignature
if local != null then local
else
val sig = Signature.fromType(declaredType, Option.when(isConstructor)(owner.asClass))
val sig = Signature.fromType(declaredType, sourceLanguage, Option.when(isConstructor)(owner.asClass))
mySignature = sig
sig
end signature
Expand Down Expand Up @@ -1198,17 +1208,19 @@ object Symbols {
private[tastyquery] def create(name: TypeName, owner: Symbol): ClassSymbol =
owner.addDeclIfDeclaringSym(ClassSymbol(name, owner))

private[tastyquery] def createRefinedClassSymbol(owner: Symbol, span: Span)(using Context): ClassSymbol =
private[tastyquery] def createRefinedClassSymbol(owner: Symbol, flags: FlagSet, span: Span)(
using Context
): ClassSymbol =
// TODO Store the `span`
createRefinedClassSymbol(owner)
createRefinedClassSymbol(owner, flags)

private[tastyquery] def createRefinedClassSymbol(owner: Symbol)(using Context): ClassSymbol =
private[tastyquery] def createRefinedClassSymbol(owner: Symbol, flags: FlagSet)(using Context): ClassSymbol =
val cls = ClassSymbol(tpnme.RefinedClassMagic, owner) // by-pass `owner.addDeclIfDeclaringSym`
cls
.withTypeParams(Nil)
.withParentsDirect(defn.ObjectType :: Nil)
.withGivenSelfType(None)
.withFlags(EmptyFlagSet, None)
.withFlags(flags, None)
.setAnnotations(Nil)
cls.checkCompleted()
cls
Expand Down
22 changes: 19 additions & 3 deletions tasty-query/shared/src/main/scala/tastyquery/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ object Types {
end ErasedTypeRef

object ErasedTypeRef:
@deprecated("use the overload that takes an explicit SourceLanguage", since = "0.7.1")
def erase(tpe: Type)(using Context): ErasedTypeRef =
Erasure.erase(tpe)
erase(tpe, SourceLanguage.Scala3)

def erase(tpe: Type, language: SourceLanguage)(using Context): ErasedTypeRef =
Erasure.erase(tpe, language)
end ErasedTypeRef

private[tastyquery] enum ResolveMemberResult:
Expand Down Expand Up @@ -1285,12 +1289,24 @@ object Types {
* - add @inlineParam to inline parameters
*/
private[tastyquery] def fromSymbols(params: List[TermSymbol], resultType: Type)(using Context): MethodType = {
def annotatedToRepeated(tpe: Type): Type = tpe match
case tpe: AnnotatedType if tpe.annotation.safeIsInternalRepeatedAnnot =>
tpe.typ match
case applied: AppliedType if applied.args.sizeIs == 1 =>
// We're going to assume that `tycon` is indeed `Seq`, here, because we cannot afford to resolve it
defn.RepeatedTypeOf(applied.args.head)
case _ =>
throw TastyFormatException(s"in $params, $tpe is declared repeated but is not a Seq type")
case _ =>
tpe
end annotatedToRepeated

// def translateInline(tp: Type): Type =
// AnnotatedType(tp, Annotation(defn.InlineParamAnnot))
// def translateErased(tp: Type): Type =
// AnnotatedType(tp, Annotation(defn.ErasedParamAnnot))
def paramInfo(param: TermSymbol) = {
var paramType = param.declaredType //.annotatedToRepeated
def paramInfo(param: TermSymbol): Type = {
var paramType = annotatedToRepeated(param.declaredType)
// if (param.is(Inline)) paramType = translateInline(paramType)
// if (param.is(Erased)) paramType = translateErased(paramType)
paramType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private[reader] object ClassfileParser {
def privateWithin(access: AccessFlags): Option[Symbol] =
if access.isPackagePrivate then Some(pkg) else None

val clsFlags = structure.access.toFlags
val clsFlags = structure.access.toFlags | JavaDefined
val clsPrivateWithin = privateWithin(structure.access)

val moduleClass = ClassSymbol
Expand All @@ -178,7 +178,7 @@ private[reader] object ClassfileParser {
given InnerClasses = innerClassesStrict

def createMember(name: SimpleName, baseFlags: FlagSet, access: AccessFlags): TermSymbol =
val flags = baseFlags | access.toFlags
val flags = baseFlags | access.toFlags | JavaDefined
val owner = if flags.is(Flags.Static) then moduleClass else cls
val sym = TermSymbol.create(name, owner).withFlags(flags, privateWithin(access))
sym.setAnnotations(Nil) // TODO Read Java annotations on fields and methods
Expand Down Expand Up @@ -237,27 +237,46 @@ private[reader] object ClassfileParser {
else parsedType
sym.withDeclaredType(adaptedType)

for sym <- allRegisteredSymbols do sym.checkCompleted()
for sym <- allRegisteredSymbols do
sym.checkCompleted()
assert(sym.flags.is(JavaDefined), s"$sym of ${sym.getClass()}")

innerClasses.declarations
}

private def patchForVarargs(sym: TermSymbol, tpe: Type)(using Context): Type =
tpe match
case tpe: MethodType if tpe.paramNames.sizeIs >= 1 =>
defn.internalRepeatedAnnotClass match
case Some(annotClass) =>
val patchedLast = AnnotatedType(tpe.paramTypes.last, TQAnnotation(annotClass))
tpe.derivedLambdaType(tpe.paramNames, tpe.paramTypes.init :+ patchedLast, tpe.resultType)
case None =>
// Warn here? How?
tpe
val patchedLast = tpe.paramTypes.last match
case ArrayTypeExtractor(lastElemType) =>
defn.RepeatedTypeOf(lastElemType)
case _ =>
throw ClassfileFormatException(s"Found ACC_VARARGS on $sym but its last param type was not an array: $tpe")
tpe.derivedLambdaType(tpe.paramNames, tpe.paramTypes.init :+ patchedLast, tpe.resultType)
case tpe: PolyType =>
tpe.derivedLambdaType(tpe.paramNames, tpe.paramTypeBounds, patchForVarargs(sym, tpe.resultType))
case _ =>
throw ClassfileFormatException(s"Found ACC_VARARGS on $sym but its type was not a MethodType: $tpe")
end patchForVarargs

/** Extracts `elemType` from `AppliedType(scala.Array, List(elemType))`.
*
* This works for array types created by `defn.ArrayTypeOf(elemType)`, but
* is not otherwise guaranteed to work in all situations.
*/
private object ArrayTypeExtractor:
def unapply(tpe: AppliedType)(using Context): Option[Type] =
tpe.tycon match
case tycon: TypeRef if tycon.name == tpnme.Array && tpe.args.sizeIs == 1 =>
tycon.prefix match
case prefix: PackageRef if prefix.symbol == defn.scalaPackage =>
Some(tpe.args.head)
case _ =>
None
case _ =>
None
end ArrayTypeExtractor

private def parse(classRoot: ClassData, structure: Structure)(using Context): ClassKind = {
import structure.{reader, given}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ private[classfiles] object JavaSignatures:
val tparams = tparamNames.map { tname =>
val paramSym = ClassTypeParamSymbol.create(tname, cls)
allRegisteredSymbols += paramSym
paramSym.withFlags(ClassTypeParam, None).setAnnotations(Nil)
paramSym.withFlags(ClassTypeParam | JavaDefined, None).setAnnotations(Nil)
paramSym
}
val lookup = tparamNames.lazyZip(tparams).toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ private[pickles] class PickleReader {
case CLASSsym =>
val tname = name.toTypeName
val cls =
if tname == tpnme.RefinedClassMagic then ClassSymbol.createRefinedClassSymbol(owner)
if tname == tpnme.RefinedClassMagic then ClassSymbol.createRefinedClassSymbol(owner, Scala2Defined)
else ClassSymbol.create(name.toTypeName, owner)
storeResultInEntries(cls)
val tpe = readSymType()
Expand Down Expand Up @@ -285,7 +285,7 @@ private[pickles] class PickleReader {
PickleFlagSet(pkl.readLongNat(), isType)

private def pickleFlagsToFlags(pickleFlags: PickleFlagSet): FlagSet = {
var flags: FlagSet = EmptyFlagSet
var flags: FlagSet = Scala2Defined

if pickleFlags.isProtected then flags |= Protected
if pickleFlags.isOverride then flags |= Override
Expand Down
Loading

0 comments on commit 1086459

Please sign in to comment.