From 7b0ccf92323bcd5fa756162da9084804ce22c414 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:32:49 +0100 Subject: [PATCH] add type check for pool parameters --- elephant/statistics.py | 14 +++++++++++--- elephant/test/test_statistics.py | 6 ++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index f4a3c4031..9217b7b4c 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -364,23 +364,31 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float: return fano if isinstance(spiketrains, elephant.trials.Trials): + # Check if parameters are of the correct type + if not isinstance(pool_trials, bool): + raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}") + elif not isinstance(pool_spike_trains, bool): + raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}") if not pool_trials and not pool_spike_trains: return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)] for idx in range(spiketrains.n_trials)] - if not pool_trials and pool_spike_trains: + elif not pool_trials and pool_spike_trains: return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx)) for idx in range(spiketrains.n_trials)] - if pool_trials and not pool_spike_trains: + elif pool_trials and not pool_spike_trains: list_of_lists_of_spiketrains = [ spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no) for trial_no in range(spiketrains.n_trials)] return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no] for trial_no in range(len(list_of_lists_of_spiketrains))]) for st_no in range(len(list_of_lists_of_spiketrains[0]))] - if pool_trials and pool_spike_trains: + elif pool_trials and pool_spike_trains: return [_compute_fano( [spiketrain for trial_no in range(spiketrains.n_trials) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])] + else: + raise TypeError(f"pool_spiketrains and pool_trials must be of type: bool, but are " + f"{type(pool_spike_trains)} and {type(pool_trials)}") else: # Legacy behavior return _compute_fano(spiketrains) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index f24adaa25..57ddeb96a 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -375,6 +375,12 @@ def test_fanofactor_trials_pool_trials_false_pool_spiketrains_false(self): for result in results: self.assertEqual(len(result), self.test_trials.n_spiketrains_trial_by_trial[0]) + def test_fanofactor_trials_pool_spike_trains_wrong_type(self): + self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trains="Wrong Type") + self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type") + self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type", + pool_spike_trains="Wrong Type") + class LVTestCase(unittest.TestCase): def setUp(self):