From c3eb21fb904ba0d80524f3bf242d4e267155de6e Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:47:24 +0200 Subject: [PATCH] Fix/dither spike train with numpy>1.23.0 (#589) * add typehints * update docstring * write regression test for issue #586 --- elephant/spike_train_surrogates.py | 118 ++++++------ elephant/test/test_spike_train_surrogates.py | 179 ++++++++++--------- 2 files changed, 159 insertions(+), 138 deletions(-) diff --git a/elephant/spike_train_surrogates.py b/elephant/spike_train_surrogates.py index 94b0cee7e..22f028844 100644 --- a/elephant/spike_train_surrogates.py +++ b/elephant/spike_train_surrogates.py @@ -37,6 +37,7 @@ import random import warnings import copy +from typing import Union, Optional, List import neo import numpy as np @@ -66,14 +67,14 @@ 'bin_shuffling', 'isi_dithering') -def _dither_spikes_with_refractory_period(spiketrain, dither, n_surrogates, - refractory_period): +def _dither_spikes_with_refractory_period(spiketrain: neo.SpikeTrain, + dither: float, + n_surrogates: int, + refractory_period: float + ) -> np.array: units = spiketrain.units t_start = spiketrain.t_start.rescale(units).magnitude t_stop = spiketrain.t_stop.rescale(units).magnitude - - dither = dither.rescale(units).magnitude - refractory_period = refractory_period.rescale(units).magnitude # The initially guesses refractory period is compared to the minimal ISI. # The smaller value is taken as the refractory to calculate with. refractory_period = np.min(np.diff(spiketrain.magnitude), @@ -108,14 +109,45 @@ def _dither_spikes_with_refractory_period(spiketrain, dither, n_surrogates, dithered_spiketrains.append(dithered_st) - dithered_spiketrains = np.array(dithered_spiketrains) * units + dithered_spiketrains = np.array(dithered_spiketrains) + + return dithered_spiketrains + + +def _dither_spikes(spiketrain: neo.SpikeTrain, dither: float, + n_surrogates: int, edges: bool) -> np.array: + units = spiketrain.units + t_start = spiketrain.t_start.rescale(units).magnitude.item() + t_stop = spiketrain.t_stop.rescale(units).magnitude.item() + # Main: generate the surrogates + dithered_spiketrains = \ + spiketrain.magnitude.reshape((1, len(spiketrain))) \ + + 2 * dither * np.random.random_sample( + (n_surrogates, len(spiketrain))) - dither + dithered_spiketrains.sort(axis=1) + + if edges: + # Leave out all spikes outside [spiketrain.t_start, spiketrain.t_stop] + dithered_spiketrains = [ + train[np.all([t_start < train, train < t_stop], axis=0)] + for train in dithered_spiketrains] + else: + # Move all spikes outside + # [spiketrain.t_start, spiketrain.t_stop] to the range's ends + dithered_spiketrains = np.minimum( + np.maximum(dithered_spiketrains, t_start), t_stop) return dithered_spiketrains @deprecated_alias(n='n_surrogates') -def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None, - edges=True, refractory_period=None): +def dither_spikes(spiketrain: neo.SpikeTrain, + dither: pq.Quantity, + n_surrogates: Optional[int] = 1, + decimals: Optional[int] = None, + edges: Optional[bool] = True, + refractory_period: Optional[Union[pq.Quantity, None]] = None + ) -> List[neo.SpikeTrain]: """ Generates surrogates of a spike train by spike dithering. @@ -129,7 +161,7 @@ def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None, Parameters ---------- - spiketrain : neo.SpikeTrain + spiketrain : :class:`neo.core.SpikeTrain` The spike train from which to generate the surrogates. dither : pq.Quantity Amount of dithering. A spike at time `t` is placed randomly within @@ -161,8 +193,8 @@ def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None, Returns ------- - list of neo.SpikeTrain - Each surrogate spike train obtained independently from `spiketrain` by + list of :class:`neo.core.SpikeTrain` + Each surrogate spike train obtained independently of `spiketrain` by randomly dithering its spikes. The range of the surrogate spike trains is the same as of `spiketrain`. @@ -186,54 +218,40 @@ def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None, [0.0 ms, 1000.0 ms])>] """ + # The trivial case if len(spiketrain) == 0: - # return the empty spiketrain n times + # return the empty spiketrain n_surrogates times return [spiketrain.copy() for _ in range(n_surrogates)] + # Handle units units = spiketrain.units - t_start = spiketrain.t_start.rescale(units).magnitude - t_stop = spiketrain.t_stop.rescale(units).magnitude - - if refractory_period is None or refractory_period == 0: - # Main: generate the surrogates - dither = dither.rescale(units).magnitude - dithered_spiketrains = \ - spiketrain.magnitude.reshape((1, len(spiketrain))) \ - + 2 * dither * np.random.random_sample( - (n_surrogates, len(spiketrain))) - dither - dithered_spiketrains.sort(axis=1) - - if edges: - # Leave out all spikes outside - # [spiketrain.t_start, spiketrain.t_stop] - dithered_spiketrains = \ - [train[ - np.all([t_start < train, train < t_stop], axis=0)] - for train in dithered_spiketrains] - else: - # Move all spikes outside - # [spiketrain.t_start, spiketrain.t_stop] to the range's ends - dithered_spiketrains = np.minimum( - np.maximum(dithered_spiketrains, t_start), - t_stop) - - dithered_spiketrains = dithered_spiketrains * units + dither = dither.rescale(units).magnitude.item() + if not refractory_period: + dithered_spiketrains = _dither_spikes( + spiketrain, dither, n_surrogates, edges) elif isinstance(refractory_period, pq.Quantity): + refractory_period = refractory_period.rescale(units).magnitude.item() + dithered_spiketrains = _dither_spikes_with_refractory_period( spiketrain, dither, n_surrogates, refractory_period) else: raise ValueError("refractory_period must be of type pq.Quantity") # Round the surrogate data to decimal position, if requested - if decimals is not None: - dithered_spiketrains = \ - dithered_spiketrains.rescale(pq.ms).round(decimals).rescale(units) - - # Return the surrogates as list of neo.SpikeTrain - return [neo.SpikeTrain(train, t_start=t_start, t_stop=t_stop, - sampling_rate=spiketrain.sampling_rate) - for train in dithered_spiketrains] + if decimals: + return [neo.SpikeTrain( + (train * units).rescale(pq.ms).round(decimals).rescale(units), + t_start=spiketrain.t_start, t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate) + for train in dithered_spiketrains] + else: + # Return the surrogates as list of neo.SpikeTrain + return [neo.SpikeTrain( + train * units, + t_start=spiketrain.t_start, t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate) + for train in dithered_spiketrains] @deprecated_alias(n='n_surrogates') @@ -393,7 +411,7 @@ def dither_spike_train(spiketrain, shift, n_surrogates=1, decimals=None, Parameters ---------- - spiketrain : neo.SpikeTrain + spiketrain : :class:`neo.core.SpikeTrain` The spike train from which to generate the surrogates. shift : pq.Quantity Amount of shift. `spiketrain` is shifted by a random amount uniformly @@ -413,8 +431,8 @@ def dither_spike_train(spiketrain, shift, n_surrogates=1, decimals=None, Returns ------- - list of neo.SpikeTrain - Each surrogate spike train obtained independently from `spiketrain` by + list of :class:`neo.core.SpikeTrain` + Each surrogate spike train obtained independently of `spiketrain` by randomly dithering the whole spike train. The time range of the surrogate spike trains is the same as in `spiketrain`. diff --git a/elephant/test/test_spike_train_surrogates.py b/elephant/test/test_spike_train_surrogates.py index f6f9996a5..af961b39b 100644 --- a/elephant/test/test_spike_train_surrogates.py +++ b/elephant/test/test_spike_train_surrogates.py @@ -24,25 +24,27 @@ def setUp(self): np.random.seed(0) random.seed(0) - def test_dither_spikes_output_format(self): + @classmethod + def setUpClass(cls) -> None: + st1 = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) + cls.st1 = st1 - spiketrain = neo.SpikeTrain([90, 93, 97, 100, 105, - 150, 180, 350] * pq.ms, t_stop=.5 * pq.s) - spiketrain.t_stop = .5 * pq.s + def test_dither_spikes_output_format(self): + self.st1.t_stop = .5 * pq.s n_surrogates = 2 dither = 10 * pq.ms surrogate_trains = surr.dither_spikes( - spiketrain, dither=dither, n_surrogates=n_surrogates) + self.st1, dither=dither, n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) self.assertIsInstance(surrogate_trains[0], neo.SpikeTrain) for surrogate_train in surrogate_trains: - self.assertEqual(surrogate_train.units, spiketrain.units) - self.assertEqual(surrogate_train.t_start, spiketrain.t_start) - self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) - self.assertEqual(len(surrogate_train), len(spiketrain)) + self.assertEqual(surrogate_train.units, self.st1.units) + self.assertEqual(surrogate_train.t_start, self.st1.t_start) + self.assertEqual(surrogate_train.t_stop, self.st1.t_stop) + self.assertEqual(len(surrogate_train), len(self.st1)) assert_array_less(0., np.diff(surrogate_train)) # check ordering def test_dither_spikes_empty_train(self): @@ -54,18 +56,32 @@ def test_dither_spikes_empty_train(self): st, dither=dither, n_surrogates=1)[0] self.assertEqual(len(surrogate_train), 0) - def test_dither_spikes_output_decimals(self): + def test_dither_spikes_refactory_period_zero_or_none(self): + dither = 10 * pq.ms + decimals = 3 + n_surrogates = 1 - st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) + np.random.seed(42) + surrogate_trains_zero = surr.dither_spikes( + self.st1, dither, decimals=decimals, n_surrogates=n_surrogates, + refractory_period=0) + np.random.seed(42) + surrogate_trains_none = surr.dither_spikes( + self.st1, dither, decimals=decimals, n_surrogates=n_surrogates, + refractory_period=None) + np.testing.assert_array_almost_equal( + surrogate_trains_zero[0].magnitude, + surrogate_trains_none[0].magnitude) + def test_dither_spikes_output_decimals(self): n_surrogates = 2 dither = 10 * pq.ms np.random.seed(42) surrogate_trains = surr.dither_spikes( - st, dither=dither, decimals=3, n_surrogates=n_surrogates) + self.st1, dither=dither, decimals=3, n_surrogates=n_surrogates) np.random.seed(42) - dither_values = np.random.random_sample((n_surrogates, len(st))) + dither_values = np.random.random_sample((n_surrogates, len(self.st1))) expected_non_dithered = np.sum(dither_values == 0) observed_non_dithered = 0 @@ -78,17 +94,14 @@ def test_dither_spikes_output_decimals(self): self.assertEqual(observed_non_dithered, expected_non_dithered) def test_dither_spikes_false_edges(self): - - st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - n_surrogates = 2 dither = 10 * pq.ms surrogate_trains = surr.dither_spikes( - st, dither=dither, n_surrogates=n_surrogates, edges=False) + self.st1, dither=dither, n_surrogates=n_surrogates, edges=False) for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): - self.assertLessEqual(surrogate_train[i], st.t_stop) + self.assertLessEqual(surrogate_train[i], self.st1.t_stop) def test_dither_spikes_with_refractory_period_output_format(self): @@ -131,24 +144,47 @@ def test_dither_spikes_with_refractory_period_empty_train(self): refractory_period=4 * pq.ms)[0] self.assertEqual(len(surrogate_train), 0) - def test_randomise_spikes_output_format(self): - - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) + def test_dither_spikes_regression_issue_586(self): + """ + When using the dither_spikes surrogate generation function, with the + edges=True option, there is an exception when spikes are removed due + to being dithered outside the spiketrain duration. + + Since the arrays in the list will have different dimensions, the + multiplication operator fails. + However, this worked with numpy==1.23 and fails with numpy>=1.24. + See: https://github.com/NeuralEnsemble/elephant/issues/586 + """ + # Generate one spiketrain with a spike close to t_stop + t_stop = 2 * pq.s + st = stg.StationaryPoissonProcess( + rate=10 * pq.Hz, t_stop=t_stop).generate_spiketrain() + st = neo.SpikeTrain(np.hstack([st.magnitude, [1.9999999]]), + units=st.units, t_stop=t_stop) + + # Dither + np.random.seed(5) + surrogate_trains = surr.dither_spikes( + st, dither=15 * pq.ms, n_surrogates=30, edges=True, decimals=2) + for surrogate in surrogate_trains: + with self.subTest(surrogate): + self.assertLess(surrogate[-1], surrogate.t_stop) + self.assertGreater(surrogate[0], surrogate.t_start) + def test_randomise_spikes_output_format(self): n_surrogates = 2 surrogate_trains = surr.randomise_spikes( - spiketrain, n_surrogates=n_surrogates) + self.st1, n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) self.assertIsInstance(surrogate_trains[0], neo.SpikeTrain) for surrogate_train in surrogate_trains: - self.assertEqual(surrogate_train.units, spiketrain.units) - self.assertEqual(surrogate_train.t_start, spiketrain.t_start) - self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) - self.assertEqual(len(surrogate_train), len(spiketrain)) + self.assertEqual(surrogate_train.units, self.st1.units) + self.assertEqual(surrogate_train.t_start, self.st1.t_start) + self.assertEqual(surrogate_train.t_stop, self.st1.t_stop) + self.assertEqual(len(surrogate_train), len(self.st1)) def test_randomise_spikes_empty_train(self): @@ -158,12 +194,9 @@ def test_randomise_spikes_empty_train(self): self.assertEqual(len(surrogate_train), 0) def test_randomise_spikes_output_decimals(self): - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - n_surrogates = 2 surrogate_trains = surr.randomise_spikes( - spiketrain, n_surrogates=n_surrogates, decimals=3) + self.st1, n_surrogates=n_surrogates, decimals=3) for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): @@ -173,23 +206,19 @@ def test_randomise_spikes_output_decimals(self): surrogate_train[i]) def test_shuffle_isis_output_format(self): - - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - n_surrogates = 2 surrogate_trains = surr.shuffle_isis( - spiketrain, n_surrogates=n_surrogates) + self.st1, n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) self.assertIsInstance(surrogate_trains[0], neo.SpikeTrain) for surrogate_train in surrogate_trains: - self.assertEqual(surrogate_train.units, spiketrain.units) - self.assertEqual(surrogate_train.t_start, spiketrain.t_start) - self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) - self.assertEqual(len(surrogate_train), len(spiketrain)) + self.assertEqual(surrogate_train.units, self.st1.units) + self.assertEqual(surrogate_train.t_start, self.st1.t_start) + self.assertEqual(surrogate_train.t_stop, self.st1.t_stop) + self.assertEqual(len(surrogate_train), len(self.st1)) def test_shuffle_isis_empty_train(self): @@ -199,16 +228,12 @@ def test_shuffle_isis_empty_train(self): self.assertEqual(len(surrogate_train), 0) def test_shuffle_isis_same_isis(self): + surrogate_train = surr.shuffle_isis(self.st1, n_surrogates=1)[0] - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - - surrogate_train = surr.shuffle_isis(spiketrain, n_surrogates=1)[0] - - st_pq = spiketrain.view(pq.Quantity) + st_pq = self.st1.view(pq.Quantity) surr_pq = surrogate_train.view(pq.Quantity) - isi0_orig = spiketrain[0] - spiketrain.t_start + isi0_orig = self.st1[0] - self.st1.t_start ISIs_orig = np.sort([isi0_orig] + [isi for isi in np.diff(st_pq)]) isi0_surr = surrogate_train[0] - surrogate_train.t_start @@ -217,17 +242,13 @@ def test_shuffle_isis_same_isis(self): self.assertTrue(np.all(ISIs_orig == ISIs_surr)) def test_shuffle_isis_output_decimals(self): - - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - surrogate_train = surr.shuffle_isis( - spiketrain, n_surrogates=1, decimals=95)[0] + self.st1, n_surrogates=1, decimals=95)[0] - st_pq = spiketrain.view(pq.Quantity) + st_pq = self.st1.view(pq.Quantity) surr_pq = surrogate_train.view(pq.Quantity) - isi0_orig = spiketrain[0] - spiketrain.t_start + isi0_orig = self.st1[0] - self.st1.t_start ISIs_orig = np.sort([isi0_orig] + [isi for isi in np.diff(st_pq)]) isi0_surr = surrogate_train[0] - surrogate_train.t_start @@ -254,24 +275,20 @@ def test_shuffle_isis_with_wrongly_ordered_spikes(self): dt=dither) def test_dither_spike_train_output_format(self): - - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - n_surrogates = 2 shift = 10 * pq.ms surrogate_trains = surr.dither_spike_train( - spiketrain, shift=shift, n_surrogates=n_surrogates) + self.st1, shift=shift, n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) self.assertIsInstance(surrogate_trains[0], neo.SpikeTrain) for surrogate_train in surrogate_trains: - self.assertEqual(surrogate_train.units, spiketrain.units) - self.assertEqual(surrogate_train.t_start, spiketrain.t_start) - self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) - self.assertEqual(len(surrogate_train), len(spiketrain)) + self.assertEqual(surrogate_train.units, self.st1.units) + self.assertEqual(surrogate_train.t_start, self.st1.t_start) + self.assertEqual(surrogate_train.t_stop, self.st1.t_stop) + self.assertEqual(len(surrogate_train), len(self.st1)) def test_dither_spike_train_empty_train(self): @@ -283,12 +300,10 @@ def test_dither_spike_train_empty_train(self): self.assertEqual(len(surrogate_train), 0) def test_dither_spike_train_output_decimals(self): - st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - n_surrogates = 2 shift = 10 * pq.ms surrogate_trains = surr.dither_spike_train( - st, shift=shift, n_surrogates=n_surrogates, decimals=3) + self.st1, shift=shift, n_surrogates=n_surrogates, decimals=3) for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): @@ -298,38 +313,30 @@ def test_dither_spike_train_output_decimals(self): surrogate_train[i]) def test_dither_spike_train_false_edges(self): - - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - n_surrogates = 2 shift = 10 * pq.ms surrogate_trains = surr.dither_spike_train( - spiketrain, shift=shift, n_surrogates=n_surrogates, edges=False) + self.st1, shift=shift, n_surrogates=n_surrogates, edges=False) for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): - self.assertLessEqual(surrogate_train[i], spiketrain.t_stop) + self.assertLessEqual(surrogate_train[i], self.st1.t_stop) def test_jitter_spikes_output_format(self): - - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - n_surrogates = 2 bin_size = 100 * pq.ms surrogate_trains = surr.jitter_spikes( - spiketrain, bin_size=bin_size, n_surrogates=n_surrogates) + self.st1, bin_size=bin_size, n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) self.assertIsInstance(surrogate_trains[0], neo.SpikeTrain) for surrogate_train in surrogate_trains: - self.assertEqual(surrogate_train.units, spiketrain.units) - self.assertEqual(surrogate_train.t_start, spiketrain.t_start) - self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) - self.assertEqual(len(surrogate_train), len(spiketrain)) + self.assertEqual(surrogate_train.units, self.st1.units) + self.assertEqual(surrogate_train.t_start, self.st1.t_start) + self.assertEqual(surrogate_train.t_stop, self.st1.t_stop) + self.assertEqual(len(surrogate_train), len(self.st1)) def test_jitter_spikes_empty_train(self): @@ -341,16 +348,12 @@ def test_jitter_spikes_empty_train(self): self.assertEqual(len(surrogate_train), 0) def test_jitter_spikes_same_bins(self): - - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) - bin_size = 100 * pq.ms surrogate_train = surr.jitter_spikes( - spiketrain, bin_size=bin_size, n_surrogates=1)[0] + self.st1, bin_size=bin_size, n_surrogates=1)[0] bin_ids_orig = np.array( - (spiketrain.view( + (self.st1.view( pq.Quantity) / bin_size).rescale( pq.dimensionless).magnitude, @@ -365,7 +368,7 @@ def test_jitter_spikes_same_bins(self): # Bug encountered when the original and surrogate trains have # different number of spikes - self.assertEqual(len(spiketrain), len(surrogate_train)) + self.assertEqual(len(self.st1), len(surrogate_train)) def test_jitter_spikes_unequal_bin_size(self):