diff --git a/build.sbt b/build.sbt index 62474e57..c863baf2 100644 --- a/build.sbt +++ b/build.sbt @@ -127,8 +127,13 @@ lazy val tastyQuery = import com.typesafe.tools.mima.core.* Seq( // private[tastyquery], not an issue + ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Contexts#Context.classloader"), ProblemFilters.exclude[MissingClassProblem]("tastyquery.Utils"), ProblemFilters.exclude[MissingClassProblem]("tastyquery.Utils$"), + // private, not an issue + ProblemFilters.exclude[MissingClassProblem]("tastyquery.Types$TermRef$Resolved"), + ProblemFilters.exclude[MissingClassProblem]("tastyquery.Types$TypeRef$Resolved"), + // Everything in tastyquery.reader is private[tastyquery] at most ProblemFilters.exclude[Problem]("tastyquery.reader.*"), ) diff --git a/tasty-query/js/src/main/scala/tastyquery/nodejs/ClasspathLoaders.scala b/tasty-query/js/src/main/scala/tastyquery/nodejs/ClasspathLoaders.scala index a7863295..d4e27f99 100644 --- a/tasty-query/js/src/main/scala/tastyquery/nodejs/ClasspathLoaders.scala +++ b/tasty-query/js/src/main/scala/tastyquery/nodejs/ClasspathLoaders.scala @@ -31,6 +31,9 @@ object ClasspathLoaders: * to create a [[Contexts.Context]]. The latter gives semantic access to all * the definitions on the classpath. * + * The entries of the resulting [[Classpaths.Classpath]] can be considered + * thread-safe, since the JavaScript environment is always single-threaded. + * * @note the resulting [[Classpaths.ClasspathEntry ClasspathEntry]] entries of * the returned [[Classpaths.Classpath]] correspond to the elements of `classpath`. */ diff --git a/tasty-query/jvm/src/main/scala/tastyquery/jdk/ClasspathLoaders.scala b/tasty-query/jvm/src/main/scala/tastyquery/jdk/ClasspathLoaders.scala index 417c7163..22bf9e8e 100644 --- a/tasty-query/jvm/src/main/scala/tastyquery/jdk/ClasspathLoaders.scala +++ b/tasty-query/jvm/src/main/scala/tastyquery/jdk/ClasspathLoaders.scala @@ -40,6 +40,9 @@ object ClasspathLoaders { * to create a [[Contexts.Context]]. The latter gives semantic access to all * the definitions on the classpath. * + * The entries of the resulting [[Classpaths.Classpath]] are all guaranteed + * to be thread-safe. + * * @note the resulting [[Classpaths.ClasspathEntry ClasspathEntry]] entries of * the returned [[Classpaths.Classpath]] correspond to the elements of `classpath`. */ diff --git a/tasty-query/shared/src/main/scala/tastyquery/Annotations.scala b/tasty-query/shared/src/main/scala/tastyquery/Annotations.scala index 06620f73..966be130 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Annotations.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Annotations.scala @@ -13,20 +13,16 @@ import tastyquery.Utils.* object Annotations: final class Annotation(val tree: TermTree): - private var mySymbol: ClassSymbol | Null = null - private var mySafeSymbol: Option[ClassSymbol] | Null = null - private var myArguments: List[TermTree] | Null = null + private val mySymbol: Memo[ClassSymbol] = uninitializedMemo + private val mySafeSymbol: Memo[Option[ClassSymbol]] = uninitializedMemo + private val myArguments: Memo[List[TermTree]] = uninitializedMemo /** The annotation class symbol. */ def symbol(using Context): ClassSymbol = - memoized( - mySymbol, - { computed => - mySymbol = computed - mySafeSymbol = Some(computed) - } - ) { + memoized2(mySymbol) { computeAnnotSymbol(tree) + } { computed => + initializeMemo(mySafeSymbol, Some(computed)) } end symbol @@ -35,14 +31,10 @@ object Annotations: * If the class of this annotation cannot be successfully resolved, returns `false`. */ private[tastyquery] def safeHasSymbol(cls: ClassSymbol)(using Context): Boolean = - val safeSymbol = memoized( - mySafeSymbol, - { computed => - computed.foreach(mySymbol = _) - mySafeSymbol = computed - } - ) { + val safeSymbol = memoized2(mySafeSymbol) { computeSafeAnnotSymbol(tree) + } { computed => + computed.foreach(sym => initializeMemo(mySymbol, sym)) } safeSymbol.contains(cls) @@ -64,7 +56,7 @@ object Annotations: * `NamedArg`s are not visible with this method. They are replaced by * their right-hand-side. */ - def arguments: List[TermTree] = memoized(myArguments, myArguments = _) { + def arguments: List[TermTree] = memoized(myArguments) { computeAnnotArguments(tree) } diff --git a/tasty-query/shared/src/main/scala/tastyquery/Classpaths.scala b/tasty-query/shared/src/main/scala/tastyquery/Classpaths.scala index b13e33e8..f5d12d47 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Classpaths.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Classpaths.scala @@ -23,6 +23,11 @@ object Classpaths: * All the methods of `ClasspathEntry` and its components may throw * `java.io.IOException`s to indicate I/O errors. * + * A `ClasspathEntry` is encouraged to be thread-safe, along with all its + * components, but it is not a strong requirement. Implementations that are + * thread-safe should be documented as such. [[Contexts.Context]]s created + * only from thread-safe `ClasspathEntry`s are thread-safe themselves. + * * Implementations of this class are encouraged to define a `toString()` * method that helps identifying the entry for debugging purposes. */ @@ -99,16 +104,21 @@ object Classpaths: def readClassFileBytes(): IArray[Byte] end ClassData - /** In-memory representation of classpath entries. */ + /** In-memory representation of classpath entries. + * + * In-memory classpath entries are thread-safe. + */ object InMemory: import Classpaths as generic + /** A thread-safe, immutable classpath entry. */ final class ClasspathEntry(debugString: String, val packages: List[PackageData]) extends generic.ClasspathEntry: override def toString(): String = debugString def listAllPackages(): List[generic.PackageData] = packages end ClasspathEntry + /** A thread-safe, immutable package information within a classpath entry. */ final class PackageData(debugString: String, val dotSeparatedName: String, val classes: List[ClassData]) extends generic.PackageData: private lazy val byBinaryName = classes.map(c => c.binaryName -> c).toMap @@ -120,6 +130,7 @@ object Classpaths: def getClassDataByBinaryName(binaryName: String): Option[generic.ClassData] = byBinaryName.get(binaryName) end PackageData + /** A thread-safe, immutable class information within a classpath entry. */ final class ClassData( debugString: String, val binaryName: String, diff --git a/tasty-query/shared/src/main/scala/tastyquery/Contexts.scala b/tasty-query/shared/src/main/scala/tastyquery/Contexts.scala index eb225f2d..18b25be7 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Contexts.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Contexts.scala @@ -27,11 +27,22 @@ object Contexts { /** Factory methods for [[Context]]. */ object Context: - /** Creates a new [[Context]] for the given [[Classpaths.Classpath]]. */ + /** Creates a new [[Context]] for the given [[Classpaths.Classpath]]. + * + * If all the [[Classpaths.ClasspathEntry ClasspathEntries]] in the classpath + * are thread-safe, then the resulting [[Context]] is thread-safe. + */ def initialize(classpath: Classpath): Context = val classloader = Loader(classpath) val ctx = new Context(classloader) classloader.initPackages()(using ctx) + + /* Exploit the portable releaseFence() call inside the `::` constructor, + * in order to publish all the mutations that were done during the + * above initialization to other threads. + */ + new ::(Nil, Nil) + ctx end initialize end Context @@ -58,15 +69,19 @@ object Contexts { * The same instance of [[Classpaths.Classpath]] can be reused to create * several [[Context]]s, if necessary. */ - final class Context private[Contexts] (private[tastyquery] val classloader: Loader) { + final class Context private[Contexts] (classloader: Loader) { private given Context = this private val sourceFiles = mutable.HashMap.empty[String, SourceFile] private val (RootPackage @ _, EmptyPackage @ _) = PackageSymbol.createRoots() + private[tastyquery] def hasGenericTuples: Boolean = classloader.hasGenericTuples + val defn: Definitions = Definitions(this: @unchecked, RootPackage, EmptyPackage) + private[tastyquery] def internalClasspathForTestsOnly: Classpath = classloader.classpath + private[tastyquery] def getSourceFile(path: String): SourceFile = sourceFiles.getOrElseUpdate(path, new SourceFile(path)) diff --git a/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala b/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala index efa21c9c..28a298f2 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala @@ -7,8 +7,19 @@ import tastyquery.Names.* import tastyquery.Symbols.* import tastyquery.Types.* -final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageSymbol, emptyPackage: PackageSymbol): - private given Context = ctx +final class Definitions private[tastyquery] ( + ctxRestricted: Context, + rootPackage: PackageSymbol, + emptyPackage: PackageSymbol +): + /** Use the restricted context for an op. + * + * !!! ONLY use from the initialization code of `lazy val`s. + * + * Well ... `def FunctionNClass` also uses it, for compatibility reasons, but it's fine. + */ + private inline def withRestrictedContext[A](op: Context ?=> A): A = + op(using ctxRestricted) // Core packages @@ -284,15 +295,17 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS createSpecialMethod(cls, nme.m_synchronized, synchronizedTpe) end createObjectMagicMethods - lazy val Object_eq: TermSymbol = ObjectClass.findNonOverloadedDecl(nme.m_eq) - lazy val Object_ne: TermSymbol = ObjectClass.findNonOverloadedDecl(nme.m_ne) - lazy val Object_synchronized: TermSymbol = ObjectClass.findNonOverloadedDecl(nme.m_synchronized) + lazy val Object_eq: TermSymbol = withRestrictedContext(ObjectClass.findNonOverloadedDecl(nme.m_eq)) + lazy val Object_ne: TermSymbol = withRestrictedContext(ObjectClass.findNonOverloadedDecl(nme.m_ne)) + + lazy val Object_synchronized: TermSymbol = + withRestrictedContext(ObjectClass.findNonOverloadedDecl(nme.m_synchronized)) private[tastyquery] def createStringMagicMethods(cls: ClassSymbol): Unit = createSpecialMethod(cls, nme.m_+, stringConcatMethodType, Final) end createStringMagicMethods - lazy val String_+ : TermSymbol = StringClass.findNonOverloadedDecl(nme.m_+) + lazy val String_+ : TermSymbol = withRestrictedContext(StringClass.findNonOverloadedDecl(nme.m_+)) private[tastyquery] def createEnumMagicMethods(cls: ClassSymbol): Unit = val ctorTpe = PolyType(List(typeName("E")), List(NothingAnyBounds), MethodType(Nil, Nil, UnitType)) @@ -412,8 +425,12 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS // Derived symbols, found on the classpath extension (pkg: PackageSymbol) - private def requiredClass(name: String): ClassSymbol = pkg.getDecl(typeName(name)).get.asClass - private def optionalClass(name: String): Option[ClassSymbol] = pkg.getDecl(typeName(name)).map(_.asClass) + private def requiredClass(name: String): ClassSymbol = + withRestrictedContext(pkg.getDecl(typeName(name)).get.asClass) + + private def optionalClass(name: String): Option[ClassSymbol] = + withRestrictedContext(pkg.getDecl(typeName(name)).map(_.asClass)) + end extension lazy val ObjectClass = javaLangPackage.requiredClass("Object") @@ -423,7 +440,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS lazy val Function0Class = scalaPackage.requiredClass("Function0") def FunctionNClass(n: Int): ClassSymbol = - scalaPackage.requiredClass(s"Function$n") + withRestrictedContext(scalaPackage.findDecl(typeName(s"Function$n")).asClass) lazy val IntClass = scalaPackage.requiredClass("Int") lazy val LongClass = scalaPackage.requiredClass("Long") @@ -465,15 +482,15 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS private[tastyquery] lazy val PolyFunctionClass = scalaPackage.optionalClass("PolyFunction") - private[tastyquery] def isPolyFunctionSub(tpe: Type): Boolean = + private[tastyquery] def isPolyFunctionSub(tpe: Type)(using Context): Boolean = PolyFunctionClass.exists(cls => tpe.baseType(cls).isDefined) - private[tastyquery] def isPolyFunctionSub(prefix: Prefix): Boolean = prefix match + private[tastyquery] def isPolyFunctionSub(prefix: Prefix)(using Context): Boolean = prefix match case tpe: Type => isPolyFunctionSub(tpe) case _ => false private[tastyquery] object PolyFunctionType: - def unapply(tpe: TermRefinement): Option[MethodicType] = + def unapply(tpe: TermRefinement)(using Context): Option[MethodicType] = PolyFunctionClass match case None => None @@ -488,7 +505,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS None end unapply - private[tastyquery] def functionClassOf(mt: MethodicType): ClassSymbol = mt match + private[tastyquery] def functionClassOf(mt: MethodicType)(using Context): ClassSymbol = mt match case mt: PolyType => mt.resultType match case resultType: MethodicType => @@ -500,15 +517,17 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS end functionClassOf end PolyFunctionType - lazy val hasGenericTuples = ctx.classloader.hasGenericTuples + lazy val hasGenericTuples = withRestrictedContext(ctx.hasGenericTuples) lazy val uninitializedMethod: Option[TermSymbol] = - scalaCompiletimePackage.getDecl(moduleClassName("package$package")).flatMap { packageObjectClass => - packageObjectClass.asClass.getDecl(termName("uninitialized")) + withRestrictedContext { + scalaCompiletimePackage.getDecl(moduleClassName("package$package")).flatMap { packageObjectClass => + packageObjectClass.asClass.getDecl(termName("uninitialized")) + } } end uninitializedMethod private[tastyquery] lazy val uninitializedMethodTermRef: TermRef = - TermRef(TermRef(defn.scalaCompiletimePackage.packageRef, termName("package$package")), termName("uninitialized")) + TermRef(TermRef(scalaCompiletimePackage.packageRef, termName("package$package")), termName("uninitialized")) end Definitions diff --git a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala index 722a82bd..13aab272 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala @@ -2,6 +2,9 @@ package tastyquery import scala.annotation.{switch, tailrec} +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} + import scala.collection.mutable import tastyquery.Annotations.* @@ -16,8 +19,6 @@ import tastyquery.Trees.* import tastyquery.Types.* import tastyquery.Utils.* -import tastyquery.reader.Loaders.Loader - /** Symbols for all kinds of definitions in Scala programs. * * Every definition, like `class`es, `def`s, `type`s and type parameters, is @@ -75,8 +76,8 @@ object Symbols { private var isFlagsInitialized = false private var myFlags: FlagSet = Flags.EmptyFlagSet private var myTree: Option[DefiningTreeType] = None - private var myPrivateWithin: Option[DeclaringSymbol] | Null = null - private var myAnnotations: List[Annotation] | Null = null + private var myPrivateWithin: SingleAssign[Option[DeclaringSymbol]] = uninitializedSingleAssign + private var myAnnotations: SingleAssign[List[Annotation]] = uninitializedSingleAssign /** Checks that this `Symbol` has been completely initialized. * @@ -97,8 +98,8 @@ object Symbols { */ protected def doCheckCompleted(): Unit = if !isFlagsInitialized then failNotCompleted("flags were not initialized") - if myPrivateWithin == null then failNotCompleted("privateWithin was not initialized") - if myAnnotations == null then failNotCompleted("annotations were not initialized") + if !myPrivateWithin.isInitialized then failNotCompleted("privateWithin was not initialized") + if !myAnnotations.isInitialized then failNotCompleted("annotations were not initialized") private[tastyquery] def setTree(t: DefiningTreeType): this.type = require(!isPackage, s"Multiple trees correspond to one package, a single tree cannot be assigned") @@ -113,7 +114,7 @@ object Symbols { setPrivateWithin(privateWithin) private[tastyquery] final def setFlags(flags: FlagSet): this.type = - if isFlagsInitialized || myPrivateWithin != null then + if isFlagsInitialized || myPrivateWithin.isInitialized then throw IllegalStateException(s"reassignment of flags to $this") else isFlagsInitialized = true @@ -122,11 +123,11 @@ object Symbols { end setFlags private[tastyquery] final def setPrivateWithin(privateWithin: Option[DeclaringSymbol]): this.type = - assignOnce(myPrivateWithin, (myPrivateWithin = privateWithin))(s"reassignment of privateWithin to $this") + assignOnce(myPrivateWithin, myPrivateWithin = _, privateWithin)(s"reassignment of privateWithin to $this") this private[tastyquery] final def setAnnotations(annots: List[Annotation]): this.type = - assignOnce(myAnnotations, (myAnnotations = annots))(s"reassignment of annotations to $this") + assignOnce(myAnnotations, myAnnotations = _, annots)(s"reassignment of annotations to $this") this final def annotations: List[Annotation] = @@ -216,7 +217,7 @@ object Symbols { sealed abstract class TermOrTypeSymbol(override val owner: Symbol) extends Symbol(owner): type MatchingSymbolType >: this.type <: TermOrTypeSymbol - private var myLocalRef: NamedType | Null = null + private val myLocalRef: Memo[NamedType] = uninitializedMemo /** A reference to this symbol that is valid within its declaring scope. * @@ -225,7 +226,7 @@ object Symbols { */ def localRef: NamedType = // overridden in subclasses to provide a better-known result type - memoized(myLocalRef, myLocalRef = _) { + memoized(myLocalRef) { val pre = this match case self: ClassSymbol if self.isRefinementClass => /* Refinement classes are not declarations of their owner. @@ -425,33 +426,35 @@ object Symbols { type MatchingSymbolType = TermSymbol // Reference fields (checked in doCheckCompleted) - private var myDeclaredType: TypeOrMethodic | Null = null - private var myParamSymss: List[ParamSymbolsClause] | Null = null + private var myDeclaredType: SingleAssign[TypeOrMethodic] = uninitializedSingleAssign + private var myParamSymss: SingleAssign[List[ParamSymbolsClause]] = uninitializedSingleAssign // Cache fields - private var mySignature: Signature | Null = null - private var myTargetName: UnsignedTermName | Null = null - private var mySignedName: TermName | Null = null + private val mySignature: Memo[Signature] = uninitializedMemo + private val myTargetName: Memo[UnsignedTermName] = uninitializedMemo + private val mySignedName: Memo[TermName] = uninitializedMemo protected override def doCheckCompleted(): Unit = super.doCheckCompleted() - if myDeclaredType == null then failNotCompleted("declaredType was not initialized") + if !myDeclaredType.isInitialized then failNotCompleted("declaredType was not initialized") if flags.is(Method) then - if myParamSymss == null then failNotCompleted("paramSymss was not initialized") + if !myParamSymss.isInitialized then failNotCompleted("paramSymss was not initialized") paramSymss.foreach(_.merge.foreach(_.checkCompleted())) - else if myParamSymss == null then myParamSymss = Nil // auto-complete for non-methods - else if myParamSymss != Nil then + else if !myParamSymss.isInitialized then + // auto-complete for non-methods + assignOnce(myParamSymss, myParamSymss = _, Nil)("unreachable") + else if getAssignedOnce(myParamSymss)("unreachable") != Nil then throw IllegalArgumentException(s"illegal non-empty paramSymss $myParamSymss for $this") end doCheckCompleted private[tastyquery] final def setDeclaredType(tpe: TypeOrMethodic): this.type = - assignOnce(myDeclaredType, (myDeclaredType = tpe))(s"reassignment of declared type to $this") + assignOnce(myDeclaredType, myDeclaredType = _, tpe)(s"reassignment of declared type to $this") this /** You should not need this; it is a hack for patching Scala 2 constructors in `PickleReader`. */ private[tastyquery] final def overwriteDeclaredType(tpe: TypeOrMethodic): this.type = - myDeclaredType = tpe + overwriteSingleAssign[TypeOrMethodic](myDeclaredType = _, tpe) this def declaredType: TypeOrMethodic = @@ -460,7 +463,7 @@ object Symbols { private lazy val isPrefixDependent: Boolean = TypeOps.isPrefixDependent(declaredType) private[tastyquery] final def setParamSymss(paramSymss: List[ParamSymbolsClause]): this.type = - assignOnce(myParamSymss, (myParamSymss = paramSymss))(s"reassignment of paramSymss to $this") + assignOnce(myParamSymss, myParamSymss = _, paramSymss)(s"reassignment of paramSymss to $this") this private[tastyquery] final def autoFillParamSymss(): this.type = @@ -607,11 +610,11 @@ object Symbols { private[tastyquery] final def needsSignature: Boolean = declaredType.isInstanceOf[MethodicType] - final def signature(using Context): Signature = memoized(mySignature, mySignature = _) { + final def signature(using Context): Signature = memoized(mySignature) { Signature.fromType(declaredType, sourceLanguage, Option.when(isConstructor)(owner.asClass)) } - final def targetName(using Context): UnsignedTermName = memoized(myTargetName, myTargetName = _) { + final def targetName(using Context): UnsignedTermName = memoized(myTargetName) { if annotations.isEmpty then name else defn.targetNameAnnotClass match @@ -630,7 +633,7 @@ object Symbols { * If the `owner` of this symbol is a `DeclaringSymbol`, then `owner.getDecl(signedName)` * will return this symbol. This is not always the case with `name`. */ - final def signedName(using Context): TermName = memoized(mySignedName, mySignedName = _) { + final def signedName(using Context): TermName = memoized(mySignedName) { if needsSignature then SignedName(name, signature, targetName) else name } @@ -748,14 +751,14 @@ object Symbols { type DefiningTreeType >: TypeParam <: TypeParam | TypeTreeBind // Reference fields (checked in doCheckCompleted) - private var myDeclaredBounds: TypeBounds | Null = null + private var myDeclaredBounds: SingleAssign[TypeBounds] = uninitializedSingleAssign protected override def doCheckCompleted(): Unit = super.doCheckCompleted() - if myDeclaredBounds == null then failNotCompleted("bounds are not initialized") + if !myDeclaredBounds.isInitialized then failNotCompleted("bounds are not initialized") private[tastyquery] final def setDeclaredBounds(bounds: TypeBounds): this.type = - assignOnce(myDeclaredBounds, (myDeclaredBounds = bounds))(s"Trying to re-set the bounds of $this") + assignOnce(myDeclaredBounds, myDeclaredBounds = _, bounds)(s"Trying to re-set the bounds of $this") this final def declaredBounds: TypeBounds = @@ -834,14 +837,14 @@ object Symbols { type DefiningTreeType = TypeMember // Reference fields (checked in doCheckCompleted) - private var myDefinition: TypeMemberDefinition | Null = null + private var myDefinition: SingleAssign[TypeMemberDefinition] = uninitializedSingleAssign protected override def doCheckCompleted(): Unit = super.doCheckCompleted() - if myDefinition == null then failNotCompleted("type member definition not initialized") + if !myDefinition.isInitialized then failNotCompleted("type member definition not initialized") private[tastyquery] final def setDefinition(definition: TypeMemberDefinition): this.type = - assignOnce(myDefinition, (myDefinition = definition))(s"Reassignment of the definition of $this") + assignOnce(myDefinition, myDefinition = _, definition)(s"Reassignment of the definition of $this") this final def typeDef: TypeMemberDefinition = @@ -902,30 +905,31 @@ object Symbols { private val specialKind: SpecialKind = computeSpecialKind(name, owner) // Reference fields (checked in doCheckCompleted) - private var myTypeParams: List[ClassTypeParamSymbol] | Null = null - private var myParents: List[Type] | Null = null - private var myGivenSelfType: Option[Type] | Null = null + private var myTypeParams: SingleAssign[List[ClassTypeParamSymbol]] = uninitializedSingleAssign + private val myParents: Memo[List[Type]] = uninitializedMemo + private var myGivenSelfType: SingleAssign[Option[Type]] = uninitializedSingleAssign // Optional reference fields private var myScala2SealedChildren: Option[List[Symbol | Scala2ExternalSymRef]] = None + private var myTopLevelTasty: List[TopLevelTree] = Nil // DeclaringSymbol-related fields private val myDeclarations: mutable.HashMap[UnsignedName, mutable.HashSet[TermOrTypeSymbol]] = mutable.HashMap[UnsignedName, mutable.HashSet[TermOrTypeSymbol]]() // Cache fields - private var mySignatureName: SignatureName | Null = null - private var myAppliedRef: Type | Null = null - private var mySelfType: Type | Null = null - private var myLinearization: List[ClassSymbol] | Null = null - private var myErasure: ErasedTypeRef.ClassRef | Null = null - private var mySealedChildren: List[SealedChild] | Null = null + private val mySignatureName: Memo[SignatureName] = uninitializedMemo + private val myAppliedRef: Memo[Type] = uninitializedMemo + private val mySelfType: Memo[Type] = uninitializedMemo + private val myLinearization: Memo[List[ClassSymbol]] = uninitializedMemo + private val myErasure: Memo[ErasedTypeRef.ClassRef] = uninitializedMemo + private val mySealedChildren: Memo[List[SealedChild]] = uninitializedMemo protected override def doCheckCompleted(): Unit = super.doCheckCompleted() - if myTypeParams == null then failNotCompleted("typeParams not initialized") - if myParents == null && tree.isEmpty then failNotCompleted("parents not initialized") - if myGivenSelfType == null then failNotCompleted("givenSelfType not initialized") + if !myTypeParams.isInitialized then failNotCompleted("typeParams not initialized") + if !myParents.isInitialized && tree.isEmpty then failNotCompleted("parents not initialized") + if !myGivenSelfType.isInitialized then failNotCompleted("givenSelfType not initialized") /** The open level of this class (open, closed, sealed or final). */ final def openLevel: OpenLevel = @@ -1036,23 +1040,23 @@ object Symbols { computeErasedName(owner.owner, filledName) end computeErasedName - memoized(mySignatureName, mySignatureName = _) { + memoized(mySignatureName) { computeErasedName(owner, name.toTermName.asInstanceOf[SignatureNameItem]) } end signatureName private[tastyquery] final def setTypeParams(tparams: List[ClassTypeParamSymbol]): this.type = - assignOnce(myTypeParams, (myTypeParams = tparams))(s"reassignment of type parameters to $this") + assignOnce(myTypeParams, myTypeParams = _, tparams)(s"reassignment of type parameters to $this") this final def typeParams: List[ClassTypeParamSymbol] = getAssignedOnce(myTypeParams)(s"type params not initialized for $this") private[tastyquery] final def setParentsDirect(parents: List[Type]): this.type = - assignOnce(myParents, (myParents = parents))(s"reassignment of parents of $this") + assignOnceMemo(myParents, parents)(s"reassignment of parents of $this") this - final def parents(using Context): List[Type] = memoized(myParents, myParents = _) { + final def parents(using Context): List[Type] = memoized(myParents) { val tree = this.tree.getOrElse { throw IllegalStateException(s"$this was not assigned parents") } @@ -1074,18 +1078,18 @@ object Symbols { ) private[tastyquery] final def setGivenSelfType(givenSelfType: Option[Type]): this.type = - assignOnce(myGivenSelfType, (myGivenSelfType = givenSelfType))(s"reassignment of givenSelfType for $this") + assignOnce(myGivenSelfType, myGivenSelfType = _, givenSelfType)(s"reassignment of givenSelfType for $this") this final def givenSelfType: Option[Type] = getAssignedOnce(myGivenSelfType)(s"givenSelfType not initialized for $this") - final def appliedRefInsideThis: Type = memoized(myAppliedRef, myAppliedRef = _) { + final def appliedRefInsideThis: Type = memoized(myAppliedRef) { if typeParams.isEmpty then localRef else AppliedType(localRef, typeParams.map(_.localRef)) } - final def selfType: Type = memoized(mySelfType, mySelfType = _) { + final def selfType: Type = memoized(mySelfType) { givenSelfType match case None => appliedRefInsideThis @@ -1094,7 +1098,23 @@ object Symbols { else AndType(givenSelf, appliedRefInsideThis) } - final def linearization(using Context): List[ClassSymbol] = memoized(myLinearization, myLinearization = _) { + private[tastyquery] final def setTopLevelTasty(trees: List[TopLevelTree]): this.type = + require(owner.isPackage, "cannot set topLevelTasty to a non-top-level class") + require(!flags.isAnyOf(Scala2Defined | JavaDefined), "cannot set topLevelTasty to a non-Scala 3 class") + myTopLevelTasty = trees + this + end setTopLevelTasty + + private[tastyquery] final def topLevelTasty: List[TopLevelTree] = + require(owner.isPackage, s"illegal call to topLevelTasty on the non-top-level class $this") + require( + !flags.isAnyOf(Scala2Defined | JavaDefined), + s"illegal call to topLevelTasty on the non-Scala 3 class $this" + ) + myTopLevelTasty + end topLevelTasty + + final def linearization(using Context): List[ClassSymbol] = memoized(myLinearization) { val parentsLin = parentClasses.foldLeft[List[ClassSymbol]](Nil) { (lin, parent) => parent.linearization.filter(c => !lin.contains(c)) ::: lin } @@ -1105,7 +1125,7 @@ object Symbols { linearization.contains(that) /** The erasure of this class; nonsensical for `scala.Array`. */ - private[tastyquery] final def erasure(using Context): ErasedTypeRef.ClassRef = memoized(myErasure, myErasure = _) { + private[tastyquery] final def erasure(using Context): ErasedTypeRef.ClassRef = memoized(myErasure) { (specialKind: @switch) match case SpecialKind.Any | SpecialKind.AnyVal | SpecialKind.Matchable | SpecialKind.Singleton => defn.ObjectClass.erasure @@ -1272,9 +1292,6 @@ object Symbols { // Type-related things - private[tastyquery] final def initParents: Boolean = - myTypeParams != null - // Partial internal guarantee that we follow the right shape private type BaseType = TypeRef | AppliedType @@ -1283,7 +1300,7 @@ object Symbols { case _ => throw AssertionError(s"baseType internally produced an invalid shape: $tpe") end asBaseType - private val baseTypeForClassCache = mutable.AnyRefMap.empty[ClassSymbol, Option[BaseType]] + private val baseTypeForClassCache = new ConcurrentHashMap[ClassSymbol, Option[BaseType]]() /** Cached core lookup of `this.baseTypeOf(clsOwner.this.cls)`. * @@ -1295,13 +1312,19 @@ object Symbols { def foldGlb(bt: Option[BaseType], ps: List[Type]): Option[BaseType] = ps.foldLeft(bt)((bt, p) => baseTypeCombine(bt, baseTypeOf(p), meet = true)) - baseTypeForClassCache.getOrElseUpdate( - cls, - if cls.isSubClass(this) then - if this.isStatic && this.typeParams.isEmpty then Some(this.localRef) - else foldGlb(None, cls.parents) - else None - ) + // Do not use computeIfAbsent because it is not lock-free + val cachedResult = baseTypeForClassCache.get(cls) + if cachedResult != null then cachedResult + else + val computed = + if cls.isSubClass(this) then + if this.isStatic && this.typeParams.isEmpty then Some(this.localRef) + else foldGlb(None, cls.parents) + else None + + val concurrentlyCachedResult = baseTypeForClassCache.putIfAbsent(cls, computed) + if concurrentlyCachedResult != null then concurrentlyCachedResult + else computed end baseTypeForClass /** Computes the (unapplied) baseType of a class type constructor. @@ -1499,10 +1522,10 @@ object Symbols { lookup(linearization) end resolveMatchingMember - private var myThisType: ThisType | Null = null + private val myThisType: Memo[ThisType] = uninitializedMemo /** The `ThisType` for this class, as visible from inside this class. */ - final def thisType: ThisType = memoized(myThisType, myThisType = _) { + final def thisType: ThisType = memoized(myThisType) { ThisType(localRef) } @@ -1515,7 +1538,8 @@ object Symbols { throw IllegalArgumentException(s"Illegal $this.setScala2SealedChildren($children) for non-Scala 2 class") if myScala2SealedChildren.isDefined then throw IllegalStateException(s"Scala 2 sealed children were already set for $this") - if mySealedChildren != null then throw IllegalStateException(s"Sealed children were already computed for $this") + if mySealedChildren.isInitialized then + throw IllegalStateException(s"Sealed children were already computed for $this") myScala2SealedChildren = Some(children) end setScala2SealedChildren @@ -1532,7 +1556,7 @@ object Symbols { * The results are ordered by their declaration order in the source. */ final def sealedChildren(using Context): List[ClassSymbol | TermSymbol] = - memoized(mySealedChildren, mySealedChildren = _) { + memoized(mySealedChildren) { if !flags.is(Sealed) then Nil else myScala2SealedChildren match @@ -1706,24 +1730,43 @@ object Symbols { extends Symbol(owner) with DeclaringSymbol { import PackageSymbol.* + import tastyquery.reader.Loaders.PackageLoadingInfo type DefiningTreeType = Nothing type DeclType = Symbol private[Symbols] val specialKind: SpecialKind = computeSpecialKind(name, owner) + /** Package loading info with raw data from the classpath. */ + private var optLoadingInfo: Option[PackageLoadingInfo] = None + // DeclaringSymbol-related fields - private val myDeclarations = mutable.HashMap[UnsignedName, Symbol]() + + /** Atomically swapped when `loadingNewRoots` successfully finishes. + * + * Other threads can read this reference at any time. + */ + private val myDeclarations = new AtomicReference[Map[UnsignedName, Symbol]](Map.empty) + + /** The pending declarations while `loadingNewRoots` is executing. + * + * Only the thread performing `loadingNewRoots` is allowed to use this map. + */ private val pendingDeclarations = mutable.HashMap[UnsignedName, Symbol]() - private var isLoadingNewRoots: Boolean = false + + /** Whether we are currently loading new roots; atomically set and reset by `loadingNewRoots`. */ + private val isLoadingNewRoots = new AtomicBoolean(false) // Cache fields val packageRef: PackageRef = new PackageRef(this) - private var myAllPackageObjectDecls: List[ClassSymbol] | Null = null + private val myAllPackageObjectDecls: Memo[List[ClassSymbol]] = uninitializedMemo this.setFlags(EmptyFlagSet, None) this.setAnnotations(Nil) + private def getMyDeclaractions: Map[UnsignedName, Symbol] = + myDeclarations.get().nn + private lazy val _fullName: PackageFullName = if owner == null || name == nme.EmptyPackageName then PackageFullName.rootPackageName else owner.fullName.select(name) @@ -1748,6 +1791,10 @@ object Symbols { /** Is this the scala package? */ private[tastyquery] def isScalaPackage: Boolean = specialKind == SpecialKind.scala + private[tastyquery] def setLoadingInfo(loadingInfo: PackageLoadingInfo): Unit = + if optLoadingInfo.isDefined then throw IllegalStateException(s"Loading info already set for $this") + optLoadingInfo = Some(loadingInfo) + /** Gets the subpackage with the specified `name`, if it exists. * * If this package contains a subpackage with the name `name`, returns @@ -1764,16 +1811,13 @@ object Symbols { /* All subpackages are created eagerly when initializing contexts, * so we can directly access myDeclarations here. */ - myDeclarations.get(name).collect { case pkg: PackageSymbol => + getMyDeclaractions.get(name).collect { case pkg: PackageSymbol => pkg } end getPackageDecl private[Symbols] final def addDecl(decl: Symbol): Unit = - assert( - !myDeclarations.contains(decl.name) && !pendingDeclarations.contains(decl.name), - s"trying to add a second entry $decl for name ${decl.name} in $this" - ) + def duplicateMessage: String = s"trying to add a second entry $decl for name ${decl.name} in $this" /* If we are loading new roots and the decl is not a package, * add the declaration to the pending set only. They will be committed @@ -1782,13 +1826,28 @@ object Symbols { * Packages are always eagerly committed. */ decl match - case decl: TermOrTypeSymbol if isLoadingNewRoots => + case decl: TermOrTypeSymbol if isLoadingNewRoots.get() => + assert(!getMyDeclaractions.contains(decl.name) && !pendingDeclarations.contains(decl.name), duplicateMessage) pendingDeclarations(decl.name) = decl + case _ => - myDeclarations(decl.name) = decl + // Manual CAS loop because `updateAndGet` does not say anything about when the lambda throws + @tailrec + def loop(): Unit = + val prev = getMyDeclaractions + assert(!prev.contains(decl.name), duplicateMessage) + val next = prev + (decl.name -> decl) + if !myDeclarations.compareAndSet(prev, next) then loop() + end loop + + loop() end addDecl /** Performs an operation that can load new roots from the class loader. + * + * This operation is synchronized per package. If another thread is + * already loading roots, this method will synchronously wait for it to + * be done. * * While loading new roots, any new non-package member sent to `addDecl` * is added to `pendingDeclarations` instead of `myDeclarations`. They @@ -1797,27 +1856,48 @@ object Symbols { * * This way, any exception occurring during loading does not pollute the * publicly visible state in `myDeclarations`. + * + * @return + * true iff at least one new declaration was added to the package during the operation */ - private def loadingNewRoots[A](op: Loader => A)(using Context): A = - if isLoadingNewRoots then throw IllegalStateException(s"Cyclic loading of new roots in package $this") - - isLoadingNewRoots = true - try - val result = op(ctx.classloader) + private def loadingNewRoots(op: PackageLoadingInfo => Unit)(using Context): Boolean = + optLoadingInfo match + case None => + false - // Upon success, commit pending declations - myDeclarations ++= pendingDeclarations + case Some(loadingInfo) => + val myDeclarationsBefore = getMyDeclaractions + + val myDeclarationsAfter = loadingInfo.synchronized { + if !isLoadingNewRoots.compareAndSet(false, true) then + throw IllegalStateException(s"Cyclic loading of new roots in package $this") + + try + op(loadingInfo) + + // Upon success, commit pending declations + if pendingDeclarations.nonEmpty then + myDeclarations.updateAndGet({ prev => + prev.nn ++ pendingDeclarations + }) + else + // get without updating to test whether another thread has brought some changes + myDeclarations.get() + finally + pendingDeclarations.clear() // whether or not they were committed + isLoadingNewRoots.set(false) + } - result - finally - pendingDeclarations.clear() // whether or not they were committed - isLoadingNewRoots = false + /* This could be true even if `pendingDeclarations.nonEmpty`, if two + * threads concurrently ask to load the same root. + */ + myDeclarationsAfter ne myDeclarationsBefore end loadingNewRoots final def getDecl(name: Name)(using Context): Option[Symbol] = name match case name: UnsignedName => - myDeclarations.get(name).orElse { - if loadingNewRoots(_.loadRoot(this, name)) then myDeclarations.get(name) + getMyDeclaractions.get(name).orElse { + if loadingNewRoots(_.loadOneRoot(name)) then getMyDeclaractions.get(name) else None } case _: SignedName => @@ -1846,14 +1926,14 @@ object Symbols { } final def declarations(using Context): List[Symbol] = - loadingNewRoots(_.loadAllRoots(this)) - myDeclarations.values.toList + loadingNewRoots(_.loadAllRoots()) + getMyDeclaractions.values.toList // See PackageRef.findMember private[tastyquery] def allPackageObjectDecls()(using Context): List[ClassSymbol] = - memoized(myAllPackageObjectDecls, myAllPackageObjectDecls = _) { - loadingNewRoots(_.loadAllPackageObjectRoots(this)) - myDeclarations.valuesIterator.collect { + memoized(myAllPackageObjectDecls) { + loadingNewRoots(_.loadAllPackageObjectRoots()) + getMyDeclaractions.valuesIterator.collect { case cls: ClassSymbol if cls.name.isPackageObjectClassName => cls }.toList .sortBy(_.name.toString) // sort for determinism diff --git a/tasty-query/shared/src/main/scala/tastyquery/Trees.scala b/tasty-query/shared/src/main/scala/tastyquery/Trees.scala index e14ae3f8..28b3d1cf 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Trees.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Trees.scala @@ -70,7 +70,7 @@ object Trees { end StatementTree sealed abstract class TermTree(pos: SourcePosition) extends StatementTree(pos): - private var myType: TermType | Null = null + private val myType: Memo[TermType] = uninitializedMemo def withPos(pos: SourcePosition): TermTree @@ -94,7 +94,7 @@ object Trees { protected def calculateType(using Context): TermType /** The term type of this tree. */ - final def tpe(using Context): TermType = memoized(myType, myType = _) { + final def tpe(using Context): TermType = memoized(myType) { calculateType } end TermTree @@ -338,24 +338,23 @@ object Trees { } /** fun(args) */ - final case class Apply protected[tastyquery] (fun: TermTree, args: List[TermTree])( - private var _methodType: MethodType | Null, - pos: SourcePosition - ) extends TermTree(pos): + final case class Apply(fun: TermTree, args: List[TermTree])(pos: SourcePosition) extends TermTree(pos): import Apply.* - def this(fun: TermTree, args: List[TermTree])(pos: SourcePosition) = this(fun, args)(null, pos) + private val myMethodType: Memo[MethodType] = uninitializedMemo + + protected[tastyquery] def this( + fun: TermTree, + args: List[TermTree] + )(methodType: MethodType | Null, pos: SourcePosition) = + this(fun, args)(pos) + if methodType != null then initializeMemo(myMethodType, methodType) - def methodType(using Context): MethodType = - val local = _methodType - if local != null then local - else - val computed = fun.tpe.widenTermRef match - case funTpe: MethodType => funTpe - case funTpe => throw NonMethodReferenceException(s"application to $funTpe") - _methodType = computed - computed - end methodType + def methodType(using Context): MethodType = memoized(myMethodType) { + fun.tpe.widenTermRef match + case funTpe: MethodType => funTpe + case funTpe => throw NonMethodReferenceException(s"application to $funTpe") + } private def instantiateMethodType(args: List[TermType])(using Context): TermType = for arg <- args do @@ -383,7 +382,7 @@ object Trees { object Apply: def apply(fun: TermTree, args: List[TermTree])(pos: SourcePosition): Apply = - new Apply(fun, args)(null, pos) + new Apply(fun, args)(pos) def forSignaturePolymorphic(fun: TermTree, methodType: MethodType, args: List[TermTree])( pos: SourcePosition @@ -716,13 +715,13 @@ object Trees { end TypeArgTree sealed abstract class TypeTree(pos: SourcePosition) extends TypeArgTree(pos) { - private var myType: NonEmptyPrefix | Null = null + private val myType: Memo[NonEmptyPrefix] = uninitializedMemo protected def calculateType: NonEmptyPrefix def withPos(pos: SourcePosition): TypeTree - final def toPrefix: NonEmptyPrefix = memoized(myType, myType = _) { + final def toPrefix: NonEmptyPrefix = memoized(myType) { calculateType } @@ -882,9 +881,9 @@ object Trees { } final case class WildcardTypeArgTree(bounds: TypeBoundsTree)(pos: SourcePosition) extends TypeArgTree(pos) { - private var myTypeOrWildcard: WildcardTypeArg | Null = null + private val myTypeOrWildcard: Memo[WildcardTypeArg] = uninitializedMemo - def toTypeOrWildcard: TypeOrWildcard = memoized(myTypeOrWildcard, myTypeOrWildcard = _) { + def toTypeOrWildcard: TypeOrWildcard = memoized(myTypeOrWildcard) { WildcardTypeArg(bounds.toTypeBounds) } diff --git a/tasty-query/shared/src/main/scala/tastyquery/Types.scala b/tasty-query/shared/src/main/scala/tastyquery/Types.scala index 20c28a4c..19919f48 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Types.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Types.scala @@ -3,7 +3,6 @@ package tastyquery import scala.annotation.{constructorOnly, tailrec, targetName} import scala.collection.mutable -import scala.compiletime.uninitialized import tastyquery.Annotations.* import tastyquery.Constants.* @@ -909,7 +908,7 @@ object Types { private[tastyquery] final def designatorInternal: AnyDesignatorType = designator - private var myName: ThisName | Null = null + private val myName: Memo[ThisName] = uninitializedMemo private[tastyquery] final def isLocalRef(sym: Symbol): Boolean = prefix == NoPrefix && (designator eq sym) @@ -939,7 +938,7 @@ object Types { */ def name: Name - protected final def nameImpl: ThisName = memoized(myName, myName = _) { + protected final def nameImpl: ThisName = memoized(myName) { (designator match { case name: Name => name case sym: TermOrTypeSymbol => sym.name @@ -1018,10 +1017,11 @@ object Types { /** The singleton type for path prefix#myDesignator. */ final class TermRef private ( val prefix: Prefix, - private var myDesignator: TermSymbol | TermName | LookupIn | Scala2ExternalSymRef + private val myDesignator: TermSymbol | TermName | LookupIn | Scala2ExternalSymRef ) extends NamedType with SingletonType with TermReferenceType { + import TermRef.Resolved protected type ThisName = TermName private[tastyquery] type ThisSymbolType = TermSymbol @@ -1029,14 +1029,11 @@ object Types { protected type ThisDesignatorType = TermSymbol | TermName | LookupIn | Scala2ExternalSymRef // Cache fields - private var mySymbol: TermSymbol | Null = null - private var myUnderlying: TypeOrMethodic | Null = null - private var myIsStable: Boolean = false // only meaningful once mySymbol is non-null + private val myResolved: Memo[Resolved] = uninitializedMemo private def this(prefix: NonEmptyPrefix, resolved: ResolveMemberResult.TermMember) = this(prefix, resolved.symbols.head) - mySymbol = resolved.symbols.head - myUnderlying = resolved.tpe + initializeMemo(myResolved, Resolved(resolved.symbols.head, resolved.tpe, resolved.isStable)) end this protected def designator: ThisDesignatorType = myDesignator @@ -1047,56 +1044,51 @@ object Types { final def name: TermName = nameImpl final def symbol(using Context): TermSymbol = - ensureResolved() - mySymbol.nn + resolved.symbol - private def ensureResolved()(using Context): Unit = - if mySymbol == null then resolve() + private def resolved(using Context): Resolved = memoized(myResolved) { + doResolve() + } - private def resolve()(using Context): Unit = - def storeResolved(sym: TermSymbol, tpe: TypeOrMethodic, isStable: Boolean): Unit = - mySymbol = sym - myDesignator = sym - myUnderlying = tpe - myIsStable = isStable + private def doResolve()(using Context): Resolved = + def resolveToMember(sym: TermSymbol, tpe: TypeOrMethodic, isStable: Boolean): Resolved = + Resolved(sym, tpe, isStable) - def storeSymbol(sym: TermSymbol): Unit = - storeResolved(sym, sym.typeAsSeenFrom(prefix), sym.isStableMember) + def resolveToSymbol(sym: TermSymbol): Resolved = + resolveToMember(sym, sym.typeAsSeenFrom(prefix), sym.isStableMember) designator match case sym: TermSymbol => - storeSymbol(sym) + resolveToSymbol(sym) case lookupIn: LookupIn => val sym = TermRef.resolveLookupIn(lookupIn) - storeSymbol(sym) + resolveToSymbol(sym) case externalRef: Scala2ExternalSymRef => val sym = NamedType.resolveScala2ExternalRef(externalRef).asTerm - storeSymbol(sym) + resolveToSymbol(sym) case name: TermName => prefix match case prefix: NonEmptyPrefix => TermRef.resolvePolyFunctionApply(prefix, name, prefix.resolveMember(name)) match case ResolveMemberResult.TermMember(symbols, tpe, isStable) if symbols.nonEmpty => - storeResolved(symbols.head, tpe, isStable) + resolveToMember(symbols.head, tpe, isStable) case _ => throw MemberNotFoundException(prefix, name) case NoPrefix => throw new AssertionError(s"found reference by name $name without a prefix") - end resolve + end doResolve final def optSymbol(using Context): Option[TermSymbol] = Some(symbol) def underlyingOrMethodic(using Context): TypeOrMethodic = - ensureResolved() - myUnderlying.asInstanceOf[TypeOrMethodic] + resolved.tpe override def underlying(using Context): Type = underlyingOrMethodic.requireType final override def isStable(using Context): Boolean = - ensureResolved() - myIsStable + resolved.isStable private[tastyquery] override def resolveMember(name: Name, pre: Type)(using Context): ResolveMemberResult = underlyingOrMethodic match @@ -1139,6 +1131,8 @@ object Types { } object TermRef: + private final class Resolved(val symbol: TermSymbol, val tpe: TypeOrMethodic, val isStable: Boolean) + def apply(prefix: NonEmptyPrefix, name: TermName): TermRef = new TermRef(prefix, name) def apply(prefix: Prefix, symbol: TermSymbol): TermRef = new TermRef(prefix, symbol) @@ -1236,8 +1230,9 @@ object Types { final class TypeRef private ( val prefix: Prefix, - private var myDesignator: TypeName | TypeSymbol | LookupTypeIn | Scala2ExternalSymRef + private val myDesignator: TypeName | TypeSymbol | LookupTypeIn | Scala2ExternalSymRef ) extends NamedType { + import TypeRef.Resolved protected type ThisName = TypeName private[tastyquery] type ThisSymbolType = TypeSymbol @@ -1245,20 +1240,16 @@ object Types { protected type ThisDesignatorType = TypeName | TypeSymbol | LookupTypeIn | Scala2ExternalSymRef // Cache fields - private var myOptSymbol: Option[TypeSymbol] | Null = null - private var myBounds: TypeBounds | Null = null + private val myResolved: Memo[Resolved] = uninitializedMemo private def this(prefix: NonEmptyPrefix, resolved: ResolveMemberResult.ClassMember) = this(prefix, resolved.cls) - myOptSymbol = Some(resolved.cls) + initializeMemo(myResolved, Resolved(Some(resolved.cls), null)) end this private def this(prefix: NonEmptyPrefix, name: TypeName, resolved: ResolveMemberResult.TypeMember) = this(prefix, name) - val optSymbol = resolved.symbols.headOption - myOptSymbol = optSymbol - if optSymbol.isDefined then myDesignator = optSymbol.get - myBounds = resolved.bounds + initializeMemo(myResolved, Resolved(resolved.symbols.headOption, resolved.bounds)) end this final def name: TypeName = nameImpl @@ -1268,49 +1259,46 @@ object Types { override def toString(): String = s"TypeRef($prefix, $myDesignator)" - private def ensureResolved()(using Context): Unit = - if myOptSymbol == null then resolve() + private def resolved(using Context): Resolved = memoized(myResolved) { + doResolve() + } - private def resolve()(using Context): Unit = - def storeClass(cls: ClassSymbol): Unit = - myOptSymbol = Some(cls) - myDesignator = cls + private def doResolve()(using Context): Resolved = + def resolveToClass(cls: ClassSymbol): Resolved = + Resolved(Some(cls), null) - def storeTypeMember(sym: Option[TypeSymbolWithBounds], bounds: TypeBounds): Unit = - myOptSymbol = sym - if sym.isDefined then myDesignator = sym.get - myBounds = bounds + def resolveToTypeMember(sym: Option[TypeSymbolWithBounds], bounds: TypeBounds): Resolved = + Resolved(sym, bounds) - def storeSymbol(sym: TypeSymbol): Unit = sym match - case cls: ClassSymbol => storeClass(cls) - case sym: TypeSymbolWithBounds => storeTypeMember(Some(sym), sym.boundsAsSeenFrom(prefix)) + def resolveToSymbol(sym: TypeSymbol): Resolved = sym match + case cls: ClassSymbol => resolveToClass(cls) + case sym: TypeSymbolWithBounds => resolveToTypeMember(Some(sym), sym.boundsAsSeenFrom(prefix)) designator match case sym: TypeSymbol => - storeSymbol(sym) + resolveToSymbol(sym) case lookupTypeIn: LookupTypeIn => val sym = TypeRef.resolveLookupTypeIn(lookupTypeIn) - storeSymbol(sym) + resolveToSymbol(sym) case externalRef: Scala2ExternalSymRef => val sym = NamedType.resolveScala2ExternalRef(externalRef).asType - storeSymbol(sym) + resolveToSymbol(sym) case name: TypeName => prefix match case prefix: NonEmptyPrefix => prefix.resolveMember(name) match case ResolveMemberResult.ClassMember(cls) => - storeClass(cls) + resolveToClass(cls) case ResolveMemberResult.TypeMember(symbols, bounds) => - storeTypeMember(symbols.headOption, bounds) + resolveToTypeMember(symbols.headOption, bounds) case _ => throw MemberNotFoundException(prefix, name) case NoPrefix => throw new AssertionError(s"found reference by name $name without a prefix") - end resolve + end doResolve final def optSymbol(using Context): Option[TypeSymbol] = - ensureResolved() - myOptSymbol.nn + resolved.optSymbol final def isClass(using Context): Boolean = optSymbol.exists(_.isClass) @@ -1325,8 +1313,7 @@ object Types { bounds.high final def bounds(using Context): TypeBounds = - ensureResolved() - val local = myBounds + val local = resolved.bounds if local == null then throw AssertionError(s"TypeRef $this has no `underlying` because it refers to a `ClassSymbol`") else local @@ -1337,8 +1324,7 @@ object Types { throw AssertionError(s"No typeDef for $this") def optAliasedType(using Context): Option[Type] = - ensureResolved() - myBounds match + resolved.bounds match case TypeAlias(alias) => Some(alias) case _ => None @@ -1391,6 +1377,8 @@ object Types { } object TypeRef: + private final class Resolved(val optSymbol: Option[TypeSymbol], val bounds: TypeBounds | Null) + def apply(prefix: NonEmptyPrefix, name: TypeName): TypeRef = new TypeRef(prefix, name) def apply(prefix: Prefix, symbol: TypeSymbol): TypeRef = new TypeRef(prefix, symbol) @@ -1431,9 +1419,9 @@ object Types { end TypeRef final class ThisType(val tref: TypeRef) extends SingletonType { - private var myUnderlying: Type | Null = null + private val myUnderlying: Memo[Type] = uninitializedMemo - override def underlying(using Context): Type = memoized(myUnderlying, myUnderlying = _) { + override def underlying(using Context): Type = memoized(myUnderlying) { val cls = this.cls if cls.isStatic then cls.selfType else cls.selfType.asSeenFrom(tref.prefix, cls) @@ -1449,10 +1437,12 @@ object Types { * by `super`. */ final class SuperType(val thistpe: ThisType, val explicitSupertpe: Option[Type]) extends TypeProxy with SingletonType: - private var mySupertpe: Type | Null = explicitSupertpe.orNull + private val mySupertpe: Memo[Type] = uninitializedMemo - private[tastyquery] final def supertpe(using Context): Type = memoized(mySupertpe, mySupertpe = _) { - thistpe.cls.parents.reduceLeft(_ & _) + private[tastyquery] final def supertpe(using Context): Type = memoized(mySupertpe) { + explicitSupertpe.getOrElse { + thistpe.cls.parents.reduceLeft(_ & _) + } } override def underlying(using Context): Type = supertpe @@ -1555,9 +1545,9 @@ object Types { /** The type of a repeated parameter of the form `T*`. */ final class RepeatedType(val elemType: Type) extends TypeProxy: - private var myUnderlying: Type | Null = null + private val myUnderlying: Memo[Type] = uninitializedMemo - override def underlying(using Context): Type = memoized(myUnderlying, myUnderlying = _) { + override def underlying(using Context): Type = memoized(myUnderlying) { defn.SeqTypeOf(elemType) } @@ -2059,11 +2049,11 @@ object Types { ) extends RefinedType: // Cache fields private[tastyquery] val isMethodic = refinedType.isInstanceOf[MethodicType] - private var mySignedName: SignedName | Null = null + private val mySignedName: Memo[SignedName] = uninitializedMemo require(!(isStable && isMethodic), s"Ill-formed $this") - private[tastyquery] def signedName(using Context): SignedName = memoized(mySignedName, mySignedName = _) { + private[tastyquery] def signedName(using Context): SignedName = memoized(mySignedName) { val sig = Signature.fromType(refinedType, SourceLanguage.Scala3, optCtorReturn = None) SignedName(refinedName, sig) } @@ -2251,11 +2241,11 @@ object Types { /** selector match { cases } */ final class MatchType(val bound: Type, val scrutinee: Type, val cases: List[MatchTypeCase]) extends TypeProxy: - private var myReduced: Option[Type] | Null = null + private val myReduced: Memo[Option[Type]] = uninitializedMemo def underlying(using Context): Type = bound - def reduced(using Context): Option[Type] = memoized(myReduced, myReduced = _) { + def reduced(using Context): Option[Type] = memoized(myReduced) { TypeMatching.matchCases(scrutinee, cases) } @@ -2357,19 +2347,11 @@ object Types { // ----- Ground Types ------------------------------------------------- final class OrType(val first: Type, val second: Type) extends GroundType { - private var myJoin: Type | Null = uninitialized + private val myJoin: Memo[Type] = uninitializedMemo /** Returns the closest non-OrType type above this one. */ - def join(using Context): Type = { - val myJoin = this.myJoin - if (myJoin != null) then myJoin - else - val computedJoin = computeJoin() - this.myJoin = computedJoin - computedJoin - } + def join(using Context): Type = memoized(myJoin) { - private def computeJoin()(using Context): Type = /** The minimal set of classes in `classes` which derive all other classes in `classes` */ def dominators(classes: List[ClassSymbol], acc: List[ClassSymbol]): List[ClassSymbol] = classes match case cls :: rest => @@ -2384,7 +2366,7 @@ object Types { val commonBaseClasses = TypeOps.baseClasses(prunedNulls) val doms = dominators(commonBaseClasses, Nil) doms.flatMap(cls => prunedNulls.baseType(cls)).reduceLeft(AndType.make(_, _)) - end computeJoin + } private def tryPruneNulls(tp: Type)(using Context): Type = tp match case tp: OrType => diff --git a/tasty-query/shared/src/main/scala/tastyquery/Utils.scala b/tasty-query/shared/src/main/scala/tastyquery/Utils.scala index f02d229c..8d5cacaa 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Utils.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Utils.scala @@ -1,24 +1,76 @@ package tastyquery +import scala.annotation.targetName + +import java.util.concurrent.atomic.AtomicReference + private[tastyquery] object Utils: + opaque type Memo[A] = AtomicReference[A] + + opaque type SingleAssign[A] = A | Null + + // Memo + + inline def uninitializedMemo[A]: Memo[A] = new AtomicReference[A]() + + extension [A](memo: Memo[A]) + @targetName("isMemoInitialized") + def isInitialized: Boolean = memo.get() != null + /** A memoized computation `computed`, stored in `memo` using the `store` setter. */ - inline def memoized[A](memo: A | Null, inline store: A => Unit)(inline compute: => A): A = - if memo != null then memo + inline def memoized[A](memo: Memo[A])(inline compute: => A): A = + val existing = memo.get() + if existing != null then existing else // Extracted in a separate def for good jitting of the code calling `memoized` def computeAndStore(): A = val computed = compute - store(computed) - computed + if memo.compareAndSet(null, computed) then computed + else memo.get().nn computeAndStore() end memoized - inline def assignOnce(existing: Any, inline assign: => Unit)(inline msgIfAlreadyAssigned: => String): Unit = + inline def memoized2[A](memo: Memo[A])(inline compute: => A)(inline afterCompute: A => Unit): A = + val existing = memo.get() + if existing != null then existing + else + // Extracted in a separate def for good jitting of the code calling `memoized2` + def computeAndStore(): A = + val computed = compute + if memo.compareAndSet(null, computed) then + afterCompute(computed) + computed + else + // We wasted the computation; use the stored value so that only once instance survives for the GC + memo.get().nn + computeAndStore() + end memoized2 + + inline def initializeMemo[A](memo: Memo[A], value: A): Unit = + memo.compareAndSet(null, value) + + inline def assignOnceMemo[A](existing: Memo[A], value: A)(inline msgIfAlreadyAssigned: => String): Unit = + if !existing.compareAndSet(null, value) then throw IllegalStateException(msgIfAlreadyAssigned) + + // SingleAssign + + inline def uninitializedSingleAssign[A]: SingleAssign[A] = null + + extension [A](singleAssign: SingleAssign[A]) + @targetName("isSingleAssignInitialized") + def isInitialized: Boolean = singleAssign != null + + inline def assignOnce[A](existing: SingleAssign[A], inline assign: SingleAssign[A] => Unit, value: A)( + inline msgIfAlreadyAssigned: => String + ): Unit = // Methods calling `assignOnce` are not in fast paths, so no need to extract the exception in a local def if existing != null then throw IllegalStateException(msgIfAlreadyAssigned) - assign + assign(value) + + inline def overwriteSingleAssign[A](inline assign: SingleAssign[A] => Unit, value: A): Unit = + assign(value) - inline def getAssignedOnce[A](value: A | Null)(inline msgIfNotAssignedYet: => String): A = + inline def getAssignedOnce[A](value: SingleAssign[A])(inline msgIfNotAssignedYet: => String): A = if value != null then value else // Extracted in a separate def for good jitting of the code calling `getAssignedOnce` diff --git a/tasty-query/shared/src/main/scala/tastyquery/reader/Loaders.scala b/tasty-query/shared/src/main/scala/tastyquery/reader/Loaders.scala index 70789983..e2f9cc3a 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/reader/Loaders.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/reader/Loaders.scala @@ -10,15 +10,20 @@ import tastyquery.Exceptions.* import tastyquery.Names.* import tastyquery.Symbols.* import tastyquery.Trees.* +import tastyquery.Utils.* import tastyquery.reader.ReaderContext.rctx import tastyquery.reader.classfiles.{ClassfileParser, ClassfileReader} import tastyquery.reader.classfiles.ClassfileParser.{ClassKind, InnerClassDecl, Resolver} +import tastyquery.reader.classfiles.ClassfileReader.Structure import tastyquery.reader.tasties.TastyUnpickler private[tastyquery] object Loaders { - private final class PackageLoadingInfo(val pkg: PackageSymbol, initPackageData: List[PackageData]): + private[tastyquery] final class PackageLoadingInfo private[Loaders] ( + pkg: PackageSymbol, + initPackageData: List[PackageData] + ): private lazy val dataByBinaryName = val localRoots = mutable.HashMap.empty[String, ClassData] for packageData <- initPackageData do @@ -29,14 +34,12 @@ private[tastyquery] object Loaders { localRoots end dataByBinaryName - private val topLevelTastys = mutable.HashMap.empty[String, List[Tree]] - private type LoadedFiles = mutable.HashSet[String] - def topLevelTastyFor(rootName: String): Option[List[Tree]] = - topLevelTastys.get(rootName) + /** Loads all the roots of the associated package. */ + def loadAllRoots()(using Context): Unit = + given ReaderContext = ReaderContext(ctx) - def loadAllRoots()(using ReaderContext, Resolver): Unit = // Sort for determinism, and to make sure that outer classes always come before their inner classes val allNames = dataByBinaryName.keysIterator.toList.sorted @@ -48,7 +51,10 @@ private[tastyquery] object Loaders { dataByBinaryName.clear() end loadAllRoots - def loadAllPackageObjectRoots()(using ReaderContext, Resolver): Unit = + /** Loads all the roots of the associated package that could be package objects. */ + def loadAllPackageObjectRoots()(using Context): Unit = + given ReaderContext = ReaderContext(ctx) + def isPackageObjectBinaryName(name: String): Boolean = name == "package" || name.endsWith("$package") @@ -60,12 +66,27 @@ private[tastyquery] object Loaders { } end loadAllPackageObjectRoots - def loadOneRoot(binaryName: String)(using ReaderContext, Resolver): Boolean = + /** Loads the root of the associated package that would define `name`, if there is one such root. */ + def loadOneRoot(name: Name)(using Context): Unit = + given ReaderContext = ReaderContext(ctx) + loadingRoots { loadedFiles => + val binaryName = topLevelSymbolNameToRootName(name) tryLoadRoot(binaryName, loadedFiles) } end loadOneRoot + private def topLevelSymbolNameToRootName(name: Name): String = name match + case name: TypeName => + topLevelSymbolNameToRootName(name.toTermName) + case ObjectClassName(objName) => + topLevelSymbolNameToRootName(objName) + case name: SimpleName => + NameTransformer.encode(name.name) + case _ => + throw IllegalStateException(s"Invalid top-level symbol name ${name.toDebugString}") + end topLevelSymbolNameToRootName + private def loadingRoots[A](op: LoadedFiles => A): A = val loadedFiles = mutable.HashSet.empty[String] val result = op(loadedFiles) @@ -76,32 +97,27 @@ private[tastyquery] object Loaders { result end loadingRoots - private def tryLoadRoot(binaryName: String, loadedFiles: LoadedFiles)(using ReaderContext, Resolver): Boolean = + private def tryLoadRoot(binaryName: String, loadedFiles: LoadedFiles)(using ReaderContext): Unit = dataByBinaryName.get(binaryName) match case None => - false + () case Some(classData) => // Avoid reading inner classes that we already loaded through their outer classes. if loadedFiles.add(binaryName) then - if classData.hasTastyFile then - doLoadTasty(classData) - true - else if doLoadClassFile(classData, loadedFiles) then true - else + if classData.hasTastyFile then doLoadTasty(classData) + else if !doLoadClassFile(classData, loadedFiles) then /* Oops, maybe we will need this one later, if it is a (non-local) * inner class of another Java class. * Removing it from loadedFiles so that we do not throw away the file. */ loadedFiles -= binaryName - false - else false end tryLoadRoot private lazy val fullBinaryNamePrefix: String = if pkg.isEmptyPackage then "" else pkg.fullName.path.mkString("", "/", "/") - def doLoadClassFile(classData: ClassData, loadedFiles: LoadedFiles)(using ReaderContext, Resolver): Boolean = + def doLoadClassFile(classData: ClassData, loadedFiles: LoadedFiles)(using ReaderContext): Boolean = val structure = ClassfileReader.readStructure(pkg, classData) val kind = ClassfileParser.detectClassKind(structure) kind match @@ -109,8 +125,7 @@ private[tastyquery] object Loaders { ClassfileParser.loadScala2Class(structure) true case ClassKind.Java => - val innerDecls = ClassfileParser.loadJavaClass(pkg, termName(classData.binaryName), structure) - doLoadJavaInnerClasses(innerDecls, loadedFiles) + doLoadJavaTopLevelClass(classData, structure, loadedFiles) true case ClassKind.TASTy => throw TastyFormatException(s"Missing TASTy file for class ${classData.binaryName} in package $pkg") @@ -122,6 +137,16 @@ private[tastyquery] object Loaders { false end doLoadClassFile + private def doLoadJavaTopLevelClass(classData: ClassData, structure: Structure, loadedFiles: LoadedFiles)( + using ReaderContext + ): Unit = + // The resolver for this top-level class and all its inner classes + given Resolver = Resolver() + + val innerDecls = ClassfileParser.loadJavaClass(pkg, termName(classData.binaryName), structure) + doLoadJavaInnerClasses(innerDecls, loadedFiles) + end doLoadJavaTopLevelClass + private def doLoadJavaInnerClasses(explore: List[InnerClassDecl], loadedFiles: LoadedFiles)( using ReaderContext, Resolver @@ -147,7 +172,7 @@ private[tastyquery] object Loaders { private def doLoadTasty(classData: ClassData)(using ReaderContext): Unit = val unpickler = TastyUnpickler(classData.readTastyFileBytes()) val debugPath = classData.toString() - val trees = unpickler + unpickler .unpickle( debugPath, TastyUnpickler.TreeSectionUnpickler( @@ -156,19 +181,15 @@ private[tastyquery] object Loaders { ) .get .unpickle() - topLevelTastys += classData.binaryName -> trees end doLoadTasty end PackageLoadingInfo class Loader(val classpath: Classpath) { - given Resolver = Resolver() - private type ByEntryMap = Map[ClasspathEntry, IArray[(PackageSymbol, IArray[String])]] private var initialized = false - private var packages: Map[PackageSymbol, PackageLoadingInfo] = compiletime.uninitialized - private var _hasGenericTuples: Boolean = compiletime.uninitialized - private var byEntry: ByEntryMap | Null = null + private var _hasGenericTuples: Boolean = false + private val byEntry: Memo[ByEntryMap] = uninitializedMemo private def toPackageName(dotSeparated: String): PackageFullName = val parts = @@ -176,58 +197,9 @@ private[tastyquery] object Loaders { else dotSeparated.split('.').toList.map(termName(_)) PackageFullName(parts) - private def topLevelSymbolNameToRootName(name: Name): String = name match - case name: TypeName => - topLevelSymbolNameToRootName(name.toTermName) - case ObjectClassName(objName) => - topLevelSymbolNameToRootName(objName) - case name: SimpleName => - NameTransformer.encode(name.name) - case _ => - throw IllegalStateException(s"Invalid top-level symbol name ${name.toDebugString}") - end topLevelSymbolNameToRootName - private def rootNameToTopLevelTermSymbolName(rootName: String): SimpleName = termName(NameTransformer.decode(rootName)) - /** If this is a root symbol, lookup possible top level tasty trees associated with it. */ - private[tastyquery] def topLevelTasty(rootSymbol: Symbol)(using Context): Option[List[Tree]] = - rootSymbol.owner match - case pkg: PackageSymbol => - val rootName = topLevelSymbolNameToRootName(rootSymbol.name) - packages.get(pkg).flatMap(_.topLevelTastyFor(rootName)) - case _ => None - - /** Loads all the roots of the given `pkg`. */ - private[tastyquery] def loadAllRoots(pkg: PackageSymbol)(using Context): Unit = - for loadingInfo <- packages.get(pkg) do loadingInfo.loadAllRoots()(using ReaderContext(ctx)) - - /** Loads all the roots of the given `pkg` that could be package objects. */ - private[tastyquery] def loadAllPackageObjectRoots(pkg: PackageSymbol)(using Context): Unit = - for loadingInfo <- packages.get(pkg) do loadingInfo.loadAllPackageObjectRoots()(using ReaderContext(ctx)) - - /** Loads the root of the given `pkg` that would define `name`, if there is one such root. - * - * When this method returns `true`, it is not guaranteed that the - * particular `name` corresponds to a `Symbol`. But when it returns - * `false`, there is a guarantee that no new symbol with the given `name` - * was loaded. - * - * Whether this method returns `true` or `false`, any subsequent call to - * `loadRoot` with the same arguments will return `false`. - * - * @return - * `true` if a root was loaded, `false` otherwise. - */ - private[tastyquery] def loadRoot(pkg: PackageSymbol, name: Name)(using Context): Boolean = - packages.get(pkg) match - case Some(loadingInfo) => - val rootName = topLevelSymbolNameToRootName(name) - loadingInfo.loadOneRoot(rootName)(using ReaderContext(ctx)) - case None => - false - end loadRoot - def lookupByEntry(src: ClasspathEntry)(using Context): Option[Iterable[TermOrTypeSymbol]] = def lookupRoots(pkg: PackageSymbol, rootNames: IArray[String]) = val buf = IArray.newBuilder[TermOrTypeSymbol] @@ -245,12 +217,10 @@ private[tastyquery] object Loaders { case Some(pkgs) => Some(pkgs.view.flatMap(lookupRoots)) case None => None - val localByEntry = byEntry - if localByEntry == null then - val newByEntry = computeByEntry() - byEntry = newByEntry - computeLookup(newByEntry) - else computeLookup(localByEntry) + val localByEntry = memoized(byEntry) { + computeByEntry() + } + computeLookup(localByEntry) end lookupByEntry def initPackages()(using ctx: Context): Unit = @@ -267,13 +237,13 @@ private[tastyquery] object Loaders { ) end loadPackages - val rawMap = loadPackages().groupBy(_._1) - packages = rawMap.map((pkg, pairs) => pkg -> new PackageLoadingInfo(pkg, pairs.map(_._2))) - _hasGenericTuples = rawMap - .get(defn.scalaPackage) - .exists(_.exists { (pkg, data) => - data.getClassDataByBinaryName("$times$colon").isDefined - }) + for (pkg, pairs) <- loadPackages().groupBy(_._1) do + val initPackageData = pairs.map(_._2) + pkg.setLoadingInfo(new PackageLoadingInfo(pkg, initPackageData)) + + if pkg.isScalaPackage then + _hasGenericTuples = initPackageData.exists(_.getClassDataByBinaryName("$times$colon").isDefined) + end for end initPackages def hasGenericTuples: Boolean = _hasGenericTuples diff --git a/tasty-query/shared/src/main/scala/tastyquery/reader/ReaderContext.scala b/tasty-query/shared/src/main/scala/tastyquery/reader/ReaderContext.scala index d1a4f7dc..ef2eb084 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/reader/ReaderContext.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/reader/ReaderContext.scala @@ -69,7 +69,7 @@ private[reader] final class ReaderContext(underlying: Context): def getSourceFile(path: String): SourceFile = underlying.getSourceFile(path) - def hasGenericTuples: Boolean = underlying.classloader.hasGenericTuples + def hasGenericTuples: Boolean = underlying.hasGenericTuples def createObjectMagicMethods(cls: ClassSymbol): Unit = underlying.defn.createObjectMagicMethods(cls) diff --git a/tasty-query/shared/src/main/scala/tastyquery/reader/tasties/TreeUnpickler.scala b/tasty-query/shared/src/main/scala/tastyquery/reader/tasties/TreeUnpickler.scala index 7b110a30..d0d49249 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/reader/tasties/TreeUnpickler.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/reader/tasties/TreeUnpickler.scala @@ -44,21 +44,24 @@ private[tasties] class TreeUnpickler private ( using ReaderContext ) = this(filename, reader, nameAtRef, posUnpicklerOpt, new TreeUnpickler.Caches) - def unpickle(): List[Tree] = + def unpickle(): Unit = @tailrec - def read(acc: ListBuffer[Tree])(using SourceFile): List[Tree] = + def read(acc: ListBuffer[TopLevelTree])(using SourceFile): List[TopLevelTree] = acc += readTopLevelStat if !reader.isAtEnd then read(acc) else acc.toList fork.enterSymbols() - val result = maybeAdjustSourceFileIn { - read(new ListBuffer[Tree]) + val topLevelTasty = maybeAdjustSourceFileIn { + read(new ListBuffer[TopLevelTree]) }(using SourceFile.NoSource) - // Check that all the Symbols we created have been completed - for sym <- caches.allRegisteredSymbols do sym.checkCompleted() - - result + // Check that all the Symbols we created have been completed, and fill in top-level TASTy trees + for sym <- caches.allRegisteredSymbols do + sym match + case sym: ClassSymbol if sym.owner.isPackage => sym.setTopLevelTasty(topLevelTasty) + case _ => () + sym.checkCompleted() + end for end unpickle private def enterSymbols(): Unit = diff --git a/tasty-query/shared/src/test/scala/tastyquery/ClasspathEntrySuite.scala b/tasty-query/shared/src/test/scala/tastyquery/ClasspathEntrySuite.scala index a09b2222..87bfa184 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/ClasspathEntrySuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/ClasspathEntrySuite.scala @@ -9,7 +9,7 @@ import tastyquery.testutil.TestPlatform class ClasspathEntrySuite extends UnrestrictedUnpicklingSuite: def scala3ClasspathEntry(using Context): ClasspathEntry = - ctx.classloader.classpath(TestPlatform.scala3ClasspathIndex) + ctx.internalClasspathForTestsOnly(TestPlatform.scala3ClasspathIndex) def lookupSyms(entry: ClasspathEntry)(using Context): IArray[Symbol] = IArray.from(ctx.findSymbolsByClasspathEntry(entry)) diff --git a/tasty-query/shared/src/test/scala/tastyquery/RestrictedUnpicklingSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/RestrictedUnpicklingSuite.scala index 9cd67e4d..f945f220 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/RestrictedUnpicklingSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/RestrictedUnpicklingSuite.scala @@ -16,9 +16,9 @@ abstract class RestrictedUnpicklingSuite extends BaseUnpicklingSuite { for base <- initRestrictedContext(rootSymbolPath, extraRootSymbolPaths) yield given Context = base val rootSym = findTopLevelClassOrModuleClass(rootSymbolPath) - val tree = base.classloader.topLevelTasty(rootSym) match - case Some(trees) => trees.head - case _ => fail(s"Missing tasty for $rootSymbolPath, but resolved root $rootSym") + val tree = rootSym.topLevelTasty match + case firstTree :: _ => firstTree + case Nil => fail(s"Missing tasty for $rootSymbolPath, but resolved root $rootSym") (base, tree) end findTopLevelTasty diff --git a/tasty-query/shared/src/test/scala/tastyquery/UnrestrictedUnpicklingSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/UnrestrictedUnpicklingSuite.scala index 3ac469ac..e26605ea 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/UnrestrictedUnpicklingSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/UnrestrictedUnpicklingSuite.scala @@ -5,13 +5,29 @@ import scala.concurrent.ExecutionContext.Implicits.global import tastyquery.Contexts.* abstract class UnrestrictedUnpicklingSuite extends BaseUnpicklingSuite { + import UnrestrictedUnpicklingSuite.* + + /** Set this to true to stress-test thread-safety by using a common `Context` across all test suites. */ + private final val useParallelTesting = false + def testWithContext(name: String)(using munit.Location)(body: Context ?=> Unit): Unit = testWithContext(new munit.TestOptions(name))(body) def testWithContext(options: munit.TestOptions)(using munit.Location)(body: Context ?=> Unit): Unit = test(options) { - for classpath <- testClasspath yield - val ctx = Context.initialize(classpath) - body(using ctx) + if useParallelTesting then + // use the common context + for ctx <- commonContextForParallelTesting yield body(using ctx) + else + // create an isolated context + for classpath <- testClasspath yield + val ctx = Context.initialize(classpath) + body(using ctx) + end if } } + +object UnrestrictedUnpicklingSuite: + private lazy val commonContextForParallelTesting = + tastyquery.testutil.TestPlatform.loadClasspath().map(Context.initialize(_)) +end UnrestrictedUnpicklingSuite