From 2e6050713c45fc3da8fd92cd8f09305d6d6237fa Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 30 Nov 2024 21:36:01 +0900 Subject: [PATCH] fix: `Predicate::Or(Set)` --- crates/erg_compiler/context/compare.rs | 33 ++-- crates/erg_compiler/context/eval.rs | 29 ++-- crates/erg_compiler/context/generalize.rs | 21 +-- crates/erg_compiler/context/instantiate.rs | 10 +- crates/erg_compiler/context/unify.rs | 20 ++- crates/erg_compiler/ty/predicate.rs | 166 ++++++++++++--------- 6 files changed, 165 insertions(+), 114 deletions(-) diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 1a73a3e51..22325deb9 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -2217,9 +2217,9 @@ impl Context { | Predicate::GeneralLessEqual { rhs, .. } | Predicate::GeneralNotEqual { rhs, .. } => self.get_pred_type(rhs), // x == 1 or x == "a" => Int or Str - Predicate::Or(lhs, rhs) => { - self.union(&self.get_pred_type(lhs), &self.get_pred_type(rhs)) - } + Predicate::Or(ors) => ors + .iter() + .fold(Never, |l, r| self.union(&l, &self.get_pred_type(r))), // REVIEW: Predicate::And(lhs, rhs) => { self.intersection(&self.get_pred_type(lhs), &self.get_pred_type(rhs)) @@ -2345,14 +2345,17 @@ impl Context { (None, None) => None, } } - Predicate::Or(l, r) => { - let l = self.eliminate_type_mismatched_preds(var, t, *l); - let r = self.eliminate_type_mismatched_preds(var, t, *r); - match (l, r) { - (Some(l), Some(r)) => Some(l | r), - (Some(l), None) => Some(l), - (None, Some(r)) => Some(r), - (None, None) => None, + Predicate::Or(preds) => { + let mut new_preds = Set::with_capacity(preds.len()); + for pred in preds { + if let Some(new_pred) = self.eliminate_type_mismatched_preds(var, t, pred) { + new_preds.insert(new_pred); + } + } + if new_preds.is_empty() { + None + } else { + Some(Predicate::Or(new_preds)) } } _ => Some(pred), @@ -2462,7 +2465,7 @@ impl Context { // {I == 1 or I == 0} !:> {I == 0 or I == 1 or I == 3} // NG: (self.is_super_pred_of(l1, l2) && self.is_super_pred_of(r1, r2)) // || (self.is_super_pred_of(l1, r2) && self.is_super_pred_of(r1, l2)) - (Pred::Or(_, _), Pred::Or(_, _)) => { + (Pred::Or(_), Pred::Or(_)) => { let lhs_ors = self.reduce_preds("or", lhs.ors()); let rhs_ors = self.reduce_preds("or", rhs.ors()); for r_val in rhs_ors.iter() { @@ -2509,8 +2512,8 @@ impl Context { (lhs, Pred::And(l, r)) => { self.is_super_pred_of(lhs, l) || self.is_super_pred_of(lhs, r) } - (lhs, Pred::Or(l, r)) => self.is_super_pred_of(lhs, l) && self.is_super_pred_of(lhs, r), - (Pred::Or(l, r), rhs) => self.is_super_pred_of(l, rhs) || self.is_super_pred_of(r, rhs), + (lhs, Pred::Or(ors)) => ors.iter().all(|or| self.is_super_pred_of(lhs, or)), + (Pred::Or(ors), rhs) => ors.iter().any(|or| self.is_super_pred_of(or, rhs)), (Pred::And(l, r), rhs) => { self.is_super_pred_of(l, rhs) && self.is_super_pred_of(r, rhs) } @@ -2523,7 +2526,7 @@ impl Context { } } - fn is_sub_pred_of(&self, lhs: &Predicate, rhs: &Predicate) -> bool { + pub(crate) fn is_sub_pred_of(&self, lhs: &Predicate, rhs: &Predicate) -> bool { self.is_super_pred_of(rhs, lhs) } diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 83d4b78a5..f06ae722d 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -4166,22 +4166,19 @@ impl Context { }; lhs & rhs } - Predicate::Or(l, r) => { - let lhs = match self.eval_pred(*l) { - Ok(pred) => pred, - Err((pred, es)) => { - errs.extend(es); - pred - } - }; - let rhs = match self.eval_pred(*r) { - Ok(pred) => pred, - Err((pred, es)) => { - errs.extend(es); - pred - } - }; - lhs | rhs + Predicate::Or(preds) => { + let mut new_preds = Set::with_capacity(preds.len()); + for pred in preds { + let pred = match self.eval_pred(pred) { + Ok(pred) => pred, + Err((pred, es)) => { + errs.extend(es); + pred + } + }; + new_preds.insert(pred); + } + Predicate::Or(new_preds) } Predicate::Not(pred) => { let pred = match self.eval_pred(*pred) { diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index f764d28e8..2af037edb 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -434,11 +434,12 @@ impl<'c> Generalizer<'c> { let rhs = self.generalize_pred(*rhs, uninit); Predicate::and(lhs, rhs) } - Predicate::Or(lhs, rhs) => { - let lhs = self.generalize_pred(*lhs, uninit); - let rhs = self.generalize_pred(*rhs, uninit); - Predicate::or(lhs, rhs) - } + Predicate::Or(preds) => Predicate::Or( + preds + .into_iter() + .map(|pred| self.generalize_pred(pred, uninit)) + .collect(), + ), Predicate::Not(pred) => { let pred = self.generalize_pred(*pred, uninit); !pred @@ -816,10 +817,12 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> { let rhs = self.deref_pred(*rhs)?; Ok(Predicate::and(lhs, rhs)) } - Predicate::Or(lhs, rhs) => { - let lhs = self.deref_pred(*lhs)?; - let rhs = self.deref_pred(*rhs)?; - Ok(Predicate::or(lhs, rhs)) + Predicate::Or(preds) => { + let mut new_preds = Set::with_capacity(preds.len()); + for pred in preds.into_iter() { + new_preds.insert(self.deref_pred(pred)?); + } + Ok(Predicate::Or(new_preds)) } Predicate::Not(pred) => { let pred = self.deref_pred(*pred)?; diff --git a/crates/erg_compiler/context/instantiate.rs b/crates/erg_compiler/context/instantiate.rs index 15b7bc074..9b5684d7b 100644 --- a/crates/erg_compiler/context/instantiate.rs +++ b/crates/erg_compiler/context/instantiate.rs @@ -630,10 +630,12 @@ impl Context { let r = self.instantiate_pred(*r, tmp_tv_cache, loc)?; Ok(Predicate::and(l, r)) } - Predicate::Or(l, r) => { - let l = self.instantiate_pred(*l, tmp_tv_cache, loc)?; - let r = self.instantiate_pred(*r, tmp_tv_cache, loc)?; - Ok(Predicate::or(l, r)) + Predicate::Or(preds) => { + let mut new_preds = Set::with_capacity(preds.len()); + for pred in preds { + new_preds.insert(self.instantiate_pred(pred, tmp_tv_cache, loc)?); + } + Ok(Predicate::Or(new_preds)) } Predicate::Not(pred) => { let pred = self.instantiate_pred(*pred, tmp_tv_cache, loc)?; diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index eabc5ebc4..e07dd528b 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -848,12 +848,30 @@ impl Unifier<'_, '_, '_, L> { | (Pred::NotEqual { rhs, .. }, Pred::NotEqual { rhs: rhs2, .. }) => { self.sub_unify_tp(rhs, rhs2, None, false) } - (Pred::And(l1, r1), Pred::And(l2, r2)) | (Pred::Or(l1, r1), Pred::Or(l2, r2)) => { + (Pred::And(l1, r1), Pred::And(l2, r2)) => { match (self.sub_unify_pred(l1, l2), self.sub_unify_pred(r1, r2)) { (Ok(()), Ok(())) => Ok(()), (Ok(()), Err(e)) | (Err(e), Ok(())) | (Err(e), Err(_)) => Err(e), } } + (Pred::Or(l_preds), Pred::Or(r_preds)) => { + let mut l_preds_ = l_preds.clone(); + let mut r_preds_ = r_preds.clone(); + for l_pred in l_preds { + if r_preds_.linear_remove(l_pred) { + l_preds_.linear_remove(l_pred); + } + } + for l_pred in l_preds_.iter() { + for r_pred in r_preds_.iter() { + if self.ctx.is_sub_pred_of(l_pred, r_pred) { + self.sub_unify_pred(l_pred, r_pred)?; + continue; + } + } + } + Ok(()) + } (Pred::Not(l), Pred::Not(r)) => self.sub_unify_pred(r, l), // sub_unify_pred(I == M, I <= ?N(: Nat)) ==> ?N(: M..) (Pred::Equal { rhs, .. }, Pred::LessEqual { rhs: rhs2, .. }) => { diff --git a/crates/erg_compiler/ty/predicate.rs b/crates/erg_compiler/ty/predicate.rs index 6bebdf0ca..91154a1cb 100644 --- a/crates/erg_compiler/ty/predicate.rs +++ b/crates/erg_compiler/ty/predicate.rs @@ -4,7 +4,7 @@ use std::ops::{BitAnd, BitOr, Not}; #[allow(unused_imports)] use erg_common::log; use erg_common::set::Set; -use erg_common::traits::{LimitedDisplay, StructuralEq}; +use erg_common::traits::{Immutable, LimitedDisplay, StructuralEq}; use erg_common::{fmt_option, set, Str}; use super::free::{Constraint, HasLevel}; @@ -12,6 +12,8 @@ use super::typaram::TyParam; use super::value::ValueObj; use super::{SharedFrees, Type}; +impl Immutable for Predicate {} + #[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub enum Predicate { Value(ValueObj), // True/False @@ -61,7 +63,7 @@ pub enum Predicate { lhs: Box, rhs: Box, }, - Or(Box, Box), + Or(Set), And(Box, Box), Not(Box), #[default] @@ -97,7 +99,16 @@ impl fmt::Display for Predicate { Self::GeneralLessEqual { lhs, rhs } => write!(f, "{lhs} <= {rhs}"), Self::GeneralGreaterEqual { lhs, rhs } => write!(f, "{lhs} >= {rhs}"), Self::GeneralNotEqual { lhs, rhs } => write!(f, "{lhs} != {rhs}"), - Self::Or(l, r) => write!(f, "({l}) or ({r})"), + Self::Or(preds) => { + write!(f, "(")?; + for (i, pred) in preds.iter().enumerate() { + if i != 0 { + write!(f, " or ")?; + } + write!(f, "{pred}")?; + } + write!(f, ")") + } Self::And(l, r) => write!(f, "({l}) and ({r})"), Self::Not(p) => write!(f, "not ({p})"), Self::Failure => write!(f, ""), @@ -166,11 +177,14 @@ impl LimitedDisplay for Predicate { write!(f, " != ")?; rhs.limited_fmt(f, limit - 1) } - Self::Or(l, r) => { + Self::Or(preds) => { write!(f, "(")?; - l.limited_fmt(f, limit - 1)?; - write!(f, ") or (")?; - r.limited_fmt(f, limit - 1)?; + for (i, pred) in preds.iter().enumerate() { + if i != 0 { + write!(f, " or ")?; + } + pred.limited_fmt(f, limit - 1)?; + } write!(f, ")") } Self::And(l, r) => { @@ -231,9 +245,7 @@ impl StructuralEq for Predicate { && name == n && args.iter().zip(a.iter()).all(|(l, r)| l.structural_eq(r)) } - (Self::Or(_, _), Self::Or(_, _)) => { - let self_ors = self.ors(); - let other_ors = other.ors(); + (Self::Or(self_ors), Self::Or(other_ors)) => { if self_ors.len() != other_ors.len() { return false; } @@ -280,9 +292,8 @@ impl HasLevel for Predicate { | Self::GeneralNotEqual { lhs, rhs } => { lhs.level().zip(rhs.level()).map(|(a, b)| a.min(b)) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.level().zip(rhs.level()).map(|(a, b)| a.min(b)) - } + Self::Or(preds) => preds.iter().filter_map(|p| p.level()).min(), + Self::And(lhs, rhs) => lhs.level().zip(rhs.level()).map(|(a, b)| a.min(b)), Self::Not(p) => p.level(), Self::Call { receiver, args, .. } => receiver .level() @@ -317,7 +328,12 @@ impl HasLevel for Predicate { lhs.set_level(level); rhs.set_level(level); } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { + Self::Or(preds) => { + for pred in preds { + pred.set_level(level); + } + } + Self::And(lhs, rhs) => { lhs.set_level(level); rhs.set_level(level); } @@ -453,14 +469,9 @@ impl Predicate { | (_, Predicate::Value(ValueObj::Bool(true))) => Predicate::TRUE, (Predicate::Value(ValueObj::Bool(false)), p) => p, (p, Predicate::Value(ValueObj::Bool(false))) => p, - (Predicate::Or(l, r), other) | (other, Predicate::Or(l, r)) => { - if l.as_ref() == &other { - *r | other - } else if r.as_ref() == &other { - *l | other - } else { - Self::Or(Box::new(Self::Or(l, r)), Box::new(other)) - } + (Predicate::Or(mut preds), other) | (other, Predicate::Or(mut preds)) => { + preds.insert(other); + Self::Or(preds) } // I == 1 or I >= 1 => I >= 1 ( @@ -474,7 +485,7 @@ impl Predicate { if p1 == p2 { p1 } else { - Self::Or(Box::new(p1), Box::new(p2)) + Self::Or(set! { p1, p2 }) } } } @@ -487,9 +498,8 @@ impl Predicate { pub fn consist_of_equal(&self) -> bool { match self { Self::Equal { .. } => true, - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.consist_of_equal() && rhs.consist_of_equal() - } + Self::Or(preds) => preds.iter().all(|p| p.consist_of_equal()), + Self::And(lhs, rhs) => lhs.consist_of_equal() && rhs.consist_of_equal(), Self::Not(pred) => pred.consist_of_equal(), _ => false, } @@ -508,11 +518,7 @@ impl Predicate { pub fn ors(&self) -> Set<&Predicate> { match self { - Self::Or(lhs, rhs) => { - let mut set = lhs.ors(); - set.extend(rhs.ors()); - set - } + Self::Or(preds) => preds.iter().collect(), _ => set! { self }, } } @@ -523,7 +529,18 @@ impl Predicate { | Self::LessEqual { lhs, .. } | Self::GreaterEqual { lhs, .. } | Self::NotEqual { lhs, .. } => Some(&lhs[..]), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { + Self::Or(preds) => { + let mut iter = preds.iter(); + let first = iter.next()?; + let subject = first.subject()?; + for pred in iter { + if subject != pred.subject()? { + return None; + } + } + Some(subject) + } + Self::And(lhs, rhs) => { let l = lhs.subject(); let r = rhs.subject(); if l != r { @@ -548,9 +565,12 @@ impl Predicate { lhs.change_subject_name(name.clone()), rhs.change_subject_name(name), ), - Self::Or(lhs, rhs) => Self::or( - lhs.change_subject_name(name.clone()), - rhs.change_subject_name(name), + Self::Or(preds) => Self::Or( + preds + .iter() + .cloned() + .map(|p| p.change_subject_name(name.clone())) + .collect(), ), Self::Not(pred) => Self::not(pred.change_subject_name(name)), Self::GeneralEqual { lhs, rhs } => Self::general_eq( @@ -588,7 +608,7 @@ impl Predicate { Self::LessEqual { lhs, rhs } => Self::le(lhs, rhs.substitute(var, tp)), Self::NotEqual { lhs, rhs } => Self::ne(lhs, rhs.substitute(var, tp)), Self::And(lhs, rhs) => Self::and(lhs.substitute(var, tp), rhs.substitute(var, tp)), - Self::Or(lhs, rhs) => Self::or(lhs.substitute(var, tp), rhs.substitute(var, tp)), + Self::Or(preds) => Self::Or(preds.into_iter().map(|p| p.substitute(var, tp)).collect()), Self::Not(pred) => Self::not(pred.substitute(var, tp)), Self::GeneralEqual { lhs, rhs } => { Self::general_eq(lhs.substitute(var, tp), rhs.substitute(var, tp)) @@ -638,7 +658,8 @@ impl Predicate { | Self::GeneralGreaterEqual { lhs, rhs } | Self::GeneralNotEqual { lhs, rhs } => lhs.mentions(name) || rhs.mentions(name), Self::Not(pred) => pred.mentions(name), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.mentions(name) || rhs.mentions(name), + Self::And(lhs, rhs) => lhs.mentions(name) || rhs.mentions(name), + Self::Or(preds) => preds.iter().any(|p| p.mentions(name)), _ => false, } } @@ -647,7 +668,14 @@ impl Predicate { match self { Self::Value(l) => Some(matches!(l, ValueObj::Bool(false))), Self::Const(_) => None, - Self::Or(lhs, rhs) => Some(lhs.can_be_false()? || rhs.can_be_false()?), + Self::Or(preds) => { + for pred in preds { + if pred.can_be_false()? { + return Some(true); + } + } + Some(false) + } Self::And(lhs, rhs) => Some(lhs.can_be_false()? && rhs.can_be_false()?), Self::Not(pred) => Some(!pred.can_be_false()?), _ => Some(true), @@ -676,7 +704,8 @@ impl Predicate { | Self::GeneralNotEqual { lhs, rhs } => { lhs.qvars().concat(rhs.qvars()).into_iter().collect() } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.qvars().concat(rhs.qvars()), + Self::And(lhs, rhs) => lhs.qvars().concat(rhs.qvars()), + Self::Or(preds) => preds.iter().fold(set! {}, |acc, p| acc.union(&p.qvars())), Self::Not(pred) => pred.qvars(), } } @@ -699,9 +728,8 @@ impl Predicate { | Self::GeneralNotEqual { lhs, rhs } => { lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) } - Self::Or(lhs, rhs) | Self::And(lhs, rhs) => { - lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) - } + Self::And(lhs, rhs) => lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f), + Self::Or(preds) => preds.iter().any(|p| p.has_type_satisfies(f)), Self::Not(pred) => pred.has_type_satisfies(f), } } @@ -722,7 +750,8 @@ impl Predicate { | Self::GeneralLessEqual { lhs, rhs } | Self::GeneralGreaterEqual { lhs, rhs } | Self::GeneralNotEqual { lhs, rhs } => lhs.has_qvar() || rhs.has_qvar(), - Self::Or(lhs, rhs) | Self::And(lhs, rhs) => lhs.has_qvar() || rhs.has_qvar(), + Self::And(lhs, rhs) => lhs.has_qvar() || rhs.has_qvar(), + Self::Or(preds) => preds.iter().any(|p| p.has_qvar()), Self::Not(pred) => pred.has_qvar(), } } @@ -743,9 +772,8 @@ impl Predicate { | Self::GeneralLessEqual { lhs, rhs } | Self::GeneralGreaterEqual { lhs, rhs } | Self::GeneralNotEqual { lhs, rhs } => lhs.has_unbound_var() || rhs.has_unbound_var(), - Self::Or(lhs, rhs) | Self::And(lhs, rhs) => { - lhs.has_unbound_var() || rhs.has_unbound_var() - } + Self::And(lhs, rhs) => lhs.has_unbound_var() || rhs.has_unbound_var(), + Self::Or(preds) => preds.iter().any(|p| p.has_unbound_var()), Self::Not(pred) => pred.has_unbound_var(), } } @@ -769,9 +797,8 @@ impl Predicate { | Self::GeneralNotEqual { lhs, rhs } => { lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var() } - Self::Or(lhs, rhs) | Self::And(lhs, rhs) => { - lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var() - } + Self::And(lhs, rhs) => lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var(), + Self::Or(preds) => preds.iter().any(|p| p.has_undoable_linked_var()), Self::Not(pred) => pred.has_undoable_linked_var(), } } @@ -825,9 +852,8 @@ impl Predicate { | Self::GeneralLessEqual { .. } | Self::GeneralGreaterEqual { .. } | Self::GeneralNotEqual { .. } => vec![], - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.typarams().into_iter().chain(rhs.typarams()).collect() - } + Self::And(lhs, rhs) => lhs.typarams().into_iter().chain(rhs.typarams()).collect(), + Self::Or(preds) => preds.iter().flat_map(|p| p.typarams()).collect(), Self::Not(pred) => pred.typarams(), } } @@ -850,7 +876,7 @@ impl Predicate { pub fn possible_tps(&self) -> Vec<&TyParam> { match self { - Self::Or(lhs, rhs) => [lhs.possible_tps(), rhs.possible_tps()].concat(), + Self::Or(preds) => preds.iter().flat_map(|p| p.possible_tps()).collect(), Self::Equal { rhs, .. } => vec![rhs], _ => vec![], } @@ -863,7 +889,7 @@ impl Predicate { rhs: TyParam::Value(value), .. } => vec![value], - Self::Or(lhs, rhs) => [lhs.possible_values(), rhs.possible_values()].concat(), + Self::Or(preds) => preds.iter().flat_map(|p| p.possible_values()).collect(), _ => vec![], } } @@ -888,7 +914,10 @@ impl Predicate { | Self::GeneralLessEqual { lhs, rhs } | Self::GeneralGreaterEqual { lhs, rhs } | Self::GeneralNotEqual { lhs, rhs } => lhs.variables().concat(rhs.variables()), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.variables().concat(rhs.variables()), + Self::And(lhs, rhs) => lhs.variables().concat(rhs.variables()), + Self::Or(preds) => preds + .iter() + .fold(set! {}, |acc, p| acc.union(&p.variables())), Self::Not(pred) => pred.variables(), } } @@ -911,9 +940,8 @@ impl Predicate { | Self::GeneralNotEqual { lhs, rhs } => { lhs.contains_value(value) || rhs.contains_value(value) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.contains_value(value) || rhs.contains_value(value) - } + Self::And(lhs, rhs) => lhs.contains_value(value) || rhs.contains_value(value), + Self::Or(preds) => preds.iter().any(|p| p.contains_value(value)), Self::Not(pred) => pred.contains_value(value), Self::Failure => false, } @@ -934,7 +962,8 @@ impl Predicate { | Self::GeneralLessEqual { lhs, rhs } | Self::GeneralGreaterEqual { lhs, rhs } | Self::GeneralNotEqual { lhs, rhs } => lhs.contains_tp(tp) || rhs.contains_tp(tp), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.contains_tp(tp) || rhs.contains_tp(tp), + Self::And(lhs, rhs) => lhs.contains_tp(tp) || rhs.contains_tp(tp), + Self::Or(preds) => preds.iter().any(|p| p.contains_tp(tp)), Self::Not(pred) => pred.contains_tp(tp), Self::Failure | Self::Const(_) => false, } @@ -955,7 +984,8 @@ impl Predicate { | Self::GeneralLessEqual { lhs, rhs } | Self::GeneralGreaterEqual { lhs, rhs } | Self::GeneralNotEqual { lhs, rhs } => lhs.contains_t(t) || rhs.contains_t(t), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.contains_t(t) || rhs.contains_t(t), + Self::And(lhs, rhs) => lhs.contains_t(t) || rhs.contains_t(t), + Self::Or(preds) => preds.iter().any(|p| p.contains_t(t)), Self::Not(pred) => pred.contains_t(t), Self::Const(_) | Self::Failure => false, } @@ -1035,9 +1065,7 @@ impl Predicate { Self::And(lhs, rhs) => { Self::And(Box::new(lhs.map_t(f, tvs)), Box::new(rhs.map_t(f, tvs))) } - Self::Or(lhs, rhs) => { - Self::Or(Box::new(lhs.map_t(f, tvs)), Box::new(rhs.map_t(f, tvs))) - } + Self::Or(preds) => Self::Or(preds.into_iter().map(|p| p.map_t(f, tvs)).collect()), Self::Not(pred) => Self::Not(Box::new(pred.map_t(f, tvs))), Self::Failure => self, } @@ -1095,9 +1123,7 @@ impl Predicate { Self::And(lhs, rhs) => { Self::And(Box::new(lhs.map_tp(f, tvs)), Box::new(rhs.map_tp(f, tvs))) } - Self::Or(lhs, rhs) => { - Self::Or(Box::new(lhs.map_tp(f, tvs)), Box::new(rhs.map_tp(f, tvs))) - } + Self::Or(preds) => Self::Or(preds.into_iter().map(|p| p.map_tp(f, tvs)).collect()), Self::Not(pred) => Self::Not(Box::new(pred.map_tp(f, tvs))), Self::Failure => self, } @@ -1147,9 +1173,11 @@ impl Predicate { Box::new(lhs.try_map_tp(f, tvs)?), Box::new(rhs.try_map_tp(f, tvs)?), )), - Self::Or(lhs, rhs) => Ok(Self::Or( - Box::new(lhs.try_map_tp(f, tvs)?), - Box::new(rhs.try_map_tp(f, tvs)?), + Self::Or(preds) => Ok(Self::Or( + preds + .into_iter() + .map(|p| p.try_map_tp(f, tvs)) + .collect::>()?, )), Self::Not(pred) => Ok(Self::Not(Box::new(pred.try_map_tp(f, tvs)?))), Self::Failure | Self::Const(_) => Ok(self),