diff --git a/crates/erg_common/triple.rs b/crates/erg_common/triple.rs index 91547caf2..a743d43ab 100644 --- a/crates/erg_common/triple.rs +++ b/crates/erg_common/triple.rs @@ -141,6 +141,14 @@ impl Triple { Triple::Ok(a) | Triple::Err(a) => Some(a), } } + + pub fn merge_or(self, default: T) -> T { + match self { + Triple::None => default, + Triple::Ok(ok) => ok, + Triple::Err(err) => err, + } + } } impl Triple { diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index d85c3cb9d..b12add8f4 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -3657,6 +3657,7 @@ impl Context { /// ```erg /// recover_typarams(Int, Nat) == Nat /// recover_typarams(Array!(Int, _), Array(Nat, 2)) == Array!(Nat, 2) + /// recover_typarams(Str or NoneType, {"a", "b"}) == {"a", "b"} /// ``` /// ```erg /// # REVIEW: should be? @@ -3667,7 +3668,8 @@ impl Context { let is_never = self.subtype_of(&intersec, &Type::Never) && guard.to.as_ref() != &Type::Never; if !is_never { - return Ok(intersec); + let min = self.min(&intersec, &guard.to).merge_or(&intersec); + return Ok(min.clone()); } if guard.to.is_monomorphic() { if self.related(base, &guard.to) { diff --git a/crates/erg_compiler/tests/infer.er b/crates/erg_compiler/tests/infer.er index 9b5c148f0..2243ed302 100644 --- a/crates/erg_compiler/tests/infer.er +++ b/crates/erg_compiler/tests/infer.er @@ -30,3 +30,11 @@ c_new x, y = C.new x, y C = Class Int C. new x, y = Self x + y + +val!() = + for! [{ "a": "b" }], (pkg as {Str: Str}) => + x = pkg.get("a", "c") + assert x in {"b"} + val!::return x + "d" +val = val!() diff --git a/crates/erg_compiler/tests/test.rs b/crates/erg_compiler/tests/test.rs index 80bbec857..89bcdd598 100644 --- a/crates/erg_compiler/tests/test.rs +++ b/crates/erg_compiler/tests/test.rs @@ -87,6 +87,9 @@ fn _test_infer_types() -> Result<(), ()> { let c_new_t = func2(add_r, r, c.clone()).quantify(); module.context.assert_var_type("c_new", &c_new_t)?; module.context.assert_attr_type(&c, "new", &c_new_t)?; + module + .context + .assert_var_type("val", &v_enum(set! { "b".into(), "d".into() }))?; Ok(()) }