Skip to content

Commit

Permalink
Filter/count erased parameters directly on parameters types
Browse files Browse the repository at this point in the history
We can filter the erased parameters by looking at the `ErasedParamAnnot`.

[Cherry-picked 7c0a848][modified]
  • Loading branch information
WojciechMazur committed Jun 20, 2024
1 parent 5f19c1d commit e0e421e
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def etaExpandCFT(using Context): Tree =
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
case defn.ContextFunctionType(argTypes, resType, _) =>
case defn.ContextFunctionType(argTypes, resType) =>
val anonFun = newAnonFun(
ctx.owner,
MethodType.companion(isContextual = true)(argTypes, resType),
Expand Down
7 changes: 3 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1881,14 +1881,13 @@ class Definitions {
* types `As`, the result type `B` and a whether the type is an erased context function.
*/
object ContextFunctionType:
def unapply(tp: Type)(using Context): Option[(List[Type], Type, List[Boolean])] =
def unapply(tp: Type)(using Context): Option[(List[Type], Type)] =
asContextFunctionType(tp) match
case PolyFunctionOf(mt: MethodType) =>
Some((mt.paramInfos, mt.resType, mt.erasedParams))
Some((mt.paramInfos, mt.resType))
case tp1 if tp1.exists =>
val args = tp1.functionArgInfos
val erasedParams = List.fill(functionArity(tp1)) { false }
Some((args.init, args.last, erasedParams))
Some((args.init, args.last))
case _ => None

/* Returns a list of erased booleans marking whether parameters are erased, for a function type. */
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/Bridges.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType)
else
val mtWithoutErasedParams = atPhase(erasurePhase) {
val defn.ContextFunctionType(argTypes, resType, erasedParams) = tp.dealias: @unchecked
val paramInfos = argTypes.zip(erasedParams).collect { case (argType, erased) if !erased => argType }
val defn.ContextFunctionType(argTypes, resType) = tp.dealias: @unchecked
val paramInfos = argTypes.filterNot(_.hasAnnotation(defn.ErasedParamAnnot))
MethodType(paramInfos, resType)
}
val anonFun = newAnonFun(ctx.owner, mtWithoutErasedParams, coord = ctx.owner.coord)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object ContextFunctionResults:
*/
def annotateContextResults(mdef: DefDef)(using Context): Unit =
def contextResultCount(rhs: Tree, tp: Type): Int = tp match
case defn.ContextFunctionType(_, resTpe, _) =>
case defn.ContextFunctionType(_, resTpe) =>
rhs match
case closureDef(meth) => 1 + contextResultCount(meth.rhs, resTpe)
case _ => 0
Expand Down Expand Up @@ -58,7 +58,8 @@ object ContextFunctionResults:
*/
def contextResultsAreErased(sym: Symbol)(using Context): Boolean =
def allErased(tp: Type): Boolean = tp.dealias match
case defn.ContextFunctionType(_, resTpe, erasedParams) => !erasedParams.contains(false) && allErased(resTpe)
case defn.ContextFunctionType(argTpes, resTpe) =>
argTpes.forall(_.hasAnnotation(defn.ErasedParamAnnot)) && allErased(resTpe)
case _ => true
contextResultCount(sym) > 0 && allErased(sym.info.finalResultType)

Expand All @@ -72,7 +73,7 @@ object ContextFunctionResults:
integrateContextResults(rt, crCount)
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount))
case defn.ContextFunctionType(argTypes, resType, erasedParams) =>
case defn.ContextFunctionType(argTypes, resType) =>
MethodType(argTypes, integrateContextResults(resType, crCount - 1))

/** The total number of parameters of method `sym`, not counting
Expand All @@ -83,9 +84,10 @@ object ContextFunctionResults:
def contextParamCount(tp: Type, crCount: Int): Int =
if crCount == 0 then 0
else
val defn.ContextFunctionType(params, resTpe, erasedParams) = tp: @unchecked
val defn.ContextFunctionType(params, resTpe) = tp: @unchecked
val rest = contextParamCount(resTpe, crCount - 1)
if erasedParams.contains(true) then erasedParams.count(_ == false) + rest else params.length + rest
val nonErasedParams = params.count(!_.hasAnnotation(defn.ErasedParamAnnot))
nonErasedParams + rest

def normalParamCount(tp: Type): Int = tp.widenExpr.stripPoly match
case mt @ MethodType(pnames) =>
Expand All @@ -103,7 +105,7 @@ object ContextFunctionResults:
def recur(tp: Type, n: Int): Type =
if n == 0 then tp
else tp match
case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1)
case defn.ContextFunctionType(_, resTpe) => recur(resTpe, n - 1)
recur(meth.info.finalResultType, depth)

/** Should selection `tree` be eliminated since it refers to an `apply`
Expand All @@ -118,7 +120,7 @@ object ContextFunctionResults:
case Select(qual, name) =>
if name == nme.apply then
qual.tpe match
case defn.ContextFunctionType(_, _, _) =>
case defn.ContextFunctionType(_, _) =>
integrateSelect(qual, n + 1)
case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs
integrateSelect(qual, n + 1)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ object ErrorReporting {
val normPt = normalize(pt, pt)

def contextFunctionCount(tp: Type): Int = tp.stripped match
case defn.ContextFunctionType(_, restp, _) => 1 + contextFunctionCount(restp)
case defn.ContextFunctionType(_, restp) => 1 + contextFunctionCount(restp)
case _ => 0
def strippedTpCount = contextFunctionCount(tree.tpe) - contextFunctionCount(normTp)
def strippedPtCount = contextFunctionCount(pt) - contextFunctionCount(normPt)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1885,7 +1885,7 @@ class Namer { typer: Typer =>
val originalTp = defaultParamType
val approxTp = wildApprox(originalTp)
approxTp.stripPoly match
case atp @ defn.ContextFunctionType(_, resType, _)
case atp @ defn.ContextFunctionType(_, resType)
if !defn.isNonRefinedFunction(atp) // in this case `resType` is lying, gives us only the non-dependent upper bound
|| resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) =>
originalTp
Expand Down

0 comments on commit e0e421e

Please sign in to comment.