Skip to content

Commit

Permalink
Have reports use Simulation Node Sets (#241)
Browse files Browse the repository at this point in the history
* Move NodePopulation._resolve_nodesets -> query.resolve_nodesets
* Use simulation node set in SpikeReport
* Use simulation node set in frame_report.py
* make reports' node resolution methods public
* NodePopulation: '_population' -> 'to_libsonata' + use that in query.resolve_nodesets
* add small tests to test reports use simulation node sets
* remove unnecessary NodeSets._resolve_nodesets method
* ignore simulation node sets overwrite warnings in tests
  • Loading branch information
joni-herttuainen authored Nov 23, 2023
1 parent e97429f commit c094e4f
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 47 deletions.
23 changes: 18 additions & 5 deletions bluepysnap/frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
"""Frame report access."""
import logging
from collections.abc import Mapping

import numpy as np
import pandas as pd
from cached_property import cached_property
from libsonata import ElementReportReader, SonataError

import bluepysnap._plotting
from bluepysnap import query
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.utils import ensure_ids

Expand Down Expand Up @@ -62,7 +64,7 @@ def name(self):
"""Access to the population name."""
return self._population_name

def _resolve(self, group):
def resolve_nodes(self, group, raise_missing_property=True):
"""Transform a group into ids array.
Notes:
Expand Down Expand Up @@ -96,7 +98,7 @@ def get(self, group=None, t_start=None, t_stop=None, t_step=None):
if t_stride < 1:
msg = f"Invalid {t_step=}. It should be None or a multiple of {self.frame_report.dt}."
raise BluepySnapError(msg)
ids = self._resolve(group).tolist()
ids = self.resolve_nodes(group).tolist()
try:
view = self._frame_population.get(
node_ids=ids, tstart=t_start, tstop=t_stop, tstride=t_stride
Expand Down Expand Up @@ -164,7 +166,7 @@ def report(self):
dataframes = {}
for population in self.frame_report.population_names:
frames = self.frame_report[population]
ids = frames.nodes.ids(group=self.group, raise_missing_property=False)
ids = frames.resolve_nodes(self.group, raise_missing_property=False)
df = frames.get(group=ids, t_start=self.t_start, t_stop=self.t_stop)
dataframes[population] = df
# optimize when there is at most one non-empty df: use copy=False, and no need to sort
Expand Down Expand Up @@ -291,14 +293,25 @@ def filter(self, group=None, t_start=None, t_stop=None):
class PopulationCompartmentReport(PopulationFrameReport):
"""Access to PopulationCompartmentsReport data."""

@property
def _node_sets(self):
"""Access to simulation node sets."""
return self.frame_report.simulation.node_sets

@cached_property
def nodes(self):
"""Returns the NodePopulation corresponding to this report."""
return self.frame_report.simulation.circuit.nodes[self._population_name]

def _resolve(self, group):
def resolve_nodes(self, group, raise_missing_property=True):
"""Transform a group into a node_id array."""
return self.nodes.ids(group=group)
if isinstance(group, str):
group = self._node_sets[group]
elif isinstance(group, Mapping):
group = query.resolve_nodesets(
self._node_sets, self.nodes, group, raise_missing_property
)
return self.nodes.ids(group=group, raise_missing_property=raise_missing_property)


class CompartmentReport(FrameReport):
Expand Down
30 changes: 7 additions & 23 deletions bluepysnap/nodes/node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
"""
import inspect
from collections.abc import Mapping, Sequence
from copy import deepcopy

import libsonata
import numpy as np
Expand Down Expand Up @@ -181,7 +180,7 @@ def _get_data(self, properties=None, node_ids=None):
cached_columns = properties_set.intersection(result.columns)
if len(cached_columns) < len(properties_set):
# some requested properties miss from the cache
nodes = self._population
nodes = self.to_libsonata
# insert columns at the correct position
for n, (loc, name) in enumerate(
self._iter_selected_properties(existing=result.columns, desired=properties_set)
Expand All @@ -196,7 +195,7 @@ def _properties(self):
return self._circuit.to_libsonata.node_population_properties(self.name)

@property
def _population(self):
def to_libsonata(self):
"""Libsonata node population.
Not cached because it would keep the hdf5 file open.
Expand All @@ -206,7 +205,7 @@ def _population(self):
@cached_property
def size(self):
"""Node population size."""
return self._population.size
return self.to_libsonata.size

@property
def type(self):
Expand All @@ -215,11 +214,11 @@ def type(self):

@cached_property
def _property_names(self):
return set(self._population.attribute_names)
return set(self.to_libsonata.attribute_names)

@cached_property
def _dynamics_params_names(self):
return set(utils.add_dynamic_prefix(self._population.dynamics_attribute_names))
return set(utils.add_dynamic_prefix(self.to_libsonata.dynamics_attribute_names))

@cached_property
def _ordered_property_names(self):
Expand Down Expand Up @@ -349,21 +348,6 @@ def _check_properties(self, properties):
if unknown_props:
raise BluepySnapError(f"Unknown node properties: {sorted(unknown_props)}")

def _resolve_nodesets(self, queries, raise_missing_prop):
def _resolve(queries, queries_key):
if queries_key == query.NODE_SET_KEY:
if query.AND_KEY not in queries:
queries[query.AND_KEY] = []
node_set = self._node_sets[queries[queries_key]]
queries[query.AND_KEY].append(
{query.NODE_ID_KEY: node_set.get_ids(self._population, raise_missing_prop)}
)
del queries[queries_key]

resolved_queries = deepcopy(queries)
query.traverse_queries_bottom_up(resolved_queries, _resolve)
return resolved_queries

def _node_ids_by_filter(self, queries, raise_missing_prop):
"""Return node IDs if their properties match the `queries` dict.
Expand All @@ -381,7 +365,7 @@ def _node_ids_by_filter(self, queries, raise_missing_prop):
>>> { Node.X: (0, 1), Node.MTYPE: 'L1_SLAC' }]})
"""
queries = self._resolve_nodesets(queries, raise_missing_prop)
queries = query.resolve_nodesets(self._node_sets, self, queries, raise_missing_prop)
properties = query.get_properties(queries)
if raise_missing_prop:
self._check_properties(properties)
Expand Down Expand Up @@ -420,7 +404,7 @@ def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
if group is None:
result = np.arange(self.size)
elif isinstance(group, NodeSet):
result = group.get_ids(self._population, raise_missing_property)
result = group.get_ids(self.to_libsonata, raise_missing_property)
elif isinstance(group, Mapping):
result = self._node_ids_by_filter(group, raise_missing_property)
elif isinstance(group, np.ndarray):
Expand Down
29 changes: 29 additions & 0 deletions bluepysnap/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,35 @@ def _collect(_, query_key):
return props


def resolve_nodesets(node_sets, population, queries, raise_missing_prop):
"""Resolve node sets in queries.
Args:
node_sets (bluepysnap.NodeSets): node sets instance
population (bluepysnap.NodePopulation): node population
queries (dict): queries to resolve
raise_missing_prop (bool): raise if property not present in population
Returns:
dict: queries with resolved node sets
"""

def _resolve(queries, queries_key):
if queries_key == NODE_SET_KEY:
if AND_KEY not in queries:
queries[AND_KEY] = []
node_set = node_sets[queries[queries_key]]
queries[AND_KEY].append(
{NODE_ID_KEY: node_set.get_ids(population.to_libsonata, raise_missing_prop)}
)
del queries[queries_key]

resolved_queries = deepcopy(queries)
traverse_queries_bottom_up(resolved_queries, _resolve)

return resolved_queries


def resolve_ids(data, population_name, queries):
"""Returns an index mask of `data` for given `queries`.
Expand Down
21 changes: 17 additions & 4 deletions bluepysnap/spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
"""Spike report access."""

from collections.abc import Mapping
from contextlib import contextmanager
from pathlib import Path

Expand All @@ -25,6 +26,7 @@
from libsonata import SonataError, SpikeReader

import bluepysnap._plotting
from bluepysnap import query
from bluepysnap.circuit_ids_types import IDS_DTYPE
from bluepysnap.exceptions import BluepySnapError

Expand Down Expand Up @@ -58,6 +60,11 @@ def __init__(self, spike_report, population_name):
self._spike_population = _get_reader(self.spike_report.config)[population_name]
self._population_name = population_name

@property
def _node_sets(self):
"""Access to simulation node sets."""
return self.spike_report.simulation.node_sets

@property
def _sorted_by(self):
"""Access to the sorting attribute.
Expand All @@ -83,9 +90,15 @@ def nodes(self):
"""Return the NodePopulation corresponding to this spike report."""
return self.spike_report.simulation.circuit.nodes[self._population_name]

def _resolve_nodes(self, group):
def resolve_nodes(self, group, raise_missing_property=True):
"""Transform a node group into a node_id array."""
return self.nodes.ids(group=group)
if isinstance(group, str):
group = self._node_sets[group]
elif isinstance(group, Mapping):
group = query.resolve_nodesets(
self._node_sets, self.nodes, group, raise_missing_property
)
return self.nodes.ids(group=group, raise_missing_property=raise_missing_property)

def get(self, group=None, t_start=None, t_stop=None):
"""Fetch spikes from the report.
Expand All @@ -98,7 +111,7 @@ def get(self, group=None, t_start=None, t_stop=None):
Returns:
pandas.Series: return spiking node_ids indexed by sorted spike time.
"""
node_ids = self._resolve_nodes(group).tolist()
node_ids = self.resolve_nodes(group).tolist()

series_name = "ids"
try:
Expand Down Expand Up @@ -161,7 +174,7 @@ def report(self):
dfs = []
for population in self.spike_report.population_names:
spikes = self.spike_report[population]
ids = spikes.nodes.ids(group=self.group, raise_missing_property=False)
ids = spikes.resolve_nodes(self.group, raise_missing_property=False)
data = spikes.get(group=ids, t_start=self.t_start, t_stop=self.t_stop).to_frame()
data["population"] = np.full(len(data), population)

Expand Down
3 changes: 3 additions & 0 deletions tests/data/node_sets_simple.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
{
"Layer23": {
"layer": [2,3]
},
"only_exists_in_simulation": {
"node_id": [0,2]
}
}
19 changes: 13 additions & 6 deletions tests/test_frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_name(self):

def test__resolve(self):
with pytest.raises(NotImplementedError):
self.test_obj._resolve([1])
self.test_obj.resolve_nodes([1])


class TestPopulationCompartmentReport:
Expand All @@ -241,9 +241,9 @@ def empty_df(self):
return self.df.iloc[:0, :0]

def test__resolve(self):
npt.assert_array_equal(self.test_obj._resolve({Cell.MTYPE: "L6_Y"}), [1, 2])
assert self.test_obj._resolve({Cell.MTYPE: "L2_X"}) == [0]
npt.assert_array_equal(self.test_obj._resolve("Node12_L6_Y"), [1, 2])
npt.assert_array_equal(self.test_obj.resolve_nodes({Cell.MTYPE: "L6_Y"}), [1, 2])
assert self.test_obj.resolve_nodes({Cell.MTYPE: "L2_X"}) == [0]
npt.assert_array_equal(self.test_obj.resolve_nodes("Node12_L6_Y"), [1, 2])

def test_nodes(self):
assert self.test_obj.nodes.get(group=2, properties=Cell.MTYPE) == "L6_Y"
Expand Down Expand Up @@ -328,6 +328,11 @@ def _assert_frame_equal(df1, df2):
ids = CircuitNodeIds.from_arrays(["default", "default", "default2"], [0, 2, 1])
_assert_frame_equal(self.test_obj.get(group=ids, t_step=t_step), self.df.loc[:, [0, 2]])

# test that simulation node_set is used
_assert_frame_equal(
self.test_obj.get("only_exists_in_simulation", t_step=t_step), self.df.loc[:, [0, 2]]
)

with pytest.raises(
BluepySnapError, match="All node IDs must be >= 0 and < 3 for population 'default'"
):
Expand All @@ -351,11 +356,13 @@ def test_get_with_invalid_t_step(self, t_step):
self.test_obj.get(t_step=t_step)

def test_get_partially_not_in_report(self):
with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([0, 4])):
with patch.object(
self.test_obj.__class__, "resolve_nodes", return_value=np.asarray([0, 4])
):
pdt.assert_frame_equal(self.test_obj.get([0, 4]), self.df.loc[:, [0]])

def test_get_not_in_report(self):
with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([4])):
with patch.object(self.test_obj.__class__, "resolve_nodes", return_value=np.asarray([4])):
pdt.assert_frame_equal(self.test_obj.get([4]), self.empty_df)

def test_node_ids(self):
Expand Down
7 changes: 5 additions & 2 deletions tests/test_node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ def test_ids(self):
_call([CircuitNodeId("default", 1), CircuitNodeId("default2", 1), ("default2", 1)])

def test_node_ids_by_filter_complex_query(self):
test_obj = create_node_population(str(TEST_DATA_DIR / "nodes.h5"), "default")
test_obj = create_node_population(
str(TEST_DATA_DIR / "nodes.h5"), "default", node_sets=NodeSets.from_dict({})
)

data = pd.DataFrame(
{
Cell.MTYPE: ["L23_MC", "L4_BP", "L6_BP", "L6_BPC"],
Expand Down Expand Up @@ -684,7 +687,7 @@ def test_filter_properties(self):
assert actual_item == expected_item

def test_get_values_from_sonata(self):
nodes = self.test_obj._population
nodes = self.test_obj.to_libsonata

# valid attributes
result = self.test_obj._get_values_from_sonata(nodes, "mtype", [0, 1])
Expand Down
Loading

0 comments on commit c094e4f

Please sign in to comment.