Skip to content

Commit

Permalink
fix: dict type bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Oct 2, 2024
1 parent a290684 commit 80eae76
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 40 deletions.
134 changes: 133 additions & 1 deletion crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use erg_common::dict::Dict;
use erg_common::set::Set;
use erg_common::style::colors::DEBUG_ERROR;
use erg_common::traits::StructuralEq;
use erg_common::{assume_unreachable, log, set_recursion_limit};
use erg_common::{assume_unreachable, log, set, set_recursion_limit};
use erg_common::{Str, Triple};

use crate::context::eval::UndoableLinkedList;
Expand Down Expand Up @@ -1457,6 +1457,7 @@ impl Context {
/// union(List(Int, 2), List(Str, 2)) == List(Int or Str, 2)
/// union(List(Int, 2), List(Str, 3)) == List(Int, 2) or List(Int, 3)
/// union(List(Int, 2), List(Int, ?N)) == List(Int, ?N)
/// union(List(Int, 2), List(?T, 2)) == List(Int or ?T, 2)
/// union({ .a = Int }, { .a = Str }) == { .a = Int or Str }
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information
/// union((A and B) or C) == (A or C) and (B or C)
Expand Down Expand Up @@ -1566,6 +1567,8 @@ impl Context {
/// union_tp(1, 1) => Some(1)
/// union_tp(1, 2) => None
/// union_tp(?N, 2) => Some(2) # REVIEW:
/// union_tp(_: Obj, 1) == Some(_: Obj)
/// union_tp(_: Nat, -1) == Some(_: Int)
/// ```
pub(crate) fn union_tp(&self, lhs: &TyParam, rhs: &TyParam) -> Option<TyParam> {
match (lhs, rhs) {
Expand All @@ -1590,6 +1593,40 @@ impl Context {
}
Some(TyParam::List(tps))
}
(TyParam::Tuple(l), TyParam::Tuple(r)) => {
let mut tps = vec![];
for (l, r) in l.iter().zip(r.iter()) {
if let Some(tp) = self.union_tp(l, r) {
tps.push(tp);
} else {
return None;
}
}
Some(TyParam::Tuple(tps))
}
(TyParam::UnsizedList(l), TyParam::UnsizedList(r)) => {
Some(TyParam::unsized_list(self.union_tp(l, r)?))
}
(TyParam::Set(l), TyParam::Set(r)) if l.len() == 1 && r.len() == 1 => {
let l = l.iter().next().unwrap();
let r = r.iter().next().unwrap();
Some(TyParam::Set(set! { self.union_tp(l, r)? }))
}
(TyParam::Record(l), TyParam::Record(r)) if l.len() == 1 && r.len() == 1 => {
let mut tps = Dict::new();
for (l_k, l_v) in l.iter() {
if let Some(r_v) = r.get(l_k) {
if let Some(tp) = self.union_tp(l_v, r_v) {
tps.insert(l_k.clone(), tp);
} else {
return None;
}
} else {
return None;
}
}
Some(TyParam::Record(tps))
}
(fv @ TyParam::FreeVar(f), other) | (other, fv @ TyParam::FreeVar(f))
if f.is_unbound() =>
{
Expand All @@ -1601,6 +1638,31 @@ impl Context {
None
}
}
(TyParam::Erased(t), other) | (other, TyParam::Erased(t)) => {
let other_t = self.get_tp_t(other).ok()?.derefine();
Some(TyParam::erased(self.union(t, &other_t)))
}
(
TyParam::App {
name: ln,
args: las,
},
TyParam::App {
name: rn,
args: ras,
},
) if ln == rn => {
debug_assert_eq!(las.len(), ras.len());
let mut unified_args = vec![];
for (lp, rp) in las.iter().zip(ras.iter()) {
if let Some(union) = self.union_tp(lp, rp) {
unified_args.push(union);
} else {
return None;
}
}
Some(TyParam::app(ln.clone(), unified_args))
}
(_, _) => {
if self.eq_tp(lhs, rhs) {
Some(lhs.clone())
Expand Down Expand Up @@ -1811,6 +1873,12 @@ impl Context {
}
}

/// ```erg
/// intersection_tp(1, 1) => Some(1)
/// intersection_tp(1, 2) => None
/// intersection_tp(?N, 2) => Some(2) # REVIEW:
/// intersection_tp(_: Nat, 1) == Some(1)
/// intersection_tp(_: Str, 1) == None
pub(crate) fn intersection_tp(&self, lhs: &TyParam, rhs: &TyParam) -> Option<TyParam> {
match (lhs, rhs) {
(TyParam::Value(ValueObj::Type(l)), TyParam::Value(ValueObj::Type(r))) => {
Expand All @@ -1834,6 +1902,40 @@ impl Context {
}
Some(TyParam::List(tps))
}
(TyParam::Tuple(l), TyParam::Tuple(r)) => {
let mut tps = vec![];
for (l, r) in l.iter().zip(r.iter()) {
if let Some(tp) = self.intersection_tp(l, r) {
tps.push(tp);
} else {
return None;
}
}
Some(TyParam::Tuple(tps))
}
(TyParam::UnsizedList(l), TyParam::UnsizedList(r)) => {
Some(TyParam::unsized_list(self.intersection_tp(l, r)?))
}
(TyParam::Set(l), TyParam::Set(r)) if l.len() == 1 && r.len() == 1 => {
let l = l.iter().next().unwrap();
let r = r.iter().next().unwrap();
Some(TyParam::Set(set! { self.intersection_tp(l, r)? }))
}
(TyParam::Record(l), TyParam::Record(r)) if l.len() == 1 && r.len() == 1 => {
let mut tps = Dict::new();
for (l_k, l_v) in l.iter() {
if let Some(r_v) = r.get(l_k) {
if let Some(tp) = self.intersection_tp(l_v, r_v) {
tps.insert(l_k.clone(), tp);
} else {
return None;
}
} else {
return None;
}
}
Some(TyParam::Record(tps))
}
(fv @ TyParam::FreeVar(f), other) | (other, fv @ TyParam::FreeVar(f))
if f.is_unbound() =>
{
Expand All @@ -1845,6 +1947,36 @@ impl Context {
None
}
}
(TyParam::Erased(t), other) | (other, TyParam::Erased(t)) => {
let other_t = self.get_tp_t(other).ok()?.derefine();
let isec = self.intersection(t, &other_t);
if isec != Never {
Some(other.clone())
} else {
None
}
}
(
TyParam::App {
name: ln,
args: las,
},
TyParam::App {
name: rn,
args: ras,
},
) if ln == rn => {
debug_assert_eq!(las.len(), ras.len());
let mut unified_args = vec![];
for (lp, rp) in las.iter().zip(ras.iter()) {
if let Some(intersec) = self.intersection_tp(lp, rp) {
unified_args.push(intersec);
} else {
return None;
}
}
Some(TyParam::app(ln.clone(), unified_args))
}
(_, _) => {
if self.eq_tp(lhs, rhs) {
Some(lhs.clone())
Expand Down
42 changes: 9 additions & 33 deletions crates/erg_compiler/context/initialize/classes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,7 @@ impl Context {
let V = type_q(TY_V);
let get_t = no_var_fn_met(
dict! { K.clone() => V.clone() }.into(),
vec![kw(KW_KEY, K.clone())],
vec![kw(KW_KEY, Obj)],
vec![kw_default(KW_DEFAULT, Def.clone(), NoneType)],
or(V.clone(), Def),
)
Expand Down Expand Up @@ -3855,19 +3855,15 @@ impl Context {
);
/* Dict! */
let dict_mut_t = poly(MUT_DICT, vec![D.clone()]);
let dict_kv_t = poly(DICT, vec![dict! { K.clone() => V.clone() }.into()]);
let dict_mut_kv_t = poly(MUT_DICT, vec![dict! { K.clone() => V.clone() }.into()]);
let mut dict_mut =
Self::builtin_poly_class(MUT_DICT, vec![PS::named_nd(TY_D, mono(GENERIC_DICT))], 3);
dict_mut.register_superclass(dict_t.clone(), &dict_);
let K = type_q(TY_K);
let V = type_q(TY_V);
let insert_t = pr_met(
ref_mut(
dict_mut_t.clone(),
Some(poly(
MUT_DICT,
vec![D.clone() + dict! { K.clone() => V.clone() }.into()],
)),
),
ref_mut(dict_mut_kv_t.clone(), None),
vec![kw(KW_KEY, K.clone()), kw(KW_VALUE, V.clone())],
None,
vec![],
Expand All @@ -3876,30 +3872,16 @@ impl Context {
.quantify();
dict_mut.register_py_builtin(PROC_INSERT, insert_t, Some(FUNDAMENTAL_SETITEM), 12);
let remove_t = pr_met(
ref_mut(
dict_mut_t.clone(),
Some(poly(
MUT_DICT,
vec![D
.clone()
.proj_call(FUNC_DIFF.into(), vec![dict! { K.clone() => Never }.into()])],
)),
),
ref_mut(dict_mut_kv_t.clone(), None),
vec![kw(KW_KEY, K.clone())],
None,
vec![],
proj_call(D.clone(), FUNDAMENTAL_GETITEM, vec![ty_tp(K.clone())]) | NoneType,
V.clone() | NoneType,
)
.quantify();
dict_mut.register_py_builtin(PROC_REMOVE, remove_t, Some(FUNC_REMOVE), 19);
let update_t = pr_met(
ref_mut(
dict_mut_t.clone(),
Some(poly(
MUT_DICT,
vec![D.clone() + dict! { K.clone() => V.clone() }.into()],
)),
),
ref_mut(dict_mut_kv_t.clone(), None),
vec![kw(
KW_ITERABLE,
poly(ITERABLE, vec![ty_tp(tuple_t(vec![K.clone(), V.clone()]))]),
Expand All @@ -3914,14 +3896,8 @@ impl Context {
.quantify();
dict_mut.register_py_builtin(PROC_UPDATE, update_t, Some(FUNC_UPDATE), 26);
let merge_t = pr_met(
ref_mut(
dict_mut_t.clone(),
Some(poly(
MUT_DICT,
vec![D.proj_call(FUNC_CONCAT.into(), vec![D2.clone()])],
)),
),
vec![kw(KW_OTHER, poly(DICT, vec![D2.clone()]))],
ref_mut(dict_mut_kv_t.clone(), None),
vec![kw(KW_OTHER, dict_kv_t.clone())],
None,
vec![],
NoneType,
Expand Down
2 changes: 1 addition & 1 deletion crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4045,7 +4045,7 @@ impl Context {
pub fn is_trait(&self, typ: &Type) -> bool {
match typ {
Type::Never => false,
Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()),
Type::FreeVar(fv) if fv.is_linked() => self.is_trait(&fv.crack()),
Type::FreeVar(_) => false,
Type::And(tys, _) => tys.iter().any(|t| self.is_trait(t)),
Type::Or(tys) => tys.iter().all(|t| self.is_trait(t)),
Expand Down
13 changes: 12 additions & 1 deletion crates/erg_compiler/context/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
if sub.len() == 1 && sup.len() == 1 {
let sub_key = sub.keys().next().unwrap();
let sup_key = sup.keys().next().unwrap();
self.sub_unify_value(sub_key, sup_key)?;
// contravariant
self.sub_unify_value(sup_key, sub_key)?;
let sub_value = sub.values().next().unwrap();
let sup_value = sup.values().next().unwrap();
self.sub_unify_value(sub_value, sup_value)?;
Expand Down Expand Up @@ -698,6 +699,16 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
Ok(())
}
(TyParam::Dict(sub), TyParam::Dict(sup)) => {
if sub.len() == 1 && sup.len() == 1 {
let sub_key = sub.keys().next().unwrap();
let sup_key = sup.keys().next().unwrap();
// contravariant
self.sub_unify_tp(sup_key, sub_key, _variance, allow_divergence)?;
let sub_value = sub.values().next().unwrap();
let sup_value = sup.values().next().unwrap();
self.sub_unify_tp(sub_value, sup_value, _variance, allow_divergence)?;
return Ok(());
}
for (sub_k, sub_v) in sub.iter() {
if let Some(sup_v) = sup
.linear_get(sub_k)
Expand Down
3 changes: 3 additions & 0 deletions tests/should_err/dict.er
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ for! {"a": 1, "b": 2}.values(), s =>
print! "key: " + s # ERR
for! {"a": 1, "b": 2}.keys(), i =>
print! i + 0 # ERR

dic as {Nat: Int} = {1: -1}
_ = dic[-1] # ERR
11 changes: 11 additions & 0 deletions tests/should_err/method.er
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
.call_method obj, x = obj.method(x) # ERR
.member obj = obj.member # ERR

.D2! = Class Dict! { Str: List!(Int) }
.D2!.
new!() = .D2! !{:}
insert!(ref! self, name, val) =
self::base.insert! name, ![val]

d = .D2!.new!()
d.insert! "aaa", 1 # OK
d.insert! "aaa", "bbb" # ERR
d.insert! 1, 1 # ERR
2 changes: 1 addition & 1 deletion tests/should_err/mut_dict.er
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
d = {"a": 1}
d as {{"a", "b"}: {1, 2}} = {"a": 1}
dict = !d

dict.insert! "b", 2
Expand Down
2 changes: 1 addition & 1 deletion tests/should_ok/mut_dict.er
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
d = {"a": 1}
d as {{"a", "b"}: {1, 2}} = {"a": 1}
dic = !d

dic.insert! "b", 2
Expand Down
4 changes: 2 additions & 2 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ fn exec_dependent_err() -> Result<(), ()> {

#[test]
fn exec_dict_err() -> Result<(), ()> {
expect_compile_failure("tests/should_err/dict.er", 0, 2)
expect_compile_failure("tests/should_err/dict.er", 0, 3)
}

#[test]
Expand Down Expand Up @@ -741,7 +741,7 @@ fn exec_callable() -> Result<(), ()> {

#[test]
fn exec_method_err() -> Result<(), ()> {
expect_compile_failure("tests/should_err/method.er", 0, 2)
expect_compile_failure("tests/should_err/method.er", 0, 4)
}

#[test]
Expand Down

0 comments on commit 80eae76

Please sign in to comment.