Skip to content

Commit

Permalink
WIP Convert query to nodesets
Browse files Browse the repository at this point in the history
  • Loading branch information
GianlucaFicarelli committed Jun 14, 2023
1 parent 2da1c6d commit 1fd44cf
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
47 changes: 47 additions & 0 deletions bluepysnap/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
VALUE_KEYS = {REGEX_KEY}
ALL_KEYS = {NODE_ID_KEY, EDGE_ID_KEY, POPULATION_KEY, OR_KEY, AND_KEY, NODE_SET_KEY} | VALUE_KEYS

GT_KEY = "$gt"
LT_KEY = "$lt"
GTE_KEY = "$gte"
LTE_KEY = "$lte"


# TODO: move to `libsonata` library
def _complex_query(prop, query):
Expand Down Expand Up @@ -158,3 +163,45 @@ def _collect(queries, queries_key):
queries = deepcopy(queries)
traverse_queries_bottom_up(queries, _collect)
return _merge_queries_masks(queries)


def _convert(queries, node_sets, node_set_name):
"""Convert a query to node_sets."""
for n, (queries_key, queries_value) in enumerate(queries.items()):
if queries_key == OR_KEY:
# create a node_set for each item, and a combined node_set with the list
assert len(queries) == 1, f"Mixing {OR_KEY} and other keys isn't supported yet"
names = []
for n2, x in enumerate(queries_value):
assert isinstance(x, dict)
name = f"{node_set_name}_{n2}"
_convert(x, node_sets, name)
names.append(name)
node_sets[node_set_name] = names
elif queries_key == AND_KEY:
assert len(queries) == 1, f"Mixing {AND_KEY} and other keys isn't supported yet"
raise NotImplementedError
else:
if isinstance(queries_value, tuple) and len(queries_value) == 2:
start, stop = queries_value
queries_value = {GTE_KEY: start, LTE_KEY: stop}
assert (
isinstance(queries_value, (str, int, float))
or isinstance(queries_value, list)
and all(isinstance(i, (str, int, float)) for i in queries_value)
or isinstance(queries_value, dict)
and {REGEX_KEY, GT_KEY, LT_KEY, GTE_KEY, LTE_KEY}.issuperset(queries_value)
and all(isinstance(i, (str, int, float)) for i in queries_value.values())
), (
"Value should be a scalar, a list of scalars, a dict of operators, "
"or a tuple of 2 elements representing an interval."
)
node_sets.setdefault(node_set_name, {})[queries_key] = deepcopy(queries_value)


def to_node_set(queries):
name = "ns"
node_sets = {}
_convert(queries, node_sets, name)
return node_sets, name
# return NodeSets.from_dict(node_sets)[name]
31 changes: 30 additions & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from bluepysnap import BluepySnapError
from bluepysnap.query import _circuit_mask, _positional_mask, resolve_ids
from bluepysnap.query import _circuit_mask, _positional_mask, resolve_ids, to_node_set


def test_positional_mask():
Expand Down Expand Up @@ -58,3 +58,32 @@ def test_resolve_ids():
with pytest.raises(BluepySnapError) as e:
resolve_ids(data, "", {"str": {"$regex": "*.some", "edge_id": 2}})
assert "Value operators can't be used with plain values" in e.value.args[0]


@pytest.mark.parametrize(
"queries, expected",
[
(
{"x": (0, 1), "mtype": "L1_SLAC"},
{"ns": {"mtype": "L1_SLAC", "x": {"$gte": 0, "$lte": 1}}},
),
(
{"$or": [{"layer": [2, 3]}, {"x": (0, 1), "mtype": "L1_SLAC"}]},
{
"ns_0": {"layer": [2, 3]},
"ns_1": {"x": {"$gte": 0, "$lte": 1}, "mtype": "L1_SLAC"},
"ns": ["ns_0", "ns_1"],
},
),
],
)
def test_to_node_set(queries, expected):
node_sets, name = to_node_set(queries)
assert name == "ns"
assert node_sets == expected


def test_to_node_raises():
queries = {"$and": [{"mtype": "L6_Y"}, {"morphology": "morph-B"}]}
with pytest.raises(NotImplementedError):
to_node_set(queries)

0 comments on commit 1fd44cf

Please sign in to comment.