Skip to content

Commit

Permalink
Add very basic testing for _plotting.py, fix a few minor things (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
joni-herttuainen authored Oct 23, 2023
1 parent 6315a6b commit 95619fa
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 12 deletions.
25 changes: 14 additions & 11 deletions bluepysnap/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_pyplot():
return plt


def spikes_firing_rate_histogram(filtered_report, time_binsize=None, ax=None): # pragma: no cover
def spikes_firing_rate_histogram(filtered_report, time_binsize=None, ax=None):
"""Spike firing rate histogram.
This plot shows the number of nodes firing during a range of time.
Expand Down Expand Up @@ -88,7 +88,7 @@ def spikes_firing_rate_histogram(filtered_report, time_binsize=None, ax=None):
return ax


def spike_raster(filtered_report, y_axis=None, ax=None): # pragma: no cover
def spike_raster(filtered_report, y_axis=None, ax=None):
"""Spike raster plot.
Shows a global overview of the circuit's firing nodes. The y axis can project either the
Expand Down Expand Up @@ -119,11 +119,14 @@ def spike_raster(filtered_report, y_axis=None, ax=None): # pragma: no cover
"ymax": -np.inf,
}

def _is_categorical_or_object(dtype):
return pd.api.types.is_object_dtype(dtype) or isinstance(dtype, pd.CategoricalDtype)

def _update_raster_properties():
if y_axis is None:
props["node_id_offset"] += spikes.nodes.size
props["pop_separators"].append(props["node_id_offset"])
elif isinstance(spikes.nodes.property_dtypes[y_axis], pd.CategoricalDtype):
elif _is_categorical_or_object(spikes.nodes.property_dtypes[y_axis]):
props["categorical_values"].update(spikes.nodes.property_values(y_axis))
else:
props["ymin"] = min(props["ymin"], spikes.nodes.get(properties=y_axis).min())
Expand All @@ -133,7 +136,7 @@ def _update_raster_properties():

# use np.int64 if displaying node_ids
dtype = spike_report[population_names[0]].nodes.property_dtypes[y_axis] if y_axis else IDS_DTYPE
if isinstance(dtype, pd.CategoricalDtype):
if _is_categorical_or_object(dtype):
# this is to prevent the problems when concatenating categoricals with unknown categories
dtype = str
data = pd.Series(index=report.index, dtype=dtype)
Expand Down Expand Up @@ -163,7 +166,7 @@ def _update_raster_properties():
ax.set_ylim(0, props["node_id_offset"])
ax.set_ylabel("nodes")
else:
if np.issubdtype(type(data.iloc[0]), np.number):
if np.issubdtype(data.dtype, np.number):
# automatically expended by plt if ymin == ymax
ax.set_ylim(props["ymin"], props["ymax"])
else:
Expand All @@ -181,7 +184,7 @@ def _update_raster_properties():
return ax


def spikes_isi(filtered_report, use_frequency=False, binsize=None, ax=None): # pragma: no cover
def spikes_isi(filtered_report, use_frequency=False, binsize=None, ax=None):
# pylint: disable=too-many-locals
"""Interspike interval histogram.
Expand All @@ -204,7 +207,9 @@ def spikes_isi(filtered_report, use_frequency=False, binsize=None, ax=None): #
if binsize is not None and binsize <= 0:
raise BluepySnapError(f"Invalid binsize = {binsize}. Should be > 0.")

gb = filtered_report.report.groupby(["ids", "population"])
# Added `observed=True` to silence pandas warning about changing default value.
# However, report should not contain categories that are not in the dataframe.
gb = filtered_report.report.groupby(["ids", "population"], observed=True)
values = np.concatenate([np.diff(node_spikes.index.to_numpy()) for _, node_spikes in gb])

if len(values) == 0:
Expand Down Expand Up @@ -232,9 +237,7 @@ def spikes_isi(filtered_report, use_frequency=False, binsize=None, ax=None): #
return ax


def spikes_firing_animation(
filtered_report, x_axis=Node.X, y_axis=Node.Y, dt=20, ax=None
): # pragma: no cover
def spikes_firing_animation(filtered_report, x_axis=Node.X, y_axis=Node.Y, dt=20, ax=None):
# pylint: disable=too-many-locals,too-many-arguments,anomalous-backslash-in-string
"""Simple animation of simulation spikes.
Expand Down Expand Up @@ -332,7 +335,7 @@ def update_animation(frame):
return anim, ax


def frame_trace(filtered_report, plot_type="mean", ax=None): # pragma: no cover
def frame_trace(filtered_report, plot_type="mean", ax=None):
"""Returns a plot displaying the voltage of a node or a compartment as a function of time.
Args:
Expand Down
135 changes: 134 additions & 1 deletion tests/test__plotting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import sys
from unittest.mock import patch
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest

import bluepysnap._plotting as test_module
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.simulation import Simulation
from bluepysnap.spike_report import FilteredSpikeReport, SpikeReport

from utils import TEST_DATA_DIR

# NOTE: The tests here are primarily to make sure all the code is covered and deprecation warnings,
# etc. are raised. They don't ensure nor really test the correctness of the functionality.


def test__get_pyplot():
Expand All @@ -15,3 +25,126 @@ def test__get_pyplot():

plt_test = test_module._get_pyplot()
assert plt_test is matplotlib.pyplot


def _get_filtered_spike_report():
return Simulation(TEST_DATA_DIR / "simulation_config.json").spikes.filter()


def _get_filtered_frame_report():
return Simulation(TEST_DATA_DIR / "simulation_config.json").reports["soma_report"].filter()


def test_spikes_firing_rate_histogram():
with pytest.raises(BluepySnapError, match="Invalid time_binsize"):
test_module.spikes_firing_rate_histogram(filtered_report=None, time_binsize=0)

filtered_report = _get_filtered_spike_report()
ax = test_module.spikes_firing_rate_histogram(filtered_report)
assert ax.xaxis.label.get_text() == "Time [ms]"
assert ax.yaxis.label.get_text() == "PSTH [Hz]"

ax.xaxis.label.set_text("Fake X")
ax.yaxis.label.set_text("Fake Y")

ax = test_module.spikes_firing_rate_histogram(filtered_report, ax=ax)
assert ax.xaxis.label.get_text() == "Fake X"
assert ax.yaxis.label.get_text() == "Fake Y"


def test_spike_raster():
filtered_report = _get_filtered_spike_report()

test_module.spike_raster(filtered_report)
test_module.spike_raster(filtered_report, y_axis="y")

ax = test_module.spike_raster(filtered_report, y_axis="mtype")

assert ax.xaxis.label.get_text() == "Time [ms]"
assert ax.yaxis.label.get_text() == "mtype"

ax.xaxis.label.set_text("Fake X")
ax.yaxis.label.set_text("Fake Y")

ax = test_module.spike_raster(filtered_report, y_axis="mtype", ax=ax)
assert ax.xaxis.label.get_text() == "Fake X"
assert ax.yaxis.label.get_text() == "Fake Y"

# Have error raised in node_population get
filtered_report.spike_report["default"].nodes.get = Mock(
side_effect=BluepySnapError("Fake error")
)
test_module.spike_raster(filtered_report, y_axis="mtype")


def test_spikes_isi():
with pytest.raises(BluepySnapError, match="Invalid binsize"):
test_module.spikes_isi(filtered_report=None, binsize=0)

filtered_report = _get_filtered_spike_report()

ax = test_module.spikes_isi(filtered_report)
assert ax.xaxis.label.get_text() == "Interspike interval [ms]"
assert ax.yaxis.label.get_text() == "Bin weight"

ax = test_module.spikes_isi(filtered_report, use_frequency=True, binsize=42)
assert ax.xaxis.label.get_text() == "Frequency [Hz]"
assert ax.yaxis.label.get_text() == "Bin weight"

ax.xaxis.label.set_text("Fake X")
ax.yaxis.label.set_text("Fake Y")
ax = test_module.spikes_isi(filtered_report, use_frequency=True, binsize=42, ax=ax)
assert ax.xaxis.label.get_text() == "Fake X"
assert ax.yaxis.label.get_text() == "Fake Y"

with patch.object(test_module.np, "concatenate", Mock(return_value=[])):
with pytest.raises(BluepySnapError, match="No data to display"):
test_module.spikes_isi(filtered_report)


def test_spikes_firing_animation(tmp_path):
with pytest.raises(BluepySnapError, match="Fake is not a valid axis"):
test_module.spikes_firing_animation(filtered_report=None, x_axis="Fake")

with pytest.raises(BluepySnapError, match="Fake is not a valid axis"):
test_module.spikes_firing_animation(filtered_report=None, y_axis="Fake")

filtered_report = _get_filtered_spike_report()
anim, ax = test_module.spikes_firing_animation(filtered_report, dt=0.2)
assert ax.title.get_text() == "time = 0.1ms"

# convert to video to have `update_animation` called
anim.save(tmp_path / "test.gif")

ax.title.set_text("Fake Title")
anim, ax = test_module.spikes_firing_animation(filtered_report, dt=0.2, ax=ax)
assert ax.title.get_text() == "Fake Title"
anim.save(tmp_path / "test.gif")

# Have error raised in node_population get
filtered_report.spike_report["default"].nodes.get = Mock(
side_effect=BluepySnapError("Fake error")
)

anim, _ = test_module.spikes_firing_animation(filtered_report, dt=0.2)
anim.save(tmp_path / "test.gif")


def test_frame_trace():
with pytest.raises(BluepySnapError, match="Unknown plot_type Fake."):
test_module.frame_trace(filtered_report=None, plot_type="Fake", ax="also fake")

filtered_report = _get_filtered_frame_report()
test_module.frame_trace(filtered_report)
ax = test_module.frame_trace(filtered_report, plot_type="all")

assert ax.xaxis.label.get_text() == "Time [ms]"
assert ax.yaxis.label.get_text() == "Voltage [mV]"

ax.xaxis.label.set_text("Fake X")
ax.yaxis.label.set_text("Fake Y")

ax = test_module.frame_trace(filtered_report, plot_type="all", ax=ax)

assert ax.xaxis.label.get_text() == "Fake X"
assert ax.yaxis.label.get_text() == "Fake Y"

0 comments on commit 95619fa

Please sign in to comment.