diff --git a/python/tests/test_graphdb.py b/python/tests/test_graphdb.py index 97bb07221a..b769acb547 100644 --- a/python/tests/test_graphdb.py +++ b/python/tests/test_graphdb.py @@ -2046,15 +2046,92 @@ def test_one_hop_filter_reset(): assert len(out_out_2) == 0 -def test_node_types(): +def test_type_filter(): g = Graph() g.add_node(1, 1, node_type="wallet") g.add_node(1, 2, node_type="timer") g.add_node(1, 3, node_type="timer") g.add_node(1, 4, node_type="wallet") - assert g.nodes.type_filter(["wallet"]).node_type.collect() == ["1", "4"] - assert g.subgraph_node_types(["timer"]).nodes.name.collect() == ["2", "3"] + assert [node.name for node in g.nodes.type_filter(["wallet"])] == ['1', '4'] + assert g.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + g = PersistentGraph() + g.add_node(1, 1, node_type="wallet") + g.add_node(2, 2, node_type="timer") + g.add_node(3, 3, node_type="timer") + g.add_node(4, 4, node_type="wallet") + + assert [node.name for node in g.nodes.type_filter(["wallet"])] == ['1', '4'] + assert g.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + subgraph = g.subgraph([1, 2, 3]) + assert [node.name for node in subgraph.nodes.type_filter(["wallet"])] == ['1'] + assert subgraph.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + w = g.window(1, 3) + assert [node.name for node in w.nodes.type_filter(["wallet"])] == ['1'] + assert w.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + g = Graph() + g.add_node(1, 1, node_type="wallet") + g.add_node(2, 2, node_type="timer") + g.add_node(3, 3, node_type="timer") + g.add_node(4, 4, node_type="counter") + g.add_edge(1, 1, 2, layer="layer1") + g.add_edge(2, 2, 3, layer="layer1") + g.add_edge(3, 2, 4, layer="layer2") + layer = g.layers(["layer1"]) + assert [node.name for node in layer.nodes.type_filter(["wallet"])] == ['1'] + assert layer.subgraph_node_types(["timer"]).nodes.name.collect() == ['2', '3'] + + g = Graph() + g.add_node(1, 1, node_type="a") + g.add_node(1, 2, node_type="b") + g.add_node(1, 3, node_type="b") + g.add_node(1, 4, node_type="a") + g.add_node(1, 5, node_type="c") + g.add_node(1, 6, node_type="e") + g.add_edge(2, 1, 2, layer="a") + g.add_edge(2, 3, 2, layer="a") + g.add_edge(2, 2, 4, layer="a") + g.add_edge(2, 4, 5, layer="a") + g.add_edge(2, 4, 5, layer="a") + g.add_edge(2, 5, 6, layer="a") + g.add_edge(2, 3, 6, layer="a") + + assert g.nodes.type_filter(["a"]).name.collect() == ['1', '4'] + assert g.nodes.type_filter(["a", "c"]).name.collect() == ['1', '4', '5'] + assert g.nodes.type_filter(["a"]).neighbours.name.collect() == [['2'], ['2', '5']] + + assert g.nodes.degree().collect() == [1, 3, 2, 2, 2, 2] + assert g.nodes.type_filter(['a']).degree().collect() == [1, 2] + assert g.nodes.type_filter(['d']).degree().collect() == [] + assert g.nodes.type_filter([]).name.collect() == [] + + assert len(g.nodes) == 6 + assert len(g.nodes.type_filter(['b'])) == 2 + assert len(g.nodes.type_filter(['d'])) == 0 + + assert g.nodes.type_filter(['d']).neighbours.name.collect() == [] + assert g.nodes.type_filter(['a']).neighbours.name.collect() == [['2'], ['2', '5']] + assert g.nodes.type_filter(['a', 'c']).neighbours.name.collect() == [['2'], ['2', '5'], ['4', '6']] + + assert g.nodes.type_filter(['a']).neighbours.type_filter(['c']).name.collect() == [[], ['5']] + assert g.nodes.type_filter(['a']).neighbours.type_filter([]).name.collect() == [[], []] + assert g.nodes.type_filter(['a']).neighbours.type_filter(['b', 'c']).name.collect() == [['2'], ['2', '5']] + assert g.nodes.type_filter(['a']).neighbours.type_filter(['d']).name.collect() == [[], []] + assert g.nodes.type_filter(['a']).neighbours.neighbours.name.collect() == [['1', '3', '4'], ['1', '3', '4', '4', '6']] + assert g.nodes.type_filter(['a']).neighbours.type_filter(['c']).neighbours.name.collect() == [[], ['4', '6']] + assert g.nodes.type_filter(['a']).neighbours.type_filter(['d']).neighbours.name.collect() == [[], []] + + assert g.node('2').neighbours.type_filter(['b']).name.collect() == ['3'] + assert g.node('2').neighbours.type_filter(['d']).name.collect() == [] + assert g.node('2').neighbours.type_filter([]).name.collect() == [] + assert g.node('2').neighbours.type_filter(['c', 'a']).name.collect() == ['1', '4'] + assert g.node('2').neighbours.type_filter(['c']).neighbours.name.collect() == [] + assert g.node('2').neighbours.neighbours.name.collect() == ['2', '2', '6', '2', '5'] + def test_time_exploded_edges(): diff --git a/raphtory-graphql/src/lib.rs b/raphtory-graphql/src/lib.rs index 78b6ac6186..ba637c47ab 100644 --- a/raphtory-graphql/src/lib.rs +++ b/raphtory-graphql/src/lib.rs @@ -969,4 +969,120 @@ mod graphql_test { let graph_roundtrip = url_decode_graph(graph_encoded).unwrap().into_dynamic(); assert_eq!(g, graph_roundtrip); } + + #[tokio::test] + async fn test_type_filter() { + let graph = Graph::new(); + graph.add_constant_properties([("name", "graph")]).unwrap(); + graph.add_node(1, 1, NO_PROPS, Some("a")).unwrap(); + graph.add_node(1, 2, NO_PROPS, Some("b")).unwrap(); + graph.add_node(1, 3, NO_PROPS, Some("b")).unwrap(); + graph.add_node(1, 4, NO_PROPS, Some("a")).unwrap(); + graph.add_node(1, 5, NO_PROPS, Some("c")).unwrap(); + graph.add_node(1, 6, NO_PROPS, Some("e")).unwrap(); + graph.add_edge(2, 1, 2, NO_PROPS, Some("a")).unwrap(); + graph.add_edge(2, 3, 2, NO_PROPS, Some("a")).unwrap(); + graph.add_edge(2, 2, 4, NO_PROPS, Some("a")).unwrap(); + graph.add_edge(2, 4, 5, NO_PROPS, Some("a")).unwrap(); + graph.add_edge(2, 4, 5, NO_PROPS, Some("a")).unwrap(); + graph.add_edge(2, 5, 6, NO_PROPS, Some("a")).unwrap(); + graph.add_edge(2, 3, 6, NO_PROPS, Some("a")).unwrap(); + + let graphs = HashMap::from([("graph".to_string(), graph)]); + let data = Data::from_map(graphs); + let schema = App::create_schema().data(data).finish().unwrap(); + + let req = r#" + { + graph(name: "graph") { + nodes { + typeFilter(nodeTypes: ["a"]) { + list { + name + } + } + } + } + } + "#; + + let req = Request::new(req); + let res = schema.execute(req).await; + let data = res.data.into_json().unwrap(); + assert_eq!( + data, + json!({ + "graph": { + "nodes": { + "typeFilter": { + "list": [ + { + "name": "1" + }, + { + "name": "4" + } + ] + } + } + } + }), + ); + + let req = r#" + { + graph(name: "graph") { + nodes { + typeFilter(nodeTypes: ["a"]) { + list{ + neighbours { + list { + name + } + } + } + } + } + } + } + "#; + + let req = Request::new(req); + let res = schema.execute(req).await; + let data = res.data.into_json().unwrap(); + assert_eq!( + data, + json!({ + "graph": { + "nodes": { + "typeFilter": { + "list": [ + { + "neighbours": { + "list": [ + { + "name": "2" + } + ] + } + }, + { + "neighbours": { + "list": [ + { + "name": "2" + }, + { + "name": "5" + } + ] + } + } + ] + } + } + } + }), + ); + } } diff --git a/raphtory-graphql/src/model/graph/nodes.rs b/raphtory-graphql/src/model/graph/nodes.rs index 33e4dfad44..ee7e8c51b5 100644 --- a/raphtory-graphql/src/model/graph/nodes.rs +++ b/raphtory-graphql/src/model/graph/nodes.rs @@ -78,7 +78,7 @@ impl GqlNodes { } async fn type_filter(&self, node_types: Vec) -> Self { - self.update(self.nn.type_filter(node_types)) + self.update(self.nn.type_filter(&node_types)) } //////////////////////// diff --git a/raphtory-graphql/src/model/graph/path_from_node.rs b/raphtory-graphql/src/model/graph/path_from_node.rs index 80509414b3..218e54045e 100644 --- a/raphtory-graphql/src/model/graph/path_from_node.rs +++ b/raphtory-graphql/src/model/graph/path_from_node.rs @@ -79,7 +79,7 @@ impl GqlPathFromNode { } async fn type_filter(&self, node_types: Vec) -> Self { - self.update(self.nn.type_filter(node_types)) + self.update(self.nn.type_filter(&node_types)) } //////////////////////// diff --git a/raphtory-graphql/src/model/graph/property.rs b/raphtory-graphql/src/model/graph/property.rs index 61dde9a923..2fc3ee9253 100644 --- a/raphtory-graphql/src/model/graph/property.rs +++ b/raphtory-graphql/src/model/graph/property.rs @@ -9,7 +9,7 @@ use raphtory::{ }, }; use serde_json::Number; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; #[derive(Clone, Debug, Scalar)] pub struct GqlPropValue(pub Prop); diff --git a/raphtory/src/arrow/graph_impl/core_ops.rs b/raphtory/src/arrow/graph_impl/core_ops.rs index 46ca6491e7..0185fade54 100644 --- a/raphtory/src/arrow/graph_impl/core_ops.rs +++ b/raphtory/src/arrow/graph_impl/core_ops.rs @@ -28,13 +28,17 @@ use crate::{ }, storage_ops::GraphStorage, }, - view::{internal::CoreGraphOps, BoxedIter}, + view::{ + internal::{CoreGraphOps, DelegateCoreOps}, + BoxedIter, + }, }, }; use itertools::Itertools; use polars_arrow::datatypes::ArrowDataType; use raphtory_arrow::{properties::Properties, GidRef, GID}; use rayon::prelude::*; + impl CoreGraphOps for ArrowGraph { fn unfiltered_num_nodes(&self) -> usize { self.inner.num_nodes() @@ -222,6 +226,11 @@ impl CoreGraphOps for ArrowGraph { .map(|layer| layer.num_edges()) .sum() } + + fn node_type_id(&self, v: VID) -> usize { + // self.graph().node_type_id(v) TODO: Impl node types for arrow graphs + 0 + } } pub fn const_props(props: &Properties, index: Index, id: usize) -> Option diff --git a/raphtory/src/arrow/graph_impl/edge_storage_ops.rs b/raphtory/src/arrow/graph_impl/edge_storage_ops.rs index 21b32fbb00..ef41512437 100644 --- a/raphtory/src/arrow/graph_impl/edge_storage_ops.rs +++ b/raphtory/src/arrow/graph_impl/edge_storage_ops.rs @@ -8,8 +8,8 @@ use crate::{ edges::edge_storage_ops::{EdgeStorageOps, TimeIndexRef}, tprop_storage_ops::TPropOps, }, - prelude::TimeIndexEntry, }; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use raphtory_arrow::{edge::Edge, tprops::ArrowTProp}; use rayon::prelude::*; use std::{iter, ops::Range}; diff --git a/raphtory/src/arrow/graph_impl/graph_ops.rs b/raphtory/src/arrow/graph_impl/graph_ops.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/raphtory/src/arrow/graph_impl/graph_ops.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/raphtory/src/arrow/graph_impl/interop.rs b/raphtory/src/arrow/graph_impl/interop.rs index 22b1bffb44..416d31a2f3 100644 --- a/raphtory/src/arrow/graph_impl/interop.rs +++ b/raphtory/src/arrow/graph_impl/interop.rs @@ -16,7 +16,10 @@ use crate::{ }; use itertools::Itertools; use polars_arrow::array::Array; -use raphtory_api::core::entities::{EID, VID}; +use raphtory_api::core::{ + entities::{EID, VID}, + storage::timeindex::TimeIndexEntry, +}; use raphtory_arrow::interop::GraphLike; impl GraphLike for Graph { diff --git a/raphtory/src/arrow/graph_impl/mod.rs b/raphtory/src/arrow/graph_impl/mod.rs index e75ebd8732..2dce974570 100644 --- a/raphtory/src/arrow/graph_impl/mod.rs +++ b/raphtory/src/arrow/graph_impl/mod.rs @@ -28,8 +28,6 @@ pub mod const_properties_ops; pub mod core_ops; pub mod edge_filter_ops; mod edge_storage_ops; -pub mod graph_ops; - mod interop; pub mod layer_ops; mod list_ops; diff --git a/raphtory/src/arrow/graph_impl/time_index_into_ops.rs b/raphtory/src/arrow/graph_impl/time_index_into_ops.rs index bfc269bb72..e658c6e0c9 100644 --- a/raphtory/src/arrow/graph_impl/time_index_into_ops.rs +++ b/raphtory/src/arrow/graph_impl/time_index_into_ops.rs @@ -1,8 +1,8 @@ use crate::{ core::storage::timeindex::{TimeIndexIntoOps, TimeIndexOps}, db::api::view::IntoDynBoxed, - prelude::TimeIndexEntry, }; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use raphtory_arrow::{ prelude::{ArrayOps, BaseArrayOps}, timestamps::TimeStamps, diff --git a/raphtory/src/arrow/graph_impl/time_semantics.rs b/raphtory/src/arrow/graph_impl/time_semantics.rs index ac231a277a..8800090f12 100644 --- a/raphtory/src/arrow/graph_impl/time_semantics.rs +++ b/raphtory/src/arrow/graph_impl/time_semantics.rs @@ -16,6 +16,7 @@ use crate::{ prelude::*, }; use itertools::Itertools; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use rayon::prelude::*; use std::{iter, ops::Range}; diff --git a/raphtory/src/arrow/graph_impl/tprops.rs b/raphtory/src/arrow/graph_impl/tprops.rs index da77989cc9..e779d0d223 100644 --- a/raphtory/src/arrow/graph_impl/tprops.rs +++ b/raphtory/src/arrow/graph_impl/tprops.rs @@ -5,8 +5,9 @@ use crate::{ }, core::storage::timeindex::TimeIndexIntoOps, db::api::{storage::tprop_storage_ops::TPropOps, view::IntoDynBoxed}, - prelude::{Prop, TimeIndexEntry}, + prelude::Prop, }; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use raphtory_arrow::{ chunked_array::{col::ChunkedPrimitiveCol, utf8_col::StringCol}, edge::Edge, diff --git a/raphtory/src/arrow/storage_interface/edge.rs b/raphtory/src/arrow/storage_interface/edge.rs index 47a391a189..b4f11038f6 100644 --- a/raphtory/src/arrow/storage_interface/edge.rs +++ b/raphtory/src/arrow/storage_interface/edge.rs @@ -4,8 +4,8 @@ use crate::{ storage::timeindex::TimeIndexOps, }, db::api::storage::edges::edge_storage_ops::EdgeStorageIntoOps, - prelude::TimeIndexEntry, }; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use raphtory_arrow::{edge::Edge, edges::Edges, graph::TemporalGraph, timestamps::TimeStamps}; use std::ops::Range; diff --git a/raphtory/src/core/entities/graph/tgraph.rs b/raphtory/src/core/entities/graph/tgraph.rs index 8ef101a2b4..dcb51de815 100644 --- a/raphtory/src/core/entities/graph/tgraph.rs +++ b/raphtory/src/core/entities/graph/tgraph.rs @@ -238,6 +238,11 @@ impl TemporalGraph { self.node_meta.get_node_type_name_by_id(node.node_type) } + pub(crate) fn node_type_id(&self, v: VID) -> usize { + let node = self.storage.get_node(v); + node.node_type + } + pub(crate) fn get_all_node_types(&self) -> Vec { self.node_meta.get_all_node_types() } diff --git a/raphtory/src/core/mod.rs b/raphtory/src/core/mod.rs index 221c74eff7..30e8bb8a3c 100644 --- a/raphtory/src/core/mod.rs +++ b/raphtory/src/core/mod.rs @@ -50,7 +50,7 @@ extern crate core; pub mod entities; pub mod state; -pub(crate) mod storage; +pub mod storage; pub mod utils; // this is here because Arc annoyingly doesn't implement all the expected comparisons diff --git a/raphtory/src/core/storage/timeindex.rs b/raphtory/src/core/storage/timeindex.rs index 1347fbfc1c..d2473b49ac 100644 --- a/raphtory/src/core/storage/timeindex.rs +++ b/raphtory/src/core/storage/timeindex.rs @@ -6,6 +6,7 @@ use crate::{ use chrono::{DateTime, NaiveDateTime, Utc}; use itertools::Itertools; use num_traits::Saturating; +use raphtory_api::core::entities::VID; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::{ diff --git a/raphtory/src/db/api/storage/edges/edge_owned_entry.rs b/raphtory/src/db/api/storage/edges/edge_owned_entry.rs index f21741f420..fa482d8704 100644 --- a/raphtory/src/db/api/storage/edges/edge_owned_entry.rs +++ b/raphtory/src/db/api/storage/edges/edge_owned_entry.rs @@ -17,8 +17,8 @@ use crate::{ }, tprop_storage_ops::TPropOps, }, - prelude::TimeIndexEntry, }; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use rayon::iter::ParallelIterator; use std::ops::Range; diff --git a/raphtory/src/db/api/storage/edges/edge_storage_ops.rs b/raphtory/src/db/api/storage/edges/edge_storage_ops.rs index 599ef51cde..76faaaf27b 100644 --- a/raphtory/src/db/api/storage/edges/edge_storage_ops.rs +++ b/raphtory/src/db/api/storage/edges/edge_storage_ops.rs @@ -7,7 +7,6 @@ use crate::{ storage::timeindex::{TimeIndex, TimeIndexIntoOps, TimeIndexOps, TimeIndexWindow}, }, db::api::view::IntoDynBoxed, - prelude::TimeIndexEntry, }; #[cfg(feature = "arrow")] @@ -17,6 +16,7 @@ use crate::{ core::entities::properties::tprop::TProp, db::api::storage::{tprop_storage_ops::TPropOps, variants::layer_variants::LayerVariants}, }; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use rayon::prelude::*; use std::ops::Range; diff --git a/raphtory/src/db/api/storage/tprop_storage_ops.rs b/raphtory/src/db/api/storage/tprop_storage_ops.rs index 413a2b5fd4..9563e72414 100644 --- a/raphtory/src/db/api/storage/tprop_storage_ops.rs +++ b/raphtory/src/db/api/storage/tprop_storage_ops.rs @@ -1,9 +1,7 @@ +use crate::core::{entities::properties::tprop::TProp, storage::timeindex::AsTime, Prop}; #[cfg(feature = "arrow")] use crate::db::api::storage::variants::storage_variants::StorageVariants; -use crate::{ - core::{entities::properties::tprop::TProp, storage::timeindex::AsTime, Prop}, - prelude::TimeIndexEntry, -}; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; #[cfg(feature = "arrow")] use raphtory_arrow::tprops::ArrowTProp; use std::ops::Range; diff --git a/raphtory/src/db/api/storage/variants/storage_variants.rs b/raphtory/src/db/api/storage/variants/storage_variants.rs index 147660a189..ef761b7412 100644 --- a/raphtory/src/db/api/storage/variants/storage_variants.rs +++ b/raphtory/src/db/api/storage/variants/storage_variants.rs @@ -1,4 +1,5 @@ -use crate::{core::Prop, db::api::storage::tprop_storage_ops::TPropOps, prelude::TimeIndexEntry}; +use crate::{core::Prop, db::api::storage::tprop_storage_ops::TPropOps}; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use rayon::iter::{ plumbing::{Consumer, ProducerCallback, UnindexedConsumer}, IndexedParallelIterator, ParallelIterator, diff --git a/raphtory/src/db/api/view/graph.rs b/raphtory/src/db/api/view/graph.rs index b6b13b60da..84cf1d2a6d 100644 --- a/raphtory/src/db/api/view/graph.rs +++ b/raphtory/src/db/api/view/graph.rs @@ -132,6 +132,7 @@ impl<'graph, G: BoxableGraphView + Sized + Clone + 'graph> GraphViewOps<'graph> let graph = self.clone(); Nodes::new(graph) } + fn materialize(&self) -> Result { let g = InternalGraph::default(); diff --git a/raphtory/src/db/api/view/internal/core_ops.rs b/raphtory/src/db/api/view/internal/core_ops.rs index 97bc4d700d..10bad849e4 100644 --- a/raphtory/src/db/api/view/internal/core_ops.rs +++ b/raphtory/src/db/api/view/internal/core_ops.rs @@ -80,6 +80,8 @@ pub trait CoreGraphOps { /// Returns the type of node fn node_type(&self, v: VID) -> Option; + fn node_type_id(&self, v: VID) -> usize; + /// Gets the internal reference for an external node reference and keeps internal references unchanged. fn internalise_node(&self, v: NodeRef) -> Option; @@ -362,6 +364,9 @@ impl CoreGraphOps for G { fn unfiltered_num_edges(&self) -> usize { self.graph().unfiltered_num_edges() } + fn node_type_id(&self, v: VID) -> usize { + self.graph().node_type_id(v) + } } pub enum NodeAdditions<'a> { diff --git a/raphtory/src/db/api/view/internal/materialize.rs b/raphtory/src/db/api/view/internal/materialize.rs index 7e61190e52..a948a77ac9 100644 --- a/raphtory/src/db/api/view/internal/materialize.rs +++ b/raphtory/src/db/api/view/internal/materialize.rs @@ -7,7 +7,7 @@ use crate::{ properties::{graph_meta::GraphMeta, props::Meta, tprop::TProp}, LayerIds, EID, ELID, VID, }, - storage::locked_view::LockedView, + storage::{locked_view::LockedView, timeindex::TimeIndexEntry}, utils::errors::GraphError, ArcStr, PropType, }, diff --git a/raphtory/src/db/api/view/mod.rs b/raphtory/src/db/api/view/mod.rs index a3cb6cddf1..37c908afa9 100644 --- a/raphtory/src/db/api/view/mod.rs +++ b/raphtory/src/db/api/view/mod.rs @@ -5,7 +5,6 @@ mod graph; pub(crate) mod internal; mod layer; pub(crate) mod node; -mod node_types_filter; mod reset_filter; pub(crate) mod time; @@ -19,7 +18,6 @@ pub use internal::{ pub use layer::*; pub(crate) use node::BaseNodeViewOps; pub use node::NodeViewOps; -pub use node_types_filter::*; pub use reset_filter::*; pub use time::*; diff --git a/raphtory/src/db/api/view/node_types_filter.rs b/raphtory/src/db/api/view/node_types_filter.rs deleted file mode 100644 index 3c8901672c..0000000000 --- a/raphtory/src/db/api/view/node_types_filter.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::db::{ - api::view::internal::{CoreGraphOps, OneHopFilter}, - graph::views::node_type_filtered_subgraph::TypeFilteredSubgraph, -}; -use std::borrow::Borrow; - -pub trait NodeTypesFilter<'graph>: OneHopFilter<'graph> { - fn type_filter, V: Borrow>( - &self, - node_types: I, - ) -> Self::Filtered> { - let meta = self.current_filter().node_meta().node_type_meta(); - let r = node_types - .into_iter() - .flat_map(|nt| meta.get_id(nt.borrow())) - .collect(); - self.one_hop_filtered(TypeFilteredSubgraph::new(self.current_filter().clone(), r)) - } -} diff --git a/raphtory/src/db/graph/graph.rs b/raphtory/src/db/graph/graph.rs index 89156e5f09..3469852431 100644 --- a/raphtory/src/db/graph/graph.rs +++ b/raphtory/src/db/graph/graph.rs @@ -209,7 +209,9 @@ mod db_tests { time::internal::InternalTimeOps, EdgeViewOps, Layer, LayerOps, NodeViewOps, StaticGraphViewOps, TimeOps, }, - graph::{edge::EdgeView, edges::Edges, node::NodeView, path::PathFromNode}, + graph::{ + edge::EdgeView, edges::Edges, node::NodeView, nodes::Nodes, path::PathFromNode, + }, }, graphgen::random_attachment::random_attachment, prelude::{AdditionOps, PropertyAdditionOps}, @@ -2487,6 +2489,464 @@ mod db_tests { // test(&arrow_graph); } + #[test] + fn test_type_filter() { + let g = PersistentGraph::new(); + + g.add_node(1, 1, NO_PROPS, Some("wallet")).unwrap(); + g.add_node(1, 2, NO_PROPS, Some("timer")).unwrap(); + g.add_node(1, 3, NO_PROPS, Some("timer")).unwrap(); + g.add_node(1, 4, NO_PROPS, Some("wallet")).unwrap(); + + assert_eq!( + g.nodes().type_filter(&vec!["wallet"]).name().collect_vec(), + vec!["1", "4"] + ); + + let g = Graph::new(); + g.add_node(1, 1, NO_PROPS, Some("a")).unwrap(); + g.add_node(1, 2, NO_PROPS, Some("b")).unwrap(); + g.add_node(1, 3, NO_PROPS, Some("b")).unwrap(); + g.add_node(1, 4, NO_PROPS, Some("a")).unwrap(); + g.add_node(1, 5, NO_PROPS, Some("c")).unwrap(); + g.add_node(1, 6, NO_PROPS, Some("e")).unwrap(); + g.add_edge(2, 1, 2, NO_PROPS, Some("a")).unwrap(); + g.add_edge(2, 3, 2, NO_PROPS, Some("a")).unwrap(); + g.add_edge(2, 2, 4, NO_PROPS, Some("a")).unwrap(); + g.add_edge(2, 4, 5, NO_PROPS, Some("a")).unwrap(); + g.add_edge(2, 4, 5, NO_PROPS, Some("a")).unwrap(); + g.add_edge(2, 5, 6, NO_PROPS, Some("a")).unwrap(); + g.add_edge(2, 3, 6, NO_PROPS, Some("a")).unwrap(); + + let w = g.window(1, 4); + assert_eq!( + w.nodes() + .type_filter(&vec!["a"]) + .iter() + .map(|v| v.degree()) + .collect::>(), + vec![1, 2] + ); + assert_eq!( + w.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["c", "b"]) + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec!["2"], vec!["2", "5"]] + ); + + let l = g.layers(["a"]).unwrap(); + assert_eq!( + l.nodes() + .type_filter(&vec!["a"]) + .iter() + .map(|v| v.degree()) + .collect::>(), + vec![1, 2] + ); + assert_eq!( + l.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["c", "b"]) + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec!["2"], vec!["2", "5"]] + ); + + let sg = g.subgraph([1, 2, 3, 4, 5, 6]); + assert_eq!( + sg.nodes() + .type_filter(&vec!["a"]) + .iter() + .map(|v| v.degree()) + .collect::>(), + vec![1, 2] + ); + assert_eq!( + sg.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["c", "b"]) + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec!["2"], vec!["2", "5"]] + ); + + assert_eq!( + g.nodes().iter().map(|v| v.degree()).collect::>(), + vec![1, 3, 2, 2, 2, 2] + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .iter() + .map(|v| v.degree()) + .collect::>(), + vec![1, 2] + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["d"]) + .iter() + .map(|v| v.degree()) + .collect::>(), + Vec::::new() + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .par_iter() + .map(|v| v.degree()) + .collect::>(), + vec![1, 2] + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["d"]) + .par_iter() + .map(|v| v.degree()) + .collect::>(), + Vec::::new() + ); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .collect() + .into_iter() + .map(|n| n.name()) + .collect_vec(), + vec!["1", "4"] + ); + assert_eq!( + g.nodes() + .type_filter(&Vec::<&str>::new()) + .collect() + .into_iter() + .map(|n| n.name()) + .collect_vec(), + Vec::<&str>::new() + ); + + assert_eq!(g.nodes().len(), 6); + assert_eq!(g.nodes().type_filter(&vec!["b"]).len(), 2); + assert_eq!(g.nodes().type_filter(&vec!["d"]).len(), 0); + + assert_eq!(g.nodes().is_empty(), false); + assert_eq!(g.nodes().type_filter(&vec!["d"]).is_empty(), true); + + assert_eq!( + g.nodes().type_filter(&vec!["a"]).name().collect_vec(), + vec!["1", "4"] + ); + assert_eq!( + g.nodes().type_filter(&vec!["a", "c"]).name().collect_vec(), + vec!["1", "4", "5"] + ); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec!["2"], vec!["2", "5"]] + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["a", "c"]) + .neighbours() + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec!["2"], vec!["2", "5"], vec!["4", "6"]] + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["d"]) + .neighbours() + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + Vec::>::new() + ); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["c"]) + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec![], vec!["5"]] + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&Vec::<&str>::new()) + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec![], Vec::<&str>::new()] + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["c", "b"]) + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec!["2"], vec!["2", "5"]] + ); + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["d"]) + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec![], Vec::<&str>::new()] + ); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .neighbours() + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec!["1", "3", "4"], vec!["1", "3", "4", "4", "6"]] + ); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["c"]) + .neighbours() + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![vec![], vec!["4", "6"]] + ); + + assert_eq!( + g.nodes() + .neighbours() + .neighbours() + .name() + .map(|n| { n.collect::>() }) + .collect_vec(), + vec![ + vec!["1", "3", "4"], + vec!["2", "2", "6", "2", "5"], + vec!["1", "3", "4", "3", "5"], + vec!["1", "3", "4", "4", "6"], + vec!["2", "5", "3", "5"], + vec!["2", "6", "4", "6"], + ] + ); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["d"]) + .total_count(), + 0 + ); + + assert!(g + .nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["d"]) + .is_all_empty()); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["d"]) + .iter() + .map(|n| { n.name().collect::>() }) + .collect_vec(), + vec![vec![], Vec::<&str>::new()] + ); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["b"]) + .collect() + .into_iter() + .flatten() + .map(|n| n.name()) + .collect_vec(), + vec!["2", "2"] + ); + + assert_eq!( + g.nodes() + .type_filter(&vec!["a"]) + .neighbours() + .type_filter(&vec!["d"]) + .collect() + .into_iter() + .flatten() + .map(|n| n.name()) + .collect_vec(), + Vec::<&str>::new() + ); + + assert_eq!( + g.node("2").unwrap().neighbours().name().collect_vec(), + vec!["1", "3", "4"] + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["b"]) + .name() + .collect_vec(), + vec!["3"] + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["d"]) + .name() + .collect_vec(), + Vec::<&str>::new() + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["c", "a"]) + .name() + .collect_vec(), + vec!["1", "4"] + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["c"]) + .neighbours() + .name() + .collect_vec(), + Vec::<&str>::new() + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .neighbours() + .name() + .collect_vec(), + vec!["2", "2", "6", "2", "5"], + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["d"]) + .len(), + 0 + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["a"]) + .neighbours() + .len(), + 3 + ); + + assert!(g + .node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["d"]) + .is_empty()); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["a"]) + .neighbours() + .is_empty(), + false + ); + + assert!(g + .node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["d"]) + .neighbours() + .is_empty()); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["d"]) + .iter() + .collect_vec(), + Vec::>::new() + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["b"]) + .collect() + .into_iter() + .map(|n| n.name()) + .collect_vec(), + vec!["3"] + ); + + assert_eq!( + g.node("2") + .unwrap() + .neighbours() + .type_filter(&vec!["d"]) + .collect() + .into_iter() + .map(|n| n.name()) + .collect_vec(), + Vec::<&str>::new() + ); + } + #[test] fn test_persistent_graph() { let g = Graph::new(); diff --git a/raphtory/src/db/graph/mod.rs b/raphtory/src/db/graph/mod.rs index c067f49b6e..51285343e5 100644 --- a/raphtory/src/db/graph/mod.rs +++ b/raphtory/src/db/graph/mod.rs @@ -1,3 +1,6 @@ +use crate::core::entities::properties::props::DictMapper; +use std::sync::Arc; + pub mod edge; pub mod edges; pub mod graph; @@ -5,3 +8,19 @@ pub mod node; pub mod nodes; pub mod path; pub mod views; + +pub(crate) fn create_node_type_filter( + dict_mapper: &DictMapper, + node_types: &[impl AsRef], +) -> Arc<[bool]> { + let len = dict_mapper.len(); + let mut bool_arr = vec![false; len]; + + for nt in node_types { + if let Some(id) = dict_mapper.get_id(nt.as_ref()) { + bool_arr[id] = true; + } + } + + bool_arr.into() +} diff --git a/raphtory/src/db/graph/nodes.rs b/raphtory/src/db/graph/nodes.rs index 811e72cbe6..26dbe58437 100644 --- a/raphtory/src/db/graph/nodes.rs +++ b/raphtory/src/db/graph/nodes.rs @@ -7,13 +7,14 @@ use crate::{ view::{ internal::{OneHopFilter, Static}, BaseNodeViewOps, BoxedLIter, DynamicGraph, IntoDynBoxed, IntoDynamic, - NodeTypesFilter, }, }, graph::{edges::NestedEdges, node::NodeView, path::PathFromGraph}, }, prelude::*, }; + +use crate::db::graph::create_node_type_filter; use rayon::iter::ParallelIterator; use std::{marker::PhantomData, sync::Arc}; @@ -21,36 +22,52 @@ use std::{marker::PhantomData, sync::Arc}; pub struct Nodes<'graph, G, GH = G> { pub(crate) base_graph: G, pub(crate) graph: GH, + pub(crate) node_types_filter: Option>, _marker: PhantomData<&'graph ()>, } -impl< - 'graph, - G: GraphViewOps<'graph> + IntoDynamic, - GH: GraphViewOps<'graph> + IntoDynamic + Static, - > From> for Nodes<'graph, DynamicGraph, DynamicGraph> +impl<'graph, G, GH> From> for Nodes<'graph, DynamicGraph, DynamicGraph> +where + G: GraphViewOps<'graph> + IntoDynamic, + GH: GraphViewOps<'graph> + IntoDynamic + Static, { fn from(value: Nodes<'graph, G, GH>) -> Self { - Nodes::new_filtered(value.base_graph.into_dynamic(), value.graph.into_dynamic()) + let base_graph = value.base_graph.into_dynamic(); + let graph = value.graph.into_dynamic(); + Nodes { + base_graph, + graph, + node_types_filter: value.node_types_filter, + _marker: PhantomData, + } } } -impl<'graph, G: GraphViewOps<'graph>> Nodes<'graph, G, G> { - pub fn new(graph: G) -> Nodes<'graph, G, G> { +impl<'graph, G> Nodes<'graph, G, G> +where + G: GraphViewOps<'graph> + Clone, +{ + pub fn new(graph: G) -> Self { let base_graph = graph.clone(); Self { base_graph, graph, + node_types_filter: None, _marker: PhantomData, } } } -impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Nodes<'graph, G, GH> { - pub fn new_filtered(base_graph: G, graph: GH) -> Self { +impl<'graph, G, GH> Nodes<'graph, G, GH> +where + G: GraphViewOps<'graph> + 'graph, + GH: GraphViewOps<'graph> + 'graph, +{ + pub fn new_filtered(base_graph: G, graph: GH, node_types_filter: Option>) -> Self { Self { base_graph, graph, + node_types_filter, _marker: PhantomData, } } @@ -58,7 +75,14 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Nodes<'graph, G, #[inline] pub(crate) fn iter_refs(&self) -> impl Iterator + 'graph { let g = self.graph.core_graph(); - g.into_nodes_iter(self.graph.clone()) + let base_graph = self.base_graph.clone(); + let node_types_filter = self.node_types_filter.clone(); + g.into_nodes_iter(self.graph.clone()).filter(move |v| { + let node_type_id = base_graph.node_type_id(*v); + node_types_filter + .as_ref() + .map_or(true, |filter| filter[node_type_id]) + }) } pub fn iter(&self) -> BoxedLIter<'graph, NodeView> { @@ -71,18 +95,26 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Nodes<'graph, G, pub fn par_iter(&self) -> impl ParallelIterator> + '_ { let cg = self.graph.core_graph(); + let node_types_filter = self.node_types_filter.clone(); + let base_graph = self.base_graph.clone(); cg.into_nodes_par(&self.graph) + .filter(move |v| { + let node_type_id = base_graph.node_type_id(*v); + node_types_filter + .as_ref() + .map_or(true, |filter| filter[node_type_id]) + }) .map(|v| NodeView::new_one_hop_filtered(&self.base_graph, &self.graph, v)) } /// Returns the number of nodes in the graph. pub fn len(&self) -> usize { - self.graph.count_nodes() + self.iter().count() } /// Returns true if the graph contains no nodes. pub fn is_empty(&self) -> bool { - self.graph.is_empty() + self.iter().next().is_none() } pub fn get(&self, node: V) -> Option> { @@ -94,6 +126,19 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Nodes<'graph, G, )) } + pub fn type_filter(&self, node_types: &[impl AsRef]) -> Nodes<'graph, G, GH> { + let node_types_filter = Some(create_node_type_filter( + self.graph.node_meta().node_type_meta(), + node_types, + )); + Nodes { + base_graph: self.base_graph.clone(), + graph: self.graph.clone(), + node_types_filter, + _marker: PhantomData, + } + } + pub fn collect(&self) -> Vec> { self.iter().collect() } @@ -107,8 +152,10 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Nodes<'graph, G, } } -impl<'graph, G: GraphViewOps<'graph> + 'graph, GH: GraphViewOps<'graph> + 'graph> - BaseNodeViewOps<'graph> for Nodes<'graph, G, GH> +impl<'graph, G, GH> BaseNodeViewOps<'graph> for Nodes<'graph, G, GH> +where + G: GraphViewOps<'graph> + 'graph, + GH: GraphViewOps<'graph> + 'graph, { type BaseGraph = G; type Graph = GH; @@ -171,8 +218,10 @@ impl<'graph, G: GraphViewOps<'graph> + 'graph, GH: GraphViewOps<'graph> + 'graph } } -impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> OneHopFilter<'graph> - for Nodes<'graph, G, GH> +impl<'graph, G, GH> OneHopFilter<'graph> for Nodes<'graph, G, GH> +where + G: GraphViewOps<'graph> + 'graph, + GH: GraphViewOps<'graph> + 'graph, { type BaseGraph = G; type FilteredGraph = GH; @@ -194,18 +243,16 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> OneHopFilter<'gr Nodes { base_graph, graph: filtered_graph, + node_types_filter: self.node_types_filter.clone(), _marker: PhantomData, } } } -impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> NodeTypesFilter<'graph> - for Nodes<'graph, G, GH> -{ -} - -impl<'graph, G: GraphViewOps<'graph> + 'graph, GH: GraphViewOps<'graph> + 'graph> IntoIterator - for Nodes<'graph, G, GH> +impl<'graph, G, GH> IntoIterator for Nodes<'graph, G, GH> +where + G: GraphViewOps<'graph> + 'graph, + GH: GraphViewOps<'graph> + 'graph, { type Item = NodeView; type IntoIter = BoxedLIter<'graph, Self::Item>; diff --git a/raphtory/src/db/graph/path.rs b/raphtory/src/db/graph/path.rs index 0012a60092..239d67b3a6 100644 --- a/raphtory/src/db/graph/path.rs +++ b/raphtory/src/db/graph/path.rs @@ -9,6 +9,7 @@ use crate::{ }, }, graph::{ + create_node_type_filter, edges::{Edges, NestedEdges}, node::NodeView, views::{ @@ -23,8 +24,8 @@ use std::sync::Arc; #[derive(Clone)] pub struct PathFromGraph<'graph, G, GH> { - pub(crate) graph: GH, pub(crate) base_graph: G, + pub(crate) graph: GH, pub(crate) nodes: Arc BoxedLIter<'graph, VID> + Send + Sync + 'graph>, pub(crate) op: Arc BoxedLIter<'graph, VID> + Send + Sync + 'graph>, } @@ -47,6 +48,21 @@ impl<'graph, G: GraphViewOps<'graph>> PathFromGraph<'graph, G, G> { } impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> PathFromGraph<'graph, G, GH> { + fn new_filtered BoxedLIter<'graph, VID> + Send + Sync + 'graph>( + base_graph: G, + graph: GH, + nodes: Arc BoxedLIter<'graph, VID> + Send + Sync + 'graph>, + op: OP, + ) -> Self { + let op = Arc::new(op); + PathFromGraph { + graph, + base_graph, + nodes, + op, + } + } + fn base_iter(&self) -> BoxedLIter<'graph, VID> { (self.nodes)() } @@ -55,9 +71,9 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> PathFromGraph<'g let graph = self.graph.clone(); let base_graph = self.base_graph.clone(); let op = self.op.clone(); - self.base_iter().map(move |node| { + self.base_iter().map(move |v| { let op = op.clone(); - let node_op = Arc::new(move || op(node)); + let node_op = Arc::new(move || op(v)); PathFromNode::new_one_hop_filtered(base_graph.clone(), graph.clone(), node_op) }) } @@ -67,12 +83,44 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> PathFromGraph<'g self.base_iter().map(move |vid| op(vid)) } + pub fn total_count(&self) -> usize { + self.iter_refs().flatten().count() + } + pub fn len(&self) -> usize { - self.iter().count() + self.iter_refs().count() + } + + pub fn is_all_empty(&self) -> bool { + self.iter_refs().flatten().next().is_none() } pub fn is_empty(&self) -> bool { - self.iter().next().is_none() + self.iter_refs().next().is_none() + } + + pub fn type_filter(&self, node_types: &[impl AsRef]) -> PathFromGraph<'graph, G, GH> { + let node_types_filter = + create_node_type_filter(self.graph.node_meta().node_type_meta(), node_types); + + let base_graph = self.base_graph.clone(); + let old_op = self.op.clone(); + + PathFromGraph::new_filtered( + self.base_graph.clone(), + self.graph.clone(), + self.nodes.clone(), + move |vid| { + let base_graph = base_graph.clone(); + let node_types_filter = node_types_filter.clone(); + old_op(vid) + .filter(move |v| { + let node_type_id = base_graph.node_type_id(*v); + node_types_filter[node_type_id] + }) + .into_dyn_boxed() + }, + ) } pub fn collect(&self) -> Vec>> { @@ -260,6 +308,18 @@ impl<'graph, G: GraphViewOps<'graph>> PathFromNode<'graph, G, G> { } impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> PathFromNode<'graph, G, GH> { + pub(crate) fn new_one_hop_filtered( + base_graph: G, + graph: GH, + op: Arc BoxedLIter<'graph, VID> + Send + Sync + 'graph>, + ) -> Self { + Self { + base_graph, + graph, + op, + } + } + pub fn iter_refs(&self) -> BoxedLIter<'graph, VID> { (self.op)() } @@ -281,21 +341,32 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> PathFromNode<'gr self.iter().next().is_none() } - pub fn collect(&self) -> Vec> { - self.iter().collect() - } + pub fn type_filter(&self, node_types: &[impl AsRef]) -> PathFromNode<'graph, G, GH> { + let node_types_filter = + create_node_type_filter(self.graph.node_meta().node_type_meta(), node_types); - pub(crate) fn new_one_hop_filtered( - base_graph: G, - graph: GH, - op: Arc BoxedLIter<'graph, VID> + Send + Sync + 'graph>, - ) -> Self { - Self { - base_graph, - graph, - op, + let base_graph = self.base_graph.clone(); + let old_op = self.op.clone(); + + PathFromNode { + base_graph: self.base_graph.clone(), + graph: self.graph.clone(), + op: Arc::new(move || { + let base_graph = base_graph.clone(); + let node_types_filter = node_types_filter.clone(); + old_op() + .filter(move |v| { + let node_type_id = base_graph.node_type_id(*v); + node_types_filter[node_type_id] + }) + .into_dyn_boxed() + }), } } + + pub fn collect(&self) -> Vec> { + self.iter().collect() + } } impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> BaseNodeViewOps<'graph> @@ -407,16 +478,6 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> OneHopFilter<'gr } } -impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> NodeTypesFilter<'graph> - for PathFromGraph<'graph, G, GH> -{ -} - -impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> NodeTypesFilter<'graph> - for PathFromNode<'graph, G, GH> -{ -} - #[cfg(test)] mod test { use crate::prelude::*; diff --git a/raphtory/src/db/internal/core_ops.rs b/raphtory/src/db/internal/core_ops.rs index 25277e7c29..bbabc190e8 100644 --- a/raphtory/src/db/internal/core_ops.rs +++ b/raphtory/src/db/internal/core_ops.rs @@ -94,6 +94,10 @@ impl CoreGraphOps for InternalGraph { self.inner().node_type(v) } + #[inline] + fn node_type_id(&self, v: VID) -> usize { + self.inner().node_type_id(v) + } #[inline] fn internalise_node(&self, v: NodeRef) -> Option { self.inner().resolve_node_ref(v) diff --git a/raphtory/src/db/internal/time_semantics.rs b/raphtory/src/db/internal/time_semantics.rs index 9870127fdf..1bbbc5aed0 100644 --- a/raphtory/src/db/internal/time_semantics.rs +++ b/raphtory/src/db/internal/time_semantics.rs @@ -17,9 +17,10 @@ use crate::{ BoxedIter, IntoDynBoxed, }, }, - prelude::{Prop, TimeIndexEntry}, + prelude::Prop, }; use itertools::{kmerge, Itertools}; +use raphtory_api::core::storage::timeindex::TimeIndexEntry; use rayon::prelude::*; use std::ops::Range; diff --git a/raphtory/src/db/task/edge/eval_edges.rs b/raphtory/src/db/task/edge/eval_edges.rs index afbea45040..f2d725303b 100644 --- a/raphtory/src/db/task/edge/eval_edges.rs +++ b/raphtory/src/db/task/edge/eval_edges.rs @@ -33,7 +33,13 @@ impl<'graph, 'a: 'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>, CS: for EvalEdges<'graph, 'a, G, GH, CS, S> { fn clone(&self) -> Self { - todo!() + Self { + ss: self.ss, + edges: self.edges.clone(), + storage: self.storage, + node_state: self.node_state.clone(), + local_state_prev: self.local_state_prev, + } } } diff --git a/raphtory/src/db/task/node/eval_node.rs b/raphtory/src/db/task/node/eval_node.rs index 3c252467d8..8e95b55bde 100644 --- a/raphtory/src/db/task/node/eval_node.rs +++ b/raphtory/src/db/task/node/eval_node.rs @@ -14,12 +14,12 @@ use crate::{ storage::storage_ops::GraphStorage, view::{internal::OneHopFilter, BaseNodeViewOps, BoxedLIter, IntoDynBoxed}, }, - graph::{edges::Edges, node::NodeView, path::PathFromNode}, + graph::{create_node_type_filter, edges::Edges, node::NodeView, path::PathFromNode}, task::{ edge::eval_edges::EvalEdges, eval_graph::EvalGraph, node::eval_node_state::EVState, }, }, - prelude::{GraphViewOps, NodeTypesFilter}, + prelude::GraphViewOps, }; use std::{cell::Ref, sync::Arc}; @@ -276,6 +276,29 @@ impl< self.iter_refs() .map(move |v| EvalNodeView::new_filtered(v, base_graph.clone(), graph.clone(), None)) } + + pub fn type_filter(&self, node_types: &[impl AsRef]) -> Self { + let node_types_filter = + create_node_type_filter(self.graph.node_meta().node_type_meta(), node_types); + + let base_graph = self.base_graph.base_graph.clone(); + let old_op = self.op.clone(); + + EvalPathFromNode { + base_graph: self.base_graph.clone(), + graph: self.graph.clone(), + op: Arc::new(move || { + let base_graph = base_graph.clone(); + let node_types_filter = node_types_filter.clone(); + old_op() + .filter(move |v| { + let node_type_id = base_graph.node_type_id(*v); + node_types_filter[node_type_id] + }) + .into_dyn_boxed() + }), + } + } } impl< @@ -539,17 +562,6 @@ impl< } } -impl< - 'graph, - 'a: 'graph, - G: GraphViewOps<'graph>, - S, - CS: ComputeState + 'a, - GH: GraphViewOps<'graph>, - > NodeTypesFilter<'graph> for EvalPathFromNode<'graph, 'a, G, GH, CS, S> -{ -} - /// Represents an entry in the shuffle table. /// /// The entry contains a reference to a `ShuffleComputeState` and an `AccId` representing the accumulator diff --git a/raphtory/src/lib.rs b/raphtory/src/lib.rs index a4e33cc7b0..1aafebaa75 100644 --- a/raphtory/src/lib.rs +++ b/raphtory/src/lib.rs @@ -107,13 +107,12 @@ pub mod vectors; pub mod prelude { pub const NO_PROPS: [(&str, Prop); 0] = []; pub use crate::{ - core::{storage::timeindex::TimeIndexEntry, IntoProp, Prop, PropUnwrap}, + core::{IntoProp, Prop, PropUnwrap}, db::{ api::{ mutation::{AdditionOps, DeletionOps, ImportOps, PropertyAdditionOps}, view::{ - EdgeViewOps, GraphViewOps, Layer, LayerOps, NodeTypesFilter, NodeViewOps, - ResetFilter, TimeOps, + EdgeViewOps, GraphViewOps, Layer, LayerOps, NodeViewOps, ResetFilter, TimeOps, }, }, graph::graph::Graph, diff --git a/raphtory/src/python/graph/node.rs b/raphtory/src/python/graph/node.rs index bd3fd214c5..28e0b8b476 100644 --- a/raphtory/src/python/graph/node.rs +++ b/raphtory/src/python/graph/node.rs @@ -423,13 +423,6 @@ pub struct PyNodes { pub(crate) nodes: Nodes<'static, DynamicGraph, DynamicGraph>, } -impl_nodetypesfilter!( - PyNodes, - nodes, - Nodes<'static, DynamicGraph, DynamicGraph>, - "Nodes" -); - impl_nodeviewops!( PyNodes, nodes, @@ -451,7 +444,7 @@ impl let graph = value.graph.into_dynamic(); let base_graph = value.base_graph.into_dynamic(); Self { - nodes: Nodes::new_filtered(base_graph, graph), + nodes: Nodes::new_filtered(base_graph, graph, value.node_types_filter), } } } @@ -673,6 +666,10 @@ impl PyNodes { Ok(df_data.to_object(py)) }) } + + pub fn type_filter(&self, node_types: Vec<&str>) -> Nodes<'static, DynamicGraph> { + self.nodes.type_filter(&node_types) + } } impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Repr for Nodes<'static, G, GH> { @@ -686,13 +683,6 @@ pub struct PyPathFromGraph { path: PathFromGraph<'static, DynamicGraph, DynamicGraph>, } -impl_nodetypesfilter!( - PyPathFromGraph, - path, - PathFromGraph<'static, DynamicGraph, DynamicGraph>, - "PathFromGraph" -); - impl_nodeviewops!( PyPathFromGraph, path, @@ -785,6 +775,13 @@ impl PyPathFromGraph { let path = self.path.clone(); (move || path.out_degree()).into() } + + pub fn type_filter( + &self, + node_types: Vec<&str>, + ) -> PathFromGraph<'static, DynamicGraph, DynamicGraph> { + self.path.type_filter(&node_types) + } } impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Repr @@ -823,13 +820,6 @@ pub struct PyPathFromNode { path: PathFromNode<'static, DynamicGraph, DynamicGraph>, } -impl_nodetypesfilter!( - PyPathFromNode, - path, - PathFromNode<'static, DynamicGraph, DynamicGraph>, - "PathFromNode" -); - impl_nodeviewops!( PyPathFromNode, path, @@ -918,6 +908,13 @@ impl PyPathFromNode { let path = self.path.clone(); (move || path.degree()).into() } + + pub fn type_filter( + &self, + node_types: Vec<&str>, + ) -> PathFromNode<'static, DynamicGraph, DynamicGraph> { + self.path.type_filter(&node_types) + } } impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> Repr diff --git a/raphtory/src/python/types/macros/trait_impl/mod.rs b/raphtory/src/python/types/macros/trait_impl/mod.rs index 7293215748..abc9d14852 100644 --- a/raphtory/src/python/types/macros/trait_impl/mod.rs +++ b/raphtory/src/python/types/macros/trait_impl/mod.rs @@ -14,6 +14,3 @@ mod nodeviewops; mod repr; #[macro_use] mod iterable_mixin; - -#[macro_use] -mod nodetypesfilter; diff --git a/raphtory/src/python/types/macros/trait_impl/nodetypesfilter.rs b/raphtory/src/python/types/macros/trait_impl/nodetypesfilter.rs deleted file mode 100644 index 07a2accc5c..0000000000 --- a/raphtory/src/python/types/macros/trait_impl/nodetypesfilter.rs +++ /dev/null @@ -1,17 +0,0 @@ -macro_rules! impl_nodetypesfilter { - ($obj:ty, $field:ident, $base_type:ty, $name:literal) => { - - #[pyo3::pymethods] - impl $obj { - - /// Filter nodes by node types - /// - /// Arguments: - /// node_types (list[str]): list of node types - /// - fn type_filter(&self, node_types: Vec<&str>) -> <$base_type as $crate::db::api::view::internal::OneHopFilter<'static>>::Filtered<$crate::db::graph::views::node_type_filtered_subgraph::TypeFilteredSubgraph<<$base_type as $crate::db::api::view::internal::OneHopFilter<'static>>::FilteredGraph>> { - self.$field.type_filter(node_types) - } - } - }; -}