Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport "Three fixes to SAM type handling" to LTS #22132

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -305,24 +305,41 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def TypeDef(sym: TypeSymbol)(using Context): TypeDef =
ta.assignType(untpd.TypeDef(sym.name, TypeTree(sym.info)), sym)

def ClassDef(cls: ClassSymbol, constr: DefDef, body: List[Tree], superArgs: List[Tree] = Nil)(using Context): TypeDef = {
/** Create a class definition
* @param cls the class symbol of the created class
* @param constr its primary constructor
* @param body the statements in its template
* @param superArgs the arguments to pass to the superclass constructor
* @param adaptVarargs if true, allow matching a vararg superclass constructor
* with a missing argument in superArgs, and synthesize an
* empty repeated parameter in the supercall in this case
*/
def ClassDef(cls: ClassSymbol, constr: DefDef, body: List[Tree],
superArgs: List[Tree] = Nil, adaptVarargs: Boolean = false)(using Context): TypeDef =
val firstParent :: otherParents = cls.info.parents: @unchecked

def adaptedSuperArgs(ctpe: Type): List[Tree] = ctpe match
case ctpe: PolyType =>
adaptedSuperArgs(ctpe.instantiate(firstParent.argTypes))
case ctpe: MethodType
if ctpe.paramInfos.length == superArgs.length + 1 =>
// last argument must be a vararg, otherwise isApplicable would have failed
superArgs :+
repeated(Nil, TypeTree(ctpe.paramInfos.last.argInfos.head, inferred = true))
case _ =>
superArgs

val superRef =
if (cls.is(Trait)) TypeTree(firstParent)
else {
def isApplicable(ctpe: Type): Boolean = ctpe match {
case ctpe: PolyType =>
isApplicable(ctpe.instantiate(firstParent.argTypes))
case ctpe: MethodType =>
(superArgs corresponds ctpe.paramInfos)(_.tpe <:< _)
case _ =>
false
}
val constr = firstParent.decl(nme.CONSTRUCTOR).suchThat(constr => isApplicable(constr.info))
New(firstParent, constr.symbol.asTerm, superArgs)
}
if cls.is(Trait) then TypeTree(firstParent)
else
val parentConstr = firstParent.applicableConstructors(superArgs.tpes, adaptVarargs) match
case Nil => assert(false, i"no applicable parent constructor of $firstParent for supercall arguments $superArgs")
case constr :: Nil => constr
case _ => assert(false, i"multiple applicable parent constructors of $firstParent for supercall arguments $superArgs")
New(firstParent, parentConstr.asTerm, adaptedSuperArgs(parentConstr.info))

ClassDefWithParents(cls, constr, superRef :: otherParents.map(TypeTree(_)), body)
}
end ClassDef

def ClassDefWithParents(cls: ClassSymbol, constr: DefDef, parents: List[Tree], body: List[Tree])(using Context): TypeDef = {
val selfType =
Expand All @@ -349,13 +366,18 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* @param parents a non-empty list of class types
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
* @param adaptVarargs if true, allow matching a vararg superclass constructor
* with a missing argument in superArgs, and synthesize an
* empty repeated parameter in the supercall in this case
*
* The class has the same owner as the first function in `termForwarders`.
* Its position is the union of all symbols in `termForwarders`.
*/
def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)],
typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = {
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls =>
def AnonClass(parents: List[Type],
termForwarders: List[(TermName, TermSymbol)],
typeMembers: List[(TypeName, TypeBounds)],
adaptVarargs: Boolean)(using Context): Block = {
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _), adaptVarargs) { cls =>
def forwarder(name: TermName, fn: TermSymbol) = {
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
for overridden <- fwdMeth.allOverriddenSymbols do
Expand All @@ -375,6 +397,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* with the specified owner and position.
*/
def AnonClass(owner: Symbol, parents: List[Type], coord: Coord)(body: ClassSymbol => List[Tree])(using Context): Block =
AnonClass(owner, parents, coord, adaptVarargs = false)(body)

private def AnonClass(owner: Symbol, parents: List[Type], coord: Coord, adaptVarargs: Boolean)(body: ClassSymbol => List[Tree])(using Context): Block =
val parents1 =
if (parents.head.classSymbol.is(Trait)) {
val head = parents.head.parents.head
Expand All @@ -383,7 +408,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else parents
val cls = newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1, coord = coord)
val constr = newConstructor(cls, Synthetic, Nil, Nil).entered
val cdef = ClassDef(cls, DefDef(constr), body(cls))
val cdef = ClassDef(cls, DefDef(constr), body(cls), Nil, adaptVarargs)
Block(cdef :: Nil, New(cls.typeRef, Nil))

def Import(expr: Tree, selectors: List[untpd.ImportSelector])(using Context): Import =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ object Phases {
def sbtExtractDependenciesPhase(using Context): Phase = ctx.base.sbtExtractDependenciesPhase
def picklerPhase(using Context): Phase = ctx.base.picklerPhase
def inliningPhase(using Context): Phase = ctx.base.inliningPhase
def stagingPhase(using Context): Phase = ctx.base.stagingPhase
def stagingPhase(using Context): Phase = ctx.base.stagingPhase
def splicingPhase(using Context): Phase = ctx.base.splicingPhase
def firstTransformPhase(using Context): Phase = ctx.base.firstTransformPhase
def refchecksPhase(using Context): Phase = ctx.base.refchecksPhase
Expand Down
34 changes: 30 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ package core
import TypeErasure.ErasedValueType
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
import Names.Name
import StdNames.nme

class TypeUtils {
class TypeUtils:
/** A decorator that provides methods on types
* that are needed in the transformer pipeline.
*/
extension (self: Type) {
extension (self: Type)

def isErasedValueType(using Context): Boolean =
self.isInstanceOf[ErasedValueType]
Expand Down Expand Up @@ -125,5 +126,30 @@ class TypeUtils {
def takesImplicitParams(using Context): Boolean = self.stripPoly match
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
case _ => false
}
}

/** The constructors of this type that are applicable to `argTypes`, without needing
* an implicit conversion. Curried constructors are always excluded.
* @param adaptVarargs if true, allow a constructor with just a varargs argument to
* match an empty argument list.
*/
def applicableConstructors(argTypes: List[Type], adaptVarargs: Boolean)(using Context): List[Symbol] =
def isApplicable(constr: Symbol): Boolean =
def recur(ctpe: Type): Boolean = ctpe match
case ctpe: PolyType =>
if argTypes.isEmpty then recur(ctpe.resultType) // no need to know instances
else recur(ctpe.instantiate(self.argTypes))
case ctpe: MethodType =>
var paramInfos = ctpe.paramInfos
if adaptVarargs && paramInfos.length == argTypes.length + 1
&& atPhaseNoLater(Phases.elimRepeatedPhase)(constr.info.isVarArgsMethod)
then // accept missing argument for varargs parameter
paramInfos = paramInfos.init
argTypes.corresponds(paramInfos)(_ <:< _) && !ctpe.resultType.isInstanceOf[MethodType]
case _ =>
false
recur(constr.info)

self.decl(nme.CONSTRUCTOR).altsWith(isApplicable).map(_.symbol)

end TypeUtils

21 changes: 11 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5612,17 +5612,18 @@ object Types extends TypeUtils {

def samClass(tp: Type)(using Context): Symbol = tp match
case tp: ClassInfo =>
def zeroParams(tp: Type): Boolean = tp.stripPoly match
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
case et: ExprType => true
case _ => false
val cls = tp.cls
val validCtor =
val ctor = cls.primaryConstructor
// `ContextFunctionN` does not have constructors
!ctor.exists || zeroParams(ctor.info)
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if validCtor && isInstantiable then tp.cls
def takesNoArgs(tp: Type) =
!tp.classSymbol.primaryConstructor.exists
// e.g. `ContextFunctionN` does not have constructors
|| tp.applicableConstructors(Nil, adaptVarargs = true).lengthCompare(1) == 0
// we require a unique constructor so that SAM expansion is deterministic
val noArgsNeeded: Boolean =
takesNoArgs(tp)
&& (!tp.cls.is(Trait) || takesNoArgs(tp.parents.head))
def isInstantiable =
!tp.cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if noArgsNeeded && isInstantiable then tp.cls
else NoSymbol
case tp: AppliedType =>
samClass(tp.superType)
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ class ExpandSAMs extends MiniPhase:
val tpe1 = collectAndStripRefinements(tpe)
val Seq(samDenot) = tpe1.possibleSamMethods
cpy.Block(tree)(stats,
AnonClass(List(tpe1),
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
refinements.toList
)
transformFollowingDeep:
AnonClass(List(tpe1),
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
refinements.toList,
adaptVarargs = true
)
)
}
case _ =>
Expand Down
15 changes: 15 additions & 0 deletions tests/neg/i15855.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class MyFunction(args: String)

trait MyFunction0[+R] extends MyFunction {
def apply(): R
}

def fromFunction0[R](f: Function0[R]): MyFunction0[R] = () => f() // error

class MyFunctionWithImplicit(implicit args: String)

trait MyFunction0WithImplicit[+R] extends MyFunctionWithImplicit {
def apply(): R
}

def fromFunction1[R](f: Function0[R]): MyFunction0WithImplicit[R] = () => f() // error
17 changes: 17 additions & 0 deletions tests/run/i15855.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class MyFunction(args: String*)

trait MyFunction0[+R] extends MyFunction {
def apply(): R
}

abstract class MyFunction1[R](args: R*):
def apply(): R

def fromFunction0[R](f: Function0[R]): MyFunction0[R] = () => f()
def fromFunction1[R](f: Function0[R]): MyFunction1[R] = () => f()

@main def Test =
val m0: MyFunction0[Int] = fromFunction0(() => 1)
val m1: MyFunction1[Int] = fromFunction1(() => 2)
assert(m0() == 1)
assert(m1() == 2)
Loading