Skip to content

Commit

Permalink
feat: add value completions for union types
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Jun 13, 2024
1 parent e46edc9 commit cb7312b
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -402,4 +402,10 @@ trait CommonMtagsEnrichments {
}
}

implicit class CommonXtensionList[T](lst: List[T]) {
def get(i: Int): Option[T] =
if (i >= 0 && i < lst.length) Some(lst(i))
else None
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ object CompletionValue:
description
override def insertMode: Option[InsertTextMode] = Some(InsertTextMode.AsIs)

case class SingletonValue(label: String, info: Type, override val range: Option[Range])
extends CompletionValue:
override def insertText: Option[String] = Some(label)
override def labelWithDescription(printer: MetalsPrinter)(using Context): String =
s"$label: ${printer.tpe(info)}"

override def completionItemKind(using Context): CompletionItemKind =
CompletionItemKind.Constant


def namedArg(label: String, sym: ParamSymbol)(using
Context
): CompletionValue =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class Completions(
val ScalaCliCompletions =
new ScalaCliCompletions(coursierComplete, pos, text)

path match
val (advanced, exclusive) = path match
case ScalaCliCompletions(dependency) =>
(ScalaCliCompletions.contribute(dependency), true)

Expand Down Expand Up @@ -533,7 +533,10 @@ class Completions(
config.isCompletionSnippetsEnabled,
)
(args, false)
end match
val singletonCompletions = InterCompletionType.inferType(path).map(
SingletonCompletions.contribute(path, _, completionPos)
).getOrElse(Nil)
(singletonCompletions ++ advanced, exclusive)
end advancedCompletions

private def isAmmoniteCompletionPosition(
Expand Down Expand Up @@ -696,6 +699,7 @@ class Completions(
case fileSysMember: CompletionValue.FileSystemMember =>
(fileSysMember.label, true)
case ii: CompletionValue.IvyImport => (ii.label, true)
case sv: CompletionValue.SingletonValue => (sv.label, true)

if !isSeen(id) && include then
isSeen += id
Expand Down Expand Up @@ -904,37 +908,18 @@ class Completions(
else 2
},
)

/**
* This one is used for the following case:
* ```scala
* def foo(argument: Int): Int = ???
* val argument = 42
* foo(arg@@) // completions should be ordered as :
* // - argument (local val) - actual value comes first
* // - argument = ... (named arg) - named arg after
* // - ... all other options
* ```
*/
def compareInApplyParams(o1: CompletionValue, o2: CompletionValue): Int =
def prioritizeByClass(o1: CompletionValue, o2: CompletionValue): Int =
def priority(v: CompletionValue): Int =
v match
case _: CompletionValue.Compiler => 0
case _ => 1
case _: CompletionValue.SingletonValue => 0
case _: CompletionValue.Compiler => 1
case _: CompletionValue.CaseKeyword => 2
case _: CompletionValue.NamedArg => 3
case _: CompletionValue.Keyword => 4
case _ => 5

priority(o1) - priority(o2)
end compareInApplyParams

def prioritizeKeywords(o1: CompletionValue, o2: CompletionValue): Int =
def priority(v: CompletionValue): Int =
v match
case _: CompletionValue.CaseKeyword => 0
case _: CompletionValue.NamedArg => 1
case _: CompletionValue.Keyword => 2
case _ => 3

priority(o1) - priority(o2)
end prioritizeKeywords
end prioritizeByClass

/**
* Some completion values should be shown first such as CaseKeyword and
Expand Down Expand Up @@ -1010,12 +995,9 @@ class Completions(
end if
end if
case _ =>
val byApplyParams = compareInApplyParams(o1, o2)
if byApplyParams != 0 then byApplyParams
else
val keywords = prioritizeKeywords(o1, o2)
if keywords != 0 then keywords
else compareByRelevance(o1, o2)
val byClass = prioritizeByClass(o1, o2)
if byClass != 0 then byClass
else compareByRelevance(o1, o2)
end compare

end Completions
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package scala.meta.internal.pc.completions

import scala.meta.internal.metals.Fuzzy
import scala.meta.internal.mtags.MtagsEnrichments.*
import scala.meta.internal.mtags.MtagsEnrichments.metalsDealias
import scala.meta.internal.pc.completions.CompletionValue.SingletonValue

import dotty.tools.dotc.ast.tpd.*
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Flags
import dotty.tools.dotc.core.StdNames
import dotty.tools.dotc.core.Types.AndType
import dotty.tools.dotc.core.Types.AppliedType
import dotty.tools.dotc.core.Types.ConstantType
import dotty.tools.dotc.core.Types.OrType
import dotty.tools.dotc.core.Types.Type
import dotty.tools.dotc.core.Types.TypeRef
import dotty.tools.dotc.util.Spans.Span

object SingletonCompletions:
def contribute(
path: List[Tree],
tpe: Type,
completionPos: CompletionPos
)(using ctx: Context): List[CompletionValue] =
for
(name, span) <-
path match
case (i @ Ident(name)) :: _ => List(name.toString() -> i.span)
case (l @ Literal(const)) :: _ => List(const.show -> l.span)
case _ => Nil
query = name.replace(Cursor.value, "")
singletonValues = collectSingletons(tpe).map(_.show)
range = completionPos.cursorPos.withStart(span.start).withEnd(span.start + query.length).toLsp
value <- singletonValues.collect:
case name if Fuzzy.matches(query, name) =>
SingletonValue(name, tpe, Some(range))
yield value

private def collectSingletons(tpe: Type)(using Context): List[Constant] =
tpe.metalsDealias match
case ConstantType(value) => List(value)
case OrType(tpe1, tpe2) =>
collectSingletons(tpe1) ++ collectSingletons(tpe2)
case AndType(tpe1, tpe2) =>
collectSingletons(tpe1).intersect(collectSingletons(tpe2))
case _ => Nil

object InterCompletionType:
def inferType(path: List[Tree])(using Context): Option[Type] =
path match
case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(Literal(Constant(null)))) :: rest => inferType(rest, lit.span)
case ident :: rest => inferType(rest, ident.span)
case _ => None

def inferType(path: List[Tree], span: Span)(using Context): Option[Type] =
path match
// List(@@)
case SeqLiteral(_, tpe) :: _ if !tpe.tpe.isErroneous => Some(tpe.tpe)
case Block(_, expr) :: rest if expr.span.contains(span) =>
inferType(rest, span)
case If(cond, _, _) :: rest if !cond.span.contains(span) =>
inferType(rest, span)
case (defn: ValOrDefDef) :: rest if !defn.tpt.tpe.isErroneous => Some(defn.tpt.tpe)
case CaseDef(_, _, body) :: Match(_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) =>
inferType(rest, span)
// f(@@)
case (app: Apply) :: rest =>
val param =
for
ind <- app.args.zipWithIndex.collectFirst:
case (arg, id) if arg.span.contains(span) => id
params <- app.symbol.paramSymss.find(!_.exists(_.isTypeParam))
param <- params.get(ind)
yield param.info
param match
// def f[T](a: T): T = ???
// f[Int](@@)
// val _: Int = f(@@)
case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) =>
for
(typeParams, args) <-
app match
case Apply(TypeApply(fun, args), _) =>
val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam))
typeParams.map((_, args.map(_.tpe)))
// val f: (j: "a") => Int
// f(@@)
case Apply(Select(v, StdNames.nme.apply), _) =>
v.symbol.info match
case AppliedType(des, args) =>
Some((des.typeSymbol.typeParams, args))
case _ => None
case _ => None
ind = typeParams.indexOf(t.symbol)
tpe <- args.get(ind)
if !tpe.isErroneous
yield tpe
case Some(tpe) => Some(tpe)
case _ => None
case _ => None

Loading

0 comments on commit cb7312b

Please sign in to comment.