Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: emit as extension method if member of implicit class with val in constructor #6190

Merged
merged 3 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ class ScalaToplevelMtags(
case _ => region
}

private def isInParenthesis(region: Region): Boolean =
region match {
case (_: Region.InParenCaseClass) | (_: Region.InParenClass) =>
true
case _ => false
}

@tailrec
private def loop(
indent: Int,
Expand Down Expand Up @@ -290,12 +297,23 @@ class ScalaToplevelMtags(
)
case DEF | VAL | VAR | GIVEN
if expectTemplate.map(!_.isExtension).getOrElse(true) =>
val isImplicit =
if (isInParenthesis(region))
expectTemplate.exists(
_.isImplicit
)
else region.isImplicit
if (needEmitTermMember()) {
withOwner(currRegion.termOwner) {
emitTerm(currRegion)
emitTerm(currRegion, isImplicit)
}
} else scanner.nextToken()
loop(indent, isAfterNewline = false, currRegion, newExpectIgnoreBody)
loop(
indent,
isAfterNewline = false,
currRegion,
if (isInParenthesis(region)) expectTemplate else newExpectIgnoreBody
)
case TYPE if expectTemplate.map(!_.isExtension).getOrElse(true) =>
if (needEmitMember(currRegion) && !prevWasDot) {
withOwner(currRegion.termOwner) {
Expand Down Expand Up @@ -408,15 +426,24 @@ class ScalaToplevelMtags(
expectTemplate match {
case Some(expect)
if needToParseBody(expect) || needToParseExtension(expect) =>
val next =
expect.startInBraceRegion(
currRegion,
expect.isExtension,
expect.isImplicit
)
resetRegion(next)
scanner.nextToken()
loop(indent, isAfterNewline = false, next, None)
if (isInParenthesis(region)) {
// inside of a class constructor
// e.g. class A(val foo: Foo { type T = Int })
// ^
acceptBalancedDelimeters(LBRACE, RBRACE)
scanner.nextToken()
loop(indent, isAfterNewline = false, currRegion, expectTemplate)
} else {
val next =
expect.startInBraceRegion(
currRegion,
expect.isExtension,
expect.isImplicit
)
resetRegion(next)
scanner.nextToken()
loop(indent, isAfterNewline = false, next, None)
}
case _ =>
acceptBalancedDelimeters(LBRACE, RBRACE)
scanner.nextToken()
Expand Down Expand Up @@ -716,7 +743,8 @@ class ScalaToplevelMtags(
/**
* Enters a global element (def/val/var/given)
*/
def emitTerm(region: Region): Unit = {
def emitTerm(region: Region, isParentImplicit: Boolean): Unit = {
val extensionProperty = if (isParentImplicit) EXTENSION else 0
val kind = scanner.curr.token
acceptTrivia()
kind match {
Expand All @@ -726,7 +754,7 @@ class ScalaToplevelMtags(
name.name,
name.pos,
Kind.METHOD,
SymbolInformation.Property.VAL.value
SymbolInformation.Property.VAL.value | extensionProperty
)
resetRegion(region)
})
Expand All @@ -736,7 +764,7 @@ class ScalaToplevelMtags(
name.name,
"()",
name.pos,
SymbolInformation.Property.VAR.value
SymbolInformation.Property.VAR.value | extensionProperty
)
resetRegion(region)
})
Expand All @@ -746,7 +774,7 @@ class ScalaToplevelMtags(
name.name,
region.overloads.disambiguator(name.name),
name.pos,
0
extensionProperty
)
)
case GIVEN =>
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/src/test/scala/tests/Example.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package tests

package a
object O {
trait Foo {
type T
}

implicit class A(val foo: Foo { type T = Int }) {
def get: Int = 1
}
}
59 changes: 49 additions & 10 deletions tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ class ScalaToplevelSuite extends BaseSuite {
List(
"a/",
"a/A.",
"a/A.bar().",
"a/A.foo().",
"a/A.bar(). EXT",
"a/A.foo(). EXT",
),
mode = All,
dialect = dialects.Scala3,
Expand All @@ -370,8 +370,8 @@ class ScalaToplevelSuite extends BaseSuite {
List(
"a/",
"a/Test$package.",
"a/Test$package.bar().",
"a/Test$package.foo().",
"a/Test$package.bar(). EXT",
"a/Test$package.foo(). EXT",
),
mode = All,
dialect = dialects.Scala3,
Expand All @@ -387,8 +387,8 @@ class ScalaToplevelSuite extends BaseSuite {
| def baz: Long = ???
|""".stripMargin,
List(
"a/", "a/Test$package.", "a/Test$package.foo().", "a/Test$package.bar().",
"a/Test$package.baz().",
"a/", "a/Test$package.", "a/Test$package.foo(). EXT",
"a/Test$package.bar(). EXT", "a/Test$package.baz(). EXT",
),
mode = All,
dialect = dialects.Scala3,
Expand Down Expand Up @@ -655,6 +655,42 @@ class ScalaToplevelSuite extends BaseSuite {
mode = All,
)

check(
"refined-type",
"""|package a
|object O {
| trait Foo {
| type T
| }
|
| implicit class A(val foo: Foo { type T = Int }) {
| def get: Int = 1
| }
|}
|""".stripMargin,
List(
"a/", "a/O.", "a/O.A#", "a/O.A#foo. EXT", "a/O.A#get(). EXT", "a/O.Foo#",
"a/O.Foo#T#",
),
mode = All,
)

check(
"implicit-class-with-val",
"""|package a
|object Foo {
| implicit class IntOps(private val i: Int) extends AnyVal {
| def inc: Int = i + 1
| }
|}
|""".stripMargin,
List(
"a/", "a/Foo.", "a/Foo.IntOps# -> AnyVal", "a/Foo.IntOps#i. EXT",
"a/Foo.IntOps#inc(). EXT",
),
mode = All,
)

def check(
options: TestOptions,
code: String,
Expand All @@ -672,11 +708,14 @@ class ScalaToplevelSuite extends BaseSuite {
val includeMembers = mode == All
val (doc, overrides) =
Mtags.indexWithOverrides(input, dialect, includeMembers)
val symbols = doc.occurrences.map(_.symbol).toList
// additionalSymbolCheck(doc.symbols)
// val symbols = doc.symbols.map(_.symbol).toList
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover?

val overriddenMap = overrides.toMap
symbols.map { symbol =>
doc.symbols.map { symbolInfo =>
val symbol = symbolInfo.symbol
val suffix = if (symbolInfo.isExtension) " EXT" else ""
overriddenMap.get(symbol) match {
case None => symbol
case None => s"$symbol$suffix"
case Some(symbols) =>
val overridden =
symbols
Expand All @@ -685,7 +724,7 @@ class ScalaToplevelSuite extends BaseSuite {
case UnresolvedOverriddenSymbol(name) => name
}
.mkString(", ")
s"$symbol -> $overridden"
s"$symbol$suffix -> $overridden"
}
}
case Toplevel => Mtags.topLevelSymbols(input, dialect)
Expand Down
Loading