From 6eec2635c9baaa7b1d23344070cb20a7f49b8624 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 17 Sep 2024 17:15:47 +0200 Subject: [PATCH 1/2] Switch to 88 line length --- src/pyparrm/_utils/_plotting.py | 207 +++++++++++------------------- src/pyparrm/_utils/_power.py | 16 ++- src/pyparrm/parrm.py | 217 ++++++++++++-------------------- tests/test_parrm.py | 178 +++++++------------------- 4 files changed, 205 insertions(+), 413 deletions(-) diff --git a/src/pyparrm/_utils/_plotting.py b/src/pyparrm/_utils/_plotting.py index 9711edb..634317f 100644 --- a/src/pyparrm/_utils/_plotting.py +++ b/src/pyparrm/_utils/_plotting.py @@ -19,30 +19,29 @@ class _ExploreParams: Parameters ---------- parrm : pyparrm.PARRM - PARRM object containing the data for which filter parameters should be - explored. + PARRM object containing the data for which filter parameters should be explored. time_range : list of int or float | None (default None) - Range of the times to plot and filter in a list of length two, - containing the first and last timepoints, respectively, in seconds. If - ``None``, all timepoints are used. + Range of the times to plot and filter in a list of length two, containing the + first and last timepoints, respectively, in seconds. If ``None``, all timepoints + are used. time_res : int | float (default 0.01) Time resolution, in seconds, to use when plotting the time-series data. freq_range : list of int or float | None (default None) - Range of the frequencies to plot in a list of length two, containing - the first and last frequencies, respectively, in Hz. If ``None``, all - frequencies are used. + Range of the frequencies to plot in a list of length two, containing the first + and last frequencies, respectively, in Hz. If ``None``, all frequencies are + used. freq_res : int | float (default 5.0) - Frequency resolution, in Hz, to use when computing the power spectra of - the data. + Frequency resolution, in Hz, to use when computing the power spectra of the + data. n_jobs : int (default 1) - Number of jobs to run in parallel when computing power spectra. Must be - less than the number of available CPUs and greater than 0 (unless it is - -1, in which case all available CPUs are used). + Number of jobs to run in parallel when computing power spectra. Must be less + than the number of available CPUs and greater than 0 (unless it is -1, in which + case all available CPUs are used). Methods ------- @@ -110,9 +109,8 @@ def _check_sort_init_inputs( ) -> None: """Check and sort init. inputs.""" assert parrm._period is not None, ( - "PyPARRM Internal Error: `_ParamSelection` should only be called " - "if the period has been estimated. Please contact the PyPARRM " - "developers." + "PyPARRM Internal Error: `_ParamSelection` should only be called if the " + "period has been estimated. Please contact the PyPARRM developers." ) self.parrm = deepcopy(parrm) self.parrm._verbose = False @@ -128,12 +126,10 @@ def _check_sort_init_inputs( raise ValueError("`time_range` must have a length of 2.") if ( time_range[0] < 0 - or time_range[1] - > self.parrm._n_samples / self.parrm._sampling_freq + or time_range[1] > self.parrm._n_samples / self.parrm._sampling_freq ): raise ValueError( - "Entries of `time_range` must lie in the range [0, " - "max. time]." + "Entries of `time_range` must lie in the range [0, max. time]." ) if time_range[0] >= time_range[1]: raise ValueError("`time_range[1]` must be > `time_range[0]`.") @@ -147,13 +143,8 @@ def _check_sort_init_inputs( # time_res if not isinstance(time_res, (int, float)): raise TypeError("`time_res` must be an int or a float.") - if ( - time_res <= 0 - or time_res >= self.time_range[-1] / self.parrm._sampling_freq - ): - raise ValueError( - "`time_res` must lie in the range (0, max. time)." - ) + if time_res <= 0 or time_res >= self.time_range[-1] / self.parrm._sampling_freq: + raise ValueError("`time_res` must lie in the range (0, max. time).") self.time_res = time_res self.decim = int(np.ceil(self.time_res * self.parrm._sampling_freq)) @@ -168,8 +159,7 @@ def _check_sort_init_inputs( raise ValueError("`freq_range` must have a length of 2.") if freq_range[0] <= 0 or freq_range[1] > self.parrm._sampling_freq / 2: raise ValueError( - "Entries of `freq_range` must lie in the range (0, " - "Nyquist frequency]." + "Entries of `freq_range` must lie in the range (0, Nyquist frequency]." ) if freq_range[0] >= freq_range[1]: raise ValueError("`freq_range[1]` must be > `freq_range[0]`.") @@ -179,18 +169,14 @@ def _check_sort_init_inputs( if not isinstance(freq_res, (int, float)): raise TypeError("`freq_res` must be an int or a float.") if freq_res <= 0 or freq_res > self.parrm._sampling_freq / 2: - raise ValueError( - "`freq_res` must lie in the range (0, Nyquist frequency]." - ) + raise ValueError("`freq_res` must lie in the range (0, Nyquist frequency].") self.freq_res = deepcopy(freq_res) # n_jobs if not isinstance(n_jobs, int): raise TypeError("`n_jobs` must be an int.") if n_jobs > cpu_count(): - raise ValueError( - "`n_jobs` must be <= the number of available CPUs." - ) + raise ValueError("`n_jobs` must be <= the number of available CPUs.") if n_jobs <= 0 and n_jobs != -1: raise ValueError("If `n_jobs` is <= 0, it must be -1.") if n_jobs == -1: @@ -204,8 +190,7 @@ def _check_sort_init_inputs( self.current_omit_n_samples = self.parrm._omit_n_samples self.current_sample_period_xvals = np.mod( - np.arange(self.current_filter_half_width * 2 - 1), - self.parrm._period, + np.arange(self.current_filter_half_width * 2 - 1), self.parrm._period ) self.current_channel_idx = 0 self.current_sample_period_yvals = np.diff( @@ -217,8 +202,7 @@ def _check_sort_init_inputs( def _initialise_parrm_data_info(self) -> None: """Initialise information from PARRM data for plotting.""" self.largest_sample_period_xvals = np.mod( - np.arange((self.parrm._n_samples // 2) - 1), - self.parrm._period, + np.arange((self.parrm._n_samples // 2) - 1), self.parrm._period ) self.largest_sample_period_xvals_range = ( self.largest_sample_period_xvals.max() @@ -226,9 +210,7 @@ def _initialise_parrm_data_info(self) -> None: ) # filtered data info. - self.times = (self.time_range / self.parrm._sampling_freq)[ - :: self.decim - ] + self.times = (self.time_range / self.parrm._sampling_freq)[:: self.decim] # freq data info. self.fft_n_points = int(self.parrm._sampling_freq // self.freq_res) @@ -263,12 +245,8 @@ def update_period_half_width(half_width: str) -> None: def update_filter_half_width(half_width: str) -> None: """Update filter half width according to the textbox.""" half_width = int(half_width) - half_width = int( - np.max((half_width, self.filter_half_width_limits[0])) - ) - half_width = int( - np.min((half_width, self.filter_half_width_limits[1])) - ) + half_width = int(np.max((half_width, self.filter_half_width_limits[0]))) + half_width = int(np.min((half_width, self.filter_half_width_limits[1]))) self.textbox_filter_half_width.set_val(str(half_width)) self.current_filter_half_width = half_width @@ -289,9 +267,7 @@ def update_omit_n_samples(n_samples: str) -> None: self.current_omit_n_samples = n_samples if n_samples >= self.current_filter_half_width: - self.slider_omit_n_samples.set_val( - self.current_filter_half_width - 1 - ) + self.slider_omit_n_samples.set_val(self.current_filter_half_width - 1) return self._update_suptitle() self._update_filter() @@ -329,53 +305,39 @@ def _initialise_plot(self) -> None: self.figure.set_layout_engine(None) # stop updates to layout plt.ioff() # no longer needed - self.figure.canvas.mpl_connect( - "key_press_event", self._check_key_event - ) + self.figure.canvas.mpl_connect("key_press_event", self._check_key_event) # samples in period space focused plot self.sample_period_focused_axis = axes["upper left"] - self.sample_period_focused_scatter = ( - self.sample_period_focused_axis.scatter( - self.current_sample_period_xvals, - self.current_sample_period_yvals, - marker="o", - edgecolors="#1f77b4", - facecolors="none", - ) - ) - self.sample_period_focused_axis.set_xlim( - (0, self.current_period_half_width) - ) + self.sample_period_focused_scatter = self.sample_period_focused_axis.scatter( + self.current_sample_period_xvals, + self.current_sample_period_yvals, + marker="o", + edgecolors="#1f77b4", + facecolors="none", + ) + self.sample_period_focused_axis.set_xlim((0, self.current_period_half_width)) self._update_sample_period_focused_ylim() - self.sample_period_focused_axis.set_xlabel( - "Sample-period modulus (A.U.)" - ) + self.sample_period_focused_axis.set_xlabel("Sample-period modulus (A.U.)") self.sample_period_focused_axis.set_ylabel("Amplitude (data units)") # samples in period space overview plot self.sample_period_overview_axis = axes["upper inner"] - self.sample_period_overview_scatter = ( - self.sample_period_overview_axis.scatter( - self.current_sample_period_xvals, - self.current_sample_period_yvals, - marker=".", - s=1, - edgecolors="#1f77b4", - alpha=0.5, - ) + self.sample_period_overview_scatter = self.sample_period_overview_axis.scatter( + self.current_sample_period_xvals, + self.current_sample_period_yvals, + marker=".", + s=1, + edgecolors="#1f77b4", + alpha=0.5, ) self.sample_period_overview_axis.set_xlim( self.sample_period_overview_axis.get_xlim() ) - self.sample_period_focus_highlight = ( - self.sample_period_overview_axis.axvspan( - 0, self.current_period_half_width, color="red", alpha=0.2 - ) - ) - self.sample_period_overview_axis.set_xlabel( - "Sample-period modulus (A.U.)" + self.sample_period_focus_highlight = self.sample_period_overview_axis.axvspan( + 0, self.current_period_half_width, color="red", alpha=0.2 ) + self.sample_period_overview_axis.set_xlabel("Sample-period modulus (A.U.)") self.sample_period_overview_axis.set_ylabel("Amplitude (data units)") self.sample_period_overview_axis.set_title( r"$\Longleftarrow$ navigate with the arrow keys $\Longrightarrow$" @@ -431,9 +393,7 @@ def _initialise_plot(self) -> None: )[0] self.freq_data_axis.set_xlabel("Log frequency (Hz)") self.freq_data_axis.set_ylabel("Log power (dB/Hz)") - self.freq_data_axis.legend( - loc="upper left", bbox_to_anchor=(0.7, 1.22) - ) + self.freq_data_axis.legend(loc="upper left", bbox_to_anchor=(0.7, 1.22)) def _initialise_widgets(self) -> None: """Initialise widgets to use on the plot.""" @@ -454,10 +414,7 @@ def _initialise_widgets(self) -> None: textalignment="center", ) - self.filter_half_width_limits = [ - 1, - int((self.parrm._n_samples - 1) / 2), - ] + self.filter_half_width_limits = [1, int((self.parrm._n_samples - 1) / 2)] self.textbox_filter_half_width = TextBox( self.figure.add_axes((0.32, 0.09, 0.15, 0.03)), f"Filter half-width [{self.filter_half_width_limits[0]} - " @@ -466,10 +423,7 @@ def _initialise_widgets(self) -> None: textalignment="center", ) - self.omit_n_samples_limits = [ - 0, - int(((self.parrm._n_samples - 1) / 2) - 1), - ] + self.omit_n_samples_limits = [0, int(((self.parrm._n_samples - 1) / 2) - 1)] self.textbox_omit_n_samples = TextBox( self.figure.add_axes((0.32, 0.06, 0.15, 0.03)), f"Omitted samples [{self.omit_n_samples_limits[0]} - " @@ -478,9 +432,7 @@ def _initialise_widgets(self) -> None: textalignment="center", ) - buttons_filter_direction_axis = self.figure.add_axes( - (0.07, 0.06, 0.05, 0.1) - ) + buttons_filter_direction_axis = self.figure.add_axes((0.07, 0.06, 0.05, 0.1)) buttons_filter_direction_axis.set_title("Filter direction") self.buttons_filter_direction = RadioButtons( buttons_filter_direction_axis, @@ -523,8 +475,7 @@ def _update_period_window(self, step: float) -> None: def _update_sample_period_vals_plots(self) -> None: """Update values and plots of samples in period space.""" self.current_sample_period_xvals = np.mod( - np.arange(self.current_filter_half_width * 2 - 1), - self.parrm._period, + np.arange(self.current_filter_half_width * 2 - 1), self.parrm._period ) self.current_sample_period_yvals = np.diff( self.parrm._data[ @@ -533,26 +484,22 @@ def _update_sample_period_vals_plots(self) -> None: ) self.sample_period_focused_scatter.remove() - self.sample_period_focused_scatter = ( - self.sample_period_focused_axis.scatter( - self.current_sample_period_xvals, - self.current_sample_period_yvals, - marker="o", - edgecolors="#1f77b4", - facecolors="none", - ) + self.sample_period_focused_scatter = self.sample_period_focused_axis.scatter( + self.current_sample_period_xvals, + self.current_sample_period_yvals, + marker="o", + edgecolors="#1f77b4", + facecolors="none", ) self.sample_period_overview_scatter.remove() - self.sample_period_overview_scatter = ( - self.sample_period_overview_axis.scatter( - self.current_sample_period_xvals, - self.current_sample_period_yvals, - marker=".", - s=1, - edgecolors="#1f77b4", - alpha=0.5, - ) + self.sample_period_overview_scatter = self.sample_period_overview_axis.scatter( + self.current_sample_period_xvals, + self.current_sample_period_yvals, + marker=".", + s=1, + edgecolors="#1f77b4", + alpha=0.5, ) def _change_channel(self, step: int) -> None: @@ -571,9 +518,7 @@ def _update_sample_period_focused_xlim_position(self, step: float) -> None: step = 0 - xlim[0] if xlim[1] + step > self.current_sample_period_xvals.max(): step = self.current_sample_period_xvals.max() - xlim[1] - self.sample_period_focused_axis.set_xlim( - (xlim[0] + step, xlim[1] + step) - ) + self.sample_period_focused_axis.set_xlim((xlim[0] + step, xlim[1] + step)) def _update_sample_period_focused_xlim_width(self, width: float) -> None: """Update width of xlim of sample-period space focused plot.""" @@ -604,10 +549,8 @@ def _update_sample_period_focus_highlight(self) -> None: """Update shaded area displaying current period window.""" xlim = self.sample_period_focused_axis.get_xlim() self.sample_period_focus_highlight.remove() # clear old patch - self.sample_period_focus_highlight = ( - self.sample_period_overview_axis.axvspan( - xlim[0], xlim[1], color="red", alpha=0.2 - ) + self.sample_period_focus_highlight = self.sample_period_overview_axis.axvspan( + xlim[0], xlim[1], color="red", alpha=0.2 ) def _update_suptitle(self) -> None: @@ -682,9 +625,7 @@ def _update_filtered_data_lines(self) -> None: # timeseries data self.filtered_data_line_time = self.time_data_axis.plot( self.times, - self.filtered_data_time[self.current_channel_idx][ - :: self.decim - ], + self.filtered_data_time[self.current_channel_idx][:: self.decim], linewidth=0.5, color="#ff7f0e", label="Filtered data", @@ -730,16 +671,12 @@ def _update_sample_period_overview_ylim(self) -> None: """Update ylim of sample-period modulus overview plot.""" self.sample_period_focus_highlight.remove() # highlight affects ylim self.sample_period_overview_axis.relim() - self.sample_period_overview_axis.autoscale_view( - scalex=False, scaley=True - ) + self.sample_period_overview_axis.autoscale_view(scalex=False, scaley=True) # restore highlight for new ylim xlim = self.sample_period_focused_axis.get_xlim() ylim = self.sample_period_overview_axis.get_ylim() - self.sample_period_focus_highlight = ( - self.sample_period_overview_axis.axvspan( - xlim[0], xlim[1], ylim[0], ylim[1], color="red", alpha=0.2 - ) + self.sample_period_focus_highlight = self.sample_period_overview_axis.axvspan( + xlim[0], xlim[1], ylim[0], ylim[1], color="red", alpha=0.2 ) self._update_sample_period_focus_highlight() diff --git a/src/pyparrm/_utils/_power.py b/src/pyparrm/_utils/_power.py index 0b8d163..8e4835f 100644 --- a/src/pyparrm/_utils/_power.py +++ b/src/pyparrm/_utils/_power.py @@ -25,12 +25,12 @@ def compute_psd( Sampling frequency, in Hz, of `data`. n_points : int - Number of points to use when computing the Fourier coefficients. Should - be double the desired number of frequencies in the power spectra. + Number of points to use when computing the Fourier coefficients. Should be + double the desired number of frequencies in the power spectra. max_freq : int | float | None (default None) - The maximum frequency that should be returned. If :obj:`None`, values - for all computed frequencies returned. + The maximum frequency that should be returned. If :obj:`None`, values for all + computed frequencies returned. n_jobs : int (default ``1``) Number of jobs to run in parallel. @@ -49,12 +49,10 @@ def compute_psd( Data is converted to, and power is returned as, float32 values for speed. - As `data` is assumed to be real-valued, only positive frequencies are - returned. The zero frequency is also discarded. + As `data` is assumed to be real-valued, only positive frequencies are returned. The + zero frequency is also discarded. """ - freqs = np.abs( - fftfreq(n_points, 1.0 / sampling_freq)[1 : (n_points // 2) + 1] - ) + freqs = np.abs(fftfreq(n_points, 1.0 / sampling_freq)[1 : (n_points // 2) + 1]) if max_freq is None: max_freq = freqs[-1] max_freq_i = np.argwhere(freqs <= max_freq)[-1][0] diff --git a/src/pyparrm/parrm.py b/src/pyparrm/parrm.py index f57d381..2564649 100644 --- a/src/pyparrm/parrm.py +++ b/src/pyparrm/parrm.py @@ -20,10 +20,10 @@ class PARRM: """Class for removing stimulation artefacts from data using PARRM. - The Period-based Artefact Reconstruction and Removal Method (PARRM) is - described in Dastin-van Rijn *et al.* (2021) :footcite:`DastinEtAl2021`. - PARRM assumes that the artefacts are semi-regular, periodic, and linearly - combined with the signal of interest. + The Period-based Artefact Reconstruction and Removal Method (PARRM) is described in + Dastin-van Rijn *et al.* (2021) :footcite:`DastinEtAl2021`. PARRM assumes that the + artefacts are semi-regular, periodic, and linearly combined with the signal of + interest. The methods should be called in the following order: 1. :meth:`find_period` @@ -33,8 +33,7 @@ class PARRM: Parameters ---------- data : ~numpy.ndarray, shape of [channels, times] - Time-series from which stimulation artefacts should be identified and - removed. + Time-series from which stimulation artefacts should be identified and removed. sampling_freq : int | float Sampling frequency of :attr:`data`, in Hz. @@ -110,7 +109,7 @@ def __init__( verbose: bool = True, ) -> None: # noqa D107 self._check_init_inputs(data, sampling_freq, artefact_freq, verbose) - (self._n_chans, self._n_samples) = self._data.shape + self._n_chans, self._n_samples = self._data.shape def _check_init_inputs( self, @@ -165,22 +164,21 @@ def find_period( :obj:`None`, all samples are used. assumed_periods : int | float | tuple[int or float] | None (default None) - Guess(es) of the artefact period. If :obj:`None`, the period is - assumed to be ``sampling_freq`` / ``artefact_freq``. + Guess(es) of the artefact period. If :obj:`None`, the period is assumed to + be ``sampling_freq`` / ``artefact_freq``. outlier_boundary : int | float (default 3.0) - Boundary (in standard deviation) to consider outlier values in - :attr:`data`. + Boundary (in standard deviation) to consider outlier values in :attr:`data`. random_seed: int | None (default None) - Seed to use when generating indices of samples to search for the - period. Only used if the number of available samples is less than - the number of requested samples. + Seed to use when generating indices of samples to search for the period. + Only used if the number of available samples is less than the number of + requested samples. n_jobs : int (default 1) - Number of jobs to run in parallel when optimising the period - estimates. Must lie in the range [1, number of CPUs] (unless it is - -1, in which case all available CPUs are used). + Number of jobs to run in parallel when optimising the period estimates. Must + lie in the range [1, number of CPUs] (unless it is -1, in which case all + available CPUs are used). """ # noqa E501 if self._verbose: print("\nFinding the artefact period...") @@ -188,11 +186,7 @@ def find_period( self._reset_result_attrs() self._check_sort_find_stim_period_inputs( - search_samples, - assumed_periods, - outlier_boundary, - random_seed, - n_jobs, + search_samples, assumed_periods, outlier_boundary, random_seed, n_jobs ) self._standardise_data() @@ -227,9 +221,7 @@ def _check_sort_find_stim_period_inputs( n_jobs: int, ) -> None: """Check and sort `find_stim_period` inputs.""" - if search_samples is not None and not isinstance( - search_samples, np.ndarray - ): + if search_samples is not None and not isinstance(search_samples, np.ndarray): raise TypeError("`search_samples` must be a NumPy array or None.") if search_samples is None: search_samples = np.arange(self._n_samples - 1) @@ -238,8 +230,7 @@ def _check_sort_find_stim_period_inputs( search_samples = np.sort(search_samples) if search_samples[0] < 0 or search_samples[-1] >= self._n_samples: raise ValueError( - "Entries of `search_samples` must lie in the range [0, " - "n_samples)." + "Entries of `search_samples` must lie in the range [0, n_samples)." ) self._search_samples = search_samples.copy() @@ -250,17 +241,12 @@ def _check_sort_find_stim_period_inputs( "`assumed_periods` must be an int, a float, a tuple, or None." ) if assumed_periods is None: - assumed_periods = tuple( - [self._sampling_freq / self._artefact_freq] - ) + assumed_periods = tuple([self._sampling_freq / self._artefact_freq]) elif isinstance(assumed_periods, (int, float)): assumed_periods = tuple([assumed_periods]) - elif not all( - isinstance(entry, (int, float)) for entry in assumed_periods - ): + elif not all(isinstance(entry, (int, float)) for entry in assumed_periods): raise TypeError( - "If a tuple, entries of `assumed_periods` must be ints or " - "floats." + "If a tuple, entries of `assumed_periods` must be ints or floats." ) self._assumed_periods = deepcopy(assumed_periods) @@ -278,9 +264,7 @@ def _check_sort_find_stim_period_inputs( if not isinstance(n_jobs, int): raise TypeError("`n_jobs` must be an int.") if n_jobs > cpu_count(): - raise ValueError( - "`n_jobs` must be <= the number of available CPUs." - ) + raise ValueError("`n_jobs` must be <= the number of available CPUs.") if n_jobs <= 0 and n_jobs != -1: raise ValueError("If `n_jobs` is <= 0, it must be -1.") if n_jobs == -1: @@ -334,8 +318,8 @@ def _optimise_period_estimate(self) -> None: if np.isnan(estimated_period[0]): raise ValueError( - "The period cannot be estimated from the data. Check that " - "your data does not contain NaNs." + "The period cannot be estimated from the data. Check that your data " + "does not contain NaNs." ) self._period = self._optimise_period_estimate_final_run( @@ -359,8 +343,8 @@ def _get_centre_indices( Portion of the data segment to ignore when getting the indices. random_state : numpy.random.RandomState - Random state object to use to generate numbers if the available - number of samples is less than that requested. + Random state object to use to generate numbers if the available number of + samples is less than that requested. Returns ------- @@ -385,9 +369,7 @@ def _get_centre_indices( return ( np.unique( random_state.randint( - 0, - end_idx - start_idx, - np.min((use_n_samples, end_idx - start_idx)), + 0, end_idx - start_idx, np.min((use_n_samples, end_idx - start_idx)) ) ) + start_idx @@ -425,11 +407,7 @@ def _get_possible_periods( return np.unique(periods) def _optimise_period_estimate_first_run( - self, - periods: np.ndarray, - indices: np.ndarray, - bandwidth: int, - lambda_: float, + self, periods: np.ndarray, indices: np.ndarray, bandwidth: int, lambda_: float ) -> tuple[np.ndarray, np.ndarray]: """Perform initial period estimate optimisation run. @@ -453,8 +431,8 @@ def _optimise_period_estimate_first_run( Periods in ascending order according to the fit error. fit_error : numpy.ndarray, shape of [periods] - Error of the fit between the data and the sinusoidal harmonics - computed from the period. + Error of the fit between the data and the sinusoidal harmonics computed from + the period. """ optimise_local_args = [ { @@ -482,8 +460,8 @@ def _optimise_period_estimate_first_run( periods = periods[min_fit_error_idcs[fit_error != np.inf]] if periods.shape == (0,): # if no valid periods raise ValueError( - "The period cannot be estimated from the data. Check " - "that your data does not contain NaNs." + "The period cannot be estimated from the data. Check that your data " + "does not contain NaNs." ) return periods, fit_error @@ -504,8 +482,7 @@ def _optimise_period_estimate_second_run( Possible periods of the artefact. fit_errors : numpy.ndarray, shape of [periods] - Fit errors between the sinusoidal harmonics and the data for each - period. + Fit errors between the sinusoidal harmonics and the data for each period. indices : numpy.ndarray, shape of [samples] Sample indices of the data to use. @@ -570,12 +547,7 @@ def _optimise_period_estimate_final_run( return fmin( self._optimise_local, period, - ( - self._standard_data, - indices, - bandwidth, - 0.0, # lambda - ), + (self._standard_data, indices, bandwidth, 0.0), # <-- lambda = 0.0 disp=False, )[0] @@ -602,13 +574,13 @@ def _optimise_local( Returns ------- fit_error : float - Error of the fit between the data and the sinusoidal harmonics - computed from the period, averaged across channels. + Error of the fit between the data and the sinusoidal harmonics computed from + the period, averaged across channels. Notes ----- - ``period`` must be the first input argument for scipy.optimize.fmin to - work on this function. + ``period`` must be the first input argument for scipy.optimize.fmin to work on + this function. """ fit_error = 0.0 @@ -627,11 +599,7 @@ def _optimise_local( return fit_error / self._n_chans def _fit_waves_to_data( - self, - data: np.ndarray, - indices: np.ndarray, - period: float, - bandwidth: int, + self, data: np.ndarray, indices: np.ndarray, period: float, bandwidth: int ) -> tuple[np.ndarray, np.ndarray] | tuple[float, float]: """Fit sine and cosine waves to the data. @@ -642,14 +610,13 @@ def _fit_waves_to_data( Returns ------- residuals : numpy.ndarray, shape of [samples] - Squared residuals of the fit between between the data and the - sinusoidal harmonics. A np.inf value is returned if the matrices - are singular. + Squared residuals of the fit between between the data and the sinusoidal + harmonics. A np.inf value is returned if the matrices are singular. beta : numpy.ndarray, shape of [2 * `bandwidth` + 1, samples] - Squared beta coefficient of the linear regression between the data - and the sinusoidal harmonics. An np.inf value is returned if the - matrices are singular. + Squared beta coefficient of the linear regression between the data and the + sinusoidal harmonics. An np.inf value is returned if the matrices are + singular. """ angles = (indices + 1) * (2 * np.pi / period) waves = np.ones((data.shape[0], 2 * bandwidth + 1)) @@ -682,27 +649,26 @@ def explore_filter_params( Parameters ---------- time_range : list of int or float | None (default None) - Range of the times to plot and filter in a list of length two, - containing the first and last timepoints, respectively, in seconds. - If :obj:`None`, all timepoints are used. + Range of the times to plot and filter in a list of length two, containing + the first and last timepoints, respectively, in seconds. If :obj:`None`, all + timepoints are used. time_res : int | float (default ``0.01``) - Time resolution, in seconds, to use when plotting the time-series - data. + Time resolution, in seconds, to use when plotting the time-series data. freq_range : list of int or float | None (default None) - Range of the frequencies to plot in a list of length two, - containing the first and last frequencies, respectively, in Hz. If - :obj:`None`, all frequencies are used. + Range of the frequencies to plot in a list of length two, containing the + first and last frequencies, respectively, in Hz. If :obj:`None`, all + frequencies are used. freq_res : int | float (default ``5.0``) - Frequency resolution, in Hz, to use when computing the power - spectra of the data. + Frequency resolution, in Hz, to use when computing the power spectra of the + data. n_jobs : int (default ``1``) - Number of jobs to run in parallel when computing the power spectra. - Must lie in the range [1, number of CPUs] (unless it is -1, in - which case all available CPUs are used). + Number of jobs to run in parallel when computing the power spectra. Must lie + in the range [1, number of CPUs] (unless it is -1, in which case all + available CPUs are used). Notes ----- @@ -714,8 +680,8 @@ def explore_filter_params( if self._period is None: raise ValueError( - "The period has not yet been estimated. The `find_period` " - "method must be called first." + "The period has not yet been estimated. The `find_period` method must " + "be called first." ) param_explorer = _ExploreParams( self, time_range, time_res, freq_range, freq_res, n_jobs @@ -737,37 +703,34 @@ def create_filter( Parameters ---------- filter_half_width : int | None (default None) - Half-width of the filter to create, in samples. If :obj:`None`, a - filter half-width will be generated based on ``omit_n_samples``. + Half-width of the filter to create, in samples. If :obj:`None`, a filter + half-width will be generated based on ``omit_n_samples``. omit_n_samples : int (default ``0``) Number of samples to omit from the centre of ``filter_half_width``. filter_direction : str (default "both") - Direction from which samples should be taken to create the filter, - relative to the centre of the filter window. Can be: "both" for - backward and forward samples; "past" for backward samples only; and - "future" for forward samples only. + Direction from which samples should be taken to create the filter, relative + to the centre of the filter window. Can be: "both" for backward and forward + samples; "past" for backward samples only; and "future" for forward samples + only. period_half_width : int | float | None (default None) - Half-width of the window in samples of period space for which - points at a similar location in the waveform will be averaged. If - :obj:`None`, :attr:`period` / 50 is used. + Half-width of the window in samples of period space for which points at a + similar location in the waveform will be averaged. If :obj:`None`, + :attr:`period` / 50 is used. """ if self._verbose: print("Creating the filter...") if self._period is None: raise ValueError( - "The period has not yet been estimated. The `find_period` " - "method must be called first." + "The period has not yet been estimated. The `find_period` method must " + "be called first." ) self._check_sort_create_filter_inputs( - filter_half_width, - omit_n_samples, - filter_direction, - period_half_width, + filter_half_width, omit_n_samples, filter_direction, period_half_width ) self._generate_filter() @@ -787,8 +750,7 @@ def _check_sort_create_filter_inputs( raise TypeError("`omit_n_samples` must be an int.") if omit_n_samples < 0 or omit_n_samples >= (self._n_samples - 1) // 2: raise ValueError( - "`omit_n_samples` must lie in the range [0, (no. of samples - " - "1) // 2)." + "`omit_n_samples` must lie in the range [0, (no. of samples - 1) // 2)." ) self._omit_n_samples = deepcopy(omit_n_samples) @@ -842,9 +804,7 @@ def _get_filter_half_width(self) -> int: def _generate_filter(self) -> None: """Generate linear filter for removing stimulation artefacts.""" - window = np.arange( - -self._filter_half_width, self._filter_half_width + 1 - ) + window = np.arange(-self._filter_half_width, self._filter_half_width + 1) modulus = np.mod(window, self._period) filter_ = np.zeros_like(window, dtype=np.float64) @@ -863,14 +823,12 @@ def _generate_filter(self) -> None: if np.all(filter_ == np.zeros_like(filter_)): raise RuntimeError( - "A suitable filter cannot be created with the specified " - "settings. Try reducing the number of omitted samples and/or " - "increasing the filter half-width." + "A suitable filter cannot be created with the specified settings. Try " + "reducing the number of omitted samples and/or increasing the filter " + "half-width." ) - filter_ = -filter_ / np.max( - (filter_.sum(), np.finfo(filter_.dtype).eps) - ) + filter_ = -filter_ / np.max((filter_.sum(), np.finfo(filter_.dtype).eps)) filter_[window == 0] = 1 @@ -879,8 +837,7 @@ def _generate_filter(self) -> None: def filter_data(self, data: np.ndarray | None = None) -> np.ndarray: """Apply the PARRM filter to the data and return it. - Can only be called after the filter has been created with - :meth:`create_filter`. + Can only be called after the filter has been created with :meth:`create_filter`. Parameters ---------- @@ -897,19 +854,15 @@ def filter_data(self, data: np.ndarray | None = None) -> np.ndarray: if self._filter is None: raise ValueError( - "The filter has not yet been created. The `create_filter` " - "method must be called first." + "The filter has not yet been created. The `create_filter` method must " + "be called first." ) data = self._check_sort_filter_data_inputs(data) - numerator = ( - convolve(data.T, self._filter[:, np.newaxis], "same") - data.T - ) + numerator = convolve(data.T, self._filter[:, np.newaxis], "same") - data.T denominator = 1 - convolve( - np.ones_like(data).T, - self._filter[:, np.newaxis], - "same", + np.ones_like(data).T, self._filter[:, np.newaxis], "same" ) filtered_data = (numerator / denominator + data.T).T @@ -924,9 +877,7 @@ def filter_data(self, data: np.ndarray | None = None) -> np.ndarray: return self._filtered_data - def _check_sort_filter_data_inputs( - self, data: np.ndarray | None - ) -> np.ndarray: + def _check_sort_filter_data_inputs(self, data: np.ndarray | None) -> np.ndarray: """Check and sort `filter_data` inputs.""" if data is None: data = self._data @@ -967,9 +918,7 @@ def filtered_data(self) -> np.ndarray: def settings(self) -> dict: """Return the settings used to generate the PARRM filter.""" if self._period is None or self._filter is None: - raise AttributeError( - "Analysis settings have not been established yet." - ) + raise AttributeError("Analysis settings have not been established yet.") return { "data": { "sampling_freq": self._sampling_freq, diff --git a/tests/test_parrm.py b/tests/test_parrm.py index 00d380f..b97f7ab 100644 --- a/tests/test_parrm.py +++ b/tests/test_parrm.py @@ -34,9 +34,7 @@ def test_parrm(n_chans: int, n_samples: int, verbose: bool, n_jobs: int): verbose=verbose, ) parrm.find_period( - assumed_periods=sampling_freq / artefact_freq, - random_seed=44, - n_jobs=n_jobs, + assumed_periods=sampling_freq / artefact_freq, random_seed=44, n_jobs=n_jobs ) for direction in ["future", "past", "both"]: parrm.create_filter(filter_direction=direction) @@ -88,9 +86,7 @@ def test_parrm_attrs(): settings = parrm.settings assert settings["data"]["sampling_freq"] == sampling_freq assert settings["data"]["artefact_freq"] == artefact_freq - assert np.all( - settings["period"]["search_samples"] == parrm._search_samples - ) + assert np.all(settings["period"]["search_samples"] == parrm._search_samples) assert settings["period"]["assumed_periods"] == parrm._assumed_periods assert settings["period"]["outlier_boundary"] == parrm._outlier_boundary assert settings["period"]["random_seed"] == parrm._random_seed @@ -107,26 +103,12 @@ def test_parrm_wrong_type_inputs(): # init object with pytest.raises(TypeError, match="`data` must be a NumPy array."): PARRM( - data=data.tolist(), - sampling_freq=sampling_freq, - artefact_freq=artefact_freq, - ) - with pytest.raises( - TypeError, match="`sampling_freq` must be an int or a float." - ): - PARRM( - data=data, - sampling_freq=str(sampling_freq), - artefact_freq=artefact_freq, - ) - with pytest.raises( - TypeError, match="`artefact_freq` must be an int or a float." - ): - PARRM( - data=data, - sampling_freq=sampling_freq, - artefact_freq=str(artefact_freq), + data=data.tolist(), sampling_freq=sampling_freq, artefact_freq=artefact_freq ) + with pytest.raises(TypeError, match="`sampling_freq` must be an int or a float."): + PARRM(data=data, sampling_freq=str(sampling_freq), artefact_freq=artefact_freq) + with pytest.raises(TypeError, match="`artefact_freq` must be an int or a float."): + PARRM(data=data, sampling_freq=sampling_freq, artefact_freq=str(artefact_freq)) with pytest.raises(TypeError, match="`verbose` must be a bool."): PARRM( data=data, @@ -134,11 +116,7 @@ def test_parrm_wrong_type_inputs(): artefact_freq=artefact_freq, verbose=str(False), ) - parrm = PARRM( - data=data, - sampling_freq=sampling_freq, - artefact_freq=artefact_freq, - ) + parrm = PARRM(data=data, sampling_freq=sampling_freq, artefact_freq=artefact_freq) # find_period with pytest.raises( @@ -146,22 +124,18 @@ def test_parrm_wrong_type_inputs(): ): parrm.find_period(search_samples=0) with pytest.raises( - TypeError, - match="`assumed_periods` must be an int, a float, a tuple, or None.", + TypeError, match="`assumed_periods` must be an int, a float, a tuple, or None." ): parrm.find_period(assumed_periods=[0]) with pytest.raises( - TypeError, - match="If a tuple, entries of `assumed_periods` must be ints or ", + TypeError, match="If a tuple, entries of `assumed_periods` must be ints or " ): parrm.find_period(assumed_periods=tuple(["test"])) with pytest.raises( TypeError, match="`outlier_boundary` must be an int or a float." ): parrm.find_period(outlier_boundary=[0]) - with pytest.raises( - TypeError, match="`random_seed` must be an int or None." - ): + with pytest.raises(TypeError, match="`random_seed` must be an int or None."): parrm.find_period(random_seed=1.5) with pytest.raises(TypeError, match="`n_jobs` must be an int."): parrm.find_period(n_jobs=1.5) @@ -176,9 +150,7 @@ def test_parrm_wrong_type_inputs(): TypeError, match="`time_range` must be a list of ints or floats." ): parrm.explore_filter_params(time_range=[0, "end"]) - with pytest.raises( - TypeError, match="`time_res` must be an int or a float." - ): + with pytest.raises(TypeError, match="`time_res` must be an int or a float."): parrm.explore_filter_params(time_res="all") with pytest.raises( TypeError, match="`freq_range` must be a list of ints or floats." @@ -188,9 +160,7 @@ def test_parrm_wrong_type_inputs(): TypeError, match="`freq_range` must be a list of ints or floats." ): parrm.explore_filter_params(freq_range=[0, "Nyquist"]) - with pytest.raises( - TypeError, match="`freq_res` must be an int or a float." - ): + with pytest.raises(TypeError, match="`freq_res` must be an int or a float."): parrm.explore_filter_params(freq_res=[0]) with pytest.raises(TypeError, match="`n_jobs` must be an int."): parrm.explore_filter_params(n_jobs=1.5) @@ -225,36 +195,20 @@ def test_parrm_wrong_value_inputs(): artefact_freq=artefact_freq, ) with pytest.raises(ValueError, match="`sampling_freq` must be > 0."): - PARRM( - data=data, - sampling_freq=0, - artefact_freq=artefact_freq, - ) + PARRM(data=data, sampling_freq=0, artefact_freq=artefact_freq) with pytest.raises(ValueError, match="`artefact_freq` must be > 0."): - PARRM( - data=data, - sampling_freq=sampling_freq, - artefact_freq=0, - ) - parrm = PARRM( - data=data, - sampling_freq=sampling_freq, - artefact_freq=artefact_freq, - ) + PARRM(data=data, sampling_freq=sampling_freq, artefact_freq=0) + parrm = PARRM(data=data, sampling_freq=sampling_freq, artefact_freq=artefact_freq) # find_period - with pytest.raises( - ValueError, match="`search_samples` must be a 1D array." - ): + with pytest.raises(ValueError, match="`search_samples` must be a 1D array."): parrm.find_period(search_samples=np.zeros((1, 1))) with pytest.raises( - ValueError, - match="Entries of `search_samples` must lie in the range ", + ValueError, match="Entries of `search_samples` must lie in the range " ): parrm.find_period(search_samples=np.array([-1, 1])) with pytest.raises( - ValueError, - match="Entries of `search_samples` must lie in the range ", + ValueError, match="Entries of `search_samples` must lie in the range " ): parrm.find_period(search_samples=np.array([0, data.shape[1]])) with pytest.raises(ValueError, match="`outlier_boundary` must be > 0."): @@ -263,44 +217,28 @@ def test_parrm_wrong_value_inputs(): ValueError, match="`n_jobs` must be <= the number of available CPUs." ): parrm.find_period(n_jobs=cpu_count() + 1) - with pytest.raises( - ValueError, match="If `n_jobs` is <= 0, it must be -1." - ): + with pytest.raises(ValueError, match="If `n_jobs` is <= 0, it must be -1."): parrm.find_period(n_jobs=-2) parrm.find_period() # explore_filter_params - with pytest.raises( - ValueError, match="`time_range` must have a length of 2." - ): + with pytest.raises(ValueError, match="`time_range` must have a length of 2."): parrm.explore_filter_params(time_range=[0, 1, 2]) - with pytest.raises( - ValueError, match="`time_range` must lie in the range " - ): + with pytest.raises(ValueError, match="`time_range` must lie in the range "): parrm.explore_filter_params(time_range=[-1, 1]) - with pytest.raises( - ValueError, match="`time_range` must lie in the range " - ): - parrm.explore_filter_params( - time_range=[0, (data.shape[1] / sampling_freq) + 1] - ) + with pytest.raises(ValueError, match="`time_range` must lie in the range "): + parrm.explore_filter_params(time_range=[0, (data.shape[1] / sampling_freq) + 1]) with pytest.raises(ValueError, match="`time_range"): parrm.explore_filter_params(time_range=[1, 0]) with pytest.raises(ValueError, match="`time_res` must lie in the range "): parrm.explore_filter_params(time_res=0) with pytest.raises(ValueError, match="`time_res` must lie in the range "): parrm.explore_filter_params(time_res=data.shape[1] / sampling_freq) - with pytest.raises( - ValueError, match="`freq_range` must have a length of 2." - ): + with pytest.raises(ValueError, match="`freq_range` must have a length of 2."): parrm.explore_filter_params(freq_range=[0, 1, 2]) - with pytest.raises( - ValueError, match="`freq_range` must lie in the range " - ): + with pytest.raises(ValueError, match="`freq_range` must lie in the range "): parrm.explore_filter_params(freq_range=[-1, 1]) - with pytest.raises( - ValueError, match="`freq_range` must lie in the range " - ): + with pytest.raises(ValueError, match="`freq_range` must lie in the range "): parrm.explore_filter_params(freq_range=[0, (sampling_freq / 2) + 1]) with pytest.raises(ValueError, match="`freq_range"): parrm.explore_filter_params(freq_range=[1, 0]) @@ -312,31 +250,20 @@ def test_parrm_wrong_value_inputs(): ValueError, match="`n_jobs` must be <= the number of available CPUs." ): parrm.explore_filter_params(n_jobs=cpu_count() + 1) - with pytest.raises( - ValueError, match="If `n_jobs` is <= 0, it must be -1." - ): + with pytest.raises(ValueError, match="If `n_jobs` is <= 0, it must be -1."): parrm.explore_filter_params(n_jobs=-2) # create_filter - with pytest.raises( - ValueError, match="`filter_half_width` must lie in the range" - ): + with pytest.raises(ValueError, match="`filter_half_width` must lie in the range"): parrm.create_filter(filter_half_width=1, omit_n_samples=2) - with pytest.raises( - ValueError, match="`filter_half_width` must lie in the range" - ): + with pytest.raises(ValueError, match="`filter_half_width` must lie in the range"): parrm.create_filter(filter_half_width=((data.shape[1] - 1) // 2) + 1) - with pytest.raises( - ValueError, match="`omit_n_samples` must lie in the range" - ): + with pytest.raises(ValueError, match="`omit_n_samples` must lie in the range"): parrm.create_filter(omit_n_samples=-1) - with pytest.raises( - ValueError, match="`omit_n_samples` must lie in the range" - ): + with pytest.raises(ValueError, match="`omit_n_samples` must lie in the range"): parrm.create_filter(omit_n_samples=(data.shape[1] - 1) // 2) with pytest.raises( - ValueError, - match="`period_half_width` must be lie in the range ", + ValueError, match="`period_half_width` must be lie in the range " ): parrm.create_filter(period_half_width=0) with pytest.raises( @@ -347,8 +274,7 @@ def test_parrm_wrong_value_inputs(): with pytest.raises(ValueError, match="`filter_direction` must be one of "): parrm.create_filter(filter_direction="not_a_direction") with pytest.raises( - RuntimeError, - match="A suitable filter cannot be created with the specified ", + RuntimeError, match="A suitable filter cannot be created with the specified " ): parrm.create_filter(omit_n_samples=48) parrm.create_filter() @@ -366,38 +292,25 @@ def test_parrm_premature_method_attribute_calls(): artefact_freq=artefact_freq, verbose=False, ) - with pytest.raises( - ValueError, match="The period has not yet been estimated." - ): + with pytest.raises(ValueError, match="The period has not yet been estimated."): parrm.explore_filter_params() - with pytest.raises( - ValueError, match="The period has not yet been estimated." - ): + with pytest.raises(ValueError, match="The period has not yet been estimated."): parrm.create_filter() - with pytest.raises( - ValueError, match="The filter has not yet been created." - ): + with pytest.raises(ValueError, match="The filter has not yet been created."): parrm.filter_data() - with pytest.raises( - AttributeError, match="No period has been computed yet." - ): + with pytest.raises(AttributeError, match="No period has been computed yet."): parrm.period - with pytest.raises( - AttributeError, match="No filter has been computed yet." - ): + with pytest.raises(AttributeError, match="No filter has been computed yet."): parrm.filter with pytest.raises(AttributeError, match="No data has been filtered yet."): parrm.filtered_data with pytest.raises( - AttributeError, - match="Analysis settings have not been established yet.", + AttributeError, match="Analysis settings have not been established yet." ): parrm.settings parrm.find_period() - with pytest.raises( - ValueError, match="The filter has not yet been created." - ): + with pytest.raises(ValueError, match="The filter has not yet been created."): parrm.filter_data() @@ -424,10 +337,7 @@ def test_compute_psd(n_chans: int, n_jobs: int): n_freqs = 5 freqs, psd = compute_psd( - data=data, - sampling_freq=sampling_freq, - n_points=n_freqs * 2, - n_jobs=n_jobs, + data=data, sampling_freq=sampling_freq, n_points=n_freqs * 2, n_jobs=n_jobs ) assert psd.shape == (n_chans, n_freqs) @@ -454,9 +364,7 @@ def test_get_example_data_paths() -> None: for name, file in DATASETS.items(): path = get_example_data_paths(name=name) assert isinstance(path, str), "`path` should be a str." - assert path.endswith( - file - ), "`path` should end with the name of the dataset." + assert path.endswith(file), "`path` should end with the name of the dataset." assert os.path.exists(path), "`path` should point to an existing file." # test it catches incorrect inputs From b4a5b88fd3d49fc0cd03029e99139b3197d205f2 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 17 Sep 2024 17:18:20 +0200 Subject: [PATCH 2/2] Add ignore --- .git-blame-ignore-revs | 1 + 1 file changed, 1 insertion(+) create mode 100644 .git-blame-ignore-revs diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..bd17827 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1 @@ +6eec2635c9baaa7b1d23344070cb20a7f49b8624 # linting and line length \ No newline at end of file