From 9e5a61d6d7370f530766720f60d740dcfc3a06c6 Mon Sep 17 00:00:00 2001 From: Lucas Jeub <lucas.jeub@pometry.com> Date: Wed, 5 Jun 2024 17:45:32 +0200 Subject: [PATCH] add equality check support for python node state --- raphtory/src/python/types/macros/iterable.rs | 4 ++-- .../python/types/macros/trait_impl/node_state.rs | 15 +++++++++++++++ raphtory/src/python/types/wrappers/iterables.rs | 1 - 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/raphtory/src/python/types/macros/iterable.rs b/raphtory/src/python/types/macros/iterable.rs index cdeef18103..9006b5b176 100644 --- a/raphtory/src/python/types/macros/iterable.rs +++ b/raphtory/src/python/types/macros/iterable.rs @@ -249,7 +249,7 @@ macro_rules! py_iterable_comp { ) -> Box<dyn Iterator<Item = $cmp_item> + 'py> { match self { Self::Vec(v) => Box::new(v.iter().cloned()), - Self::This(t) => Box::new(t.borrow(py).iter().map_into()), + Self::This(t) => Box::new(t.borrow(py).iter().map(|v| v.into())), } } } @@ -262,7 +262,7 @@ macro_rules! py_iterable_comp { impl<I: Iterator<Item = J>, J: Into<$cmp_item>> From<I> for $cmp_internal { fn from(value: I) -> Self { - Self::Vec(value.map_into().collect()) + Self::Vec(value.map(|v| v.into()).collect()) } } diff --git a/raphtory/src/python/types/macros/trait_impl/node_state.rs b/raphtory/src/python/types/macros/trait_impl/node_state.rs index 5965070fea..0fcfcbf63f 100644 --- a/raphtory/src/python/types/macros/trait_impl/node_state.rs +++ b/raphtory/src/python/types/macros/trait_impl/node_state.rs @@ -14,6 +14,7 @@ use chrono::{DateTime, Utc}; use pyo3::{ exceptions::{PyKeyError, PyTypeError}, prelude::*, + types::PyNotImplemented, }; use std::sync::Arc; @@ -129,6 +130,20 @@ macro_rules! impl_node_state_ord_ops { .median_item() .map(|(n, v)| (n.cloned(), ($to_owned)(v))) } + + fn __eq__<'py>(&'py self, other: &'py PyAny, py: Python<'py>) -> PyObject { + if let Ok(other) = other.extract::<PyRef<Self>>() { + return self.inner.values().eq(other.inner.values()).into_py(py); + } else if let Ok(other) = other.extract::<Vec<$value>>() { + return self + .inner + .values() + .map($to_owned) + .eq(other.iter().cloned()) + .into_py(py); + } + PyNotImplemented::get(py).into_py(py) + } } }; } diff --git a/raphtory/src/python/types/wrappers/iterables.rs b/raphtory/src/python/types/wrappers/iterables.rs index c35240652a..4520155db1 100644 --- a/raphtory/src/python/types/wrappers/iterables.rs +++ b/raphtory/src/python/types/wrappers/iterables.rs @@ -1,6 +1,5 @@ use crate::{core::ArcStr, db::api::view::BoxedIter, prelude::Prop, python::types::repr::Repr}; use chrono::{DateTime, Utc}; -use itertools::Itertools; use num::cast::AsPrimitive; use pyo3::prelude::*; use std::{i64, iter::Sum};