Skip to content

Commit

Permalink
Add include_empty keyword argument
Browse files Browse the repository at this point in the history
When using get-like functions (edges.pathway_edges, nodes.get) it is common to concatenate
the outputs of the functions to mimic the old bluepysnap interface.
However, if no nodes/edges satisfy the query, instead of resulting in an empty dataframe,
this approach will result in an error, because no dataframes at all are yielded by the generator.

The include_empty keyword argument allows the user to specify that they still want to see a dataframe if the query turns up nothing.

One example where this is useful for me is when I use the length of a connection dataframe to determine connection probability.
  • Loading branch information
HDictus committed Nov 29, 2023
1 parent c094e4f commit 42acc44
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 8 deletions.
19 changes: 15 additions & 4 deletions bluepysnap/edges/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def ids(self, group=None, sample=None, limit=None):
fun = lambda x: (x.ids(group), x.name)
return self._get_ids_from_pop(fun, CircuitEdgeIds, sample=sample, limit=limit)

def get(self, edge_ids=None, properties=None): # pylint: disable=arguments-renamed
def get(self,
edge_ids=None,
properties=None,
include_empty=False
): # pylint: disable=arguments-renamed
"""Edge properties by iterating populations.
Args:
Expand All @@ -102,7 +106,7 @@ def get(self, edge_ids=None, properties=None): # pylint: disable=arguments-rena
raise BluepySnapError("You need to set edge_ids in get.")
if properties is None:
return edge_ids
return super().get(edge_ids, properties)
return super().get(edge_ids, properties, include_empty=include_empty)

def afferent_nodes(self, target, unique=True):
"""Get afferent CircuitNodeIDs for given target ``node_id``.
Expand Down Expand Up @@ -148,13 +152,20 @@ def efferent_nodes(self, source, unique=True):
result.unique(inplace=True)
return result

def pathway_edges(self, source=None, target=None, properties=None):
def pathway_edges(
self,
source=None,
target=None,
properties=None,
include_empty=False
):
"""Get edges corresponding to ``source`` -> ``target`` connections.
Args:
source: source node group
target: target node group
properties: None / edge property name / list of edge property names
include_empty: whether to include populations for which the query is empty
Returns:
- CircuitEdgeIDs, if ``properties`` is None;
Expand All @@ -172,7 +183,7 @@ def pathway_edges(self, source=None, target=None, properties=None):
)

if properties:
return self.get(result, properties)
return self.get(result, properties, include_empty=include_empty)
return result

def afferent_edges(self, node_id, properties=None):
Expand Down
4 changes: 2 additions & 2 deletions bluepysnap/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def ids(self, group=None, sample=None, limit=None):
"""Resolves the ids of the NetworkObject."""

@abc.abstractmethod
def get(self, group=None, properties=None):
def get(self, group=None, properties=None, include_empty=False):
"""Yields the properties of the NetworkObject."""
ids = self.ids(group)
properties = utils.ensure_list(properties)
Expand All @@ -147,7 +147,7 @@ def get(self, group=None, properties=None):
# since ids is sorted, global_pop_ids should be sorted as well
global_pop_ids = ids.filter_population(name)
pop_ids = global_pop_ids.get_ids()
if len(pop_ids) > 0:
if len(pop_ids) > 0 or include_empty:
pop_properties = properties_set & pop.property_names
# Since the columns are passed as Series, index cannot be specified directly.
# However, it's a bit more performant than converting the Series to numpy arrays.
Expand Down
6 changes: 4 additions & 2 deletions bluepysnap/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def ids(self, group=None, sample=None, limit=None):
fun = lambda x: (x.ids(group, raise_missing_property=False), x.name)
return self._get_ids_from_pop(fun, CircuitNodeIds, sample=sample, limit=limit)

def get(self, group=None, properties=None): # pylint: disable=arguments-differ
def get(self, group=None, properties=None, include_empty=False): # pylint: disable=arguments-differ
"""Node properties by iterating populations.
Args:
Expand All @@ -130,6 +130,8 @@ def get(self, group=None, properties=None): # pylint: disable=arguments-differ
properties (str/list): If specified, return only the properties in the list.
Otherwise return all properties.
include_empty: whether to include populations for which the query is empty
Returns:
generator: yields tuples of ``(<population_name>, pandas.DataFrame)``:
Expand All @@ -142,7 +144,7 @@ def get(self, group=None, properties=None): # pylint: disable=arguments-differ
if properties is None:
# not strictly needed, but ensure that the properties are always in the same order
properties = sorted(self.property_names)
return super().get(group, properties)
return super().get(group, properties, include_empty=include_empty)

def __getstate__(self):
"""Make Nodes pickle-able, without storing state of caches."""
Expand Down
16 changes: 16 additions & 0 deletions tests/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,22 @@ def test_pathway_edges(self):
check_dtype=False,
)

# check that 'include_empty' kwarg works
tested = self.test_obj.pathway_edges(
source=[1],
properties=properties,
include_empty=True)

expected = pd.DataFrame(
{prop: [] for prop in properties},
index=pd.MultiIndex.from_tuples(
[],
names=["population", "edge_ids"],
),
)
tested = pd.concat([df for _, df in tested])
pdt.assert_frame_equal(tested, expected, check_dtype=False, check_index_type=False)

# use global mapping for nodes
assert self.test_obj.pathway_edges(
source={"mtype": "L6_Y"}, target={"mtype": "L2_X"}
Expand Down
18 changes: 18 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,24 @@ def test_get(self):

with pytest.raises(BluepySnapError, match="Unknown properties required: {'unknown'}"):
next(self.test_obj.get(properties="unknown"))

tested = pd.concat([
df for _, df in self.test_obj.get({'layer': 10}, include_empty=True)
])
expected = pd.DataFrame(
{p: [] for p in list(self.test_obj.property_names)},
index=pd.MultiIndex.from_tuples(
[],
names=["population", "node_ids"]
)
)
pdt.assert_frame_equal(
tested.sort_index(axis=1),
expected.sort_index(axis=1),
check_dtype=False,
check_index_type=False
)


def test_functionality_with_separate_node_set(self):
with pytest.raises(BluepySnapError, match="Undefined node set"):
Expand Down

0 comments on commit 42acc44

Please sign in to comment.