Skip to content

Commit

Permalink
Added possibility to query Edge IDs and Node IDs based on edge/node p…
Browse files Browse the repository at this point in the history
…opulation type
  • Loading branch information
Joni Herttuainen committed Jan 4, 2024
1 parent 19866e0 commit da04bcc
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ New Features

- deprecated ``validate``

Improvements
~~~~~~~~~~~~
- Possibility to query Edge IDs and Node IDs based on edge/node population type using query key ``population_type``

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/edges/edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _edge_ids_by_filter(self, queries, raise_missing_prop):
chunk_size = int(1e8)
for chunk in np.array_split(ids, 1 + len(ids) // chunk_size):
data = self.get(chunk, properties - unknown_props)
res.extend(chunk[query.resolve_ids(data, self.name, queries)])
res.extend(chunk[query.resolve_ids(data, self.name, self.type, queries)])
return np.array(res, dtype=IDS_DTYPE)

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/nodes/node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def _node_ids_by_filter(self, queries, raise_missing_prop):
self._check_properties(properties)
# load all the properties needed to execute the query, excluding the unknown properties
data = self._get_data(properties & self.property_names)
idx = query.resolve_ids(data, self.name, queries)
idx = query.resolve_ids(data, self.name, self.type, queries)
return idx.nonzero()[0]

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
Expand Down
24 changes: 18 additions & 6 deletions bluepysnap/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@
NODE_ID_KEY = "node_id"
EDGE_ID_KEY = "edge_id"
POPULATION_KEY = "population"
POPULATION_TYPE_KEY = "population_type"
OR_KEY = "$or"
AND_KEY = "$and"
REGEX_KEY = "$regex"
NODE_SET_KEY = "$node_set"
VALUE_KEYS = {REGEX_KEY}
ALL_KEYS = {NODE_ID_KEY, EDGE_ID_KEY, POPULATION_KEY, OR_KEY, AND_KEY, NODE_SET_KEY} | VALUE_KEYS
ALL_KEYS = {
NODE_ID_KEY,
EDGE_ID_KEY,
POPULATION_KEY,
POPULATION_TYPE_KEY,
OR_KEY,
AND_KEY,
NODE_SET_KEY,
} | VALUE_KEYS


def _logical_and(masks):
Expand Down Expand Up @@ -97,23 +106,26 @@ def _positional_mask(data, ids):
return mask


def _circuit_mask(data, population_name, queries):
def _circuit_mask(data, population_name, population_type, queries):
"""Handle the population, node ID queries."""
populations = queries.pop(POPULATION_KEY, None)
types = queries.pop(POPULATION_TYPE_KEY, None)
if populations is not None and population_name not in set(utils.ensure_list(populations)):
ids = []
elif types is not None and population_type not in set(utils.ensure_list(types)):
ids = []
else:
ids = queries.pop(NODE_ID_KEY, queries.pop(EDGE_ID_KEY, None))
return queries, _positional_mask(data, ids)


def _properties_mask(data, population_name, queries):
def _properties_mask(data, population_name, population_type, queries):
"""Return mask of IDs matching `props` dict."""
unknown_props = set(queries) - set(data.columns) - ALL_KEYS
if unknown_props:
return False

queries, mask = _circuit_mask(data, population_name, queries)
queries, mask = _circuit_mask(data, population_name, population_type, queries)
if mask is False or isinstance(mask, np.ndarray) and not mask.any():
# Avoid fail and/or processing time if wrong population or no nodes
return False
Expand Down Expand Up @@ -201,7 +213,7 @@ def _resolve(queries, queries_key):
return resolved_queries


def resolve_ids(data, population_name, queries):
def resolve_ids(data, population_name, population_type, queries):
"""Returns an index mask of `data` for given `queries`.
Args:
Expand All @@ -228,7 +240,7 @@ def _collect(queries, queries_key):
queries[queries_key] = _logical_and(children_mask)
else:
queries[queries_key] = _properties_mask(
data, population_name, {queries_key: queries[queries_key]}
data, population_name, population_type, {queries_key: queries[queries_key]}
)

queries = deepcopy(queries)
Expand Down
27 changes: 26 additions & 1 deletion tests/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from bluepysnap.circuit_ids_types import IDS_DTYPE, CircuitEdgeId, CircuitNodeId
from bluepysnap.exceptions import BluepySnapError

from utils import PICKLED_SIZE_ADJUSTMENT, TEST_DATA_DIR
from utils import PICKLED_SIZE_ADJUSTMENT, TEST_DATA_DIR, copy_test_data, edit_config


class TestEdges:
Expand Down Expand Up @@ -183,6 +183,31 @@ def test_ids(self):
expected = CircuitEdgeIds.from_dict({"default": [1, 2, 3], "default2": [1, 2, 3]})
assert tested == expected

# Test querying with the population type
with copy_test_data() as (_, config_path):
with edit_config(config_path) as config:
config["networks"]["edges"][0]["populations"]["default"]["type"] = "electrical"

circuit = Circuit(config_path)

tested = circuit.edges.ids({"population_type": "electrical"})
expected = CircuitEdgeIds.from_arrays(4 * ["default"], [0, 1, 2, 3])
assert tested == expected

tested = circuit.edges.ids({"population_type": "chemical"})
expected = CircuitEdgeIds.from_arrays(4 * ["default2"], [0, 1, 2, 3])
assert tested == expected

tested = circuit.edges.ids(
{"population_type": ["electrical", "chemical"], "node_id": [0]}
)
expected = CircuitEdgeIds.from_tuples([("default", 0), ("default2", 0)])
assert tested == expected

tested = circuit.edges.ids({"population_type": "fake"})
expected = CircuitEdgeIds.from_arrays([], [])
assert tested == expected

def test_get(self):
with pytest.raises(BluepySnapError, match="You need to set edge_ids in get."):
self.test_obj.get(properties=["other2"])
Expand Down
71 changes: 70 additions & 1 deletion tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.node_sets import NodeSets

from utils import PICKLED_SIZE_ADJUSTMENT, TEST_DATA_DIR
from utils import PICKLED_SIZE_ADJUSTMENT, TEST_DATA_DIR, copy_test_data, edit_config


class TestNodes:
Expand Down Expand Up @@ -234,6 +234,31 @@ def test_ids(self):
expected = CircuitNodeIds.from_dict({"default": [0, 1, 2], "default2": [0]})
assert tested == expected

# Test querying with the population type
with copy_test_data() as (_, config_path):
with edit_config(config_path) as config:
config["networks"]["nodes"][0]["populations"]["default"]["type"] = "virtual"

circuit = Circuit(config_path)

tested = circuit.nodes.ids({"population_type": "virtual"})
expected = CircuitNodeIds.from_arrays(3 * ["default"], [0, 1, 2])
assert tested == expected

tested = circuit.nodes.ids({"population_type": "biophysical"})
expected = CircuitNodeIds.from_arrays(4 * ["default2"], [0, 1, 2, 3])
assert tested == expected

tested = circuit.nodes.ids(
{"population_type": ["virtual", "biophysical"], "node_id": [0]}
)
expected = CircuitNodeIds.from_tuples([("default", 0), ("default2", 0)])
assert tested == expected

tested = circuit.nodes.ids({"population_type": "fake"})
expected = CircuitNodeIds.from_arrays([], [])
assert tested == expected

def test_get(self):
# return all properties for all the ids
tested = self.test_obj.get()
Expand Down Expand Up @@ -366,6 +391,50 @@ def test_get(self):
with pytest.raises(BluepySnapError, match="Unknown properties required: {'unknown'}"):
next(self.test_obj.get(properties="unknown"))

# Test querying with the population type
with copy_test_data() as (_, config_path):
with edit_config(config_path) as config:
config["networks"]["nodes"][0]["populations"]["default"]["type"] = "virtual"

circuit = Circuit(config_path)

tested = circuit.nodes.get(group={"population_type": "virtual"}, properties=["layer"])
tested = pd.concat(df for _, df in tested)
expected = pd.DataFrame(
{"layer": np.array([2, 6, 6], dtype=int)},
index=pd.MultiIndex.from_tuples(
[
("default", 0),
("default", 1),
("default", 2),
],
names=["population", "node_ids"],
),
)
pdt.assert_frame_equal(tested, expected)

tested = circuit.nodes.get(
group={"population_type": "biophysical"}, properties=["layer"]
)
tested = pd.concat(df for _, df in tested)
expected = pd.DataFrame(
{"layer": np.array([7, 8, 8, 2], dtype=int)},
index=pd.MultiIndex.from_tuples(
[
("default2", 0),
("default2", 1),
("default2", 2),
("default2", 3),
],
names=["population", "node_ids"],
),
)
pdt.assert_frame_equal(tested, expected)

tested = circuit.nodes.get(group={"population_type": "fake"}, properties=["layer"])
tested = dict(tested)
assert tested == {}

def test_functionality_with_separate_node_set(self):
with pytest.raises(BluepySnapError, match="Undefined node set"):
self.test_obj.ids("ExtraLayer2")
Expand Down
32 changes: 18 additions & 14 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,25 @@ def test_positional_mask():

def test_population_mask():
data = pd.DataFrame(range(3))
queries, mask = _circuit_mask(data, "default", {"population": "default", "other": "val"})
queries, mask = _circuit_mask(
data, "default", "biophysical", {"population": "default", "other": "val"}
)
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [True, True, True])

queries, mask = _circuit_mask(data, "default", {"population": "unknown", "other": "val"})
queries, mask = _circuit_mask(
data, "default", "biophysical", {"population": "unknown", "other": "val"}
)
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [False, False, False])

queries, mask = _circuit_mask(
data, "default", {"population": "default", "node_id": [2], "other": "val"}
data, "default", "biophysical", {"population": "default", "node_id": [2], "other": "val"}
)
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [False, False, True])

queries, mask = _circuit_mask(data, "default", {"other": "val"})
queries, mask = _circuit_mask(data, "default", "biophysical", {"other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [True, True, True])

Expand Down Expand Up @@ -105,28 +109,28 @@ def test_resolve_ids():
data = pd.DataFrame(
[[1, 0.4, "seven"], [2, 0.5, "eight"], [3, 0.6, "nine"]], columns=["int", "float", "str"]
)
assert [False, True, False] == resolve_ids(data, "", {"str": "eight"}).tolist()
assert [False, False, False] == resolve_ids(data, "", {"int": 1, "str": "eight"}).tolist()
assert [False, True, False] == resolve_ids(data, "", "", {"str": "eight"}).tolist()
assert [False, False, False] == resolve_ids(data, "", "", {"int": 1, "str": "eight"}).tolist()
assert [True, True, False] == resolve_ids(
data, "", {"$or": [{"str": "seven"}, {"float": (0.5, 0.5)}]}
data, "", "", {"$or": [{"str": "seven"}, {"float": (0.5, 0.5)}]}
).tolist()
assert [False, False, True] == resolve_ids(
data, "", {"$and": [{"str": "nine"}, {"int": 3}]}
data, "", "", {"$and": [{"str": "nine"}, {"int": 3}]}
).tolist()
assert [False, False, True] == resolve_ids(
data, "", {"$and": [{"str": "nine"}, {"$or": [{"int": 1}, {"int": 3}]}]}
data, "", "", {"$and": [{"str": "nine"}, {"$or": [{"int": 1}, {"int": 3}]}]}
).tolist()
assert [False, True, True] == resolve_ids(
data, "", {"$or": [{"float": (0.59, 0.61)}, {"$and": [{"str": "eight"}, {"int": 2}]}]}
data, "", "", {"$or": [{"float": (0.59, 0.61)}, {"$and": [{"str": "eight"}, {"int": 2}]}]}
).tolist()
assert [True, False, True] == resolve_ids(
data, "", {"$or": [{"node_id": 0}, {"edge_id": 2}]}
data, "", "", {"$or": [{"node_id": 0}, {"edge_id": 2}]}
).tolist()
assert [False, False, False] == resolve_ids(data, "", {"$or": []}).tolist()
assert [True, True, True] == resolve_ids(data, "", {"$and": []}).tolist()
assert [False, False, False] == resolve_ids(data, "", "", {"$or": []}).tolist()
assert [True, True, True] == resolve_ids(data, "", "", {"$and": []}).tolist()

with pytest.raises(BluepySnapError) as e:
resolve_ids(data, "", {"str": {"$regex": "*.some", "edge_id": 2}})
resolve_ids(data, "", "", {"str": {"$regex": "*.some", "edge_id": 2}})
assert "Value operators can't be used with plain values" in e.value.args[0]


Expand Down

0 comments on commit da04bcc

Please sign in to comment.