Skip to content

Commit

Permalink
use saturating crate, fix assertion bug
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Sep 26, 2023
1 parent e7845c5 commit 53386e6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 69 deletions.
15 changes: 8 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,27 @@ repository = "https://github.com/egraphs-good/egg"
version = "0.9.5"

[dependencies]
env_logger = {version = "0.9.0", default-features = false}
env_logger = { version = "0.9.0", default-features = false }
fxhash = "0.2.1"
hashbrown = "0.12.1"
indexmap = "1.8.1"
instant = "0.1.12"
log = "0.4.17"
smallvec = {version = "1.8.0", features = ["union", "const_generics"]}
symbol_table = {version = "0.2.0", features = ["global"]}
smallvec = { version = "1.8.0", features = ["union", "const_generics"] }
symbol_table = { version = "0.2.0", features = ["global"] }
symbolic_expressions = "5.0.3"
thiserror = "1.0.31"

# for the lp feature
coin_cbc = {version = "0.1.6", optional = true}
coin_cbc = { version = "0.1.6", optional = true }

# for the serde-1 feature
serde = {version = "1.0.137", features = ["derive"], optional = true}
vectorize = {version = "0.2.0", optional = true}
serde = { version = "1.0.137", features = ["derive"], optional = true }
vectorize = { version = "0.2.0", optional = true }

# for the reports feature
serde_json = {version = "1.0.81", optional = true}
serde_json = { version = "1.0.81", optional = true }
saturating = "0.1.0"

[dev-dependencies]
ordered-float = "3.0.0"
Expand Down
122 changes: 60 additions & 62 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ use crate::{
util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id,
Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var,
};
use saturating::Saturating;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::fmt::{self, Debug, Display, Formatter};
use std::rc::Rc;

use symbolic_expressions::Sexp;

type ProofCost = Saturating<usize>;

const CONGRUENCE_LIMIT: usize = 2;
const GREEDY_NUM_ITERS: usize = 2;

Expand Down Expand Up @@ -62,14 +65,14 @@ pub struct Explain<L: Language> {
// the explanation.
// Invariant: The distance is always <= the unoptimized distance
// That is, less than or equal to the result of `distance_between`
shortest_explanation_memo: HashMap<(Id, Id), (usize, Id)>,
shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>,
}

#[derive(Default)]
struct DistanceMemo {
parent_distance: Vec<(Id, usize)>,
parent_distance: Vec<(Id, ProofCost)>,
common_ancestor: HashMap<(Id, Id), Id>,
tree_depth: HashMap<Id, usize>,
tree_depth: HashMap<Id, ProofCost>,
}

/// Explanation trees are the compact representation showing
Expand Down Expand Up @@ -233,12 +236,12 @@ impl<L: Language + Display + FromOp> Explanation<L> {

/// Get the size of this explanation tree in terms of the number of rewrites
/// in the let-bound version of the tree.
pub fn get_tree_size(&self) -> usize {
pub fn get_tree_size(&self) -> ProofCost {
let mut seen = Default::default();
let mut seen_adjacent = Default::default();
let mut sum = 0;
let mut sum: ProofCost = Saturating(0);
for e in self.explanation_trees.iter() {
sum += self.tree_size(&mut seen, &mut seen_adjacent, e);
sum = sum + self.tree_size(&mut seen, &mut seen_adjacent, e);
}
sum
}
Expand All @@ -248,29 +251,29 @@ impl<L: Language + Display + FromOp> Explanation<L> {
seen: &mut HashSet<*const TreeTerm<L>>,
seen_adjacent: &mut HashSet<(Id, Id)>,
current: &Rc<TreeTerm<L>>,
) -> usize {
) -> ProofCost {
if !seen.insert(&**current as *const TreeTerm<L>) {
return 0;
return Saturating(0);
}
let mut my_size = 0;
let mut my_size: ProofCost = Saturating(0);
if current.forward_rule.is_some() {
my_size += 1;
my_size += Saturating(1);
}
if current.backward_rule.is_some() {
my_size += 1;
my_size += Saturating(1);
}
assert!(my_size <= 1);
if my_size == 1 {
assert!(my_size <= Saturating(1));
if my_size == Saturating(1) {
if !seen_adjacent.insert((current.current, current.last)) {
return 0;
return Saturating(0);
} else {
seen_adjacent.insert((current.last, current.current));
}
}

for child_proof in &current.child_proofs {
for child in child_proof {
my_size += self.tree_size(seen, seen_adjacent, child);
my_size = self.tree_size(seen, seen_adjacent, child);
}
}
my_size
Expand Down Expand Up @@ -853,7 +856,7 @@ impl<L: Language> FlatTerm<L> {
// Make sure to use push_increase instead of push when using priority queue
#[derive(Copy, Clone, Eq, PartialEq)]
struct HeapState<I> {
cost: usize,
cost: ProofCost,
item: I,
}
// The priority queue depends on `Ord`.
Expand Down Expand Up @@ -1080,7 +1083,7 @@ impl<L: Language> Explain<L> {
return;
}
if let Some((cost, _)) = self.shortest_explanation_memo.get(&(node1, node2)) {
if cost <= &1 {
if cost <= &Saturating(1) {
return;
}
}
Expand All @@ -1106,9 +1109,9 @@ impl<L: Language> Explain<L> {
.neighbors
.push(rconnection);
self.shortest_explanation_memo
.insert((node1, node2), (1, node2));
.insert((node1, node2), (Saturating(1), node2));
self.shortest_explanation_memo
.insert((node2, node1), (1, node1));
.insert((node2, node1), (Saturating(1), node1));
}

pub(crate) fn union(
Expand All @@ -1132,9 +1135,9 @@ impl<L: Language> Explain<L> {

if let Justification::Rule(_) = justification {
self.shortest_explanation_memo
.insert((node1, node2), (1, node2));
.insert((node1, node2), (Saturating(1), node2));
self.shortest_explanation_memo
.insert((node2, node1), (1, node1));
.insert((node2, node1), (Saturating(1), node1));
}

let pconnection = Connection {
Expand Down Expand Up @@ -1455,28 +1458,28 @@ impl<L: Language> Explain<L> {
enodes
}

fn add_tree_depths(&self, node: Id, depths: &mut HashMap<Id, usize>) -> usize {
fn add_tree_depths(&self, node: Id, depths: &mut HashMap<Id, ProofCost>) -> ProofCost {
if depths.get(&node).is_none() {
let parent = self.parent(node);
let depth = if parent == node {
0
Saturating(0)
} else {
self.add_tree_depths(parent, depths) + 1
self.add_tree_depths(parent, depths) + Saturating(1)
};
depths.insert(node, depth);
}
return *depths.get(&node).unwrap();
}

fn calculate_tree_depths(&self) -> HashMap<Id, usize> {
fn calculate_tree_depths(&self) -> HashMap<Id, ProofCost> {
let mut depths = HashMap::default();
for i in 0..self.explainfind.len() {
self.add_tree_depths(Id::from(i), &mut depths);
}
depths
}

fn replace_distance(&mut self, current: Id, next: Id, right: Id, distance: usize) {
fn replace_distance(&mut self, current: Id, next: Id, right: Id, distance: ProofCost) {
self.shortest_explanation_memo
.insert((current, right), (distance, next));
}
Expand All @@ -1486,11 +1489,10 @@ impl<L: Language> Explain<L> {
right: Id,
left_connections: &[Connection],
distance_memo: &mut DistanceMemo,
target_cost: usize,
) {
self.shortest_explanation_memo
.insert((right, right), (0, right));
let mut last_cost = 0;
.insert((right, right), (Saturating(0), right));
let mut last_cost = Saturating(0);
for connection in left_connections.iter().rev() {
let next = connection.next;
let current = connection.current;
Expand All @@ -1503,12 +1505,16 @@ impl<L: Language> Explain<L> {
last_cost = dist + next_cost;
self.replace_distance(current, next, right, next_cost + dist);
}
assert!(last_cost <= target_cost);
}

fn distance_between(&mut self, left: Id, right: Id, distance_memo: &mut DistanceMemo) -> usize {
fn distance_between(
&mut self,
left: Id,
right: Id,
distance_memo: &mut DistanceMemo,
) -> ProofCost {
if left == right {
return 0;
return Saturating(0);
}
let ancestor = if let Some(a) = distance_memo.common_ancestor.get(&(left, right)) {
*a
Expand All @@ -1535,11 +1541,13 @@ impl<L: Language> Explain<L> {
);

// calculate distance to find upper bound
match b.checked_add(c) {
Some(added) => added
.checked_sub(a.checked_mul(2).unwrap_or(0))
.unwrap_or(usize::MAX),
None => usize::MAX,
match b.0.checked_add(c.0) {
Some(added) => Saturating(
added
.checked_sub(a.0.checked_mul(2).unwrap_or(0))
.unwrap_or(usize::MAX),
),
None => Saturating(usize::MAX),
}

//assert_eq!(dist+1, Explanation::new(self.explain_enodes(left, right, &mut Default::default())).make_flat_explanation().len());
Expand All @@ -1550,20 +1558,16 @@ impl<L: Language> Explain<L> {
current: Id,
next: Id,
distance_memo: &mut DistanceMemo,
) -> usize {
) -> ProofCost {
let current_node = self.explainfind[usize::from(current)].node.clone();
let next_node = self.explainfind[usize::from(next)].node.clone();
let mut cost: usize = 0;
let mut cost: ProofCost = Saturating(0);
for (left_child, right_child) in current_node
.children()
.iter()
.zip(next_node.children().iter())
{
cost = cost.saturating_add(self.distance_between(
*left_child,
*right_child,
distance_memo,
));
cost += self.distance_between(*left_child, *right_child, distance_memo);
}
cost
}
Expand All @@ -1572,12 +1576,12 @@ impl<L: Language> Explain<L> {
&mut self,
connection: &Connection,
distance_memo: &mut DistanceMemo,
) -> usize {
) -> ProofCost {
match connection.justification {
Justification::Congruence => {
self.congruence_distance(connection.current, connection.next, distance_memo)
}
Justification::Rule(_) => 1,
Justification::Rule(_) => Saturating(1),
}
}

Expand All @@ -1586,7 +1590,7 @@ impl<L: Language> Explain<L> {
enode: Id,
ancestor: Id,
distance_memo: &mut DistanceMemo,
) -> usize {
) -> ProofCost {
loop {
let parent = distance_memo.parent_distance[usize::from(enode)].0;
let dist = distance_memo.parent_distance[usize::from(enode)].1;
Expand All @@ -1596,8 +1600,7 @@ impl<L: Language> Explain<L> {

let parent_parent = distance_memo.parent_distance[usize::from(parent)].0;
if parent_parent != parent {
let new_dist =
dist.saturating_add(distance_memo.parent_distance[usize::from(parent)].1);
let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1;
distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist);
} else {
if ancestor == Id::from(usize::MAX) {
Expand All @@ -1617,7 +1620,7 @@ impl<L: Language> Explain<L> {
Justification::Congruence => {
self.congruence_distance(current, next, distance_memo)
}
Justification::Rule(_) => 1,
Justification::Rule(_) => Saturating(1),
};
distance_memo.parent_distance[usize::from(parent)] = (self.parent(parent), cost);
}
Expand Down Expand Up @@ -1703,7 +1706,7 @@ impl<L: Language> Explain<L> {
) -> Option<(Vec<Connection>, Vec<Connection>)> {
let mut todo = BinaryHeap::new();
todo.push(HeapState {
cost: 0,
cost: Saturating(0),
item: Connection {
current: start,
next: start,
Expand Down Expand Up @@ -1737,7 +1740,7 @@ impl<L: Language> Explain<L> {

for neighbor in &self.explainfind[usize::from(current)].neighbors {
if let Justification::Rule(_) = neighbor.justification {
let neighbor_cost = cost_so_far.saturating_add(1);
let neighbor_cost = cost_so_far + Saturating(1);
todo.push(HeapState {
item: neighbor.clone(),
cost: neighbor_cost,
Expand All @@ -1748,7 +1751,7 @@ impl<L: Language> Explain<L> {
for other in congruence_neighbors[usize::from(current)].iter() {
let next = other;
let distance = self.congruence_distance(current, *next, distance_memo);
let next_cost = cost_so_far.saturating_add(distance);
let next_cost = cost_so_far + distance;
todo.push(HeapState {
item: Connection {
current,
Expand All @@ -1767,7 +1770,7 @@ impl<L: Language> Explain<L> {
let mut right_connections = vec![];

// we would like to assert that we found a path better than the normal one
// but since proof sizes are saturated (saturating_add) this is not true
// but since proof sizes are saturated this is not true
/*let dist = self.distance_between(start, end, distance_memo);
if *total_cost.unwrap() > dist {
panic!(
Expand All @@ -1776,7 +1779,7 @@ impl<L: Language> Explain<L> {
dist
);
}*/
if *total_cost.unwrap() == self.distance_between(start, end, distance_memo) {
if *total_cost.unwrap() >= self.distance_between(start, end, distance_memo) {
let (a_left_connections, a_right_connections) = self.get_path_unoptimized(start, end);
left_connections = a_left_connections;
right_connections = a_right_connections;
Expand All @@ -1793,12 +1796,7 @@ impl<L: Language> Explain<L> {
}
}
connections.reverse();
self.populate_path_length(
end,
&connections,
distance_memo,
*path_cost.get(&end).unwrap(),
);
self.populate_path_length(end, &connections, distance_memo);
left_connections = connections;
}

Expand Down Expand Up @@ -1974,7 +1972,7 @@ impl<L: Language> Explain<L> {
) {
let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
let mut parent_distance = vec![(Id::from(0), 0); self.explainfind.len()];
let mut parent_distance = vec![(Id::from(0), Saturating(0)); self.explainfind.len()];
for (i, entry) in parent_distance.iter_mut().enumerate() {
entry.0 = Id::from(i);
}
Expand Down

0 comments on commit 53386e6

Please sign in to comment.