Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce TermSymbol.paramSymss. #383

Merged
merged 2 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ inThisBuild(Def.settings(
Developer("sjrd", "Sébastien Doeraene", "[email protected]", url("https://github.com/sjrd/")),
Developer("bishabosha", "Jamie Thompson", "[email protected]", url("https://github.com/bishabosha")),
),
versionPolicyIntention := Compatibility.BinaryAndSourceCompatible,
versionPolicyIntention := Compatibility.BinaryCompatible,
// Ignore dependencies to internal modules whose version is like `1.2.3+4...` (see https://github.com/scalacenter/sbt-version-policy#how-to-integrate-with-sbt-dynver)
versionPolicyIgnoredInternalDependencyVersions := Some("^\\d+\\.\\d+\\.\\d+\\+\\d+".r)
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
.withFlags(Method | flags, privateWithin = None)
.withDeclaredType(tpe)
.setAnnotations(Nil)
.autoFillParamSymss()
sym.checkCompleted()
sym
end createSpecialMethod
Expand Down Expand Up @@ -341,6 +342,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
mt => resultTypeParam.localRef
)
)
applyMethod.autoFillParamSymss()
applyMethod.setAnnotations(Nil)
applyMethod.checkCompleted()

Expand Down
22 changes: 22 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Substituters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ private[tastyquery] object Substituters:
if from.isEmpty then tp
else new SubstLocalParamsMap(from, to).apply(tp)

def substLocalBoundParams(tp: TypeMappable, from: ParamRefBinder, to: List[Type]): tp.ThisTypeMappableType =
if to.isEmpty then tp
else new SubstLocalBoundParamsMap(from, to).apply(tp)

def substLocalThisClassTypeParams(
tp: TypeMappable,
from: List[ClassTypeParamSymbol],
Expand Down Expand Up @@ -142,6 +146,24 @@ private[tastyquery] object Substituters:
end transform
end SubstLocalParamsMap

private final class SubstLocalBoundParamsMap(from: ParamRefBinder, to: List[TypeOrWildcard]) extends TypeMap:
protected def transform(tp: TypeMappable): TypeMappable =
tp match
case tp: ParamRef =>
if tp.binder eq from then to(tp.paramNum) else tp
case tp: NamedType =>
tp.prefix match
case NoPrefix | _: PackageRef => tp
case prefix: Type => tp.derivedSelect(apply(prefix))
case _: ThisType =>
tp
case tp: AppliedType =>
tp.map(apply(_), apply(_))
case _ =>
mapOver(tp)
end transform
end SubstLocalBoundParamsMap

private final class SubstLocalThisClassTypeParamsMap(from: List[ClassTypeParamSymbol], to: List[Type])
extends TypeMap:
protected def transform(tp: TypeMappable): TypeMappable =
Expand Down
66 changes: 65 additions & 1 deletion tasty-query/shared/src/main/scala/tastyquery/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ object Symbols {
* This method is called by the various file readers after reading each
* file, for all the `Symbol`s created while reading that file.
*/
private[tastyquery] final def checkCompleted(): Unit =
private[tastyquery] final def checkCompleted(): this.type =
doCheckCompleted()
this

protected final def failNotCompleted(details: String): Nothing =
throw IllegalStateException(s"$this of class ${this.getClass().getName()} was not completed: $details")
Expand Down Expand Up @@ -426,12 +427,15 @@ object Symbols {
end matchingSymbol
end TermOrTypeSymbol

type ParamSymbolsClause = Either[List[TermSymbol], List[LocalTypeParamSymbol]]

final class TermSymbol private (val name: UnsignedTermName, owner: Symbol) extends TermOrTypeSymbol(owner):
type DefiningTreeType = ValOrDefDef | Bind
type MatchingSymbolType = TermSymbol

// Reference fields (checked in doCheckCompleted)
private var myDeclaredType: TypeOrMethodic | Null = null
private var myParamSymss: List[ParamSymbolsClause] | Null = null

// Cache fields
private var mySignature: Signature | Null = null
Expand All @@ -442,16 +446,76 @@ object Symbols {
super.doCheckCompleted()
if myDeclaredType == null then failNotCompleted("declaredType was not initialized")

if flags.is(Method) then
if myParamSymss == null then failNotCompleted("paramSymss was not initialized")
paramSymss.foreach(_.merge.foreach(_.checkCompleted()))
else if myParamSymss == null then myParamSymss = Nil // auto-complete for non-methods
else if myParamSymss != Nil then
throw IllegalArgumentException(s"illegal non-empty paramSymss $myParamSymss for $this")
end doCheckCompleted

private[tastyquery] final def withDeclaredType(tpe: TypeOrMethodic): this.type =
if myDeclaredType != null then throw new IllegalStateException(s"reassignment of declared type to $this")
myDeclaredType = tpe
this

/** You should not need this; it is a hack for patching Scala 2 constructors in `PickleReader`. */
private[tastyquery] final def overwriteDeclaredType(tpe: TypeOrMethodic): this.type =
myDeclaredType = tpe
this

def declaredType: TypeOrMethodic =
val local = myDeclaredType
if local != null then local
else throw new IllegalStateException(s"$this was not assigned a declared type")

private[tastyquery] final def setParamSymss(paramSymss: List[ParamSymbolsClause]): this.type =
if myParamSymss != null then throw IllegalStateException(s"reassignment of paramSymss to $this")
myParamSymss = paramSymss
this

private[tastyquery] final def autoFillParamSymss(): this.type =
setParamSymss(autoComputeParamSymss(declaredType))

private def autoComputeParamSymss(tpe: TypeOrMethodic): List[ParamSymbolsClause] = tpe match
case tpe: MethodType =>
/* For term params, we do not instantiate the paramTypes.
* We only use autoFillParamSymss for Java definitions, which do not
* support term param references at all, and from Definitions, which
* does not use that capability in the term param bounds.
*/
val paramSyms = tpe.paramNames.lazyZip(tpe.paramTypes).map { (name, paramType) =>
TermSymbol
.createNotDeclaration(name, this)
.withFlags(EmptyFlagSet, privateWithin = None)
.withDeclaredType(paramType)
.setAnnotations(Nil)
}
Left(paramSyms) :: autoComputeParamSymss(tpe.resultType)

case tpe: PolyType =>
val paramSyms = tpe.paramNames.map { name =>
LocalTypeParamSymbol
.create(name, this)
.withFlags(EmptyFlagSet, privateWithin = None)
.setAnnotations(Nil)
}
val paramSymRefs = paramSyms.map(_.localRef)
def subst(t: TypeOrMethodic): t.ThisTypeMappableType =
Substituters.substLocalBoundParams(t, tpe, paramSymRefs)
for (paramSym, paramTypeBounds) <- paramSyms.lazyZip(tpe.paramTypeBounds) do
paramSym.setDeclaredBounds(paramTypeBounds.mapBounds(subst(_)))
Right(paramSyms) :: autoComputeParamSymss(subst(tpe.resultType))

case tpe: Type =>
Nil
end autoComputeParamSymss

def paramSymss: List[ParamSymbolsClause] =
val local = myParamSymss
if local != null then local
else throw IllegalStateException(s"$this was not assigned its paramSymss")

/** Is this symbol a module val, i.e., the term of an `object`?
*
* @return true iff `kind == TermSymbolKind.Module`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ private[reader] object ClassfileParser {
else if sym.isMethod && javaFlags.isVarargsIfMethod then patchForVarargs(sym, parsedType)
else parsedType
sym.withDeclaredType(adaptedType)
sym.autoFillParamSymss()

// Verify after the fact that we don't mark signature-polymorphic methods that should not be
if sym.isSignaturePolymorphicMethod then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,19 @@ private[pickles] class PickleReader {
TermSymbol.createNotDeclaration(name.toTermName, owner)
else TermSymbol.create(name.toTermName, owner)
storeResultInEntries(sym) // Store the symbol before reading its type, to avoid cycles
val tpe = readSymType()
val storedType = readSymType() match
case storedType: Type => storedType
case storedType => throw Scala2PickleFormatException(s"Type expected for $sym but found $storedType")
val unwrappedTpe: TypeOrMethodic =
if flags.is(Method) then
tpe match
case tpe: TypeOrMethodic => translateTempPolyForMethod(tpe)
case _ => throw Scala2PickleFormatException(s"Type or methodic type expected for $sym but found $tpe")
else
tpe match
case tpe: Type => tpe
case _ => throw Scala2PickleFormatException(s"Type expected for $sym but found $tpe")
end unwrappedTpe
val ctorPatchedTpe =
if flags.is(Method) && name == nme.Constructor then patchConstructorType(sym.owner.asClass, unwrappedTpe)
else unwrappedTpe
sym.withDeclaredType(ctorPatchedTpe)
if flags.is(Method) then translateTempMethodAndPolyForMethod(storedType)
else storedType
val paramSymss = paramSymssOf(storedType)
if flags.is(Method) && name == nme.Constructor then
sym.withDeclaredType(patchConstructorType(sym.owner.asClass, unwrappedTpe))
sym.setParamSymss(patchConstructorParamSymss(sym, paramSymss))
else
sym.withDeclaredType(unwrappedTpe)
sym.setParamSymss(paramSymss)
case MODULEsym =>
val sym = TermSymbol.create(name.toTermName, owner)
storeResultInEntries(sym)
Expand All @@ -371,6 +369,15 @@ private[pickles] class PickleReader {
sym
}

private def paramSymssOf(storedType: Type): List[ParamSymbolsClause] = storedType match
case TempMethodType(paramSyms, resType) =>
Left(paramSyms) :: paramSymssOf(resType.requireType)
case TempPolyType(paramSyms, resType) =>
Right(paramSyms.map(_.asInstanceOf[LocalTypeParamSymbol])) :: paramSymssOf(resType.requireType)
case _ =>
Nil
end paramSymssOf

private def patchConstructorType(cls: ClassSymbol, tpe: TypeOrMethodic)(using ReaderContext): TypeOrMethodic =
def resultToUnit(tpe: TypeOrMethodic): TypeOrMethodic =
tpe match
Expand All @@ -385,6 +392,41 @@ private[pickles] class PickleReader {
cls.makePolyConstructorType(tpe1)
end patchConstructorType

private def patchConstructorParamSymss(
ctor: TermSymbol,
paramSymss: List[ParamSymbolsClause]
): List[ParamSymbolsClause] =
val cls = ctor.owner.asClass
val clsTypeParams = cls.typeParams

if clsTypeParams.isEmpty then paramSymss
else
// Create the symbols; don't assign bounds yet
val ctorTypeParams = clsTypeParams.map { clsTypeParam =>
LocalTypeParamSymbol
.create(clsTypeParam.name, ctor)
.withFlags(EmptyFlagSet, privateWithin = None)
.setAnnotations(Nil)
}

val ctorTypeParamRefs = ctorTypeParams.map(_.localRef)
def subst(tpe: TypeMappable): tpe.ThisTypeMappableType =
Substituters.substLocalThisClassTypeParams(tpe, clsTypeParams, ctorTypeParamRefs)

// Assign the bounds; when they refer to each other we need to substitute for the new local refs
for (clsTypeParam, ctorTypeParam) <- clsTypeParams.lazyZip(ctorTypeParams) do
ctorTypeParam.setDeclaredBounds(subst(clsTypeParam.declaredBounds))

// Overwrite the types of the existing param syms to refer to the new local refs as well
for
case Left(paramSyms) <- paramSymss
paramSym <- paramSyms
do paramSym.overwriteDeclaredType(subst(paramSym.declaredType))

Right(ctorTypeParams) :: paramSymss
end if
end patchConstructorParamSymss

def readChildren()(using ReaderContext, PklStream, Entries, Index): Unit =
val tag = pkl.readByte()
assert(tag == CHILDREN)
Expand Down Expand Up @@ -665,18 +707,7 @@ private[pickles] class PickleReader {
case METHODtpe | IMPLICITMETHODtpe =>
val restpe = readTypeOrMethodicRef()
val params = pkl.until(end, () => readLocalSymbolRef().asTerm)
val maker = MethodType
/*val maker = MethodType.companion(
isImplicit = tag == IMPLICITMETHODtpe || params.nonEmpty && params.head.is(Implicit))*/
val result = maker.fromSymbols(params, restpe)
// result.resType match
// case restpe1: MethodType if restpe1 ne restpe =>
// val prevResParams = caches.paramsOfMethodType.remove(restpe)
// if prevResParams != null then
// caches.paramsOfMethodType.put(restpe1, prevResParams)
// case _ =>
// if params.nonEmpty then caches.paramsOfMethodType.put(result, params)
result
TempMethodType(params, restpe)
case POLYtpe =>
// create PolyType
// - PT => register at index
Expand Down Expand Up @@ -742,18 +773,25 @@ private[pickles] class PickleReader {
end translateTempPolyForTypeMember

/** Convert temp poly type to PolyType and leave other types alone. */
private def translateTempPolyForMethod(tp: TypeOrMethodic)(using ReaderContext): TypeOrMethodic = tp match
private def translateTempMethodAndPolyForMethod(tp: TypeOrMethodic)(using ReaderContext): TypeOrMethodic = tp match
case TempMethodType(paramSyms, resType) =>
resType match
case resType: TypeOrMethodic =>
MethodType.fromSymbols(paramSyms, translateTempMethodAndPolyForMethod(resType))
case _ =>
throw Scala2PickleFormatException(s"Invalid type for method: $tp")

case TempPolyType(tparams, restpe) =>
val localTParams = tparams.asInstanceOf[List[LocalTypeParamSymbol]] // no class type params in methods
restpe match
case restpe: TypeOrMethodic =>
PolyType.fromParams(localTParams, restpe)
PolyType.fromParams(localTParams, translateTempMethodAndPolyForMethod(restpe))
case _ =>
throw Scala2PickleFormatException(s"Invalid type for method: $tp")

case tp =>
tp
end translateTempPolyForMethod
end translateTempMethodAndPolyForMethod

private def noSuchTypeTag(tag: Int, end: Int): Nothing =
errorBadSignature("bad type tag: " + tag)
Expand Down Expand Up @@ -947,6 +985,9 @@ private[reader] object PickleReader {
private val Scala2Constructor: SimpleName = termName("this")
private val Scala2TraitConstructor: SimpleName = termName("$init$")

private[tastyquery] case class TempMethodType(paramSyms: List[TermSymbol], resType: TypeMappable)
extends CustomTransientGroundType

private[tastyquery] case class TempPolyType(paramSyms: List[TypeParamSymbol], resType: TypeMappable)
extends CustomTransientGroundType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -768,10 +768,15 @@ private[tasties] class TreeUnpickler private (
if name == nme.Constructor then normalizeCtorParamClauses(params)
else params
symbol.withDeclaredType(ParamsClause.makeDefDefType(normalizedParams, tpt))
symbol.setParamSymss(normalizedParams.map(paramsClauseToParamSymbolsClause(_)))
definingTree(symbol, DefDef(name, normalizedParams, tpt, rhs, symbol)(spn))
}
}

private def paramsClauseToParamSymbolsClause(clause: ParamsClause): ParamSymbolsClause = clause match
case Left(termParams) => Left(termParams.map(_.symbol))
case Right(typeParams) => Right(typeParams.map(_.symbol.asInstanceOf[LocalTypeParamSymbol]))

/** Normalizes the param clauses of a constructor definition.
*
* Make sure it has at least one non-implicit parameter list. This is done
Expand Down
Loading