From 842878155186af716ca07918e818d09490f82d4d Mon Sep 17 00:00:00 2001 From: Fabian Murariu Date: Mon, 25 Nov 2024 20:46:52 +0000 Subject: [PATCH] just collect in parallel to vec when masking --- raphtory/src/db/graph/views/masked_graph.rs | 32 ++++++++++++++++----- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/raphtory/src/db/graph/views/masked_graph.rs b/raphtory/src/db/graph/views/masked_graph.rs index 9d072b3a5..20412b636 100644 --- a/raphtory/src/db/graph/views/masked_graph.rs +++ b/raphtory/src/db/graph/views/masked_graph.rs @@ -7,12 +7,14 @@ use crate::{ nodes::{node_ref::NodeStorageRef, node_storage_ops::NodeStorageOps}, }, view::internal::{ - Base, EdgeFilterOps, Immutable, InheritCoreOps, InheritLayerOps, InheritListOps, - InheritMaterialize, InheritTimeSemantics, NodeFilterOps, Static, + Base, CoreGraphOps, EdgeFilterOps, Immutable, InheritCoreOps, InheritLayerOps, + InheritListOps, InheritMaterialize, InheritTimeSemantics, InternalLayerOps, + NodeFilterOps, Static, }, }, prelude::{GraphViewOps, LayerOps}, }; +use rayon::prelude::*; use roaring::RoaringTreemap; use std::{ fmt::{Debug, Formatter}, @@ -57,12 +59,28 @@ impl<'graph, G: GraphViewOps<'graph>> MaskedGraph { for l_name in graph.unique_layers() { let l_id = graph.get_layer_id(&l_name).unwrap(); let layer_g = graph.layers(l_name).unwrap(); - let nodes = layer_g.nodes(); - let edges = layer_g.edges(); - let nodes: RoaringTreemap = nodes.into_iter().map(|id| id.node.as_u64()).collect(); - let edges: RoaringTreemap = - edges.into_iter().map(|id| id.edge.pid().as_u64()).collect(); + let nodes = layer_g + .nodes() + .par_iter() + .map(|node| node.node.as_u64()) + .collect::>(); + + let nodes: RoaringTreemap = nodes.into_iter().collect(); + + let edges = layer_g.core_edges(); + + let edges = edges + .as_ref() + .par_iter(&LayerIds::All) + .filter(|edge| { + graph.filter_edge(edge.as_ref(), layer_g.layer_ids()) + && nodes.contains(edge.src().as_u64()) + && nodes.contains(edge.dst().as_u64()) + }) + .map(|edge| edge.eid().as_u64()) + .collect::>(); + let edges: RoaringTreemap = edges.into_iter().collect(); if layered_masks.len() < l_id + 1 { layered_masks.resize(l_id + 1, (RoaringTreemap::new(), RoaringTreemap::new()));