diff --git a/gqlalchemy/query_builders/declarative_base.py b/gqlalchemy/query_builders/declarative_base.py index e97094a1..7498937d 100644 --- a/gqlalchemy/query_builders/declarative_base.py +++ b/gqlalchemy/query_builders/declarative_base.py @@ -30,7 +30,13 @@ from gqlalchemy.graph_algorithms.integrated_algorithms import IntegratedAlgorithm from gqlalchemy.vendors.memgraph import Memgraph from gqlalchemy.models import Node, Relationship -from gqlalchemy.utilities import to_cypher_labels, to_cypher_properties, to_cypher_value, to_cypher_qm_arguments +from gqlalchemy.utilities import ( + to_cypher_labels, + to_cypher_properties, + to_cypher_value, + to_cypher_qm_arguments, + to_null_operator, +) from gqlalchemy.vendors.database_client import DatabaseClient @@ -186,7 +192,7 @@ def _build_where_query(self, item: str, operator: Operator, **kwargs) -> "Declar if value is None: if literal is None: - raise GQLAlchemyLiteralAndExpressionMissing(clause=self.type) + operator_str = to_null_operator(operator_str) value = to_cypher_value(literal) elif literal is not None: diff --git a/gqlalchemy/utilities.py b/gqlalchemy/utilities.py index b58e1891..c7b1c43b 100644 --- a/gqlalchemy/utilities.py +++ b/gqlalchemy/utilities.py @@ -84,10 +84,20 @@ def _is_torch_tensor(value): return False +def to_null_operator(value: str) -> str: + if value == "=": + return "IS" + if value == "!=" or value == "<>": + return "IS NOT" + raise InvalidOperatorException(f"Operator {value} can not be used with None") + + def to_cypher_value(value: Any, config: NetworkXCypherConfig = None) -> str: """Converts value to a valid Cypher type.""" if config is None: config = NetworkXCypherConfig() + if value is None: + return "null" value_type = type(value) @@ -264,3 +274,7 @@ def __str__(self) -> str: class NanException(Exception): pass + + +class InvalidOperatorException(Exception): + pass diff --git a/tests/query_builders/test_query_builders.py b/tests/query_builders/test_query_builders.py index 7e4ae429..a14ae45a 100644 --- a/tests/query_builders/test_query_builders.py +++ b/tests/query_builders/test_query_builders.py @@ -358,6 +358,26 @@ def test_where_property(self, vendor): mock.assert_called_with(expected_query) + @pytest.mark.parametrize("operator", ["<>", "!=", "="]) + def test_where_property_null_operators(self, vendor, operator): + query_builder = ( + vendor[1] + .match() + .node(labels="L1", variable="n") + .to(relationship_type="TO") + .node(labels="L2", variable="m") + .where(item="n.name", operator=operator, literal=None) + .return_() + ) + expected_query = ( + f" MATCH (n:L1)-[:TO]->(m:L2) WHERE n.name IS {'NOT ' if operator != '=' else ''}null RETURN * " + ) + + with patch.object(vendor[0], "execute_and_fetch", return_value=None) as mock: + query_builder.execute() + + mock.assert_called_with(expected_query) + def test_where_not_property(self, vendor): query_builder = ( vendor[1]