From a12e398df421961489a2e5d3044e3ec3bd737223 Mon Sep 17 00:00:00 2001 From: ljeub-pometry <97447091+ljeub-pometry@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:44:38 +0100 Subject: [PATCH] add wrapper for constructing vec from any python iterable (#1862) * add wrapper for constructing vec from any python iterable * add test for new import option --- python/python/raphtory/__init__.pyi | 39 ++++ python/python/raphtory/graphql/__init__.pyi | 22 ++- .../tests/graphql/edit_graph/test_graphql.py | 133 ++++++++++--- python/tests/test_graphdb/test_graphdb.py | 181 ++++++++++-------- .../test_graphdb/test_graphdb_imports.py | 28 +++ raphtory/src/python/graph/graph.rs | 5 +- raphtory/src/python/types/iterable.rs | 36 +++- 7 files changed, 322 insertions(+), 122 deletions(-) diff --git a/python/python/raphtory/__init__.pyi b/python/python/raphtory/__init__.pyi index 7a6ad44fc5..d000b539d3 100644 --- a/python/python/raphtory/__init__.pyi +++ b/python/python/raphtory/__init__.pyi @@ -1235,6 +1235,25 @@ class Graph(GraphView): path (str): The path to the cache file """ + def create_node( + self, + timestamp: TimeInput, + id: str | int, + properties: Optional[PropInput] = None, + node_type: Optional[str] = None, + ) -> MutableNode: + """ + Creates a new node with the given id and properties to the graph. It fails if the node already exists. + + Arguments: + timestamp (TimeInput): The timestamp of the node. + id (str|int): The id of the node. + properties (PropInput, optional): The properties of the node. + node_type (str, optional): The optional string which will be used as a node type + Returns: + MutableNode: The created node + """ + @staticmethod def deserialise(bytes: bytes): """ @@ -3506,6 +3525,26 @@ class PersistentGraph(GraphView): path (str): The path to the cache file """ + def create_node( + self, + timestamp: TimeInput, + id: str | int, + properties: dict = None, + node_type: str = None, + ): + """ + Creates a new node with the given id and properties to the graph. It fails if the node already exists. + + Arguments: + timestamp (TimeInput): The timestamp of the node. + id (str | int): The id of the node. + properties (dict): The properties of the node. + node_type (str) : The optional string which will be used as a node type + + Returns: + MutableNode + """ + def delete_edge( self, timestamp: int, src: str | int, dst: str | int, layer: str = None ): diff --git a/python/python/raphtory/graphql/__init__.pyi b/python/python/raphtory/graphql/__init__.pyi index 28827d4877..c66635a92c 100644 --- a/python/python/raphtory/graphql/__init__.pyi +++ b/python/python/raphtory/graphql/__init__.pyi @@ -329,9 +329,6 @@ class RemoteEdgeAddition(object): """Create and return a new object. See help(type) for accurate signature.""" class RemoteGraph(object): - def __new__(self, path, client) -> RemoteGraph: - """Create and return a new object. See help(type) for accurate signature.""" - def add_constant_properties(self, properties: dict): """ Adds constant properties to the remote graph. @@ -406,6 +403,25 @@ class RemoteGraph(object): properties (dict): The temporal properties of the graph. """ + def create_node( + self, + timestamp: int | str | datetime, + id: str | int, + properties: Optional[dict] = None, + node_type: Optional[str] = None, + ): + """ + Create a new node with the given id and properties to the remote graph and fail if the node already exists. + + Arguments: + timestamp (int|str|datetime): The timestamp of the node. + id (str|int): The id of the node. + properties (dict, optional): The properties of the node. + node_type (str, optional): The optional string which will be used as a node type + Returns: + RemoteNode + """ + def delete_edge( self, timestamp: int, diff --git a/python/tests/graphql/edit_graph/test_graphql.py b/python/tests/graphql/edit_graph/test_graphql.py index 28b36f959d..a75d6259f0 100644 --- a/python/tests/graphql/edit_graph/test_graphql.py +++ b/python/tests/graphql/edit_graph/test_graphql.py @@ -21,8 +21,8 @@ def test_encode_graph(): encoded = encode_graph(g) assert ( - encoded - == "EgxaCgoIX2RlZmF1bHQSDBIKCghfZGVmYXVsdBoFCgNiZW4aCQoFaGFtemEYARoLCgdoYWFyb29uGAIiAhABIgYIAhABGAEiBBACGAIqAhoAKgQSAhABKgQSAhADKgIKACoGEgQIARABKgYSBAgBEAIqBAoCCAEqBhIECAIQAioGEgQIAhADKgQKAggCKgQ6AhABKgIyACoIOgYIARACGAEqBDICCAEqCDoGCAIQAxgCKgQyAggC" + encoded + == "EgxaCgoIX2RlZmF1bHQSDBIKCghfZGVmYXVsdBoFCgNiZW4aCQoFaGFtemEYARoLCgdoYWFyb29uGAIiAhABIgYIAhABGAEiBBACGAIqAhoAKgQSAhABKgQSAhADKgIKACoGEgQIARABKgYSBAgBEAIqBAoCCAEqBhIECAIQAioGEgQIAhADKgQKAggCKgQ6AhABKgIyACoIOgYIARACGAEqBDICCAEqCDoGCAIQAxgCKgQyAggC" ) @@ -42,8 +42,8 @@ def test_wrong_url(): with pytest.raises(Exception) as excinfo: client = RaphtoryClient("http://broken_url.com") assert ( - str(excinfo.value) - == "Could not connect to the given server - no response --error sending request for url (http://broken_url.com/)" + str(excinfo.value) + == "Could not connect to the given server - no response --error sending request for url (http://broken_url.com/)" ) @@ -393,16 +393,14 @@ def test_create_node(): query_nodes = """{graph(path: "g") {nodes {list {name}}}}""" assert client.query(query_nodes) == { - "graph": { - "nodes": { - "list": [{"name": "ben"}, {"name": "shivam"}] - } - } + "graph": {"nodes": {"list": [{"name": "ben"}, {"name": "shivam"}]}} } create_node_query = """{updateGraph(path: "g") { createNode(time: 0, name: "oogway") { success } }}""" - assert client.query(create_node_query) == {"updateGraph": {"createNode": {"success": True}}} + assert client.query(create_node_query) == { + "updateGraph": {"createNode": {"success": True}} + } assert client.query(query_nodes) == { "graph": { "nodes": { @@ -428,11 +426,7 @@ def test_create_node_using_client(): query_nodes = """{graph(path: "g") {nodes {list {name}}}}""" assert client.query(query_nodes) == { - "graph": { - "nodes": { - "list": [{"name": "ben"}, {"name": "shivam"}] - } - } + "graph": {"nodes": {"list": [{"name": "ben"}, {"name": "shivam"}]}} } remote_graph = client.remote_graph(path="g") @@ -460,23 +454,56 @@ def test_create_node_using_client_with_properties(): client = RaphtoryClient("http://localhost:1737") client.send_graph(path="g", graph=g) - query_nodes = """{graph(path: "g") {nodes {list {name, properties { keys }}}}}""" + query_nodes = ( + """{graph(path: "g") {nodes {list {name, properties { keys }}}}}""" + ) assert client.query(query_nodes) == { "graph": { "nodes": { - "list": [{"name": "ben", 'properties': {'keys': []}}, {"name": "shivam", 'properties': {'keys': []}}] + "list": [ + {"name": "ben", "properties": {"keys": []}}, + {"name": "shivam", "properties": {"keys": []}}, + ] } } } remote_graph = client.remote_graph(path="g") - remote_graph.create_node(timestamp=0, id="oogway", properties={"prop1": 60, "prop2": 31.3, "prop3": "abc123", "prop4": True, "prop5": [1, 2, 3]}) - nodes = json.loads(json.dumps(client.query(query_nodes)))['graph']['nodes']['list'] - node_oogway = next(node for node in nodes if node['name'] == 'oogway') - assert sorted(node_oogway['properties']['keys']) == ['prop1', 'prop2', 'prop3', 'prop4', 'prop5'] + remote_graph.create_node( + timestamp=0, + id="oogway", + properties={ + "prop1": 60, + "prop2": 31.3, + "prop3": "abc123", + "prop4": True, + "prop5": [1, 2, 3], + }, + ) + nodes = json.loads(json.dumps(client.query(query_nodes)))["graph"]["nodes"][ + "list" + ] + node_oogway = next(node for node in nodes if node["name"] == "oogway") + assert sorted(node_oogway["properties"]["keys"]) == [ + "prop1", + "prop2", + "prop3", + "prop4", + "prop5", + ] with pytest.raises(Exception) as excinfo: - remote_graph.create_node(timestamp=0, id="oogway", properties={"prop1": 60, "prop2": 31.3, "prop3": "abc123", "prop4": True, "prop5": [1, 2, 3]}) + remote_graph.create_node( + timestamp=0, + id="oogway", + properties={ + "prop1": 60, + "prop2": 31.3, + "prop3": "abc123", + "prop4": True, + "prop5": [1, 2, 3], + }, + ) assert "Node already exists" in str(excinfo.value) @@ -494,20 +521,57 @@ def test_create_node_using_client_with_properties_node_type(): assert client.query(query_nodes) == { "graph": { "nodes": { - "list": [{"name": "ben", 'nodeType': None, 'properties': {'keys': []}}, {"name": "shivam", 'nodeType': None, 'properties': {'keys': []}}] + "list": [ + {"name": "ben", "nodeType": None, "properties": {"keys": []}}, + { + "name": "shivam", + "nodeType": None, + "properties": {"keys": []}, + }, + ] } } } remote_graph = client.remote_graph(path="g") - remote_graph.create_node(timestamp=0, id="oogway", properties={"prop1": 60, "prop2": 31.3, "prop3": "abc123", "prop4": True, "prop5": [1, 2, 3]}, node_type="master") - nodes = json.loads(json.dumps(client.query(query_nodes)))['graph']['nodes']['list'] - node_oogway = next(node for node in nodes if node['name'] == 'oogway') - assert node_oogway['nodeType'] == 'master' - assert sorted(node_oogway['properties']['keys']) == ['prop1', 'prop2', 'prop3', 'prop4', 'prop5'] + remote_graph.create_node( + timestamp=0, + id="oogway", + properties={ + "prop1": 60, + "prop2": 31.3, + "prop3": "abc123", + "prop4": True, + "prop5": [1, 2, 3], + }, + node_type="master", + ) + nodes = json.loads(json.dumps(client.query(query_nodes)))["graph"]["nodes"][ + "list" + ] + node_oogway = next(node for node in nodes if node["name"] == "oogway") + assert node_oogway["nodeType"] == "master" + assert sorted(node_oogway["properties"]["keys"]) == [ + "prop1", + "prop2", + "prop3", + "prop4", + "prop5", + ] with pytest.raises(Exception) as excinfo: - remote_graph.create_node(timestamp=0, id="oogway", properties={"prop1": 60, "prop2": 31.3, "prop3": "abc123", "prop4": True, "prop5": [1, 2, 3]}, node_type="master") + remote_graph.create_node( + timestamp=0, + id="oogway", + properties={ + "prop1": 60, + "prop2": 31.3, + "prop3": "abc123", + "prop4": True, + "prop5": [1, 2, 3], + }, + node_type="master", + ) assert "Node already exists" in str(excinfo.value) @@ -525,7 +589,10 @@ def test_create_node_using_client_with_node_type(): assert client.query(query_nodes) == { "graph": { "nodes": { - "list": [{"name": "ben", 'nodeType': None}, {"name": "shivam", 'nodeType': None}] + "list": [ + {"name": "ben", "nodeType": None}, + {"name": "shivam", "nodeType": None}, + ] } } } @@ -535,7 +602,11 @@ def test_create_node_using_client_with_node_type(): assert client.query(query_nodes) == { "graph": { "nodes": { - "list": [{"name": "ben", 'nodeType': None}, {"name": "shivam", 'nodeType': None}, {"name": "oogway", 'nodeType': "master"}] + "list": [ + {"name": "ben", "nodeType": None}, + {"name": "shivam", "nodeType": None}, + {"name": "oogway", "nodeType": "master"}, + ] } } } diff --git a/python/tests/test_graphdb/test_graphdb.py b/python/tests/test_graphdb/test_graphdb.py index eb1c43b6bd..bed24790ca 100644 --- a/python/tests/test_graphdb/test_graphdb.py +++ b/python/tests/test_graphdb/test_graphdb.py @@ -353,8 +353,8 @@ def test_getitem(): @with_disk_graph def check(g): assert ( - g.node(1).properties.temporal.get("cost") - == g.node(1).properties.temporal["cost"] + g.node(1).properties.temporal.get("cost") + == g.node(1).properties.temporal["cost"] ) check(g) @@ -607,7 +607,7 @@ def time_history_test(time, key, value): assert g.at(time).node(1).properties.temporal.get(key) is None assert g.at(time).nodes.properties.temporal.get(key) is None assert ( - g.at(time).nodes.out_neighbours.properties.temporal.get(key) is None + g.at(time).nodes.out_neighbours.properties.temporal.get(key) is None ) else: assert g.at(time).node(1).properties.temporal.get(key).items() == value @@ -812,22 +812,22 @@ def no_static_property_test(key, value): assert sorted(g.node(1).properties.temporal.keys()) == expected_names_no_static assert sorted(g.nodes.properties.temporal.keys()) == expected_names_no_static assert ( - sorted(g.nodes.out_neighbours.properties.temporal.keys()) - == expected_names_no_static + sorted(g.nodes.out_neighbours.properties.temporal.keys()) + == expected_names_no_static ) expected_names_no_static_at_1 = sorted(["prop 4", "prop 1", "prop 3"]) assert ( - sorted(g.at(1).node(1).properties.temporal.keys()) - == expected_names_no_static_at_1 + sorted(g.at(1).node(1).properties.temporal.keys()) + == expected_names_no_static_at_1 ) assert ( - sorted(g.at(1).nodes.properties.temporal.keys()) - == expected_names_no_static_at_1 + sorted(g.at(1).nodes.properties.temporal.keys()) + == expected_names_no_static_at_1 ) assert ( - sorted(g.at(1).nodes.out_neighbours.properties.temporal.keys()) - == expected_names_no_static_at_1 + sorted(g.at(1).nodes.out_neighbours.properties.temporal.keys()) + == expected_names_no_static_at_1 ) # testing has_property @@ -1325,11 +1325,11 @@ def test_constant_property_update(): def updates(v): v.update_constant_properties({"prop1": "value1", "prop2": 123}) assert ( - v.properties.get("prop1") == "value1" and v.properties.get("prop2") == 123 + v.properties.get("prop1") == "value1" and v.properties.get("prop2") == 123 ) v.update_constant_properties({"prop1": "value2", "prop2": 345}) assert ( - v.properties.get("prop1") == "value2" and v.properties.get("prop2") == 345 + v.properties.get("prop1") == "value2" and v.properties.get("prop2") == 345 ) v.add_constant_properties({"name": "value1"}) @@ -1666,18 +1666,18 @@ def check(g): assert g.exclude_layer("layer2").count_edges() == 4 with pytest.raises( - Exception, - match=re.escape( - "Invalid layer: test_layer. Valid layers: _default, layer1, layer2" - ), + Exception, + match=re.escape( + "Invalid layer: test_layer. Valid layers: _default, layer1, layer2" + ), ): g.layers(["test_layer"]) with pytest.raises( - Exception, - match=re.escape( - "Invalid layer: test_layer. Valid layers: _default, layer1, layer2" - ), + Exception, + match=re.escape( + "Invalid layer: test_layer. Valid layers: _default, layer1, layer2" + ), ): g.edge(1, 2).layers(["test_layer"]) @@ -1754,20 +1754,20 @@ def test_layer_name(): assert str(e.value) == error_msg assert [ - list(iterator) for iterator in g.nodes.neighbours.edges.explode().layer_name - ] == [ - ["_default", "awesome layer"], - ["_default", "awesome layer"], - ["_default", "awesome layer"], - ] + list(iterator) for iterator in g.nodes.neighbours.edges.explode().layer_name + ] == [ + ["_default", "awesome layer"], + ["_default", "awesome layer"], + ["_default", "awesome layer"], + ] assert [ - list(iterator) - for iterator in g.nodes.neighbours.edges.explode_layers().layer_name - ] == [ - ["_default", "awesome layer"], - ["_default", "awesome layer"], - ["_default", "awesome layer"], - ] + list(iterator) + for iterator in g.nodes.neighbours.edges.explode_layers().layer_name + ] == [ + ["_default", "awesome layer"], + ["_default", "awesome layer"], + ["_default", "awesome layer"], + ] def test_time(): @@ -1801,12 +1801,12 @@ def check(g): # assert str(e.value) == error_msg assert [ - list(iterator) for iterator in g.nodes.neighbours.edges.explode().time - ] == [ - [0, 0, 1], - [0, 0, 1], - [0, 0, 1], - ] + list(iterator) for iterator in g.nodes.neighbours.edges.explode().time + ] == [ + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + ] check(g) @@ -2371,8 +2371,8 @@ def test_weird_windows(): @with_disk_graph def check(g): with pytest.raises( - Exception, - match="'ddd' is not a valid datetime, valid formats are RFC3339, RFC2822, %Y-%m-%d, %Y-%m-%dT%H:%M:%S%.3f, %Y-%m-%dT%H:%M:%S%, %Y-%m-%d %H:%M:%S%.3f and %Y-%m-%d %H:%M:%S%", + Exception, + match="'ddd' is not a valid datetime, valid formats are RFC3339, RFC2822, %Y-%m-%d, %Y-%m-%dT%H:%M:%S%.3f, %Y-%m-%dT%H:%M:%S%, %Y-%m-%d %H:%M:%S%.3f and %Y-%m-%d %H:%M:%S%", ): g.window("ddd", "aaa") @@ -2569,9 +2569,9 @@ def check(g): assert g.nodes.type_filter(["a"]).neighbours.type_filter( ["c"] ).name.collect() == [ - [], - ["5"], - ] + [], + ["5"], + ] assert g.nodes.type_filter(["a"]).neighbours.type_filter([]).name.collect() == [ [], [], @@ -2582,9 +2582,9 @@ def check(g): assert g.nodes.type_filter(["a"]).neighbours.type_filter( ["d"] ).name.collect() == [ - [], - [], - ] + [], + [], + ] assert g.nodes.type_filter(["a"]).neighbours.neighbours.name.collect() == [ ["1", "3", "4"], ["1", "3", "4", "4", "6"], @@ -2639,20 +2639,20 @@ def check(g): for edge in edges: time_nested.append(edge.time) assert [ - item - for sublist in g.nodes.edges.explode().time.collect() - for item in sublist - ] == time_nested + item + for sublist in g.nodes.edges.explode().time.collect() + for item in sublist + ] == time_nested date_time_nested = [] for edges in g.nodes.edges.explode(): for edge in edges: date_time_nested.append(edge.date_time) assert [ - item - for sublist in g.nodes.edges.explode().date_time.collect() - for item in sublist - ] == date_time_nested + item + for sublist in g.nodes.edges.explode().date_time.collect() + for item in sublist + ] == date_time_nested check(g) @@ -2966,44 +2966,44 @@ def check(g): assert len(index.fuzzy_search_nodes("name:habza", levenshtein_distance=1)) == 1 assert ( - len( - index.fuzzy_search_nodes( - "name:haa", levenshtein_distance=1, prefix=True - ) + len( + index.fuzzy_search_nodes( + "name:haa", levenshtein_distance=1, prefix=True ) - == 2 + ) + == 2 ) assert ( - len( - index.fuzzy_search_nodes( - "value_str:abc123", levenshtein_distance=2, prefix=True - ) + len( + index.fuzzy_search_nodes( + "value_str:abc123", levenshtein_distance=2, prefix=True ) - == 2 + ) + == 2 ) assert ( - len( - index.fuzzy_search_nodes( - "value_str:dsss312", levenshtein_distance=2, prefix=False - ) + len( + index.fuzzy_search_nodes( + "value_str:dsss312", levenshtein_distance=2, prefix=False ) - == 1 + ) + == 1 ) assert len(index.fuzzy_search_edges("from:bon", levenshtein_distance=1)) == 2 assert ( - len( - index.fuzzy_search_edges("from:bo", levenshtein_distance=1, prefix=True) - ) - == 2 + len( + index.fuzzy_search_edges("from:bo", levenshtein_distance=1, prefix=True) + ) + == 2 ) assert ( - len( - index.fuzzy_search_edges( - "from:eon", levenshtein_distance=2, prefix=True - ) + len( + index.fuzzy_search_edges( + "from:eon", levenshtein_distance=2, prefix=True ) - == 2 + ) + == 2 ) check(g) @@ -3011,26 +3011,38 @@ def check(g): def test_create_node_graph(): g = Graph() - g.create_node(1, "shivam", properties={"value": 60, "value_f": 31.3, "value_str": "abc123"}) + g.create_node( + 1, "shivam", properties={"value": 60, "value_f": 31.3, "value_str": "abc123"} + ) node = g.node("shivam") assert node.name == "shivam" assert node.properties == {"value": 60, "value_f": 31.3, "value_str": "abc123"} with pytest.raises(Exception) as excinfo: - g.create_node(1, "shivam", properties={"value": 60, "value_f": 31.3, "value_str": "abc123"}) + g.create_node( + 1, + "shivam", + properties={"value": 60, "value_f": 31.3, "value_str": "abc123"}, + ) assert "Node already exists" in str(excinfo.value) def test_create_node_graph_with_deletion(): g = PersistentGraph() - g.create_node(1, "shivam", properties={"value": 60, "value_f": 31.3, "value_str": "abc123"}) + g.create_node( + 1, "shivam", properties={"value": 60, "value_f": 31.3, "value_str": "abc123"} + ) node = g.node("shivam") assert node.name == "shivam" assert node.properties == {"value": 60, "value_f": 31.3, "value_str": "abc123"} with pytest.raises(Exception) as excinfo: - g.create_node(1, "shivam", properties={"value": 60, "value_f": 31.3, "value_str": "abc123"}) + g.create_node( + 1, + "shivam", + properties={"value": 60, "value_f": 31.3, "value_str": "abc123"}, + ) assert "Node already exists" in str(excinfo.value) @@ -3047,6 +3059,7 @@ def datadir(tmpdir, request): raise e return tmpdir + # def currently_broken_fuzzy_search(): #TODO: Fix fuzzy searching for properties # g = Graph() # g.add_edge(2,"haaroon","hamza", properties={"value":60,"value_f":31.3,"value_str":"abc123"}) diff --git a/python/tests/test_graphdb/test_graphdb_imports.py b/python/tests/test_graphdb/test_graphdb_imports.py index 836771562b..0dd6fda55a 100644 --- a/python/tests/test_graphdb/test_graphdb_imports.py +++ b/python/tests/test_graphdb/test_graphdb_imports.py @@ -56,3 +56,31 @@ def test_import_with_int(): g2.import_node(g.node(1)) g2.import_nodes([g.node(2), g.node(3)]) assert g2.count_nodes() == g.count_nodes() + + +def test_import_edges(): + g = Graph() + g.add_node(1, 1) + g.add_node(1, 2) + g.add_node(1, 3) + g.add_edge(1, 4, 5) + g.add_edge(1, 6, 7) + g.add_edge(1, 8, 9) + g2 = Graph() + g2.import_edges(g.edges) + assert g2.count_edges() == 3 + assert g.edges.id == g2.edges.id + + +def test_import_edges_iterator(): + g = Graph() + g.add_node(1, 1) + g.add_node(1, 2) + g.add_node(1, 3) + g.add_edge(1, 4, 5) + g.add_edge(1, 6, 7) + g.add_edge(1, 8, 9) + g2 = Graph() + g2.import_edges(iter(g.edges)) + assert g2.count_edges() == 3 + assert g.edges.id == g2.edges.id diff --git a/raphtory/src/python/graph/graph.rs b/raphtory/src/python/graph/graph.rs index 5f48490171..583710b659 100644 --- a/raphtory/src/python/graph/graph.rs +++ b/raphtory/src/python/graph/graph.rs @@ -17,6 +17,7 @@ use crate::{ edge::PyEdge, graph_with_deletions::PyPersistentGraph, io::pandas_loaders::*, node::PyNode, views::graph_view::PyGraphView, }, + types::iterable::FromIterable, utils::{PyNodeRef, PyTime}, }, serialise::{StableDecode, StableEncode}, @@ -302,7 +303,7 @@ impl PyGraph { /// force (bool): An optional boolean flag indicating whether to force the import of the nodes. /// #[pyo3(signature = (nodes, force = false))] - pub fn import_nodes(&self, nodes: Vec, force: bool) -> Result<(), GraphError> { + pub fn import_nodes(&self, nodes: FromIterable, force: bool) -> Result<(), GraphError> { let node_views = nodes.iter().map(|node| &node.node); self.graph.import_nodes(node_views, force) } @@ -338,7 +339,7 @@ impl PyGraph { /// edges (List[Edge]): A list of Edge objects representing the edges to be imported. /// force (bool): An optional boolean flag indicating whether to force the import of the edges. #[pyo3(signature = (edges, force = false))] - pub fn import_edges(&self, edges: Vec, force: bool) -> Result<(), GraphError> { + pub fn import_edges(&self, edges: FromIterable, force: bool) -> Result<(), GraphError> { let edge_views = edges.iter().map(|edge| &edge.edge); self.graph.import_edges(edge_views, force) } diff --git a/raphtory/src/python/types/iterable.rs b/raphtory/src/python/types/iterable.rs index 9f73ffea8b..9ac5b6fa9c 100644 --- a/raphtory/src/python/types/iterable.rs +++ b/raphtory/src/python/types/iterable.rs @@ -2,8 +2,12 @@ use crate::{ db::api::view::BoxedIter, python::types::repr::{iterator_repr, Repr}, }; -use pyo3::{IntoPy, PyObject}; -use std::{marker::PhantomData, sync::Arc}; +use pyo3::prelude::*; +use std::{ + marker::PhantomData, + ops::{Deref, DerefMut}, + sync::Arc, +}; pub struct Iterable + From + Repr> { pub name: &'static str, @@ -113,3 +117,31 @@ impl + From + Repr> Repr for NestedIterable(Vec); +impl Deref for FromIterable { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FromIterable { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'py, T: FromPyObject<'py>> FromPyObject<'py> for FromIterable { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let len = ob.len().unwrap_or(0); + let mut vec = Vec::::with_capacity(len); + { + for value in ob.iter()? { + vec.push(value?.extract()?) + } + } + Ok(Self(vec)) + } +}