Skip to content

Commit

Permalink
chore: improve method completion
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Dec 6, 2023
1 parent fb0248f commit 2d54a39
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 44 deletions.
15 changes: 5 additions & 10 deletions crates/els/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,10 @@ impl<Checker: BuildRunnable, Parser: Parsable> Server<Checker, Parser> {
self.send_log(format!("CompletionKind: {comp_kind:?}"))?;
let mut result: Vec<CompletionItem> = vec![];
let mut already_appeared = Set::new();
let contexts = if comp_kind.should_be_local() {
self.get_local_ctx(&uri, pos)
let (receiver_t, contexts) = if comp_kind.should_be_local() {
(None, self.get_local_ctx(&uri, pos))
} else {
self.get_receiver_ctxs(&uri, pos)?
self.get_receiver_and_ctxs(&uri, pos)?
};
let offset = match comp_kind {
CompletionKind::RetriggerLocal => 0,
Expand Down Expand Up @@ -586,11 +586,6 @@ impl<Checker: BuildRunnable, Parser: Parsable> Server<Checker, Parser> {
}
_ => None,
});
let receiver_t = comp_kind
.should_be_method()
.then(|| self.get_min_expr(&uri, pos, -2))
.flatten()
.map(|(_, expr)| expr.t());
let Some(mod_ctx) = self.get_mod_ctx(&uri) else {
return Ok(None);
};
Expand All @@ -600,10 +595,10 @@ impl<Checker: BuildRunnable, Parser: Parsable> Server<Checker, Parser> {
}
// only show static methods, if the receiver is a type
if vi.t.is_method()
&& receiver_t.as_ref().map_or(true, |t| {
&& receiver_t.as_ref().map_or(true, |receiver| {
!mod_ctx
.context
.subtype_of(t, vi.t.self_t().unwrap_or(Type::OBJ))
.subtype_of(receiver, vi.t.self_t().unwrap_or(Type::OBJ))
})
{
continue;
Expand Down
22 changes: 20 additions & 2 deletions crates/els/file_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use erg_common::set::Set;
use erg_common::shared::Shared;
use erg_common::traits::DequeStream;
use erg_compiler::erg_parser::lex::Lexer;
use erg_compiler::erg_parser::token::{Token, TokenCategory, TokenStream};
use erg_compiler::erg_parser::token::{Token, TokenCategory, TokenKind, TokenStream};

use crate::_log;
use crate::server::{ELSResult, RedirectableStdout};
Expand Down Expand Up @@ -176,12 +176,25 @@ impl FileCache {
/// a{pos}\n -> \n -> a
pub fn get_symbol(&self, uri: &NormalizedUrl, pos: Position) -> Option<Token> {
let mut token = self.get_token(uri, pos)?;
let mut offset = 0;
while !matches!(token.category(), TokenCategory::Symbol) {
token = self.get_token_relatively(uri, pos, -1)?;
offset -= 1;
token = self.get_token_relatively(uri, pos, offset)?;
}
Some(token)
}

pub fn get_receiver(&self, uri: &NormalizedUrl, attr_marker_pos: Position) -> Option<Token> {
let mut token = self.get_token(uri, attr_marker_pos)?;
let mut offset = 0;
while !matches!(token.kind, TokenKind::Dot | TokenKind::DblColon) {
offset -= 1;
token = self.get_token_relatively(uri, attr_marker_pos, offset)?;
}
offset -= 1;
self.get_token_relatively(uri, attr_marker_pos, offset)
}

pub fn get_token_relatively(
&self,
uri: &NormalizedUrl,
Expand All @@ -200,6 +213,11 @@ impl FileCache {
return Some(i);
}
}
for (i, tok) in tokens.iter().enumerate() {
if util::roughly_pos_in_loc(tok, pos) {
return Some(i);
}
}
None
})()?;
let index = (index as isize + offset) as usize;
Expand Down
13 changes: 12 additions & 1 deletion crates/els/hir_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ impl<'a> HIRVisitor<'a> {
Expr::Record(record) => self.get_expr_from_record(expr, record, pos),
Expr::Set(set) => self.get_expr_from_set(expr, set, pos),
Expr::Tuple(tuple) => self.get_expr_from_tuple(expr, tuple, pos),
Expr::TypeAsc(type_asc) => self.get_expr(&type_asc.expr, pos),
Expr::TypeAsc(type_asc) => self.get_expr_from_type_asc(expr, type_asc, pos),
Expr::Dummy(dummy) => self.get_expr_from_dummy(dummy, pos),
Expr::Compound(block) | Expr::Code(block) => {
self.get_expr_from_block(block.iter(), pos)
Expand Down Expand Up @@ -534,6 +534,17 @@ impl<'a> HIRVisitor<'a> {
} // _ => None, // todo!(),
}
}

fn get_expr_from_type_asc<'e>(
&'e self,
expr: &'e Expr,
type_asc: &'e TypeAscription,
pos: Position,
) -> Option<&Expr> {
self.get_expr(&type_asc.expr, pos)
.or_else(|| self.get_expr(&type_asc.spec.expr, pos))
.or_else(|| self.return_expr_if_contains(expr, pos, type_asc))
}
}

impl<'a> HIRVisitor<'a> {
Expand Down
26 changes: 14 additions & 12 deletions crates/els/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use erg_compiler::error::CompileWarning;
use erg_compiler::hir::HIR;
use erg_compiler::lower::ASTLowerer;
use erg_compiler::module::{IRs, ModuleEntry, SharedCompilerResource};
use erg_compiler::ty::HasType;
use erg_compiler::ty::{HasType, Type};

pub use molc::RedirectableStdout;
use molc::{FakeClient, LangServer};
Expand Down Expand Up @@ -873,21 +873,19 @@ impl<Checker: BuildRunnable, Parser: Parsable> Server<Checker, Parser> {
ctxs
}

pub(crate) fn get_receiver_ctxs(
pub(crate) fn get_receiver_and_ctxs(
&self,
uri: &NormalizedUrl,
attr_marker_pos: Position,
) -> ELSResult<Vec<&Context>> {
) -> ELSResult<(Option<Type>, Vec<&Context>)> {
let Some(module) = self.raw_get_mod_ctx(uri) else {
return Ok(vec![]);
return Ok((None, vec![]));
};
let maybe_token = self
.file_cache
.get_token_relatively(uri, attr_marker_pos, -2);
let maybe_token = self.file_cache.get_receiver(uri, attr_marker_pos);
if let Some(token) = maybe_token {
// self.send_log(format!("token: {token}"))?;
// _log!(self, "token: {token}");
let mut ctxs = vec![];
if let Some(visitor) = self.get_visitor(uri) {
let expr = if let Some(visitor) = self.get_visitor(uri) {
if let Some(expr) =
loc_to_pos(token.loc()).and_then(|pos| visitor.get_min_expr(pos))
{
Expand All @@ -902,14 +900,18 @@ impl<Checker: BuildRunnable, Parser: Parsable> Server<Checker, Parser> {
{
ctxs.extend(singular_ctxs);
}
Some(expr.t())
} else {
_log!(self, "expr not found: {token}");
None
}
}
Ok(ctxs)
} else {
None
};
Ok((expr, ctxs))
} else {
self.send_log("token not found")?;
Ok(vec![])
Ok((None, vec![]))
}
}

Expand Down
36 changes: 36 additions & 0 deletions crates/els/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const FILE_C: &str = "tests/c.er";
const FILE_IMPORTS: &str = "tests/imports.er";
const FILE_INVALID_SYNTAX: &str = "tests/invalid_syntax.er";
const FILE_RETRIGGER: &str = "tests/retrigger.er";
const FILE_TOLERANT_COMPLETION: &str = "tests/tolerant_completion.er";

use els::{NormalizedUrl, Server};
use erg_proc_macros::exec_new_thread;
Expand Down Expand Up @@ -129,6 +130,41 @@ fn test_completion_retrigger() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[test]
fn test_tolerant_completion() -> Result<(), Box<dyn std::error::Error>> {
let mut client = Server::bind_fake_client();
client.request_initialize()?;
client.notify_initialized()?;
let uri = NormalizedUrl::from_file_path(Path::new(FILE_TOLERANT_COMPLETION).canonicalize()?)?;
client.notify_open(FILE_TOLERANT_COMPLETION)?;
let _ = client.wait_diagnostics()?;
client.notify_change(uri.clone().raw(), add_char(2, 9, "."))?;
let resp = client.request_completion(uri.clone().raw(), 2, 10, ".")?;
if let Some(CompletionResponse::Array(items)) = resp {
assert!(items.len() >= 10);
assert!(items.iter().any(|item| item.label == "tqdm"));
} else {
return Err(format!("not items: {resp:?}").into());
}
client.notify_change(uri.clone().raw(), add_char(5, 16, "."))?;
let resp = client.request_completion(uri.clone().raw(), 5, 17, ".")?;
if let Some(CompletionResponse::Array(items)) = resp {
assert!(items.len() >= 40);
assert!(items.iter().any(|item| item.label == "capitalize"));
} else {
return Err(format!("not items: {resp:?}").into());
}
client.notify_change(uri.clone().raw(), add_char(5, 14, "."))?;
let resp = client.request_completion(uri.raw(), 5, 15, ".")?;
if let Some(CompletionResponse::Array(items)) = resp {
assert!(items.len() >= 40);
assert!(items.iter().any(|item| item.label == "abs"));
Ok(())
} else {
Err(format!("not items: {resp:?}").into())
}
}

#[test]
#[exec_new_thread]
fn test_rename() -> Result<(), Box<dyn std::error::Error>> {
Expand Down
6 changes: 6 additions & 0 deletions crates/els/tests/tolerant_completion.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
tqdm = pyimport "tqdm"

f _: tqdm
i = 1
s = "a"
g() = None + i s i
2 changes: 1 addition & 1 deletion crates/erg_compiler/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
.map(|op| op.is_import())
.unwrap_or(false) =>
{
Ok(hir::Expr::Call(self.lower_call(call, None)?))
Ok(hir::Expr::Call(self.lower_call(call, None)))
}
other => Err(LowerErrors::from(LowerError::declare_error(
self.cfg().input.clone(),
Expand Down
43 changes: 25 additions & 18 deletions crates/erg_compiler/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1267,11 +1267,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {

/// returning `Ok(call)` does not mean the call is valid, just means it is syntactically valid
/// `ASTLowerer` is designed to cause as little information loss in HIR as possible
pub(crate) fn lower_call(
&mut self,
call: ast::Call,
expect: Option<&Type>,
) -> LowerResult<hir::Call> {
pub(crate) fn lower_call(&mut self, call: ast::Call, expect: Option<&Type>) -> hir::Call {
log!(info "entered {}({}{}(...))", fn_name!(), call.obj, fmt_option!(call.attr_name));
let pushed = if let (Some(name), None) = (call.obj.get_name(), &call.attr_name) {
self.module.context.higher_order_caller.push(name.clone());
Expand All @@ -1288,7 +1284,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
self.module.context.higher_order_caller.pop();
}
errs.extend(es);
return Err(errs);
hir::Expr::Dummy(hir::Dummy::empty())
}
};
let opt_vi = self.module.context.get_call_t_without_args(
Expand Down Expand Up @@ -1375,10 +1371,12 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
self.module.context.higher_order_caller.pop();
}
if errs.is_empty() {
self.exec_additional_op(&mut call)?;
if let Err(es) = self.exec_additional_op(&mut call) {
errs.extend(es);
}
}
self.errs.extend(errs);
Ok(call)
call
}

/// importing is done in [preregister](https://github.com/erg-lang/erg/blob/ffd33015d540ff5a0b853b28c01370e46e0fcc52/crates/erg_compiler/context/register.rs#L819)
Expand Down Expand Up @@ -2676,14 +2674,27 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
) -> LowerResult<hir::TypeAscription> {
log!(info "entered {}({tasc})", fn_name!());
let kind = tasc.kind();
let spec_t = self
let spec_t = match self
.module
.context
.instantiate_typespec(&tasc.t_spec.t_spec)?;
.instantiate_typespec(&tasc.t_spec.t_spec)
{
Ok(spec_t) => spec_t,
Err(errs) => {
self.errs.extend(errs);
Type::Failure
}
};
let expect = expect.map_or(Some(&spec_t), |exp| {
self.module.context.min(exp, &spec_t).ok().or(Some(&spec_t))
});
let expr = self.lower_expr(*tasc.expr, expect)?;
let expr = match self.lower_expr(*tasc.expr, expect) {
Ok(expr) => expr,
Err(errs) => {
self.errs.extend(errs);
hir::Expr::Dummy(hir::Dummy::new(vec![]))
}
};
match kind {
AscriptionKind::TypeOf | AscriptionKind::AsCast => {
self.module.context.sub_unify(
Expand Down Expand Up @@ -2825,7 +2836,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
ast::Expr::Accessor(acc) => hir::Expr::Accessor(self.lower_acc(acc, expect)?),
ast::Expr::BinOp(bin) => hir::Expr::BinOp(self.lower_bin(bin, expect)),
ast::Expr::UnaryOp(unary) => hir::Expr::UnaryOp(self.lower_unary(unary, expect)),
ast::Expr::Call(call) => hir::Expr::Call(self.lower_call(call, expect)?),
ast::Expr::Call(call) => hir::Expr::Call(self.lower_call(call, expect)),
ast::Expr::DataPack(pack) => hir::Expr::Call(self.lower_pack(pack, expect)?),
ast::Expr::Lambda(lambda) => hir::Expr::Lambda(self.lower_lambda(lambda, expect)?),
ast::Expr::TypeAscription(tasc) => {
Expand All @@ -2835,7 +2846,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
// Checking is also performed for expressions in Dummy. However, it has no meaning in code generation
ast::Expr::Dummy(dummy) => hir::Expr::Dummy(self.lower_dummy(dummy, expect)?),
ast::Expr::InlineModule(inline) => {
hir::Expr::Call(self.lower_inline_module(inline, expect)?)
hir::Expr::Call(self.lower_inline_module(inline, expect))
}
other => {
log!(err "unreachable: {other}");
Expand Down Expand Up @@ -2926,11 +2937,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
Ok(hir::Dummy::new(hir_dummy))
}

fn lower_inline_module(
&mut self,
inline: InlineModule,
expect: Option<&Type>,
) -> LowerResult<hir::Call> {
fn lower_inline_module(&mut self, inline: InlineModule, expect: Option<&Type>) -> hir::Call {
log!(info "entered {}", fn_name!());
let Some(ast::Expr::Literal(mod_name)) = inline.import.args.get_left_or_key("Path") else {
unreachable!();
Expand Down

0 comments on commit 2d54a39

Please sign in to comment.