Skip to content

Commit

Permalink
add user warning and did refactoring of function
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz-Alexander-Kern committed Nov 14, 2024
1 parent 3c1eb5d commit 6bb2fc3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
44 changes: 23 additions & 21 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
spiketrains : list or elephant.trials.Trial
List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of
spike times for which to compute the Fano factor of spike counts, or
an `elephant.trials.Trial` object, here the behavior can be controlled with the
an `elephant.trials.Trial` object, here the behavior can be controlled with the
pool_trials and pool_spike_trains parameters.
warn_tolerance : pq.Quantity
In case of a list of input neo.SpikeTrains, if their durations vary by
Expand Down Expand Up @@ -325,7 +325,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
Notes
-----
The check for the equal duration of the input spike trains is performed
only if the input is of type`neo.SpikeTrain`: if you pass a numpy array,
only if the input is of type`neo.SpikeTrain`: if you pass e.g. a numpy array,
please make sure that they all have the same duration manually.
Examples
Expand All @@ -346,30 +346,32 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
elif not isinstance(pool_spike_trains, bool):
raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}")
elif not is_time_quantity(warn_tolerance):
raise TypeError("'warn_tolerance' must be a time quantity.")
raise TypeError(f"'warn_tolerance' must be a time quantity, but got {type(warn_tolerance)}")

def _check_input_spiketrains_durations(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity],
List[np.ndarray]]) -> None:
if spiketrains and all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
durations = np.array(tuple(st.duration for st in spiketrains))
if np.max(durations) - np.min(durations) > warn_tolerance:
warnings.warn(f"Fano factor calculated for spike trains of "
f"different duration (minimum: {np.min(durations)}s, maximum "
f"{np.max(durations)}s).")
else:
warnings.warn(f"Spiketrains was of type {type(spiketrains)}, which does not support automatic duration"
f"check. The parameter 'warn_tolerance' will have no effect. Please ensure manually that"
f"all spike trains have the same duration.")

def _compute_fano(spiketrains: List[neo.SpikeTrain]) -> float:
def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray]]) -> float:
# Check spike train durations
_check_input_spiketrains_durations(spiketrains)
# Build array of spike counts (one per spike train)
spike_counts = np.array([len(st) for st in spiketrains])

spike_counts = np.array(tuple(len(st) for st in spiketrains))
# Compute FF
if all(count == 0 for count in spike_counts):
if np.all(np.array(spike_counts) == 0):
# empty list of spiketrains reaches this branch, and NaN is returned
return np.nan

if all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
durations = [(st.t_stop - st.t_start).simplified.item()
for st in spiketrains]
durations_min = min(durations)
durations_max = max(durations)
if durations_max - durations_min > warn_tolerance.simplified.item():
warnings.warn("Fano factor calculated for spike trains of "
"different duration (minimum: {_min}s, maximum "
"{_max}s).".format(_min=durations_min,
_max=durations_max))

fano = spike_counts.var() / spike_counts.mean()
return fano
else:
return spike_counts.var()/spike_counts.mean()

if isinstance(spiketrains, elephant.trials.Trials):
if not pool_trials and not pool_spike_trains:
Expand Down
5 changes: 5 additions & 0 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,11 @@ def test_fanofactor_trials_pool_spike_trains_wrong_type(self):
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type",
pool_spike_trains="Wrong Type")

def test_fanofactor_warn_durations_manual_check(self):
st1 = [1, 2, 3] * pq.s
st2 = [1, 2, 3] * pq.s
self.assertWarns(UserWarning, statistics.fanofactor, (st1, st2))


class LVTestCase(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 6bb2fc3

Please sign in to comment.