diff --git a/bluepysnap/frame_report.py b/bluepysnap/frame_report.py index 3e4f8f48..c914fbac 100644 --- a/bluepysnap/frame_report.py +++ b/bluepysnap/frame_report.py @@ -16,6 +16,7 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. """Frame report access.""" import logging +from collections.abc import Mapping import numpy as np import pandas as pd @@ -23,6 +24,7 @@ from libsonata import ElementReportReader, SonataError import bluepysnap._plotting +from bluepysnap import query from bluepysnap.exceptions import BluepySnapError from bluepysnap.utils import ensure_ids @@ -62,7 +64,7 @@ def name(self): """Access to the population name.""" return self._population_name - def _resolve(self, group): + def resolve_nodes(self, group, raise_missing_property=True): """Transform a group into ids array. Notes: @@ -96,7 +98,7 @@ def get(self, group=None, t_start=None, t_stop=None, t_step=None): if t_stride < 1: msg = f"Invalid {t_step=}. It should be None or a multiple of {self.frame_report.dt}." raise BluepySnapError(msg) - ids = self._resolve(group).tolist() + ids = self.resolve_nodes(group).tolist() try: view = self._frame_population.get( node_ids=ids, tstart=t_start, tstop=t_stop, tstride=t_stride @@ -164,7 +166,7 @@ def report(self): dataframes = {} for population in self.frame_report.population_names: frames = self.frame_report[population] - ids = frames.nodes.ids(group=self.group, raise_missing_property=False) + ids = frames.resolve_nodes(self.group, raise_missing_property=False) df = frames.get(group=ids, t_start=self.t_start, t_stop=self.t_stop) dataframes[population] = df # optimize when there is at most one non-empty df: use copy=False, and no need to sort @@ -291,14 +293,25 @@ def filter(self, group=None, t_start=None, t_stop=None): class PopulationCompartmentReport(PopulationFrameReport): """Access to PopulationCompartmentsReport data.""" + @property + def _node_sets(self): + """Access to simulation node sets.""" + return self.frame_report.simulation.node_sets + @cached_property def nodes(self): """Returns the NodePopulation corresponding to this report.""" return self.frame_report.simulation.circuit.nodes[self._population_name] - def _resolve(self, group): + def resolve_nodes(self, group, raise_missing_property=True): """Transform a group into a node_id array.""" - return self.nodes.ids(group=group) + if isinstance(group, str): + group = self._node_sets[group] + elif isinstance(group, Mapping): + group = query.resolve_nodesets( + self._node_sets, self.nodes, group, raise_missing_property + ) + return self.nodes.ids(group=group, raise_missing_property=raise_missing_property) class CompartmentReport(FrameReport): diff --git a/bluepysnap/nodes/node_population.py b/bluepysnap/nodes/node_population.py index e9d114c7..ea16598c 100644 --- a/bluepysnap/nodes/node_population.py +++ b/bluepysnap/nodes/node_population.py @@ -59,7 +59,6 @@ """ import inspect from collections.abc import Mapping, Sequence -from copy import deepcopy import libsonata import numpy as np @@ -181,7 +180,7 @@ def _get_data(self, properties=None, node_ids=None): cached_columns = properties_set.intersection(result.columns) if len(cached_columns) < len(properties_set): # some requested properties miss from the cache - nodes = self._population + nodes = self.to_libsonata # insert columns at the correct position for n, (loc, name) in enumerate( self._iter_selected_properties(existing=result.columns, desired=properties_set) @@ -196,7 +195,7 @@ def _properties(self): return self._circuit.to_libsonata.node_population_properties(self.name) @property - def _population(self): + def to_libsonata(self): """Libsonata node population. Not cached because it would keep the hdf5 file open. @@ -206,7 +205,7 @@ def _population(self): @cached_property def size(self): """Node population size.""" - return self._population.size + return self.to_libsonata.size @property def type(self): @@ -215,11 +214,11 @@ def type(self): @cached_property def _property_names(self): - return set(self._population.attribute_names) + return set(self.to_libsonata.attribute_names) @cached_property def _dynamics_params_names(self): - return set(utils.add_dynamic_prefix(self._population.dynamics_attribute_names)) + return set(utils.add_dynamic_prefix(self.to_libsonata.dynamics_attribute_names)) @cached_property def _ordered_property_names(self): @@ -349,21 +348,6 @@ def _check_properties(self, properties): if unknown_props: raise BluepySnapError(f"Unknown node properties: {sorted(unknown_props)}") - def _resolve_nodesets(self, queries, raise_missing_prop): - def _resolve(queries, queries_key): - if queries_key == query.NODE_SET_KEY: - if query.AND_KEY not in queries: - queries[query.AND_KEY] = [] - node_set = self._node_sets[queries[queries_key]] - queries[query.AND_KEY].append( - {query.NODE_ID_KEY: node_set.get_ids(self._population, raise_missing_prop)} - ) - del queries[queries_key] - - resolved_queries = deepcopy(queries) - query.traverse_queries_bottom_up(resolved_queries, _resolve) - return resolved_queries - def _node_ids_by_filter(self, queries, raise_missing_prop): """Return node IDs if their properties match the `queries` dict. @@ -381,7 +365,7 @@ def _node_ids_by_filter(self, queries, raise_missing_prop): >>> { Node.X: (0, 1), Node.MTYPE: 'L1_SLAC' }]}) """ - queries = self._resolve_nodesets(queries, raise_missing_prop) + queries = query.resolve_nodesets(self._node_sets, self, queries, raise_missing_prop) properties = query.get_properties(queries) if raise_missing_prop: self._check_properties(properties) @@ -420,7 +404,7 @@ def ids(self, group=None, limit=None, sample=None, raise_missing_property=True): if group is None: result = np.arange(self.size) elif isinstance(group, NodeSet): - result = group.get_ids(self._population, raise_missing_property) + result = group.get_ids(self.to_libsonata, raise_missing_property) elif isinstance(group, Mapping): result = self._node_ids_by_filter(group, raise_missing_property) elif isinstance(group, np.ndarray): diff --git a/bluepysnap/query.py b/bluepysnap/query.py index ede5824a..079b62c5 100644 --- a/bluepysnap/query.py +++ b/bluepysnap/query.py @@ -172,6 +172,35 @@ def _collect(_, query_key): return props +def resolve_nodesets(node_sets, population, queries, raise_missing_prop): + """Resolve node sets in queries. + + Args: + node_sets (bluepysnap.NodeSets): node sets instance + population (bluepysnap.NodePopulation): node population + queries (dict): queries to resolve + raise_missing_prop (bool): raise if property not present in population + + Returns: + dict: queries with resolved node sets + """ + + def _resolve(queries, queries_key): + if queries_key == NODE_SET_KEY: + if AND_KEY not in queries: + queries[AND_KEY] = [] + node_set = node_sets[queries[queries_key]] + queries[AND_KEY].append( + {NODE_ID_KEY: node_set.get_ids(population.to_libsonata, raise_missing_prop)} + ) + del queries[queries_key] + + resolved_queries = deepcopy(queries) + traverse_queries_bottom_up(resolved_queries, _resolve) + + return resolved_queries + + def resolve_ids(data, population_name, queries): """Returns an index mask of `data` for given `queries`. diff --git a/bluepysnap/spike_report.py b/bluepysnap/spike_report.py index 1d513cf9..23856b54 100644 --- a/bluepysnap/spike_report.py +++ b/bluepysnap/spike_report.py @@ -16,6 +16,7 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. """Spike report access.""" +from collections.abc import Mapping from contextlib import contextmanager from pathlib import Path @@ -25,6 +26,7 @@ from libsonata import SonataError, SpikeReader import bluepysnap._plotting +from bluepysnap import query from bluepysnap.circuit_ids_types import IDS_DTYPE from bluepysnap.exceptions import BluepySnapError @@ -58,6 +60,11 @@ def __init__(self, spike_report, population_name): self._spike_population = _get_reader(self.spike_report.config)[population_name] self._population_name = population_name + @property + def _node_sets(self): + """Access to simulation node sets.""" + return self.spike_report.simulation.node_sets + @property def _sorted_by(self): """Access to the sorting attribute. @@ -83,9 +90,15 @@ def nodes(self): """Return the NodePopulation corresponding to this spike report.""" return self.spike_report.simulation.circuit.nodes[self._population_name] - def _resolve_nodes(self, group): + def resolve_nodes(self, group, raise_missing_property=True): """Transform a node group into a node_id array.""" - return self.nodes.ids(group=group) + if isinstance(group, str): + group = self._node_sets[group] + elif isinstance(group, Mapping): + group = query.resolve_nodesets( + self._node_sets, self.nodes, group, raise_missing_property + ) + return self.nodes.ids(group=group, raise_missing_property=raise_missing_property) def get(self, group=None, t_start=None, t_stop=None): """Fetch spikes from the report. @@ -98,7 +111,7 @@ def get(self, group=None, t_start=None, t_stop=None): Returns: pandas.Series: return spiking node_ids indexed by sorted spike time. """ - node_ids = self._resolve_nodes(group).tolist() + node_ids = self.resolve_nodes(group).tolist() series_name = "ids" try: @@ -161,7 +174,7 @@ def report(self): dfs = [] for population in self.spike_report.population_names: spikes = self.spike_report[population] - ids = spikes.nodes.ids(group=self.group, raise_missing_property=False) + ids = spikes.resolve_nodes(self.group, raise_missing_property=False) data = spikes.get(group=ids, t_start=self.t_start, t_stop=self.t_stop).to_frame() data["population"] = np.full(len(data), population) diff --git a/tests/data/node_sets_simple.json b/tests/data/node_sets_simple.json index 49c44c56..7b2c28c4 100644 --- a/tests/data/node_sets_simple.json +++ b/tests/data/node_sets_simple.json @@ -1,5 +1,8 @@ { "Layer23": { "layer": [2,3] + }, + "only_exists_in_simulation": { + "node_id": [0,2] } } diff --git a/tests/test_frame_report.py b/tests/test_frame_report.py index 6e03e032..c74a4c2d 100644 --- a/tests/test_frame_report.py +++ b/tests/test_frame_report.py @@ -223,7 +223,7 @@ def test_name(self): def test__resolve(self): with pytest.raises(NotImplementedError): - self.test_obj._resolve([1]) + self.test_obj.resolve_nodes([1]) class TestPopulationCompartmentReport: @@ -241,9 +241,9 @@ def empty_df(self): return self.df.iloc[:0, :0] def test__resolve(self): - npt.assert_array_equal(self.test_obj._resolve({Cell.MTYPE: "L6_Y"}), [1, 2]) - assert self.test_obj._resolve({Cell.MTYPE: "L2_X"}) == [0] - npt.assert_array_equal(self.test_obj._resolve("Node12_L6_Y"), [1, 2]) + npt.assert_array_equal(self.test_obj.resolve_nodes({Cell.MTYPE: "L6_Y"}), [1, 2]) + assert self.test_obj.resolve_nodes({Cell.MTYPE: "L2_X"}) == [0] + npt.assert_array_equal(self.test_obj.resolve_nodes("Node12_L6_Y"), [1, 2]) def test_nodes(self): assert self.test_obj.nodes.get(group=2, properties=Cell.MTYPE) == "L6_Y" @@ -328,6 +328,11 @@ def _assert_frame_equal(df1, df2): ids = CircuitNodeIds.from_arrays(["default", "default", "default2"], [0, 2, 1]) _assert_frame_equal(self.test_obj.get(group=ids, t_step=t_step), self.df.loc[:, [0, 2]]) + # test that simulation node_set is used + _assert_frame_equal( + self.test_obj.get("only_exists_in_simulation", t_step=t_step), self.df.loc[:, [0, 2]] + ) + with pytest.raises( BluepySnapError, match="All node IDs must be >= 0 and < 3 for population 'default'" ): @@ -351,11 +356,13 @@ def test_get_with_invalid_t_step(self, t_step): self.test_obj.get(t_step=t_step) def test_get_partially_not_in_report(self): - with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([0, 4])): + with patch.object( + self.test_obj.__class__, "resolve_nodes", return_value=np.asarray([0, 4]) + ): pdt.assert_frame_equal(self.test_obj.get([0, 4]), self.df.loc[:, [0]]) def test_get_not_in_report(self): - with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([4])): + with patch.object(self.test_obj.__class__, "resolve_nodes", return_value=np.asarray([4])): pdt.assert_frame_equal(self.test_obj.get([4]), self.empty_df) def test_node_ids(self): diff --git a/tests/test_node_population.py b/tests/test_node_population.py index 7fafa8eb..2b905d91 100644 --- a/tests/test_node_population.py +++ b/tests/test_node_population.py @@ -303,7 +303,10 @@ def test_ids(self): _call([CircuitNodeId("default", 1), CircuitNodeId("default2", 1), ("default2", 1)]) def test_node_ids_by_filter_complex_query(self): - test_obj = create_node_population(str(TEST_DATA_DIR / "nodes.h5"), "default") + test_obj = create_node_population( + str(TEST_DATA_DIR / "nodes.h5"), "default", node_sets=NodeSets.from_dict({}) + ) + data = pd.DataFrame( { Cell.MTYPE: ["L23_MC", "L4_BP", "L6_BP", "L6_BPC"], @@ -684,7 +687,7 @@ def test_filter_properties(self): assert actual_item == expected_item def test_get_values_from_sonata(self): - nodes = self.test_obj._population + nodes = self.test_obj.to_libsonata # valid attributes result = self.test_obj._get_values_from_sonata(nodes, "mtype", [0, 1]) diff --git a/tests/test_query.py b/tests/test_query.py index 0f2f1dde..768b5ad1 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -3,8 +3,21 @@ import pandas as pd import pytest -from bluepysnap import BluepySnapError -from bluepysnap.query import _circuit_mask, _logical_and, _logical_or, _positional_mask, resolve_ids +from bluepysnap import BluepySnapError, Circuit +from bluepysnap.node_sets import NodeSets +from bluepysnap.query import ( + AND_KEY, + NODE_ID_KEY, + NODE_SET_KEY, + _circuit_mask, + _logical_and, + _logical_or, + _positional_mask, + resolve_ids, + resolve_nodesets, +) + +from utils import TEST_DATA_DIR def test_positional_mask(): @@ -34,6 +47,60 @@ def test_population_mask(): npt.assert_array_equal(mask, [True, True, True]) +@pytest.mark.parametrize( + "query,expected", + [ + ( + {NODE_SET_KEY: "Node12_L6_Y", NODE_ID_KEY: 1}, + {NODE_ID_KEY: 1, AND_KEY: [{NODE_ID_KEY: [1, 2]}]}, + ), + ( + {NODE_SET_KEY: "Layer23", "any_query": "any_value"}, + {"any_query": "any_value", AND_KEY: [{NODE_ID_KEY: [0]}]}, + ), + ( + {NODE_SET_KEY: "Empty_nodes"}, + {AND_KEY: [{NODE_ID_KEY: []}]}, + ), + ( + {"any_query": "any_value"}, + {"any_query": "any_value"}, + ), + ( + {AND_KEY: [{NODE_SET_KEY: "Node12_L6_Y"}, {NODE_SET_KEY: "Layer23"}]}, + {AND_KEY: [{AND_KEY: [{NODE_ID_KEY: [1, 2]}]}, {AND_KEY: [{NODE_ID_KEY: [0]}]}]}, + ), + ], +) +def test_resolve_nodesets_success(query, expected): + circuit = Circuit(str(TEST_DATA_DIR / "circuit_config.json")) + population, node_sets = circuit.nodes["default"], circuit.node_sets + + res = resolve_nodesets(node_sets, population, query, raise_missing_prop=True) + npt.assert_equal(res, expected) + + +def test_resolve_nodesets_exceptions(): + population = Circuit(str(TEST_DATA_DIR / "circuit_config.json")).nodes["default"] + node_sets = NodeSets.from_dict({"fake_set": {"missing_property": "fake"}}) + + query = {NODE_SET_KEY: "fake_set"} + res = resolve_nodesets(node_sets, population, query, raise_missing_prop=False) + expected = {AND_KEY: [{NODE_ID_KEY: []}]} + npt.assert_equal(res, expected) + + with pytest.raises(BluepySnapError, match="No such attribute: 'missing_property'"): + resolve_nodesets(node_sets, population, query, raise_missing_prop=True) + + query = {NODE_SET_KEY: "missing_set"} + with pytest.raises(BluepySnapError, match="Undefined node set: 'missing_set'"): + resolve_nodesets(node_sets, population, query, raise_missing_prop=True) + + query = {NODE_SET_KEY: ["Node12_L6_Y", "Layer23"]} + with pytest.raises(BluepySnapError, match=r"Unexpected type: 'list' \(expected: 'str'\)"): + resolve_nodesets(node_sets, population, query, raise_missing_prop=True) + + def test_resolve_ids(): data = pd.DataFrame( [[1, 0.4, "seven"], [2, 0.5, "eight"], [3, 0.6, "nine"]], columns=["int", "float", "str"] diff --git a/tests/test_simulation.py b/tests/test_simulation.py index abb62ad7..3410110f 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -61,6 +61,7 @@ def test_all(): expected_content = { **json.loads((TEST_DATA_DIR / "node_sets.json").read_text()), "Layer23": {"layer": [2, 3]}, + "only_exists_in_simulation": {"node_id": [0, 2]}, } assert simulation.node_sets.content == expected_content diff --git a/tests/test_spike_report.py b/tests/test_spike_report.py index 93cc5e56..c97c5f09 100644 --- a/tests/test_spike_report.py +++ b/tests/test_spike_report.py @@ -153,9 +153,9 @@ def test_nodes_invalid_population(self): test_obj.nodes def test__resolve_nodes(self): - npt.assert_array_equal(self.test_obj._resolve_nodes({Cell.MTYPE: "L6_Y"}), [1, 2]) - assert self.test_obj._resolve_nodes({Cell.MTYPE: "L2_X"}) == [0] - npt.assert_array_equal(self.test_obj._resolve_nodes("Node12_L6_Y"), [1, 2]) + npt.assert_array_equal(self.test_obj.resolve_nodes({Cell.MTYPE: "L6_Y"}), [1, 2]) + assert self.test_obj.resolve_nodes({Cell.MTYPE: "L2_X"}) == [0] + npt.assert_array_equal(self.test_obj.resolve_nodes("Node12_L6_Y"), [1, 2]) def test_get(self): tested = self.test_obj.get() @@ -226,6 +226,12 @@ def test_get(self): self.test_obj.get(group="Layer23"), _create_series([0, 0], [0.2, 1.3]) ) + # test that simulation node_set is used + pdt.assert_series_equal( + self.test_obj.get("only_exists_in_simulation"), + _create_series([2, 0, 2, 0], [0.1, 0.2, 0.7, 1.3]), + ) + # no 0.1, 0.7 from ("default2", 2) ids = CircuitNodeIds.from_arrays(["default", "default", "default2"], [0, 1, 2]) npt.assert_array_equal(self.test_obj.get(ids), _create_series([0, 1, 0], [0.2, 0.3, 1.3])) @@ -254,13 +260,13 @@ def test_get2(self): ) @patch( - test_module.__name__ + ".PopulationSpikeReport._resolve_nodes", return_value=np.asarray([4]) + test_module.__name__ + ".PopulationSpikeReport.resolve_nodes", return_value=np.asarray([4]) ) def test_get_not_in_report(self, mock): pdt.assert_series_equal(self.test_obj.get(4), _create_series([], [])) @patch( - test_module.__name__ + ".PopulationSpikeReport._resolve_nodes", + test_module.__name__ + ".PopulationSpikeReport.resolve_nodes", return_value=np.asarray([0, 4]), ) def test_get_not_in_report(self, mock): diff --git a/tox.ini b/tox.ini index 9fb92ff4..1aa2a302 100644 --- a/tox.ini +++ b/tox.ini @@ -73,3 +73,8 @@ python = 3.9: py39 3.10: py310 3.11: py311, lint, docs + +[pytest] +filterwarnings = + # ignoring the warning about Simulation node sets overwriting Circuit node sets in tests + ignore:Simulation node sets overwrite:RuntimeWarning:bluepysnap