diff --git a/raphtory-benchmark/src/main.rs b/raphtory-benchmark/src/main.rs index 4687d095a5..ca535c5748 100644 --- a/raphtory-benchmark/src/main.rs +++ b/raphtory-benchmark/src/main.rs @@ -1,4 +1,3 @@ -use chrono::DateTime; use clap::{ArgAction, Parser}; use csv::StringRecord; use flate2::read::GzDecoder; @@ -50,6 +49,10 @@ struct Args { /// Debug to print more info to the screen #[arg(long, action=ArgAction::SetTrue)] debug: bool, + + /// Set the number of locks for the node and edge storage + #[arg(long)] + num_shards: Option, } fn main() { @@ -85,6 +88,7 @@ fn main() { let to_column = args.to_column; let time_column = args.time_column; let download = args.download; + let num_shards = args.num_shards; if download { let url = "https://osf.io/download/nbq6h/"; @@ -139,38 +143,35 @@ fn main() { println!("Running setup..."); let mut now = Instant::now(); // Iterate over the CSV records - let g = { - let g = Graph::new(); - CsvLoader::new(file_path) - .set_header(header) - .set_delimiter(&delimiter) - .load_rec_into_graph(&g, |generic_loader: StringRecord, g: &Graph| { - let src_id = generic_loader - .get(from_column) - .map(|s| s.to_owned()) - .unwrap(); - let dst_id = generic_loader.get(to_column).map(|s| s.to_owned()).unwrap(); - let mut edge_time = DateTime::from_timestamp(1, 0).unwrap().naive_utc(); - if time_column != -1 { - edge_time = DateTime::from_timestamp_millis( - generic_loader - .get(time_column as usize) - .unwrap() - .parse() - .unwrap(), - ) - .unwrap() - .naive_utc(); - } - if debug { - println!("Adding edge {} -> {} at time {}", src_id, dst_id, edge_time); - } - g.add_edge(edge_time, src_id, dst_id, NO_PROPS, None) - .expect("Failed to add edge"); - }) - .expect("Failed to load graph from CSV data files"); - g + let g = match num_shards { + Some(num_shards) => { + println!("Constructing graph with {num_shards} shards."); + Graph::new_with_shards(num_shards) + } + None => Graph::new(), }; + CsvLoader::new(file_path) + .set_header(header) + .set_delimiter(&delimiter) + .load_rec_into_graph(&g, |generic_loader: StringRecord, g: &Graph| { + let src_id = generic_loader.get(from_column).unwrap(); + let dst_id = generic_loader.get(to_column).unwrap(); + let edge_time = if time_column != -1 { + generic_loader + .get(time_column as usize) + .unwrap() + .parse() + .unwrap() + } else { + 1i64 + }; + if debug { + println!("Adding edge {} -> {} at time {}", src_id, dst_id, edge_time); + } + g.add_edge(edge_time, src_id, dst_id, NO_PROPS, None) + .expect("Failed to add edge"); + }) + .expect("Failed to load graph from CSV data files"); println!("Setup took {} seconds", now.elapsed().as_secs_f64()); if debug { diff --git a/raphtory/src/db/graph/graph.rs b/raphtory/src/db/graph/graph.rs index af4841a00a..63fd86f29a 100644 --- a/raphtory/src/db/graph/graph.rs +++ b/raphtory/src/db/graph/graph.rs @@ -140,7 +140,7 @@ impl InheritMutationOps for Graph {} impl InheritViewOps for Graph {} impl Graph { - /// Create a new graph with the specified number of shards + /// Create a new graph /// /// Returns: /// @@ -156,6 +156,14 @@ impl Graph { Self(Arc::new(InternalGraph::default())) } + /// Create a new graph with specified number of shards + /// + /// Returns: + /// + /// A raphtory graph + pub fn new_with_shards(num_shards: usize) -> Self { + Self(Arc::new(InternalGraph::new(num_shards))) + } pub(crate) fn from_internal_graph(internal_graph: Arc) -> Self { Self(internal_graph) } @@ -2903,4 +2911,14 @@ mod db_tests { ] ); } + + #[test] + fn num_locks_same_as_threads() { + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(5) + .build() + .unwrap(); + let graph = pool.install(|| Graph::new()); + assert_eq!(graph.0.inner().storage.nodes.data.len(), 5); + } }