From df837d70d3fb39e2febfd8e57be528b1898f2c0a Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 17 Sep 2024 17:13:28 +0900 Subject: [PATCH] fix: sub-unification bug --- crates/erg_compiler/context/compare.rs | 1 + crates/erg_compiler/context/unify.rs | 88 ++++++++++++-------------- tests/should_err/and.er | 3 + tests/should_err/or.er | 3 + tests/should_ok/and.er | 7 ++ tests/should_ok/or.er | 7 ++ tests/test.rs | 20 ++++++ 7 files changed, 82 insertions(+), 47 deletions(-) create mode 100644 tests/should_err/and.er create mode 100644 tests/should_err/or.er create mode 100644 tests/should_ok/and.er create mode 100644 tests/should_ok/or.er diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 5707af044..f6125066d 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -816,6 +816,7 @@ impl Context { } // Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true // Int or Str or NoneType :> Str or Int + // Int or Str or NoneType :> Str or NoneType or Nat (Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))), // not Nat :> not Int == true (Not(l), Not(r)) => self.subtype_of(l, r), diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 3df47f5c5..fea7e0dcc 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -156,6 +156,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } + // FIXME: This is not correct, we must visit all permutations of the types (And(l), And(r)) if l.len() == r.len() => { let mut r = r.clone(); for _ in 0..r.len() { @@ -1297,65 +1298,58 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { // self.sub_unify(&lsub, &union, loc, param_name)?; maybe_sup.update_tyvar(union, intersec, self.undoable, false); } + // TODO: Preferentially compare same-structure types (e.g. K(?T) <: K(?U)) (And(ltys), And(rtys)) => { - let mut rtys = rtys.clone(); - for _ in 0..rtys.len() { - if ltys - .iter() - .zip(rtys.iter()) - .all(|(l, r)| self.ctx.subtype_of(l, r)) - { - for (l, r) in ltys.iter().zip(rtys.iter()) { - self.sub_unify(l, r)?; + let mut ltys_ = ltys.clone(); + let mut rtys_ = rtys.clone(); + // Show and EqHash and T <: Eq and Show and Ord + // => EqHash and T <: Eq and Ord + for lty in ltys.iter() { + if let Some(idx) = rtys_.iter().position(|r| r == lty) { + rtys_.remove(idx); + let idx = ltys_.iter().position(|l| l == lty).unwrap(); + ltys_.remove(idx); + } + } + // EqHash and T <: Eq and Ord + for lty in ltys_.iter() { + // lty: EqHash + // rty: Eq, Ord + for rty in rtys_.iter() { + if self.ctx.subtype_of(lty, rty) { + self.sub_unify(lty, rty)?; + continue; } - return Ok(()); } - rtys.rotate_left(1); } - return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( - self.ctx.cfg.input.clone(), - line!() as usize, - self.loc.loc(), - self.ctx.caused_by(), - self.param_name.as_ref().unwrap_or(&Str::ever("_")), - None, - maybe_sup, - maybe_sub, - self.ctx.get_candidates(maybe_sub), - self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub), - ))); } + // TODO: Preferentially compare same-structure types (e.g. K(?T) <: K(?U)) + // Nat or Str or NoneType <: NoneType or ?T or Int + // => Str <: ?T // (Int or ?T) <: (?U or Int) // OK: (Int <: Int); (?T <: ?U) // NG: (Int <: ?U); (?T <: Int) (Or(ltys), Or(rtys)) => { - let ltys = ltys.to_vec(); - let mut rtys = rtys.to_vec(); - for _ in 0..rtys.len() { - if ltys - .iter() - .zip(rtys.iter()) - .all(|(l, r)| self.ctx.subtype_of(l, r)) - { - for (l, r) in ltys.iter().zip(rtys.iter()) { - self.sub_unify(l, r)?; + let mut ltys_ = ltys.clone(); + let mut rtys_ = rtys.clone(); + // Nat or T or Str <: Str or Int or NoneType + // => Nat or T <: Int or NoneType + for lty in ltys { + if rtys_.linear_remove(lty) { + ltys_.linear_remove(lty); + } + } + // Nat or T <: Int or NoneType + for lty in ltys_.iter() { + // lty: Nat + // rty: Int, NoneType + for rty in rtys_.iter() { + if self.ctx.subtype_of(lty, rty) { + self.sub_unify(lty, rty)?; + continue; } - return Ok(()); } - rtys.rotate_left(1); } - return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( - self.ctx.cfg.input.clone(), - line!() as usize, - self.loc.loc(), - self.ctx.caused_by(), - self.param_name.as_ref().unwrap_or(&Str::ever("_")), - None, - maybe_sup, - maybe_sub, - self.ctx.get_candidates(maybe_sub), - self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub), - ))); } // NG: Nat <: ?T or Int ==> Nat or Int (?T = Nat) // OK: Nat <: ?T or Int ==> ?T or Int diff --git a/tests/should_err/and.er b/tests/should_err/and.er new file mode 100644 index 000000000..a16daf02c --- /dev/null +++ b/tests/should_err/and.er @@ -0,0 +1,3 @@ +a as Eq and Hash and Show and Add(Str) = "a" +f _: Ord and Eq and Show and Hash = None +f a # ERR diff --git a/tests/should_err/or.er b/tests/should_err/or.er new file mode 100644 index 000000000..20e8937da --- /dev/null +++ b/tests/should_err/or.er @@ -0,0 +1,3 @@ +a as Int or Str or NoneType = 1 +f _: Nat or NoneType or Str = None +f a # ERR diff --git a/tests/should_ok/and.er b/tests/should_ok/and.er new file mode 100644 index 000000000..4f3bee0d7 --- /dev/null +++ b/tests/should_ok/and.er @@ -0,0 +1,7 @@ +a as Eq and Hash and Show = 1 +f _: Eq and Show and Hash = None +f a + +b as Eq and Hash and Ord and Show = 1 +g _: Ord and Eq and Show and Hash = None +g b diff --git a/tests/should_ok/or.er b/tests/should_ok/or.er new file mode 100644 index 000000000..2e36f19b0 --- /dev/null +++ b/tests/should_ok/or.er @@ -0,0 +1,7 @@ +a as Nat or Str or NoneType = 1 +f _: Int or NoneType or Str = None +f a + +b as Nat or Str or NoneType or Bool = 1 +g _: Int or NoneType or Bool or Str = None +g b diff --git a/tests/test.rs b/tests/test.rs index 3259eee09..4ed20fb92 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -16,6 +16,11 @@ fn exec_advanced_type_spec() -> Result<(), ()> { expect_success("tests/should_ok/advanced_type_spec.er", 5) } +#[test] +fn exec_and() -> Result<(), ()> { + expect_success("tests/should_ok/and.er", 0) +} + #[test] fn exec_args_expansion() -> Result<(), ()> { expect_success("tests/should_ok/args_expansion.er", 0) @@ -327,6 +332,11 @@ fn exec_operators() -> Result<(), ()> { expect_success("tests/should_ok/operators.er", 0) } +#[test] +fn exec_or() -> Result<(), ()> { + expect_success("tests/should_ok/or.er", 0) +} + #[test] fn exec_patch() -> Result<(), ()> { expect_success("examples/patch.er", 0) @@ -527,6 +537,11 @@ fn exec_list_member_err() -> Result<(), ()> { expect_failure("tests/should_err/list_member.er", 0, 3) } +#[test] +fn exec_and_err() -> Result<(), ()> { + expect_failure("tests/should_err/and.er", 0, 1) +} + #[test] fn exec_as() -> Result<(), ()> { expect_failure("tests/should_err/as.er", 0, 6) @@ -634,6 +649,11 @@ fn exec_move_check() -> Result<(), ()> { expect_failure("examples/move_check.er", 1, 1) } +#[test] +fn exec_or_err() -> Result<(), ()> { + expect_failure("tests/should_err/or.er", 0, 1) +} + #[test] fn exec_poly_type_spec_err() -> Result<(), ()> { expect_failure("tests/should_err/poly_type_spec.er", 0, 3)