From 3366043a2d36da44598f16b714eef941c544ebf6 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 21 Sep 2024 17:05:57 +0900 Subject: [PATCH] fix: `Dict::get` --- .../context/initialize/classes.rs | 11 ++++---- .../context/initialize/const_func.rs | 27 ++++++++++--------- .../erg_compiler/context/initialize/funcs.rs | 5 ++-- crates/erg_compiler/lower.rs | 12 ++++++++- tests/should_err/mut_dict.er | 4 ++- tests/should_err/refinement.er | 3 +++ tests/should_ok/dict.er | 1 + tests/should_ok/refinement.er | 3 +++ tests/test.rs | 2 +- 9 files changed, 45 insertions(+), 23 deletions(-) diff --git a/crates/erg_compiler/context/initialize/classes.rs b/crates/erg_compiler/context/initialize/classes.rs index 5145da490..3800d5102 100644 --- a/crates/erg_compiler/context/initialize/classes.rs +++ b/crates/erg_compiler/context/initialize/classes.rs @@ -2359,14 +2359,13 @@ impl Context { ))); dict_.register_builtin_const(FUNC_AS_RECORD, Visibility::BUILTIN_PUBLIC, None, as_record); let Def = type_q(TY_DEFAULT); + let K = type_q(TY_K); + let V = type_q(TY_V); let get_t = no_var_fn_met( - dict_t.clone(), - vec![kw(KW_KEY, T.clone())], + dict! { K.clone() => V.clone() }.into(), + vec![kw(KW_KEY, K.clone())], vec![kw_default(KW_DEFAULT, Def.clone(), NoneType)], - or( - proj_call(D.clone(), FUNDAMENTAL_GETITEM, vec![ty_tp(T.clone())]), - Def, - ), + or(V.clone(), Def), ) .quantify(); dict_.register_py_builtin(FUNC_GET, get_t, Some(FUNC_GET), 9); diff --git a/crates/erg_compiler/context/initialize/const_func.rs b/crates/erg_compiler/context/initialize/const_func.rs index e5a556d6b..eec8979ee 100644 --- a/crates/erg_compiler/context/initialize/const_func.rs +++ b/crates/erg_compiler/context/initialize/const_func.rs @@ -239,21 +239,21 @@ pub(crate) fn sub_vdict_get<'d>( ) -> Option<&'d ValueObj> { let mut matches = vec![]; for (k, v) in dict.iter() { - match (key, k) { - (ValueObj::Type(idx), ValueObj::Type(kt)) - if ctx.subtype_of(&idx.typ().lower_bounded(), &kt.typ().lower_bounded()) => + if key == k { + return Some(v); + } + match (ctx.convert_value_into_type(key.clone()), ctx.convert_value_into_type(k.clone())) { + (Ok(idx), Ok(kt)) + if ctx.subtype_of(&idx.lower_bounded(), &kt.lower_bounded()) /*|| dict.len() == 1*/ => { matches.push((idx, kt, v)); } - (idx, k) if idx == k => { - return Some(v); - } _ => {} } } for (idx, kt, v) in matches.into_iter() { let list = UndoableLinkedList::new(); - match ctx.undoable_sub_unify(idx.typ(), kt.typ(), &(), &list, None) { + match ctx.undoable_sub_unify(&idx, &kt, &(), &list, None) { Ok(_) => { return Some(v); } @@ -272,21 +272,24 @@ pub(crate) fn sub_tpdict_get<'d>( ) -> Option<&'d TyParam> { let mut matches = vec![]; for (k, v) in dict.iter() { - match (<&Type>::try_from(key), <&Type>::try_from(k)) { + if key == k { + return Some(v); + } + match ( + ctx.convert_tp_into_type(key.clone()), + ctx.convert_tp_into_type(k.clone()), + ) { (Ok(idx), Ok(kt)) if ctx.subtype_of(&idx.lower_bounded(), &kt.lower_bounded()) || dict.len() == 1 => { matches.push((idx, kt, v)); } - (_, _) if key == k => { - return Some(v); - } _ => {} } } for (idx, kt, v) in matches.into_iter() { let list = UndoableLinkedList::new(); - match ctx.undoable_sub_unify(idx, kt, &(), &list, None) { + match ctx.undoable_sub_unify(&idx, &kt, &(), &list, None) { Ok(_) => { return Some(v); } diff --git a/crates/erg_compiler/context/initialize/funcs.rs b/crates/erg_compiler/context/initialize/funcs.rs index 261627406..08bb15e08 100644 --- a/crates/erg_compiler/context/initialize/funcs.rs +++ b/crates/erg_compiler/context/initialize/funcs.rs @@ -1077,8 +1077,9 @@ impl Context { ); let E = mono_q(TY_E, subtypeof(mono(EQ))); let E2 = mono_q(TY_E, subtypeof(mono(IRREGULAR_EQ))); - let op_t = bin_op(E.clone(), E, Bool).quantify() - & bin_op(E2.clone(), E2.clone(), E2.proj(OUTPUT)).quantify(); + let op_t = (bin_op(E.clone(), E, Bool).quantify() + & bin_op(E2.clone(), E2.clone(), E2.proj(OUTPUT)).quantify()) + .with_default_intersec_index(0); self.register_builtin_py_impl( OP_EQ, op_t.clone(), diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 2a2d752f4..978607aa8 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -1221,7 +1221,17 @@ impl GenericASTLowerer { Some(guard(namespace, target, to)) } TokenKind::Symbol if &op.content[..] == "isinstance" => { - let to = self.module.context.expr_to_type(rhs.clone()).ok()?; + // isinstance(x, (T, U)) => x: T or U + let to = if let ast::Expr::Tuple(ast::Tuple::Normal(tys)) = rhs { + tys.elems.pos_args.iter().fold(Type::Never, |acc, ex| { + let Ok(ty) = self.module.context.expr_to_type(ex.expr.clone()) else { + return acc; + }; + self.module.context.union(&acc, &ty) + }) + } else { + self.module.context.expr_to_type(rhs.clone()).ok()? + }; Some(guard(namespace, target, to)) } TokenKind::IsOp | TokenKind::DblEq => { diff --git a/tests/should_err/mut_dict.er b/tests/should_err/mut_dict.er index 461cee210..331970a2e 100644 --- a/tests/should_err/mut_dict.er +++ b/tests/should_err/mut_dict.er @@ -4,4 +4,6 @@ dict = !d dict.insert! "b", 2 _ = dict.get("a") == "a" # ERR _ = dict.get("b") == "a" # ERR -_ = dict.get("c") # ERR +_ = dict.get("c") # OK +_ = dict["b"] # OK +_ = dict["c"] # ERR diff --git a/tests/should_err/refinement.er b/tests/should_err/refinement.er index b116874e2..d1e76bcd1 100644 --- a/tests/should_err/refinement.er +++ b/tests/should_err/refinement.er @@ -13,3 +13,6 @@ _: {I: Int | (I < 5 or I != 3) and I != 4} = 4 # ERR check _: {S: Str | S.replace("abc", "") == ""} = None check "abcd" # ERR + +dic as Dict({{111}: {222}}) = {111: 222} +_ = dic[333] # ERR diff --git a/tests/should_ok/dict.er b/tests/should_ok/dict.er index 8ef36e093..841a24cac 100644 --- a/tests/should_ok/dict.er +++ b/tests/should_ok/dict.er @@ -6,5 +6,6 @@ for! {"a": 1, "b": 2}.values(), i => dic = { "a": 1, "b": 2 } assert dic.concat({ "c": 3 }) == { "a": 1, "b": 2, "c": 3 } assert dic.diff({ "a": 1 }) == { "b": 2 } +assert dic.get("a"+"b", 3) == 3 rec = dic.as_record() assert rec.a == 1 and rec.b == 2 diff --git a/tests/should_ok/refinement.er b/tests/should_ok/refinement.er index 3a2dd6e15..141d700e5 100644 --- a/tests/should_ok/refinement.er +++ b/tests/should_ok/refinement.er @@ -6,3 +6,6 @@ _: {I: Int | I < 5 or I != 3 and I != 4} = 4 check _: {S: Str | S.replace("abc", "") == ""} = None check "abc" + +dic as Dict({{111}: {222}}) = {111: 222} +_: {222} = dic[111] diff --git a/tests/test.rs b/tests/test.rs index e2704fa80..2e6c5bf90 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -766,7 +766,7 @@ fn exec_recursive_fn_err() -> Result<(), ()> { #[test] fn exec_refinement_err() -> Result<(), ()> { - expect_failure("tests/should_err/refinement.er", 0, 9) + expect_failure("tests/should_err/refinement.er", 0, 10) } #[test]