Skip to content

Commit

Permalink
add type check for pool parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz-Alexander-Kern committed Nov 14, 2024
1 parent 1711f43 commit 7b0ccf9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
14 changes: 11 additions & 3 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,23 +364,31 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float:
return fano

if isinstance(spiketrains, elephant.trials.Trials):
# Check if parameters are of the correct type
if not isinstance(pool_trials, bool):
raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}")
elif not isinstance(pool_spike_trains, bool):
raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}")
if not pool_trials and not pool_spike_trains:
return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)]
for idx in range(spiketrains.n_trials)]
if not pool_trials and pool_spike_trains:
elif not pool_trials and pool_spike_trains:
return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx))
for idx in range(spiketrains.n_trials)]
if pool_trials and not pool_spike_trains:
elif pool_trials and not pool_spike_trains:
list_of_lists_of_spiketrains = [
spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)
for trial_no in range(spiketrains.n_trials)]
return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no]
for trial_no in range(len(list_of_lists_of_spiketrains))])
for st_no in range(len(list_of_lists_of_spiketrains[0]))]
if pool_trials and pool_spike_trains:
elif pool_trials and pool_spike_trains:
return [_compute_fano(
[spiketrain for trial_no in range(spiketrains.n_trials)
for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])]
else:
raise TypeError(f"pool_spiketrains and pool_trials must be of type: bool, but are "
f"{type(pool_spike_trains)} and {type(pool_trials)}")
else: # Legacy behavior
return _compute_fano(spiketrains)

Expand Down
6 changes: 6 additions & 0 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,12 @@ def test_fanofactor_trials_pool_trials_false_pool_spiketrains_false(self):
for result in results:
self.assertEqual(len(result), self.test_trials.n_spiketrains_trial_by_trial[0])

def test_fanofactor_trials_pool_spike_trains_wrong_type(self):
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trains="Wrong Type")
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type")
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type",
pool_spike_trains="Wrong Type")


class LVTestCase(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 7b0ccf9

Please sign in to comment.