Skip to content

Commit

Permalink
Merge pull request #234 from sjrd/selftypes
Browse files Browse the repository at this point in the history
Handle self types.
  • Loading branch information
bishabosha authored Dec 22, 2022
2 parents fef56ee + 0f2e402 commit 502e2db
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
val cls = ClassSymbol.create(name, scalaPackage)
cls.withTypeParams(Nil)
cls.withParentsDirect(parents)
cls.withGivenSelfType(None)
cls.withFlags(flags, None)
cls.setAnnotations(Nil)
cls.checkCompleted()
Expand Down
62 changes: 56 additions & 6 deletions tasty-query/shared/src/main/scala/tastyquery/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ object Symbols {
private var myTypeParams: List[ClassTypeParamSymbol] | Null = null
private var myParentsInit: (() => List[Type]) | Null = null
private var myParents: List[Type] | Null = null
private var myGivenSelfType: Option[Type] | Null = null

// Optional reference fields
private var mySpecialErasure: Option[() => ErasedTypeRef] = None
Expand All @@ -583,12 +584,15 @@ object Symbols {
mutable.HashMap[Name, mutable.HashSet[TermOrTypeSymbol]]()

// Cache fields
private var myAppliedRef: Type | Null = null
private var mySelfType: Type | Null = null
private var myLinearization: List[ClassSymbol] | Null = null

protected override def doCheckCompleted(): Unit =
super.doCheckCompleted()
if myTypeParams == null then failNotCompleted("typeParams not initialized")
if myParents == null && myParentsInit == null then failNotCompleted("parents not initialized")
if myGivenSelfType == null then failNotCompleted("givenSelfType not initialized")

private[tastyquery] def isValueClass(using Context): Boolean =
parents.nonEmpty && parents.head.classSymbol.exists(_ == defn.AnyValClass)
Expand Down Expand Up @@ -659,6 +663,39 @@ object Symbols {
}
)

private[tastyquery] final def withGivenSelfType(givenSelfType: Option[Type]): this.type =
if myGivenSelfType != null then throw new IllegalStateException(s"reassignment of givenSelfType for $this")
myGivenSelfType = givenSelfType
this

final def givenSelfType(using Context): Option[Type] =
val local = myGivenSelfType
if local == null then throw new IllegalStateException(s"givenSelfType not initialized for $this")
else local

final def appliedRef(using Context): Type =
val local = myAppliedRef
if local != null then local
else
val computed = typeRef.appliedTo(typeParams.map(_.typeRef))
myAppliedRef = computed
computed
end appliedRef

final def selfType(using Context): Type =
val local = mySelfType
if local != null then local
else
val computed = givenSelfType match
case None =>
appliedRef
case Some(givenSelf) =>
if is(Module) then givenSelf
else AndType(givenSelf, appliedRef)
mySelfType = computed
computed
end selfType

final def linearization(using Context): List[ClassSymbol] =
val local = myLinearization
if local != null then local
Expand Down Expand Up @@ -825,13 +862,13 @@ object Symbols {

/** Compute tp.baseType(this) */
private[tastyquery] final def baseTypeOf(tp: Type)(using Context): Option[Type] =
def combineGlb(bt1: Option[Type], bt2: Option[Type]): Option[Type] =
if bt1.isEmpty then bt2
else if bt2.isEmpty then bt1
else Some(bt1.get & bt2.get)

def recur(tp: Type): Option[Type] = tp match
case tp: TypeRef =>
def combineGlb(bt1: Option[Type], bt2: Option[Type]): Option[Type] =
if bt1.isEmpty then bt2
else if bt2.isEmpty then bt1
else Some(bt1.get & bt2.get)

def foldGlb(bt: Option[Type], ps: List[Type]): Option[Type] =
ps.foldLeft(bt)((bt, p) => combineGlb(bt, recur(p)))

Expand Down Expand Up @@ -869,8 +906,20 @@ object Symbols {
case tp: TypeProxy =>
recur(tp.superType)

case tp: AndType =>
val tp1 = tp.first
val tp2 = tp.second
// TODO? Opt when this.isStatic && tp.derivesFrom(this) && this.typeParams.isEmpty then this.typeRef
val combined = combineGlb(recur(tp1), recur(tp2))
combined match
case Some(combined: AndType) if (combined.first eq tp1) && (combined.second eq tp2) =>
// Return `tp` itself to allow `Subtyping.level3WithBaseType` to cut off infinite recursions
Some(tp)
case _ =>
combined

case _ =>
// TODO Handle AndType and OrType, and JavaArrayType
// TODO Handle OrType and JavaArrayType
None
end recur

Expand Down Expand Up @@ -988,6 +1037,7 @@ object Symbols {
cls
.withTypeParams(Nil)
.withParentsDirect(defn.ObjectType :: Nil)
.withGivenSelfType(None)
.withFlags(EmptyFlagSet, None)
.setAnnotations(Nil)
cls.checkCompleted()
Expand Down
13 changes: 12 additions & 1 deletion tasty-query/shared/src/main/scala/tastyquery/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,19 @@ object Types {
end TypeRef

final class ThisType(val tref: TypeRef) extends PathType with SingletonType {
private var myUnderlying: Type | Null = null

override def underlying(using Context): Type =
tref // TODO This is probably wrong
val local = myUnderlying
if local != null then local
else
val cls = this.cls
val computed =
if cls.isStatic then cls.selfType
else cls.selfType.asSeenFrom(tref.prefix, cls)
myUnderlying = computed
computed
end underlying

final def cls(using Context): ClassSymbol = tref.symbol.asClass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ private[reader] object ClassfileParser {
.withFlags(clsFlags | Flags.ModuleClassCreationFlags, clsPrivateWithin)
.setAnnotations(Nil)
.withParentsDirect(defn.ObjectType :: Nil)
.withGivenSelfType(None)
allRegisteredSymbols += moduleClass

val module = TermSymbol
Expand Down Expand Up @@ -216,6 +217,7 @@ private[reader] object ClassfileParser {
cls.withParentsDirect(parents)
end initParents

cls.withGivenSelfType(None)
cls.withFlags(clsFlags, clsPrivateWithin)
cls.setAnnotations(Nil) // TODO Read Java annotations on classes
initParents()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,10 @@ private[pickles] class PickleReader {
case tpe: TempClassInfoType => tpe.parentTypes
case tpe =>
throw AssertionError(s"unexpected type $tpe for $cls, owner is $owner")
val givenSelfType = if atEnd then None else Some(readTypeRef())
cls.withParentsDirect(parentTypes)
cls.withTypeParams(typeParams)
cls.withGivenSelfType(givenSelfType)
cls
case VALsym =>
val sym = TermSymbol.create(name.toTermName, owner)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ private[tasties] class TreeUnpickler(
}
}
val self = readSelf
cls.withGivenSelfType(self.map(_.tpt.toType))
// The first entry is the constructor
val cstr = readStat.asInstanceOf[DefDef]
val body = readStats(end)
Expand Down
79 changes: 79 additions & 0 deletions tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1939,4 +1939,83 @@ class TypeSuite extends UnrestrictedUnpicklingSuite {
end if
end for
}

testWithContext("applied-ref") {
val FooClass = ctx.findStaticClass("simple_trees.SelfTypes.Foo")
val BarClass = ctx.findStaticClass("simple_trees.SelfTypes.Bar")
val FooBarClass = ctx.findStaticClass("simple_trees.SelfTypes.FooBar")

val fooTArg = FooClass.typeParams.head
val List(barTArg1, barTArg2) = BarClass.typeParams: @unchecked

assert(clue(FooClass.appliedRef).isApplied(_.isRef(FooClass), List(_.isRef(fooTArg))))
assert(clue(BarClass.appliedRef).isApplied(_.isRef(BarClass), List(_.isRef(barTArg1), _.isRef(barTArg2))))
assert(clue(FooBarClass.appliedRef).isRef(FooBarClass))
}

testWithContext("self-types") {
val FooClass = ctx.findStaticClass("simple_trees.SelfTypes.Foo")
val BarClass = ctx.findStaticClass("simple_trees.SelfTypes.Bar")
val FooBarClass = ctx.findStaticClass("simple_trees.SelfTypes.FooBar")

val fooTArg = FooClass.typeParams.head
val List(barTArg1, barTArg2) = BarClass.typeParams: @unchecked

val expectedGivenSelfType: Type => Boolean =
tpe => tpe.isApplied(_.isRef(BarClass), List(_.isRef(fooTArg), _.isRef(defn.IntClass)))

assert(clue(FooClass.givenSelfType).exists(expectedGivenSelfType))
assert(
clue(FooClass.selfType).isIntersectionOf(
expectedGivenSelfType,
_.isApplied(_.isRef(FooClass), List(_.isRef(FooClass.typeParams.head)))
)
)

assert(clue(BarClass.givenSelfType).isEmpty)
assert(clue(BarClass.selfType).isApplied(_.isRef(BarClass), List(_.isRef(barTArg1), _.isRef(barTArg2))))

assert(clue(FooBarClass.givenSelfType).isEmpty)
assert(clue(FooBarClass.selfType).isRef(FooBarClass))
}

testWithContext("scala2-self-types") {
val ClassManifestAlias = ctx.findStaticType("scala.reflect.package.ClassManifest")
val ClassManifestDeprecatedApisClass = ctx.findTopLevelClass("scala.reflect.ClassManifestDeprecatedApis")

val cmDeprecatedApisTArg = ClassManifestDeprecatedApisClass.typeParams.head

val expectedGivenSelfType: Type => Boolean =
tpe => tpe.isApplied(_.isRef(ClassManifestAlias), List(_.isRef(cmDeprecatedApisTArg)))

assert(clue(ClassManifestDeprecatedApisClass.givenSelfType).exists(expectedGivenSelfType))
assert(
clue(ClassManifestDeprecatedApisClass.selfType).isIntersectionOf(
expectedGivenSelfType,
_.isApplied(_.isRef(ClassManifestDeprecatedApisClass), List(_.isRef(cmDeprecatedApisTArg)))
)
)
}

testWithContext("selections-with-self-types") {
val FooClass = ctx.findStaticClass("simple_trees.SelfTypes.Foo")
val BarClass = ctx.findStaticClass("simple_trees.SelfTypes.Bar")
val PairClass = ctx.findStaticClass("simple_trees.SelfTypes.Pair")

val fooTArg = FooClass.typeParams.head
val List(barTArg1, barTArg2) = BarClass.typeParams: @unchecked

val targetMethod = BarClass.findNonOverloadedDecl(termName("bar"))

for testMethodName <- List("throughSelf", "throughThis", "bare") do
val DefDef(_, _, _, Some(body), _) = FooClass.findNonOverloadedDecl(termName(testMethodName)).tree.get: @unchecked
val Apply(sel @ Select(ths: This, SignedName(SimpleName("bar"), _, _)), Nil) = body: @unchecked

assert(clue(ths.tpe).isInstanceOf[ThisType])
assert(clue(ths.tpe.asInstanceOf[ThisType].cls) == FooClass)
assert(clue(sel.tpe).isRef(targetMethod))

assert(clue(body.tpe).isApplied(_.isRef(PairClass), List(_.isRef(fooTArg), _.isRef(defn.IntClass))))
end for
}
}
23 changes: 23 additions & 0 deletions test-sources/src/main/scala/simple_trees/SelfTypes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package simple_trees

object SelfTypes:
trait Foo[T]:
self: Bar[T, Int] =>

def throughSelf: Pair[T, Int] = self.bar()
def throughThis: MyPair = this.bar()
def bare: Pair[T, Int] = bar()
end Foo

trait Bar[A, B]:
def bar(): Pair[A, B]

type MyPair = Pair[A, B]
end Bar

class FooBar extends Foo[String] with Bar[String, Int]:
def bar(): Pair[String, Int] = Pair("foo", 4)
end FooBar

final class Pair[+A, +B](val a: A, val B: B)
end SelfTypes

0 comments on commit 502e2db

Please sign in to comment.