From 1757196cf2489d25267fbee7df48c10f2f9d6192 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Fri, 15 Nov 2024 15:21:25 +0100 Subject: [PATCH] add tests ensuring consistent output for pooling options trial object --- elephant/test/test_statistics.py | 44 +++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 426111810..28043bdeb 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -482,7 +482,7 @@ def test_cv2_raise_error(self): self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq])) -class InstantaneousRateTest(unittest.TestCase): +class InstantaneousRateTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -490,7 +490,7 @@ def setUpClass(cls) -> None: Run once before tests: """ - block = _create_trials_block(n_trials=36) + block = _create_trials_block(n_trials=36, n_spiketrains=5) cls.block = block cls.trial_object = TrialsFromBlock(block, description='trials are segments') @@ -988,6 +988,42 @@ def test_instantaneous_rate_trials_pool_trials(self): pool_spike_trains=False, pool_trials=True) self.assertIsInstance(rate, neo.core.AnalogSignal) + self.assertEqual(rate.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0]) + + def test_instantaneous_rate_trials_pool_spiketrains(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate(self.trial_object, + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=True, + pool_trials=False) + self.assertIsInstance(rate, list) + self.assertEqual(len(rate), self.trial_object.n_trials) + self.assertEqual(rate[0].shape[1], 1) + + def test_instantaneous_rate_trials_pool_spiketrains_pool_trials(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate(self.trial_object, + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=True, + pool_trials=True) + self.assertIsInstance(rate, neo.AnalogSignal) + self.assertEqual(rate.shape[1], 1) + + def test_instantaneous_rate_trials_pool_spiketrains_false_pool_trials_false(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate(self.trial_object, + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=False, + pool_trials=False) + self.assertIsInstance(rate, list) + self.assertEqual(len(rate), self.trial_object.n_trials) + self.assertEqual(rate[0].shape[1], self.trial_object.n_spiketrains_trial_by_trial[0]) def test_instantaneous_rate_list_pool_spike_trains(self): kernel = kernels.GaussianKernel(sigma=500 * pq.ms) @@ -999,7 +1035,7 @@ def test_instantaneous_rate_list_pool_spike_trains(self): pool_spike_trains=True, pool_trials=False) self.assertIsInstance(rate, neo.core.AnalogSignal) - self.assertEqual(rate.magnitude.shape[1], 1) + self.assertEqual(rate.shape[1], 1) def test_instantaneous_rate_list_of_spike_trains(self): kernel = kernels.GaussianKernel(sigma=500 * pq.ms) @@ -1010,7 +1046,7 @@ def test_instantaneous_rate_list_of_spike_trains(self): pool_spike_trains=False, pool_trials=False) self.assertIsInstance(rate, neo.core.AnalogSignal) - self.assertEqual(rate.magnitude.shape[1], 2) + self.assertEqual(rate.magnitude.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0]) class TimeHistogramTestCase(unittest.TestCase):