Skip to content

Commit

Permalink
Merge pull request #521 from erg-lang/perf_or_type
Browse files Browse the repository at this point in the history
Change And/Or-type structures
  • Loading branch information
mtshiba authored Sep 18, 2024
2 parents 43828f6 + df837d7 commit 0408117
Show file tree
Hide file tree
Showing 28 changed files with 1,035 additions and 764 deletions.
28 changes: 17 additions & 11 deletions crates/els/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,30 @@ fn comp_item_kind(t: &Type, muty: Mutability) -> CompletionItemKind {
Type::Subr(_) | Type::Quantified(_) => CompletionItemKind::FUNCTION,
Type::ClassType => CompletionItemKind::CLASS,
Type::TraitType => CompletionItemKind::INTERFACE,
Type::Or(l, r) => {
let l = comp_item_kind(l, muty);
let r = comp_item_kind(r, muty);
if l == r {
l
Type::Or(tys) => {
let fst = comp_item_kind(tys.iter().next().unwrap(), muty);
if tys
.iter()
.map(|t| comp_item_kind(t, muty))
.all(|k| k == fst)
{
fst
} else if muty.is_const() {
CompletionItemKind::CONSTANT
} else {
CompletionItemKind::VARIABLE
}
}
Type::And(l, r) => {
let l = comp_item_kind(l, muty);
let r = comp_item_kind(r, muty);
if l == CompletionItemKind::VARIABLE {
r
Type::And(tys) => {
for k in tys.iter().map(|t| comp_item_kind(t, muty)) {
if k != CompletionItemKind::VARIABLE {
return k;
}
}
if muty.is_const() {
CompletionItemKind::CONSTANT
} else {
l
CompletionItemKind::VARIABLE
}
}
Type::Refinement(r) => comp_item_kind(&r.t, muty),
Expand Down
12 changes: 12 additions & 0 deletions crates/erg_common/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ impl<K, V> Dict<K, V> {
}
}

/// ```
/// # use erg_common::dict;
/// # use erg_common::dict::Dict;
/// let mut dict = Dict::with_capacity(3);
/// assert_eq!(dict.capacity(), 3);
/// dict.insert("a", 1);
/// assert_eq!(dict.capacity(), 3);
/// dict.insert("b", 2);
/// dict.insert("c", 3);
/// dict.insert("d", 4);
/// assert_ne!(dict.capacity(), 3);
/// ```
pub fn with_capacity(capacity: usize) -> Self {
Self {
dict: FxHashMap::with_capacity_and_hasher(capacity, Default::default()),
Expand Down
11 changes: 11 additions & 0 deletions crates/erg_common/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,17 @@ impl RecursionCounter {

#[macro_export]
macro_rules! set_recursion_limit {
(panic, $msg:expr, $limit:expr) => {
use std::sync::atomic::AtomicU32;

static COUNTER: AtomicU32 = AtomicU32::new($limit);

let counter = $crate::macros::RecursionCounter::new(&COUNTER);
if counter.limit_reached() {
$crate::log!(err "Recursion limit reached");
panic!($msg);
}
};
($returns:expr, $limit:expr) => {
use std::sync::atomic::AtomicU32;

Expand Down
14 changes: 14 additions & 0 deletions crates/erg_common/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,20 @@ impl<T: Hash + Eq + Clone> Set<T> {
self.insert(other);
self
}

/// ```
/// # use erg_common::set;
/// assert_eq!(set!{1, 2}.product(&set!{3, 4}), set!{(&1, &3), (&1, &4), (&2, &3), (&2, &4)});
/// ```
pub fn product<'l, 'r, U: Hash + Eq>(&'l self, other: &'r Set<U>) -> Set<(&'l T, &'r U)> {
let mut res = set! {};
for x in self.iter() {
for y in other.iter() {
res.insert((x, y));
}
}
res
}
}

impl<T: Hash + Ord> Set<T> {
Expand Down
2 changes: 2 additions & 0 deletions crates/erg_common/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,8 @@ impl<T: Immutable + ?Sized> Immutable for &T {}
impl<T: Immutable> Immutable for Option<T> {}
impl<T: Immutable> Immutable for Vec<T> {}
impl<T: Immutable> Immutable for [T] {}
impl<T: Immutable, U: Immutable> Immutable for (T, U) {}
impl<T: Immutable, U: Immutable, V: Immutable> Immutable for (T, U, V) {}
impl<T: Immutable + ?Sized> Immutable for Box<T> {}
impl<T: Immutable + ?Sized> Immutable for std::rc::Rc<T> {}
impl<T: Immutable + ?Sized> Immutable for std::sync::Arc<T> {}
10 changes: 10 additions & 0 deletions crates/erg_common/triple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ impl<T: fmt::Display, E: fmt::Display> fmt::Display for Triple<T, E> {
}

impl<T, E> Triple<T, E> {
pub const fn is_ok(&self) -> bool {
matches!(self, Triple::Ok(_))
}
pub const fn is_err(&self) -> bool {
matches!(self, Triple::Err(_))
}
pub const fn is_none(&self) -> bool {
matches!(self, Triple::None)
}

pub fn none_then(self, err: E) -> Result<T, E> {
match self {
Triple::None => Err(err),
Expand Down
127 changes: 69 additions & 58 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,7 @@ impl Context {
self.structural_supertype_of(&l, rhs)
}
// {1, 2, 3} :> {1} or {2, 3} == true
(Refinement(_refine), Or(l, r)) => {
self.supertype_of(lhs, l) && self.supertype_of(lhs, r)
}
(Refinement(_refine), Or(tys)) => tys.iter().all(|ty| self.supertype_of(lhs, ty)),
// ({I: Int | True} :> Int) == true
// {N: Nat | ...} :> Int) == false
// ({I: Int | I >= 0} :> Int) == false
Expand Down Expand Up @@ -817,41 +815,37 @@ impl Context {
self.sub_unify(&inst, l, &(), None).is_ok()
}
// Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true
(Or(l_1, l_2), Or(r_1, r_2)) => {
if l_1.is_union_type() && self.supertype_of(l_1, rhs) {
return true;
}
if l_2.is_union_type() && self.supertype_of(l_2, rhs) {
return true;
}
(self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2))
|| (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1))
}
// Int or Str or NoneType :> Str or Int
// Int or Str or NoneType :> Str or NoneType or Nat
(Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))),
// not Nat :> not Int == true
(Not(l), Not(r)) => self.subtype_of(l, r),
// (Int or Str) :> Nat == Int :> Nat || Str :> Nat == true
// (Num or Show) :> Show == Num :> Show || Show :> Num == true
(Or(l_or, r_or), rhs) => self.supertype_of(l_or, rhs) || self.supertype_of(r_or, rhs),
(Or(ors), rhs) => ors.iter().any(|or| self.supertype_of(or, rhs)),
// Int :> (Nat or Str) == Int :> Nat && Int :> Str == false
(lhs, Or(l_or, r_or)) => self.supertype_of(lhs, l_or) && self.supertype_of(lhs, r_or),
(And(l_1, l_2), And(r_1, r_2)) => {
if l_1.is_intersection_type() && self.supertype_of(l_1, rhs) {
(lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)),
// Hash and Eq :> HashEq and ... == true
// Add(T) and Eq :> Add(Int) and Eq == true
(And(l), And(r)) => {
if r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))) {
return true;
}
if l_2.is_intersection_type() && self.supertype_of(l_2, rhs) {
return true;
if l.len() == r.len() {
let mut r = r.clone();
for _ in 1..l.len() {
if l.iter().zip(&r).all(|(l, r)| self.supertype_of(l, r)) {
return true;
}
r.rotate_left(1);
}
}
(self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2))
|| (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1))
false
}
// (Num and Show) :> Show == false
(And(l_and, r_and), rhs) => {
self.supertype_of(l_and, rhs) && self.supertype_of(r_and, rhs)
}
(And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)),
// Show :> (Num and Show) == true
(lhs, And(l_and, r_and)) => {
self.supertype_of(lhs, l_and) || self.supertype_of(lhs, r_and)
}
(lhs, And(ands)) => ands.iter().any(|and| self.supertype_of(lhs, and)),
// Not(Eq) :> Float == !(Eq :> Float) == true
(Not(_), Obj) => false,
(Not(l), rhs) => !self.supertype_of(l, rhs),
Expand Down Expand Up @@ -923,18 +917,18 @@ impl Context {
Type::NamedTuple(fields) => fields.iter().cloned().collect(),
Type::Refinement(refine) => self.fields(&refine.t),
Type::Structural(t) => self.fields(t),
Type::Or(l, r) => {
let l_fields = self.fields(l);
let r_fields = self.fields(r);
let l_field_names = l_fields.keys().collect::<Set<_>>();
let r_field_names = r_fields.keys().collect::<Set<_>>();
let field_names = l_field_names.intersection(&r_field_names);
Type::Or(tys) => {
let or_fields = tys.iter().map(|t| self.fields(t)).collect::<Set<_>>();
let field_names = or_fields
.iter()
.flat_map(|fs| fs.keys())
.collect::<Set<_>>();
let mut fields = Dict::new();
for (name, l_t, r_t) in field_names
for (name, tys) in field_names
.iter()
.map(|&name| (name, &l_fields[name], &r_fields[name]))
.map(|&name| (name, or_fields.iter().filter_map(|fields| fields.get(name))))
{
let union = self.union(l_t, r_t);
let union = tys.fold(Never, |acc, ty| self.union(&acc, ty));
fields.insert(name.clone(), union);
}
fields
Expand Down Expand Up @@ -1417,6 +1411,8 @@ impl Context {
/// union({ .a = Int }, { .a = Str }) == { .a = Int or Str }
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information
/// union((A and B) or C) == (A or C) and (B or C)
/// union(Never, Int) == Int
/// union(Obj, Int) == Obj
/// ```
pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type {
if lhs == rhs {
Expand Down Expand Up @@ -1479,10 +1475,9 @@ impl Context {
(Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()),
_ => self.simple_union(lhs, rhs),
},
(other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other),
(other, or @ Or(_)) | (or @ Or(_), other) => self.union_add(or, other),
// (A and B) or C ==> (A or C) and (B or C)
(and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => {
let ands = and_t.ands();
(And(ands), other) | (other, And(ands)) => {
let mut t = Type::Obj;
for branch in ands.iter() {
let union = self.union(branch, other);
Expand Down Expand Up @@ -1666,6 +1661,12 @@ impl Context {

/// Returns intersection of two types (`A and B`).
/// If `A` and `B` have a subtype relationship, it is equal to `min(A, B)`.
/// ```erg
/// intersection(Nat, Int) == Nat
/// intersection(Int, Str) == Never
/// intersection(Obj, Int) == Int
/// intersection(Never, Int) == Never
/// ```
pub(crate) fn intersection(&self, lhs: &Type, rhs: &Type) -> Type {
if lhs == rhs {
return lhs.clone();
Expand Down Expand Up @@ -1696,12 +1697,9 @@ impl Context {
(_, Not(r)) => self.diff(lhs, r),
(Not(l), _) => self.diff(rhs, l),
// A and B and A == A and B
(other, and @ And(_, _)) | (and @ And(_, _), other) => {
self.intersection_add(and, other)
}
(other, and @ And(_)) | (and @ And(_), other) => self.intersection_add(and, other),
// (A or B) and C == (A and C) or (B and C)
(or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => {
let ors = or_t.ors();
(Or(ors), other) | (other, Or(ors)) => {
if ors.iter().any(|t| t.has_unbound_var()) {
return self.simple_intersection(lhs, rhs);
}
Expand Down Expand Up @@ -1797,13 +1795,15 @@ impl Context {
/// intersection_add(Int and ?T(:> NoneType), Str) == Never
/// ```
fn intersection_add(&self, intersection: &Type, elem: &Type) -> Type {
let ands = intersection.ands();
let mut ands = intersection.ands();
let bounded = ands.iter().map(|t| t.lower_bounded());
for t in bounded {
if self.subtype_of(&t, elem) {
return intersection.clone();
} else if self.supertype_of(&t, elem) {
return constructors::ands(ands.linear_exclude(&t).include(elem.clone()));
ands.retain(|ty| ty != &t);
ands.push(elem.clone());
return constructors::ands(ands);
}
}
and(intersection.clone(), elem.clone())
Expand Down Expand Up @@ -1836,21 +1836,21 @@ impl Context {
fn narrow_type_by_pred(&self, t: Type, pred: &Predicate) -> Type {
match (t, pred) {
(
Type::Or(l, r),
Type::Or(tys),
Predicate::NotEqual {
rhs: TyParam::Value(v),
..
},
) if v.is_none() => {
let l = self.narrow_type_by_pred(*l, pred);
let r = self.narrow_type_by_pred(*r, pred);
if l.is_nonetype() {
r
} else if r.is_nonetype() {
l
} else {
or(l, r)
let mut new_tys = Set::new();
for ty in tys {
let ty = self.narrow_type_by_pred(ty, pred);
if ty.is_nonelike() {
continue;
}
new_tys.insert(ty);
}
Type::checked_or(new_tys)
}
(Type::Refinement(refine), _) => {
let t = self.narrow_type_by_pred(*refine.t, pred);
Expand Down Expand Up @@ -1992,8 +1992,12 @@ impl Context {
guard.target.clone(),
self.complement(&guard.to),
)),
Or(l, r) => self.intersection(&self.complement(l), &self.complement(r)),
And(l, r) => self.union(&self.complement(l), &self.complement(r)),
Or(ors) => ors
.iter()
.fold(Obj, |l, r| self.intersection(&l, &self.complement(r))),
And(ands) => ands
.iter()
.fold(Never, |l, r| self.union(&l, &self.complement(r))),
other => not(other.clone()),
}
}
Expand All @@ -2011,7 +2015,14 @@ impl Context {
match lhs {
Type::FreeVar(fv) if fv.is_linked() => self.diff(&fv.crack(), rhs),
// Type::And(l, r) => self.intersection(&self.diff(l, rhs), &self.diff(r, rhs)),
Type::Or(l, r) => self.union(&self.diff(l, rhs), &self.diff(r, rhs)),
Type::Or(tys) => {
let mut new_tys = vec![];
for ty in tys {
let diff = self.diff(ty, rhs);
new_tys.push(diff);
}
new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r))
}
_ => lhs.clone(),
}
}
Expand Down
Loading

0 comments on commit 0408117

Please sign in to comment.