Skip to content

Commit

Permalink
Introduce a more restrictive ReaderContext inside reader.*.
Browse files Browse the repository at this point in the history
It statically makes sure that we never access anything that
requires reading another file.

This change surfaced a couple places where there could potentially
be external reading, and which are now a bit more complicated to
fit the static restriction. The most prominent example is
`TreeUnpickler.readWithin`, which also suggested that we decouple
reading the `privateWithin` from the flags (to only do it after
`createSymbols()` is over).
  • Loading branch information
sjrd committed Sep 22, 2023
1 parent f03c643 commit b2d3b05
Show file tree
Hide file tree
Showing 15 changed files with 338 additions and 194 deletions.
25 changes: 8 additions & 17 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,14 @@ lazy val tastyQuery =
mimaBinaryIssueFilters ++= {
import com.typesafe.tools.mima.core.*
Seq(
// private, so this is fine
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.reader.tasties.TreeUnpickler#Caches.refinedTypeTreeCache"),
ProblemFilters.exclude[MissingClassProblem]("tastyquery.reader.tasties.TreeUnpickler$LocalContext"),

// private[reader], so this is fine
ProblemFilters.exclude[Problem]("tastyquery.reader.tasties.TastyUnpickler#*"),

// private[tastyquery], so this is fine
ProblemFilters.exclude[MissingClassProblem]("tastyquery.Types$TypeParamInfo"),

/* We removed TypeParamInfo from the parents of ClassTypeParam.
* Since TypeParamInfo was `private[tastyquery]`, there is little chance it leaked.
*/
ProblemFilters.exclude[MissingTypesProblem]("tastyquery.Symbols$ClassTypeParamSymbol"),

// new abstract method in fully sealed trait, so this is fine
ProblemFilters.exclude[ReversedMissingMethodProblem]("tastyquery.Types#TermLambdaType.paramTypes"),
// Everything in tastyquery.reader is private[tastyquery] at most
ProblemFilters.exclude[Problem]("tastyquery.reader.*"),

// private[tastyquery], not an issue
ProblemFilters.exclude[IncompatibleMethTypeProblem]("tastyquery.Symbols#ClassSymbol.createRefinedClassSymbol"),
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Types#PolyType.fromParamsSymbols"),
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Types#TypeLambda.fromParamsSymbols"),
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Types#TypeLambdaTypeCompanion.fromParamsSymbols"),
)
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
def isTupleNClass(sym: ClassSymbol): Boolean =
sym.owner == scalaPackage && TupleNClasses.contains(sym)

lazy val hasGenericTuples = scalaPackage.getDecl(tpnme.TupleCons).isDefined
lazy val hasGenericTuples = ctx.classloader.hasGenericTuples

lazy val uninitializedMethod: Option[TermSymbol] =
scalaCompiletimePackage.getDecl(moduleClassName("package$package")).flatMap { packageObjectClass =>
Expand Down
51 changes: 38 additions & 13 deletions tasty-query/shared/src/main/scala/tastyquery/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ object Symbols {
private var isFlagsInitialized = false
private var myFlags: FlagSet = Flags.EmptyFlagSet
private var myTree: Option[DefiningTreeType] = None
private var myPrivateWithin: Option[DeclaringSymbol] = None
private var myPrivateWithin: Option[DeclaringSymbol] | Null = null
private var myAnnotations: List[Annotation] | Null = null

/** Checks that this `Symbol` has been completely initialized.
Expand All @@ -96,6 +96,7 @@ object Symbols {
*/
protected def doCheckCompleted(): Unit =
if !isFlagsInitialized then throw failNotCompleted("flags were not initialized")
if myPrivateWithin == null then throw failNotCompleted("privateWithin was not initialized")
if myAnnotations == null then throw failNotCompleted("annotations were not initialized")

private[tastyquery] def withTree(t: DefiningTreeType): this.type =
Expand All @@ -107,12 +108,24 @@ object Symbols {
myTree

private[tastyquery] final def withFlags(flags: FlagSet, privateWithin: Option[DeclaringSymbol]): this.type =
if isFlagsInitialized then throw IllegalStateException(s"reassignment of flags to $this")
setFlags(flags)
setPrivateWithin(privateWithin)

private[tastyquery] final def setFlags(flags: FlagSet): this.type =
if isFlagsInitialized || myPrivateWithin != null then
throw IllegalStateException(s"reassignment of flags to $this")
else
isFlagsInitialized = true
myFlags = flags
this
end setFlags

private[tastyquery] final def setPrivateWithin(privateWithin: Option[DeclaringSymbol]): this.type =
if myPrivateWithin != null then throw IllegalStateException(s"reassignment of privateWithin to $this")
else
myPrivateWithin = privateWithin
this
end setPrivateWithin

private[tastyquery] final def setAnnotations(annots: List[Annotation]): this.type =
if myAnnotations != null then throw IllegalStateException(s"reassignment of annotations to $this")
Expand All @@ -126,8 +139,9 @@ object Symbols {
else throw IllegalStateException(s"annotations of $this have not been initialized")

protected final def privateWithin: Option[DeclaringSymbol] =
if isFlagsInitialized then myPrivateWithin
else throw IllegalStateException(s"flags of $this have not been initialized")
val local = myPrivateWithin
if local != null then local
else throw IllegalStateException(s"privateWithin of $this has not been initialized")

protected final def flags: FlagSet =
if isFlagsInitialized then myFlags
Expand Down Expand Up @@ -1187,13 +1201,16 @@ object Symbols {
end distinguishOverloaded

final def getDecl(name: TypeName)(using Context): Option[TypeSymbol] =
getDeclImpl(name)

private[tastyquery] final def getDeclImpl(name: TypeName): Option[TypeSymbol] =
myDeclarations.get(name) match
case Some(decls) =>
assert(decls.sizeIs == 1, decls)
Some(decls.head.asType)
case None =>
None
end getDecl
end getDeclImpl

final def getDecl(name: TermName)(using Context): Option[TermSymbol] =
getDecl(name: Name).map(_.asTerm)
Expand Down Expand Up @@ -1273,6 +1290,9 @@ object Symbols {
end findNonOverloadedDecl

final def declarations(using Context): List[TermOrTypeSymbol] =
declarationsOfClass

private[tastyquery] final def declarationsOfClass: List[TermOrTypeSymbol] =
myDeclarations.values.toList.flatten

// Member lookup, including inherited members
Expand Down Expand Up @@ -1537,9 +1557,7 @@ object Symbols {
*
* This is only used by the Scala 2 unpickler.
*/
private[tastyquery] def setScala2SealedChildren(children: List[Symbol | Scala2ExternalSymRef])(
using Context
): Unit =
private[tastyquery] def setScala2SealedChildren(children: List[Symbol | Scala2ExternalSymRef]): Unit =
if !flags.is(Scala2Defined) then
throw IllegalArgumentException(s"Illegal $this.setScala2SealedChildren($children) for non-Scala 2 class")
if myScala2SealedChildren.isDefined then
Expand Down Expand Up @@ -1617,17 +1635,20 @@ object Symbols {
private[tastyquery] def createNotDeclaration(name: TypeName, owner: Symbol): ClassSymbol =
ClassSymbol(name, owner)

private[tastyquery] def createRefinedClassSymbol(owner: Symbol, flags: FlagSet, pos: SourcePosition)(
using Context
private[tastyquery] def createRefinedClassSymbol(
owner: Symbol,
objectType: TypeRef,
flags: FlagSet,
pos: SourcePosition
): ClassSymbol =
// TODO Store the `pos`
createRefinedClassSymbol(owner, flags)
createRefinedClassSymbol(owner, objectType, flags)

private[tastyquery] def createRefinedClassSymbol(owner: Symbol, flags: FlagSet)(using Context): ClassSymbol =
private[tastyquery] def createRefinedClassSymbol(owner: Symbol, objectType: TypeRef, flags: FlagSet): ClassSymbol =
val cls = ClassSymbol(tpnme.RefinedClassMagic, owner) // by-pass `owner.addDeclIfDeclaringSym`
cls
.withTypeParams(Nil)
.withParentsDirect(defn.ObjectType :: Nil)
.withParentsDirect(objectType :: Nil)
.withGivenSelfType(None)
.withFlags(flags, None)
.setAnnotations(Nil)
Expand Down Expand Up @@ -1665,6 +1686,10 @@ object Symbols {
/** Is this the root package? */
final def isRootPackage: Boolean = owner == null

/** Is this the scala package? */
private[tastyquery] def isScalaPackage: Boolean =
name == nme.scalaPackageName && owner != null && owner.isRootPackage

/** Gets the subpackage with the specified `name`, if it exists.
*
* If this package contains a subpackage with the name `name`, returns
Expand Down
12 changes: 7 additions & 5 deletions tasty-query/shared/src/main/scala/tastyquery/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,10 @@ object Types {
None // TODO

/** Is this type exactly Nothing (no vars, aliases, refinements etc allowed)? */
private[tastyquery] final def isExactlyNothing(using Context): Boolean = this match
private[tastyquery] final def isExactlyNothing: Boolean = this match
case tp: TypeRef if tp.name == tpnme.Nothing =>
tp.prefix.match
case prefix: PackageRef => prefix.symbol == defn.scalaPackage
case prefix: PackageRef => prefix.symbol.isScalaPackage
case _ => false
case _ =>
false
Expand Down Expand Up @@ -894,6 +894,10 @@ object Types {
private[tastyquery] final def isLocalRef(sym: Symbol): Boolean =
prefix == NoPrefix && (designator eq sym)

private[tastyquery] final def localSymbol: ThisSymbolType =
require(prefix == NoPrefix, prefix)
designator.asInstanceOf[ThisSymbolType]

private[tastyquery] final def isSomeClassTypeParamRef: Boolean =
designator.isInstanceOf[ClassTypeParamSymbol]

Expand Down Expand Up @@ -1585,9 +1589,7 @@ object Types {
sealed abstract class TypeLambdaTypeCompanion[RT <: TypeOrMethodic, LT <: TypeLambdaType]
extends LambdaTypeCompanion[TypeName, TypeBounds, RT, LT] {
@targetName("fromParamsSymbols")
private[tastyquery] final def fromParams(params: List[LocalTypeParamSymbol], resultType: RT)(
using Context
): LT | RT =
private[tastyquery] final def fromParams(params: List[LocalTypeParamSymbol], resultType: RT): LT | RT =
if params.isEmpty then resultType
else
val paramNames = params.map(_.name)
Expand Down
18 changes: 13 additions & 5 deletions tasty-query/shared/src/main/scala/tastyquery/reader/Loaders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import tastyquery.Names.*
import tastyquery.Symbols.*
import tastyquery.Trees.*

import tastyquery.reader.ReaderContext.rctx
import tastyquery.reader.classfiles.ClassfileParser
import tastyquery.reader.classfiles.ClassfileParser.{ClassKind, InnerClassDecl, Resolver}
import tastyquery.reader.tasties.TastyUnpickler
Expand All @@ -26,7 +27,7 @@ private[tastyquery] object Loaders {
pkg.fullName.select(rootName)
end Loader

class Loader(val classpath: Classpath) { loader =>
class Loader(val classpath: Classpath) {

given Resolver = Resolver()

Expand All @@ -39,6 +40,7 @@ private[tastyquery] object Loaders {

private var searched = false
private var packages: Map[PackageSymbol, IArray[PackageData]] = compiletime.uninitialized
private var _hasGenericTuples: Boolean = compiletime.uninitialized
private var byEntry: ByEntryMap | Null = null
private val roots: mutable.Map[PackageSymbol, mutable.Map[SimpleName, Entry]] = mutable.HashMap.empty
private var topLevelTastys: Map[Loader.Root, List[Tree]] = Map.empty
Expand All @@ -65,11 +67,13 @@ private[tastyquery] object Loaders {
* In any case, no new declarations can ever be found for the given root
* after this method.
*/
private def completeRoot(root: Loader.Root, entry: Entry)(using Context): Unit =
private def completeRoot(root: Loader.Root, entry: Entry)(using ctx: Context): Unit =
doCompleteRoot(root, entry)(using ReaderContext(ctx))

private def doCompleteRoot(root: Loader.Root, entry: Entry)(using ReaderContext): Unit =
def innerClassLookup(nested: IArray[ClassData]): Map[SimpleName, ClassData] =
val mkBinaryName: String => SimpleName =
if root.pkg == defn.EmptyPackage then termName(_)
if root.pkg == rctx.EmptyPackage then termName(_)
else
val pre = root.pkg.fullName.path.mkString("/")
bin => termName(s"$pre/$bin")
Expand Down Expand Up @@ -140,7 +144,7 @@ private[tastyquery] object Loaders {
case entry: Entry.TastyOnly =>
// Tested in `SymbolSuite`, `ReadTreeSuite`, these do not need to see class files.
enterTasty(root, entry.tastyData)
end completeRoot
end doCompleteRoot

/** Loads all the roots of the given `pkg`. */
private[tastyquery] def loadAllRoots(pkg: PackageSymbol)(using Context): Unit =
Expand Down Expand Up @@ -272,9 +276,13 @@ private[tastyquery] object Loaders {

if !searched then
searched = true
loader.packages = loadPackages().groupMap((pkg, _) => pkg)((_, data) => data)
packages = loadPackages().groupMap((pkg, _) => pkg)((_, data) => data)
_hasGenericTuples =
packages.get(defn.scalaPackage).exists(_.exists(_.tastys.exists(_.binaryName == "$times$colon")))
end initPackages

def hasGenericTuples: Boolean = _hasGenericTuples

private def computeByEntry()(using Context): ByEntryMap =
require(searched)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package tastyquery.reader

import tastyquery.Contexts.*
import tastyquery.Names.*
import tastyquery.SourceFile
import tastyquery.Symbols.*
import tastyquery.Types.*

/** A restricted Context that is safe to use from the readers.
*
* It does not give access to anything that might require reading other files.
*/
private[reader] final class ReaderContext(underlying: Context):
def RootPackage: PackageSymbol = underlying.defn.RootPackage
def EmptyPackage: PackageSymbol = underlying.defn.EmptyPackage
def javaLangPackage: PackageSymbol = underlying.defn.javaLangPackage
def scalaPackage: PackageSymbol = underlying.defn.scalaPackage

def NothingType: NothingType = underlying.defn.NothingType
def AnyType: TypeRef = underlying.defn.AnyType
def MatchableType: TypeRef = underlying.defn.MatchableType
def ObjectType: TypeRef = underlying.defn.ObjectType
def FromJavaObjectType: TypeRef = underlying.defn.FromJavaObjectType

def IntType: TypeRef = underlying.defn.IntType
def LongType: TypeRef = underlying.defn.LongType
def FloatType: TypeRef = underlying.defn.FloatType
def DoubleType: TypeRef = underlying.defn.DoubleType
def BooleanType: TypeRef = underlying.defn.BooleanType
def ByteType: TypeRef = underlying.defn.ByteType
def ShortType: TypeRef = underlying.defn.ShortType
def CharType: TypeRef = underlying.defn.CharType
def UnitType: TypeRef = underlying.defn.UnitType

def ArrayTypeOf(tpe: TypeOrWildcard): AppliedType = underlying.defn.ArrayTypeOf(tpe)
def RepeatedTypeOf(tpe: TypeOrWildcard): AppliedType = underlying.defn.RepeatedTypeOf(tpe)

def GenericTupleTypeOf(elementTypes: List[TypeOrWildcard]): Type = underlying.defn.GenericTupleTypeOf(elementTypes)

def NothingAnyBounds: RealTypeBounds = underlying.defn.NothingAnyBounds

def uninitializedMethodTermRef: TermRef = underlying.defn.uninitializedMethodTermRef

def findPackageFromRootOrCreate(fullyQualifiedName: FullyQualifiedName): PackageSymbol =
underlying.findPackageFromRootOrCreate(fullyQualifiedName)

/** Reads a package reference, with a fallback on faked term references.
*
* In a full, correct classpath, `createPackageSelection()` will always
* return a `PackageRef`. However, in an incomplete or incorrect classpath,
* this method may return a `TermRef` if the target package does not exist.
*
* An alternative would be to create missing packages on the fly, but that
* would not be consistent with `Trees.Select.tpe` and
* `Trees.TermRefTypeTree.toType`.
*/
def createPackageSelection(path: List[TermName]): TermReferenceType =
path.foldLeft[TermReferenceType](RootPackage.packageRef) { (prefix, name) =>
NamedType.possibleSelFromPackage(prefix, name)
}
end createPackageSelection

def getSourceFile(path: String): SourceFile =
underlying.getSourceFile(path)

def hasGenericTuples: Boolean = underlying.classloader.hasGenericTuples

def createObjectMagicMethods(cls: ClassSymbol): Unit =
underlying.defn.createObjectMagicMethods(cls)

def createStringMagicMethods(cls: ClassSymbol): Unit =
underlying.defn.createStringMagicMethods(cls)
end ReaderContext

private[reader] object ReaderContext:
def rctx(using context: ReaderContext): context.type = context
end ReaderContext
Loading

0 comments on commit b2d3b05

Please sign in to comment.