Skip to content

Commit

Permalink
add tests ensuring consistent output for pooling options trial object
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz-Alexander-Kern committed Nov 15, 2024
1 parent 123ca04 commit 1757196
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,15 @@ def test_cv2_raise_error(self):
self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq]))


class InstantaneousRateTest(unittest.TestCase):
class InstantaneousRateTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
"""
Run once before tests:
"""

block = _create_trials_block(n_trials=36)
block = _create_trials_block(n_trials=36, n_spiketrains=5)
cls.block = block
cls.trial_object = TrialsFromBlock(block,
description='trials are segments')
Expand Down Expand Up @@ -988,6 +988,42 @@ def test_instantaneous_rate_trials_pool_trials(self):
pool_spike_trains=False,
pool_trials=True)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])

def test_instantaneous_rate_trials_pool_spiketrains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
self.assertEqual(rate[0].shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_pool_trials(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=True)
self.assertIsInstance(rate, neo.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_false_pool_trials_false(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
self.assertEqual(rate[0].shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])

def test_instantaneous_rate_list_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)
Expand All @@ -999,7 +1035,7 @@ def test_instantaneous_rate_list_pool_spike_trains(self):
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 1)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_of_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)
Expand All @@ -1010,7 +1046,7 @@ def test_instantaneous_rate_list_of_spike_trains(self):
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 2)
self.assertEqual(rate.magnitude.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])


class TimeHistogramTestCase(unittest.TestCase):
Expand Down

0 comments on commit 1757196

Please sign in to comment.