Skip to content

Commit

Permalink
drop '.dat' support in simulation validation (#252)
Browse files Browse the repository at this point in the history
* check all populations if `source` not defined for input spikes

* small typo in tests

* remove support for 'dat' files'

* Remove source frome 'synapse_replay'

* Update tests to not expect synapse_replay.source to be validated

* fix lint
  • Loading branch information
joni-herttuainen authored May 13, 2024
1 parent 06d9e52 commit d7ca855
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 95 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Improvements

- the types conform to `node types <https://sonata-extension.readthedocs.io/en/latest/sonata_config.html#populations>`_ and `edge types <https://sonata-extension.readthedocs.io/en/latest/sonata_config.html#id4>`_ defined in the sonata specification
- teach the `bluepysnap validate-circuit` and `bluepysnap validate-simulation` the ability to `--ignore-datatype-errors` so that mismatches of datatypes to the specification are ignored
- Update simulation validation to conform to the SONATA spec

- ``synapse_replay.source`` and ``.dat`` spike input files are no longer supported


Version v3.0.1
Expand Down
2 changes: 0 additions & 2 deletions bluepysnap/schemas/definitions/simulation_input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ $input_defs:
type: number
spike_file:
type: string
source:
type: string
tau:
type: number
variance:
Expand Down
52 changes: 25 additions & 27 deletions bluepysnap/simulation_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import h5py
import libsonata
import numpy as np
import pandas as pd

from bluepysnap import schemas
from bluepysnap.circuit_ids import CircuitNodeIds
Expand Down Expand Up @@ -219,15 +218,12 @@ def _get_ids_from_spike_file(file_):
"""Get unique gids from an input spikes file."""
file_ = Path(file_)
suffix = file_.suffix
if suffix == ".dat":
spikes = pd.read_csv(file_, delimiter=r"\s+", skiprows=1, header=None, names=["t", "id"])
return set(spikes["id"].values - 1)
elif suffix == ".h5":
if suffix == ".h5":
spikes = libsonata.SpikeReader(file_)
populations = spikes.get_population_names()
return {pop: set(spikes[pop].get_dict()["node_ids"]) for pop in populations}

raise IOError(f"Unknown file type: '{suffix}' (supported: '.h5', '.dat')")
raise IOError(f"Unsupported file type: '{suffix}' (supported: '.h5')")


def _get_ids_from_node_set(node_set, config):
Expand All @@ -244,6 +240,19 @@ def _get_ids_from_node_set(node_set, config):
return ids_per_population


def _get_ids_from_populations(config, only_non_virtual=False):
"""Get node ids of populations."""
circuit = libsonata.CircuitConfig.from_file(config["_circuit_config"])
populations = circuit.node_populations

if only_non_virtual:
populations = [
pop for pop in populations if circuit.node_population_properties(pop).type != "virtual"
]

return {pop: circuit.node_population(pop).select_all().flatten() for pop in populations}


def _get_missing_ids(sub_ids, super_ids):
"""Get `sub_ids` ids missing from `super_ids`."""
if isinstance(sub_ids, set):
Expand Down Expand Up @@ -277,26 +286,26 @@ def _validate_spike_file_contents(input_, config, prefix):
except IOError as e:
return [BluepySnapValidationError.fatal(f"{prefix}: {' '.join(map(str,e.args))}")]

nodeset_ids = _get_ids_from_node_set(input_["source"], config)
source = f"node set '{input_['source']}'"
if nodeset := input_.get("source"):
sim_ids = _get_ids_from_node_set(nodeset, config)
source = f"node set '{nodeset}'"
else:
sim_ids = _get_ids_from_populations(config)
source = "node populations"

return _compare_ids(spike_ids, nodeset_ids, source, prefix)
return _compare_ids(spike_ids, sim_ids, source, prefix)


def _validate_spike_input(name, input_, config):
errors = []

if (key := "source") in input_:
prefix = f"inputs.{name}.{key}"
errors += _validate_node_set_exists(config, input_[key], prefix=prefix)

if (key := "spike_file") in input_:
spike_path = _resolve_path(input_[key], config)

prefix = f"inputs.{name}.{key}"
errors += _validate_file_exists(spike_path, prefix=prefix)

if len(errors) > 0 or "source" not in input_ or not _file_exists(config["_circuit_config"]):
if len(errors) > 0 or not _file_exists(config["_circuit_config"]):
errors += [BluepySnapValidationError.fatal(f"{prefix}: Can not validate file contents")]
else:
errors += _validate_spike_file_contents(input_, config, prefix)
Expand Down Expand Up @@ -429,17 +438,6 @@ def validate_reports(config):
return errors


def _get_ids_from_non_virtual_pops(config):
"""Get ids of all non-virtual populations."""
circuit = libsonata.CircuitConfig.from_file(config["_circuit_config"])

return {
pop: circuit.node_population(pop).select_all().flatten()
for pop in circuit.node_populations
if circuit.node_population_properties(pop).type != "virtual"
}


def _validate_electrodes_file(path, config):
"""Validate the ids for each of the populations in `electrodes_file` can be found."""
prefix = "run.electrodes_file"
Expand All @@ -458,8 +456,8 @@ def _validate_electrodes_file(path, config):
source = f"node set '{node_set}'"
sim_ids = _get_ids_from_node_set(node_set, config)
else:
source = "non-virtual populations"
sim_ids = _get_ids_from_non_virtual_pops(config)
source = "non-virtual node populations"
sim_ids = _get_ids_from_populations(config, only_non_virtual=True)

return _compare_ids(elec_ids, sim_ids, source, prefix)

Expand Down
145 changes: 79 additions & 66 deletions tests/test_simulation_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from unittest.mock import MagicMock, Mock, call, patch

import numpy.testing as npt
import pandas as pd
import pytest

import bluepysnap.simulation_validation as test_module
Expand Down Expand Up @@ -294,16 +293,11 @@ def test_validate_connection_overrides(mock_validate_override):


def test__get_ids_from_spike_file(tmp_path):
spike_path = tmp_path / "spikes.dat"
pd.DataFrame({"/scatter": [1]}).to_csv(spike_path, sep="\t")

assert test_module._get_ids_from_spike_file(spike_path) == {0}

spike_path = TEST_DATA_DIR / "input_spikes.h5"
assert test_module._get_ids_from_spike_file(spike_path) == {"default": {0}}

with pytest.raises(IOError, match=r"Unknown file type: '.fake' \(supported: '.h5', '.dat'\)"):
test_module._get_ids_from_spike_file("fake_spikes.fake")
with pytest.raises(IOError, match=r"Unsupported file type: '.dat' \(supported: '.h5'\)"):
test_module._get_ids_from_spike_file(tmp_path / "spikes.dat")


def test__get_ids_from_node_set():
Expand All @@ -323,6 +317,22 @@ def test__get_ids_from_node_set():
assert test_module._get_ids_from_node_set("fake_node_set", config) == {}


def test__get_ids_from_populations():
with copy_test_data() as (_, config_path):
with edit_config(config_path) as circuit_config:
circuit_config["networks"]["nodes"][0]["populations"]["default2"]["type"] = "virtual"

config = {"_circuit_config": TEST_DATA_DIR / "circuit_config.json"}
res = test_module._get_ids_from_populations(config)
expected = {"default": [0, 1, 2], "default2": [0, 1, 2, 3]}
npt.assert_equal(res, expected)

config = {"_circuit_config": config_path}
res = test_module._get_ids_from_populations(config, only_non_virtual=True)
expected = {"default": [0, 1, 2]}
npt.assert_equal(res, expected)


def test__get_missing_ids():
nodeset_ids = {"test": [1, 2, 3], "test2": [4, 5]}
spike_ids_from_dat = {1, 2, 3, 4, 5}
Expand Down Expand Up @@ -367,45 +377,66 @@ def test__compare_ids(mock_missing_ids):
assert test_module._compare_ids(None, None, source, prefix) == expected


@patch.object(test_module, "_get_ids_from_node_set", new=Mock())
@patch.object(test_module, "_resolve_path", new=Mock())
@patch.object(test_module, "_get_ids_from_spike_file")
@patch.object(test_module, "_get_missing_ids")
def test__validate_spike_file_contents(mock_missing_ids, mock_ids_from_spikes):
input_config = {"source": "fake_node_set", "spike_file": "fake_spikes.h5"}

mock_missing_ids.return_value = []
res = test_module._validate_spike_file_contents(input_config, config=None, prefix="")
expected = []
assert res == expected

mock_missing_ids.return_value = [0, 1, 2]
res = test_module._validate_spike_file_contents(input_config, config=None, prefix="fake_prefix")
msg = "fake_prefix: 3 id(s) not found in node set 'fake_node_set': 0, 1, 2"
expected = [BluepySnapValidationError.fatal(msg)]
assert res == expected

mock_missing_ids.return_value = [("fake_population", id_) for id_ in [0, 1, 2]]
res = test_module._validate_spike_file_contents(input_config, config=None, prefix="fake_prefix")
msg = (
"fake_prefix: 3 id(s) not found in node set 'fake_node_set': "
"('fake_population', 0), ('fake_population', 1), ('fake_population', 2)"
)
expected = [BluepySnapValidationError.fatal(msg)]
assert res == expected
@pytest.mark.parametrize(
"input_config,expected_message,spike_ids_side_effect",
[
[{"spike_file": "fake_spikes.h5"}, None, lambda *_: {0, 1}],
[
{"source": "fake_node_set", "spike_file": "fake_spikes.h5"},
"3 id(s) not found in node set 'fake_node_set': 5, 6, 7",
lambda *_: {5, 6, 7},
],
[
{"spike_file": "fake_spikes.h5"},
(
"3 id(s) not found in node populations: "
"('fake_population', 0), ('fake_population', 1), ('fake_population', 2)"
),
lambda *_: {"fake_population": [0, 1, 2]},
],
[{"source": "fake_node_set", "spike_file": "fake_spikes.h5"}, None, lambda *_: {0, 1}],
[
{"spike_file": "fake_spikes.h5"},
"3 id(s) not found in node populations: 5, 6, 7",
lambda *_: {5, 6, 7},
],
[
{"source": "fake_node_set", "spike_file": "fake_spikes.h5"},
(
"3 id(s) not found in node set 'fake_node_set': "
"('fake_population', 0), ('fake_population', 1), ('fake_population', 2)"
),
lambda *_: {"fake_population": [0, 1, 2]},
],
[{"spike_file": "fake_spikes.h5"}, "Unknown IOError", IOError("Unknown", "IOError")],
],
)
def test__validate_spike_file_contents(
mock_ids_from_spikes, input_config, expected_message, spike_ids_side_effect
):
config = {
"_config_dir": TEST_DATA_DIR,
"_circuit_config": TEST_DATA_DIR / "circuit_config.json",
"_node_sets_instance": NodeSets.from_dict(
{"fake_node_set": {"population": ["default"], "node_id": [0, 1]}}
),
}
prefix = "fake"
mock_ids_from_spikes.side_effect = spike_ids_side_effect
if expected_message is not None:
expected = [BluepySnapValidationError.fatal(f"{prefix}: {expected_message}")]
else:
expected = []
res = test_module._validate_spike_file_contents(input_config, config, prefix)

mock_ids_from_spikes.side_effect = IOError("Unknown", "IOError")
res = test_module._validate_spike_file_contents(input_config, config=None, prefix="fake_prefix")
msg = "fake_prefix: Unknown IOError"
expected = [BluepySnapValidationError.fatal(msg)]
assert res == expected


def test__validate_spike_input():
node_sets = NodeSets.from_dict({"fake_node_set": {"node_id": [0]}})

input_config = {
"source": "fake_node_set",
"spike_file": TEST_DATA_DIR / "input_spikes.h5",
}
config = {
Expand All @@ -421,7 +452,6 @@ def test__validate_spike_input():
}

expected_error_messages = [
"inputs.test.source: Unknown node set: 'fail_node_set'",
f"inputs.test.spike_file: No such file: {input_config['spike_file']}",
"inputs.test.spike_file: Can not validate file contents",
]
Expand Down Expand Up @@ -535,16 +565,14 @@ def test_validate_inputs():
"pass_3": {"module": "not_synapse_replay", "source": "fail_node_set"},
"pass_4": {"module": "not_synapse_replay", "spike_file": fail_spike_file},
"fail_0": {"module": "test_module", "node_set": "fail_node_set"},
"fail_1": {"module": "synapse_replay", "source": "fail_node_set"},
"fail_2": {"module": "synapse_replay", "spike_file": fail_spike_file},
"fail_1": {"module": "synapse_replay", "spike_file": fail_spike_file},
},
}

expected_error_messages = [
"inputs.fail_0.node_set: Unknown node set: 'fail_node_set'",
"inputs.fail_1.source: Unknown node set: 'fail_node_set'",
f"inputs.fail_2.spike_file: No such file: {fail_spike_file}",
f"inputs.fail_2.spike_file: Can not validate file contents",
f"inputs.fail_1.spike_file: No such file: {fail_spike_file}",
f"inputs.fail_1.spike_file: Can not validate file contents",
]

expected = [BluepySnapValidationError.fatal(msg) for msg in expected_error_messages]
Expand Down Expand Up @@ -760,27 +788,12 @@ def test_validate_reports(tmp_path):
assert test_module.validate_reports(config) == expected


def test__get_ids_from_non_virtual_pops():
config = {"_circuit_config": TEST_DATA_DIR / "circuit_config.json"}
res = test_module._get_ids_from_non_virtual_pops(config)
expected = {"default": [0, 1, 2], "default2": [0, 1, 2, 3]}
npt.assert_equal(res, expected)

with copy_test_data() as (_, config_path):
with edit_config(config_path) as circuit_config:
circuit_config["networks"]["nodes"][0]["populations"]["default2"]["type"] = "virtual"

config = {"_circuit_config": config_path}
res = test_module._get_ids_from_non_virtual_pops(config)
expected = {"default": [0, 1, 2]}
npt.assert_equal(res, expected)


def test__validate_electrodes_file():
path = "./fake_path"
prefix = "run.electrodes_file"
expected = [
BluepySnapValidationError.fatal(f"run.electrodes_file: No such file: {TEST_DATA_DIR/path}"),
BluepySnapValidationError.fatal(f"run.electrodes_file: Can not validate file contents"),
BluepySnapValidationError.fatal(f"{prefix}: No such file: {TEST_DATA_DIR/path}"),
BluepySnapValidationError.fatal(f"{prefix}: Can not validate file contents"),
]
config = {"run": {"electrodes_file": path}, "_config_dir": TEST_DATA_DIR}
assert test_module.validate_run(config) == expected
Expand All @@ -793,9 +806,9 @@ def test__validate_electrodes_file():
}
assert test_module.validate_run(config) == []

with patch.object(test_module, "_get_ids_from_non_virtual_pops") as patched:
with patch.object(test_module, "_get_ids_from_populations") as patched:
patched.return_value = {"default": {0}}
msg = "run.electrodes_file: 1 id(s) not found in non-virtual populations: ('default', 1)"
msg = f"{prefix}: 1 id(s) not found in non-virtual node populations: ('default', 1)"
expected = [BluepySnapValidationError.fatal(msg)]
assert test_module.validate_run(config) == expected

Expand All @@ -804,12 +817,12 @@ def test__validate_electrodes_file():
assert test_module.validate_run(config) == []

config["node_set"] = "Layer23"
msg = "run.electrodes_file: 1 id(s) not found in node set 'Layer23': ('default', 1)"
msg = f"{prefix}: 1 id(s) not found in node set 'Layer23': ('default', 1)"
expected = [BluepySnapValidationError.fatal(msg)]
assert test_module.validate_run(config) == expected

config["_circuit_config"] = ""
msg = "run.electrodes_file: Can not validate file contents"
msg = f"{prefix}: Can not validate file contents"
expected = [BluepySnapValidationError.fatal(msg)]
assert test_module.validate_run(config) == expected

Expand Down

0 comments on commit d7ca855

Please sign in to comment.