Skip to content

Commit

Permalink
allow per-language discriminants
Browse files Browse the repository at this point in the history
  • Loading branch information
bjchambers committed Sep 18, 2023
1 parent e7845c5 commit e910479
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 10 deletions.
7 changes: 3 additions & 4 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
};

#[cfg(feature = "serde-1")]
use ::serde::{Deserialize, Serialize};
use serde::{Deserialize, Serialize};

use log::*;

Expand Down Expand Up @@ -76,7 +76,7 @@ pub struct EGraph<L: Language, N: Analysis<L>> {
pub(crate) classes: HashMap<Id, EClass<L, N::Data>>,
#[cfg_attr(feature = "serde-1", serde(skip))]
#[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))]
pub(crate) classes_by_op: HashMap<std::mem::Discriminant<L>, HashSet<Id>>,
pub(crate) classes_by_op: HashMap<L::Discriminant, HashSet<Id>>,
/// Whether or not reading operation are allowed on this e-graph.
/// Mutating operations will set this to `false`, and
/// [`EGraph::rebuild`] will set it to true.
Expand Down Expand Up @@ -977,9 +977,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
trimmed += old_len - class.nodes.len();

let mut add = |n: &L| {
#[allow(enum_intrinsics_non_enums)]
classes_by_op
.entry(std::mem::discriminant(n))
.entry(n.discriminant())
.or_default()
.insert(class.id)
};
Expand Down
11 changes: 11 additions & 0 deletions src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ use thiserror::Error;
/// See [`SymbolLang`] for quick-and-dirty use cases.
#[allow(clippy::len_without_is_empty)]
pub trait Language: Debug + Clone + Eq + Ord + Hash {
type Discriminant: Debug + Clone + Eq + Hash;

#[allow(enum_intrinsics_non_enums)]
fn discriminant(&self) -> Self::Discriminant;

/// Returns true if this enode matches another enode.
/// This should only consider the operator, not the children `Id`s.
fn matches(&self, other: &Self) -> bool;
Expand Down Expand Up @@ -826,6 +831,12 @@ impl SymbolLang {
}

impl Language for SymbolLang {
type Discriminant = Symbol;

fn discriminant(&self) -> Self::Discriminant {
self.op.clone()
}

fn matches(&self, other: &Self) -> bool {
self.op == other.op && self.len() == other.len()
}
Expand Down
7 changes: 3 additions & 4 deletions src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ fn for_each_matching_node<L, D>(
where
L: Language,
{
#[allow(enum_intrinsics_non_enums)]
if eclass.nodes.len() < 50 {
eclass
.nodes
Expand All @@ -53,17 +52,17 @@ where
debug_assert!(node.all(|id| id == Id::from(0)));
debug_assert!(eclass.nodes.windows(2).all(|w| w[0] < w[1]));
let mut start = eclass.nodes.binary_search(node).unwrap_or_else(|i| i);
let discrim = std::mem::discriminant(node);
let discrim = node.discriminant();
while start > 0 {
if std::mem::discriminant(&eclass.nodes[start - 1]) == discrim {
if eclass.nodes[start - 1].discriminant() == discrim {
start -= 1;
} else {
break;
}
}
let mut matching = eclass.nodes[start..]
.iter()
.take_while(|&n| std::mem::discriminant(n) == discrim)
.take_while(|&n| n.discriminant() == discrim)
.filter(|n| node.matches(n));
debug_assert_eq!(
matching.clone().count(),
Expand Down
7 changes: 7 additions & 0 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ macro_rules! __define_language {
$vis enum $name $decl

impl $crate::Language for $name {
type Discriminant = std::mem::Discriminant<Self>;

#[inline(always)]
fn discriminant(&self) -> Self::Discriminant {
std::mem::discriminant(self)
}

#[inline(always)]
fn matches(&self, other: &Self) -> bool {
::std::mem::discriminant(self) == ::std::mem::discriminant(other) &&
Expand Down
20 changes: 18 additions & 2 deletions src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,24 @@ pub enum ENodeOrVar<L> {
Var(Var),
}

/// The discriminant for the language of [`Pattern`]s.
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub enum ENodeOrVarDiscriminant<L: Language> {
ENode(L::Discriminant),
Var(Var),
}

impl<L: Language> Language for ENodeOrVar<L> {
type Discriminant = ENodeOrVarDiscriminant<L>;

#[inline(always)]
fn discriminant(&self) -> Self::Discriminant {
match self {
ENodeOrVar::ENode(n) => ENodeOrVarDiscriminant::ENode(n.discriminant()),
ENodeOrVar::Var(v) => ENodeOrVarDiscriminant::Var(*v),
}
}

fn matches(&self, _other: &Self) -> bool {
panic!("Should never call this")
}
Expand Down Expand Up @@ -271,8 +288,7 @@ impl<L: Language, A: Analysis<L>> Searcher<L, A> for Pattern<L> {
fn search_with_limit(&self, egraph: &EGraph<L, A>, limit: usize) -> Vec<SearchMatches<L>> {
match self.ast.as_ref().last().unwrap() {
ENodeOrVar::ENode(e) => {
#[allow(enum_intrinsics_non_enums)]
let key = std::mem::discriminant(e);
let key = e.discriminant();
match egraph.classes_by_op.get(&key) {
None => vec![],
Some(ids) => rewrite::search_eclasses_with_limit(
Expand Down

0 comments on commit e910479

Please sign in to comment.