Skip to content

Commit

Permalink
Fix #213: Support refinement types.
Browse files Browse the repository at this point in the history
* Actually create RefinedTypes from RefinedTypeTree.toType.
* Implement `RefinedType.resolveMember` to take the refinements
  into account.
  • Loading branch information
sjrd committed Feb 17, 2023
1 parent 576692a commit d202eb4
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 18 deletions.
38 changes: 37 additions & 1 deletion tasty-query/shared/src/main/scala/tastyquery/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,22 @@ object Trees {

type ParamsClause = Either[List[ValDef], List[TypeParam]]

private[tastyquery] object ParamsClause:
def makeDefDefType(paramLists: List[ParamsClause], resultTpt: TypeTree)(using Context): Type =
def rec(paramLists: List[ParamsClause]): Type =
paramLists match
case Left(params) :: rest =>
val paramSymbols = params.map(_.symbol)
MethodType.fromSymbols(paramSymbols, rec(rest))
case Right(tparams) :: rest =>
PolyType.fromParams(tparams, rec(rest))
case Nil =>
resultTpt.toType

rec(paramLists)
end makeDefDefType
end ParamsClause

/** mods def name[tparams](vparams_1)...(vparams_n): tpt = rhs */
final case class DefDef(
name: TermName,
Expand Down Expand Up @@ -638,7 +654,27 @@ object Trees {
extends TypeTree(span) {

override protected def calculateType(using Context): Type =
underlying.toType // TODO Actually take the refinements into account
val base = underlying.toType
refinements.foldLeft(base) { (parent, refinement) =>
refinement match
case TypeMember(name, rhs, _) =>
val refinedBounds = rhs match
case TypeAliasDefinitionTree(tpt) =>
TypeAlias(tpt.toType)
case rhs: TypeBoundsTree =>
rhs.toTypeBounds
case _ =>
throw InvalidProgramStructureException(s"Unexpected rhs for type refinement '$name': $rhs")
end refinedBounds
TypeRefinement(parent, name, refinedBounds)

case ValDef(name, tpt, _, _) =>
TermRefinement(parent, name, tpt.toType)

case DefDef(name, paramClauses, resultType, _, _) =>
TermRefinement(parent, name, ParamsClause.makeDefDefType(paramClauses, resultType))
}
end calculateType

override final def withSpan(span: Span): RefinedTypeTree =
RefinedTypeTree(underlying, refinements, refinedCls)(span)
Expand Down
17 changes: 17 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,17 @@ object Types {
val refinedName: Name

override final def underlying(using Context): Type = parent

private[tastyquery] override def resolveMember(name: Name, pre: Type)(using Context): ResolveMemberResult =
val parentMember = parent.resolveMember(name, pre)

if name != refinedName then parentMember
else
val myResult = makeResolveMemberResult(pre)
ResolveMemberResult.merge(parentMember, myResult)
end resolveMember

protected def makeResolveMemberResult(pre: Type)(using Context): ResolveMemberResult
end RefinedType

/** A type refinement `parent { type refinedName <:> refinedBounds }`.
Expand All @@ -1397,6 +1408,9 @@ object Types {
*/
final class TypeRefinement(val parent: Type, val refinedName: TypeName, val refinedBounds: TypeBounds)
extends RefinedType:
protected def makeResolveMemberResult(pre: Type)(using Context): ResolveMemberResult =
ResolveMemberResult.TypeMember(Nil, refinedBounds)

private[tastyquery] final def derivedTypeRefinement(
parent: Type,
refinedName: TypeName,
Expand All @@ -1415,6 +1429,9 @@ object Types {
* @param refinedType The refined type for the given term member
*/
final class TermRefinement(val parent: Type, val refinedName: TermName, val refinedType: Type) extends RefinedType:
protected def makeResolveMemberResult(pre: Type)(using Context): ResolveMemberResult =
ResolveMemberResult.TermMember(Nil, refinedType)

private[tastyquery] final def derivedTermRefinement(parent: Type, refinedName: TermName, refinedType: Type): Type =
if ((parent eq this.parent) && (refinedName eq this.refinedName) && (refinedType eq this.refinedType)) this
else TermRefinement(parent, refinedName, refinedType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ private[tasties] class TreeUnpickler(
val normalizedParams =
if name == nme.Constructor then normalizeCtorParamClauses(params)
else params
symbol.withDeclaredType(makeDefDefType(normalizedParams, tpt))
symbol.withDeclaredType(ParamsClause.makeDefDefType(normalizedParams, tpt))
definingTree(symbol, DefDef(name, normalizedParams, tpt, rhs, symbol)(spn))
}
}
Expand Down Expand Up @@ -697,20 +697,6 @@ private[tasties] class TreeUnpickler(
else paramLists :+ Left(Nil) // add `()` at the end
end normalizeCtorParamClauses

private def makeDefDefType(paramLists: List[ParamsClause], resultTpt: TypeTree): Type =
def rec(paramLists: List[ParamsClause]): Type =
paramLists match
case Left(params) :: rest =>
val paramSymbols = params.map(_.symbol)
MethodType.fromSymbols(paramSymbols, rec(rest))
case Right(tparams) :: rest =>
PolyType.fromParams(tparams, rec(rest))
case Nil =>
resultTpt.toType

rec(paramLists)
end makeDefDefType

private def readTerms(end: Addr)(using LocalContext): List[TermTree] =
reader.until(end)(readTerm)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,99 @@ class SubtypingSuite extends UnrestrictedUnpicklingSuite:
.withRef[refyAlias.AliasOfAbstractType, JString]
}

testWithContext("simple-paths-in-subclasses") {
testWithContext("paths-and-refinements") {
import subtyping.paths.{A, B, C, SimplePaths, ConcreteSimplePathsChild}

val paths = "subtyping.paths"
val AClass = ctx.findTopLevelClass(s"$paths.A")
val BClass = ctx.findTopLevelClass(s"$paths.B")
val CClass = ctx.findTopLevelClass(s"$paths.C")
val SimplePathsClass = ctx.findTopLevelClass(s"$paths.SimplePaths")
val ConcreteSimplePathsChildClass = ctx.findTopLevelClass(s"$paths.ConcreteSimplePathsChild")

val setupMethod = ctx.findTopLevelModuleClass(s"$paths.Setup").findNonOverloadedDecl(name"refinements")
val setupMethodDef = setupMethod.tree.get.asInstanceOf[DefDef]
val Left(valDefs) = setupMethodDef.paramLists.head: @unchecked
val List(x, y) = valDefs.map(valDef => TermRef(NoPrefix, valDef.symbol))
val yAsStringRefine = TermRef(NoPrefix, findLocalValDef(setupMethodDef.rhs.get, name"yAsStringRefine"))
val zAsIntRefine = TermRef(NoPrefix, findLocalValDef(setupMethodDef.rhs.get, name"zAsIntRefine"))

type StringRefine = SimplePaths {
type AbstractType = String
type AbstractTypeWithBounds <: B
type ConcreteOnlyMember = Boolean
}
type IntRefine = SimplePaths { type AbstractType = Int }

val refx: SimplePaths = new SimplePaths
val refy: ConcreteSimplePathsChild = new ConcreteSimplePathsChild
val refyAsStringRefine: StringRefine = refy
val refzAsIntRefine: IntRefine = new SimplePaths {
type AbstractType = Int
}

assertStrictSubtype(x, SimplePathsClass.typeRef)
assertStrictSubtype(y, ConcreteSimplePathsChildClass.typeRef)
assertStrictSubtype(yAsStringRefine, SimplePathsClass.typeRef)
assertStrictSubtype(zAsIntRefine, SimplePathsClass.typeRef)

assertNeitherSubtype(yAsStringRefine, ConcreteSimplePathsChildClass.typeRef)
assertNeitherSubtype(zAsIntRefine, ConcreteSimplePathsChildClass.typeRef)

assertNeitherSubtype(x.select(tname"AbstractType"), defn.StringType).withRef[refx.AbstractType, JString]
assertEquiv(y.select(tname"AbstractType"), defn.StringType).withRef[refy.AbstractType, JString]

assertEquiv(yAsStringRefine.select(tname"AbstractType"), defn.StringType)
.withRef[refyAsStringRefine.AbstractType, JString]
assertEquiv(yAsStringRefine.select(tname"AbstractType"), y.select(tname"AbstractType"))
.withRef[refyAsStringRefine.AbstractType, refy.AbstractType]
assertNeitherSubtype(yAsStringRefine.select(tname"AbstractType"), x.select(tname"AbstractType"))
.withRef[refyAsStringRefine.AbstractType, refx.AbstractType]

assertEquiv(yAsStringRefine.select(tname"ConcreteOnlyMember"), defn.BooleanType)
.withRef[refyAsStringRefine.ConcreteOnlyMember, Boolean]
assertEquiv(yAsStringRefine.select(tname"ConcreteOnlyMember"), y.select(tname"ConcreteOnlyMember"))
.withRef[refyAsStringRefine.ConcreteOnlyMember, refy.ConcreteOnlyMember]
assertNeitherSubtype(yAsStringRefine.select(tname"ConcreteOnlyMember"), x.select(tname"AbstractType"))
.withRef[refyAsStringRefine.ConcreteOnlyMember, refx.AbstractType]

assertEquiv(yAsStringRefine.select(tname"AliasOfAbstractType"), y.select(tname"AbstractType"))
.withRef[refyAsStringRefine.AliasOfAbstractType, refy.AbstractType]
assertEquiv(yAsStringRefine.select(tname"AliasOfAbstractType"), defn.StringType)
.withRef[refyAsStringRefine.AliasOfAbstractType, JString]
assertNeitherSubtype(
yAsStringRefine.select(tname"AliasOfAbstractType"),
zAsIntRefine.select(tname"AliasOfAbstractType")
)
.withRef[refyAsStringRefine.AliasOfAbstractType, refzAsIntRefine.AliasOfAbstractType]

assertStrictSubtype(yAsStringRefine.select(tname"AbstractTypeWithBounds"), BClass.typeRef)
.withRef[refyAsStringRefine.AbstractTypeWithBounds, B]
assertStrictSubtype(yAsStringRefine.select(tname"AbstractTypeWithBounds"), AClass.typeRef)
.withRef[refyAsStringRefine.AbstractTypeWithBounds, A]
assertStrictSubtype(CClass.typeRef, yAsStringRefine.select(tname"AbstractTypeWithBounds"))
.withRef[C, refyAsStringRefine.AbstractTypeWithBounds]

assertStrictSubtype(yAsStringRefine.select(tname"AliasOfAbstractTypeWithBounds"), BClass.typeRef)
.withRef[refyAsStringRefine.AliasOfAbstractTypeWithBounds, B]
assertStrictSubtype(
yAsStringRefine.select(tname"AliasOfAbstractTypeWithBounds"),
y.select(tname"AliasOfAbstractTypeWithBounds")
)
.withRef[refyAsStringRefine.AliasOfAbstractTypeWithBounds, refy.AliasOfAbstractTypeWithBounds]
assertStrictSubtype(yAsStringRefine.select(tname"AliasOfAbstractTypeWithBounds"), AClass.typeRef)
.withRef[refyAsStringRefine.AliasOfAbstractTypeWithBounds, A]
assertStrictSubtype(CClass.typeRef, yAsStringRefine.select(tname"AliasOfAbstractTypeWithBounds"))
.withRef[C, refyAsStringRefine.AliasOfAbstractTypeWithBounds]

assertNeitherSubtype(
yAsStringRefine.select(tname"AliasOfAbstractTypeWithBounds"),
zAsIntRefine.select(tname"AliasOfAbstractTypeWithBounds")
)
.withRef[refyAsStringRefine.AliasOfAbstractTypeWithBounds, refzAsIntRefine.AliasOfAbstractTypeWithBounds]
}

testWithContext("simple-paths-in-nested-classes") {
import subtyping.paths.NestedClasses

val paths = "subtyping.paths"
Expand Down
19 changes: 18 additions & 1 deletion test-sources/src/main/scala/subtyping/paths/Paths.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,29 @@ object Setup:
def subclassPaths(x: SimplePaths, y: ConcreteSimplePathsChild, z: ConcreteSimplePathsChild): Unit =
val yAlias: y.type = y
end subclassPaths

def refinements(x: SimplePaths, y: ConcreteSimplePathsChild): Unit =
type StringRefine = SimplePaths {
type AbstractType = String
type AbstractTypeWithBounds <: B
type ConcreteOnlyMember = Boolean
}
type IntRefine = SimplePaths { type AbstractType = Int }

val yAsStringRefine: StringRefine = y
val zAsIntRefine: IntRefine = new SimplePaths {
type AbstractType = Int
type FooBar = Int
}
end refinements
end Setup

trait A
trait B extends A
class C extends B
class D extends B

class SimplePaths:
open class SimplePaths:
type AbstractType
type AbstractTypeWithBounds >: C <: A

Expand All @@ -32,4 +47,6 @@ end OtherSimplePaths
class ConcreteSimplePathsChild extends SimplePaths:
type AbstractType = String
type AbstractTypeWithBounds = B

type ConcreteOnlyMember = Boolean
end ConcreteSimplePathsChild

0 comments on commit d202eb4

Please sign in to comment.