Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #405: Completely overhaul erasure of value classes. #408

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ lazy val tastyQuery =
import com.typesafe.tools.mima.core.*
Seq(
// private, not an issue
ProblemFilters.exclude[MissingClassProblem]("tastyquery.Erasure$ErasedValueClass"),
ProblemFilters.exclude[MissingClassProblem]("tastyquery.Erasure$ErasedValueClass$"),
ProblemFilters.exclude[MissingClassProblem]("tastyquery.TypeOps$TypeFold"),
// private[tastyquery], not an issue
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Signatures#Signature.toSigName"),
// Everything in tastyquery.reader is private[tastyquery] at most
ProblemFilters.exclude[Problem]("tastyquery.reader.*"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,15 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
lazy val CharClass = scalaPackage.requiredClass("Char")
lazy val UnitClass = scalaPackage.requiredClass("Unit")

private[tastyquery] lazy val BoxedBooleanClass = javaLangPackage.requiredClass("Boolean")
private[tastyquery] lazy val BoxedCharClass = javaLangPackage.requiredClass("Character")
private[tastyquery] lazy val BoxedByteClass = javaLangPackage.requiredClass("Byte")
private[tastyquery] lazy val BoxedShortClass = javaLangPackage.requiredClass("Short")
private[tastyquery] lazy val BoxedIntClass = javaLangPackage.requiredClass("Integer")
private[tastyquery] lazy val BoxedLongClass = javaLangPackage.requiredClass("Long")
private[tastyquery] lazy val BoxedFloatClass = javaLangPackage.requiredClass("Float")
private[tastyquery] lazy val BoxedDoubleClass = javaLangPackage.requiredClass("Double")

lazy val StringClass = javaLangPackage.requiredClass("String")

lazy val ProductClass = scalaPackage.requiredClass("Product")
Expand Down
244 changes: 184 additions & 60 deletions tasty-query/shared/src/main/scala/tastyquery/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ import tastyquery.Types.*
import tastyquery.Types.ErasedTypeRef.*

private[tastyquery] object Erasure:
// TODO: improve this to match dotty:
// - use correct type erasure algorithm from Scala 3, with specialisations
// for Java types and Scala 2 types (i.e. varargs, value-classes)

@deprecated("use the overload that takes an explicit SourceLanguage", since = "0.7.1")
def erase(tpe: Type)(using Context): ErasedTypeRef =
erase(tpe, SourceLanguage.Scala3)
Expand All @@ -27,44 +23,41 @@ private[tastyquery] object Erasure:
finishErase(preErase(tpe, keepUnit))
end erase

/** First pass of erasure, where some special types are preserved as is.
private[tastyquery] def eraseForSigName(tpe: Type, language: SourceLanguage, keepUnit: Boolean)(
using Context
): ErasedTypeRef =
given SourceLanguage = language

val patchedPreErased = preErase(tpe, keepUnit) match
case ArrayTypeRef(ClassRef(cls), dimensions) if cls.isDerivedValueClass =>
// Hack! dotc's `sigName` does *not* correspond to erasure in this case!
val patchedBase =
if cls.typeParams.isEmpty then preEraseMonoValueClass(cls)
else preErasePolyValueClass(cls, cls.typeParams.map(_.localRef))
patchedBase.underlying.multiArrayOf(dimensions)
case typeRef =>
typeRef

finishErase(patchedPreErased)
end eraseForSigName

private final case class ErasedValueClass(valueClass: ClassSymbol, underlying: ErasedTypeRef)

private type PreErasedTypeRef = ErasedTypeRef | ErasedValueClass

/** First pass of erasure, where some special types are preserved as is,
* and where value classes become `ErasedValueClass`es.
*
* In particular, `Any` is preserved as `Any`, instead of becoming
* `java.lang.Object`.
*/
private def preErase(tpe: Type, keepUnit: Boolean)(using Context, SourceLanguage): ErasedTypeRef =
def arrayOfBounds(bounds: TypeBounds): ErasedTypeRef =
preErase(bounds.high, keepUnit = false) match
case ClassRef(cls) if cls.isAny || cls.isAnyVal =>
ClassRef(defn.ObjectClass)
case typeRef =>
typeRef.arrayOf()

def arrayOf(tpe: TypeOrWildcard): ErasedTypeRef = tpe match
case tpe: AppliedType =>
tpe.tycon match
case TypeRef.OfClass(cls) =>
if cls.isArray then
val List(targ) = tpe.args: @unchecked
arrayOf(targ).arrayOf()
else ClassRef(cls).arrayOf()
case _ =>
arrayOf(tpe.translucentSuperType)
case TypeRef.OfClass(cls) =>
if cls.isUnit then ClassRef(defn.ErasedBoxedUnitClass).arrayOf()
else ClassRef(cls).arrayOf()
case tpe: TypeRef =>
tpe.optSymbol match
case Some(sym: TypeMemberSymbol) if sym.isOpaqueTypeAlias =>
arrayOf(tpe.translucentSuperType)
case _ =>
tpe.bounds match
case bounds: AbstractTypeBounds => arrayOfBounds(bounds)
case TypeAlias(alias) => arrayOf(alias)
case tpe: TypeParamRef => arrayOfBounds(tpe.bounds)
case tpe: Type => preErase(tpe, keepUnit = false).arrayOf()
case tpe: WildcardTypeArg => arrayOfBounds(tpe.bounds)
end arrayOf
private def preErase(tpe: Type, keepUnit: Boolean)(using Context, SourceLanguage): PreErasedTypeRef =
def arrayOf(tpe: TypeOrWildcard): ErasedTypeRef =
if isGenericArrayElement(tpe) then ClassRef(defn.ObjectClass)
else
preErase(tpe.highIfWildcard, keepUnit = false) match
case base: ErasedTypeRef => base.arrayOf()
case ErasedValueClass(valueClass, _) => ClassRef(valueClass).arrayOf()

tpe match
case tpe: AppliedType =>
Expand All @@ -73,11 +66,13 @@ private[tastyquery] object Erasure:
if cls.isArray then
val List(targ) = tpe.args: @unchecked
arrayOf(targ)
else if cls.isDerivedValueClass then preErasePolyValueClass(cls, tpe.args)
else ClassRef(cls)
case _ =>
preErase(tpe.translucentSuperType, keepUnit)
case TypeRef.OfClass(cls) =>
if !keepUnit && cls.isUnit then ClassRef(defn.ErasedBoxedUnitClass)
else if cls.isDerivedValueClass then preEraseMonoValueClass(cls)
else ClassRef(cls)
case tpe: TypeRef =>
preErase(tpe.translucentSuperType, keepUnit)
Expand All @@ -90,7 +85,10 @@ private[tastyquery] object Erasure:
case Some(reduced) => preErase(reduced, keepUnit)
case None => preErase(tpe.bound, keepUnit)
case tpe: OrType =>
erasedLub(preErase(tpe.first, keepUnit = false), preErase(tpe.second, keepUnit = false))
erasedLub(
finishErase(preErase(tpe.first, keepUnit = false)),
finishErase(preErase(tpe.second, keepUnit = false))
)
case tpe: AndType =>
summon[SourceLanguage] match
case SourceLanguage.Java =>
Expand Down Expand Up @@ -120,29 +118,157 @@ private[tastyquery] object Erasure:
throw IllegalArgumentException(s"Unexpected type in erasure: $tpe")
end preErase

private def finishErase(typeRef: ErasedTypeRef)(using Context): ErasedTypeRef =
private def finishErase(typeRef: PreErasedTypeRef)(using Context, SourceLanguage): ErasedTypeRef =
typeRef match
case ClassRef(cls) =>
if cls.isDerivedValueClass then finishEraseValueClass(cls)
else cls.erasure
case ArrayTypeRef(ClassRef(cls), dimensions) =>
ArrayTypeRef(cls.erasure, dimensions)
case ClassRef(cls) => cls.erasure
case ArrayTypeRef(ClassRef(cls), dimensions) => ArrayTypeRef(cls.erasure, dimensions)
case ErasedValueClass(_, underlying) => finishErase(underlying)
end finishErase

private def finishEraseValueClass(cls: ClassSymbol)(using Context): ErasedTypeRef =
private def preEraseMonoValueClass(cls: ClassSymbol)(using Context, SourceLanguage): ErasedValueClass =
val ctor = cls.findNonOverloadedDecl(nme.Constructor)

val underlying = ctor.declaredType match
case tpe: MethodType if tpe.paramNames.sizeIs == 1 =>
tpe.paramTypes.head
case _ =>
throw InvalidProgramStructureException(s"Illegal value class constructor type ${ctor.declaredType.showBasic}")

// The underlying of value classes are never value classes themselves (by language spec)
val erasedUnderlying = preErase(underlying, keepUnit = false).asInstanceOf[ErasedTypeRef]

ErasedValueClass(cls, erasedUnderlying)
end preEraseMonoValueClass

private def preErasePolyValueClass(cls: ClassSymbol, targs: List[TypeOrWildcard])(
using Context,
SourceLanguage
): ErasedValueClass =
val ctor = cls.findNonOverloadedDecl(nme.Constructor)

def illegalConstructorType(): Nothing =
throw InvalidProgramStructureException(s"Illegal value class constructor type ${ctor.declaredType.showBasic}")

def ctorParamType(tpe: TypeOrMethodic): Type = tpe match
case tpe: MethodType if tpe.paramTypes.sizeIs == 1 => tpe.paramTypes.head
case tpe: MethodType => illegalConstructorType()
case tpe: PolyType => ctorParamType(tpe.resultType)
case tpe: Type => illegalConstructorType()
case _ => illegalConstructorType()

val ctorPolyType = ctor.declaredType match
case tpe: PolyType => tpe
case _ => illegalConstructorType()

val genericUnderlying = ctorParamType(ctorPolyType.resultType)
val specializedUnderlying = ctorParamType(ctorPolyType.instantiate(targs))

// The underlying of value classes are never value classes themselves (by language spec)
val erasedGenericUnderlying = preErase(genericUnderlying, keepUnit = false).asInstanceOf[ErasedTypeRef]
val erasedSpecializedUnderlying = preErase(specializedUnderlying, keepUnit = false).asInstanceOf[ErasedTypeRef]

erase(ctorParamType(ctor.declaredType), ctor.sourceLanguage)
end finishEraseValueClass
def isPrimitive(typeRef: ErasedTypeRef): Boolean = typeRef match
case ClassRef(cls) => cls.isPrimitiveValueClass
case _: ArrayTypeRef => false

/* Ideally, we would just use `erasedSpecializedUnderlying` as the erasure of `tp`.
* However, there are two special cases for polymorphic value classes, which
* historically come from Scala 2:
*
* - Given `class Foo[A](x: A) extends AnyVal`, `Foo[X]` should erase like
* `X`, except if its a primitive in which case it erases to the boxed
* version of this primitive.
* - Given `class Bar[A](x: Array[A]) extends AnyVal`, `Bar[X]` will be
* erased like `Array[A]` as seen from its definition site, no matter
* the `X` (same if `A` is bounded).
*/
val erasedValueClass =
if isPrimitive(erasedSpecializedUnderlying) && !isPrimitive(erasedGenericUnderlying) then
ClassRef(erasedSpecializedUnderlying.asInstanceOf[ClassRef].cls.boxedClass)
else if genericUnderlying.baseType(defn.ArrayClass).isDefined then erasedGenericUnderlying
else erasedSpecializedUnderlying

ErasedValueClass(cls, erasedValueClass)
end preErasePolyValueClass

/** Is `Array[tp]` a generic Array that needs to be erased to `Object`?
* This is true if among the subtypes of `Array[tp]` there is either:
* - both a reference array type and a primitive array type
* (e.g. `Array[_ <: Int | String]`, `Array[_ <: Any]`)
* - or two different primitive array types (e.g. `Array[_ <: Int | Double]`)
* In both cases the erased lub of those array types on the JVM is `Object`.
*
* In addition, if `isScala2` is true, we mimic the Scala 2 erasure rules and
* also return true for element types upper-bounded by a non-reference type
* such as in `Array[_ <: Int]` or `Array[_ <: UniversalTrait]`.
*/
private def isGenericArrayElement(tp: TypeOrWildcard)(using Context, SourceLanguage): Boolean =
/** A symbol that represents the sort of JVM array that values of type `tp` can be stored in:
* - If we can always store such values in a reference array, return `j.l.Object`.
* - If we can always store them in a specific primitive array, return the corresponding primitive class.
* - Otherwise, return `None`.
*/
def arrayUpperBound(tp: Type): Option[ClassSymbol] = tp.dealias match
case TypeRef.OfClass(cls) =>
def isScala2SpecialCase: Boolean =
summon[SourceLanguage] == SourceLanguage.Scala2
&& !cls.isNull
&& !cls.isSubClass(defn.ObjectClass)

// Only a few classes have both primitives and references as subclasses.
if cls.isAny || cls.isAnyVal || cls.isMatchable || cls.isSingleton || isScala2SpecialCase then None
else if cls.isPrimitiveValueClass then Some(cls)
else
// Derived value classes in arrays are always boxed, so they end up here as well
Some(defn.ObjectClass)

case tp: TypeProxy =>
arrayUpperBound(tp.translucentSuperType)
case tp: AndType =>
arrayUpperBound(tp.first).orElse(arrayUpperBound(tp.second))
case tp: OrType =>
val firstBound = arrayUpperBound(tp.first)
val secondBound = arrayUpperBound(tp.first)
if firstBound == secondBound then firstBound
else None
case _: NothingType | _: AnyKindType | _: TypeLambda =>
None
case tp: CustomTransientGroundType =>
throw IllegalArgumentException(s"Unexpected transient type: $tp")
end arrayUpperBound

/** Can one of the JVM Array type store all possible values of type `tp`? */
def fitsInJVMArray(tp: Type): Boolean = arrayUpperBound(tp).isDefined

tp match
case tp: WildcardTypeArg =>
!fitsInJVMArray(tp.bounds.high)

case tp: Type =>
tp.dealias match
case tp: TypeRef =>
tp.optSymbol match
case Some(cls: ClassSymbol) =>
false
case Some(sym: TypeMemberSymbol) if sym.isOpaqueTypeAlias =>
isGenericArrayElement(tp.translucentSuperType)
case _ =>
tp.bounds match
case TypeAlias(alias) => isGenericArrayElement(alias)
case AbstractTypeBounds(_, high) => !fitsInJVMArray(high)
case tp: TypeParamRef =>
!fitsInJVMArray(tp)
case tp: MatchType =>
val cases = tp.cases
cases.nonEmpty && !fitsInJVMArray(cases.map(_.result).reduce(OrType(_, _)))
case tp: TypeProxy =>
isGenericArrayElement(tp.translucentSuperType)
case tp: AndType =>
isGenericArrayElement(tp.first) && isGenericArrayElement(tp.second)
case tp: OrType =>
isGenericArrayElement(tp.first) || isGenericArrayElement(tp.second)
case _: NothingType | _: AnyKindType | _: TypeLambda =>
false
case tp: CustomTransientGroundType =>
throw IllegalArgumentException(s"Unexpected transient type: $tp")
end isGenericArrayElement

/** The erased least upper bound of two erased types is computed as follows.
*
Expand Down Expand Up @@ -224,7 +350,7 @@ private[tastyquery] object Erasure:
* - Associativity and commutativity, because this method acts as the minimum
* of the total order induced by `compareErasedGlb`.
*/
private def erasedGlb(tp1: ErasedTypeRef, tp2: ErasedTypeRef)(using Context): ErasedTypeRef =
private def erasedGlb(tp1: PreErasedTypeRef, tp2: PreErasedTypeRef)(using Context): PreErasedTypeRef =
if compareErasedGlb(tp1, tp2) <= 0 then tp1
else tp2

Expand All @@ -248,7 +374,7 @@ private[tastyquery] object Erasure:
*
* @see erasedGlb
*/
private def compareErasedGlb(tp1: ErasedTypeRef, tp2: ErasedTypeRef)(using Context): Int =
private def compareErasedGlb(tp1: PreErasedTypeRef, tp2: PreErasedTypeRef)(using Context): Int =
def compareClasses(cls1: ClassSymbol, cls2: ClassSymbol): Int =
if cls1.isSubClass(cls2) then -1
else if cls2.isSubClass(cls1) then 1
Expand All @@ -260,13 +386,11 @@ private[tastyquery] object Erasure:
// fast path
0

case (ClassRef(cls1), _) if cls1.isDerivedValueClass =>
tp2 match
case ClassRef(cls2) if cls2.isDerivedValueClass =>
compareClasses(cls1, cls2)
case _ =>
-1
case (_, ClassRef(cls2)) if cls2.isDerivedValueClass =>
case (ErasedValueClass(cls1, _), ErasedValueClass(cls2, _)) =>
compareClasses(cls1, cls2)
case (ErasedValueClass(cls1, _), _) =>
-1
case (_, ErasedValueClass(cls2, _)) =>
1

case (tp1: ArrayTypeRef, tp2: ArrayTypeRef) =>
Expand Down
Loading