Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide divergence and convergence stats for EdgePopulation #242

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Improvements

- Both now return ``self.ids(query)`` if ``properties=None``
- ``properties`` is now a keyword argument in ``EdgePopulation.get``
- Added ``EdgePopulation.stats`` with two methods: ``divergence``, ``convergence``


Version v3.0.1
Expand Down
6 changes: 6 additions & 0 deletions bluepysnap/edges/edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from bluepysnap import query, utils
from bluepysnap.circuit_ids import CircuitEdgeIds, CircuitNodeId
from bluepysnap.circuit_ids_types import IDS_DTYPE, CircuitEdgeId
from bluepysnap.edges.edge_population_stats import StatsHelper
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.sonata_constants import DYNAMICS_PREFIX, ConstContainer, Edge

Expand Down Expand Up @@ -147,6 +148,11 @@ def property_dtypes(self):
"""
return self.get([0], list(self.property_names)).dtypes.sort_index()

@cached_property
def stats(self):
"""Access edge population stats methods."""
return StatsHelper(self)

def container_property_names(self, container):
"""Lists the ConstContainer properties shared with the EdgePopulation.

Expand Down
91 changes: 91 additions & 0 deletions bluepysnap/edges/edge_population_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""EdgePopulation stats helper."""

import numpy as np

from bluepysnap.exceptions import BluepySnapError


class StatsHelper:
"""EdgePopulation stats helper."""

def __init__(self, edge_population):
"""Initialize StatsHelper with an EdgePopulation instance."""
self._edge_population = edge_population

def divergence(self, source, target, by, sample=None):
"""`source` -> `target` divergence.

Calculate the divergence based on number of `"connections"` or `"synapses"` each `source`
cell shares with the cells specified in `target`.

* `connections`: number of unique target cells each source cell shares a connection with
* `synapses`: number of unique synapses between a source cell and its target cells

Args:
source (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): source nodes
target (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): target nodes
by (str): 'synapses' or 'connections'
joni-herttuainen marked this conversation as resolved.
Show resolved Hide resolved
sample (int): if specified, sample size for source group

Returns:
Array with synapse / connection count per each cell from `source` sample
(taking into account only connections to cells in `target`).
"""
by_alternatives = {"synapses", "connections"}
if by not in by_alternatives:
raise BluepySnapError(f"`by` should be one of {by_alternatives}; got: {by}")

source_sample = self._edge_population.source.ids(source, sample=sample)

result = {id_: 0 for id_ in source_sample}
if by == "synapses":
connections = self._edge_population.iter_connections(
source_sample, target, return_synapse_count=True
)
for pre_gid, _, synapse_count in connections:
result[pre_gid] += synapse_count
else:
connections = self._edge_population.iter_connections(source_sample, target)
for pre_gid, _ in connections:
result[pre_gid] += 1

return np.array(list(result.values()))

def convergence(self, source, target, by=None, sample=None):
"""`source` -> `target` convergence.

Calculate the convergence based on number of `"connections"` or `"synapses"` each `target`
cell shares with the cells specified in `source`.

* `connections`: number of unique source cells each target cell shares a connection with
* `synapses`: number of unique synapses between a target cell and its source cells

Args:
source (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): source nodes
target (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): target nodes
by (str): 'synapses' or 'connections'
sample (int): if specified, sample size for target group

Returns:
Array with synapse / connection count per each cell from `target` sample
(taking into account only connections from cells in `source`).
"""
by_alternatives = {"synapses", "connections"}
if by not in by_alternatives:
raise BluepySnapError(f"`by` should be one of {by_alternatives}; got: {by}")

target_sample = self._edge_population.target.ids(target, sample=sample)

result = {id_: 0 for id_ in target_sample}
if by == "synapses":
connections = self._edge_population.iter_connections(
source, target_sample, return_synapse_count=True
)
for _, post_gid, synapse_count in connections:
result[post_gid] += synapse_count
else:
connections = self._edge_population.iter_connections(source, target_sample)
for _, post_gid in connections:
result[post_gid] += 1

return np.array(list(result.values()))
2 changes: 2 additions & 0 deletions tests/test_edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from bluepysnap.circuit import Circuit
from bluepysnap.circuit_ids import CircuitEdgeIds, CircuitNodeIds
from bluepysnap.circuit_ids_types import IDS_DTYPE, CircuitEdgeId, CircuitNodeId
from bluepysnap.edges.edge_population_stats import StatsHelper
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.sonata_constants import DEFAULT_EDGE_TYPE, Edge

Expand Down Expand Up @@ -41,6 +42,7 @@ def test_basic(self):
assert self.test_obj.source.name == "default"
assert self.test_obj.target.name == "default"
assert self.test_obj.size, 4
assert isinstance(self.test_obj.stats, StatsHelper)
assert sorted(self.test_obj.property_names) == sorted(
[
Synapse.SOURCE_NODE_ID,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_edge_population_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import Mock

import numpy.testing as npt
import pytest

import bluepysnap.edges.edge_population_stats as test_module
from bluepysnap.exceptions import BluepySnapError


class TestStatsHelper:
def setup_method(self):
self.edge_pop = Mock()
self.stats = test_module.StatsHelper(self.edge_pop)

def test_divergence_by_synapses(self):
self.edge_pop.source.ids.return_value = [1, 2]
self.edge_pop.iter_connections.return_value = [(1, None, 42), (1, None, 43)]
actual = self.stats.divergence("pre", "post", by="synapses")
npt.assert_equal(actual, [85, 0])

def test_divergence_by_connections(self):
self.edge_pop.source.ids.return_value = [1, 2]
self.edge_pop.iter_connections.return_value = [(1, None), (1, None)]
actual = self.stats.divergence("pre", "post", by="connections")
npt.assert_equal(actual, [2, 0])

def test_divergence_error(self):
pytest.raises(BluepySnapError, self.stats.divergence, "pre", "post", by="err")

def test_convergence_by_synapses(self):
self.edge_pop.target.ids.return_value = [1, 2]
self.edge_pop.iter_connections.return_value = [(None, 2, 42), (None, 2, 43)]
actual = self.stats.convergence("pre", "post", by="synapses")
npt.assert_equal(actual, [0, 85])

def test_convergence_by_connections(self):
self.edge_pop.target.ids.return_value = [1, 2]
self.edge_pop.iter_connections.return_value = [(None, 2), (None, 2)]
actual = self.stats.convergence("pre", "post", by="connections")
npt.assert_equal(actual, [0, 2])

def test_convergence_error(self):
pytest.raises(BluepySnapError, self.stats.convergence, "pre", "post", by="err")
Loading