Skip to content

Commit

Permalink
expr: align visitor logic in MirScalarExpr* with MirRelationExpr
Browse files Browse the repository at this point in the history
- Factor out `MirScalarExpr::*visit*` code in a dedicated
  struct `MirScalarExprVisitor`.
- Use `MirScalarExpr::*visit*` to delegate to the latter.
- Implement `CheckedRecursion` for `MirScalarExprVisitor`
  (for now unused).
  • Loading branch information
aalexandrov committed Dec 2, 2021
1 parent 89df30d commit abe7d88
Showing 1 changed file with 154 additions and 73 deletions.
227 changes: 154 additions & 73 deletions src/expr/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use std::collections::HashSet;
use std::fmt;
use std::mem;

use ore::stack::CheckedRecursion;
use ore::stack::RecursionGuard;
use serde::{Deserialize, Serialize};

use lowertest::MzEnumReflect;
Expand All @@ -25,6 +27,7 @@ use repr::{ColumnType, Datum, RelationType, Row, RowArena, ScalarType};

use self::func::{BinaryFunc, NullaryFunc, UnaryFunc, VariadicFunc};
use crate::scalar::func::parse_timezone;
use crate::RECURSION_LIMIT;

pub mod func;
pub mod like_pattern;
Expand Down Expand Up @@ -116,112 +119,54 @@ impl MirScalarExpr {
}

/// Applies an infallible immutable `f` to each child of type `MirScalarExpr`.
pub fn visit_children<'a, F>(&'a self, mut f: F)
///
/// Deletages to `MirScalarExprVisitor::visit_children`.
pub fn visit_children<'a, F>(&'a self, f: F)
where
F: FnMut(&'a Self),
{
match self {
MirScalarExpr::Column(_) => (),
MirScalarExpr::Literal(_, _) => (),
MirScalarExpr::CallNullary(_) => (),
MirScalarExpr::CallUnary { expr, .. } => {
f(expr);
}
MirScalarExpr::CallBinary { expr1, expr2, .. } => {
f(expr1);
f(expr2);
}
MirScalarExpr::CallVariadic { exprs, .. } => {
for expr in exprs {
f(expr);
}
}
MirScalarExpr::If { cond, then, els } => {
f(cond);
f(then);
f(els);
}
}
MirScalarExprVisitor::new().visit_children(self, f)
}

/// Applies an infallible mutable `f` to each child of type `MirScalarExpr`.
pub fn visit_mut_children<'a, F>(&'a mut self, mut f: F)
///
/// Deletages to `MirScalarExprVisitor::visit_mut_children`.
pub fn visit_mut_children<'a, F>(&'a mut self, f: F)
where
F: FnMut(&'a mut Self),
{
match self {
MirScalarExpr::Column(_) => (),
MirScalarExpr::Literal(_, _) => (),
MirScalarExpr::CallNullary(_) => (),
MirScalarExpr::CallUnary { expr, .. } => {
f(expr);
}
MirScalarExpr::CallBinary { expr1, expr2, .. } => {
f(expr1);
f(expr2);
}
MirScalarExpr::CallVariadic { exprs, .. } => {
for expr in exprs {
f(expr);
}
}
MirScalarExpr::If { cond, then, els } => {
f(cond);
f(then);
f(els);
}
}
MirScalarExprVisitor::new().visit_mut_children(self, f)
}

/// Post-order immutable infallible `MirScalarExpr` visitor.
///
/// Grows the recursion stack if needed in order to avoid stack overflow errors.
/// Deletages to `MirScalarExprVisitor::visit_post`.
pub fn visit_post<'a, F>(&'a self, f: &mut F)
where
F: FnMut(&'a Self),
{
maybe_grow(|| {
self.visit_children(|e| e.visit_post(f));
f(self);
})
MirScalarExprVisitor::new().visit_post(self, f)
}

/// Post-order mutable infallible `MirScalarExpr` visitor.
///
/// Grows the recursion stack if needed in order to avoid stack overflow errors.
/// Deletages to `MirScalarExprVisitor::visit_mut_post`.
pub fn visit_mut_post<F>(&mut self, f: &mut F)
where
F: FnMut(&mut Self),
{
maybe_grow(|| {
self.visit_mut_children(|e| e.visit_mut_post(f));
f(self);
})
MirScalarExprVisitor::new().visit_mut_post(self, f)
}

/// A generalization of `visit_mut`. The function `pre` runs on a
/// `MirScalarExpr` before it runs on any of the child `MirScalarExpr`s.
/// The function `post` runs on child `MirScalarExpr`s first before the
/// parent. Optionally, `pre` can return which child `MirScalarExpr`s, if
/// any, should be visited (default is to visit all children).
/// A generalization of `visit_mut`.
///
/// Grows the recursion stack if needed in order to avoid stack overflow errors.
/// Deletages to `MirScalarExprVisitor::visit_mut_pre_post`.
pub fn visit_mut_pre_post<F1, F2>(&mut self, pre: &mut F1, post: &mut F2)
where
F1: FnMut(&mut Self) -> Option<Vec<&mut MirScalarExpr>>,
F2: FnMut(&mut Self),
{
maybe_grow(|| {
let to_visit = pre(self);
if let Some(to_visit) = to_visit {
for e in to_visit {
e.visit_mut_pre_post(pre, post);
}
} else {
self.visit_mut_children(|e| e.visit_mut_pre_post(pre, post));
}
post(self);
})
MirScalarExprVisitor::new().visit_mut_pre_post(self, pre, post)
}

/// Rewrites column indices with their value in `permutation`.
Expand Down Expand Up @@ -1023,6 +968,142 @@ impl fmt::Display for MirScalarExpr {
}
}

#[derive(Debug)]
struct MirScalarExprVisitor {
recursion_guard: RecursionGuard,
}

/// Contains visitor implementations.
///
/// [child, pre, post] x [fallible, infallible] x [immutable, mutable]
impl MirScalarExprVisitor {
/// Constructs a new MirScalarExprVisitor using a [`RecursionGuard`] with [`RECURSION_LIMIT`].
fn new() -> MirScalarExprVisitor {
MirScalarExprVisitor {
recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
}
}

/// Applies an infallible immutable `f` to each `expr` child of type `MirScalarExpr`.
fn visit_children<'a, F>(&self, expr: &'a MirScalarExpr, mut f: F)
where
F: FnMut(&'a MirScalarExpr),
{
match expr {
MirScalarExpr::Column(_) => (),
MirScalarExpr::Literal(_, _) => (),
MirScalarExpr::CallNullary(_) => (),
MirScalarExpr::CallUnary { expr, .. } => {
f(expr);
}
MirScalarExpr::CallBinary { expr1, expr2, .. } => {
f(expr1);
f(expr2);
}
MirScalarExpr::CallVariadic { exprs, .. } => {
for expr in exprs {
f(expr);
}
}
MirScalarExpr::If { cond, then, els } => {
f(cond);
f(then);
f(els);
}
}
}

/// Applies an infallible mutable `f` to each `expr` child of type `MirScalarExpr`.
fn visit_mut_children<'a, F>(&self, expr: &'a mut MirScalarExpr, mut f: F)
where
F: FnMut(&'a mut MirScalarExpr),
{
match expr {
MirScalarExpr::Column(_) => (),
MirScalarExpr::Literal(_, _) => (),
MirScalarExpr::CallNullary(_) => (),
MirScalarExpr::CallUnary { expr, .. } => {
f(expr);
}
MirScalarExpr::CallBinary { expr1, expr2, .. } => {
f(expr1);
f(expr2);
}
MirScalarExpr::CallVariadic { exprs, .. } => {
for expr in exprs {
f(expr);
}
}
MirScalarExpr::If { cond, then, els } => {
f(cond);
f(then);
f(els);
}
}
}

/// Post-order immutable infallible `MirScalarExpr` visitor for `expr`.
///
/// Grows the recursion stack if needed in order to avoid stack overflow errors.
#[inline]
fn visit_post<'a, F>(&self, expr: &'a MirScalarExpr, f: &mut F)
where
F: FnMut(&'a MirScalarExpr),
{
maybe_grow(|| {
self.visit_children(expr, |e| self.visit_post(e, f));
f(expr)
})
}

/// Post-order mutable infallible `MirScalarExpr` visitor for `expr`.
///
/// Grows the recursion stack if needed in order to avoid stack overflow errors.
#[inline]
fn visit_mut_post<F>(&self, expr: &mut MirScalarExpr, f: &mut F)
where
F: FnMut(&mut MirScalarExpr),
{
maybe_grow(|| {
self.visit_mut_children(expr, |e| self.visit_mut_post(e, f));
f(expr)
})
}

/// A generalization of `visit_mut`. The function `pre` runs on a
/// `MirScalarExpr` before it runs on any of the child `MirScalarExpr`s.
/// The function `post` runs on child `MirScalarExpr`s first before the
/// parent. Optionally, `pre` can return which child `MirScalarExpr`s, if
/// any, should be visited (default is to visit all children).
///
/// Grows the recursion stack if needed in order to avoid stack overflow errors.
#[inline]
pub fn visit_mut_pre_post<F1, F2>(&self, expr: &mut MirScalarExpr, pre: &mut F1, post: &mut F2)
where
F1: FnMut(&mut MirScalarExpr) -> Option<Vec<&mut MirScalarExpr>>,
F2: FnMut(&mut MirScalarExpr),
{
maybe_grow(|| {
let to_visit = pre(expr);
if let Some(to_visit) = to_visit {
for e in to_visit {
self.visit_mut_pre_post(e, pre, post);
}
} else {
self.visit_mut_children(expr, |e| self.visit_mut_pre_post(e, pre, post));
}
post(expr);
})
}
}

/// Add checked recursion support for [`MirScalarExprVisitor`].
impl CheckedRecursion for MirScalarExprVisitor {
fn recursion_guard(&self) -> &RecursionGuard {
&self.recursion_guard
}
}

#[derive(
Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzEnumReflect,
)]
Expand Down

0 comments on commit abe7d88

Please sign in to comment.