Skip to content

Commit

Permalink
fix some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianmurariu committed Jul 22, 2024
1 parent 59ba432 commit a672642
Show file tree
Hide file tree
Showing 26 changed files with 171 additions and 116 deletions.
10 changes: 6 additions & 4 deletions raphtory-api/src/core/entities/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::{borrow::Cow, fmt::{Display, Formatter}};
use std::{
borrow::Cow,
fmt::{Display, Formatter},
};

#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
use pyo3::exceptions::PyTypeError;
#[cfg(feature = "python")]
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};

use self::edges::edge_ref::EdgeRef;
Expand Down Expand Up @@ -157,7 +160,6 @@ impl<'source> FromPyObject<'source> for GID {
}
}


#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
pub enum GidRef<'a> {
U64(u64),
Expand Down
41 changes: 23 additions & 18 deletions raphtory/src/algorithms/community_detection/label_propagation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use rand::{rngs::StdRng, seq::SliceRandom, thread_rng, SeedableRng};
use raphtory_api::core::entities::VID;
use raphtory_api::core::entities::GID;
use std::collections::{BTreeMap, HashMap, HashSet};

use crate::{
Expand All @@ -25,12 +25,13 @@ pub fn label_propagation<G>(
where
G: StaticGraphViewOps,
{
let mut labels: Vec<usize> = Vec::with_capacity(graph.count_nodes());
let nodes = graph.nodes();
labels = nodes.iter().map(|n| n.node.0).collect();
let mut labels: HashMap<NodeView<&G>, GID> = HashMap::new();
let nodes = &graph.nodes();
for node in nodes.iter() {
labels.insert(node, node.id());
}

let nodes = graph.nodes();
let mut shuffled_nodes: Vec<NodeView<_>> = nodes.iter().collect();
let mut shuffled_nodes: Vec<NodeView<&G>> = nodes.iter().collect();
if let Some(seed_value) = seed {
let mut rng = StdRng::from_seed(seed_value);
shuffled_nodes.shuffle(&mut rng);
Expand All @@ -43,37 +44,35 @@ where
changed = false;
for node in &shuffled_nodes {
let neighbors = node.neighbours();
let mut label_count: BTreeMap<usize, f64> = BTreeMap::new();
let mut label_count: BTreeMap<GID, f64> = BTreeMap::new();

for neighbour in neighbors {
let key = labels[neighbour.node.0];
*label_count.entry(key).or_insert(0.0) += 1.0;
*label_count.entry(labels[&neighbour].clone()).or_insert(0.0) += 1.0;
}

if let Some(max_label) = find_max_label(&label_count) {
if max_label != labels[node.node.0] {
labels[node.node.0] = max_label;
if max_label != labels[node] {
labels.insert(node.clone(), max_label);
changed = true;
}
}
}
}

// Group nodes by their labels to form communities
let mut communities: HashMap<usize, HashSet<NodeView<G>>> = HashMap::new();
for (node, label) in labels.into_iter().enumerate() {
let node = graph.node(VID(node)).unwrap();
communities.entry(label).or_default().insert(node);
let mut communities: HashMap<GID, HashSet<NodeView<G>>> = HashMap::new();
for (node, label) in labels {
communities.entry(label).or_default().insert(node.cloned());
}

Ok(communities.values().cloned().collect())
}

fn find_max_label(label_count: &BTreeMap<usize, f64>) -> Option<usize> {
fn find_max_label(label_count: &BTreeMap<GID, f64>) -> Option<GID> {
label_count
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(label, _)| *label)
.map(|(label, _)| label.clone())
}

#[cfg(test)]
Expand All @@ -100,10 +99,16 @@ mod lpa_tests {
for (ts, src, dst) in edges {
graph.add_edge(ts, src, dst, NO_PROPS, None).unwrap();
}

test_storage!(&graph, |graph| {
let seed = Some([5; 32]);
let result = label_propagation(graph, seed).unwrap();

let ids = result
.iter()
.map(|n_set| n_set.iter().map(|n| n.node).collect::<Vec<_>>())
.collect::<Vec<_>>();
println!("{:?}", ids);

let expected = vec![
HashSet::from([
graph.node("R1").unwrap(),
Expand Down
12 changes: 7 additions & 5 deletions raphtory/src/algorithms/components/connected_components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
algorithms::algorithm_result::AlgorithmResult,
core::{entities::VID, state::compute_state::ComputeStateVec},
db::{
api::view::{BaseNodeViewOps, NodeViewOps, StaticGraphViewOps},
api::view::{NodeViewOps, StaticGraphViewOps},
task::{
context::Context,
node::eval_node::EvalNodeView,
Expand All @@ -13,7 +13,6 @@ use crate::{
};
use raphtory_api::core::entities::GID;
use rayon::prelude::*;
use std::cmp;

#[derive(Clone, Debug, Default)]
struct WccState {
Expand Down Expand Up @@ -280,7 +279,7 @@ mod cc_test {

let vs = vs.into_iter().unique().collect::<Vec<u64>>();

let smallest = vs.iter().min().unwrap();
// let smallest = vs.iter().min().unwrap();

let first = vs[0];

Expand All @@ -299,7 +298,7 @@ mod cc_test {
// now we do connected community_detection over window 0..1
let res = weakly_connected_components(graph, usize::MAX, None).group_by();

let actual = res
let (node, size) = res
.into_iter()
.map(|(cc, group)| (cc, Reverse(group.len())))
.sorted_by(|l, r| l.1.cmp(&r.1))
Expand All @@ -308,7 +307,10 @@ mod cc_test {
.next()
.unwrap();

assert_eq!(actual, (GID::U64(*smallest), edges.len()));
let node = graph.node(node).map(|node| node.node);
assert_eq!(node, Some(VID(0)));

assert_eq!(size, edges.len());
});
}
}
Expand Down
2 changes: 1 addition & 1 deletion raphtory/src/algorithms/components/lcc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl LargestConnectedComponent for Graph {
{
let mut connected_components_map =
weakly_connected_components(self, usize::MAX, None).group_by();
let mut lcc_key:GID = GID::U64(0);
let mut lcc_key: GID = GID::U64(0);
let mut key_length = 0;
let mut is_tie = false;

Expand Down
10 changes: 7 additions & 3 deletions raphtory/src/algorithms/components/out_components.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use raphtory_api::core::entities::GID;
use rayon::prelude::*;
use crate::{
algorithms::algorithm_result::AlgorithmResult,
core::{entities::VID, state::compute_state::ComputeStateVec},
Expand All @@ -13,6 +11,8 @@ use crate::{
},
},
};
use raphtory_api::core::entities::GID;
use rayon::prelude::*;
use std::collections::HashSet;

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -77,7 +77,11 @@ where
.par_iter()
.map(|node| {
let VID(id) = node.node;
let comps = local[id].out_components.iter().map(|vid| graph.node_id(*vid)).collect();
let comps = local[id]
.out_components
.iter()
.map(|vid| graph.node_id(*vid))
.collect();
(id, comps)
})
.collect()
Expand Down
13 changes: 8 additions & 5 deletions raphtory/src/algorithms/pathing/temporal_reachability.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::{
algorithms::algorithm_result::AlgorithmResult,
core::{entities::nodes::node_ref::AsNodeRef, state::{
accumulator_id::accumulators::{hash_set, min, or},
compute_state::ComputeStateVec,
}},
core::{
entities::nodes::node_ref::AsNodeRef,
state::{
accumulator_id::accumulators::{hash_set, min, or},
compute_state::ComputeStateVec,
},
},
db::{
api::view::StaticGraphViewOps,
task::{
Expand All @@ -17,7 +20,7 @@ use crate::{
};
use itertools::Itertools;
use num_traits::Zero;
use raphtory_api::core::{entities::VID, input::input_node::InputNode};
use raphtory_api::core::entities::VID;
use std::{collections::HashMap, ops::Add};

#[derive(Eq, Hash, PartialEq, Clone, Debug, Default)]
Expand Down
6 changes: 5 additions & 1 deletion raphtory/src/core/entities/graph/tgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,11 @@ impl TemporalGraph {
Some(*v_id)
}
NodeRef::ExternalStr(string) => {
let v_id = self.logical_to_physical.get(&GID::Str(string.to_owned()))?;
let v_id = self
.logical_to_physical
.get(&GID::Str(string.to_owned()))
.or_else(|| self.logical_to_physical.get(&GID::U64(string.id())))
.or_else(|| self.logical_to_physical.get(&GID::I64(string.id() as i64)))?;
Some(*v_id)
}
}
Expand Down
4 changes: 1 addition & 3 deletions raphtory/src/db/api/mutation/addition_ops.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::time_from_input;
use crate::{
core::{
entities::{edges::edge_ref::EdgeRef, nodes::node_ref::AsNodeRef},
Expand All @@ -12,9 +13,6 @@ use crate::{
graph::{edge::EdgeView, node::NodeView},
},
};
use raphtory_api::core::input::input_node::InputNode;

use super::time_from_input;

pub trait AdditionOps: StaticGraphViewOps {
// TODO: Probably add vector reference here like add
Expand Down
4 changes: 1 addition & 3 deletions raphtory/src/db/api/mutation/deletion_ops.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::time_from_input;
use crate::{
core::{
entities::nodes::node_ref::AsNodeRef,
Expand All @@ -8,9 +9,6 @@ use crate::{
TryIntoInputTime,
},
};
use raphtory_api::core::input::input_node::InputNode;

use super::time_from_input;

pub trait DeletionOps: InternalDeletionOps + InternalAdditionOps + Sized {
fn delete_edge<V: AsNodeRef, T: TryIntoInputTime>(
Expand Down
1 change: 0 additions & 1 deletion raphtory/src/db/api/mutation/import_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::{
mutation::internal::{
InternalAdditionOps, InternalDeletionOps, InternalPropertyAdditionOps,
},
storage::nodes::node_storage_ops::NodeStorageOps,
view::{internal::InternalMaterialize, IntoDynamic, StaticGraphViewOps},
},
graph::{edge::EdgeView, node::NodeView},
Expand Down
20 changes: 9 additions & 11 deletions raphtory/src/db/api/view/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ impl<'graph, G: BoxableGraphView + Sized + Clone + 'graph> GraphViewOps<'graph>
for ee in ee.explode() {
g.add_edge(
ee.time().expect("exploded edge"),
ee.src().name(),
ee.dst().name(),
ee.src().id(),
ee.dst().id(),
ee.properties().temporal().collect_properties(),
layer_name,
)?;
Expand All @@ -186,17 +186,17 @@ impl<'graph, G: BoxableGraphView + Sized + Clone + 'graph> GraphViewOps<'graph>
let v_type_string = v.node_type(); //stop it being dropped
let v_type_str = v_type_string.as_str();
for h in v.history() {
g.add_node(h, v.name(), NO_PROPS, v_type_str)?;
g.add_node(h, v.id(), NO_PROPS, v_type_str)?;
}
for (name, prop_view) in v.properties().temporal().iter() {
for (t, prop) in prop_view.iter() {
g.add_node(t, v.name(), [(name.clone(), prop)], v_type_str)?;
g.add_node(t, v.id(), [(name.clone(), prop)], v_type_str)?;
}
}

let node = match g.node(v.id()) {
Some(node) => node,
None => g.add_node(earliest, v.name(), NO_PROPS, v_type_str)?,
None => g.add_node(earliest, v.id(), NO_PROPS, v_type_str)?,
};

node.add_constant_properties(v.properties().constant())?;
Expand Down Expand Up @@ -502,12 +502,10 @@ mod test_materialize {
g.add_edge(0, 1, 2, [("layer2", "2")], Some("2")).unwrap();

let gm = g.materialize().unwrap();
assert!(gm
.nodes()
.name()
.values()
.collect::<Vec<String>>()
.eq(&vec!["1", "2"]));
assert_eq!(
gm.nodes().name().values().collect::<Vec<String>>(),
vec!["1", "2"]
);

assert!(!g
.layers("2")
Expand Down
13 changes: 11 additions & 2 deletions raphtory/src/db/api/view/serialise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ impl<'graph, G: GraphViewOps<'graph>> StableEncoder for G {
},
};
serialise::Node {
gid: proto_gid,
gid: Some(proto_gid),
vid: vid.0 as u64,
name,
}
Expand Down Expand Up @@ -408,7 +408,8 @@ impl<

// align the nodes
for node in g.nodes {
let l_vid = graph.resolve_node(node.gid)?;
let gid = from_proto_gid(node.gid.and_then(|gid| gid.gid).expect("Missing GID"));
let l_vid = graph.resolve_node(gid)?;
assert_eq!(l_vid, VID(node.vid as usize));
}

Expand Down Expand Up @@ -575,6 +576,14 @@ impl<
}
}

fn from_proto_gid(gid: gid::Gid) -> GID {
match gid {
gid::Gid::GidU64(n) => GID::U64(n),
gid::Gid::GidI64(n) => GID::I64(n),
gid::Gid::GidStr(s) => GID::Str(s),
}
}

fn as_prop(prop_pair: &PropPair) -> (usize, Prop) {
let PropPair { key, value } = prop_pair;
let value = value.as_ref().expect("Missing prop value");
Expand Down
10 changes: 8 additions & 2 deletions raphtory/src/db/graph/views/deletion_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1277,9 +1277,15 @@ mod test_deletions {
pg.add_edge(0, 0, 1, [("added", Prop::I64(0))], None)
.unwrap();
pg.delete_edge(10, 0, 1, None).unwrap();
assert_eq!(pg.edges().id().collect::<Vec<_>>(), vec![(GID::U64(0), GID::U64(1))]);
assert_eq!(
pg.edges().id().collect::<Vec<_>>(),
vec![(GID::U64(0), GID::U64(1))]
);

let g = pg.event_graph();
assert_eq!(g.edges().id().collect::<Vec<_>>(), vec![(GID::U64(0), GID::U64(1))]);
assert_eq!(
g.edges().id().collect::<Vec<_>>(),
vec![(GID::U64(0), GID::U64(1))]
);
}
}
7 changes: 6 additions & 1 deletion raphtory/src/db/graph/views/window_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,12 @@ mod views_test {
test_storage!(&graph, |graph| {
let wg = graph.window(-2, 0);

let actual = wg.nodes().id().values().filter_map(|id| id.to_u64()).collect::<Vec<_>>();
let actual = wg
.nodes()
.id()
.values()
.filter_map(|id| id.to_u64())
.collect::<Vec<_>>();

let expected = vec![1, 2];

Expand Down
2 changes: 1 addition & 1 deletion raphtory/src/graph.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ message GID {
}

message Node{
uint64 GID = 1;
GID gid = 1;
uint64 vid = 2;
optional string name = 3;
}
Expand Down
Loading

0 comments on commit a672642

Please sign in to comment.