diff --git a/crates/erg_common/triple.rs b/crates/erg_common/triple.rs index 91547caf2..a743d43ab 100644 --- a/crates/erg_common/triple.rs +++ b/crates/erg_common/triple.rs @@ -141,6 +141,14 @@ impl Triple { Triple::Ok(a) | Triple::Err(a) => Some(a), } } + + pub fn merge_or(self, default: T) -> T { + match self { + Triple::None => default, + Triple::Ok(ok) => ok, + Triple::Err(err) => err, + } + } } impl Triple { diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 4fb64adb5..425d28c48 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -119,6 +119,11 @@ impl Context { self.supertype_of(lhs, rhs) || self.subtype_of(lhs, rhs) } + pub(crate) fn _related_tp(&self, lhs: &TyParam, rhs: &TyParam) -> bool { + self._subtype_of_tp(lhs, rhs, Variance::Covariant) + || self.supertype_of_tp(lhs, rhs, Variance::Covariant) + } + /// lhs :> rhs ? pub(crate) fn supertype_of(&self, lhs: &Type, rhs: &Type) -> bool { let res = match Self::cheap_supertype_of(lhs, rhs) { @@ -1118,6 +1123,10 @@ impl Context { } } + pub(crate) fn covariant_supertype_of_tp(&self, lp: &TyParam, rp: &TyParam) -> bool { + self.supertype_of_tp(lp, rp, Variance::Covariant) + } + /// lhs <: rhs? pub(crate) fn structural_subtype_of(&self, lhs: &Type, rhs: &Type) -> bool { self.structural_supertype_of(rhs, lhs) @@ -1282,6 +1291,7 @@ impl Context { /// union(Array(Int, 2), Array(Str, 3)) == Array(Int, 2) or Array(Int, 3) /// union({ .a = Int }, { .a = Str }) == { .a = Int or Str } /// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } + /// union((A and B) or C) == (A or C) and (B or C) /// ``` pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type { if lhs == rhs { @@ -1345,6 +1355,16 @@ impl Context { _ => self.simple_union(lhs, rhs), }, (other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other), + // (A and B) or C ==> (A or C) and (B or C) + (and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => { + let ands = and_t.ands(); + let mut t = Type::Obj; + for branch in ands.iter() { + let union = self.union(branch, other); + t = and(t, union); + } + t + } (t, Type::Never) | (Type::Never, t) => t.clone(), // Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2) ( @@ -1497,12 +1517,6 @@ impl Context { self.intersection(&fv.crack(), other) } (Refinement(l), Refinement(r)) => Type::Refinement(self.intersection_refinement(l, r)), - (other, Refinement(refine)) | (Refinement(refine), other) => { - let other = other.clone().into_refinement(); - let intersec = self.intersection_refinement(&other, refine); - self.try_squash_refinement(intersec) - .unwrap_or_else(Type::Refinement) - } (Structural(l), Structural(r)) => self.intersection(l, r).structuralize(), (Guard(l), Guard(r)) => { if l.namespace == r.namespace && l.target == r.target { @@ -1527,6 +1541,25 @@ impl Context { (other, and @ And(_, _)) | (and @ And(_, _), other) => { self.intersection_add(and, other) } + // (A or B) and C == (A and C) or (B and C) + (or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => { + let ors = or_t.ors(); + if ors.iter().any(|t| t.has_unbound_var()) { + return self.simple_intersection(lhs, rhs); + } + let mut t = Type::Never; + for branch in ors.iter() { + let isec = self.intersection(branch, other); + t = self.union(&t, &isec); + } + t + } + (other, Refinement(refine)) | (Refinement(refine), other) => { + let other = other.clone().into_refinement(); + let intersec = self.intersection_refinement(&other, refine); + self.try_squash_refinement(intersec) + .unwrap_or_else(Type::Refinement) + } // overloading (l, r) if l.is_subr() && r.is_subr() => and(lhs.clone(), rhs.clone()), _ => self.simple_intersection(lhs, rhs), diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 2d09a8f5d..ba4b59257 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -155,6 +155,7 @@ impl<'c> Substituter<'c> { /// e.g. /// ```erg /// qt: Array(T, N), st: Array(Int, 3) + /// qt: T or NoneType, st: NoneType or Int (T == Int) /// ``` /// invalid (no effect): /// ```erg @@ -167,8 +168,15 @@ impl<'c> Substituter<'c> { st: &Type, ) -> EvalResult> { let qtps = qt.typarams(); - let stps = st.typarams(); - if qt.qual_name() != st.qual_name() || qtps.len() != stps.len() { + let mut stps = st.typarams(); + // Or, And are commutative, choose fitting order + if qt.qual_name() == st.qual_name() && (st.qual_name() == "Or" || st.qual_name() == "And") { + if ctx.covariant_supertype_of_tp(&qtps[0], &stps[1]) + && ctx.covariant_supertype_of_tp(&qtps[1], &stps[0]) + { + stps.swap(0, 1); + } + } else if qt.qual_name() != st.qual_name() || qtps.len() != stps.len() { if let Some(inner) = st.ref_inner().or_else(|| st.ref_mut_inner()) { return Self::substitute_typarams(ctx, qt, &inner); } else if let Some(sub) = st.get_sub() { diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index a692edfe0..898ca24ba 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -1180,6 +1180,9 @@ impl Context { super_exists } + /// Check if a trait implementation exists for a polymorphic class. + /// This is needed because the trait implementation spec can contain projection types. + /// e.g. `Tuple(Ts) <: Container(Ts.union())` fn poly_class_trait_impl_exists(&self, class: &Type, trait_: &Type) -> bool { let class_hash = get_hash(&class); let trait_hash = get_hash(&trait_); diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index d85c3cb9d..300ae91a6 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -3657,6 +3657,7 @@ impl Context { /// ```erg /// recover_typarams(Int, Nat) == Nat /// recover_typarams(Array!(Int, _), Array(Nat, 2)) == Array!(Nat, 2) + /// recover_typarams(Str or NoneType, {"a", "b"}) == {"a", "b"} /// ``` /// ```erg /// # REVIEW: should be? diff --git a/crates/erg_compiler/tests/infer.er b/crates/erg_compiler/tests/infer.er index 9b5c148f0..2243ed302 100644 --- a/crates/erg_compiler/tests/infer.er +++ b/crates/erg_compiler/tests/infer.er @@ -30,3 +30,11 @@ c_new x, y = C.new x, y C = Class Int C. new x, y = Self x + y + +val!() = + for! [{ "a": "b" }], (pkg as {Str: Str}) => + x = pkg.get("a", "c") + assert x in {"b"} + val!::return x + "d" +val = val!() diff --git a/crates/erg_compiler/tests/test.rs b/crates/erg_compiler/tests/test.rs index 80bbec857..89bcdd598 100644 --- a/crates/erg_compiler/tests/test.rs +++ b/crates/erg_compiler/tests/test.rs @@ -87,6 +87,9 @@ fn _test_infer_types() -> Result<(), ()> { let c_new_t = func2(add_r, r, c.clone()).quantify(); module.context.assert_var_type("c_new", &c_new_t)?; module.context.assert_attr_type(&c, "new", &c_new_t)?; + module + .context + .assert_var_type("val", &v_enum(set! { "b".into(), "d".into() }))?; Ok(()) }