Skip to content

Commit

Permalink
Add query key population_type (#229)
Browse files Browse the repository at this point in the history
* Added possibility to query Edge IDs and Node IDs based on edge/node population type

* minor fixes

* add population_type in docstring

* return False in _positional mask if no ids

* update CHANGELOG

* add an example in group concept

* add population_type query examples in the notebooks

* fix test_query.py
  • Loading branch information
joni-herttuainen authored Apr 12, 2024
1 parent fb301e2 commit dc2793c
Show file tree
Hide file tree
Showing 9 changed files with 444 additions and 37 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Improvements
- ``properties`` is now a keyword argument in ``EdgePopulation.get``
- Added ``EdgePopulation.stats`` with two methods: ``divergence``, ``convergence``
- Added new notebooks covering node sets as well as node and edge queries
- Added the possibility to query Edge IDs and Node IDs based on edge/node population type using query key ``population_type``

- the types conform to `node types <https://sonata-extension.readthedocs.io/en/latest/sonata_config.html#populations>`_ and `edge types <https://sonata-extension.readthedocs.io/en/latest/sonata_config.html#id4>`_ defined in the sonata specification


Version v3.0.1
Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/edges/edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _edge_ids_by_filter(self, queries, raise_missing_prop):
chunk_size = int(1e8)
for chunk in np.array_split(ids, 1 + len(ids) // chunk_size):
data = self.get(chunk, properties - unknown_props)
res.extend(chunk[query.resolve_ids(data, self.name, queries)])
res.extend(chunk[query.resolve_ids(data, self.name, self.type, queries)])
return np.array(res, dtype=IDS_DTYPE)

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
Expand Down
4 changes: 3 additions & 1 deletion bluepysnap/nodes/node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
>>> nodes.ids(group={ Node.LAYER: 2}) # returns list of IDs matching layer==2
>>> nodes.ids(group={ Node.LAYER: [2, 3]}) # returns list of IDs with layer in [2,3]
>>> nodes.ids(group={ Node.X: (0, 1)}) # returns list of IDs with 0 < x < 1
>>> # returns list of IDs of biophysical node populations
>>> nodes.ids(group={ "population_type": "biophysical"})
>>> # returns list of IDs matching one of the queries inside the 'or' list
>>> nodes.ids(group={'$or': [{ Node.LAYER: [2, 3]},
>>> { Node.X: (0, 1), Node.MTYPE: 'L1_SLAC' }]})
Expand Down Expand Up @@ -371,7 +373,7 @@ def _node_ids_by_filter(self, queries, raise_missing_prop):
self._check_properties(properties)
# load all the properties needed to execute the query, excluding the unknown properties
data = self._get_data(properties & self.property_names)
idx = query.resolve_ids(data, self.name, queries)
idx = query.resolve_ids(data, self.name, self.type, queries)
return idx.nonzero()[0]

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
Expand Down
29 changes: 22 additions & 7 deletions bluepysnap/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@
NODE_ID_KEY = "node_id"
EDGE_ID_KEY = "edge_id"
POPULATION_KEY = "population"
POPULATION_TYPE_KEY = "population_type"
OR_KEY = "$or"
AND_KEY = "$and"
REGEX_KEY = "$regex"
NODE_SET_KEY = "$node_set"
VALUE_KEYS = {REGEX_KEY}
ALL_KEYS = {NODE_ID_KEY, EDGE_ID_KEY, POPULATION_KEY, OR_KEY, AND_KEY, NODE_SET_KEY} | VALUE_KEYS
ALL_KEYS = {
NODE_ID_KEY,
EDGE_ID_KEY,
POPULATION_KEY,
POPULATION_TYPE_KEY,
OR_KEY,
AND_KEY,
NODE_SET_KEY,
} | VALUE_KEYS


def _logical_and(masks):
Expand Down Expand Up @@ -92,29 +101,34 @@ def _positional_mask(data, ids):
return True
if isinstance(ids, int):
ids = [ids]
elif len(ids) == 0:
return False
mask = np.full(len(data), fill_value=False)
indices = data.index.get_indexer(ids)
mask[indices[indices > -1]] = True
return mask


def _circuit_mask(data, population_name, queries):
def _circuit_mask(data, population_name, population_type, queries):
"""Handle the population, node ID queries."""
populations = queries.pop(POPULATION_KEY, None)
if populations is not None and population_name not in set(utils.ensure_list(populations)):
types = queries.pop(POPULATION_TYPE_KEY, None)
if populations is not None and population_name not in utils.ensure_list(populations):
ids = []
elif types is not None and population_type not in utils.ensure_list(types):
ids = []
else:
ids = queries.pop(NODE_ID_KEY, queries.pop(EDGE_ID_KEY, None))
return queries, _positional_mask(data, ids)


def _properties_mask(data, population_name, queries):
def _properties_mask(data, population_name, population_type, queries):
"""Return mask of IDs matching `props` dict."""
unknown_props = set(queries) - set(data.columns) - ALL_KEYS
if unknown_props:
return False

queries, mask = _circuit_mask(data, population_name, queries)
queries, mask = _circuit_mask(data, population_name, population_type, queries)
if mask is False or isinstance(mask, np.ndarray) and not mask.any():
# Avoid fail and/or processing time if wrong population or no nodes
return False
Expand Down Expand Up @@ -202,12 +216,13 @@ def _resolve(queries, queries_key):
return resolved_queries


def resolve_ids(data, population_name, queries):
def resolve_ids(data, population_name, population_type, queries):
"""Returns an index mask of `data` for given `queries`.
Args:
data (pd.DataFrame): data
population_name (str): population name of `data`
population_type (str): population type
queries (dict): queries
Returns:
Expand All @@ -229,7 +244,7 @@ def _collect(queries, queries_key):
queries[queries_key] = _logical_and(children_mask)
else:
queries[queries_key] = _properties_mask(
data, population_name, {queries_key: queries[queries_key]}
data, population_name, population_type, {queries_key: queries[queries_key]}
)

queries = deepcopy(queries)
Expand Down
159 changes: 150 additions & 9 deletions doc/source/notebooks/09_node_queries.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"On top of these, a query can also be based on `node_id` or the `population_type`.\n",
"\n",
"When the query is a `dict` and there is a `list` in the query, it is (usually) considered as an \"OR\" condition, and the keys of the query are considered as an \"AND\" condition. E.g.,\n",
"```python\n",
"circuit.nodes.ids({ # give me the ids of nodes that\n",
Expand Down Expand Up @@ -398,6 +400,145 @@
"pd.concat([df for _,df in result])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Querying with population type\n",
"We can query nodes (or edges) based on their population type as specified in the [SONATA circuit configuration file](https://sonata-extension.readthedocs.io/en/latest/sonata_config.html).\n",
"\n",
"Let's find all the source nodes of projections (i.e., `virtual` nodes):"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>model_template</th>\n",
" <th>model_type</th>\n",
" </tr>\n",
" <tr>\n",
" <th>population</th>\n",
" <th>node_ids</th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">CorticoThalamic_projections</th>\n",
" <th>0</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">MedialLemniscus_projections</th>\n",
" <th>5018</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5019</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5020</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5021</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5022</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>88443 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" model_template model_type\n",
"population node_ids \n",
"CorticoThalamic_projections 0 virtual\n",
" 1 virtual\n",
" 2 virtual\n",
" 3 virtual\n",
" 4 virtual\n",
"... ... ...\n",
"MedialLemniscus_projections 5018 virtual\n",
" 5019 virtual\n",
" 5020 virtual\n",
" 5021 virtual\n",
" 5022 virtual\n",
"\n",
"[88443 rows x 2 columns]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = circuit.nodes.get({'population_type': ['virtual']})\n",
"pd.concat([df for _,df in result])"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -409,7 +550,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -508,7 +649,7 @@
"[65198 rows x 1 columns]"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -532,7 +673,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -562,7 +703,7 @@
"dtype: object"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -583,7 +724,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -684,7 +825,7 @@
"28381 135.235031 517.312378 428.180695"
]
},
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -710,7 +851,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -740,7 +881,7 @@
" names=['population', 'node_ids'], length=35567)"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -769,7 +910,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand Down
Loading

0 comments on commit dc2793c

Please sign in to comment.