Skip to content

Commit

Permalink
add equality check support for python node state
Browse files Browse the repository at this point in the history
  • Loading branch information
ljeub-pometry authored and fabianmurariu committed Jun 11, 2024
1 parent ad20136 commit 40a77c1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
4 changes: 2 additions & 2 deletions raphtory/src/python/types/macros/iterable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ macro_rules! py_iterable_comp {
) -> Box<dyn Iterator<Item = $cmp_item> + 'py> {
match self {
Self::Vec(v) => Box::new(v.iter().cloned()),
Self::This(t) => Box::new(t.borrow(py).iter().map_into()),
Self::This(t) => Box::new(t.borrow(py).iter().map(|v| v.into())),
}
}
}
Expand All @@ -262,7 +262,7 @@ macro_rules! py_iterable_comp {

impl<I: Iterator<Item = J>, J: Into<$cmp_item>> From<I> for $cmp_internal {
fn from(value: I) -> Self {
Self::Vec(value.map_into().collect())
Self::Vec(value.map(|v| v.into()).collect())
}
}

Expand Down
15 changes: 15 additions & 0 deletions raphtory/src/python/types/macros/trait_impl/node_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use chrono::{DateTime, Utc};
use pyo3::{
exceptions::{PyKeyError, PyTypeError},
prelude::*,
types::PyNotImplemented,
};
use std::sync::Arc;

Expand Down Expand Up @@ -129,6 +130,20 @@ macro_rules! impl_node_state_ord_ops {
.median_item()
.map(|(n, v)| (n.cloned(), ($to_owned)(v)))
}

fn __eq__<'py>(&'py self, other: &'py PyAny, py: Python<'py>) -> PyObject {
if let Ok(other) = other.extract::<PyRef<Self>>() {
return self.inner.values().eq(other.inner.values()).into_py(py);
} else if let Ok(other) = other.extract::<Vec<$value>>() {
return self
.inner
.values()
.map($to_owned)
.eq(other.iter().cloned())
.into_py(py);
}
PyNotImplemented::get(py).into_py(py)
}
}
};
}
Expand Down
1 change: 0 additions & 1 deletion raphtory/src/python/types/wrappers/iterables.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::{core::ArcStr, db::api::view::BoxedIter, prelude::Prop, python::types::repr::Repr};
use chrono::{DateTime, Utc};
use itertools::Itertools;
use num::cast::AsPrimitive;
use pyo3::prelude::*;
use std::{i64, iter::Sum};
Expand Down

0 comments on commit 40a77c1

Please sign in to comment.