From 1925bf306b9876cbae28a324ac9eda7574656cdd Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:41:29 +0100 Subject: [PATCH] refactor type annotations --- elephant/statistics.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 9217b7b4c..745453bb0 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -270,8 +270,9 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): return rates -def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, elephant.trials.Trials], - warn_tolerance:pq.Quantity=0.1 * pq.ms, pool_trials:bool=False, pool_spike_trains:bool=False): +def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray], elephant.trials.Trials], + warn_tolerance: pq.Quantity = 0.1 * pq.ms, pool_trials: bool = False, pool_spike_trains: bool = False + ) -> Union[float, List[float], List[List[float]]]: r""" Evaluates the empirical Fano factor F of the spike counts of a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object. @@ -296,7 +297,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, pool_trials and pool_spike_trains parameters. warn_tolerance : pq.Quantity In case of a list of input neo.SpikeTrains, if their durations vary by - more than `warn_tolerence` in their absolute values, throw a warning + more than `warn_tolerance` in their absolute values, throw a warning (see Notes). Default: 0.1 ms pool_trials : bool, optional @@ -308,7 +309,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, Returns ------- - fano : float or list of floats + fano : float, list of floats or list of list of floats The Fano factor of the spike counts of the input spike trains. Returns np.NaN if an empty list is specified, or if all spike trains are empty. If a `Trial` object is provided, returns a list of Fano @@ -338,7 +339,15 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, 0.07142857142857142 """ - def _compute_fano(spiketrains: neo.SpikeTrain) -> float: + # Check if parameters are of the correct type + if not isinstance(pool_trials, bool): + raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}") + elif not isinstance(pool_spike_trains, bool): + raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}") + elif not is_time_quantity(warn_tolerance): + raise TypeError("'warn_tolerance' must be a time quantity.") + + def _compute_fano(spiketrains: List[neo.SpikeTrain]) -> float: # Build array of spike counts (one per spike train) spike_counts = np.array([len(st) for st in spiketrains]) @@ -348,8 +357,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float: return np.nan if all(isinstance(st, neo.SpikeTrain) for st in spiketrains): - if not is_time_quantity(warn_tolerance): - raise TypeError("'warn_tolerance' must be a time quantity.") durations = [(st.t_stop - st.t_start).simplified.item() for st in spiketrains] durations_min = min(durations) @@ -364,11 +371,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float: return fano if isinstance(spiketrains, elephant.trials.Trials): - # Check if parameters are of the correct type - if not isinstance(pool_trials, bool): - raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}") - elif not isinstance(pool_spike_trains, bool): - raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}") if not pool_trials and not pool_spike_trains: return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)] for idx in range(spiketrains.n_trials)] @@ -386,9 +388,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float: return [_compute_fano( [spiketrain for trial_no in range(spiketrains.n_trials) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])] - else: - raise TypeError(f"pool_spiketrains and pool_trials must be of type: bool, but are " - f"{type(pool_spike_trains)} and {type(pool_trials)}") else: # Legacy behavior return _compute_fano(spiketrains)