From 32e119958fd6a8b4610b6b96ea44faf4b6d2bd3a Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Wed, 17 Jul 2024 10:04:57 +0200 Subject: [PATCH] [Fix] #613 spike_train_generation module to handle multichannel AnalogSignal inputs (#614) * fix docstring add type annotations * fix input checks peak detection * add tests for peak_extraction * add handling of multichannel analogsignals to peak detection --- doc/conf.py | 2 +- elephant/spike_train_generation.py | 286 ++++++++++++++----- elephant/test/test_spike_train_generation.py | 171 +++++++++-- 3 files changed, 355 insertions(+), 104 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 67b766adb..907952f7f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -351,7 +351,7 @@ intersphinx_mapping = { 'viziphant': ('https://viziphant.readthedocs.io/en/stable/', None), 'numpy': ('https://numpy.org/doc/stable', None), - 'neo': ('https://neo.readthedocs.io/en/stable/', None), + 'neo': ('https://neo.readthedocs.io/en/latest/', None), 'quantities': ('https://python-quantities.readthedocs.io/en/stable/', None), 'python': ('https://docs.python.org/3/', None), 'scipy': ('https://docs.scipy.org/doc/scipy/', None) diff --git a/elephant/spike_train_generation.py b/elephant/spike_train_generation.py index 1c279c61a..ecac8b41d 100644 --- a/elephant/spike_train_generation.py +++ b/elephant/spike_train_generation.py @@ -52,9 +52,10 @@ from __future__ import division, print_function, unicode_literals import warnings -from typing import List, Union, Optional +from typing import List, Literal, Union, Optional import neo +from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq from scipy import stats @@ -83,53 +84,21 @@ ] -def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above', - time_stamps=None, interval=(-2 * pq.ms, 4 * pq.ms)): - """ - Return the peak times for all events that cross threshold and the - waveforms. Usually used for extracting spikes from a membrane - potential to calculate waveform properties. - - Parameters - ---------- - signal : neo.AnalogSignal - An analog input signal. - threshold : pq.Quantity, optional - Contains a value that must be reached for an event to be detected. - Default: 0.0 * pq.mV - sign : {'above', 'below'}, optional - Determines whether to count threshold crossings that cross above or - below the threshold. - Default: 'above' - time_stamps : pq.Quantity, optional - If `spike_train` is a `pq.Quantity` array, `time_stamps` provides the - time stamps around which the waveform is extracted. If it is None, the - function `peak_detection` is used to calculate the time_stamps - from signal. - Default: None - interval : tuple of pq.Quantity - Specifies the time interval around the `time_stamps` where the waveform - is extracted. - Default: (-2 * pq.ms, 4 * pq.ms) - - Returns - ------- - result_st : neo.SpikeTrain - Contains the time_stamps of each of the spikes and the waveforms in - `result_st.waveforms`. - - See Also - -------- - elephant.spike_train_generation.peak_detection - """ +def _spike_extraction_from_single_channel( + signal: neo.core.AnalogSignal, + threshold: pq.Quantity = 0.0 * pq.mV, + sign: Literal['above', 'below'] = 'above', + time_stamps: neo.core.SpikeTrain = None, + interval: tuple = (-2 * pq.ms, 4 * pq.ms) + ) -> neo.core.SpikeTrain: # Get spike time_stamps if time_stamps is None: time_stamps = peak_detection(signal, threshold, sign=sign) elif hasattr(time_stamps, 'times'): time_stamps = time_stamps.times - elif isinstance(time_stamps, pq.Quantity): - raise TypeError("time_stamps must be None, a pq.Quantity array or" + - " expose the.times interface") + else: + raise TypeError("time_stamps must be None, a `neo.core.SpikeTrain`" + " or expose the.times interface") if len(time_stamps) == 0: return neo.SpikeTrain(time_stamps, units=signal.times.units, @@ -139,6 +108,7 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above', # Unpack the extraction interval from tuple or array extr_left, extr_right = interval + if extr_left > extr_right: raise ValueError("interval[0] must be < interval[1]") @@ -185,15 +155,23 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above', left_sweep=extr_left) -def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'): +def spike_extraction( + signal: neo.core.AnalogSignal, + threshold: pq.Quantity = 0.0 * pq.mV, + sign: Literal['above', 'below'] = 'above', + time_stamps: neo.core.SpikeTrain = None, + interval: tuple = (-2 * pq.ms, 4 * pq.ms), + always_as_list: bool = False + ) -> Union[neo.core.SpikeTrain, SpikeTrainList]: """ - Returns the times when the analog signal crosses a threshold. - Usually used for extracting spike times from a membrane potential. + Return the peak times for all events that cross threshold and the + waveforms. Usually used for extracting spikes from a membrane + potential to calculate waveform properties. Parameters ---------- - signal : neo.AnalogSignal - An analog input signal. + signal : :class:`neo.core.AnalogSignal` + An analog input signal one or more channels. threshold : pq.Quantity, optional Contains a value that must be reached for an event to be detected. Default: 0.0 * pq.mV @@ -201,20 +179,66 @@ def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'): Determines whether to count threshold crossings that cross above or below the threshold. Default: 'above' + time_stamps : :class:`neo.core.SpikeTrain` , optional + Provides the time stamps around which the waveform is extracted. If it + is None, the function `peak_detection` is used to calculate the + `time_stamps` from signal. + Default: None + interval : tuple of :class:`pq.Quantity` + Specifies the time interval around the `time_stamps` where the waveform + is extracted. + Default: (-2 * pq.ms, 4 * pq.ms) + always_as_list: bool, optional + If True, :class:`neo.core.spiketrainslist.SpikeTrainList` is returned. + Default: False Returns - ------- - result_st : neo.SpikeTrain - Contains the spike times of each of the events (spikes) extracted from - the signal. - """ + ------- # noqa + result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList`. + Contains the time_stamps of each of the spikes and the waveforms in + `result_st.waveforms`. - if not isinstance(threshold, pq.Quantity): - raise ValueError('threshold must be a pq.Quantity') + See Also + -------- + :func:`elephant.spike_train_generation.peak_detection` + """ + if isinstance(signal, neo.core.AnalogSignal): + if signal.shape[1] == 1: + if always_as_list: + return SpikeTrainList(items=( + _spike_extraction_from_single_channel( + signal, + threshold=threshold, + time_stamps=time_stamps, + interval=interval, + sign=sign),)) + else: + return _spike_extraction_from_single_channel( + signal, threshold=threshold, time_stamps=time_stamps, + interval=interval, sign=sign) + elif signal.shape[1] > 1: + spiketrainlist = SpikeTrainList() + for channel in range(signal.shape[1]): + spiketrainlist.append( + _spike_extraction_from_single_channel( + neo.core.AnalogSignal( + signal[:, channel], + sampling_rate=signal.sampling_rate), + threshold=threshold, sign=sign, + time_stamps=time_stamps, + interval=interval, + )) + return spiketrainlist + else: + raise TypeError( + f"Signal must be AnalogSignal, provided: {type(signal)}") - if sign not in ('above', 'below'): - raise ValueError("sign should be 'above' or 'below'") +def _threshold_detection_from_single_channel( + signal: neo.core.AnalogSignal, + threshold: pq.Quantity = 0.0 * pq.mV, + sign: str = 'above' + ) -> neo.core.SpikeTrain: if sign == 'above': cutout = np.where(signal > threshold)[0] else: @@ -242,53 +266,88 @@ def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'): return result_st -def peak_detection(signal, threshold=0.0 * pq.mV, sign='above', - as_array=False): +def threshold_detection( + signal: neo.core.AnalogSignal, + threshold: pq.Quantity = 0.0 * pq.mV, + sign: Literal['above', 'below'] = 'above', + always_as_list: bool = False, + ) -> Union[neo.core.SpikeTrain, SpikeTrainList]: """ - Return the peak times for all events that cross threshold. + Returns the times when the analog signal crosses a threshold. Usually used for extracting spike times from a membrane potential. - Similar to spike_train_generation.threshold_detection. Parameters ---------- - signal : neo.AnalogSignal - An analog input signal. - threshold : pq.Quantity, optional + signal : :class:`neo.core.AnalogSignal` + An analog input signal with one or multiple channels. + threshold : :class:`pq.Quantity`, optional Contains a value that must be reached for an event to be detected. - Default: 0.*pq.mV + Default: 0.0 * pq.mV sign : {'above', 'below'}, optional Determines whether to count threshold crossings that cross above or below the threshold. Default: 'above' - as_array : bool, optional - If True, a NumPy array of the resulting peak times is returned instead - of a (default) `neo.SpikeTrain` object. + always_as_list: bool, optional + If True, a :class:`neo.core.spiketrainslist.SpikeTrainList`. Default: False Returns - ------- - result_st : neo.SpikeTrain + ------- # noqa + result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList` Contains the spike times of each of the events (spikes) extracted from - the signal. + the signal. If `signal` is an AnalogSignal with multiple channels, or + `always_return_list=True` , a + :class:`neo.core.spiketrainlist.SpikeTrainList` is returned. """ if not isinstance(threshold, pq.Quantity): - raise ValueError("threshold must be a pq.Quantity") + raise TypeError('threshold must be a pq.Quantity') if sign not in ('above', 'below'): raise ValueError("sign should be 'above' or 'below'") + if isinstance(signal, neo.core.AnalogSignal): + if signal.shape[1] == 1: + if always_as_list: + return SpikeTrainList(items=( + _threshold_detection_from_single_channel( + signal, threshold=threshold, sign=sign),)) + else: + return _threshold_detection_from_single_channel( + signal, threshold=threshold, sign=sign) + elif signal.shape[1] > 1: + spiketrainlist = SpikeTrainList() + for channel in range(signal.shape[1]): + spiketrainlist.append(_threshold_detection_from_single_channel( + neo.core.AnalogSignal(signal[:, channel], + sampling_rate=signal.sampling_rate), + threshold=threshold, sign=sign) + ) + return spiketrainlist + else: + raise TypeError( + f"Signal must be AnalogSignal, provided: {type(signal)}") + + +# legacy implementation of peak_detection +def _peak_detection_from_single_channel( + signal: neo.core.AnalogSignal, + threshold: pq.Quantity = 0.0 * pq.mV, + sign: str = 'above', + as_array: bool = False + ) -> neo.core.SpikeTrain: if sign == 'above': cutout = np.where(signal > threshold)[0] peak_func = np.argmax - else: - # sign == 'below' + elif sign == 'below': cutout = np.where(signal < threshold)[0] peak_func = np.argmin + else: + raise ValueError("sign should be 'above' or 'below'") if len(cutout) == 0: events_base = np.zeros(0) else: - # Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2 + # Select the crossings lasting at least 2 dtps, np.diff(cutout) > 2 # This avoids empty slices border_start = np.where(np.diff(cutout) > 1)[0] border_end = border_start + 1 @@ -327,6 +386,83 @@ def peak_detection(signal, threshold=0.0 * pq.mV, sign='above', return result_st +def peak_detection(signal: neo.core.AnalogSignal, + threshold: pq.Quantity = 0.0 * pq.mV, + sign: Literal['above', 'below'] = 'above', + as_array: bool = False, + always_as_list: bool = False + ) -> Union[neo.core.SpikeTrain, SpikeTrainList]: + """ + Return the peak times for all events that cross threshold. + Usually used for extracting spike times from a membrane potential. + Similar to spike_train_generation.threshold_detection. + + Parameters + ---------- + signal : :class:`neo.core.AnalogSignal` + An analog input signal or a list of analog input signals. + threshold : :class:`pq.Quantity`, optional + Contains a value that must be reached for an event to be detected. + Default: 0.*pq.mV + sign : {'above', 'below'}, optional + Determines whether to count threshold crossings that cross above or + below the threshold. + Default: 'above' + as_array : bool, optional + If True, a NumPy array of the resulting peak times is returned instead + of a (default) `neo.SpikeTrain` object. + Default: False + always_as_list: bool, optional + If True, a :class:`neo.core.spiketrainslist.SpikeTrainList` is returned. + Default: False + + Returns + ------- # noqa + result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList` + :class:`np.ndarrav`, List[:class:`np.ndarrav`] + Contains the spike times of each of the events (spikes) extracted from + the signal. + If `signal` is an AnalogSignal with multiple channels or + `always_return_list=True` a list is returned. + """ + if not isinstance(threshold, pq.Quantity): + raise TypeError( + f"threshold must be a pq.Quantity, provided: {type(threshold)}") + + if isinstance(signal, neo.core.AnalogSignal): + if signal.shape[1] == 1: + if always_as_list and not as_array: + return SpikeTrainList(items=( + _peak_detection_from_single_channel( + signal, threshold=threshold, sign=sign, + as_array=as_array),)) + elif always_as_list and as_array: + return [_peak_detection_from_single_channel( + signal, threshold=threshold, sign=sign, as_array=as_array)] + else: + return _peak_detection_from_single_channel( + signal, threshold=threshold, sign=sign, as_array=as_array) + elif signal.shape[1] > 1 and as_array: + return [_peak_detection_from_single_channel(neo.core.AnalogSignal( + signal[:, channel], sampling_rate=signal.sampling_rate), + threshold=threshold, + sign=sign, as_array=as_array + ) for channel in range(signal.shape[1])] + elif signal.shape[1] > 1 and not as_array: + spiketrainlist = SpikeTrainList() + for channel in range(signal.shape[1]): + spiketrainlist.append(_peak_detection_from_single_channel( + neo.core.AnalogSignal(signal[:, channel], + sampling_rate=signal.sampling_rate), + threshold=threshold, + sign=sign, as_array=as_array + )) + return spiketrainlist + else: + raise TypeError( + f"Signal must be AnalogSignal, provided: {type(signal)}") + + class AbstractPointProcess: """ Abstract point process to subclass from. diff --git a/elephant/test/test_spike_train_generation.py b/elephant/test/test_spike_train_generation.py index 3ea160c35..e1a43716d 100644 --- a/elephant/test/test_spike_train_generation.py +++ b/elephant/test/test_spike_train_generation.py @@ -13,6 +13,7 @@ import warnings import neo +from neo.core.spiketrainlist import SpikeTrainList import numpy as np from numpy.testing import assert_array_almost_equal, assert_allclose import quantities as pq @@ -37,9 +38,10 @@ def pdiff(a, b): return abs((a - b) / a) -class AnalogSignalThresholdDetectionTestCase(unittest.TestCase): +class ThresholdDetectionTestCase(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): # Load membrane potential simulated using Brian2 # according to make_spike_extraction_test_data.py. curr_dir = os.path.dirname(os.path.realpath(__file__)) @@ -49,10 +51,14 @@ def setUp(self): with open(raw_data_file_loc, 'r') as f: for x in f.readlines(): raw_data.append(float(x)) - self.vm = neo.AnalogSignal( + cls.vm = neo.AnalogSignal( raw_data, units=pq.V, sampling_period=0.1 * pq.ms) - self.true_time_stamps = [0.0123, 0.0354, 0.0712, 0.1191, 0.1694, - 0.2200, 0.2711] * pq.s + cls.vm_3d = neo.AnalogSignal(np.array([raw_data, + raw_data, + raw_data]).T, + units=pq.V, sampling_period=0.1 * pq.ms) + cls.true_time_stamps = [0.0123, 0.0354, 0.0712, 0.1191, 0.1694, + 0.2200, 0.2711] * pq.s def test_threshold_detection(self): # Test whether spikes are extracted at the correct times from @@ -81,15 +87,52 @@ def test_threshold_detection(self): except AttributeError: # If numpy version too old to have allclose self.assertTrue(np.array_equal(spike_train, self.true_time_stamps)) - def test_peak_detection_threshold(self): + def test_threshold_detection_threshold(self): # Test for empty SpikeTrain when threshold is too high result = threshold_detection(self.vm, threshold=30 * pq.mV) self.assertEqual(len(result), 0) + def test_threshold_raise_type_error(self): + with self.assertRaises(TypeError): + threshold_detection(self.vm, threshold=30) -class AnalogSignalPeakDetectionTestCase(unittest.TestCase): + def test_sign_raise_value_error(self): + with self.assertRaises(ValueError): + threshold_detection(self.vm, sign="wrong input") - def setUp(self): + def test_return_is_neo_spike_train(self): + self.assertIsInstance(threshold_detection(self.vm), + neo.core.SpikeTrain) + + def test_signal_raise_type_error(self): + with self.assertRaises(TypeError): + threshold_detection(self.vm.magnitude) + + def test_always_return_as_list(self): + self.assertIsInstance(threshold_detection(self.vm, + always_as_list=True), + SpikeTrainList) + + def test_analog_signal_multiple_channels(self): + list_of_spike_trains = threshold_detection(self.vm_3d) + self.assertEqual(len(list_of_spike_trains), 3) + for spike_train in list_of_spike_trains: + with self.subTest(value=spike_train): + self.assertIsInstance(spike_train, neo.SpikeTrain) + self.assertIsInstance(list_of_spike_trains, SpikeTrainList) + + def test_empty_analog_signal(self): + empty_analog_signal = neo.AnalogSignal([], units='V', + sampling_period=1*pq.ms) + self.assertEqual(empty_analog_signal.shape, (0, 1)) + self.assertIsInstance(threshold_detection(empty_analog_signal), + neo.core.SpikeTrain) + + +class PeakDetectionTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): curr_dir = os.path.dirname(os.path.realpath(__file__)) raw_data_file_loc = os.path.join( curr_dir, 'spike_extraction_test_data.txt') @@ -97,16 +140,19 @@ def setUp(self): with open(raw_data_file_loc, 'r') as f: for x in f.readlines(): raw_data.append(float(x)) - self.vm = neo.AnalogSignal( + cls.vm = neo.AnalogSignal( raw_data, units=pq.V, sampling_period=0.1 * pq.ms) - self.true_time_stamps = [0.0124, 0.0354, 0.0713, 0.1192, 0.1695, - 0.2201, 0.2711] * pq.s - - def test_peak_detection_time_stamps(self): + cls.vm_3d = neo.AnalogSignal(np.array([raw_data, + raw_data, + raw_data]).T, + units=pq.V, sampling_period=0.1 * pq.ms) + cls.true_time_stamps = [0.0124, 0.0354, 0.0713, 0.1192, 0.1695, + 0.2201, 0.2711] * pq.s + + def test_peak_detection_validate_result(self): # Test with default arguments result = peak_detection(self.vm) self.assertEqual(len(self.true_time_stamps), len(result)) - self.assertIsInstance(result, neo.core.SpikeTrain) try: assert_array_almost_equal(result, self.true_time_stamps) @@ -118,10 +164,48 @@ def test_peak_detection_threshold(self): result = peak_detection(self.vm, threshold=30 * pq.mV) self.assertEqual(len(result), 0) + def test_threshold_raise_type_error(self): + with self.assertRaises(TypeError): + peak_detection(self.vm, threshold=30) -class AnalogSignalSpikeExtractionTestCase(unittest.TestCase): + def test_sign_raise_value_error(self): + with self.assertRaises(ValueError): + peak_detection(self.vm, sign="wrong input") - def setUp(self): + def test_return_is_neo_spike_train(self): + self.assertIsInstance(peak_detection(self.vm), neo.core.SpikeTrain) + + def test_signal_raise_type_error(self): + with self.assertRaises(TypeError): + peak_detection(self.vm.magnitude) + + def test_always_return_as_list(self): + self.assertIsInstance(peak_detection(self.vm, always_as_list=True), + SpikeTrainList) + + def test_analog_signal_multiple_channels(self): + list_of_spike_trains = peak_detection(self.vm_3d) + self.assertEqual(len(list_of_spike_trains), 3) + for spike_train in list_of_spike_trains: + with self.subTest(value=spike_train): + self.assertIsInstance(spike_train, neo.SpikeTrain) + + def test_analog_signal_multiple_channels_as_array(self): + list_of_spike_trains = peak_detection(self.vm_3d, as_array=True) + self.assertEqual(len(list_of_spike_trains), 3) + for spike_train in list_of_spike_trains: + with self.subTest(value=spike_train): + self.assertIsInstance(spike_train, np.ndarray) + + def test_analog_signal_single_channel_as_array(self): + array = peak_detection(self.vm, as_array=True) + self.assertIsInstance(array, np.ndarray) + self.assertEqual(array.ndim, 1) + + +class SpikeExtractionTestCase(unittest.TestCase): + @classmethod + def setUpClass(cls): curr_dir = os.path.dirname(os.path.realpath(__file__)) raw_data_file_loc = os.path.join( curr_dir, 'spike_extraction_test_data.txt') @@ -129,27 +213,58 @@ def setUp(self): with open(raw_data_file_loc, 'r') as f: for x in f.readlines(): raw_data.append(float(x)) - self.vm = neo.AnalogSignal( + cls.vm = neo.AnalogSignal( raw_data, units=pq.V, sampling_period=0.1 * pq.ms) - self.first_spike = np.array([-0.04084546, -0.03892033, -0.03664779, - -0.03392689, -0.03061474, -0.02650277, - -0.0212756, -0.01443531, -0.00515365, - 0.00803962, 0.02797951, -0.07, - -0.06974495, -0.06950466, -0.06927778, - -0.06906314, -0.06885969, -0.06866651, - -0.06848277, -0.06830773, -0.06814071, - -0.06798113, -0.06782843, -0.06768213, - -0.06754178, -0.06740699, -0.06727737, - -0.06715259, -0.06703235, -0.06691635]) + cls.vm_3d = neo.AnalogSignal(np.array([raw_data, + raw_data, + raw_data]).T, + units=pq.V, sampling_period=0.1 * pq.ms) + cls.first_spike = np.array([-0.04084546, -0.03892033, -0.03664779, + -0.03392689, -0.03061474, -0.02650277, + -0.0212756, -0.01443531, -0.00515365, + 0.00803962, 0.02797951, -0.07, + -0.06974495, -0.06950466, -0.06927778, + -0.06906314, -0.06885969, -0.06866651, + -0.06848277, -0.06830773, -0.06814071, + -0.06798113, -0.06782843, -0.06768213, + -0.06754178, -0.06740699, -0.06727737, + -0.06715259, -0.06703235, -0.06691635]) def test_spike_extraction_waveform(self): - spike_train = spike_extraction(self.vm.reshape(-1), + spike_train = spike_extraction(self.vm, interval=(-1 * pq.ms, 2 * pq.ms)) assert_array_almost_equal( spike_train.waveforms[0][0].magnitude.reshape(-1), self.first_spike) + def test_threshold_raise_type_error(self): + with self.assertRaises(TypeError): + spike_extraction(self.vm, threshold=30) + + def test_sign_raise_value_error(self): + with self.assertRaises(ValueError): + spike_extraction(self.vm, sign="wrong input") + + def test_return_is_neo_spike_train(self): + self.assertIsInstance(spike_extraction(self.vm), neo.core.SpikeTrain) + + def test_signal_raise_type_error(self): + with self.assertRaises(TypeError): + spike_extraction(self.vm.magnitude) + + def test_always_return_as_list(self): + self.assertIsInstance(spike_extraction(self.vm, always_as_list=True), + SpikeTrainList) + + def test_analog_signal_multiple_channels(self): + list_of_spike_trains = spike_extraction(self.vm_3d) + self.assertEqual(len(list_of_spike_trains), 3) + for spike_train in list_of_spike_trains: + with self.subTest(value=spike_train): + self.assertIsInstance(spike_train, neo.SpikeTrain) + self.assertIsInstance(list_of_spike_trains, SpikeTrainList) + class AbstractPointProcessTestCase(unittest.TestCase): def test_not_implemented_error(self):