Skip to content

Commit

Permalink
refactor type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz-Alexander-Kern committed Nov 14, 2024
1 parent 7b0ccf9 commit 1925bf3
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,9 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None):
return rates


def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, elephant.trials.Trials],
warn_tolerance:pq.Quantity=0.1 * pq.ms, pool_trials:bool=False, pool_spike_trains:bool=False):
def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray], elephant.trials.Trials],
warn_tolerance: pq.Quantity = 0.1 * pq.ms, pool_trials: bool = False, pool_spike_trains: bool = False
) -> Union[float, List[float], List[List[float]]]:
r"""
Evaluates the empirical Fano factor F of the spike counts of
a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object.
Expand All @@ -296,7 +297,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray,
pool_trials and pool_spike_trains parameters.
warn_tolerance : pq.Quantity
In case of a list of input neo.SpikeTrains, if their durations vary by
more than `warn_tolerence` in their absolute values, throw a warning
more than `warn_tolerance` in their absolute values, throw a warning
(see Notes).
Default: 0.1 ms
pool_trials : bool, optional
Expand All @@ -308,7 +309,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray,
Returns
-------
fano : float or list of floats
fano : float, list of floats or list of list of floats
The Fano factor of the spike counts of the input spike trains.
Returns np.NaN if an empty list is specified, or if all spike trains
are empty. If a `Trial` object is provided, returns a list of Fano
Expand Down Expand Up @@ -338,7 +339,15 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray,
0.07142857142857142
"""
def _compute_fano(spiketrains: neo.SpikeTrain) -> float:
# Check if parameters are of the correct type
if not isinstance(pool_trials, bool):
raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}")
elif not isinstance(pool_spike_trains, bool):
raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}")
elif not is_time_quantity(warn_tolerance):
raise TypeError("'warn_tolerance' must be a time quantity.")

def _compute_fano(spiketrains: List[neo.SpikeTrain]) -> float:
# Build array of spike counts (one per spike train)
spike_counts = np.array([len(st) for st in spiketrains])

Expand All @@ -348,8 +357,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float:
return np.nan

if all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
if not is_time_quantity(warn_tolerance):
raise TypeError("'warn_tolerance' must be a time quantity.")
durations = [(st.t_stop - st.t_start).simplified.item()
for st in spiketrains]
durations_min = min(durations)
Expand All @@ -364,11 +371,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float:
return fano

if isinstance(spiketrains, elephant.trials.Trials):
# Check if parameters are of the correct type
if not isinstance(pool_trials, bool):
raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}")
elif not isinstance(pool_spike_trains, bool):
raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}")
if not pool_trials and not pool_spike_trains:
return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)]
for idx in range(spiketrains.n_trials)]
Expand All @@ -386,9 +388,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float:
return [_compute_fano(
[spiketrain for trial_no in range(spiketrains.n_trials)
for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])]
else:
raise TypeError(f"pool_spiketrains and pool_trials must be of type: bool, but are "
f"{type(pool_spike_trains)} and {type(pool_trials)}")
else: # Legacy behavior
return _compute_fano(spiketrains)

Expand Down

0 comments on commit 1925bf3

Please sign in to comment.