diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 9c8879f8f..7bc17bf7c 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -7,7 +7,7 @@ use erg_common::dict::Dict; use erg_common::set::Set; use erg_common::style::colors::DEBUG_ERROR; use erg_common::traits::StructuralEq; -use erg_common::{assume_unreachable, log, set_recursion_limit}; +use erg_common::{assume_unreachable, log, set, set_recursion_limit}; use erg_common::{Str, Triple}; use crate::context::eval::UndoableLinkedList; @@ -1457,6 +1457,7 @@ impl Context { /// union(List(Int, 2), List(Str, 2)) == List(Int or Str, 2) /// union(List(Int, 2), List(Str, 3)) == List(Int, 2) or List(Int, 3) /// union(List(Int, 2), List(Int, ?N)) == List(Int, ?N) + /// union(List(Int, 2), List(?T, 2)) == List(Int or ?T, 2) /// union({ .a = Int }, { .a = Str }) == { .a = Int or Str } /// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information /// union((A and B) or C) == (A or C) and (B or C) @@ -1566,6 +1567,8 @@ impl Context { /// union_tp(1, 1) => Some(1) /// union_tp(1, 2) => None /// union_tp(?N, 2) => Some(2) # REVIEW: + /// union_tp(_: Obj, 1) == Some(_: Obj) + /// union_tp(_: Nat, -1) == Some(_: Int) /// ``` pub(crate) fn union_tp(&self, lhs: &TyParam, rhs: &TyParam) -> Option { match (lhs, rhs) { @@ -1590,6 +1593,40 @@ impl Context { } Some(TyParam::List(tps)) } + (TyParam::Tuple(l), TyParam::Tuple(r)) => { + let mut tps = vec![]; + for (l, r) in l.iter().zip(r.iter()) { + if let Some(tp) = self.union_tp(l, r) { + tps.push(tp); + } else { + return None; + } + } + Some(TyParam::Tuple(tps)) + } + (TyParam::UnsizedList(l), TyParam::UnsizedList(r)) => { + Some(TyParam::unsized_list(self.union_tp(l, r)?)) + } + (TyParam::Set(l), TyParam::Set(r)) if l.len() == 1 && r.len() == 1 => { + let l = l.iter().next().unwrap(); + let r = r.iter().next().unwrap(); + Some(TyParam::Set(set! { self.union_tp(l, r)? })) + } + (TyParam::Record(l), TyParam::Record(r)) if l.len() == 1 && r.len() == 1 => { + let mut tps = Dict::new(); + for (l_k, l_v) in l.iter() { + if let Some(r_v) = r.get(l_k) { + if let Some(tp) = self.union_tp(l_v, r_v) { + tps.insert(l_k.clone(), tp); + } else { + return None; + } + } else { + return None; + } + } + Some(TyParam::Record(tps)) + } (fv @ TyParam::FreeVar(f), other) | (other, fv @ TyParam::FreeVar(f)) if f.is_unbound() => { @@ -1601,6 +1638,31 @@ impl Context { None } } + (TyParam::Erased(t), other) | (other, TyParam::Erased(t)) => { + let other_t = self.get_tp_t(other).ok()?.derefine(); + Some(TyParam::erased(self.union(t, &other_t))) + } + ( + TyParam::App { + name: ln, + args: las, + }, + TyParam::App { + name: rn, + args: ras, + }, + ) if ln == rn => { + debug_assert_eq!(las.len(), ras.len()); + let mut unified_args = vec![]; + for (lp, rp) in las.iter().zip(ras.iter()) { + if let Some(union) = self.union_tp(lp, rp) { + unified_args.push(union); + } else { + return None; + } + } + Some(TyParam::app(ln.clone(), unified_args)) + } (_, _) => { if self.eq_tp(lhs, rhs) { Some(lhs.clone()) @@ -1811,6 +1873,12 @@ impl Context { } } + /// ```erg + /// intersection_tp(1, 1) => Some(1) + /// intersection_tp(1, 2) => None + /// intersection_tp(?N, 2) => Some(2) # REVIEW: + /// intersection_tp(_: Nat, 1) == Some(1) + /// intersection_tp(_: Str, 1) == None pub(crate) fn intersection_tp(&self, lhs: &TyParam, rhs: &TyParam) -> Option { match (lhs, rhs) { (TyParam::Value(ValueObj::Type(l)), TyParam::Value(ValueObj::Type(r))) => { @@ -1834,6 +1902,40 @@ impl Context { } Some(TyParam::List(tps)) } + (TyParam::Tuple(l), TyParam::Tuple(r)) => { + let mut tps = vec![]; + for (l, r) in l.iter().zip(r.iter()) { + if let Some(tp) = self.intersection_tp(l, r) { + tps.push(tp); + } else { + return None; + } + } + Some(TyParam::Tuple(tps)) + } + (TyParam::UnsizedList(l), TyParam::UnsizedList(r)) => { + Some(TyParam::unsized_list(self.intersection_tp(l, r)?)) + } + (TyParam::Set(l), TyParam::Set(r)) if l.len() == 1 && r.len() == 1 => { + let l = l.iter().next().unwrap(); + let r = r.iter().next().unwrap(); + Some(TyParam::Set(set! { self.intersection_tp(l, r)? })) + } + (TyParam::Record(l), TyParam::Record(r)) if l.len() == 1 && r.len() == 1 => { + let mut tps = Dict::new(); + for (l_k, l_v) in l.iter() { + if let Some(r_v) = r.get(l_k) { + if let Some(tp) = self.intersection_tp(l_v, r_v) { + tps.insert(l_k.clone(), tp); + } else { + return None; + } + } else { + return None; + } + } + Some(TyParam::Record(tps)) + } (fv @ TyParam::FreeVar(f), other) | (other, fv @ TyParam::FreeVar(f)) if f.is_unbound() => { @@ -1845,6 +1947,36 @@ impl Context { None } } + (TyParam::Erased(t), other) | (other, TyParam::Erased(t)) => { + let other_t = self.get_tp_t(other).ok()?.derefine(); + let isec = self.intersection(t, &other_t); + if isec != Never { + Some(other.clone()) + } else { + None + } + } + ( + TyParam::App { + name: ln, + args: las, + }, + TyParam::App { + name: rn, + args: ras, + }, + ) if ln == rn => { + debug_assert_eq!(las.len(), ras.len()); + let mut unified_args = vec![]; + for (lp, rp) in las.iter().zip(ras.iter()) { + if let Some(intersec) = self.intersection_tp(lp, rp) { + unified_args.push(intersec); + } else { + return None; + } + } + Some(TyParam::app(ln.clone(), unified_args)) + } (_, _) => { if self.eq_tp(lhs, rhs) { Some(lhs.clone()) diff --git a/crates/erg_compiler/context/initialize/classes.rs b/crates/erg_compiler/context/initialize/classes.rs index a8f031ef6..b8e669937 100644 --- a/crates/erg_compiler/context/initialize/classes.rs +++ b/crates/erg_compiler/context/initialize/classes.rs @@ -2370,7 +2370,7 @@ impl Context { let V = type_q(TY_V); let get_t = no_var_fn_met( dict! { K.clone() => V.clone() }.into(), - vec![kw(KW_KEY, K.clone())], + vec![kw(KW_KEY, Obj)], vec![kw_default(KW_DEFAULT, Def.clone(), NoneType)], or(V.clone(), Def), ) @@ -3855,19 +3855,15 @@ impl Context { ); /* Dict! */ let dict_mut_t = poly(MUT_DICT, vec![D.clone()]); + let dict_kv_t = poly(DICT, vec![dict! { K.clone() => V.clone() }.into()]); + let dict_mut_kv_t = poly(MUT_DICT, vec![dict! { K.clone() => V.clone() }.into()]); let mut dict_mut = Self::builtin_poly_class(MUT_DICT, vec![PS::named_nd(TY_D, mono(GENERIC_DICT))], 3); dict_mut.register_superclass(dict_t.clone(), &dict_); let K = type_q(TY_K); let V = type_q(TY_V); let insert_t = pr_met( - ref_mut( - dict_mut_t.clone(), - Some(poly( - MUT_DICT, - vec![D.clone() + dict! { K.clone() => V.clone() }.into()], - )), - ), + ref_mut(dict_mut_kv_t.clone(), None), vec![kw(KW_KEY, K.clone()), kw(KW_VALUE, V.clone())], None, vec![], @@ -3876,30 +3872,16 @@ impl Context { .quantify(); dict_mut.register_py_builtin(PROC_INSERT, insert_t, Some(FUNDAMENTAL_SETITEM), 12); let remove_t = pr_met( - ref_mut( - dict_mut_t.clone(), - Some(poly( - MUT_DICT, - vec![D - .clone() - .proj_call(FUNC_DIFF.into(), vec![dict! { K.clone() => Never }.into()])], - )), - ), + ref_mut(dict_mut_kv_t.clone(), None), vec![kw(KW_KEY, K.clone())], None, vec![], - proj_call(D.clone(), FUNDAMENTAL_GETITEM, vec![ty_tp(K.clone())]) | NoneType, + V.clone() | NoneType, ) .quantify(); dict_mut.register_py_builtin(PROC_REMOVE, remove_t, Some(FUNC_REMOVE), 19); let update_t = pr_met( - ref_mut( - dict_mut_t.clone(), - Some(poly( - MUT_DICT, - vec![D.clone() + dict! { K.clone() => V.clone() }.into()], - )), - ), + ref_mut(dict_mut_kv_t.clone(), None), vec![kw( KW_ITERABLE, poly(ITERABLE, vec![ty_tp(tuple_t(vec![K.clone(), V.clone()]))]), @@ -3914,14 +3896,8 @@ impl Context { .quantify(); dict_mut.register_py_builtin(PROC_UPDATE, update_t, Some(FUNC_UPDATE), 26); let merge_t = pr_met( - ref_mut( - dict_mut_t.clone(), - Some(poly( - MUT_DICT, - vec![D.proj_call(FUNC_CONCAT.into(), vec![D2.clone()])], - )), - ), - vec![kw(KW_OTHER, poly(DICT, vec![D2.clone()]))], + ref_mut(dict_mut_kv_t.clone(), None), + vec![kw(KW_OTHER, dict_kv_t.clone())], None, vec![], NoneType, diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index addfabe9e..efd95746f 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -4045,7 +4045,7 @@ impl Context { pub fn is_trait(&self, typ: &Type) -> bool { match typ { Type::Never => false, - Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()), + Type::FreeVar(fv) if fv.is_linked() => self.is_trait(&fv.crack()), Type::FreeVar(_) => false, Type::And(tys, _) => tys.iter().any(|t| self.is_trait(t)), Type::Or(tys) => tys.iter().all(|t| self.is_trait(t)), diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index b9d3d5526..395b780e8 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -367,7 +367,8 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { if sub.len() == 1 && sup.len() == 1 { let sub_key = sub.keys().next().unwrap(); let sup_key = sup.keys().next().unwrap(); - self.sub_unify_value(sub_key, sup_key)?; + // contravariant + self.sub_unify_value(sup_key, sub_key)?; let sub_value = sub.values().next().unwrap(); let sup_value = sup.values().next().unwrap(); self.sub_unify_value(sub_value, sup_value)?; @@ -698,6 +699,16 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { Ok(()) } (TyParam::Dict(sub), TyParam::Dict(sup)) => { + if sub.len() == 1 && sup.len() == 1 { + let sub_key = sub.keys().next().unwrap(); + let sup_key = sup.keys().next().unwrap(); + // contravariant + self.sub_unify_tp(sup_key, sub_key, _variance, allow_divergence)?; + let sub_value = sub.values().next().unwrap(); + let sup_value = sup.values().next().unwrap(); + self.sub_unify_tp(sub_value, sup_value, _variance, allow_divergence)?; + return Ok(()); + } for (sub_k, sub_v) in sub.iter() { if let Some(sup_v) = sup .linear_get(sub_k) diff --git a/tests/should_err/dict.er b/tests/should_err/dict.er index caeb4fecd..225cb28f8 100644 --- a/tests/should_err/dict.er +++ b/tests/should_err/dict.er @@ -2,3 +2,6 @@ for! {"a": 1, "b": 2}.values(), s => print! "key: " + s # ERR for! {"a": 1, "b": 2}.keys(), i => print! i + 0 # ERR + +dic as {Nat: Int} = {1: -1} +_ = dic[-1] # ERR diff --git a/tests/should_err/method.er b/tests/should_err/method.er index d400075a0..52c4f3a94 100644 --- a/tests/should_err/method.er +++ b/tests/should_err/method.er @@ -1,2 +1,13 @@ .call_method obj, x = obj.method(x) # ERR .member obj = obj.member # ERR + +.D2! = Class Dict! { Str: List!(Int) } +.D2!. + new!() = .D2! !{:} + insert!(ref! self, name, val) = + self::base.insert! name, ![val] + +d = .D2!.new!() +d.insert! "aaa", 1 # OK +d.insert! "aaa", "bbb" # ERR +d.insert! 1, 1 # ERR diff --git a/tests/should_err/mut_dict.er b/tests/should_err/mut_dict.er index 331970a2e..7f09f0af4 100644 --- a/tests/should_err/mut_dict.er +++ b/tests/should_err/mut_dict.er @@ -1,4 +1,4 @@ -d = {"a": 1} +d as {{"a", "b"}: {1, 2}} = {"a": 1} dict = !d dict.insert! "b", 2 diff --git a/tests/should_ok/mut_dict.er b/tests/should_ok/mut_dict.er index 695e0463d..53d0dc671 100644 --- a/tests/should_ok/mut_dict.er +++ b/tests/should_ok/mut_dict.er @@ -1,4 +1,4 @@ -d = {"a": 1} +d as {{"a", "b"}: {1, 2}} = {"a": 1} dic = !d dic.insert! "b", 2 diff --git a/tests/test.rs b/tests/test.rs index 3bae7e6a4..e1df7ae49 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -599,7 +599,7 @@ fn exec_dependent_err() -> Result<(), ()> { #[test] fn exec_dict_err() -> Result<(), ()> { - expect_compile_failure("tests/should_err/dict.er", 0, 2) + expect_compile_failure("tests/should_err/dict.er", 0, 3) } #[test] @@ -741,7 +741,7 @@ fn exec_callable() -> Result<(), ()> { #[test] fn exec_method_err() -> Result<(), ()> { - expect_compile_failure("tests/should_err/method.er", 0, 2) + expect_compile_failure("tests/should_err/method.er", 0, 4) } #[test]