Skip to content

Commit

Permalink
Merge pull request #339 from sjrd/lambda-sam-class-symbol
Browse files Browse the repository at this point in the history
Add Lambda.samClassSymbol to get the class symbol of the SAM type.
  • Loading branch information
sjrd authored Aug 7, 2023
2 parents 713d2f8 + bdfc947 commit 9dd543d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
15 changes: 14 additions & 1 deletion tasty-query/shared/src/main/scala/tastyquery/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ object Trees {
* @param tpt Defined only if the lambda's type is a SAMtype rather than a function type.
*/
final case class Lambda(meth: TermTree, tpt: Option[TypeTree])(span: Span) extends TermTree(span) {
protected final def calculateType(using Context): TermType = tpt match
protected final def calculateType(using Context): Type = tpt match
case Some(tpt) =>
tpt.toType

Expand All @@ -548,6 +548,19 @@ object Trees {
else functionNTypeRef.appliedTo(methodType.paramTypes :+ methodType.resultType.asInstanceOf[Type])
end calculateType

/** The class symbol of the SAM type of this lambda.
*
* A `Lambda` can be considered as an anonymous class of the form `new tpt { ... }`.
* Given that observation, `samClassSymbol` represents the `parentClasses.head` of that
* hypothetical anonymous class.
*
* When `tpt` is `None`, `samClassSymbol` will be one of the `scala.FunctionN` classes.
*/
def samClassSymbol(using Context): ClassSymbol =
tpe.requireType.classSymbol.getOrElse {
throw InvalidProgramStructureException(s"Non-class type $tpe for SAM type of $this")
}

override final def withSpan(span: Span): Lambda = Lambda(meth, tpt)(span)
}

Expand Down
8 changes: 8 additions & 0 deletions tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1295,16 +1295,23 @@ class TypeSuite extends UnrestrictedUnpicklingSuite {
def getRhsOf(name: String): TermTree =
FunctionClass.findNonOverloadedDecl(termName(name)).tree.get.asInstanceOf[ValDef].rhs.get

def findLambdaOf(tree: TermTree): Lambda = tree match
case tree: Lambda => tree
case Block(_, expr) => findLambdaOf(expr)
case _ => fail("lambda expected", clues(tree))

// val functionLambda = (x: Int) => x + 1
val functionLambda = getRhsOf("functionLambda")
assert(
clue(functionLambda.tpe)
.isApplied(_.isRef(defn.FunctionNClass(1)), List(_.isRef(defn.IntClass), _.isRef(defn.IntClass)))
)
assert(clue(findLambdaOf(functionLambda).samClassSymbol) == defn.FunctionNClass(1))

// val samLambda: Runnable = () => ()
val samLambda = getRhsOf("samLambda")
assert(clue(samLambda.tpe).isRef(ctx.findTopLevelClass("java.lang.Runnable")))
assert(clue(findLambdaOf(samLambda).samClassSymbol) == ctx.findTopLevelClass("java.lang.Runnable"))

// val polyID: [T] => T => T = [T] => (x: T) => x
val polyID = getRhsOf("polyID")
Expand Down Expand Up @@ -1346,6 +1353,7 @@ class TypeSuite extends UnrestrictedUnpicklingSuite {
case dependentIDTpe =>
fail(s"unexpected dependentID type: $dependentIDTpe")
end match
assert(clue(findLambdaOf(dependentID).samClassSymbol) == defn.FunctionNClass(1))
}

testWithContext("varargs") {
Expand Down

0 comments on commit 9dd543d

Please sign in to comment.