diff --git a/python/tests/test_graphdb.py b/python/tests/test_graphdb.py index 04c8bc04ef..4209d3864d 100644 --- a/python/tests/test_graphdb.py +++ b/python/tests/test_graphdb.py @@ -1979,15 +1979,44 @@ def test_one_hop_filter_reset(): assert len(out_out_2) == 0 -def test_node_types(): +def test_node_types_filter(): g = Graph() g.add_node(1, 1, node_type="wallet") g.add_node(1, 2, node_type="timer") g.add_node(1, 3, node_type="timer") g.add_node(1, 4, node_type="wallet") - assert g.nodes.type_filter(["wallet"]).node_type.collect() == ["1", "4"] - assert g.subgraph_node_types(["timer"]).nodes.name.collect() == ["2", "3"] + assert [node.name for node in g.nodes.type_filter(["wallet"])] == ['1', '4'] + assert g.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + g = PersistentGraph() + g.add_node(1, 1, node_type="wallet") + g.add_node(2, 2, node_type="timer") + g.add_node(3, 3, node_type="timer") + g.add_node(4, 4, node_type="wallet") + + assert [node.name for node in g.nodes.type_filter(["wallet"])] == ['1', '4'] + assert g.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + subgraph = g.subgraph([1, 2, 3]) + assert [node.name for node in subgraph.nodes.type_filter(["wallet"])] == ['1'] + assert subgraph.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + w = g.window(1, 3) + assert [node.name for node in w.nodes.type_filter(["wallet"])] == ['1'] + assert w.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + g = Graph() + g.add_node(1, 1, node_type="wallet") + g.add_node(2, 2, node_type="timer") + g.add_node(3, 3, node_type="timer") + g.add_node(4, 4, node_type="counter") + g.add_edge(1, 1, 2, layer="layer1") + g.add_edge(2, 2, 3, layer="layer1") + g.add_edge(3, 2, 4, layer="layer2") + layer = g.layers(["layer1"]) + assert [node.name for node in layer.nodes.type_filter(["wallet"])] == ['1'] + assert layer.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] def test_time_exploded_edges(): diff --git a/raphtory-graphql/src/model/graph/nodes.rs b/raphtory-graphql/src/model/graph/nodes.rs index 33e4dfad44..69e2b95098 100644 --- a/raphtory-graphql/src/model/graph/nodes.rs +++ b/raphtory-graphql/src/model/graph/nodes.rs @@ -1,7 +1,11 @@ use crate::model::graph::node::Node; use dynamic_graphql::{ResolvedObject, ResolvedObjectFields}; +use itertools::Itertools; use raphtory::{ - db::{api::view::DynamicGraph, graph::nodes::Nodes}, + db::{ + api::view::DynamicGraph, + graph::{node::NodeView, nodes::Nodes}, + }, prelude::*, }; @@ -77,8 +81,12 @@ impl GqlNodes { self.update(self.nn.shrink_end(end)) } - async fn type_filter(&self, node_types: Vec) -> Self { - self.update(self.nn.type_filter(node_types)) + async fn type_filter(&self, node_types: Vec) -> Vec { + self.nn + .type_filter(node_types) + .into_iter() + .map(Node::from) + .collect_vec() } //////////////////////// diff --git a/raphtory-graphql/src/model/graph/path_from_node.rs b/raphtory-graphql/src/model/graph/path_from_node.rs index 80509414b3..f9ab931efb 100644 --- a/raphtory-graphql/src/model/graph/path_from_node.rs +++ b/raphtory-graphql/src/model/graph/path_from_node.rs @@ -78,10 +78,6 @@ impl GqlPathFromNode { self.update(self.nn.shrink_end(end)) } - async fn type_filter(&self, node_types: Vec) -> Self { - self.update(self.nn.type_filter(node_types)) - } - //////////////////////// //// TIME QUERIES ////// //////////////////////// diff --git a/raphtory/src/db/api/view/mod.rs b/raphtory/src/db/api/view/mod.rs index bae339df9b..6812721110 100644 --- a/raphtory/src/db/api/view/mod.rs +++ b/raphtory/src/db/api/view/mod.rs @@ -5,7 +5,6 @@ mod graph; pub(crate) mod internal; mod layer; pub(crate) mod node; -mod node_types_filter; pub(crate) mod time; pub(crate) use edge::BaseEdgeViewOps; @@ -18,7 +17,6 @@ pub use internal::{ pub use layer::*; pub(crate) use node::BaseNodeViewOps; pub use node::NodeViewOps; -pub use node_types_filter::*; pub use time::*; pub type BoxedIter = Box + Send>; diff --git a/raphtory/src/db/api/view/node.rs b/raphtory/src/db/api/view/node.rs index ce3f829522..b9810112fc 100644 --- a/raphtory/src/db/api/view/node.rs +++ b/raphtory/src/db/api/view/node.rs @@ -11,7 +11,7 @@ use crate::{ storage::locked::LockedGraph, view::{ internal::{CoreGraphOps, TimeSemantics}, - TimeOps, + Base, TimeOps, }, }, prelude::{EdgeViewOps, GraphViewOps, LayerOps}, diff --git a/raphtory/src/db/api/view/node_types_filter.rs b/raphtory/src/db/api/view/node_types_filter.rs deleted file mode 100644 index 3c8901672c..0000000000 --- a/raphtory/src/db/api/view/node_types_filter.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::db::{ - api::view::internal::{CoreGraphOps, OneHopFilter}, - graph::views::node_type_filtered_subgraph::TypeFilteredSubgraph, -}; -use std::borrow::Borrow; - -pub trait NodeTypesFilter<'graph>: OneHopFilter<'graph> { - fn type_filter, V: Borrow>( - &self, - node_types: I, - ) -> Self::Filtered> { - let meta = self.current_filter().node_meta().node_type_meta(); - let r = node_types - .into_iter() - .flat_map(|nt| meta.get_id(nt.borrow())) - .collect(); - self.one_hop_filtered(TypeFilteredSubgraph::new(self.current_filter().clone(), r)) - } -} diff --git a/raphtory/src/db/graph/graph.rs b/raphtory/src/db/graph/graph.rs index e61c4988a3..6d81d449d2 100644 --- a/raphtory/src/db/graph/graph.rs +++ b/raphtory/src/db/graph/graph.rs @@ -217,7 +217,9 @@ mod db_tests { EdgeViewOps, Layer, LayerOps, NodeViewOps, TimeOps, }, }, - graph::{edge::EdgeView, edges::Edges, node::NodeView, path::PathFromNode}, + graph::{ + edge::EdgeView, edges::Edges, node::NodeView, nodes::Nodes, path::PathFromNode, + }, }, graphgen::random_attachment::random_attachment, prelude::{AdditionOps, PropertyAdditionOps}, @@ -2154,6 +2156,26 @@ mod db_tests { ); } + #[test] + fn test_node_type() { + let g = Graph::new(); + g.add_node(1, 1, NO_PROPS, Some("a")).unwrap(); + g.add_node(2, 2, NO_PROPS, Some("b")).unwrap(); + g.add_node(2, 3, NO_PROPS, Some("b")).unwrap(); + g.add_edge(3, 1, 2, NO_PROPS, None).unwrap(); + g.add_edge(3, 3, 2, NO_PROPS, None).unwrap(); + + for node in g.nodes().type_filter(vec!["a"]) { + assert_eq!(node.degree(), 1); + } + + for node in g.nodes() { + if node.node_type() == Some(ArcStr::from("a")) { + assert_eq!(node.degree(), 1); + } + } + } + #[test] fn test_persistent_graph() { let g = Graph::new(); diff --git a/raphtory/src/db/graph/nodes.rs b/raphtory/src/db/graph/nodes.rs index 3e564c38cb..90d51523f5 100644 --- a/raphtory/src/db/graph/nodes.rs +++ b/raphtory/src/db/graph/nodes.rs @@ -6,7 +6,6 @@ use crate::{ view::{ internal::{OneHopFilter, Static}, BaseNodeViewOps, BoxedLIter, DynamicGraph, IntoDynBoxed, IntoDynamic, - NodeTypesFilter, }, }, graph::{edges::NestedEdges, node::NodeView, path::PathFromGraph}, @@ -16,7 +15,7 @@ use crate::{ use crate::db::api::storage::locked::LockedGraph; use rayon::iter::ParallelIterator; -use std::{marker::PhantomData, sync::Arc}; +use std::{borrow::Borrow, marker::PhantomData, sync::Arc}; #[derive(Clone)] pub struct Nodes<'graph, G, GH = G> { @@ -95,6 +94,29 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Nodes<'graph, G, )) } + pub fn type_filter, V: AsRef>( + &self, + node_types: I, + ) -> BoxedLIter<'graph, NodeView> + where + I::IntoIter: Send + Sync + 'graph, + V: Send + Sync + 'graph, + { + let base_graph = self.base_graph.clone(); + node_types + .into_iter() + .flat_map(move |nt| { + base_graph.nodes().into_iter().filter_map(move |node| { + if nt.as_ref() == node.node_type()?.as_ref() { + Some(node) + } else { + None + } + }) + }) + .into_dyn_boxed() + } + pub fn collect(&self) -> Vec> { self.iter().collect() } @@ -192,11 +214,6 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> OneHopFilter<'gr } } -impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> NodeTypesFilter<'graph> - for Nodes<'graph, G, GH> -{ -} - impl<'graph, G: GraphViewOps<'graph> + 'graph, GH: GraphViewOps<'graph> + 'graph> IntoIterator for Nodes<'graph, G, GH> { diff --git a/raphtory/src/db/graph/path.rs b/raphtory/src/db/graph/path.rs index c61928ae0a..dcd02b41e8 100644 --- a/raphtory/src/db/graph/path.rs +++ b/raphtory/src/db/graph/path.rs @@ -407,16 +407,6 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> OneHopFilter<'gr } } -impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> NodeTypesFilter<'graph> - for PathFromGraph<'graph, G, GH> -{ -} - -impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> NodeTypesFilter<'graph> - for PathFromNode<'graph, G, GH> -{ -} - #[cfg(test)] mod test { use crate::prelude::*; diff --git a/raphtory/src/db/task/node/eval_node.rs b/raphtory/src/db/task/node/eval_node.rs index 0d113c5e39..f5cfe1ce44 100644 --- a/raphtory/src/db/task/node/eval_node.rs +++ b/raphtory/src/db/task/node/eval_node.rs @@ -16,7 +16,7 @@ use crate::{ graph::{node::NodeView, path::PathFromNode}, task::{node::eval_node_state::EVState, task_state::Local2}, }, - prelude::{GraphViewOps, NodeTypesFilter}, + prelude::GraphViewOps, }; use crate::db::{api::storage::locked::LockedGraph, task::edge::eval_edges::EvalEdges}; @@ -532,17 +532,6 @@ impl< } } -impl< - 'graph, - 'a: 'graph, - G: GraphViewOps<'graph>, - S, - CS: ComputeState + 'a, - GH: GraphViewOps<'graph>, - > NodeTypesFilter<'graph> for EvalPathFromNode<'graph, 'a, G, GH, CS, S> -{ -} - /// Represents an entry in the shuffle table. /// /// The entry contains a reference to a `ShuffleComputeState` and an `AccId` representing the accumulator diff --git a/raphtory/src/lib.rs b/raphtory/src/lib.rs index 0236f5965f..e7b9e07536 100644 --- a/raphtory/src/lib.rs +++ b/raphtory/src/lib.rs @@ -108,10 +108,7 @@ pub mod prelude { db::{ api::{ mutation::{AdditionOps, DeletionOps, ImportOps, PropertyAdditionOps}, - view::{ - EdgeViewOps, GraphViewOps, Layer, LayerOps, NodeTypesFilter, NodeViewOps, - TimeOps, - }, + view::{EdgeViewOps, GraphViewOps, Layer, LayerOps, NodeViewOps, TimeOps}, }, graph::graph::Graph, }, diff --git a/raphtory/src/python/graph/node.rs b/raphtory/src/python/graph/node.rs index 95a1759326..a223b67b0c 100644 --- a/raphtory/src/python/graph/node.rs +++ b/raphtory/src/python/graph/node.rs @@ -27,6 +27,7 @@ use crate::{ *, }; use chrono::{DateTime, Utc}; +use itertools::Itertools; use pyo3::{ exceptions::{PyIndexError, PyKeyError}, prelude::*, @@ -419,13 +420,6 @@ pub struct PyNodes { pub(crate) nodes: Nodes<'static, DynamicGraph, DynamicGraph>, } -impl_nodetypesfilter!( - PyNodes, - nodes, - Nodes<'static, DynamicGraph, DynamicGraph>, - "Nodes" -); - impl_nodeviewops!( PyNodes, nodes, @@ -669,6 +663,10 @@ impl PyNodes { Ok(df_data.to_object(py)) }) } + + pub fn type_filter(&self, node_types: Vec) -> Vec> { + self.nodes.type_filter(node_types).collect_vec() + } } impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Repr for Nodes<'static, G, GH> { @@ -682,13 +680,6 @@ pub struct PyPathFromGraph { path: PathFromGraph<'static, DynamicGraph, DynamicGraph>, } -impl_nodetypesfilter!( - PyPathFromGraph, - path, - PathFromGraph<'static, DynamicGraph, DynamicGraph>, - "PathFromGraph" -); - impl_nodeviewops!( PyPathFromGraph, path, @@ -819,13 +810,6 @@ pub struct PyPathFromNode { path: PathFromNode<'static, DynamicGraph, DynamicGraph>, } -impl_nodetypesfilter!( - PyPathFromNode, - path, - PathFromNode<'static, DynamicGraph, DynamicGraph>, - "PathFromNode" -); - impl_nodeviewops!( PyPathFromNode, path, diff --git a/raphtory/src/python/types/macros/trait_impl/mod.rs b/raphtory/src/python/types/macros/trait_impl/mod.rs index 7293215748..abc9d14852 100644 --- a/raphtory/src/python/types/macros/trait_impl/mod.rs +++ b/raphtory/src/python/types/macros/trait_impl/mod.rs @@ -14,6 +14,3 @@ mod nodeviewops; mod repr; #[macro_use] mod iterable_mixin; - -#[macro_use] -mod nodetypesfilter; diff --git a/raphtory/src/python/types/macros/trait_impl/nodetypesfilter.rs b/raphtory/src/python/types/macros/trait_impl/nodetypesfilter.rs deleted file mode 100644 index 07a2accc5c..0000000000 --- a/raphtory/src/python/types/macros/trait_impl/nodetypesfilter.rs +++ /dev/null @@ -1,17 +0,0 @@ -macro_rules! impl_nodetypesfilter { - ($obj:ty, $field:ident, $base_type:ty, $name:literal) => { - - #[pyo3::pymethods] - impl $obj { - - /// Filter nodes by node types - /// - /// Arguments: - /// node_types (list[str]): list of node types - /// - fn type_filter(&self, node_types: Vec<&str>) -> <$base_type as $crate::db::api::view::internal::OneHopFilter<'static>>::Filtered<$crate::db::graph::views::node_type_filtered_subgraph::TypeFilteredSubgraph<<$base_type as $crate::db::api::view::internal::OneHopFilter<'static>>::FilteredGraph>> { - self.$field.type_filter(node_types) - } - } - }; -}