Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

248 inconsistency between edgepopulationiter connections and edgesiter connections #249

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
=========

Version v2.2.0~dev0
edasubert marked this conversation as resolved.
Show resolved Hide resolved
--------------

Breaking Changes
~~~~~~~~~~~~~~~~
- Edge populations' ``iter_connections`` returns ``CircuitNodeId``s instead of ``int``s
edasubert marked this conversation as resolved.
Show resolved Hide resolved


Version v2.1.0
--------------

Expand Down
20 changes: 17 additions & 3 deletions bluepysnap/edges/edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from more_itertools import first

from bluepysnap import query, utils
from bluepysnap.circuit_ids import CircuitEdgeIds
from bluepysnap.circuit_ids import CircuitEdgeIds, CircuitNodeId
from bluepysnap.circuit_ids_types import IDS_DTYPE, CircuitEdgeId
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.sonata_constants import DYNAMICS_PREFIX, ConstContainer, Edge
Expand Down Expand Up @@ -518,6 +518,14 @@ def _optimal_direction():
secondary_node_ids_used.add(conn_node_id)
break

def _complete_circuit_node_ids(self, connections):
for connection in connections:
yield (
CircuitNodeId(population=self.source.name, id=connection[0]),
CircuitNodeId(population=self.target.name, id=connection[1]),
connection[2],
)

def iter_connections(
self,
source=None,
Expand Down Expand Up @@ -554,12 +562,18 @@ def iter_connections(
source_node_ids = self._resolve_node_ids(self.source, source)
target_node_ids = self._resolve_node_ids(self.target, target)

it = self._iter_connections(source_node_ids, target_node_ids, unique_node_ids, shuffle)
it = self._complete_circuit_node_ids(
self._iter_connections(source_node_ids, target_node_ids, unique_node_ids, shuffle)
)

if return_edge_count:
return it
elif return_edge_ids:
add_edge_ids = lambda x: (x[0], x[1], self.pair_edges(x[0], x[1]))
add_edge_ids = lambda x: (
x[0],
x[1],
CircuitEdgeIds.from_dict({self.name: self.pair_edges(x[0], x[1])}),
)
return map(add_edge_ids, it)
else:
omit_edge_count = lambda x: x[:2]
Expand Down
46 changes: 2 additions & 44 deletions bluepysnap/edges/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from bluepysnap._doctools import AbstractDocSubstitutionMeta
from bluepysnap.circuit_ids import CircuitEdgeIds, CircuitNodeIds
from bluepysnap.circuit_ids_types import CircuitNodeId
from bluepysnap.edges.edge_population import EdgePopulation
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.network import NetworkObject
Expand Down Expand Up @@ -221,39 +220,6 @@ def pair_edges(self, source_node_id, target_node_id, properties=None):
source=source_node_id, target=target_node_id, properties=properties
)

@staticmethod
def _add_circuit_ids(its, source, target):
"""Generator comprehension adding the CircuitIds to the iterator.

Notes:
Using closures or lambda functions would result in override functions and so the
source and target would be the same for all the populations.
"""
return (
(CircuitNodeId(source, source_id), CircuitNodeId(target, target_id), count)
for source_id, target_id, count in its
)

@staticmethod
def _add_edge_ids(its, source, target, pop_name):
"""Generator comprehension adding the CircuitIds to the iterator."""
return (
(
CircuitNodeId(source, source_id),
CircuitNodeId(target, target_id),
CircuitEdgeIds.from_dict({pop_name: edge_id}),
)
for source_id, target_id, edge_id in its
)

@staticmethod
def _omit_edge_count(its, source, target):
"""Generator comprehension adding the CircuitIds to the iterator."""
return (
(CircuitNodeId(source, source_id), CircuitNodeId(target, target_id))
for source_id, target_id in its
)

joni-herttuainen marked this conversation as resolved.
Show resolved Hide resolved
def iter_connections(
self, source=None, target=None, return_edge_ids=False, return_edge_count=False
):
Expand All @@ -276,21 +242,13 @@ def iter_connections(
raise BluepySnapError(
"`return_edge_count` and `return_edge_ids` are mutually exclusive"
)
for name, pop in self.items():
it = pop.iter_connections(
for pop in self.values():
yield from pop.iter_connections(
source=source,
target=target,
return_edge_ids=return_edge_ids,
return_edge_count=return_edge_count,
)
source_pop = pop.source.name
target_pop = pop.target.name
if return_edge_count:
yield from self._add_circuit_ids(it, source_pop, target_pop)
elif return_edge_ids:
yield from self._add_edge_ids(it, source_pop, target_pop, name)
else:
yield from self._omit_edge_count(it, source_pop, target_pop)
joni-herttuainen marked this conversation as resolved.
Show resolved Hide resolved

def __getstate__(self):
"""Make Edges pickle-able, without storing state of caches."""
Expand Down
153 changes: 137 additions & 16 deletions tests/test_edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from bluepysnap.bbp import Synapse
from bluepysnap.circuit import Circuit
from bluepysnap.circuit_ids import CircuitEdgeIds, CircuitNodeIds
from bluepysnap.circuit_ids_types import IDS_DTYPE, CircuitEdgeId
from bluepysnap.circuit_ids_types import IDS_DTYPE, CircuitEdgeId, CircuitNodeId
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.sonata_constants import DEFAULT_EDGE_TYPE, Edge

Expand Down Expand Up @@ -519,18 +519,29 @@ def test_pathway_edges_6(self):

def test_iter_connections_1(self):
it = self.test_obj.iter_connections([0, 2], [1])
assert next(it) == (0, 1)
assert next(it) == (2, 1)
assert next(it) == (
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=1),
)
assert next(it) == (
CircuitNodeId(population="default", id=2),
CircuitNodeId(population="default", id=1),
)
with pytest.raises(StopIteration):
next(it)

def test_iter_connections_2(self):
it = self.test_obj.iter_connections([0, 2], [1], unique_node_ids=True)
assert list(it) == [(0, 1)]
assert list(it) == [
(CircuitNodeId(population="default", id=0), CircuitNodeId(population="default", id=1)),
]

def test_iter_connections_3(self):
it = self.test_obj.iter_connections([0, 2], [1], shuffle=True)
assert sorted(it) == [(0, 1), (2, 1)]
assert sorted(it) == [
(CircuitNodeId(population="default", id=0), CircuitNodeId(population="default", id=1)),
(CircuitNodeId(population="default", id=2), CircuitNodeId(population="default", id=1)),
]

def test_iter_connections_4(self):
it = self.test_obj.iter_connections(None, None)
Expand All @@ -539,23 +550,54 @@ def test_iter_connections_4(self):

def test_iter_connections_5(self):
it = self.test_obj.iter_connections(None, [1])
assert list(it) == [(0, 1), (2, 1)]
assert list(it) == [
(CircuitNodeId(population="default", id=0), CircuitNodeId(population="default", id=1)),
(CircuitNodeId(population="default", id=2), CircuitNodeId(population="default", id=1)),
]

def test_iter_connections_6(self):
it = self.test_obj.iter_connections([2], None)
assert list(it) == [(2, 0), (2, 1)]
assert list(it) == [
(CircuitNodeId(population="default", id=2), CircuitNodeId(population="default", id=0)),
(CircuitNodeId(population="default", id=2), CircuitNodeId(population="default", id=1)),
]

def test_iter_connections_7(self):
it = self.test_obj.iter_connections([], [0, 1, 2])
assert list(it) == []

def test_iter_connections_8(self):
it = self.test_obj.iter_connections([0, 2], [1], return_edge_ids=True)
npt.assert_equal(list(it), [(0, 1, [1, 2]), (2, 1, [3])])
npt.assert_equal(
list(it),
[
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=1),
CircuitEdgeIds.from_dict({"default": [1, 2]}),
),
(
CircuitNodeId(population="default", id=2),
CircuitNodeId(population="default", id=1),
CircuitEdgeIds.from_dict({"default": [3]}),
),
],
)

def test_iter_connections_9(self):
it = self.test_obj.iter_connections([0, 2], [1], return_edge_count=True)
assert list(it) == [(0, 1, 2), (2, 1, 1)]
assert list(it) == [
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=1),
2,
),
(
CircuitNodeId(population="default", id=2),
CircuitNodeId(population="default", id=1),
1,
),
]

def test_iter_connections_10(self):
with pytest.raises(BluepySnapError):
Expand All @@ -577,25 +619,104 @@ def test_iter_connection_unique(self):
test_obj = TestEdgePopulation.get_edge_population(config_path, "default")

it = test_obj.iter_connections([0, 1, 2], [0, 1, 2])
assert sorted(it) == [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
assert sorted(it) == [
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=1),
),
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=2),
),
(
CircuitNodeId(population="default", id=1),
CircuitNodeId(population="default", id=0),
),
(
CircuitNodeId(population="default", id=1),
CircuitNodeId(population="default", id=2),
),
(
CircuitNodeId(population="default", id=2),
CircuitNodeId(population="default", id=0),
),
(
CircuitNodeId(population="default", id=2),
CircuitNodeId(population="default", id=1),
),
]

it = test_obj.iter_connections([0, 1, 2], [0, 1, 2], unique_node_ids=True)
assert sorted(it) == [(0, 1), (1, 0)]
assert sorted(it) == [
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=1),
),
(
CircuitNodeId(population="default", id=1),
CircuitNodeId(population="default", id=0),
),
]

it = test_obj.iter_connections([0, 1, 2], [0, 2], unique_node_ids=True)
assert sorted(it) == [(0, 2), (1, 0)]
assert sorted(it) == [
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=2),
),
(
CircuitNodeId(population="default", id=1),
CircuitNodeId(population="default", id=0),
),
]

it = test_obj.iter_connections([0, 2], [0, 2], unique_node_ids=True)
assert sorted(it) == [(0, 2), (2, 0)]
assert sorted(it) == [
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=2),
),
(
CircuitNodeId(population="default", id=2),
CircuitNodeId(population="default", id=0),
),
]

it = test_obj.iter_connections([0, 1, 2], [0, 2, 1], unique_node_ids=True)
assert sorted(it) == [(0, 1), (1, 0)]
assert sorted(it) == [
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=1),
),
(
CircuitNodeId(population="default", id=1),
CircuitNodeId(population="default", id=0),
),
]

it = test_obj.iter_connections([1, 2], [0, 1, 2], unique_node_ids=True)
assert sorted(it) == [(1, 0), (2, 1)]
assert sorted(it) == [
(
CircuitNodeId(population="default", id=1),
CircuitNodeId(population="default", id=0),
),
(
CircuitNodeId(population="default", id=2),
CircuitNodeId(population="default", id=1),
),
]

it = test_obj.iter_connections([0, 1, 2], [1, 2], unique_node_ids=True)
assert sorted(it) == [(0, 1), (1, 2)]
assert sorted(it) == [
(
CircuitNodeId(population="default", id=0),
CircuitNodeId(population="default", id=1),
),
(
CircuitNodeId(population="default", id=1),
CircuitNodeId(population="default", id=2),
),
]

def test_h5_filepath_from_config(self):
assert self.test_obj.h5_filepath == str(TEST_DATA_DIR / "edges.h5")
Expand Down