diff --git a/crates/els/completion.rs b/crates/els/completion.rs index 47c465110..13cb818ff 100644 --- a/crates/els/completion.rs +++ b/crates/els/completion.rs @@ -42,24 +42,30 @@ fn comp_item_kind(t: &Type, muty: Mutability) -> CompletionItemKind { Type::Subr(_) | Type::Quantified(_) => CompletionItemKind::FUNCTION, Type::ClassType => CompletionItemKind::CLASS, Type::TraitType => CompletionItemKind::INTERFACE, - Type::Or(l, r) => { - let l = comp_item_kind(l, muty); - let r = comp_item_kind(r, muty); - if l == r { - l + Type::Or(tys) => { + let fst = comp_item_kind(tys.iter().next().unwrap(), muty); + if tys + .iter() + .map(|t| comp_item_kind(t, muty)) + .all(|k| k == fst) + { + fst } else if muty.is_const() { CompletionItemKind::CONSTANT } else { CompletionItemKind::VARIABLE } } - Type::And(l, r) => { - let l = comp_item_kind(l, muty); - let r = comp_item_kind(r, muty); - if l == CompletionItemKind::VARIABLE { - r + Type::And(tys) => { + for k in tys.iter().map(|t| comp_item_kind(t, muty)) { + if k != CompletionItemKind::VARIABLE { + return k; + } + } + if muty.is_const() { + CompletionItemKind::CONSTANT } else { - l + CompletionItemKind::VARIABLE } } Type::Refinement(r) => comp_item_kind(&r.t, muty), diff --git a/crates/erg_common/dict.rs b/crates/erg_common/dict.rs index 9f46fd13c..c4b2767c5 100644 --- a/crates/erg_common/dict.rs +++ b/crates/erg_common/dict.rs @@ -129,6 +129,18 @@ impl Dict { } } + /// ``` + /// # use erg_common::dict; + /// # use erg_common::dict::Dict; + /// let mut dict = Dict::with_capacity(3); + /// assert_eq!(dict.capacity(), 3); + /// dict.insert("a", 1); + /// assert_eq!(dict.capacity(), 3); + /// dict.insert("b", 2); + /// dict.insert("c", 3); + /// dict.insert("d", 4); + /// assert_ne!(dict.capacity(), 3); + /// ``` pub fn with_capacity(capacity: usize) -> Self { Self { dict: FxHashMap::with_capacity_and_hasher(capacity, Default::default()), diff --git a/crates/erg_common/macros.rs b/crates/erg_common/macros.rs index 93a6130ff..8e4e3e8e8 100644 --- a/crates/erg_common/macros.rs +++ b/crates/erg_common/macros.rs @@ -627,6 +627,17 @@ impl RecursionCounter { #[macro_export] macro_rules! set_recursion_limit { + (panic, $msg:expr, $limit:expr) => { + use std::sync::atomic::AtomicU32; + + static COUNTER: AtomicU32 = AtomicU32::new($limit); + + let counter = $crate::macros::RecursionCounter::new(&COUNTER); + if counter.limit_reached() { + $crate::log!(err "Recursion limit reached"); + panic!($msg); + } + }; ($returns:expr, $limit:expr) => { use std::sync::atomic::AtomicU32; diff --git a/crates/erg_common/set.rs b/crates/erg_common/set.rs index c925cf1e9..8a2fe180a 100644 --- a/crates/erg_common/set.rs +++ b/crates/erg_common/set.rs @@ -382,6 +382,20 @@ impl Set { self.insert(other); self } + + /// ``` + /// # use erg_common::set; + /// assert_eq!(set!{1, 2}.product(&set!{3, 4}), set!{(&1, &3), (&1, &4), (&2, &3), (&2, &4)}); + /// ``` + pub fn product<'l, 'r, U: Hash + Eq>(&'l self, other: &'r Set) -> Set<(&'l T, &'r U)> { + let mut res = set! {}; + for x in self.iter() { + for y in other.iter() { + res.insert((x, y)); + } + } + res + } } impl Set { diff --git a/crates/erg_common/traits.rs b/crates/erg_common/traits.rs index 7c0c84bf1..168a268f7 100644 --- a/crates/erg_common/traits.rs +++ b/crates/erg_common/traits.rs @@ -1407,6 +1407,8 @@ impl Immutable for &T {} impl Immutable for Option {} impl Immutable for Vec {} impl Immutable for [T] {} +impl Immutable for (T, U) {} +impl Immutable for (T, U, V) {} impl Immutable for Box {} impl Immutable for std::rc::Rc {} impl Immutable for std::sync::Arc {} diff --git a/crates/erg_common/triple.rs b/crates/erg_common/triple.rs index a743d43ab..a3626d3d3 100644 --- a/crates/erg_common/triple.rs +++ b/crates/erg_common/triple.rs @@ -18,6 +18,16 @@ impl fmt::Display for Triple { } impl Triple { + pub const fn is_ok(&self) -> bool { + matches!(self, Triple::Ok(_)) + } + pub const fn is_err(&self) -> bool { + matches!(self, Triple::Err(_)) + } + pub const fn is_none(&self) -> bool { + matches!(self, Triple::None) + } + pub fn none_then(self, err: E) -> Result { match self { Triple::None => Err(err), diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index b5f62b1f3..f6125066d 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -763,9 +763,7 @@ impl Context { self.structural_supertype_of(&l, rhs) } // {1, 2, 3} :> {1} or {2, 3} == true - (Refinement(_refine), Or(l, r)) => { - self.supertype_of(lhs, l) && self.supertype_of(lhs, r) - } + (Refinement(_refine), Or(tys)) => tys.iter().all(|ty| self.supertype_of(lhs, ty)), // ({I: Int | True} :> Int) == true // {N: Nat | ...} :> Int) == false // ({I: Int | I >= 0} :> Int) == false @@ -817,41 +815,37 @@ impl Context { self.sub_unify(&inst, l, &(), None).is_ok() } // Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true - (Or(l_1, l_2), Or(r_1, r_2)) => { - if l_1.is_union_type() && self.supertype_of(l_1, rhs) { - return true; - } - if l_2.is_union_type() && self.supertype_of(l_2, rhs) { - return true; - } - (self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2)) - || (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1)) - } + // Int or Str or NoneType :> Str or Int + // Int or Str or NoneType :> Str or NoneType or Nat + (Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))), // not Nat :> not Int == true (Not(l), Not(r)) => self.subtype_of(l, r), // (Int or Str) :> Nat == Int :> Nat || Str :> Nat == true // (Num or Show) :> Show == Num :> Show || Show :> Num == true - (Or(l_or, r_or), rhs) => self.supertype_of(l_or, rhs) || self.supertype_of(r_or, rhs), + (Or(ors), rhs) => ors.iter().any(|or| self.supertype_of(or, rhs)), // Int :> (Nat or Str) == Int :> Nat && Int :> Str == false - (lhs, Or(l_or, r_or)) => self.supertype_of(lhs, l_or) && self.supertype_of(lhs, r_or), - (And(l_1, l_2), And(r_1, r_2)) => { - if l_1.is_intersection_type() && self.supertype_of(l_1, rhs) { + (lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)), + // Hash and Eq :> HashEq and ... == true + // Add(T) and Eq :> Add(Int) and Eq == true + (And(l), And(r)) => { + if r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))) { return true; } - if l_2.is_intersection_type() && self.supertype_of(l_2, rhs) { - return true; + if l.len() == r.len() { + let mut r = r.clone(); + for _ in 1..l.len() { + if l.iter().zip(&r).all(|(l, r)| self.supertype_of(l, r)) { + return true; + } + r.rotate_left(1); + } } - (self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2)) - || (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1)) + false } // (Num and Show) :> Show == false - (And(l_and, r_and), rhs) => { - self.supertype_of(l_and, rhs) && self.supertype_of(r_and, rhs) - } + (And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)), // Show :> (Num and Show) == true - (lhs, And(l_and, r_and)) => { - self.supertype_of(lhs, l_and) || self.supertype_of(lhs, r_and) - } + (lhs, And(ands)) => ands.iter().any(|and| self.supertype_of(lhs, and)), // Not(Eq) :> Float == !(Eq :> Float) == true (Not(_), Obj) => false, (Not(l), rhs) => !self.supertype_of(l, rhs), @@ -923,18 +917,18 @@ impl Context { Type::NamedTuple(fields) => fields.iter().cloned().collect(), Type::Refinement(refine) => self.fields(&refine.t), Type::Structural(t) => self.fields(t), - Type::Or(l, r) => { - let l_fields = self.fields(l); - let r_fields = self.fields(r); - let l_field_names = l_fields.keys().collect::>(); - let r_field_names = r_fields.keys().collect::>(); - let field_names = l_field_names.intersection(&r_field_names); + Type::Or(tys) => { + let or_fields = tys.iter().map(|t| self.fields(t)).collect::>(); + let field_names = or_fields + .iter() + .flat_map(|fs| fs.keys()) + .collect::>(); let mut fields = Dict::new(); - for (name, l_t, r_t) in field_names + for (name, tys) in field_names .iter() - .map(|&name| (name, &l_fields[name], &r_fields[name])) + .map(|&name| (name, or_fields.iter().filter_map(|fields| fields.get(name)))) { - let union = self.union(l_t, r_t); + let union = tys.fold(Never, |acc, ty| self.union(&acc, ty)); fields.insert(name.clone(), union); } fields @@ -1417,6 +1411,8 @@ impl Context { /// union({ .a = Int }, { .a = Str }) == { .a = Int or Str } /// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information /// union((A and B) or C) == (A or C) and (B or C) + /// union(Never, Int) == Int + /// union(Obj, Int) == Obj /// ``` pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type { if lhs == rhs { @@ -1479,10 +1475,9 @@ impl Context { (Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()), _ => self.simple_union(lhs, rhs), }, - (other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other), + (other, or @ Or(_)) | (or @ Or(_), other) => self.union_add(or, other), // (A and B) or C ==> (A or C) and (B or C) - (and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => { - let ands = and_t.ands(); + (And(ands), other) | (other, And(ands)) => { let mut t = Type::Obj; for branch in ands.iter() { let union = self.union(branch, other); @@ -1666,6 +1661,12 @@ impl Context { /// Returns intersection of two types (`A and B`). /// If `A` and `B` have a subtype relationship, it is equal to `min(A, B)`. + /// ```erg + /// intersection(Nat, Int) == Nat + /// intersection(Int, Str) == Never + /// intersection(Obj, Int) == Int + /// intersection(Never, Int) == Never + /// ``` pub(crate) fn intersection(&self, lhs: &Type, rhs: &Type) -> Type { if lhs == rhs { return lhs.clone(); @@ -1696,12 +1697,9 @@ impl Context { (_, Not(r)) => self.diff(lhs, r), (Not(l), _) => self.diff(rhs, l), // A and B and A == A and B - (other, and @ And(_, _)) | (and @ And(_, _), other) => { - self.intersection_add(and, other) - } + (other, and @ And(_)) | (and @ And(_), other) => self.intersection_add(and, other), // (A or B) and C == (A and C) or (B and C) - (or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => { - let ors = or_t.ors(); + (Or(ors), other) | (other, Or(ors)) => { if ors.iter().any(|t| t.has_unbound_var()) { return self.simple_intersection(lhs, rhs); } @@ -1797,13 +1795,15 @@ impl Context { /// intersection_add(Int and ?T(:> NoneType), Str) == Never /// ``` fn intersection_add(&self, intersection: &Type, elem: &Type) -> Type { - let ands = intersection.ands(); + let mut ands = intersection.ands(); let bounded = ands.iter().map(|t| t.lower_bounded()); for t in bounded { if self.subtype_of(&t, elem) { return intersection.clone(); } else if self.supertype_of(&t, elem) { - return constructors::ands(ands.linear_exclude(&t).include(elem.clone())); + ands.retain(|ty| ty != &t); + ands.push(elem.clone()); + return constructors::ands(ands); } } and(intersection.clone(), elem.clone()) @@ -1836,21 +1836,21 @@ impl Context { fn narrow_type_by_pred(&self, t: Type, pred: &Predicate) -> Type { match (t, pred) { ( - Type::Or(l, r), + Type::Or(tys), Predicate::NotEqual { rhs: TyParam::Value(v), .. }, ) if v.is_none() => { - let l = self.narrow_type_by_pred(*l, pred); - let r = self.narrow_type_by_pred(*r, pred); - if l.is_nonetype() { - r - } else if r.is_nonetype() { - l - } else { - or(l, r) + let mut new_tys = Set::new(); + for ty in tys { + let ty = self.narrow_type_by_pred(ty, pred); + if ty.is_nonelike() { + continue; + } + new_tys.insert(ty); } + Type::checked_or(new_tys) } (Type::Refinement(refine), _) => { let t = self.narrow_type_by_pred(*refine.t, pred); @@ -1992,8 +1992,12 @@ impl Context { guard.target.clone(), self.complement(&guard.to), )), - Or(l, r) => self.intersection(&self.complement(l), &self.complement(r)), - And(l, r) => self.union(&self.complement(l), &self.complement(r)), + Or(ors) => ors + .iter() + .fold(Obj, |l, r| self.intersection(&l, &self.complement(r))), + And(ands) => ands + .iter() + .fold(Never, |l, r| self.union(&l, &self.complement(r))), other => not(other.clone()), } } @@ -2011,7 +2015,14 @@ impl Context { match lhs { Type::FreeVar(fv) if fv.is_linked() => self.diff(&fv.crack(), rhs), // Type::And(l, r) => self.intersection(&self.diff(l, rhs), &self.diff(r, rhs)), - Type::Or(l, r) => self.union(&self.diff(l, rhs), &self.diff(r, rhs)), + Type::Or(tys) => { + let mut new_tys = vec![]; + for ty in tys { + let diff = self.diff(ty, rhs); + new_tys.push(diff); + } + new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r)) + } _ => lhs.clone(), } } diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index a8772d51b..61531a765 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -2285,44 +2285,42 @@ impl Context { Err((t, errs)) } } - Type::And(l, r) => { - let l = match self.eval_t_params(*l, level, t_loc) { - Ok(l) => l, - Err((l, es)) => { - errs.extend(es); - l - } - }; - let r = match self.eval_t_params(*r, level, t_loc) { - Ok(r) => r, - Err((r, es)) => { - errs.extend(es); - r + Type::And(ands) => { + let mut new_ands = set! {}; + for and in ands.into_iter() { + match self.eval_t_params(and, level, t_loc) { + Ok(and) => { + new_ands.insert(and); + } + Err((and, es)) => { + new_ands.insert(and); + errs.extend(es); + } } - }; - let intersec = self.intersection(&l, &r); + } + let intersec = new_ands + .into_iter() + .fold(Type::Obj, |l, r| self.intersection(&l, &r)); if errs.is_empty() { Ok(intersec) } else { Err((intersec, errs)) } } - Type::Or(l, r) => { - let l = match self.eval_t_params(*l, level, t_loc) { - Ok(l) => l, - Err((l, es)) => { - errs.extend(es); - l - } - }; - let r = match self.eval_t_params(*r, level, t_loc) { - Ok(r) => r, - Err((r, es)) => { - errs.extend(es); - r + Type::Or(ors) => { + let mut new_ors = set! {}; + for or in ors.into_iter() { + match self.eval_t_params(or, level, t_loc) { + Ok(or) => { + new_ors.insert(or); + } + Err((or, es)) => { + new_ors.insert(or); + errs.extend(es); + } } - }; - let union = self.union(&l, &r); + } + let union = new_ors.into_iter().fold(Never, |l, r| self.union(&l, &r)); if errs.is_empty() { Ok(union) } else { diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index 13a0049ad..d8899e701 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -289,17 +289,21 @@ impl Generalizer { } proj_call(lhs, attr_name, args) } - And(l, r) => { - let l = self.generalize_t(*l, uninit); - let r = self.generalize_t(*r, uninit); + And(ands) => { // not `self.intersection` because types are generalized - and(l, r) + let ands = ands + .into_iter() + .map(|t| self.generalize_t(t, uninit)) + .collect(); + Type::checked_and(ands) } - Or(l, r) => { - let l = self.generalize_t(*l, uninit); - let r = self.generalize_t(*r, uninit); + Or(ors) => { // not `self.union` because types are generalized - or(l, r) + let ors = ors + .into_iter() + .map(|t| self.generalize_t(t, uninit)) + .collect(); + Type::checked_or(ors) } Not(l) => not(self.generalize_t(*l, uninit)), Structural(ty) => { @@ -1045,15 +1049,23 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> { let pred = self.deref_pred(*refine.pred)?; Ok(refinement(refine.var, t, pred)) } - And(l, r) => { - let l = self.deref_tyvar(*l)?; - let r = self.deref_tyvar(*r)?; - Ok(self.ctx.intersection(&l, &r)) + And(ands) => { + let mut new_ands = vec![]; + for t in ands.into_iter() { + new_ands.push(self.deref_tyvar(t)?); + } + Ok(new_ands + .into_iter() + .fold(Type::Obj, |acc, t| self.ctx.intersection(&acc, &t))) } - Or(l, r) => { - let l = self.deref_tyvar(*l)?; - let r = self.deref_tyvar(*r)?; - Ok(self.ctx.union(&l, &r)) + Or(ors) => { + let mut new_ors = vec![]; + for t in ors.into_iter() { + new_ors.push(self.deref_tyvar(t)?); + } + Ok(new_ors + .into_iter() + .fold(Type::Never, |acc, t| self.ctx.union(&acc, &t))) } Not(ty) => { let ty = self.deref_tyvar(*ty)?; @@ -1733,22 +1745,33 @@ impl Context { /// ``` pub(crate) fn squash_tyvar(&self, typ: Type) -> Type { match typ { - Or(l, r) => { - let l = self.squash_tyvar(*l); - let r = self.squash_tyvar(*r); + Or(tys) => { + let new_tys = tys + .into_iter() + .map(|t| self.squash_tyvar(t)) + .collect::>(); + let mut union = Never; // REVIEW: - if l.is_unnamed_unbound_var() && r.is_unnamed_unbound_var() { - match (self.subtype_of(&l, &r), self.subtype_of(&r, &l)) { - (true, true) | (true, false) => { - let _ = self.sub_unify(&l, &r, &(), None); + if new_tys.iter().all(|t| t.is_unnamed_unbound_var()) { + for ty in new_tys.iter() { + if union == Never { + union = ty.clone(); + continue; } - (false, true) => { - let _ = self.sub_unify(&r, &l, &(), None); + match (self.subtype_of(&union, ty), self.subtype_of(&union, ty)) { + (true, true) | (true, false) => { + let _ = self.sub_unify(&union, ty, &(), None); + } + (false, true) => { + let _ = self.sub_unify(ty, &union, &(), None); + } + _ => {} } - _ => {} } } - self.union(&l, &r) + new_tys + .into_iter() + .fold(Never, |acc, t| self.union(&acc, &t)) } FreeVar(ref fv) if fv.constraint_is_sandwiched() => { let (sub_t, super_t) = fv.get_subsup().unwrap(); diff --git a/crates/erg_compiler/context/hint.rs b/crates/erg_compiler/context/hint.rs index e7826d022..b7b795224 100644 --- a/crates/erg_compiler/context/hint.rs +++ b/crates/erg_compiler/context/hint.rs @@ -116,9 +116,12 @@ impl Context { return Some(hint); } } - (Type::And(l, r), found) => { - let left = self.readable_type(l.as_ref().clone()); - let right = self.readable_type(r.as_ref().clone()); + (Type::And(tys), found) if tys.len() == 2 => { + let mut iter = tys.iter(); + let l = iter.next().unwrap(); + let r = iter.next().unwrap(); + let left = self.readable_type(l.clone()); + let right = self.readable_type(r.clone()); if self.supertype_of(l, found) { let msg = switch_lang!( "japanese" => format!("型{found}は{left}のサブタイプですが、{right}のサブタイプではありません"), diff --git a/crates/erg_compiler/context/initialize/classes.rs b/crates/erg_compiler/context/initialize/classes.rs index 1d9fba476..ae54c018c 100644 --- a/crates/erg_compiler/context/initialize/classes.rs +++ b/crates/erg_compiler/context/initialize/classes.rs @@ -16,6 +16,9 @@ use crate::varinfo::Mutability; use Mutability::*; impl Context { + // NOTE: Registering traits that a class implements requires type checking, + // which means that registering a class requires that the preceding types have already been registered, + // so `register_builtin_type` should be called as early as possible. pub(super) fn init_builtin_classes(&mut self) { let vis = if PYTHON_MODE { Visibility::BUILTIN_PUBLIC @@ -29,6 +32,7 @@ impl Context { let N = mono_q_tp(TY_N, instanceof(Nat)); let M = mono_q_tp(TY_M, instanceof(Nat)); let never = Self::builtin_mono_class(NEVER, 1); + self.register_builtin_type(Never, never, vis.clone(), Const, Some(NEVER)); /* Obj */ let mut obj = Self::builtin_mono_class(OBJ, 2); obj.register_py_builtin( @@ -2965,6 +2969,21 @@ impl Context { None, union, ); + self.register_builtin_type( + mono(GENERIC_TUPLE), + generic_tuple, + vis.clone(), + Const, + Some(FUNC_TUPLE), + ); + self.register_builtin_type( + homo_tuple_t, + homo_tuple, + vis.clone(), + Const, + Some(FUNC_TUPLE), + ); + self.register_builtin_type(_tuple_t, tuple_, vis.clone(), Const, Some(FUNC_TUPLE)); /* Or (true or type) */ let or_t = poly(OR, vec![ty_tp(L), ty_tp(R)]); let mut or = Self::builtin_poly_class(OR, vec![PS::t_nd(TY_L), PS::t_nd(TY_R)], 2); @@ -3673,6 +3692,8 @@ impl Context { Some(FUNC_UPDATE), ); list_mut_.register_trait_methods(list_mut_t.clone(), list_mut_mutable); + self.register_builtin_type(lis_t, list_, vis.clone(), Const, Some(LIST)); + self.register_builtin_type(list_mut_t, list_mut_, vis.clone(), Const, Some(LIST)); /* ByteArray! */ let bytearray_mut_t = mono(MUT_BYTEARRAY); let mut bytearray_mut = Self::builtin_mono_class(MUT_BYTEARRAY, 2); @@ -4213,7 +4234,6 @@ impl Context { let mut qfunc_meta_type = Self::builtin_mono_class(QUANTIFIED_FUNC_META_TYPE, 2); qfunc_meta_type.register_superclass(mono(QUANTIFIED_PROC_META_TYPE), &qproc_meta_type); qfunc_meta_type.register_superclass(mono(QUANTIFIED_FUNC), &qfunc); - self.register_builtin_type(Never, never, vis.clone(), Const, Some(NEVER)); self.register_builtin_type(Obj, obj, vis.clone(), Const, Some(FUNC_OBJECT)); // self.register_type(mono(RECORD), vec![], record, Visibility::BUILTIN_PRIVATE, Const); let name = if PYTHON_MODE { FUNC_INT } else { INT }; @@ -4261,7 +4281,6 @@ impl Context { Const, Some(UNSIZED_LIST), ); - self.register_builtin_type(lis_t, list_, vis.clone(), Const, Some(LIST)); self.register_builtin_type(mono(SLICE), slice, vis.clone(), Const, Some(FUNC_SLICE)); self.register_builtin_type( mono(GENERIC_SET), @@ -4274,21 +4293,6 @@ impl Context { self.register_builtin_type(g_dict_t, generic_dict, vis.clone(), Const, Some(DICT)); self.register_builtin_type(dict_t, dict_, vis.clone(), Const, Some(DICT)); self.register_builtin_type(mono(BYTES), bytes, vis.clone(), Const, Some(BYTES)); - self.register_builtin_type( - mono(GENERIC_TUPLE), - generic_tuple, - vis.clone(), - Const, - Some(FUNC_TUPLE), - ); - self.register_builtin_type( - homo_tuple_t, - homo_tuple, - vis.clone(), - Const, - Some(FUNC_TUPLE), - ); - self.register_builtin_type(_tuple_t, tuple_, vis.clone(), Const, Some(FUNC_TUPLE)); self.register_builtin_type(mono(RECORD), record, vis.clone(), Const, Some(RECORD)); self.register_builtin_type( mono(RECORD_META_TYPE), @@ -4411,7 +4415,6 @@ impl Context { Some(MEMORYVIEW), ); self.register_builtin_type(mono(MUT_FILE), file_mut, vis.clone(), Const, Some(FILE)); - self.register_builtin_type(list_mut_t, list_mut_, vis.clone(), Const, Some(LIST)); self.register_builtin_type( bytearray_mut_t, bytearray_mut, diff --git a/crates/erg_compiler/context/initialize/traits.rs b/crates/erg_compiler/context/initialize/traits.rs index 40bad3696..ba94cb30b 100644 --- a/crates/erg_compiler/context/initialize/traits.rs +++ b/crates/erg_compiler/context/initialize/traits.rs @@ -588,10 +588,10 @@ impl Context { neg.register_builtin_erg_decl(OP_NEG, op_t, Visibility::BUILTIN_PUBLIC); neg.register_builtin_erg_decl(OUTPUT, Type, Visibility::BUILTIN_PUBLIC); /* Num */ - let mut num = Self::builtin_mono_trait(NUM, 2); - num.register_superclass(poly(ADD, vec![]), &add); - num.register_superclass(poly(SUB, vec![]), &sub); - num.register_superclass(poly(MUL, vec![]), &mul); + let num = Self::builtin_mono_trait(NUM, 2); + // num.register_superclass(poly(ADD, vec![]), &add); + // num.register_superclass(poly(SUB, vec![]), &sub); + // num.register_superclass(poly(MUL, vec![]), &mul); /* ToBool */ let mut to_bool = Self::builtin_mono_trait(TO_BOOL, 2); let _Slf = mono_q(SELF, subtypeof(mono(TO_BOOL))); diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 6a6e32b1a..30be2cfc8 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -1046,35 +1046,35 @@ impl Context { } Type::Structural(t) => self.get_attr_info_from_attributive(t, ident, namespace), // TODO: And - Type::Or(l, r) => { - let l_info = self.get_attr_info_from_attributive(l, ident, namespace); - let r_info = self.get_attr_info_from_attributive(r, ident, namespace); - match (l_info, r_info) { - (Triple::Ok(l), Triple::Ok(r)) => { - let res = self.union(&l.t, &r.t); - let vis = if l.vis.is_public() && r.vis.is_public() { - Visibility::DUMMY_PUBLIC - } else { - Visibility::DUMMY_PRIVATE - }; - let vi = VarInfo::new( - res, - l.muty, - vis, - l.kind, - l.comptime_decos, - l.ctx, - l.py_name, - l.def_loc, - ); - Triple::Ok(vi) - } - (Triple::Ok(_), Triple::Err(e)) | (Triple::Err(e), Triple::Ok(_)) => { - Triple::Err(e) + Type::Or(tys) => { + let mut info = Triple::::None; + for ty in tys { + match ( + self.get_attr_info_from_attributive(ty, ident, namespace), + &info, + ) { + (Triple::Ok(vi), Triple::Ok(vi_)) => { + let res = self.union(&vi.t, &vi_.t); + let vis = if vi.vis.is_public() && vi_.vis.is_public() { + Visibility::DUMMY_PUBLIC + } else { + Visibility::DUMMY_PRIVATE + }; + let vi = VarInfo { t: res, vis, ..vi }; + info = Triple::Ok(vi); + } + (Triple::Ok(vi), Triple::None) => { + info = Triple::Ok(vi); + } + (Triple::Err(err), _) => { + info = Triple::Err(err); + break; + } + (Triple::None, _) => {} + (_, Triple::Err(_)) => unreachable!(), } - (Triple::Err(e1), Triple::Err(_e2)) => Triple::Err(e1), - _ => Triple::None, } + info } _other => Triple::None, } @@ -1952,7 +1952,7 @@ impl Context { res } } - Type::And(_, _) => { + Type::And(_) => { let instance = self.resolve_overload( obj, instance.clone(), @@ -3012,32 +3012,30 @@ impl Context { self.get_nominal_super_type_ctxs(&Type) } } - Type::And(l, r) => { - match ( - self.get_nominal_super_type_ctxs(l), - self.get_nominal_super_type_ctxs(r), - ) { - // TODO: sort - (Some(l), Some(r)) => Some([l, r].concat()), - (Some(l), None) => Some(l), - (None, Some(r)) => Some(r), - (None, None) => None, - } - } - // TODO - Type::Or(l, r) => match (l.as_ref(), r.as_ref()) { - (Type::FreeVar(l), Type::FreeVar(r)) - if l.is_unbound_and_sandwiched() && r.is_unbound_and_sandwiched() => + Type::And(tys) => { + let mut acc = vec![]; + for ctxs in tys + .iter() + .filter_map(|t| self.get_nominal_super_type_ctxs(t)) { - let (_lsub, lsup) = l.get_subsup().unwrap(); - let (_rsub, rsup) = r.get_subsup().unwrap(); - self.get_nominal_super_type_ctxs(&self.union(&lsup, &rsup)) + acc.extend(ctxs); } - (Type::Refinement(l), Type::Refinement(r)) if l.t == r.t => { - self.get_nominal_super_type_ctxs(&l.t) + if acc.is_empty() { + None + } else { + Some(acc) } - _ => self.get_nominal_type_ctx(&Obj).map(|ctx| vec![ctx]), - }, + } + Type::Or(tys) => { + let union = tys + .iter() + .fold(Never, |l, r| self.union(&l, &r.upper_bounded())); + if union.is_union_type() { + self.get_nominal_super_type_ctxs(&Obj) + } else { + self.get_nominal_super_type_ctxs(&union) + } + } _ => self .get_simple_nominal_super_type_ctxs(t) .map(|ctxs| ctxs.collect()), @@ -3231,7 +3229,7 @@ impl Context { .unwrap_or(self) .rec_local_get_mono_type("GenericNamedTuple"); } - Type::Or(_l, _r) => { + Type::Or(_) => { if let Some(ctx) = self.get_nominal_type_ctx(&poly("Or", vec![])) { return Some(ctx); } @@ -3366,26 +3364,27 @@ impl Context { match trait_ { // And(Add, Sub) == intersection({Int <: Add(Int), Bool <: Add(Bool) ...}, {Int <: Sub(Int), ...}) // == {Int <: Add(Int) and Sub(Int), ...} - Type::And(l, r) => { - let l_impls = self.get_trait_impls(l); - let l_base = Set::from_iter(l_impls.iter().map(|ti| &ti.sub_type)); - let r_impls = self.get_trait_impls(r); - let r_base = Set::from_iter(r_impls.iter().map(|ti| &ti.sub_type)); - let bases = l_base.intersection(&r_base); + Type::And(tys) => { + let impls = tys + .iter() + .flat_map(|ty| self.get_trait_impls(ty)) + .collect::>(); + let bases = impls.iter().map(|ti| &ti.sub_type); let mut isec = set! {}; - for base in bases.into_iter() { - let lti = l_impls.iter().find(|ti| &ti.sub_type == base).unwrap(); - let rti = r_impls.iter().find(|ti| &ti.sub_type == base).unwrap(); - let sup_trait = self.intersection(<i.sup_trait, &rti.sup_trait); - isec.insert(TraitImpl::new(lti.sub_type.clone(), sup_trait, None)); + for base in bases { + let base_impls = impls.iter().filter(|ti| ti.sub_type == *base); + let sup_trait = + base_impls.fold(Obj, |l, r| self.intersection(&l, &r.sup_trait)); + if sup_trait != Obj { + isec.insert(TraitImpl::new(base.clone(), sup_trait, None)); + } } isec } - Type::Or(l, r) => { - let l_impls = self.get_trait_impls(l); - let r_impls = self.get_trait_impls(r); + Type::Or(tys) => { // FIXME: - l_impls.union(&r_impls) + tys.iter() + .fold(set! {}, |acc, ty| acc.union(&self.get_trait_impls(ty))) } _ => self.get_simple_trait_impls(trait_), } @@ -3955,11 +3954,11 @@ impl Context { pub fn is_class(&self, typ: &Type) -> bool { match typ { - Type::And(_l, _r) => false, + Type::And(_) => false, Type::Never => true, Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()), Type::FreeVar(_) => false, - Type::Or(l, r) => self.is_class(l) && self.is_class(r), + Type::Or(tys) => tys.iter().all(|t| self.is_class(t)), Type::Proj { lhs, rhs } => self .get_proj_candidates(lhs, rhs) .iter() @@ -3982,7 +3981,8 @@ impl Context { Type::Never => false, Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()), Type::FreeVar(_) => false, - Type::And(l, r) | Type::Or(l, r) => self.is_trait(l) && self.is_trait(r), + Type::And(tys) => tys.iter().any(|t| self.is_trait(t)), + Type::Or(tys) => tys.iter().all(|t| self.is_trait(t)), Type::Proj { lhs, rhs } => self .get_proj_candidates(lhs, rhs) .iter() diff --git a/crates/erg_compiler/context/instantiate.rs b/crates/erg_compiler/context/instantiate.rs index 42d0c36f6..be4028548 100644 --- a/crates/erg_compiler/context/instantiate.rs +++ b/crates/erg_compiler/context/instantiate.rs @@ -953,15 +953,21 @@ impl Context { let t = self.instantiate_t_inner(*t, tmp_tv_cache, loc)?; Ok(t.structuralize()) } - And(l, r) => { - let l = self.instantiate_t_inner(*l, tmp_tv_cache, loc)?; - let r = self.instantiate_t_inner(*r, tmp_tv_cache, loc)?; - Ok(self.intersection(&l, &r)) + And(tys) => { + let mut new_tys = vec![]; + for ty in tys.iter().cloned() { + new_tys.push(self.instantiate_t_inner(ty, tmp_tv_cache, loc)?); + } + Ok(new_tys + .into_iter() + .fold(Obj, |l, r| self.intersection(&l, &r))) } - Or(l, r) => { - let l = self.instantiate_t_inner(*l, tmp_tv_cache, loc)?; - let r = self.instantiate_t_inner(*r, tmp_tv_cache, loc)?; - Ok(self.union(&l, &r)) + Or(tys) => { + let mut new_tys = vec![]; + for ty in tys.iter().cloned() { + new_tys.push(self.instantiate_t_inner(ty, tmp_tv_cache, loc)?); + } + Ok(new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r))) } Not(ty) => { let ty = self.instantiate_t_inner(*ty, tmp_tv_cache, loc)?; @@ -998,10 +1004,12 @@ impl Context { let t = fv.crack().clone(); self.instantiate(t, callee) } - And(lhs, rhs) => { - let lhs = self.instantiate(*lhs, callee)?; - let rhs = self.instantiate(*rhs, callee)?; - Ok(lhs & rhs) + And(tys) => { + let tys = tys + .into_iter() + .map(|t| self.instantiate(t, callee)) + .collect::>>()?; + Ok(tys.into_iter().fold(Obj, |l, r| l & r)) } Quantified(quant) => { let mut tmp_tv_cache = TyVarCache::new(self.level, self); @@ -1028,22 +1036,16 @@ impl Context { )?; } } - Type::And(l, r) => { - if let Some(self_t) = l.self_t() { - self.sub_unify( - callee.ref_t(), - self_t, - callee, - Some(&Str::ever("self")), - )?; - } - if let Some(self_t) = r.self_t() { - self.sub_unify( - callee.ref_t(), - self_t, - callee, - Some(&Str::ever("self")), - )?; + Type::And(tys) => { + for ty in tys { + if let Some(self_t) = ty.self_t() { + self.sub_unify( + callee.ref_t(), + self_t, + callee, + Some(&Str::ever("self")), + )?; + } } } other => unreachable!("{other}"), @@ -1066,10 +1068,12 @@ impl Context { let t = fv.crack().clone(); self.instantiate_dummy(t) } - And(lhs, rhs) => { - let lhs = self.instantiate_dummy(*lhs)?; - let rhs = self.instantiate_dummy(*rhs)?; - Ok(lhs & rhs) + And(tys) => { + let tys = tys + .into_iter() + .map(|t| self.instantiate_dummy(t)) + .collect::>>()?; + Ok(tys.into_iter().fold(Obj, |l, r| l & r)) } Quantified(quant) => { let mut tmp_tv_cache = TyVarCache::new(self.level, self); diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index cb3987259..fea7e0dcc 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -70,6 +70,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { /// occur(?T(<: Str) or ?U(<: Int), ?T(<: Str)) ==> Error /// occur(?T(<: ?U or Y), ?U) ==> OK /// occur(?T, ?T.Output) ==> OK + /// occur(?T, ?T or Int) ==> Error /// ``` fn occur(&self, maybe_sub: &Type, maybe_sup: &Type) -> TyCheckResult<()> { if maybe_sub == maybe_sup { @@ -155,17 +156,71 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (Or(l, r), Or(l2, r2)) | (And(l, r), And(l2, r2)) => self - .occur(l, l2) - .and(self.occur(r, r2)) - .or(self.occur(l, r2).and(self.occur(r, l2))), - (lhs, Or(l, r)) | (lhs, And(l, r)) => { - self.occur_inner(lhs, l)?; - self.occur_inner(lhs, r) + // FIXME: This is not correct, we must visit all permutations of the types + (And(l), And(r)) if l.len() == r.len() => { + let mut r = r.clone(); + for _ in 0..r.len() { + if l.iter() + .zip(r.iter()) + .all(|(l, r)| self.occur_inner(l, r).is_ok()) + { + return Ok(()); + } + r.rotate_left(1); + } + Err(TyCheckErrors::from(TyCheckError::subtyping_error( + self.ctx.cfg.input.clone(), + line!() as usize, + maybe_sub, + maybe_sup, + self.loc.loc(), + self.ctx.caused_by(), + ))) } - (Or(l, r), rhs) | (And(l, r), rhs) => { - self.occur_inner(l, rhs)?; - self.occur_inner(r, rhs) + (Or(l), Or(r)) if l.len() == r.len() => { + let l = l.to_vec(); + let mut r = r.to_vec(); + for _ in 0..r.len() { + if l.iter() + .zip(r.iter()) + .all(|(l, r)| self.occur_inner(l, r).is_ok()) + { + return Ok(()); + } + r.rotate_left(1); + } + Err(TyCheckErrors::from(TyCheckError::subtyping_error( + self.ctx.cfg.input.clone(), + line!() as usize, + maybe_sub, + maybe_sup, + self.loc.loc(), + self.ctx.caused_by(), + ))) + } + (lhs, And(tys)) => { + for ty in tys.iter() { + self.occur_inner(lhs, ty)?; + } + Ok(()) + } + (lhs, Or(tys)) => { + for ty in tys.iter() { + self.occur_inner(lhs, ty)?; + } + Ok(()) + } + (And(tys), rhs) => { + for ty in tys.iter() { + self.occur_inner(ty, rhs)?; + } + Ok(()) + } + (Or(tys), rhs) => { + for ty in tys.iter() { + self.occur_inner(ty, rhs)?; + } + Ok(()) } _ => Ok(()), } @@ -266,13 +321,29 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (lhs, Or(l, r)) | (lhs, And(l, r)) => { - self.occur_inner(lhs, l)?; - self.occur_inner(lhs, r) + (lhs, And(tys)) => { + for ty in tys.iter() { + self.occur_inner(lhs, ty)?; + } + Ok(()) + } + (lhs, Or(tys)) => { + for ty in tys.iter() { + self.occur_inner(lhs, ty)?; + } + Ok(()) } - (Or(l, r), rhs) | (And(l, r), rhs) => { - self.occur_inner(l, rhs)?; - self.occur_inner(r, rhs) + (And(tys), rhs) => { + for ty in tys.iter() { + self.occur_inner(ty, rhs)?; + } + Ok(()) + } + (Or(tys), rhs) => { + for ty in tys.iter() { + self.occur_inner(ty, rhs)?; + } + Ok(()) } _ => Ok(()), } @@ -1227,38 +1298,66 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { // self.sub_unify(&lsub, &union, loc, param_name)?; maybe_sup.update_tyvar(union, intersec, self.undoable, false); } + // TODO: Preferentially compare same-structure types (e.g. K(?T) <: K(?U)) + (And(ltys), And(rtys)) => { + let mut ltys_ = ltys.clone(); + let mut rtys_ = rtys.clone(); + // Show and EqHash and T <: Eq and Show and Ord + // => EqHash and T <: Eq and Ord + for lty in ltys.iter() { + if let Some(idx) = rtys_.iter().position(|r| r == lty) { + rtys_.remove(idx); + let idx = ltys_.iter().position(|l| l == lty).unwrap(); + ltys_.remove(idx); + } + } + // EqHash and T <: Eq and Ord + for lty in ltys_.iter() { + // lty: EqHash + // rty: Eq, Ord + for rty in rtys_.iter() { + if self.ctx.subtype_of(lty, rty) { + self.sub_unify(lty, rty)?; + continue; + } + } + } + } + // TODO: Preferentially compare same-structure types (e.g. K(?T) <: K(?U)) + // Nat or Str or NoneType <: NoneType or ?T or Int + // => Str <: ?T // (Int or ?T) <: (?U or Int) // OK: (Int <: Int); (?T <: ?U) // NG: (Int <: ?U); (?T <: Int) - (Or(l1, r1), Or(l2, r2)) | (And(l1, r1), And(l2, r2)) => { - if self.ctx.subtype_of(l1, l2) && self.ctx.subtype_of(r1, r2) { - let (l_sup, r_sup) = if !l1.is_unbound_var() - && !r2.is_unbound_var() - && self.ctx.subtype_of(l1, r2) - { - (r2, l2) - } else { - (l2, r2) - }; - self.sub_unify(l1, l_sup)?; - self.sub_unify(r1, r_sup)?; - } else { - self.sub_unify(l1, r2)?; - self.sub_unify(r1, l2)?; + (Or(ltys), Or(rtys)) => { + let mut ltys_ = ltys.clone(); + let mut rtys_ = rtys.clone(); + // Nat or T or Str <: Str or Int or NoneType + // => Nat or T <: Int or NoneType + for lty in ltys { + if rtys_.linear_remove(lty) { + ltys_.linear_remove(lty); + } + } + // Nat or T <: Int or NoneType + for lty in ltys_.iter() { + // lty: Nat + // rty: Int, NoneType + for rty in rtys_.iter() { + if self.ctx.subtype_of(lty, rty) { + self.sub_unify(lty, rty)?; + continue; + } + } } } // NG: Nat <: ?T or Int ==> Nat or Int (?T = Nat) // OK: Nat <: ?T or Int ==> ?T or Int - (sub, Or(l, r)) - if l.is_unbound_var() - && !sub.is_unbound_var() - && !r.is_unbound_var() - && self.ctx.subtype_of(sub, r) => {} - (sub, Or(l, r)) - if r.is_unbound_var() - && !sub.is_unbound_var() - && !l.is_unbound_var() - && self.ctx.subtype_of(sub, l) => {} + (sub, Or(tys)) + if !sub.is_unbound_var() + && tys + .iter() + .any(|ty| !ty.is_unbound_var() && self.ctx.subtype_of(sub, ty)) => {} // e.g. Structural({ .method = (self: T) -> Int })/T (Structural(sub), FreeVar(sup_fv)) if sup_fv.is_unbound() && sub.contains_tvar(sup_fv) => {} @@ -1622,30 +1721,34 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } } // (X or Y) <: Z is valid when X <: Z and Y <: Z - (Or(l, r), _) => { - self.sub_unify(l, maybe_sup)?; - self.sub_unify(r, maybe_sup)?; + (Or(tys), _) => { + for ty in tys { + self.sub_unify(ty, maybe_sup)?; + } } // X <: (Y and Z) is valid when X <: Y and X <: Z - (_, And(l, r)) => { - self.sub_unify(maybe_sub, l)?; - self.sub_unify(maybe_sub, r)?; + (_, And(tys)) => { + for ty in tys { + self.sub_unify(maybe_sub, ty)?; + } } // (X and Y) <: Z is valid when X <: Z or Y <: Z - (And(l, r), _) => { - if self.ctx.subtype_of(l, maybe_sup) { - self.sub_unify(l, maybe_sup)?; - } else { - self.sub_unify(r, maybe_sup)?; + (And(tys), _) => { + for ty in tys { + if self.ctx.subtype_of(ty, maybe_sup) { + return self.sub_unify(ty, maybe_sup); + } } + self.sub_unify(tys.iter().next().unwrap(), maybe_sup)?; } // X <: (Y or Z) is valid when X <: Y or X <: Z - (_, Or(l, r)) => { - if self.ctx.subtype_of(maybe_sub, l) { - self.sub_unify(maybe_sub, l)?; - } else { - self.sub_unify(maybe_sub, r)?; + (_, Or(tys)) => { + for ty in tys { + if self.ctx.subtype_of(maybe_sub, ty) { + return self.sub_unify(maybe_sub, ty); + } } + self.sub_unify(maybe_sub, tys.iter().next().unwrap())?; } (Ref(sub), Ref(sup)) => { self.sub_unify(sub, sup)?; @@ -1887,27 +1990,35 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { /// ``` fn unify(&self, lhs: &Type, rhs: &Type) -> Option { match (lhs, rhs) { - (Type::Or(l, r), other) | (other, Type::Or(l, r)) => { - if let Some(t) = self.unify(l, other) { - return self.unify(&t, l); - } else if let Some(t) = self.unify(r, other) { - return self.unify(&t, l); + (Never, other) | (other, Never) => { + return Some(other.clone()); + } + (Or(tys), other) | (other, Or(tys)) => { + let mut unified = Never; + for ty in tys { + if let Some(t) = self.unify(ty, other) { + unified = self.ctx.union(&unified, &t); + } + } + if unified != Never { + return Some(unified); + } else { + return None; } - return None; } - (Type::FreeVar(fv), _) if fv.is_linked() => return self.unify(&fv.crack(), rhs), - (_, Type::FreeVar(fv)) if fv.is_linked() => return self.unify(lhs, &fv.crack()), + (FreeVar(fv), _) if fv.is_linked() => return self.unify(&fv.crack(), rhs), + (_, FreeVar(fv)) if fv.is_linked() => return self.unify(lhs, &fv.crack()), // TODO: unify(?T, ?U) ? - (Type::FreeVar(_), Type::FreeVar(_)) => {} - (Type::FreeVar(fv), _) if fv.constraint_is_sandwiched() => { + (FreeVar(_), FreeVar(_)) => {} + (FreeVar(fv), _) if fv.constraint_is_sandwiched() => { let sub = fv.get_sub()?; return self.unify(&sub, rhs); } - (_, Type::FreeVar(fv)) if fv.constraint_is_sandwiched() => { + (_, FreeVar(fv)) if fv.constraint_is_sandwiched() => { let sub = fv.get_sub()?; return self.unify(lhs, &sub); } - (Type::Refinement(lhs), Type::Refinement(rhs)) => { + (Refinement(lhs), Refinement(rhs)) => { if let Some(_union) = self.unify(&lhs.t, &rhs.t) { return Some(self.ctx.union_refinement(lhs, rhs).into()); } @@ -1917,11 +2028,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { let l_sups = self.ctx.get_super_classes(lhs)?; let r_sups = self.ctx.get_super_classes(rhs)?; for l_sup in l_sups { - if self.ctx.supertype_of(&l_sup, &Obj) { + if l_sup == Obj || self.ctx.is_trait(&l_sup) { continue; } for r_sup in r_sups.clone() { - if self.ctx.supertype_of(&r_sup, &Obj) { + if r_sup == Obj || self.ctx.is_trait(&r_sup) { continue; } if let Some(t) = self.ctx.max(&l_sup, &r_sup).either() { diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 04d161b48..5d6724137 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -362,10 +362,9 @@ impl GenericASTLowerer { } } - fn elem_err(&self, l: &Type, r: &Type, elem: &hir::Expr) -> LowerErrors { + fn elem_err(&self, union: Type, elem: &hir::Expr) -> LowerErrors { let elem_disp_notype = elem.to_string_notype(); - let l = self.module.context.readable_type(l.clone()); - let r = self.module.context.readable_type(r.clone()); + let union = self.module.context.readable_type(union); LowerErrors::from(LowerError::syntax_error( self.cfg.input.clone(), line!() as usize, @@ -379,10 +378,10 @@ impl GenericASTLowerer { ) .to_owned(), Some(switch_lang!( - "japanese" => format!("[..., {elem_disp_notype}: {l} or {r}]など明示的に型を指定してください"), - "simplified_chinese" => format!("请明确指定类型,例如: [..., {elem_disp_notype}: {l} or {r}]"), - "traditional_chinese" => format!("請明確指定類型,例如: [..., {elem_disp_notype}: {l} or {r}]"), - "english" => format!("please specify the type explicitly, e.g. [..., {elem_disp_notype}: {l} or {r}]"), + "japanese" => format!("[..., {elem_disp_notype}: {union}]など明示的に型を指定してください"), + "simplified_chinese" => format!("请明确指定类型,例如: [..., {elem_disp_notype}: {union}]"), + "traditional_chinese" => format!("請明確指定類型,例如: [..., {elem_disp_notype}: {union}]"), + "english" => format!("please specify the type explicitly, e.g. [..., {elem_disp_notype}: {union}]"), )), )) } @@ -453,36 +452,25 @@ impl GenericASTLowerer { union: &Type, elem: &hir::Expr, ) -> LowerResult<()> { - if ERG_MODE && expect_elem.is_none() { - if let Some((l, r)) = union_.union_pair() { - match (l.is_unbound_var(), r.is_unbound_var()) { - // e.g. [1, "a"] - (false, false) => { - if let hir::Expr::TypeAsc(type_asc) = elem { - // e.g. [1, "a": Str or NoneType] - if !self - .module - .context - .supertype_of(&type_asc.spec.spec_t, union) - { - return Err(self.elem_err(&l, &r, elem)); - } // else(OK): e.g. [1, "a": Str or Int] - } - // OK: ?T(:> {"a"}) or ?U(:> {"b"}) or {"c", "d"} => {"a", "b", "c", "d"} <: Str - else if self - .module - .context - .coerce(union_.derefine(), &()) - .map_or(true, |coerced| coerced.union_pair().is_some()) - { - return Err(self.elem_err(&l, &r, elem)); - } - } - // TODO: check if the type is compatible with the other type - (true, false) => {} - (false, true) => {} - (true, true) => {} - } + if ERG_MODE && expect_elem.is_none() && union_.union_size() > 1 { + if let hir::Expr::TypeAsc(type_asc) = elem { + // e.g. [1, "a": Str or NoneType] + if !self + .module + .context + .supertype_of(&type_asc.spec.spec_t, union) + { + return Err(self.elem_err(union_.clone(), elem)); + } // else(OK): e.g. [1, "a": Str or Int] + } + // OK: ?T(:> {"a"}) or ?U(:> {"b"}) or {"c", "d"} => {"a", "b", "c", "d"} <: Str + else if self + .module + .context + .coerce(union_.derefine(), &()) + .map_or(true, |coerced| coerced.union_pair().is_some()) + { + return Err(self.elem_err(union_.clone(), elem)); } } Ok(()) @@ -1502,9 +1490,10 @@ impl GenericASTLowerer { } _ => {} }, - Type::And(lhs, rhs) => { - self.push_guard(nth, kind, lhs); - self.push_guard(nth, kind, rhs); + Type::And(tys) => { + for ty in tys { + self.push_guard(nth, kind, ty); + } } _ => {} } diff --git a/crates/erg_compiler/module/promise.rs b/crates/erg_compiler/module/promise.rs index 42525989e..e57b5ea68 100644 --- a/crates/erg_compiler/module/promise.rs +++ b/crates/erg_compiler/module/promise.rs @@ -1,7 +1,7 @@ use std::fmt; use std::thread::{current, JoinHandle, ThreadId}; -use erg_common::consts::DEBUG_MODE; +use erg_common::consts::{DEBUG_MODE, SINGLE_THREAD}; use erg_common::dict::Dict; use erg_common::pathutil::NormalizedPathBuf; use erg_common::shared::Shared; @@ -169,12 +169,19 @@ impl SharedPromises { } pub fn join(&self, path: &NormalizedPathBuf) -> std::thread::Result<()> { + if !self.graph.entries().contains(path) { + return Err(Box::new(format!("not registered: {path}"))); + } if self.graph.ancestors(path).contains(&self.root) { // cycle detected, `self.path` must not in the dependencies // Erg analysis processes never join ancestor threads (although joining ancestors itself is allowed in Rust) // self.wait_until_finished(path); return Ok(()); } + if SINGLE_THREAD { + assert!(self.is_joined(path)); + return Ok(()); + } // Suppose A depends on B and C, and B depends on C. // In this case, B must join C before A joins C. Otherwise, a deadlock will occur. let children = self.graph.children(path); diff --git a/crates/erg_compiler/ty/constructors.rs b/crates/erg_compiler/ty/constructors.rs index 3e78dc33c..eb9ffd61d 100644 --- a/crates/erg_compiler/ty/constructors.rs +++ b/crates/erg_compiler/ty/constructors.rs @@ -593,35 +593,11 @@ pub fn refinement(var: Str, t: Type, pred: Predicate) -> Type { } pub fn and(lhs: Type, rhs: Type) -> Type { - match (lhs, rhs) { - (Type::And(l, r), other) | (other, Type::And(l, r)) => { - if l.as_ref() == &other { - and(*r, other) - } else if r.as_ref() == &other { - and(*l, other) - } else { - Type::And(Box::new(Type::And(l, r)), Box::new(other)) - } - } - (Type::Obj, other) | (other, Type::Obj) => other, - (lhs, rhs) => Type::And(Box::new(lhs), Box::new(rhs)), - } + lhs & rhs } pub fn or(lhs: Type, rhs: Type) -> Type { - match (lhs, rhs) { - (Type::Or(l, r), other) | (other, Type::Or(l, r)) => { - if l.as_ref() == &other { - or(*r, other) - } else if r.as_ref() == &other { - or(*l, other) - } else { - Type::Or(Box::new(Type::Or(l, r)), Box::new(other)) - } - } - (Type::Never, other) | (other, Type::Never) => other, - (lhs, rhs) => Type::Or(Box::new(lhs), Box::new(rhs)), - } + lhs | rhs } pub fn ors(tys: impl IntoIterator) -> Type { diff --git a/crates/erg_compiler/ty/free.rs b/crates/erg_compiler/ty/free.rs index 4cb6d0cf1..e9cad3bfb 100644 --- a/crates/erg_compiler/ty/free.rs +++ b/crates/erg_compiler/ty/free.rs @@ -774,7 +774,12 @@ impl Free { let placeholder = placeholder.unwrap_or(&Type::Failure); let is_recursive = self.is_recursive(); if is_recursive { - self.undoable_link(placeholder); + let target = Type::FreeVar(self.clone()); + let placeholder_ = placeholder + .clone() + .eliminate_subsup(&target) + .eliminate_and_or_recursion(&target); + self.undoable_link(&placeholder_); } let res = f(); if is_recursive { @@ -884,7 +889,9 @@ impl Free { let placeholder = placeholder.unwrap_or(&TyParam::Failure); let is_recursive = self.is_recursive(); if is_recursive { - self.undoable_link(placeholder); + let target = TyParam::FreeVar(self.clone()); + let placeholder_ = placeholder.clone().eliminate_recursion(&target); + self.undoable_link(&placeholder_); } let res = f(); if is_recursive { diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index 32195de6e..32f782779 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -537,48 +537,15 @@ impl SubrType { } pub fn contains_tvar(&self, target: &FreeTyVar) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().contains_tvar(target)) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().contains_tvar(target)) - || self.default_params.iter().any(|pt| { - pt.typ().contains_tvar(target) - || pt.default_typ().is_some_and(|t| t.contains_tvar(target)) - }) - || self.return_t.contains_tvar(target) + self.has_type_satisfies(|t| t.contains_tvar(target)) } pub fn contains_type(&self, target: &Type) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().contains_type(target)) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().contains_type(target)) - || self.default_params.iter().any(|pt| { - pt.typ().contains_type(target) - || pt.default_typ().is_some_and(|t| t.contains_type(target)) - }) - || self.return_t.contains_type(target) + self.has_type_satisfies(|t| t.contains_type(target)) } pub fn contains_tp(&self, target: &TyParam) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().contains_tp(target)) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().contains_tp(target)) - || self.default_params.iter().any(|pt| { - pt.typ().contains_tp(target) - || pt.default_typ().is_some_and(|t| t.contains_tp(target)) - }) - || self.return_t.contains_tp(target) + self.has_type_satisfies(|t| t.contains_tp(target)) } pub fn map(self, f: &mut impl FnMut(Type) -> Type) -> Self { @@ -708,48 +675,27 @@ impl SubrType { Set::multi_intersection(qnames_sets).extended(structural_qname) } - pub fn has_qvar(&self) -> bool { - self.non_default_params.iter().any(|pt| pt.typ().has_qvar()) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().has_qvar()) + pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { + self.non_default_params.iter().any(|pt| f(pt.typ())) + || self.var_params.as_ref().map_or(false, |pt| f(pt.typ())) || self .default_params .iter() - .any(|pt| pt.typ().has_qvar() || pt.default_typ().is_some_and(|t| t.has_qvar())) - || self.return_t.has_qvar() + .any(|pt| f(pt.typ()) || pt.default_typ().is_some_and(f)) + || self.kw_var_params.as_ref().map_or(false, |pt| f(pt.typ())) + || f(&self.return_t) + } + + pub fn has_qvar(&self) -> bool { + self.has_type_satisfies(|t| t.has_qvar()) } pub fn has_unbound_var(&self) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().has_unbound_var()) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().has_unbound_var()) - || self.default_params.iter().any(|pt| { - pt.typ().has_unbound_var() || pt.default_typ().is_some_and(|t| t.has_unbound_var()) - }) - || self.return_t.has_unbound_var() + self.has_type_satisfies(|t| t.has_unbound_var()) } pub fn has_undoable_linked_var(&self) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().has_undoable_linked_var()) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().has_undoable_linked_var()) - || self.default_params.iter().any(|pt| { - pt.typ().has_undoable_linked_var() - || pt - .default_typ() - .is_some_and(|t| t.has_undoable_linked_var()) - }) - || self.return_t.has_undoable_linked_var() + self.has_type_satisfies(|t| t.has_undoable_linked_var()) } pub fn typarams(&self) -> Vec { @@ -1410,8 +1356,8 @@ pub enum Type { Refinement(RefinementType), // e.g. |T: Type| T -> T Quantified(Box), - And(Box, Box), - Or(Box, Box), + And(Vec), + Or(Set), Not(Box), // NOTE: It was found that adding a new variant above `Poly` may cause a subtyping bug, // possibly related to enum internal numbering, but the cause is unknown. @@ -1504,8 +1450,10 @@ impl PartialEq for Type { (Self::NamedTuple(lhs), Self::NamedTuple(rhs)) => lhs == rhs, (Self::Refinement(l), Self::Refinement(r)) => l == r, (Self::Quantified(l), Self::Quantified(r)) => l == r, - (Self::And(_, _), Self::And(_, _)) => self.ands().linear_eq(&other.ands()), - (Self::Or(_, _), Self::Or(_, _)) => self.ors().linear_eq(&other.ors()), + (Self::And(l), Self::And(r)) => { + l.iter().collect::>().linear_eq(&r.iter().collect()) + } + (Self::Or(l), Self::Or(r)) => l.linear_eq(r), (Self::Not(l), Self::Not(r)) => l == r, ( Self::Poly { @@ -1659,20 +1607,28 @@ impl LimitedDisplay for Type { write!(f, "|")?; quantified.limited_fmt(f, limit - 1) } - Self::And(lhs, rhs) => { - lhs.limited_fmt(f, limit - 1)?; - write!(f, " and ")?; - rhs.limited_fmt(f, limit - 1) + Self::And(ands) => { + for (i, ty) in ands.iter().enumerate() { + if i > 0 { + write!(f, " and ")?; + } + ty.limited_fmt(f, limit - 1)?; + } + Ok(()) + } + Self::Or(ors) => { + for (i, ty) in ors.iter().enumerate() { + if i > 0 { + write!(f, " or ")?; + } + ty.limited_fmt(f, limit - 1)?; + } + Ok(()) } Self::Not(ty) => { write!(f, "not ")?; ty.limited_fmt(f, limit - 1) } - Self::Or(lhs, rhs) => { - lhs.limited_fmt(f, limit - 1)?; - write!(f, " or ")?; - rhs.limited_fmt(f, limit - 1) - } Self::Poly { name, params } => { write!(f, "{name}(")?; if !DEBUG_MODE && self.is_module() { @@ -1845,14 +1801,40 @@ impl From> for Type { impl BitAnd for Type { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { - Self::And(Box::new(self), Box::new(rhs)) + match (self, rhs) { + (Self::And(l), Self::And(r)) => Self::And([l, r].concat()), + (Self::Obj, other) | (other, Self::Obj) => other, + (Self::Never, _) | (_, Self::Never) => Self::Never, + (Self::And(mut l), r) => { + l.push(r); + Self::And(l) + } + (l, Self::And(mut r)) => { + r.push(l); + Self::And(r) + } + (l, r) => Self::checked_and(vec![l, r]), + } } } impl BitOr for Type { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { - Self::Or(Box::new(self), Box::new(rhs)) + match (self, rhs) { + (Self::Or(l), Self::Or(r)) => Self::Or(l.union(&r)), + (Self::Obj, _) | (_, Self::Obj) => Self::Obj, + (Self::Never, other) | (other, Self::Never) => other, + (Self::Or(mut l), r) => { + l.insert(r); + Self::Or(l) + } + (l, Self::Or(mut r)) => { + r.insert(l); + Self::Or(r) + } + (l, r) => Self::checked_or(set! {l, r}), + } } } @@ -1967,17 +1949,8 @@ impl HasLevel for Type { .filter_map(|o| *o) .min() } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - let l = lhs - .level() - .unwrap_or(GENERIC_LEVEL) - .min(rhs.level().unwrap_or(GENERIC_LEVEL)); - if l == GENERIC_LEVEL { - None - } else { - Some(l) - } - } + Self::And(tys) => tys.iter().filter_map(|t| t.level()).min(), + Self::Or(tys) => tys.iter().filter_map(|t| t.level()).min(), Self::Not(ty) => ty.level(), Self::Record(attrs) => attrs.values().filter_map(|t| t.level()).min(), Self::NamedTuple(attrs) => attrs.iter().filter_map(|(_, t)| t.level()).min(), @@ -2058,9 +2031,15 @@ impl HasLevel for Type { Self::Quantified(quant) => { quant.set_level(level); } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.set_level(level); - rhs.set_level(level); + Self::And(tys) => { + for t in tys.iter() { + t.set_level(level); + } + } + Self::Or(tys) => { + for t in tys.iter() { + t.set_level(level); + } } Self::Not(ty) => ty.set_level(level), Self::Record(attrs) => { @@ -2211,9 +2190,9 @@ impl StructuralEq for Type { (Self::Guard(l), Self::Guard(r)) => l.structural_eq(r), // NG: (l.structural_eq(l2) && r.structural_eq(r2)) // || (l.structural_eq(r2) && r.structural_eq(l2)) - (Self::And(_, _), Self::And(_, _)) => { - let self_ands = self.ands(); - let other_ands = other.ands(); + (Self::And(self_ands), Self::And(other_ands)) => { + let self_ands = self_ands.iter().collect::>(); + let other_ands = other_ands.iter().collect::>(); if self_ands.len() != other_ands.len() { return false; } @@ -2227,9 +2206,7 @@ impl StructuralEq for Type { } true } - (Self::Or(_, _), Self::Or(_, _)) => { - let self_ors = self.ors(); - let other_ors = other.ors(); + (Self::Or(self_ors), Self::Or(other_ors)) => { if self_ors.len() != other_ors.len() { return false; } @@ -2323,10 +2300,30 @@ impl Type { } } + pub fn checked_or(tys: Set) -> Self { + if tys.is_empty() { + panic!("tys is empty"); + } else if tys.len() == 1 { + tys.into_iter().next().unwrap() + } else { + Self::Or(tys) + } + } + + pub fn checked_and(tys: Vec) -> Self { + if tys.is_empty() { + panic!("tys is empty"); + } else if tys.len() == 1 { + tys.into_iter().next().unwrap() + } else { + Self::And(tys) + } + } + pub fn quantify(self) -> Self { debug_assert!(self.is_subr(), "{self} is not subr"); match self { - Self::And(lhs, rhs) => lhs.quantify() & rhs.quantify(), + Self::And(tys) => Self::And(tys.into_iter().map(|t| t.quantify()).collect()), other => Self::Quantified(Box::new(other)), } } @@ -2424,7 +2421,7 @@ impl Type { Self::Quantified(t) => t.is_procedure(), Self::Subr(subr) if subr.kind == SubrKind::Proc => true, Self::Refinement(refine) => refine.t.is_procedure(), - Self::And(lhs, rhs) => lhs.is_procedure() && rhs.is_procedure(), + Self::And(tys) => tys.iter().any(|t| t.is_procedure()), _ => false, } } @@ -2442,6 +2439,7 @@ impl Type { name.ends_with('!') } Self::Refinement(refine) => refine.t.is_mut_type(), + Self::And(tys) => tys.iter().any(|t| t.is_mut_type()), _ => false, } } @@ -2460,6 +2458,7 @@ impl Type { Self::Poly { name, params, .. } if &name[..] == "Tuple" => params.is_empty(), Self::Refinement(refine) => refine.t.is_nonelike(), Self::Bounded { sup, .. } => sup.is_nonelike(), + Self::And(tys) => tys.iter().any(|t| t.is_nonelike()), _ => false, } } @@ -2509,7 +2508,7 @@ impl Type { pub fn is_union_type(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_union_type(), - Self::Or(_, _) => true, + Self::Or(_) => true, Self::Refinement(refine) => refine.t.is_union_type(), _ => false, } @@ -2542,7 +2541,7 @@ impl Type { pub fn is_intersection_type(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_intersection_type(), - Self::And(_, _) => true, + Self::And(_) => true, Self::Refinement(refine) => refine.t.is_intersection_type(), _ => false, } @@ -2556,11 +2555,11 @@ impl Type { fv.do_avoiding_recursion(|| sub.union_size().max(sup.union_size())) } // Or(Or(Int, Str), Nat) == 3 - Self::Or(l, r) => l.union_size() + r.union_size(), + Self::Or(tys) => tys.len(), Self::Refinement(refine) => refine.t.union_size(), Self::Ref(t) => t.union_size(), Self::RefMut { before, after: _ } => before.union_size(), - Self::And(lhs, rhs) => lhs.union_size().max(rhs.union_size()), + Self::And(tys) => tys.iter().map(|ty| ty.union_size()).max().unwrap_or(1), Self::Not(ty) => ty.union_size(), Self::Callable { param_ts, return_t } => param_ts .iter() @@ -2601,7 +2600,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_refinement(), Self::Refinement(_) => true, - Self::And(l, r) => l.is_refinement() && r.is_refinement(), + Self::And(tys) => tys.iter().any(|t| t.is_refinement()), _ => false, } } @@ -2610,6 +2609,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_singleton_refinement(), Self::Refinement(refine) => matches!(refine.pred.as_ref(), Predicate::Equal { .. }), + Self::And(tys) => tys.iter().any(|t| t.is_singleton_refinement()), _ => false, } } @@ -2619,6 +2619,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_record(), Self::Record(_) => true, Self::Refinement(refine) => refine.t.is_record(), + Self::And(tys) => tys.iter().any(|t| t.is_record()), _ => false, } } @@ -2632,6 +2633,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_erg_module(), Self::Refinement(refine) => refine.t.is_erg_module(), Self::Poly { name, .. } => &name[..] == "Module", + Self::And(tys) => tys.iter().any(|t| t.is_erg_module()), _ => false, } } @@ -2641,6 +2643,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_py_module(), Self::Refinement(refine) => refine.t.is_py_module(), Self::Poly { name, .. } => &name[..] == "PyModule", + Self::And(tys) => tys.iter().any(|t| t.is_py_module()), _ => false, } } @@ -2651,7 +2654,7 @@ impl Type { Self::Refinement(refine) => refine.t.is_method(), Self::Subr(subr) => subr.is_method(), Self::Quantified(quant) => quant.is_method(), - Self::And(l, r) => l.is_method() && r.is_method(), + Self::And(tys) => tys.iter().any(|t| t.is_method()), _ => false, } } @@ -2662,7 +2665,7 @@ impl Type { Self::Subr(_) => true, Self::Quantified(quant) => quant.is_subr(), Self::Refinement(refine) => refine.t.is_subr(), - Self::And(l, r) => l.is_subr() && r.is_subr(), + Self::And(tys) => tys.iter().any(|t| t.is_subr()), _ => false, } } @@ -2673,7 +2676,10 @@ impl Type { Self::Subr(subr) => Some(subr.kind), Self::Refinement(refine) => refine.t.subr_kind(), Self::Quantified(quant) => quant.subr_kind(), - Self::And(l, r) => l.subr_kind().and_then(|k| r.subr_kind().map(|k2| k | k2)), + Self::And(tys) => tys + .iter() + .filter_map(|t| t.subr_kind()) + .fold(None, |a, b| Some(a? | b)), _ => None, } } @@ -2683,7 +2689,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_quantified_subr(), Self::Quantified(_) => true, Self::Refinement(refine) => refine.t.is_quantified_subr(), - Self::And(l, r) => l.is_quantified_subr() && r.is_quantified_subr(), + Self::And(tys) => tys.iter().any(|t| t.is_quantified_subr()), _ => false, } } @@ -2720,6 +2726,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_iterable(), Self::Poly { name, .. } => &name[..] == "Iterable", Self::Refinement(refine) => refine.t.is_iterable(), + Self::And(tys) => tys.iter().any(|t| t.is_iterable()), _ => false, } } @@ -2810,6 +2817,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_structural(), Self::Structural(_) => true, Self::Refinement(refine) => refine.t.is_structural(), + Self::And(tys) => tys.iter().any(|t| t.is_structural()), _ => false, } } @@ -2819,6 +2827,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_failure(), Self::Refinement(refine) => refine.t.is_failure(), Self::Failure => true, + Self::And(tys) => tys.iter().any(|t| t.is_failure()), _ => false, } } @@ -2890,92 +2899,37 @@ impl Type { }) .unwrap_or(false) } - Self::Record(rec) => rec.iter().any(|(_, t)| t.contains_tvar(target)), - Self::NamedTuple(rec) => rec.iter().any(|(_, t)| t.contains_tvar(target)), - Self::Poly { params, .. } => params.iter().any(|tp| tp.contains_tvar(target)), - Self::Quantified(t) => t.contains_tvar(target), - Self::Subr(subr) => subr.contains_tvar(target), - // TODO: preds - Self::Refinement(refine) => refine.t.contains_tvar(target), - Self::Structural(ty) => ty.contains_tvar(target), - Self::Proj { lhs, .. } => lhs.contains_tvar(target), - Self::ProjCall { lhs, args, .. } => { - lhs.contains_tvar(target) || args.iter().any(|t| t.contains_tvar(target)) - } - Self::And(lhs, rhs) => lhs.contains_tvar(target) || rhs.contains_tvar(target), - Self::Or(lhs, rhs) => lhs.contains_tvar(target) || rhs.contains_tvar(target), - Self::Not(t) => t.contains_tvar(target), - Self::Ref(t) => t.contains_tvar(target), - Self::RefMut { before, after } => { - before.contains_tvar(target) - || after.as_ref().map_or(false, |t| t.contains_tvar(target)) - } - Self::Bounded { sub, sup } => sub.contains_tvar(target) || sup.contains_tvar(target), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.contains_tvar(target)) || return_t.contains_tvar(target) - } - Self::Guard(guard) => guard.to.contains_tvar(target), - mono_type_pattern!() => false, + _ => self.has_type_satisfies(|t| t.contains_tvar(target)), } } pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { match self { - Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_type_satisfies(f), - Self::FreeVar(fv) if fv.constraint_is_typeof() => { - fv.get_type().unwrap().has_type_satisfies(f) - } + Self::FreeVar(fv) if fv.is_linked() => f(&fv.crack()), + Self::FreeVar(fv) if fv.constraint_is_typeof() => f(&fv.get_type().unwrap()), Self::FreeVar(fv) => fv .get_subsup() - .map(|(sub, sup)| { - fv.do_avoiding_recursion(|| { - sub.has_type_satisfies(f) || sup.has_type_satisfies(f) - }) - }) + .map(|(sub, sup)| fv.do_avoiding_recursion(|| f(&sub) || f(&sup))) .unwrap_or(false), - Self::Record(rec) => rec.iter().any(|(_, t)| t.has_type_satisfies(f)), - Self::NamedTuple(rec) => rec.iter().any(|(_, t)| t.has_type_satisfies(f)), + Self::Record(rec) => rec.values().any(f), + Self::NamedTuple(rec) => rec.iter().any(|(_, t)| f(t)), Self::Poly { params, .. } => params.iter().any(|tp| tp.has_type_satisfies(f)), - Self::Quantified(t) => t.has_type_satisfies(f), - Self::Subr(subr) => { - subr.non_default_params - .iter() - .any(|pt| pt.typ().has_type_satisfies(f)) - || subr - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().has_type_satisfies(f)) - || subr - .default_params - .iter() - .any(|pt| pt.typ().has_type_satisfies(f)) - || subr - .default_params - .iter() - .any(|pt| pt.default_typ().map_or(false, |t| t.has_type_satisfies(f))) - || subr.return_t.has_type_satisfies(f) - } - // TODO: preds - Self::Refinement(refine) => refine.t.has_type_satisfies(f), - Self::Structural(ty) => ty.has_type_satisfies(f), - Self::Proj { lhs, .. } => lhs.has_type_satisfies(f), + Self::Quantified(t) => f(t), + Self::Subr(subr) => subr.has_type_satisfies(f), + Self::Refinement(refine) => f(&refine.t) || refine.pred.has_type_satisfies(f), + Self::Structural(ty) => f(ty), + Self::Proj { lhs, .. } => f(lhs), Self::ProjCall { lhs, args, .. } => { - lhs.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f)) - } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) - } - Self::Not(t) => t.has_type_satisfies(f), - Self::Ref(t) => t.has_type_satisfies(f), - Self::RefMut { before, after } => { - before.has_type_satisfies(f) - || after.as_ref().map_or(false, |t| t.has_type_satisfies(f)) - } - Self::Bounded { sub, sup } => sub.has_type_satisfies(f) || sup.has_type_satisfies(f), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.has_type_satisfies(f)) || return_t.has_type_satisfies(f) - } - Self::Guard(guard) => guard.to.has_type_satisfies(f), + lhs.has_type_satisfies(f) || args.iter().any(|tp| tp.has_type_satisfies(f)) + } + Self::And(tys) => tys.iter().any(f), + Self::Or(tys) => tys.iter().any(f), + Self::Not(t) => f(t), + Self::Ref(t) => f(t), + Self::RefMut { before, after } => f(before) || after.as_ref().map_or(false, |t| f(t)), + Self::Bounded { sub, sup } => f(sub) || f(sup), + Self::Callable { param_ts, return_t } => param_ts.iter().any(f) || f(return_t), + Self::Guard(guard) => f(&guard.to), mono_type_pattern!() => false, } } @@ -3037,8 +2991,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_type(target) || args.iter().any(|t| t.contains_type(target)) } - Self::And(lhs, rhs) => lhs.contains_type(target) || rhs.contains_type(target), - Self::Or(lhs, rhs) => lhs.contains_type(target) || rhs.contains_type(target), + Self::And(tys) => tys.iter().any(|t| t.contains_type(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_type(target)), Self::Not(t) => t.contains_type(target), Self::Ref(t) => t.contains_type(target), Self::RefMut { before, after } => { @@ -3075,8 +3029,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_tp(target) || args.iter().any(|t| t.contains_tp(target)) } - Self::And(lhs, rhs) => lhs.contains_tp(target) || rhs.contains_tp(target), - Self::Or(lhs, rhs) => lhs.contains_tp(target) || rhs.contains_tp(target), + Self::And(tys) => tys.iter().any(|t| t.contains_tp(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_tp(target)), Self::Not(t) => t.contains_tp(target), Self::Ref(t) => t.contains_tp(target), Self::RefMut { before, after } => { @@ -3109,8 +3063,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_value(target) || args.iter().any(|t| t.contains_value(target)) } - Self::And(lhs, rhs) => lhs.contains_value(target) || rhs.contains_value(target), - Self::Or(lhs, rhs) => lhs.contains_value(target) || rhs.contains_value(target), + Self::And(tys) => tys.iter().any(|t| t.contains_value(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_value(target)), Self::Not(t) => t.contains_value(target), Self::Ref(t) => t.contains_value(target), Self::RefMut { before, after } => { @@ -3153,9 +3107,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_type(self) || args.iter().any(|t| t.contains_type(self)) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.contains_type(self) || rhs.contains_type(self) - } + Self::And(tys) => tys.iter().any(|t| t.contains_type(self)), + Self::Or(tys) => tys.iter().any(|t| t.contains_type(self)), Self::Not(t) => t.contains_type(self), Self::Ref(t) => t.contains_type(self), Self::RefMut { before, after } => { @@ -3225,9 +3178,9 @@ impl Type { Self::Inf => Str::ever("Inf"), Self::NegInf => Str::ever("NegInf"), Self::Mono(name) => name.clone(), - Self::And(_, _) => Str::ever("And"), + Self::And(_) => Str::ever("And"), Self::Not(_) => Str::ever("Not"), - Self::Or(_, _) => Str::ever("Or"), + Self::Or(_) => Str::ever("Or"), Self::Ref(_) => Str::ever("Ref"), Self::RefMut { .. } => Str::ever("RefMut"), Self::Subr(SubrType { @@ -3317,7 +3270,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().contains_intersec(typ), Self::Refinement(refine) => refine.t.contains_intersec(typ), - Self::And(t1, t2) => t1.contains_intersec(typ) || t2.contains_intersec(typ), + Self::And(tys) => tys.iter().any(|t| t.contains_intersec(typ)), _ => self == typ, } } @@ -3326,7 +3279,25 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_pair(), Self::Refinement(refine) => refine.t.union_pair(), - Self::Or(t1, t2) => Some((*t1.clone(), *t2.clone())), + Self::Or(tys) if tys.len() == 2 => { + let mut iter = tys.iter(); + Some((iter.next().unwrap().clone(), iter.next().unwrap().clone())) + } + Self::Or(tys) => { + let mut iter = tys.iter(); + let t1 = iter.next().unwrap().clone(); + let t2 = iter.cloned().collect(); + Some((t1, Type::Or(t2))) + } + _ => None, + } + } + + pub fn union_types(&self) -> Option> { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_types(), + Self::Refinement(refine) => refine.t.union_types(), + Self::Or(tys) => Some(tys.clone()), _ => None, } } @@ -3336,7 +3307,7 @@ impl Type { match self { Type::FreeVar(fv) if fv.is_linked() => fv.crack().contains_union(typ), Type::Refinement(refine) => refine.t.contains_union(typ), - Type::Or(t1, t2) => t1.contains_union(typ) || t2.contains_union(typ), + Type::Or(tys) => tys.iter().any(|t| t.contains_union(typ)), _ => self == typ, } } @@ -3350,11 +3321,7 @@ impl Type { .into_iter() .map(|t| t.quantify()) .collect(), - Type::And(t1, t2) => { - let mut types = t1.intersection_types(); - types.extend(t2.intersection_types()); - types - } + Type::And(tys) => tys.clone(), _ => vec![self.clone()], } } @@ -3430,9 +3397,8 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_unbound() => true, Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_totally_unbound(), - Self::Or(t1, t2) | Self::And(t1, t2) => { - t1.is_totally_unbound() && t2.is_totally_unbound() - } + Self::And(tys) => tys.iter().all(|t| t.is_totally_unbound()), + Self::Or(tys) => tys.iter().all(|t| t.is_totally_unbound()), Self::Not(t) => t.is_totally_unbound(), _ => false, } @@ -3539,9 +3505,15 @@ impl Type { sub.destructive_coerce(); self.destructive_link(&sub); } - Type::And(l, r) | Type::Or(l, r) => { - l.destructive_coerce(); - r.destructive_coerce(); + Type::And(tys) => { + for t in tys { + t.destructive_coerce(); + } + } + Type::Or(tys) => { + for t in tys { + t.destructive_coerce(); + } } Type::Not(l) => l.destructive_coerce(), Type::Poly { params, .. } => { @@ -3594,9 +3566,15 @@ impl Type { sub.undoable_coerce(list); self.undoable_link(&sub, list); } - Type::And(l, r) | Type::Or(l, r) => { - l.undoable_coerce(list); - r.undoable_coerce(list); + Type::And(tys) => { + for t in tys { + t.undoable_coerce(list); + } + } + Type::Or(tys) => { + for t in tys { + t.undoable_coerce(list); + } } Type::Not(l) => l.undoable_coerce(list), Type::Poly { params, .. } => { @@ -3655,7 +3633,12 @@ impl Type { .map(|t| t.qvars_inner()) .unwrap_or_else(|| set! {}), ), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.qvars_inner().concat(rhs.qvars_inner()), + Self::And(tys) => tys + .iter() + .fold(set! {}, |acc, t| acc.concat(t.qvars_inner())), + Self::Or(tys) => tys + .iter() + .fold(set! {}, |acc, t| acc.concat(t.qvars_inner())), Self::Not(ty) => ty.qvars_inner(), Self::Callable { param_ts, return_t } => param_ts .iter() @@ -3719,30 +3702,15 @@ impl Type { opt_t.map_or(false, |t| t.has_qvar()) } } - Self::Ref(ty) => ty.has_qvar(), - Self::RefMut { before, after } => { - before.has_qvar() || after.as_ref().map(|t| t.has_qvar()).unwrap_or(false) - } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.has_qvar() || rhs.has_qvar(), - Self::Not(ty) => ty.has_qvar(), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.has_qvar()) || return_t.has_qvar() - } Self::Subr(subr) => subr.has_qvar(), Self::Quantified(_) => false, // Self::Quantified(quant) => quant.has_qvar(), - Self::Record(r) => r.values().any(|t| t.has_qvar()), - Self::NamedTuple(r) => r.iter().any(|(_, t)| t.has_qvar()), Self::Refinement(refine) => refine.t.has_qvar() || refine.pred.has_qvar(), Self::Poly { params, .. } => params.iter().any(|tp| tp.has_qvar()), - Self::Proj { lhs, .. } => lhs.has_qvar(), Self::ProjCall { lhs, args, .. } => { lhs.has_qvar() || args.iter().any(|tp| tp.has_qvar()) } - Self::Structural(ty) => ty.has_qvar(), - Self::Guard(guard) => guard.to.has_qvar(), - Self::Bounded { sub, sup } => sub.has_qvar() || sup.has_qvar(), - mono_type_pattern!() => false, + _ => self.has_type_satisfies(|t| t.has_qvar()), } } @@ -3768,40 +3736,12 @@ impl Type { opt_t.map_or(false, |t| t.has_undoable_linked_var()) } } - Self::Ref(ty) => ty.has_undoable_linked_var(), - Self::RefMut { before, after } => { - before.has_undoable_linked_var() - || after - .as_ref() - .map(|t| t.has_undoable_linked_var()) - .unwrap_or(false) - } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var() - } - Self::Not(ty) => ty.has_undoable_linked_var(), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.has_undoable_linked_var()) - || return_t.has_undoable_linked_var() - } Self::Subr(subr) => subr.has_undoable_linked_var(), - Self::Quantified(quant) => quant.has_undoable_linked_var(), - Self::Record(r) => r.values().any(|t| t.has_undoable_linked_var()), - Self::NamedTuple(r) => r.iter().any(|(_, t)| t.has_undoable_linked_var()), - Self::Refinement(refine) => { - refine.t.has_undoable_linked_var() || refine.pred.has_undoable_linked_var() - } Self::Poly { params, .. } => params.iter().any(|tp| tp.has_undoable_linked_var()), - Self::Proj { lhs, .. } => lhs.has_undoable_linked_var(), Self::ProjCall { lhs, args, .. } => { lhs.has_undoable_linked_var() || args.iter().any(|tp| tp.has_undoable_linked_var()) } - Self::Structural(ty) => ty.has_undoable_linked_var(), - Self::Guard(guard) => guard.to.has_undoable_linked_var(), - Self::Bounded { sub, sup } => { - sub.has_undoable_linked_var() || sup.has_undoable_linked_var() - } - mono_type_pattern!() => false, + _ => self.has_type_satisfies(|t| t.has_undoable_linked_var()), } } @@ -3812,46 +3752,13 @@ impl Type { pub fn has_unbound_var(&self) -> bool { match self { Self::FreeVar(fv) => fv.has_unbound_var(), - Self::Ref(t) => t.has_unbound_var(), - Self::RefMut { before, after } => { - before.has_unbound_var() - || after.as_ref().map(|t| t.has_unbound_var()).unwrap_or(false) - } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_unbound_var() || rhs.has_unbound_var() - } - Self::Not(ty) => ty.has_unbound_var(), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.has_unbound_var()) || return_t.has_unbound_var() - } - Self::Subr(subr) => { - subr.non_default_params - .iter() - .any(|pt| pt.typ().has_unbound_var()) - || subr - .var_params - .as_ref() - .map(|pt| pt.typ().has_unbound_var()) - .unwrap_or(false) - || subr.default_params.iter().any(|pt| { - pt.typ().has_unbound_var() - || pt.default_typ().is_some_and(|t| t.has_unbound_var()) - }) - || subr.return_t.has_unbound_var() - } - Self::Record(r) => r.values().any(|t| t.has_unbound_var()), - Self::NamedTuple(r) => r.iter().any(|(_, t)| t.has_unbound_var()), + Self::Subr(subr) => subr.has_unbound_var(), Self::Refinement(refine) => refine.t.has_unbound_var() || refine.pred.has_unbound_var(), - Self::Quantified(quant) => quant.has_unbound_var(), Self::Poly { params, .. } => params.iter().any(|p| p.has_unbound_var()), - Self::Proj { lhs, .. } => lhs.has_unbound_var(), Self::ProjCall { lhs, args, .. } => { lhs.has_unbound_var() || args.iter().any(|t| t.has_unbound_var()) } - Self::Structural(ty) => ty.has_unbound_var(), - Self::Guard(guard) => guard.to.has_unbound_var(), - Self::Bounded { sub, sup } => sub.has_unbound_var() || sup.has_unbound_var(), - mono_type_pattern!() => false, + _ => self.has_type_satisfies(|t| t.has_unbound_var()), } } @@ -3874,7 +3781,8 @@ impl Type { Self::Refinement(refine) => refine.t.typarams_len(), // REVIEW: Self::Ref(_) | Self::RefMut { .. } => Some(1), - Self::And(_, _) | Self::Or(_, _) => Some(2), + Self::And(tys) => Some(tys.len()), + Self::Or(tys) => Some(tys.len()), Self::Not(_) => Some(1), Self::Subr(subr) => Some( subr.non_default_params.len() @@ -3940,9 +3848,8 @@ impl Type { Self::FreeVar(_unbound) => vec![], Self::Refinement(refine) => refine.t.typarams(), Self::Ref(t) | Self::RefMut { before: t, .. } => vec![TyParam::t(*t.clone())], - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - vec![TyParam::t(*lhs.clone()), TyParam::t(*rhs.clone())] - } + Self::And(tys) => tys.iter().cloned().map(TyParam::t).collect(), + Self::Or(tys) => tys.iter().cloned().map(TyParam::t).collect(), Self::Not(t) => vec![TyParam::t(*t.clone())], Self::Subr(subr) => subr.typarams(), Self::Quantified(quant) => quant.typarams(), @@ -4163,8 +4070,8 @@ impl Type { let r = r.iter().map(|(k, v)| (k.clone(), v.derefine())).collect(); Self::NamedTuple(r) } - Self::And(l, r) => l.derefine() & r.derefine(), - Self::Or(l, r) => l.derefine() | r.derefine(), + Self::And(tys) => Self::checked_and(tys.iter().map(|t| t.derefine()).collect()), + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.derefine()).collect()), Self::Not(ty) => !ty.derefine(), Self::Proj { lhs, rhs } => lhs.derefine().proj(rhs.clone()), Self::ProjCall { @@ -4219,7 +4126,7 @@ impl Type { /// (T or U).eliminate_subsup(T) == U /// ?X(<: T or U).eliminate_subsup(T) == ?X(<: U) /// ``` - pub fn eliminate_subsup(self, target: &Type) -> Self { + pub(crate) fn eliminate_subsup(self, target: &Type) -> Self { match self { Self::FreeVar(fv) if fv.is_linked() => fv.unwrap_linked().eliminate_subsup(target), Self::FreeVar(ref fv) if fv.constraint_is_sandwiched() => { @@ -4237,22 +4144,28 @@ impl Type { }); self } - Self::And(l, r) => { - if l.addr_eq(target) { - return r.eliminate_subsup(target); - } else if r.addr_eq(target) { - return l.eliminate_subsup(target); - } - l.eliminate_subsup(target) & r.eliminate_subsup(target) - } - Self::Or(l, r) => { - if l.addr_eq(target) { - return r.eliminate_subsup(target); - } else if r.addr_eq(target) { - return l.eliminate_subsup(target); - } - l.eliminate_subsup(target) | r.eliminate_subsup(target) - } + Self::And(tys) => Self::checked_and( + tys.into_iter() + .filter_map(|t| { + if t.addr_eq(target) { + None + } else { + Some(t.eliminate_subsup(target)) + } + }) + .collect(), + ), + Self::Or(tys) => Self::checked_or( + tys.into_iter() + .filter_map(|t| { + if t.addr_eq(target) { + None + } else { + Some(t.eliminate_subsup(target)) + } + }) + .collect(), + ), other => other, } } @@ -4261,7 +4174,7 @@ impl Type { /// ?T(<: K(X)).eliminate_recursion(X) == ?T(<: K(X)) /// Tuple(X).eliminate_recursion(X) == Tuple(Never) /// ``` - pub fn eliminate_recursion(self, target: &Type) -> Self { + pub(crate) fn eliminate_recursion(self, target: &Type) -> Self { if self.is_free_var() && self.addr_eq(target) { return Self::Never; } @@ -4307,8 +4220,18 @@ impl Type { before: Box::new(before.eliminate_recursion(target)), after: after.map(|t| Box::new(t.eliminate_recursion(target))), }, - Self::And(l, r) => l.eliminate_recursion(target) & r.eliminate_recursion(target), - Self::Or(l, r) => l.eliminate_recursion(target) | r.eliminate_recursion(target), + Self::And(tys) => Self::checked_and( + tys.into_iter() + .filter(|t| !t.addr_eq(target)) + .map(|t| t.eliminate_recursion(target)) + .collect(), + ), + Self::Or(tys) => Self::checked_or( + tys.into_iter() + .filter(|t| !t.addr_eq(target)) + .map(|t| t.eliminate_recursion(target)) + .collect(), + ), Self::Not(ty) => !ty.eliminate_recursion(target), Self::Proj { lhs, rhs } => lhs.eliminate_recursion(target).proj(rhs), Self::ProjCall { @@ -4333,6 +4256,18 @@ impl Type { } } + pub(crate) fn eliminate_and_or_recursion(self, target: &Type) -> Self { + match self { + Self::And(tys) => { + Self::checked_and(tys.into_iter().filter(|t| !t.addr_eq(target)).collect()) + } + Self::Or(tys) => { + Self::checked_or(tys.into_iter().filter(|t| !t.addr_eq(target)).collect()) + } + _ => self, + } + } + pub fn replace(self, target: &Type, to: &Type) -> Type { let table = ReplaceTable::make(target, to); table.replace(self) @@ -4461,8 +4396,8 @@ impl Type { before: Box::new(before.map(f)), after: after.map(|t| Box::new(t.map(f))), }, - Self::And(l, r) => l.map(f) & r.map(f), - Self::Or(l, r) => l.map(f) | r.map(f), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.map(f)).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.map(f)).collect()), Self::Not(ty) => !ty.map(f), Self::Proj { lhs, rhs } => lhs.map(f).proj(rhs), Self::ProjCall { @@ -4489,7 +4424,7 @@ impl Type { /// Unlike `replace`, this does not make a look-up table. fn _replace(mut self, target: &Type, to: &Type) -> Type { - if self.structural_eq(target) { + if &self == target { self = to.clone(); } self.map(&mut |t| t._replace(target, to)) @@ -4555,8 +4490,12 @@ impl Type { before: Box::new(before._replace_tp(target, to)), after: after.map(|t| Box::new(t._replace_tp(target, to))), }, - Self::And(l, r) => l._replace_tp(target, to) & r._replace_tp(target, to), - Self::Or(l, r) => l._replace_tp(target, to) | r._replace_tp(target, to), + Self::And(tys) => { + Self::checked_and(tys.into_iter().map(|t| t._replace_tp(target, to)).collect()) + } + Self::Or(tys) => { + Self::checked_or(tys.into_iter().map(|t| t._replace_tp(target, to)).collect()) + } Self::Not(ty) => !ty._replace_tp(target, to), Self::Proj { lhs, rhs } => lhs._replace_tp(target, to).proj(rhs), Self::ProjCall { @@ -4632,8 +4571,8 @@ impl Type { before: Box::new(before.map_tp(f)), after: after.map(|t| Box::new(t.map_tp(f))), }, - Self::And(l, r) => l.map_tp(f) & r.map_tp(f), - Self::Or(l, r) => l.map_tp(f) | r.map_tp(f), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.map_tp(f)).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.map_tp(f)).collect()), Self::Not(ty) => !ty.map_tp(f), Self::Proj { lhs, rhs } => lhs.map_tp(f).proj(rhs), Self::ProjCall { @@ -4721,8 +4660,16 @@ impl Type { after, }) } - Self::And(l, r) => Ok(l.try_map_tp(f)? & r.try_map_tp(f)?), - Self::Or(l, r) => Ok(l.try_map_tp(f)? | r.try_map_tp(f)?), + Self::And(tys) => Ok(Self::checked_and( + tys.into_iter() + .map(|t| t.try_map_tp(f)) + .collect::>()?, + )), + Self::Or(tys) => Ok(Self::checked_or( + tys.into_iter() + .map(|t| t.try_map_tp(f)) + .collect::>()?, + )), Self::Not(ty) => Ok(!ty.try_map_tp(f)?), Self::Proj { lhs, rhs } => Ok(lhs.try_map_tp(f)?.proj(rhs)), Self::ProjCall { @@ -4755,12 +4702,28 @@ impl Type { *refine.t = refine.t.replace_param(target, to); Self::Refinement(refine) } - Self::And(l, r) => l.replace_param(target, to) & r.replace_param(target, to), + Self::And(tys) => Self::And( + tys.into_iter() + .map(|t| t.replace_param(target, to)) + .collect(), + ), Self::Guard(guard) => Self::Guard(guard.replace_param(target, to)), _ => self, } } + pub fn eliminate_and_or(&mut self) { + match self { + Self::And(tys) if tys.len() == 1 => { + *self = tys.remove(0); + } + Self::Or(tys) if tys.len() == 1 => { + *self = tys.take_all().into_iter().next().unwrap(); + } + _ => {} + } + } + pub fn replace_params<'l, 'r>( mut self, target: impl Iterator, @@ -4828,8 +4791,8 @@ impl Type { } Self::NamedTuple(r) } - Self::And(l, r) => l.normalize() & r.normalize(), - Self::Or(l, r) => l.normalize() | r.normalize(), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.normalize()).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.normalize()).collect()), Self::Not(ty) => !ty.normalize(), Self::Structural(ty) => ty.normalize().structuralize(), Self::Quantified(quant) => quant.normalize().quantify(), @@ -4861,14 +4824,36 @@ impl Type { free.get_sub().unwrap_or(self.clone()) } else { match self { - Self::And(l, r) => l.lower_bounded() & r.lower_bounded(), - Self::Or(l, r) => l.lower_bounded() | r.lower_bounded(), + Self::And(tys) => { + Self::checked_and(tys.iter().map(|t| t.lower_bounded()).collect()) + } + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.lower_bounded()).collect()), Self::Not(ty) => !ty.lower_bounded(), _ => self.clone(), } } } + /// ```erg + /// assert Int.upper_bounded() == Int + /// assert ?T(<: Str).upper_bounded() == Str + /// assert (?T(<: Str) or ?U(<: Int)).upper_bounded() == (Str or Int) + /// ``` + pub fn upper_bounded(&self) -> Type { + if let Ok(free) = <&FreeTyVar>::try_from(self) { + free.get_super().unwrap_or(self.clone()) + } else { + match self { + Self::And(tys) => { + Self::checked_and(tys.iter().map(|t| t.upper_bounded()).collect()) + } + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.upper_bounded()).collect()), + Self::Not(ty) => !ty.upper_bounded(), + _ => self.clone(), + } + } + } + pub(crate) fn addr_eq(&self, other: &Type) -> bool { match (self, other) { (Self::FreeVar(slf), _) if slf.is_linked() => slf.crack().addr_eq(other), @@ -4891,8 +4876,8 @@ impl Type { } match self { Self::FreeVar(fv) => { - let to = to.clone().eliminate_subsup(self).eliminate_recursion(self); - fv.link(&to); + let to_ = to.clone().eliminate_subsup(self).eliminate_recursion(self); + fv.link(&to_); } Self::Refinement(refine) => refine.t.destructive_link(to), _ => { @@ -4914,8 +4899,12 @@ impl Type { } match self { Self::FreeVar(fv) => { - let to = to.clone().eliminate_subsup(self); // FIXME: .eliminate_recursion(self) - fv.undoable_link(&to); + // NOTE: we can't use `eliminate_recursion` + let to_ = to + .clone() + .eliminate_subsup(self) + .eliminate_and_or_recursion(self); + fv.undoable_link(&to_); } Self::Refinement(refine) => refine.t.undoable_link(to, list), _ => { @@ -5027,12 +5016,12 @@ impl Type { /// Add.ands() == {Add} /// (Add and Sub).ands() == {Add, Sub} /// ``` - pub fn ands(&self) -> Set { + pub fn ands(&self) -> Vec { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().ands(), Self::Refinement(refine) => refine.t.ands(), - Self::And(l, r) => l.ands().union(&r.ands()), - _ => set![self.clone()], + Self::And(tys) => tys.clone(), + _ => vec![self.clone()], } } @@ -5044,7 +5033,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().ors(), Self::Refinement(refine) => refine.t.ors(), - Self::Or(l, r) => l.ors().union(&r.ors()), + Self::Or(tys) => tys.clone(), _ => set![self.clone()], } } @@ -5089,7 +5078,8 @@ impl Type { Self::Callable { param_ts, .. } => { param_ts.iter().flat_map(|t| t.contained_ts()).collect() } - Self::And(l, r) | Self::Or(l, r) => l.contained_ts().union(&r.contained_ts()), + Self::And(tys) => tys.iter().flat_map(|t| t.contained_ts()).collect(), + Self::Or(tys) => tys.iter().flat_map(|t| t.contained_ts()).collect(), Self::Not(t) => t.contained_ts(), Self::Bounded { sub, sup } => sub.contained_ts().union(&sup.contained_ts()), Self::Quantified(ty) | Self::Structural(ty) => ty.contained_ts(), @@ -5158,9 +5148,30 @@ impl Type { } return_t.dereference(); } - Self::And(l, r) | Self::Or(l, r) => { - l.dereference(); - r.dereference(); + Self::And(tys) => { + *tys = std::mem::take(tys) + .into_iter() + .map(|mut t| { + t.dereference(); + t + }) + .collect(); + if tys.len() == 1 { + *self = tys.remove(0); + } + } + Self::Or(tys) => { + *tys = tys + .take_all() + .into_iter() + .map(|mut t| { + t.dereference(); + t + }) + .collect(); + if tys.len() == 1 { + *self = tys.take_all().into_iter().next().unwrap(); + } } Self::Not(ty) => { ty.dereference(); @@ -5268,7 +5279,8 @@ impl Type { set.extend(return_t.variables()); set } - Self::And(l, r) | Self::Or(l, r) => l.variables().union(&r.variables()), + Self::And(tys) => tys.iter().flat_map(|t| t.variables()).collect(), + Self::Or(tys) => tys.iter().flat_map(|t| t.variables()).collect(), Self::Not(ty) => ty.variables(), Self::Bounded { sub, sup } => sub.variables().union(&sup.variables()), Self::Quantified(ty) | Self::Structural(ty) => ty.variables(), @@ -5380,13 +5392,16 @@ impl<'t> ReplaceTable<'t> { self.iterate(l, r); } } - (Type::And(l, r), Type::And(l2, r2)) => { - self.iterate(l, l2); - self.iterate(r, r2); + // FIXME: + (Type::And(tys), Type::And(tys2)) => { + for (l, r) in tys.iter().zip(tys2.iter()) { + self.iterate(l, r); + } } - (Type::Or(l, r), Type::Or(l2, r2)) => { - self.iterate(l, l2); - self.iterate(r, r2); + (Type::Or(tys), Type::Or(tys2)) => { + for (l, r) in tys.iter().zip(tys2.iter()) { + self.iterate(l, r); + } } (Type::Not(t), Type::Not(t2)) => { self.iterate(t, t2); diff --git a/crates/erg_compiler/ty/predicate.rs b/crates/erg_compiler/ty/predicate.rs index 1e6c489ca..d43fd61dc 100644 --- a/crates/erg_compiler/ty/predicate.rs +++ b/crates/erg_compiler/ty/predicate.rs @@ -656,7 +656,8 @@ impl Predicate { pub fn qvars(&self) -> Set<(Str, Constraint)> { match self { - Self::Value(_) | Self::Const(_) | Self::Failure => set! {}, + Self::Const(_) | Self::Failure => set! {}, + Self::Value(val) => val.qvars(), Self::Call { receiver, args, .. } => { let mut set = receiver.qvars(); for arg in args { @@ -680,9 +681,35 @@ impl Predicate { } } + pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { + match self { + Self::Const(_) | Self::Failure => false, + Self::Value(val) => val.has_type_satisfies(f), + Self::Call { receiver, args, .. } => { + receiver.has_type_satisfies(f) || args.iter().any(|a| a.has_type_satisfies(f)) + } + Self::Attr { receiver, .. } => receiver.has_type_satisfies(f), + Self::Equal { rhs, .. } + | Self::GreaterEqual { rhs, .. } + | Self::LessEqual { rhs, .. } + | Self::NotEqual { rhs, .. } => rhs.has_type_satisfies(f), + Self::GeneralEqual { lhs, rhs } + | Self::GeneralLessEqual { lhs, rhs } + | Self::GeneralGreaterEqual { lhs, rhs } + | Self::GeneralNotEqual { lhs, rhs } => { + lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) + } + Self::Or(lhs, rhs) | Self::And(lhs, rhs) => { + lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) + } + Self::Not(pred) => pred.has_type_satisfies(f), + } + } + pub fn has_qvar(&self) -> bool { match self { - Self::Value(_) | Self::Const(_) | Self::Failure => false, + Self::Const(_) | Self::Failure => false, + Self::Value(val) => val.has_qvar(), Self::Call { receiver, args, .. } => { receiver.has_qvar() || args.iter().any(|a| a.has_qvar()) } @@ -702,7 +729,8 @@ impl Predicate { pub fn has_unbound_var(&self) -> bool { match self { - Self::Value(_) | Self::Const(_) | Self::Failure => false, + Self::Const(_) | Self::Failure => false, + Self::Value(val) => val.has_unbound_var(), Self::Call { receiver, args, .. } => { receiver.has_unbound_var() || args.iter().any(|a| a.has_unbound_var()) } @@ -724,7 +752,8 @@ impl Predicate { pub fn has_undoable_linked_var(&self) -> bool { match self { - Self::Value(_) | Self::Const(_) | Self::Failure => false, + Self::Const(_) | Self::Failure => false, + Self::Value(val) => val.has_undoable_linked_var(), Self::Call { receiver, args, .. } => { receiver.has_undoable_linked_var() || args.iter().any(|a| a.has_undoable_linked_var()) diff --git a/crates/erg_compiler/ty/typaram.rs b/crates/erg_compiler/ty/typaram.rs index d104c298b..3ec3aca90 100644 --- a/crates/erg_compiler/ty/typaram.rs +++ b/crates/erg_compiler/ty/typaram.rs @@ -1268,21 +1268,21 @@ impl TyParam { pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_type_satisfies(f), - Self::FreeVar(fv) => fv.get_type().map_or(false, |t| t.has_type_satisfies(f)), - Self::Type(t) => t.has_type_satisfies(f), - Self::Erased(t) => t.has_type_satisfies(f), + Self::FreeVar(fv) => fv.get_type().map_or(false, |t| f(&t)), + Self::Type(t) => f(t), + Self::Erased(t) => f(t), Self::Proj { obj, .. } => obj.has_type_satisfies(f), Self::ProjCall { obj, args, .. } => { - obj.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f)) + obj.has_type_satisfies(f) || args.iter().any(|tp| tp.has_type_satisfies(f)) } - Self::List(ts) | Self::Tuple(ts) => ts.iter().any(|t| t.has_type_satisfies(f)), + Self::List(ts) | Self::Tuple(ts) => ts.iter().any(|tp| tp.has_type_satisfies(f)), Self::UnsizedList(elem) => elem.has_type_satisfies(f), - Self::Set(ts) => ts.iter().any(|t| t.has_type_satisfies(f)), + Self::Set(ts) => ts.iter().any(|tp| tp.has_type_satisfies(f)), Self::Dict(ts) => ts .iter() .any(|(k, v)| k.has_type_satisfies(f) || v.has_type_satisfies(f)), Self::Record(rec) | Self::DataClass { fields: rec, .. } => { - rec.iter().any(|(_, tp)| tp.has_type_satisfies(f)) + rec.values().any(|tp| tp.has_type_satisfies(f)) } Self::Lambda(lambda) => lambda.body.iter().any(|tp| tp.has_type_satisfies(f)), Self::UnaryOp { val, .. } => val.has_type_satisfies(f), diff --git a/crates/erg_compiler/ty/value.rs b/crates/erg_compiler/ty/value.rs index 878495016..994040acc 100644 --- a/crates/erg_compiler/ty/value.rs +++ b/crates/erg_compiler/ty/value.rs @@ -2080,7 +2080,7 @@ impl ValueObj { pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { match self { - Self::Type(t) => t.typ().has_type_satisfies(f), + Self::Type(t) => f(t.typ()), Self::List(ts) | Self::Tuple(ts) => ts.iter().any(|t| t.has_type_satisfies(f)), Self::UnsizedList(elem) => elem.has_type_satisfies(f), Self::Set(ts) => ts.iter().any(|t| t.has_type_satisfies(f)), diff --git a/tests/should_err/and.er b/tests/should_err/and.er new file mode 100644 index 000000000..a16daf02c --- /dev/null +++ b/tests/should_err/and.er @@ -0,0 +1,3 @@ +a as Eq and Hash and Show and Add(Str) = "a" +f _: Ord and Eq and Show and Hash = None +f a # ERR diff --git a/tests/should_err/or.er b/tests/should_err/or.er new file mode 100644 index 000000000..20e8937da --- /dev/null +++ b/tests/should_err/or.er @@ -0,0 +1,3 @@ +a as Int or Str or NoneType = 1 +f _: Nat or NoneType or Str = None +f a # ERR diff --git a/tests/should_ok/and.er b/tests/should_ok/and.er new file mode 100644 index 000000000..4f3bee0d7 --- /dev/null +++ b/tests/should_ok/and.er @@ -0,0 +1,7 @@ +a as Eq and Hash and Show = 1 +f _: Eq and Show and Hash = None +f a + +b as Eq and Hash and Ord and Show = 1 +g _: Ord and Eq and Show and Hash = None +g b diff --git a/tests/should_ok/or.er b/tests/should_ok/or.er new file mode 100644 index 000000000..2e36f19b0 --- /dev/null +++ b/tests/should_ok/or.er @@ -0,0 +1,7 @@ +a as Nat or Str or NoneType = 1 +f _: Int or NoneType or Str = None +f a + +b as Nat or Str or NoneType or Bool = 1 +g _: Int or NoneType or Bool or Str = None +g b diff --git a/tests/test.rs b/tests/test.rs index 3259eee09..4ed20fb92 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -16,6 +16,11 @@ fn exec_advanced_type_spec() -> Result<(), ()> { expect_success("tests/should_ok/advanced_type_spec.er", 5) } +#[test] +fn exec_and() -> Result<(), ()> { + expect_success("tests/should_ok/and.er", 0) +} + #[test] fn exec_args_expansion() -> Result<(), ()> { expect_success("tests/should_ok/args_expansion.er", 0) @@ -327,6 +332,11 @@ fn exec_operators() -> Result<(), ()> { expect_success("tests/should_ok/operators.er", 0) } +#[test] +fn exec_or() -> Result<(), ()> { + expect_success("tests/should_ok/or.er", 0) +} + #[test] fn exec_patch() -> Result<(), ()> { expect_success("examples/patch.er", 0) @@ -527,6 +537,11 @@ fn exec_list_member_err() -> Result<(), ()> { expect_failure("tests/should_err/list_member.er", 0, 3) } +#[test] +fn exec_and_err() -> Result<(), ()> { + expect_failure("tests/should_err/and.er", 0, 1) +} + #[test] fn exec_as() -> Result<(), ()> { expect_failure("tests/should_err/as.er", 0, 6) @@ -634,6 +649,11 @@ fn exec_move_check() -> Result<(), ()> { expect_failure("examples/move_check.er", 1, 1) } +#[test] +fn exec_or_err() -> Result<(), ()> { + expect_failure("tests/should_err/or.er", 0, 1) +} + #[test] fn exec_poly_type_spec_err() -> Result<(), ()> { expect_failure("tests/should_err/poly_type_spec.er", 0, 3)