From c2b64bc35fc0af2c0f9785862eb448cc385df740 Mon Sep 17 00:00:00 2001 From: Martin Schulz Date: Fri, 8 Dec 2023 17:08:19 +0100 Subject: [PATCH] FIX: Connectivity saved to .nc and other fixes (#71) --- mne_pipeline_hd/__main__.py | 80 +++-- .../development/console_widget_speed.py | 101 ++++-- mne_pipeline_hd/extra/functions.csv | 25 +- mne_pipeline_hd/extra/parameters.csv | 7 +- mne_pipeline_hd/functions/operations.py | 204 ++++++------ mne_pipeline_hd/functions/plot.py | 310 ++++++++---------- mne_pipeline_hd/gui/base_widgets.py | 68 ++-- mne_pipeline_hd/gui/dialogs.py | 38 ++- mne_pipeline_hd/gui/function_widgets.py | 9 +- mne_pipeline_hd/gui/gui_utils.py | 131 ++++---- mne_pipeline_hd/gui/loading_widgets.py | 132 +++++--- mne_pipeline_hd/gui/main_window.py | 11 +- mne_pipeline_hd/gui/models.py | 10 +- mne_pipeline_hd/gui/parameter_widgets.py | 147 ++++++--- mne_pipeline_hd/gui/plot_widgets.py | 17 +- mne_pipeline_hd/pipeline/controller.py | 60 +++- mne_pipeline_hd/pipeline/function_utils.py | 41 +-- mne_pipeline_hd/pipeline/legacy.py | 17 +- mne_pipeline_hd/pipeline/loading.py | 203 +++++++----- mne_pipeline_hd/pipeline/pipeline_utils.py | 48 ++- mne_pipeline_hd/pipeline/project.py | 61 ++-- mne_pipeline_hd/tests/test_concurrent.py | 4 +- mne_pipeline_hd/tests/test_console.py | 40 +++ mne_pipeline_hd/tests/test_loading.py | 5 +- .../tests/test_parameter_widgets.py | 42 ++- requirements.txt | 1 + 26 files changed, 1058 insertions(+), 754 deletions(-) create mode 100644 mne_pipeline_hd/tests/test_console.py diff --git a/mne_pipeline_hd/__main__.py b/mne_pipeline_hd/__main__.py index 8643cbe..fe8c0cc 100644 --- a/mne_pipeline_hd/__main__.py +++ b/mne_pipeline_hd/__main__.py @@ -5,10 +5,11 @@ Github: https://github.com/marsipu/mne-pipeline-hd """ -import logging import os +import re import sys from importlib import resources +from os.path import join import qtpy from qtpy.QtCore import QTimer, Qt @@ -19,7 +20,14 @@ from mne_pipeline_hd.gui.gui_utils import StdoutStderrStream, UncaughtHook from mne_pipeline_hd.gui.welcome_window import WelcomeWindow from mne_pipeline_hd.pipeline.legacy import legacy_import_check -from mne_pipeline_hd.pipeline.pipeline_utils import ismac, islin, QS +from mne_pipeline_hd.pipeline.pipeline_utils import ( + ismac, + islin, + QS, + iswin, + init_logging, + logger, +) # Check for changes in required packages legacy_import_check() @@ -27,6 +35,12 @@ import qdarktheme # noqa: E402 +def init_streams(): + # Redirect stdout and stderr to capture it later in GUI + sys.stdout = StdoutStderrStream("stdout") + sys.stderr = StdoutStderrStream("stderr") + + def main(): app_name = "mne-pipeline-hd" organization_name = "marsipu" @@ -40,6 +54,8 @@ def main(): app.setApplicationName(app_name) app.setOrganizationName(organization_name) app.setOrganizationDomain(domain_name) + # For Spyder to make console accessible again + app.lastWindowClosed.connect(app.quit) # Avoid file-dialog-problems with custom file-managers in linux if islin: @@ -54,37 +70,23 @@ def main(): # # Set multiprocessing method to spawn # multiprocessing.set_start_method('spawn') - # Redirect stdout to capture it later in GUI - sys.stdout = StdoutStderrStream("stdout") - # Redirect stderr to capture it later in GUI - sys.stderr = StdoutStderrStream("stderr") + init_streams() debug_mode = os.environ.get("MNEPHD_DEBUG", False) == "true" + init_logging(debug_mode) - # Initialize Logger (root) - logger = logging.getLogger() - if debug_mode: - logger.setLevel(logging.DEBUG) - else: - logger.setLevel(QS().value("log_level", defaultValue=logging.INFO)) - formatter = logging.Formatter( - "%(asctime)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S" - ) - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - logger.info("Starting MNE-Pipeline HD") + logger().info("Starting MNE-Pipeline HD") # Show Qt-binding if any([qtpy.PYQT5, qtpy.PYQT6]): qt_version = qtpy.PYQT_VERSION else: qt_version = qtpy.PYSIDE_VERSION - logger.info(f"Using {qtpy.API_NAME} {qt_version}") + logger().info(f"Using {qtpy.API_NAME} {qt_version}") # Initialize Exception-Hook if debug_mode: - logger.info("Debug-Mode is activated") + logger().info("Debug-Mode is activated") else: qt_exception_hook = UncaughtHook() # this registers the exception_hook() function @@ -103,40 +105,34 @@ def main(): if app_style not in ["dark", "light", "auto"]: app_style = "auto" - if app_style == "dark": - qdarktheme.setup_theme("dark") + qdarktheme.setup_theme(app_style) + st = qdarktheme.load_stylesheet(app_style) + is_dark = "background:rgba(32, 33, 36, 1.000)" in st + if is_dark: icon_name = "mne_pipeline_icon_dark.png" - elif app_style == "light": - qdarktheme.setup_theme("light") - icon_name = "mne_pipeline_icon_light.png" + # Fix ToolTip-Problem on Windows + # https://github.com/5yutan5/PyQtDarkTheme/issues/239 + if iswin: + match = re.search(r"QToolTip \{([^\{\}]+)\}", st) + if match is not None: + replace_str = "QToolTip {" + match.group(1) + ";border: 0px}" + st = st.replace(match.group(0), replace_str) + QApplication.instance().setStyleSheet(st) else: - qdarktheme.setup_theme("auto") - st = qdarktheme.load_stylesheet("auto") - if "background:rgba(32, 33, 36, 1.000)" in st: - icon_name = "mne_pipeline_icon_dark.png" - else: - icon_name = "mne_pipeline_icon_light.png" - - icon_path = resources.files(mne_pipeline_hd.extra) / icon_name + icon_name = "mne_pipeline_icon_light.png" + + icon_path = join(resources.files(mne_pipeline_hd.extra), icon_name) app_icon = QIcon(str(icon_path)) app.setWindowIcon(app_icon) # Initiate WelcomeWindow WelcomeWindow() - # Redirect stdout to capture it later in GUI - sys.stdout = StdoutStderrStream("stdout") - # Redirect stderr to capture it later in GUI - sys.stderr = StdoutStderrStream("stderr") - # Command-Line interrupt with Ctrl+C possible timer = QTimer() timer.timeout.connect(lambda: None) timer.start(500) - # For Spyder to make console accessible again - app.lastWindowClosed.connect(app.quit) - sys.exit(app.exec()) diff --git a/mne_pipeline_hd/development/console_widget_speed.py b/mne_pipeline_hd/development/console_widget_speed.py index e078f37..8c8bf81 100644 --- a/mne_pipeline_hd/development/console_widget_speed.py +++ b/mne_pipeline_hd/development/console_widget_speed.py @@ -1,42 +1,75 @@ # -*- coding: utf-8 -*- import sys -from time import perf_counter -import numpy as np from PyQt5.QtCore import QTimer from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton from mne_pipeline_hd.gui.gui_utils import ConsoleWidget -app = QApplication(sys.argv) -widget = QWidget() -layout = QVBoxLayout() -cw = ConsoleWidget() -layout.addWidget(cw) -close_bt = QPushButton("Close") -close_bt.clicked.connect(widget.close) -layout.addWidget(close_bt) -widget.setLayout(layout) -widget.show() -last_time = perf_counter() - -performance_buffer = list() - - -def test_write(): - global last_time - cw.write_progress("\r" + (f"Test {len(performance_buffer)}" * 1000)) - diff = perf_counter() - last_time - performance_buffer.append(diff) - if len(performance_buffer) >= 100: - fps = 1 / np.mean(performance_buffer) - print(f"Performance is: {fps:.2f} FPS") - performance_buffer.clear() - last_time = perf_counter() - - -timer = QTimer() -timer.timeout.connect(test_write) -timer.start(1) - -sys.exit(app.exec()) +test_text = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed non risus. +Suspendisse lectus tortor, dignissim sit amet, adipiscing nec, ultricies sed, +dolor. Cras elementum ultrices diam. Maecenas ligula massa, varius a, semper +congue, euismod non, mi. Proin porttitor, orci nec nonummy molestie, enim est +eleifend mi, non fermentum diam nisl sit amet erat. Duis semper. Duis arcu +massa, scelerisque vitae, consequat in, pretium a, enim. Pellentesque congue. +Ut in risus volutpat libero pharetra tempor. Cras vestibulum bibendum augue. +Praesent egestas leo in pede. Praesent blandit odio eu enim. Pellentesque +sed dui ut augue blandit sodales. Vestibulum ante ipsum primis in faucibus +orci luctus et ultrices posuere cubilia Curae; Aliquam nibh. Mauris ac mauris +sed pede pellentesque fermentum. Maecenas adipiscing ante non diam sodales +hendrerit. +\r Progress: 0% +\r Progress: 10% +\r Progress: 20% +\r Progress: 30% +\r Progress: 40% +\r Progress: 50% +\r Progress: 60% +\r Progress: 70% +\r Progress: 80% +\r Progress: 90% +\r Progress: 100%""" + + +class SpeedWidget(QWidget): + def __init__(self): + super().__init__() + + layout = QVBoxLayout(self) + self.cw = ConsoleWidget() + layout.addWidget(self.cw) + startbt = QPushButton("Start") + startbt.clicked.connect(self.start) + layout.addWidget(startbt) + stopbt = QPushButton("Stop") + stopbt.clicked.connect(self.stop) + layout.addWidget(stopbt) + close_bt = QPushButton("Close") + close_bt.clicked.connect(self.close) + layout.addWidget(close_bt) + + self.test_text = test_text.split("\n") + self.line_idx = 0 + + self.timer = QTimer(self) + self.timer.timeout.connect(self.write) + + def start(self): + self.timer.start(42) + + def stop(self): + self.timer.stop() + + def write(self): + if self.line_idx >= len(self.test_text): + self.line_idx = 0 + text = self.test_text[self.line_idx] + self.cw.write_stdout(text) + self.line_idx += 1 + + +if __name__ == "__main__": + app = QApplication(sys.argv) + w = SpeedWidget() + w.show() + sys.exit(app.exec()) diff --git a/mne_pipeline_hd/extra/functions.csv b/mne_pipeline_hd/extra/functions.csv index f393ef9..ac4ad8c 100644 --- a/mne_pipeline_hd/extra/functions.csv +++ b/mne_pipeline_hd/extra/functions.csv @@ -1,6 +1,8 @@ ;alias;target;tab;group;matplotlib;mayavi;dependencies;module;pkg_name;func_args find_bads;Find Bad Channels;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,n_jobs filter_data;Filter;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,filter_target,highpass,lowpass,filter_length,l_trans_bandwidth,h_trans_bandwidth,filter_method,iir_params,fir_phase,fir_window,fir_design,skip_by_annotation,fir_pad,n_jobs,enable_cuda,erm_t_limit,bad_interpolation +notch_filter;Notch Filter;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,notch_frequencies,n_jobs +interpolate_bads;Interpolate Bads;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,bad_interpolation add_erm_ssp;Empty-Room SSP;MEEG;Compute;Preprocessing;True;False;;operations;basic;meeg,erm_ssp_duration,erm_n_grad,erm_n_mag,erm_n_eeg,n_jobs,show_plots eeg_reference_raw;Set EEG Reference;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,ref_channels find_events;Find events;MEEG;Compute;events;False;False;;operations;basic;meeg,stim_channels,min_duration,shortest_event,adjust_timeline_by_msec @@ -10,7 +12,6 @@ estimate_noise_covariance;Noise-Covariance;MEEG;Compute;Preprocessing;False;Fals run_ica;Run ICA;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,ica_method,ica_fitto,n_components,ica_noise_cov,ica_remove_proj,ica_reject,ica_autoreject,overwrite_ar,ch_types,ch_names,reject_by_annotation,ica_eog,eog_channel,ica_ecg,ecg_channel apply_ica;Apply ICA;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,ica_apply_target,n_pca_components get_evokeds;Get Evokeds;MEEG;Compute;events;False;False;;operations;basic;meeg -interpolate_bads;Interpolate Bads;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,bad_interpolation compute_psd_raw;Compute PSD (Raw);MEEG;Compute;Time-Frequency;False;False;;operations;basic;meeg,psd_method,n_jobs compute_psd_epochs;Compute PSD (Epochs);MEEG;Compute;Time-Frequency;False;False;;operations;basic;meeg,psd_method,n_jobs tfr;Time-Frequency;MEEG;Compute;Time-Frequency;False;False;;operations;basic;meeg,tfr_freqs,tfr_n_cycles,tfr_average,tfr_use_fft,tfr_baseline,tfr_baseline_mode,tfr_method,multitaper_bandwidth,stockwell_width,n_jobs @@ -26,9 +27,9 @@ morph_labels_from_fsaverage;;FSMRI;Compute;MRI-Preprocessing;False;False;;operat create_inverse_operator;;MEEG;Compute;Inverse;False;False;;operations;basic;meeg source_estimate;;MEEG;Compute;Inverse;False;False;;operations;basic;meeg,inverse_method,pick_ori,lambda2 apply_morph;;MEEG;Compute;Inverse;False;False;;operations;basic;meeg,morph_to -label_time_course;;MEEG;Compute;Inverse;False;False;;operations;basic;meeg,target_labels,target_parcellation,extract_mode +label_time_course;;MEEG;Compute;Inverse;False;False;;operations;basic;meeg,target_labels,extract_mode ecd_fit;;MEEG;Compute;Inverse;False;False;;operations;basic;meeg,ecd_times,ecd_positions,ecd_orientations,t_epoch -src_connectivity;;MEEG;Compute;Inverse;False;False;;operations;basic;meeg,target_labels,target_parcellation,inverse_method,lambda2,con_methods,con_frequencies,con_time_window,n_jobs +src_connectivity;;MEEG;Compute;Time-Frequency;False;False;;operations;basic;meeg,target_labels,inverse_method,lambda2,con_methods,con_fmin,con_fmax,con_time_window,n_jobs grand_avg_evokeds;;Group;Compute;Grand-Average;False;False;;operations;basic;group,ga_interpolate_bads,ga_drop_bads grand_avg_tfr;;Group;Compute;Grand-Average;False;False;;operations;basic;group grand_avg_morphed;;Group;Compute;Grand-Average;False;False;;operations;basic;group,morph_to @@ -61,21 +62,21 @@ plot_evoked_white;;MEEG;Plot;Evoked;True;False;;plot;basic;meeg,show_plots plot_evoked_image;;MEEG;Plot;Evoked;True;False;;plot;basic;meeg,show_plots plot_compare_evokeds;;MEEG;Plot;Evoked;True;False;;plot;basic;meeg,show_plots plot_gfp;;MEEG;Plot;Evoked;True;False;;plot;basic;meeg,show_plots -plot_stc;Plot Source-Estimate;MEEG;Plot;Inverse;True;True;;plot;basic;meeg,target_labels,target_parcellation,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,backend_3d +plot_stc;Plot Source-Estimate;MEEG;Plot;Inverse;True;True;;plot;basic;meeg,target_labels,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,backend_3d plot_stc_interactive;;MEEG;Plot;Inverse;True;True;;plot;basic;meeg,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,backend_3d -plot_labels;;FSMRI;Plot;Inverse;True;True;;plot;basic;fsmri,target_labels,target_parcellation,label_colors,stc_hemi,stc_surface,stc_views,backend_3d -plot_animated_stc;Plot Source-Estimate Video;MEEG;Plot;Inverse;True;True;;plot;basic;meeg,target_labels,target_parcellation,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,stc_animation_span,stc_animation_dilat,backend_3d +plot_labels;;FSMRI;Plot;Inverse;True;True;;plot;basic;fsmri,target_labels,label_colors,stc_hemi,stc_surface,stc_views,backend_3d +plot_animated_stc;Plot Source-Estimate Video;MEEG;Plot;Inverse;True;True;;plot;basic;meeg,target_labels,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,stc_animation_span,stc_animation_dilat,backend_3d plot_snr;;MEEG;Plot;Inverse;True;False;;plot;basic;meeg,show_plots -plot_label_time_course;;MEEG;Plot;Inverse;True;False;;plot;basic;meeg,show_plots +plot_label_time_course;;MEEG;Plot;Inverse;True;False;;plot;basic;meeg,label_colors,show_plots plot_ecd;;MEEG;Plot;Inverse;True;True;;plot;basic;meeg -plot_src_connectivity;;MEEG;Plot;Time-Frequency;True;False;;plot;basic;meeg,show_plots +plot_src_connectivity;;MEEG;Plot;Time-Frequency;True;False;;plot;basic;meeg,label_colors,show_plots plot_grand_avg_evokeds;;Group;Plot;Grand-Average;True;False;;plot;basic;group,show_plots plot_grand_avg_tfr;;Group;Plot;Grand-Average;True;False;;plot;basic;group,show_plots -plot_grand_avg_stc;;Group;Plot;Grand-Average;True;True;;plot;basic;group,target_labels,target_parcellation,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,backend_3d -plot_grand_avg_stc_anim;;Group;Plot;Grand-Average;True;True;;plot;basic;group,target_labels,target_parcellation,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,stc_animation_span,stc_animation_dilat,backend_3d +plot_grand_avg_stc;;Group;Plot;Grand-Average;True;True;;plot;basic;group,target_labels,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,backend_3d +plot_grand_avg_stc_anim;;Group;Plot;Grand-Average;True;True;;plot;basic;group,target_labels,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,stc_animation_span,stc_animation_dilat,backend_3d plot_grand_average_stc_interactive;;Group;Plot;Grand-Average;True;True;;plot;basic;group,label_colors,stc_surface,stc_hemi,stc_views,stc_time,stc_clim,stc_roll,stc_azimuth,stc_elevation,backend_3d -plot_grand_avg_ltc;;Group;Plot;Grand-Average;True;False;;plot;basic;group,show_plots -plot_grand_avg_connect;;Group;Plot;Grand-Average;True;False;;plot;basic;group,morph_to,show_plots,connectivity_vmin,connectivity_vmax,con_group_boundaries +plot_grand_avg_ltc;;Group;Plot;Grand-Average;True;False;;plot;basic;group,label_colors,show_plots +plot_grand_avg_connect;;Group;Plot;Grand-Average;True;False;;plot;basic;group,label_colors,show_plots plot_ica_components;Plot ICA-Components;MEEG;Plot;ICA;True;False;;plot;basic;meeg,show_plots,close_func plot_ica_sources;Plot ICA-Sources;MEEG;Plot;ICA;True;False;;plot;basic;meeg,ica_source_data,show_plots,close_func plot_ica_overlay;Plot ICA-Overlay;MEEG;Plot;ICA;True;False;;plot;basic;meeg,ica_overlay_data,show_plots diff --git a/mne_pipeline_hd/extra/parameters.csv b/mne_pipeline_hd/extra/parameters.csv index 760ff28..07aca06 100644 --- a/mne_pipeline_hd/extra/parameters.csv +++ b/mne_pipeline_hd/extra/parameters.csv @@ -76,12 +76,12 @@ stc_azimuth;;Inverse;70;;Azimuth for view for Source Estimate Plots;IntGui;{'max stc_elevation;;Inverse;60;;Elevation for view for Source Estimate Plots;IntGui;{'max_val': 360} stc_animation_span;;Inverse;(0,0.5);s;time-span for stc-animation[s];TupleGui; stc_animation_dilat;;Inverse;20;;time-dilation for stc-animation;IntGui; -target_parcellation;Target Parcellation;Inverse;aparc;;The parcellation to use.;StringGui; target_labels;Target Labels;Inverse;[];;;LabelGui; label_colors;;Inverse;{};;Set custom colors for labels.;ColorGui;{'keys': 'target_labels', 'none_select':True} extract_mode;Label-Extraction-Mode;Inverse;auto;;mode for extracting label-time-course from Source-Estimate;ComboGui;{'options': ['auto', 'max', 'mean', 'mean_flip', 'pca_flip']} con_methods;;Connectivity;['coh'];;methods for connectivity;CheckListGui;{'options': ['coh', 'cohy', 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', 'pli2_unbiased', 'wpli', 'wpli2_debiased']} -con_frequencies;;Connectivity;(30, 80);;frequencies for connectivity;TupleGui;{'none_select': True, 'step': 1} +con_fmin;;Connectivity;30;;lower frequency/frequencies for connectivity;MultiTypeGui;{'type_selection': True, 'types': ['float', 'list']} +con_fmax;;Connectivity;80;;upper frequency/frequencies for connectivity;MultiTypeGui;{'type_selection': True, 'types': ['float', 'list']} con_time_window;;Connectivity;(0, 0.5);;time-window for connectivity;TupleGui;{'none_select': True, 'step': 0.001} ecd_times;;Inverse;{};;;DictGui; ecd_positions;;Inverse;{};;;DictGui; @@ -96,9 +96,8 @@ erm_n_mag;;Preprocessing;2;;The number of projections for Magnetometer;IntGui; erm_n_eeg;;Preprocessing;0;;The number of projections for EEG;IntGui; ga_interpolate_bads;;Grand-Average;True;;If to interpolate bad channels for the Grand-Average;BoolGui; ga_drop_bads;;Grand-Average;True;;If to drop bad channels for the Grand-Average;BoolGui; -connectivity_vmin;;Connectivity;None;;Minimum value for colormap;FloatGui;{'step': 0.01, 'none_select':True} -connectivity_vmax;;Connectivity;None;;Maximum value for colormap;FloatGui;{'step': 0.01, 'none_select':True} psd_method;;Time-Frequency;welch;;The method for spectral estimation;ComboGui;{'options': ['welch', 'multitaper']} psd_topomap_bands;;Time-Frequency;None;;The frequency bands for the topomap-plot;DictGui;{'none_select': True} backend_3d;3D-Backend;Plot;pyvistaqt;;Choose the 3D-Backend for Brain-plots.;ComboGui;{'options': ['pyvistaqt', 'notebook']} con_group_boundaries;;Connectivity;None;;Set group-boundaries for circular plot.;FuncGui;{'none_select': True} +notch_frequencies;;Preprocessing;50;;Set frequencies for Notch filtering;FuncGui;"" diff --git a/mne_pipeline_hd/functions/operations.py b/mne_pipeline_hd/functions/operations.py index a2ba1cb..ce7ae41 100644 --- a/mne_pipeline_hd/functions/operations.py +++ b/mne_pipeline_hd/functions/operations.py @@ -8,7 +8,6 @@ from __future__ import print_function import gc -import logging import os import shutil import subprocess @@ -24,14 +23,16 @@ import mne_connectivity import numpy as np from mne.preprocessing import ICA, find_bad_channels_maxwell +from mne_connectivity import SpectralConnectivity -from mne_pipeline_hd.pipeline.loading import MEEG +from mne_pipeline_hd.pipeline.loading import MEEG, FSMRI from mne_pipeline_hd.pipeline.pipeline_utils import ( check_kwargs, compare_filep, ismac, iswin, get_n_jobs, + logger, ) @@ -53,7 +54,7 @@ def find_bads(meeg, n_jobs, **kwargs): noisy_chs, flat_chs = find_bad_channels_maxwell( raw, coord_frame=coord_frame, **kwargs ) - logging.info(f"Noisy channels: {noisy_chs}\n" f"Flat channels: {flat_chs}") + logger().info(f"Noisy channels: {noisy_chs}\n" f"Flat channels: {flat_chs}") raw.info["bads"] = noisy_chs + flat_chs + raw.info["bads"] meeg.set_bad_channels(raw.info["bads"]) meeg.save_raw(raw) @@ -87,7 +88,7 @@ def filter_data( if any([results[key] != "equal" for key in results]): # Load Data - data = meeg.io_dict[filter_target]["load"]() + data = meeg.load(filter_target) # use cuda for filtering if enabled if enable_cuda: @@ -131,9 +132,9 @@ def filter_data( # Save Data if filter_target == "raw": - meeg.io_dict["raw_filtered"]["save"](data) + meeg.save("raw_filtered", data) else: - meeg.io_dict[filter_target]["save"](data) + meeg.save(filter_target, data) # Remove raw to avoid memory overload del data @@ -195,16 +196,27 @@ def filter_data( print("no erm_file assigned") +def notch_filter(meeg, notch_frequencies, n_jobs): + raw_filtered = meeg.load_filtered() + + raw_filtered = raw_filtered.notch_filter(notch_frequencies, n_jobs=1) + meeg.save_filtered(raw_filtered) + + def interpolate_bads(meeg, bad_interpolation): - data = meeg.io_dict[bad_interpolation]["load"]() + data = meeg.load(bad_interpolation) if bad_interpolation == "evoked": for evoked in data: - evoked.interpolate_bads() + # Add bads for channels present + evoked.info["bads"] = [b for b in meeg.bad_channels if b in data.ch_names] + evoked.interpolate_bads(reset_bads=True) else: - data.interpolate_bads() + # Add bads for channels present + data.info["bads"] = [b for b in meeg.bad_channels if b in data.ch_names] + data.interpolate_bads(reset_bads=True) - meeg.io_dict[bad_interpolation]["save"](data) + meeg.save(bad_interpolation, data) def add_erm_ssp( @@ -789,10 +801,10 @@ def apply_ica(meeg, ica_apply_target, n_pca_components): def get_evokeds(meeg): - epochs = meeg.load_epochs() - evokeds = [] - for trial in meeg.sel_trials: - evoked = epochs[trial].average() + meeg.load_epochs() + evokeds = list() + for trial, epoch in meeg.get_trial_epochs(): + evoked = epoch.average() # Todo: optional if you want weights in your evoked.comment?! evoked.comment = trial evokeds.append(evoked) @@ -812,12 +824,22 @@ def calculate_gfp(evoked): def grand_avg_evokeds(group, ga_interpolate_bads, ga_drop_bads): - trial_dict = {} + trial_dict = dict() for name in group.group_list: meeg = MEEG(name, group.ct) print(f"Add {name} to grand_average") evokeds = meeg.load_evokeds() for evoked in evokeds: + if ga_interpolate_bads: + bad_evoked = evoked.copy().pick(np.arange(len(meeg.bad_channels))) + bad_evoked = bad_evoked.rename_channels( + { + old: new + for old, new in zip(bad_evoked.ch_names, meeg.bad_channels) + } + ) + bad_evoked.info["bads"] = meeg.bad_channels + evoked.add_channels([bad_evoked]) if evoked.nave != 0: if evoked.comment in trial_dict: trial_dict[evoked.comment].append(evoked) @@ -873,15 +895,15 @@ def tfr( powers = list() itcs = list() - epochs = meeg.load_epochs() + meeg.load_epochs() # Calculate Time-Frequency for each trial from epochs # using the selected method - for trial in meeg.sel_trials: + for trial, epoch in meeg.get_trial_epochs(): if tfr_method == "multitaper": multitaper_kwargs = check_kwargs(kwargs, mne.time_frequency.tfr_multitaper) tfr_result = mne.time_frequency.tfr_multitaper( - epochs[trial], + epoch, freqs=tfr_freqs, n_cycles=tfr_n_cycles, time_bandwidth=multitaper_bandwidth, @@ -895,7 +917,7 @@ def tfr( fmin, fmax = tfr_freqs[[0, -1]] stockwell_kwargs = check_kwargs(kwargs, mne.time_frequency.tfr_stockwell) tfr_result = mne.time_frequency.tfr_stockwell( - epochs[trial], + epoch, fmin=fmin, fmax=fmax, width=stockwell_width, @@ -906,7 +928,7 @@ def tfr( else: morlet_kwargs = check_kwargs(kwargs, mne.time_frequency.tfr_morlet) tfr_result = mne.time_frequency.tfr_morlet( - epochs[trial], + epoch, freqs=tfr_freqs, n_cycles=tfr_n_cycles, n_jobs=n_jobs, @@ -1113,8 +1135,7 @@ def make_dense_scalp_surfaces(fsmri): "mne", "make_scalp_surfaces", "--overwrite", - "--subject", - fsmri.name, + f"--subject={fsmri.name}", "--force", ] @@ -1170,15 +1191,17 @@ def prepare_bem(fsmri, bem_spacing, bem_conductivity): def morph_fsmri(meeg, morph_to): if meeg.fsmri.name != morph_to: forward = meeg.load_forward() + fsmri_to = FSMRI(morph_to, meeg.ct) morph = mne.compute_source_morph( forward["src"], subject_from=meeg.fsmri.name, subject_to=morph_to, subjects_dir=meeg.subjects_dir, + src_to=fsmri_to.load_source_space(), ) meeg.save_source_morph(morph) else: - logging.info( + logger().info( f"There is no need to morph the source-space for {meeg.name}, " f'because the morph-destination "{morph_to}" ' f"is the same as the associated FSMRI." @@ -1288,7 +1311,7 @@ def source_estimate(meeg, inverse_method, pick_ori, lambda2): inverse_operator = meeg.load_inverse_operator() evokeds = meeg.load_evokeds() - stcs = {} + stcs = dict() for evoked in [ev for ev in evokeds if ev.comment in meeg.sel_trials]: stc = mne.minimum_norm.apply_inverse( evoked, inverse_operator, lambda2, method=inverse_method, pick_ori=pick_ori @@ -1298,15 +1321,20 @@ def source_estimate(meeg, inverse_method, pick_ori, lambda2): meeg.save_source_estimates(stcs) -def label_time_course(meeg, target_labels, target_parcellation, extract_mode): +def label_time_course(meeg, target_labels, extract_mode): + if len(target_labels) == 0: + raise RuntimeError( + "No labels selected for label time course extraction. " + "Please select at least one label." + ) stcs = meeg.load_source_estimates() src = meeg.fsmri.load_source_space() - labels = meeg.fsmri.get_labels(target_labels, target_parcellation) + labels = meeg.fsmri.get_labels(target_labels) - ltc_dict = {} + ltc_dict = dict() for trial in stcs: - ltc_dict[trial] = {} + ltc_dict[trial] = dict() times = stcs[trial].times for label in labels: ltc = stcs[trial].extract_label_time_course(label, src, mode=extract_mode)[ @@ -1339,8 +1367,8 @@ def mixed_norm_estimate(meeg, pick_ori, inverse_method): evoked, inv_op, lambda2, method="dSPM" ) - mixn_dips = {} - mixn_stcs = {} + mixn_dips = dict() + mixn_stcs = dict() for evoked in [ev for ev in evokeds if ev.comment in meeg.sel_trials]: alpha = 30 # regularization parameter between 0 and 100 (100 is high) @@ -1403,11 +1431,11 @@ def ecd_fit(meeg, ecd_times, ecd_positions, ecd_orientations, t_epoch): bem = meeg.fsmri.load_bem_solution() trans = meeg.load_transformation() - ecd_dips = {} + ecd_dips = dict() for evoked in evokeds: trial = evoked.comment - ecd_dips[trial] = {} + ecd_dips[trial] = dict() for dip in ecd_time: tmin, tmax = ecd_time[dip] copy_evoked = evoked.copy().crop(tmin, tmax) @@ -1455,12 +1483,12 @@ def apply_morph(meeg, morph_to): stcs = meeg.load_source_estimates() morph = meeg.load_source_morph() - morphed_stcs = {} + morphed_stcs = dict() for trial in stcs: morphed_stcs[trial] = morph.apply(stcs[trial]) meeg.save_morphed_source_estimates(morphed_stcs) else: - logging.info( + logger().info( f"{meeg.name} is already in source-space of {morph_to} " f"and won't be morphed" ) @@ -1469,34 +1497,36 @@ def apply_morph(meeg, morph_to): def src_connectivity( meeg, target_labels, - target_parcellation, inverse_method, lambda2, con_methods, - con_frequencies, + con_fmin, + con_fmax, con_time_window, n_jobs, ): + if len(target_labels) == 0: + raise RuntimeError( + "No labels selected for connectivity estimation. " + "Please select at least one label." + ) info = meeg.load_info() - all_epochs = meeg.load_epochs() inverse_operator = meeg.load_inverse_operator() src = inverse_operator["src"] - labels = meeg.fsmri.get_labels(target_labels, target_parcellation) + labels = meeg.fsmri.get_labels(target_labels) if len(labels) == 0: - raise RuntimeError( - "No labels found, check your target_labels and target_parcellation" - ) + raise RuntimeError("No labels found, check your target_labels") if len(meeg.sel_trials) == 0: raise RuntimeError( "No trials selected, check your Selected IDs in Preparation/" ) - con_dict = {} + con_dict = dict() - for trial in meeg.sel_trials: - con_dict[trial] = {} - epochs = all_epochs[trial] + for trial, epoch in meeg.get_trial_epochs(): + con_dict[trial] = dict() + epochs = epoch # Crop if necessary if con_time_window is not None: @@ -1523,11 +1553,12 @@ def src_connectivity( sfreq = info["sfreq"] # the sampling frequency con = mne_connectivity.spectral_connectivity_epochs( label_ts, + names=target_labels, method=con_methods, mode="multitaper", sfreq=sfreq, - fmin=con_frequencies[0], - fmax=con_frequencies[1], + fmin=con_fmin, + fmax=con_fmax, faverage=True, mt_adaptive=True, n_jobs=n_jobs, @@ -1539,14 +1570,7 @@ def src_connectivity( # con is a 3D array, get the connectivity for the first (and only) # freq. band for each con_method for method, c in zip(con_methods, con): - con_dict[trial][method] = c.get_data(output="dense")[:, :, 0] - - # Add target_labels for later identification - con_dict["__info__"] = { - "labels": target_labels, - "parcellation": target_parcellation, - "frequencies": con_frequencies, - } + con_dict[trial][method] = c meeg.save_connectivity(con_dict) @@ -1556,9 +1580,9 @@ def grand_avg_morphed(group, morph_to): # stc in the end!!! n_chunks = 8 # divide in chunks to save memory - fusion_dict = {} + fusion_dict = dict() for i in range(0, len(group.group_list), n_chunks): - sub_trial_dict = {} + sub_trial_dict = dict() ga_chunk = group.group_list[i : i + n_chunks] print(ga_chunk) for name in ga_chunk: @@ -1592,7 +1616,7 @@ def grand_avg_morphed(group, morph_to): else: fusion_dict.update({trial: [sub_trial_average]}) - ga_stcs = {} + ga_stcs = dict() for trial in fusion_dict: if len(fusion_dict[trial]) != 0: print(f"grand_average for {group.name}-{trial}") @@ -1611,7 +1635,7 @@ def grand_avg_morphed(group, morph_to): def grand_avg_ltc(group): - ltc_average_dict = {} + ltc_average_dict = dict() times = None for name in group.group_list: meeg = MEEG(name, group.ct) @@ -1619,7 +1643,7 @@ def grand_avg_ltc(group): ltc_dict = meeg.load_ltc() for trial in ltc_dict: if trial not in ltc_average_dict: - ltc_average_dict[trial] = {} + ltc_average_dict[trial] = dict() for label in ltc_dict[trial]: # First row of array is label-time-course-data, # second row is time-array @@ -1630,21 +1654,14 @@ def grand_avg_ltc(group): # Should be the same for each trial and label times = ltc_dict[trial][label][1] - ga_ltc = {} + ga_ltc = dict() for trial in ltc_average_dict: - ga_ltc[trial] = {} + ga_ltc[trial] = dict() for label in ltc_average_dict[trial]: if len(ltc_average_dict[trial][label]) != 0: print(f"grand_average for {trial}-{label}") ltc_list = ltc_average_dict[trial][label] - # Take the absolute values - ltc_list = [abs(it) for it in ltc_list] - n_subjects = len(ltc_list) - average = ltc_list[0] - for idx in range(1, n_subjects): - average += ltc_list[idx] - - average /= n_subjects + average = np.mean(ltc_list, axis=0) ga_ltc[trial][label] = np.vstack((average, times)) @@ -1653,40 +1670,39 @@ def grand_avg_ltc(group): def grand_avg_connect(group): # Prepare the Average-Dict - con_average_dict = {} + con_average_dict = dict() for name in group.group_list: meeg = MEEG(name, group.ct) print(f"Add {name} to grand_average") con_dict = meeg.load_connectivity() - con_info = con_dict.pop("__info__") for trial in con_dict: if trial not in con_average_dict: - con_average_dict[trial] = {} - for con_method in con_dict[trial]: - if con_method in con_average_dict[trial]: - con_average_dict[trial][con_method].append( - con_dict[trial][con_method] - ) - else: - con_average_dict[trial][con_method] = [con_dict[trial][con_method]] + con_average_dict[trial] = dict() + for con_method, con in con_dict[trial].items(): + if con_method not in con_average_dict[trial]: + con_average_dict[trial][con_method] = list() + con_average_dict[trial][con_method].append(con) - ga_con = {"__info__": con_info} + ga_con_dict = dict() for trial in con_average_dict: - ga_con[trial] = {} - for con_method in con_average_dict[trial]: - if len(con_average_dict[trial][con_method]) != 0: + ga_con_dict[trial] = dict() + for con_method, con_list in con_average_dict[trial].items(): + if len(con_list) != 0: print(f"grand_average for {trial}-{con_method}") - con_list = con_average_dict[trial][con_method] - n_subjects = len(con_list) - average = con_list[0] - for idx in range(1, n_subjects): - average += con_list[idx] - - average /= n_subjects - - ga_con[trial][con_method] = average + avg_data = np.mean([con.get_data() for con in con_list], axis=0) + + ga_con = SpectralConnectivity( + data=avg_data, + freqs=con_list[0].freqs, + n_nodes=con_list[0].n_nodes, + names=con_list[0].names, + indices=con_list[0].indices, + method=con_list[0].method, + n_epochs_used=len(con_list), + ) + ga_con_dict[trial][con_method] = ga_con - group.save_ga_con(ga_con) + group.save_ga_con(ga_con_dict) def print_info(meeg): diff --git a/mne_pipeline_hd/functions/plot.py b/mne_pipeline_hd/functions/plot.py index ddd314e..8d7ca81 100644 --- a/mne_pipeline_hd/functions/plot.py +++ b/mne_pipeline_hd/functions/plot.py @@ -7,7 +7,7 @@ from __future__ import print_function -import gc +import itertools from functools import partial from os.path import join @@ -17,7 +17,7 @@ import numpy as np # Make use of program also possible with sensor-space installation of mne -from mne_pipeline_hd.pipeline.loading import FSMRI +from mne_pipeline_hd.pipeline.loading import MEEG from mne_pipeline_hd.pipeline.plot_utils import pipeline_plot try: @@ -78,7 +78,7 @@ def plot_filtered(meeg, show_plots, close_func=_save_raw_on_close, **kwargs): events = None print("No events found") - fig = raw.plot( + raw.plot( events=events, bad_color="red", scalings="auto", @@ -88,18 +88,6 @@ def plot_filtered(meeg, show_plots, close_func=_save_raw_on_close, **kwargs): **kwargs, ) - if hasattr(fig, "canvas"): - # Connect to closing of Matplotlib-Figure - fig.canvas.mpl_connect( - "close_event", - partial(close_func, meeg=meeg, raw=raw, raw_type="raw_filtered"), - ) - else: - # Connect to closing of PyQt-Figure - fig.gotClosed.connect( - partial(close_func, None, meeg=meeg, raw=raw, raw_type="raw_filtered") - ) - def plot_sensors(meeg, plot_sensors_kind, ch_types, show_plots): loaded_info = meeg.load_info() @@ -380,7 +368,7 @@ def plot_evoked_image(meeg, show_plots): def plot_compare_evokeds(meeg, show_plots): evokeds = meeg.load_evokeds() - evokeds = {evoked.comment: evoked for evoked in evokeds} + evokeds = {f"{evoked.comment}={evoked.nave}": evoked for evoked in evokeds} fig = mne.viz.plot_compare_evokeds(evokeds, show=show_plots) @@ -595,7 +583,6 @@ def _brain_plot( stc_time, stc_clim, target_labels, - target_parcellation, label_colors, stc_roll, stc_azimuth, @@ -620,7 +607,7 @@ def _brain_plot( brain.show_view(roll=stc_roll, azimuth=stc_azimuth, elevation=stc_elevation) brain.add_text(0, 0.9, title, "title", color="w", font_size=14) if not interactive: - labels = meeg.fsmri.get_labels(target_labels, target_parcellation) + labels = meeg.fsmri.get_labels(target_labels) for label in labels: color = label_colors.get(label.name) brain.add_label(label, borders=True, color=color) @@ -648,7 +635,6 @@ def _brain_plot( def plot_stc( meeg, target_labels, - target_parcellation, label_colors, stc_surface, stc_hemi, @@ -670,7 +656,6 @@ def plot_stc( stc_time=stc_time, stc_clim=stc_clim, target_labels=target_labels, - target_parcellation=target_parcellation, label_colors=label_colors, stc_roll=stc_roll, stc_azimuth=stc_azimuth, @@ -701,7 +686,6 @@ def plot_stc_interactive( stc_time=stc_time, stc_clim=stc_clim, target_labels=None, - target_parcellation=None, label_colors=None, stc_roll=stc_roll, stc_azimuth=stc_azimuth, @@ -714,7 +698,6 @@ def plot_stc_interactive( def plot_animated_stc( meeg, target_labels, - target_parcellation, label_colors, stc_surface, stc_hemi, @@ -738,7 +721,6 @@ def plot_animated_stc( stc_time=stc_time, stc_clim=stc_clim, target_labels=target_labels, - target_parcellation=target_parcellation, label_colors=label_colors, stc_roll=stc_roll, stc_azimuth=stc_azimuth, @@ -752,7 +734,6 @@ def plot_animated_stc( def plot_labels( fsmri, target_labels, - target_parcellation, label_colors, stc_hemi, stc_surface, @@ -762,17 +743,23 @@ def plot_labels( with mne.viz.use_3d_backend(backend_3d): Brain = mne.viz.get_brain_class() brain = Brain( - subject_id=fsmri.name, + subject=fsmri.name, hemi=stc_hemi, surf=stc_surface, subjects_dir=fsmri.subjects_dir, views=stc_views, ) - labels = fsmri.get_labels(target_labels, target_parcellation) + labels = fsmri.get_labels(target_labels) + y = 0.95 for label in labels: color = label_colors.get(label.name) + if color is None: + color = label.color brain.add_label(label, borders=False, color=color) + brain.add_text(x=0.05, y=y, text=label.name, color=color, font_size=14) + y -= 0.05 + fsmri.plot_save("labels", brain=brain) @@ -844,81 +831,128 @@ def plot_snr(meeg, show_plots): meeg.plot_save("snr", trial=trial, matplotlib_figure=fig) -def plot_label_time_course(meeg, show_plots): +def plot_label_time_course(meeg, label_colors, show_plots): ltcs = meeg.load_ltc() for trial in ltcs: - for label in ltcs[trial]: - plt.figure() - plt.plot(ltcs[trial][label][1], ltcs[trial][label][0]) - plt.title( - f"{meeg.name}-{trial}-{label}\n" - f'Extraction-Mode: {meeg.pa["extract_mode"]}' - ) - plt.xlabel("Time in s") - plt.ylabel("Source amplitude") - if show_plots: - plt.show() + plt.figure() + plt.title( + f"{meeg.name}-{trial}\n" f'Extraction-Mode: {meeg.pa["extract_mode"]}' + ) + plt.xlabel("Time in s") + plt.ylabel("Source amplitude") + for label_name, data in ltcs[trial].items(): + color = label_colors.get(label_name, "black") + plt.plot(data[1], data[0], color=color, label=label_name) + plt.legend() + if show_plots: + plt.show() + meeg.plot_save("label-time-course", trial=trial) - meeg.plot_save("label-time-course", subfolder=label, trial=trial) +def _get_n_subplots(n_items): + n_subplots = np.ceil(np.sqrt(n_items)).astype(int) + if n_items <= 2: + nrows = 1 + ax_idxs = range(n_subplots) + else: + nrows = n_subplots + ax_idxs = itertools.product(range(n_subplots), repeat=2) + ncols = n_subplots -def plot_src_connectivity(meeg, show_plots): - con_dict = meeg.load_connectivity() - con_info = con_dict.pop("__info__") - labels = meeg.fsmri.get_labels(con_info["labels"], con_info["parcellation"]) - if "unknown-lh" in labels: - labels.pop("unknown-lh") - label_colors = [label.color for label in labels] - label_names = [label.name for label in labels] - lh_labels = [l_name for l_name in label_names if l_name.endswith("lh")] - rh_labels = [l_name for l_name in label_names if l_name.endswith("rh")] - - # Get the y-location of the label - lh_label_ypos = [np.mean(lb.pos[:, 1]) for lb in labels if lb.name in lh_labels] - rh_label_ypos = [np.mean(lb.pos[:, 1]) for lb in labels if lb.name in rh_labels] - - # Reorder the labels based on their location - lh_labels = [label for (yp, label) in sorted(zip(lh_label_ypos, lh_labels))] - rh_labels = [label for (yp, label) in sorted(zip(rh_label_ypos, rh_labels))] - - # Save the plot order and create a circular layout - node_order = list() - node_order.extend(lh_labels[::-1]) # reverse the order - node_order.extend(rh_labels) - - node_angles = mne.viz.circular_layout( - label_names, - node_order, - start_pos=90, - group_boundaries=[0, len(label_names) / 2], - ) + return nrows, ncols, ax_idxs - # Plot the graph using node colors from the FreeSurfer parcellation. - # We only show the 300 strongest connections. + +def _plot_connectivity(obj, con_dict, label_colors, show_plots): for trial in con_dict: - for con_method in con_dict[trial]: - title = ( - f"{trial}: {con_info['frequencies'][0]}-{con_info['frequencies'][1]}" - ) - fig, axes = mne_connectivity.viz.plot_connectivity_circle( - con_dict[trial][con_method], + for con_method, con in con_dict[trial].items(): + labels = obj.fsmri.get_labels(con.names) + if "unknown-lh" in labels: + labels.pop("unknown-lh") + colors = list() + for label in labels: + color = label_colors.get(label.name) + if color is None: + color = label.color + colors.append(color) + label_names = [label.name for label in labels] + lh_labels = [l_name for l_name in label_names if l_name.endswith("lh")] + rh_labels = [l_name for l_name in label_names if l_name.endswith("rh")] + + # Get the y-location of the label + lh_label_ypos = [ + np.mean(lb.pos[:, 1]) for lb in labels if lb.name in lh_labels + ] + rh_label_ypos = [ + np.mean(lb.pos[:, 1]) for lb in labels if lb.name in rh_labels + ] + + # Reorder the labels based on their location + lh_labels = [label for (yp, label) in sorted(zip(lh_label_ypos, lh_labels))] + rh_labels = [label for (yp, label) in sorted(zip(rh_label_ypos, rh_labels))] + + # Save the plot order and create a circular layout + node_order = list() + node_order.extend(lh_labels[::-1]) # reverse the order + node_order.extend(rh_labels) + + node_angles = mne.viz.circular_layout( label_names, - n_lines=100, - node_angles=node_angles, - node_colors=label_colors, - title=title, - fontsize_names=8, - show=show_plots, + node_order, + start_pos=90, + group_boundaries=[0, len(label_names) / 2], ) - - meeg.plot_save( - "connectivity", subfolder=con_method, trial=trial, matplotlib_figure=fig + con_data = con.get_data(output="dense") + + nrows, ncols, ax_idxs = _get_n_subplots(len(con.freqs)) + ax_idxs = list(ax_idxs) + fig, axes = plt.subplots( + nrows=nrows, + ncols=ncols, + subplot_kw={"projection": "polar"}, + facecolor="black", + figsize=(8, 8), + ) + # Remove extra axes + if nrows**2 > len(con.freqs): + fig.delaxes(axes[-1, -1]) + + for freq_idx, freq in enumerate(con.freqs): + if isinstance(axes, np.ndarray): + ax = axes[ax_idxs[freq_idx]] + else: + ax = axes + title = f"Frequency={freq:.1f} Hz" + mne_connectivity.viz.plot_connectivity_circle( + con_data[:, :, freq_idx], + label_names, + n_lines=None, + node_angles=node_angles, + node_colors=colors, + title=title, + show=show_plots, + ax=ax, + ) + fig.suptitle( + f"{obj.name}-{trial}-{con_method}", + horizontalalignment="center", + color="white", + ) + plt.tight_layout() + if isinstance(obj, MEEG): + plot_name = "connectivity" + else: + plot_name = "ga_connectivity" + obj.plot_save( + plot_name, subfolder=con_method, trial=trial, matplotlib_figure=fig ) -# %% Grand-Average Plots +def plot_src_connectivity(meeg, label_colors, show_plots): + con_dict = meeg.load_connectivity() + _plot_connectivity(meeg, con_dict, label_colors, show_plots) +# %% Grand-Average Plots def plot_grand_avg_evokeds(group, show_plots): ga_evokeds = group.load_ga_evokeds() @@ -962,7 +996,6 @@ def plot_grand_avg_tfr(group, show_plots): def plot_grand_avg_stc( group, target_labels, - target_parcellation, label_colors, stc_surface, stc_hemi, @@ -984,7 +1017,6 @@ def plot_grand_avg_stc( stc_time=stc_time, stc_clim=stc_clim, target_labels=target_labels, - target_parcellation=target_parcellation, label_colors=label_colors, stc_roll=stc_roll, stc_azimuth=stc_azimuth, @@ -1016,7 +1048,6 @@ def plot_grand_average_stc_interactive( stc_time=stc_time, stc_clim=stc_clim, target_labels=None, - target_parcellation=None, label_colors=label_colors, stc_roll=stc_roll, stc_azimuth=stc_azimuth, @@ -1029,7 +1060,6 @@ def plot_grand_average_stc_interactive( def plot_grand_avg_stc_anim( group, target_labels, - target_parcellation, label_colors, stc_surface, stc_hemi, @@ -1053,7 +1083,6 @@ def plot_grand_avg_stc_anim( stc_time=stc_time, stc_clim=stc_clim, target_labels=target_labels, - target_parcellation=target_parcellation, label_colors=label_colors, stc_roll=stc_roll, stc_azimuth=stc_azimuth, @@ -1064,91 +1093,30 @@ def plot_grand_avg_stc_anim( ) -def plot_grand_avg_ltc(group, show_plots): +def plot_grand_avg_ltc(group, label_colors, show_plots): ga_ltc = group.load_ga_ltc() for trial in ga_ltc: - for label in ga_ltc[trial]: - plt.figure() - plt.plot(ga_ltc[trial][label][1], ga_ltc[trial][label][0]) - plt.title( - f"Label-Time-Course for {group.name}-{trial}-{label}\n" - f'with Extraction-Mode: {group.pa["extract_mode"]}' - ) - plt.xlabel("Time in ms") - plt.ylabel("Source amplitude") - if show_plots: - plt.show() + plt.figure() + plt.title( + f"Label-Time-Course for {group.name}-{trial}\n" + f'with Extraction-Mode: {group.pa["extract_mode"]}' + ) + plt.xlabel("Time in ms") + plt.ylabel("Source amplitude") + for label_name, data in ga_ltc[trial].items(): + color = label_colors.get(label_name, "black") + plt.plot(data[1], data[0], color=color, label=label_name) + plt.legend() + if show_plots: + plt.show() - group.plot_save("ga_label-time-course", subfolder=label, trial=trial) + group.plot_save("ga_label-time-course", trial=trial) def plot_grand_avg_connect( group, - morph_to, + label_colors, show_plots, - connectivity_vmin, - connectivity_vmax, - con_group_boundaries, ): - ga_dict = group.load_ga_con() - con_info = ga_dict.pop("__info__") - # Get labels for FreeSurfer 'aparc' cortical parcellation - # with 34 labels/hemi - fsmri = FSMRI(morph_to, group.ct) - labels = fsmri.get_labels(con_info["labels"], con_info["parcellation"]) - if "unknown-lh" in labels: - labels.remove("unknown-lh") - - label_colors = [label.color for label in labels] - label_names = [lb.name for lb in labels] - - lh_labels = [l_name for l_name in label_names if l_name.endswith("lh")] - rh_labels = [l_name for l_name in label_names if l_name.endswith("rh")] - - # Get the y-location of the label - lh_label_ypos = [np.mean(lb.pos[:, 1]) for lb in labels if lb.name in lh_labels] - rh_label_ypos = [np.mean(lb.pos[:, 1]) for lb in labels if lb.name in rh_labels] - - # Reorder the labels based on their location - lh_labels = [label for (yp, label) in sorted(zip(lh_label_ypos, lh_labels))] - rh_labels = [label for (yp, label) in sorted(zip(rh_label_ypos, rh_labels))] - - # Save the plot order and create a circular layout - node_order = list() - node_order.extend(lh_labels[::-1]) # reverse the order - node_order.extend(rh_labels) - - # ToDo: what happens when there is an odd number of labels/groups - node_angles = mne.viz.circular_layout( - label_names, - node_order, - start_pos=90, - group_boundaries=con_group_boundaries, - ) - - for trial in ga_dict: - for method in ga_dict[trial]: - title = ( - f"{method}: {con_info['frequencies'][0]}-{con_info['frequencies'][1]}" - ) - fig, axes = mne_connectivity.viz.plot_connectivity_circle( - ga_dict[trial][method], - label_names, - n_lines=300, - node_angles=node_angles, - node_colors=label_colors, - title=title, - vmin=connectivity_vmin, - vmax=connectivity_vmax, - fontsize_names=16, - show=show_plots, - ) - - group.plot_save( - "ga_connectivity", subfolder=method, trial=trial, matplotlib_figure=fig - ) - - -def close_all(): - plt.close("all") - gc.collect() + con_dict = group.load_ga_con() + _plot_connectivity(group, con_dict, label_colors, show_plots) diff --git a/mne_pipeline_hd/gui/base_widgets.py b/mne_pipeline_hd/gui/base_widgets.py index 879e591..e54fe7f 100644 --- a/mne_pipeline_hd/gui/base_widgets.py +++ b/mne_pipeline_hd/gui/base_widgets.py @@ -6,7 +6,6 @@ """ import itertools -import logging import re import sys @@ -49,7 +48,7 @@ FileManagementModel, TreeModel, ) -from mne_pipeline_hd.pipeline.pipeline_utils import QS +from mne_pipeline_hd.pipeline.pipeline_utils import QS, logger class Base(QWidget): @@ -77,6 +76,9 @@ def __init__(self, model, view, drag_drop, parent, title): self.view.selectionModel().currentChanged.connect(self._current_changed) self.view.selectionModel().selectionChanged.connect(self._selection_changed) self.model.dataChanged.connect(self._data_changed) + # Also send signal when rows are removed/added + self.model.rowsInserted.connect(self._data_changed) + self.model.rowsRemoved.connect(self._data_changed) self.init_ui() @@ -104,11 +106,13 @@ def get_current(self): def _current_changed(self, current_idx, previous_idx): current = self.model.getData(current_idx) + # ToDo: For ListWidget after removal, + # there is a bug when previous_idx is too high previous = self.model.getData(previous_idx) self.currentChanged.emit(current, previous) - logging.debug(f"Current changed from {previous} to {current}") + logger().debug(f"Current changed from {previous} to {current}") def get_selected(self): try: @@ -126,13 +130,13 @@ def _selection_changed(self): self.selectionChanged.emit(selected) - logging.debug(f"Selection changed to {selected}") + logger().debug(f"Selection changed to {selected}") def _data_changed(self, index, _): data = self.model.getData(index) self.dataChanged.emit(data, index) - logging.debug(f"{data} changed at {index}") + logger().debug(f"{data} changed at {index}") def content_changed(self): """Informs ModelView about external change made in data""" @@ -187,8 +191,6 @@ class SimpleList(BaseList): Parent Widget (QWidget or inherited) or None if there is no parent. title : str | None An optional title. - verbose : bool - Set True to see debugging for signals. Notes ----- @@ -238,8 +240,6 @@ class EditList(BaseList): An optional title. model : QAbstractItemModel Provide an alternative to EditListModel. - verbose : bool - Set True to see debugging for signals Notes ----- @@ -354,8 +354,6 @@ class CheckList(BaseList): Parent Widget (QWidget or inherited) or None if there is no parent. title : str | None An optional title - verbose : bool - Set True to see debugging for signals Notes ----- @@ -430,7 +428,7 @@ def init_ui(self): def _checked_changed(self): self.checkedChanged.emit(self.model._checked) - logging.debug(f"Changed values: {self.model._checked}") + logger().debug(f"Changed values: {self.model._checked}") def replace_checked(self, new_checked): """Replaces model._checked with new checked list""" @@ -478,8 +476,6 @@ class CheckDictList(BaseList): Parent Widget (QWidget or inherited) or None if there is no parent. title : str | None An optional title. - verbose : bool - Set True to see debugging for signals. Notes ----- @@ -551,8 +547,6 @@ class CheckDictEditList(EditList): Parent Widget (QWidget or inherited) or None if there is no parent. title : str | None An optional title. - verbose : bool - Set True to see debugging for signals. Notes ----- @@ -652,7 +646,7 @@ def _current_changed(self, current_idx, previous_idx): self.currentChanged.emit(current_data, previous_data) - logging.debug(f"Current changed from {current_data} to {previous_data}") + logger().debug(f"Current changed from {current_data} to {previous_data}") def _selected_keyvalue(self, indexes): try: @@ -668,7 +662,7 @@ def _selection_changed(self): self.selectionChanged.emit(selected_data) - logging.debug(f"Selection to {selected_data}") + logger().debug(f"Selection to {selected_data}") def select(self, keys, values, clear_selection=True): key_indices = [i for i, x in enumerate(self.model._data.keys()) if x in keys] @@ -705,8 +699,6 @@ class SimpleDict(BaseDict): Set True to resize the rows to contents. resize_columns : bool Set True to resize the columns to contents. - verbose : bool - Set True to see debugging for signals. """ @@ -730,6 +722,7 @@ def __init__( ) +# ToDo: DataChanged somehow not emitted when row is removed class EditDict(BaseDict): """A Widget to display and edit a Dictionary @@ -752,8 +745,6 @@ class EditDict(BaseDict): Set True to resize the rows to contents. resize_columns : bool Set True to resize the columns to contents. - verbose : bool - Set True to see debugging for signals. """ @@ -851,8 +842,6 @@ class BasePandasTable(Base): The view for the pandas DataFrame. title : str | None An optional title. - verbose : bool - Set True to see debugging for signals. """ def __init__( @@ -921,7 +910,7 @@ def _current_changed(self, current_idx, previous_idx): self.currentChanged.emit(current_list, previous_list) - logging.debug(f"Current changed from {previous_list} to {current_list}") + logger().debug(f"Current changed from {previous_list} to {current_list}") def get_selected(self): # Somehow, the indexes got from selectionChanged @@ -936,7 +925,7 @@ def _selection_changed(self): selection_list = self.get_selected() self.selectionChanged.emit(selection_list) - logging.debug(f"Selection changed to {selection_list}") + logger().debug(f"Selection changed to {selection_list}") def select(self, values=None, rows=None, columns=None, clear_selection=True): """ @@ -1007,8 +996,6 @@ class SimplePandasTable(BasePandasTable): Set True to resize the rows to contents resize_columns : bool Set True to resize the columns to contents - verbose : bool - Set True to see debugging for signals Notes ----- @@ -1058,8 +1045,6 @@ class EditPandasTable(BasePandasTable): Set True to resize the rows to contents. resize_columns : bool Set True to resize the columns to contents. - verbose : bool - Set True to see debugging for signals. Notes ----- @@ -1247,8 +1232,6 @@ class FilePandasTable(BasePandasTable): Parent Widget (QWidget or inherited) or None if there is no parent title : str | None An optional title - verbose : bool - Set True to see debugging for signals Notes ----- @@ -1555,6 +1538,9 @@ def warning( ) +# ToDo: Proper testing +# Testing all signals properly emitted (also on row add/remove) +# Testing when _data is empty and get-Data, what happens? class AllBaseWidgets(QWidget): def __init__(self): super().__init__() @@ -1608,30 +1594,26 @@ def __init__(self): "SimpleList": { "extended_selection": True, "title": "BaseList", - "verbose": True, }, "EditList": { "ui_button_pos": "bottom", "extended_selection": True, "title": "EditList", - "verbose": True, }, - "CheckList": {"one_check": False, "title": "CheckList", "verbose": True}, + "CheckList": {"one_check": False, "title": "CheckList"}, "CheckDictList": { "extended_selection": True, "title": "CheckDictList", - "verbose": True, }, - "CheckDictEditList": {"title": "CheckDictEditList", "verbose": True}, - "SimpleDict": {"title": "BaseDict", "verbose": True}, - "EditDict": {"ui_button_pos": "left", "title": "EditDict", "verbose": True}, - "SimplePandasTable": {"title": "BasePandasTable", "verbose": True}, - "EditPandasTable": {"title": "EditPandasTable", "verbose": True}, - "DictTree": {"title": "BaseDictTree", "verbose": True}, + "CheckDictEditList": {"title": "CheckDictEditList"}, + "SimpleDict": {"title": "BaseDict"}, + "EditDict": {"ui_button_pos": "left", "title": "EditDict"}, + "SimplePandasTable": {"title": "BasePandasTable"}, + "EditPandasTable": {"title": "EditPandasTable"}, + "DictTree": {"title": "BaseDictTree"}, "AssignWidget": { "properties_editable": True, "title": "AssignWidget", - "verbose": True, }, } diff --git a/mne_pipeline_hd/gui/dialogs.py b/mne_pipeline_hd/gui/dialogs.py index e7a9e47..cd74adf 100644 --- a/mne_pipeline_hd/gui/dialogs.py +++ b/mne_pipeline_hd/gui/dialogs.py @@ -254,24 +254,28 @@ def __init__(self, main_win): layout.addWidget(self.from_cmbx, 1, 0) layout.addWidget(QLabel("Parameter-Preset:"), 2, 0) self.from_pp_cmbx = QComboBox() + self.from_pp_cmbx.currentTextChanged.connect(self.from_pp_selected) layout.addWidget(self.from_pp_cmbx, 3, 0) layout.addWidget(QLabel("To:"), 0, 1) self.to_cmbx = QComboBox() self.to_cmbx.currentTextChanged.connect(self.to_selected) - self.to_cmbx.setEnabled(False) layout.addWidget(self.to_cmbx, 1, 1) layout.addWidget(QLabel("Parameter-Preset:"), 2, 1) self.to_pp_cmbx = QComboBox() self.to_pp_cmbx.setEditable(True) layout.addWidget(self.to_pp_cmbx, 3, 1) + layout.addWidget(QLabel("Parameter:"), 4, 0, 1, 2) + self.param_cmbx = QComboBox() + layout.addWidget(self.param_cmbx, 5, 0, 1, 2) + copy_bt = QPushButton("Copy") copy_bt.clicked.connect(self.copy_parameters) - layout.addWidget(copy_bt, 4, 0) + layout.addWidget(copy_bt, 6, 0) close_bt = QPushButton("Close") close_bt.clicked.connect(self.close) - layout.addWidget(close_bt, 4, 1) + layout.addWidget(close_bt, 6, 1) widget.setLayout(layout) super().__init__( @@ -282,6 +286,9 @@ def __init__(self, main_win): show_close_bt=False, ) + # Initialize with first from-entry + self.from_selected(self.from_cmbx.currentText()) + def _get_p_presets(self, pr_name): if self.ct.pr.name == pr_name: project = self.ct.pr @@ -290,15 +297,27 @@ def _get_p_presets(self, pr_name): return list(project.parameters.keys()) + def from_pp_selected(self, from_pp_name): + if from_pp_name: + self.param_cmbx.clear() + params = list( + Project(self.ct, self.from_cmbx.currentText()) + .parameters[from_pp_name] + .keys() + ) + params.insert(0, "") + self.param_cmbx.addItems(params) + def from_selected(self, from_name): if from_name: - self.to_cmbx.setEnabled(True) self.to_cmbx.clear() self.to_cmbx.addItems([p for p in self.ct.projects if p != from_name]) self.from_pp_cmbx.clear() self.from_pp_cmbx.addItems(self._get_p_presets(from_name)) + self.from_pp_selected(self.from_pp_cmbx.currentText()) + def to_selected(self, to_name): if to_name: self.to_pp_cmbx.clear() @@ -309,12 +328,19 @@ def copy_parameters(self): from_pp = self.from_pp_cmbx.currentText() to_name = self.to_cmbx.currentText() to_pp = self.to_pp_cmbx.currentText() + param = self.param_cmbx.currentText() + if param == "": + param = None if from_name and to_name: - self.ct.copy_parameters_between_projects(from_name, from_pp, to_name, to_pp) + self.ct.copy_parameters_between_projects( + from_name, from_pp, to_name, to_pp, param + ) if to_name == self.ct.pr.name: self.main_win.parameters_dock.redraw_param_widgets() QMessageBox().information( - self, "Finished", f"Parameters copied from {from_name} " f"to {to_name}!" + self, + "Finished", + f"Copied parameter '{param}' from {from_name} to {to_name}!", ) diff --git a/mne_pipeline_hd/gui/function_widgets.py b/mne_pipeline_hd/gui/function_widgets.py index dddf9d9..8daca15 100644 --- a/mne_pipeline_hd/gui/function_widgets.py +++ b/mne_pipeline_hd/gui/function_widgets.py @@ -137,11 +137,8 @@ def init_ui(self): self.restart_bt.clicked.connect(self.restart) bt_layout.addWidget(self.restart_bt) - if QS().value("use_qthread"): - self.reload_chbx = None - else: - self.reload_chbx = QCheckBox("Reload Modules") - bt_layout.addWidget(self.reload_chbx) + self.reload_chbx = QCheckBox("Reload Modules") + bt_layout.addWidget(self.reload_chbx) self.autoscroll_bt = QPushButton("Auto-Scroll") self.autoscroll_bt.setCheckable(True) @@ -181,6 +178,8 @@ def restart(self): # ToDo: MP # if self.reload_chbx and self.reload_chbx.isChecked(): # init_mp_pool() + if self.reload_chbx.isChecked(): + self.mw.ct.reload_modules() # Clear Console-Widget self.console_widget.clear() diff --git a/mne_pipeline_hd/gui/gui_utils.py b/mne_pipeline_hd/gui/gui_utils.py index 4a59428..13eb3ec 100644 --- a/mne_pipeline_hd/gui/gui_utils.py +++ b/mne_pipeline_hd/gui/gui_utils.py @@ -40,7 +40,7 @@ ) from mne_pipeline_hd import _object_refs -from mne_pipeline_hd.pipeline.pipeline_utils import QS +from mne_pipeline_hd.pipeline.pipeline_utils import QS, logger def center(widget): @@ -151,7 +151,7 @@ def show_error_dialog(exc_str): if QApplication.instance() is not None: ErrorDialog(exc_str, title="A unexpected error occurred") else: - logging.debug("No QApplication instance available.") + logger().debug("No QApplication instance available.") def gui_error_decorator(func): @@ -204,7 +204,7 @@ def exception_hook(self, exc_type, exc_value, exc_traceback): exc_value, "".join(traceback.format_tb(exc_traceback)), ) - logging.critical( + logger().critical( f"Uncaught exception:\n" f"{exc_str[0]}: {exc_str[1]}\n" f"{exc_str[2]}", @@ -220,15 +220,6 @@ def __init__(self, parent=None): super().__init__(parent) -def _html_compatible(text): - text = text.replace("<", "<") - text = text.replace(">", ">") - text = text.replace("\n", "
") - text = text.replace("\x1b", "") - - return text - - class ConsoleWidget(QPlainTextEdit): """A Widget displaying formatted stdout/stderr-output""" @@ -237,62 +228,70 @@ def __init__(self): self.setReadOnly(True) self.autoscroll = True - - self.buffer_time = 1 + self.is_progress = False # Buffer to avoid crash for too many inputs self.buffer = list() + self.buffer_time = 50 self.buffer_timer = QTimer() self.buffer_timer.timeout.connect(self.write_buffer) self.buffer_timer.start(self.buffer_time) - def _add_html(self, text): - self.appendHtml(text) - if self.autoscroll: - self.ensureCursorVisible() - def write_buffer(self): + if self.is_progress: + # Delete last line + cursor = self.textCursor() + # Avoid having no break between progress and text + # Remove last line + cursor.select(QTextCursor.LineUnderCursor) + cursor.removeSelectedText() + cursor.deletePreviousChar() + self.is_progress = False + if len(self.buffer) > 0: - text, kind = self.buffer.pop(0) - if kind == "html": - self._add_html(text) - - elif kind == "progress": - text = text.replace("\r", "") - text = _html_compatible(text) - text = f'{text}' - # Delete last line - cursor = self.textCursor() - cursor.select(QTextCursor.LineUnderCursor) - cursor.removeSelectedText() - self._add_html(text) - - elif kind == "stdout": - text = _html_compatible(text) - self._add_html(text) - - elif kind == "stderr": - # weird characters in some progress are excluded - # (e.g. from autoreject) - if "\x1b" not in text: - text = _html_compatible(text) - text = f'{text}' - self._add_html(text) + text_list = self.buffer.copy() + self.buffer.clear() + text = "".join(text_list) + # Remove last break because of appendHtml above + if text[-4:] == "
": + text = text[:-4] + self.appendHtml(text) + if self.autoscroll: + self.ensureCursorVisible() def set_autoscroll(self, autoscroll): self.autoscroll = autoscroll def write_html(self, text): - self.buffer.append((text, "html")) + self.buffer.append(text) + + def _html_compatible(self, text): + text = text.replace("<", "<") + text = text.replace(">", ">") + text = text.replace("\n", "
") + text = text.replace("\x1b", "") + + if text[:1] == "\r": + self.is_progress = True + text = text.replace("\r", "") + # Avoid having no break between progress and text + text = f"{text}" + if len(self.buffer) > 0: + if self.buffer[-1][:20] == "": + self.buffer.pop(-1) + return text def write_stdout(self, text): - self.buffer.append((text, "stdout")) + text = self._html_compatible(text) + self.buffer.append(text) def write_stderr(self, text): - self.buffer.append((text, "stderr")) - - def write_progress(self, text): - self.buffer.append((text, "progress")) + text = self._html_compatible(text) + if text[-4:] == "
": + text = f'{text[:-4]}
' + else: + text = f'{text}' + self.buffer.append(text) # Make sure cursor is not moved def mousePressEvent(self, event): @@ -312,40 +311,34 @@ def __init__(self): # Connect custom stdout and stderr to display-function sys.stdout.signal.text_written.connect(self.write_stdout) sys.stderr.signal.text_written.connect(self.write_stderr) - # Handle progress-bars - sys.stdout.signal.text_updated.connect(self.write_progress) - sys.stderr.signal.text_updated.connect(self.write_progress) class StreamSignals(QObject): - text_updated = Signal(str) text_written = Signal(str) -# ToDo: Buffering and halting signal-emission -# (continue writing to sys.__stdout__/__stderr__) -# when no accepted/printed-signal is coming back from receiving Widget class StdoutStderrStream(io.TextIOBase): def __init__(self, kind): super().__init__() self.signal = StreamSignals() - if kind == "stdout": + self.kind = kind + if self.kind == "stdout": self.original_stream = sys.__stdout__ - else: + elif self.kind == "stderr": self.original_stream = sys.__stderr__ + else: + self.original_stream = None def write(self, text): - # Still send output to the command-line - self.original_stream.write(text) - - # Get progress-text with '\r' as prefix - if text[:1] == "\r": - self.signal.text_updated.emit(text) - else: - self.signal.text_written.emit(text) + if self.original_stream is not None: + # Still send output to the command-line + self.original_stream.write(text) + # Emit signal to display in GUI + self.signal.text_written.emit(text) def flush(self): - self.original_stream.flush() + if self.original_stream is not None: + self.original_stream.flush() class WorkerSignals(QObject): @@ -738,7 +731,7 @@ def get_user_input_string(prompt, title="Input required!", force=False): "You need to provide an appropriate input to proceed!", ) else: - logging.warning( + logger().warning( "Input required! You need to provide " "an appropriate input to proceed!" ) diff --git a/mne_pipeline_hd/gui/loading_widgets.py b/mne_pipeline_hd/gui/loading_widgets.py index 7c332d4..ef6cb83 100644 --- a/mne_pipeline_hd/gui/loading_widgets.py +++ b/mne_pipeline_hd/gui/loading_widgets.py @@ -4,7 +4,6 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ -import logging import os import re import shutil @@ -79,10 +78,10 @@ from mne_pipeline_hd.gui.models import AddFilesModel from mne_pipeline_hd.gui.parameter_widgets import ComboGui from mne_pipeline_hd.pipeline.loading import FSMRI, Group, MEEG -from mne_pipeline_hd.pipeline.pipeline_utils import compare_filep, QS +from mne_pipeline_hd.pipeline.pipeline_utils import compare_filep, QS, logger -def index_parser(index, all_items): +def index_parser(index, all_items, groups=None): """ Parses indices from a index-string in all_items @@ -96,7 +95,7 @@ def index_parser(index, all_items): ------- """ - run = list() + indices = list() rm = list() try: @@ -115,9 +114,9 @@ def index_parser(index, all_items): rm.append(int(sp[1:])) elif "all" in sp: for i in range(len(all_items)): - run.append(i) + indices.append(i) else: - run = [x for x in range(len(all_items))] + indices = [x for x in range(len(all_items))] elif "," in index and "-" in index: z = index.split(",") @@ -125,9 +124,9 @@ def index_parser(index, all_items): if "-" in i and "!" not in i: x, y = i.split("-") for n in range(int(x), int(y) + 1): - run.append(n) + indices.append(n) elif "!" not in i: - run.append(int(i)) + indices.append(int(i)) elif "!" in i and "-" in i: x, y = i.split("-") x = x[1:] @@ -138,7 +137,7 @@ def index_parser(index, all_items): elif "-" in index and "," not in index: x, y = index.split("-") - run = [x for x in range(int(x), int(y) + 1)] + indices = [x for x in range(int(x), int(y) + 1)] elif "," in index and "-" not in index: splits = index.split(",") @@ -146,21 +145,25 @@ def index_parser(index, all_items): if "!" in sp: rm.append(int(sp)) else: - run.append(int(sp)) + indices.append(int(sp)) + + elif groups is not None and index in groups: + files = [x for x in all_items if x in groups[index]] + indices = [all_items.index(x) for x in files] else: if len(all_items) < int(index) or int(index) < 0: - run = [] + indices = [] else: - run = [int(index)] + indices = [int(index)] - run = [i for i in run if i not in rm] - files = [x for (i, x) in enumerate(all_items) if i in run] + indices = [i for i in indices if i not in rm] + files = np.asarray(all_items)[indices].tolist() - return files, run + return files except ValueError: - return [], [] + return [] class RemoveDialog(QDialog): @@ -232,7 +235,8 @@ def init_ui(self): "'1-4,7,20-26' (The last two combined)\n" "'1-20,!4-6' (1-20 except 4-6)\n" "'all' (All files in file_list.py)\n" - "'all,!4-6' (All files except 4-6)" + "'all,!4-6' (All files except 4-6)\n" + " (All files in group)" ) if self.meeg_view: @@ -328,13 +332,15 @@ def reload_dock(self): def select_meeg(self): index = self.meeg_ledit.text() - self.mw.ct.pr.sel_meeg, idxs = index_parser(index, self.mw.ct.pr.all_meeg) + self.mw.ct.pr.sel_meeg = index_parser( + index, self.mw.ct.pr.all_meeg, self.mw.ct.pr.all_groups + ) # Replace _checked in CheckListModel because of rereferencing above self.meeg_list.replace_checked(self.mw.ct.pr.sel_meeg) def select_fsmri(self): index = self.fsmri_ledit.text() - self.mw.ct.pr.sel_fsmri, idxs = index_parser(index, self.mw.ct.pr.all_fsmri) + self.mw.ct.pr.sel_fsmri = index_parser(index, self.mw.ct.pr.all_fsmri) # Replace _checked in CheckListModel because of rereferencing above self.fsmri_list.replace_checked(self.mw.ct.pr.sel_fsmri) @@ -771,7 +777,7 @@ def add_files(self, worker_signals): self.pr.add_meeg(name, file_path, is_erm) worker_signals.pgbar_n.emit(n + 1) else: - logging.info("Canceled Loading") + logger().info("Canceled Loading") break def add_files_starter(self): @@ -930,9 +936,9 @@ def import_mri_subject(self): self.paths.update({fsmri: folder_path}) self.populate_list_widget() else: - logging.info(f"{fsmri} already existing in {self.ct.subjects_dir}") + logger().info(f"{fsmri} already existing in {self.ct.subjects_dir}") else: - logging.warning( + logger().warning( "Selected Folder doesn't seem to " "be a Freesurfer-Segmentation" ) @@ -952,9 +958,9 @@ def import_mri_subjects(self): self.folders.append(fsmri) self.paths.update({fsmri: folder_path}) else: - logging.info(f"{fsmri} already existing in {self.ct.subjects_dir}") + logger().info(f"{fsmri} already existing in {self.ct.subjects_dir}") else: - logging.warning( + logger().warning( "Selected Folder doesn't seem to be " "a Freesurfer-Segmentation" ) self.populate_list_widget() @@ -1424,13 +1430,14 @@ def __init__(self, main_win): self.name = None self.event_id = dict() + self.queries = dict() self.labels = list() self.checked_labels = list() self.layout = QVBoxLayout() self.init_ui() - self.open() + self.show() def init_ui(self): list_layout = QHBoxLayout() @@ -1462,6 +1469,15 @@ def init_ui(self): list_layout.addLayout(event_id_layout) + self.query_widget = EditDict( + self.queries, ui_buttons=True, title="Metadata-Queries" + ) + self.query_widget.setToolTip( + "Add Metadata-Queries as value for trials which are named with key" + ) + self.query_widget.dataChanged.connect(self.update_check_list) + list_layout.addWidget(self.query_widget) + self.check_widget = CheckList(title="Select IDs") list_layout.addWidget(self.check_widget) @@ -1493,15 +1509,27 @@ def get_event_id(self): self.event_id = dict() self.event_id_widget.replace_data(self.event_id) + meeg = MEEG(self.name, self.ct, suppress_warnings=True) try: # Load events from File - meeg = MEEG(self.name, self.ct, suppress_warnings=True) events = meeg.load_events() except FileNotFoundError: - self.event_id_label.setText(f"No events found for {self.name}") + label_text = f"No events found for {self.name}" else: - ids = np.unique(events[:, 2]) - self.event_id_label.setText(f"events found: {ids}") + label_text = f"events found: {np.unique(events[:, 2])}" + + try: + # Load epochs from File + epochs = meeg.load_epochs() + assert epochs.metadata is not None + except (FileNotFoundError, AssertionError): + self.query_widget.setEnabled(False) + label_text += "\nNo metadata found" + else: + self.query_widget.setEnabled(True) + label_text += "\nMetadata found" + + self.event_id_label.setText(label_text) def save_event_id(self): if self.name: @@ -1509,8 +1537,14 @@ def save_event_id(self): # Write Event-ID to Project self.pr.meeg_event_id[self.name] = self.event_id - # Get selected Trials and write them to meeg.pr - self.pr.sel_event_id[self.name] = self.checked_labels + # Get selected Trials, add queries and write them to meeg.pr + sel_event_id = dict() + for label in self.checked_labels: + if label in self.queries: + sel_event_id[label] = self.queries[label] + else: + sel_event_id[label] = None + self.pr.sel_event_id[self.name] = sel_event_id def file_selected(self, current, _): """Called when File from file_widget is selected""" @@ -1523,12 +1557,23 @@ def file_selected(self, current, _): # Load checked trials if self.name in self.pr.sel_event_id: - self.checked_labels = self.pr.sel_event_id[self.name] + # Update query-widget + if self.query_widget.isEnabled(): + sel_trials = self.pr.sel_event_id[self.name] + if not isinstance(sel_trials, dict): + sel_trials = {k: None for k in sel_trials} + self.queries = {k: v for k, v in sel_trials.items() if v is not None} + self.query_widget.replace_data(self.queries) + # Legacy to allow reading lists before + # they were changed to dicts for queries + self.checked_labels = list(self.pr.sel_event_id[self.name]) else: self.checked_labels = list() self.update_check_list() + # ToDo: Make all combinations possible def update_check_list(self): + self.labels = [k for k in self.queries.keys()] # Get selectable trials and update widget prelabels = [i.split("/") for i in self.event_id.keys() if i != ""] if len(prelabels) > 0: @@ -1538,11 +1583,14 @@ def update_check_list(self): for item in prelabels[1:]: conc_labels += item # Make sure that only unique labels exist - self.labels = list(set(conc_labels)) + self.labels += list(set(conc_labels)) # Make sure, that only trials, which exist in event_id exist for chk_label in self.checked_labels: - if not any(chk_label in key for key in self.event_id): + if ( + not any(chk_label in key for key in self.event_id) + and chk_label not in self.queries + ): self.checked_labels.remove(chk_label) else: self.labels = list() @@ -1568,6 +1616,10 @@ class EvIDApply(QDialog): def __init__(self, parent): super().__init__(parent) self.p = parent + + # Save to make sel_event_id available in apply_evid + self.p.save_event_id() + self.apply_to = list() self.layout = QVBoxLayout() @@ -1599,8 +1651,8 @@ def apply_evid(self): for file in self.apply_to: # Avoid with copy that CheckList-Model changes selected # for all afterwards (same reference) - self.p.pr.meeg_event_id[file] = self.p.event_id.copy() - self.p.pr.sel_event_id[file] = self.p.checked_labels.copy() + self.p.pr.meeg_event_id[file] = self.p.pr.meeg_event_id[self.p.name].copy() + self.p.pr.sel_event_id[file] = self.p.pr.sel_event_id[self.p.name].copy() class CopyTrans(QDialog): @@ -1740,7 +1792,7 @@ def get_file_tables(self, kind): obj_pd = self.pd_group obj_pd_time = self.pd_group_time obj_pd_size = self.pd_group_size - logging.debug(f"Loading {kind}") + logger().debug(f"Loading {kind}") for obj_name in obj_list: if kind == "MEEG": @@ -2302,7 +2354,7 @@ def reload_raw(self, selected_raw, raw_path): meeg = MEEG(selected_raw, self.ct) raw = mne.io.read_raw(raw_path, preload=True) meeg.save_raw(raw) - logging.info(f"Reloaded raw for {selected_raw}") + logger().info(f"Reloaded raw for {selected_raw}") def start_reload(self): # Not with partial because otherwise the clicked-arg @@ -2378,7 +2430,7 @@ def _init_ui(self): def export_data(self): if self.dest_path: - logging.info("Starting Export\n") + logger().info("Starting Export\n") for meeg_name, path_types in self.export_paths.items(): os.mkdir(join(self.dest_path, meeg_name)) for path_type in [pt for pt in path_types if pt in self.selected_types]: @@ -2388,6 +2440,6 @@ def export_data(self): shutil.copy2( src_path, join(self.dest_path, meeg_name, dest_name) ) - logging.info(f"\r{meeg_name}: Copying {path_type}...") + logger().info(f"\r{meeg_name}: Copying {path_type}...") else: QMessageBox.warning(self, "Ups!", "Destination-Path not set!") diff --git a/mne_pipeline_hd/gui/main_window.py b/mne_pipeline_hd/gui/main_window.py index 8ffa690..8706049 100644 --- a/mne_pipeline_hd/gui/main_window.py +++ b/mne_pipeline_hd/gui/main_window.py @@ -4,7 +4,6 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ -import logging import sys from functools import partial @@ -32,7 +31,6 @@ ) from mne_pipeline_hd import _object_refs -from mne_pipeline_hd.functions.plot import close_all from mne_pipeline_hd.gui.dialogs import ( QuickGuide, RawInfo, @@ -79,12 +77,14 @@ from mne_pipeline_hd.gui.plot_widgets import PlotViewSelection from mne_pipeline_hd.gui.tools import DataTerminal from mne_pipeline_hd.pipeline.controller import Controller +from mne_pipeline_hd.pipeline.function_utils import close_all from mne_pipeline_hd.pipeline.pipeline_utils import ( restart_program, ismac, QS, _run_from_script, iswin, + logger, ) @@ -202,6 +202,10 @@ def project_changed(self, idx): self.update_project_ui() + def pr_rename(self): + self.ct.rename_project() + self.update_project_box() + def pr_clean_fp(self): WorkerDialog( self, @@ -299,6 +303,7 @@ def init_menu(self): # Project project_menu = self.menuBar().addMenu("&Project") + project_menu.addAction("&Rename Project", self.pr_rename) project_menu.addAction("&Clean File-Parameters", self.pr_clean_fp) project_menu.addAction("&Clean Plot-Files", self.pr_clean_pf) project_menu.addAction( @@ -556,7 +561,7 @@ def update_func_bts(self): try: tab.deleteLater() except RuntimeError: - logging.debug("Tab already deleted") + logger().debug("Tab already deleted") self.bt_dict = dict() self.add_func_bts() diff --git a/mne_pipeline_hd/gui/models.py b/mne_pipeline_hd/gui/models.py index 3040731..c9707cd 100644 --- a/mne_pipeline_hd/gui/models.py +++ b/mne_pipeline_hd/gui/models.py @@ -4,7 +4,6 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ - from ast import literal_eval from datetime import datetime @@ -19,6 +18,8 @@ from qtpy.QtGui import QBrush, QFont from mne_pipeline_hd.gui.gui_utils import get_std_icon +from mne_pipeline_hd.pipeline.pipeline_utils import logger + # ToDo: Merge models and base widgets @@ -46,6 +47,9 @@ def __init__(self, data=None, show_index=False, drag_drop=False, **kwargs): self._data = data def getData(self, index): + if len(self._data) == 0: + logger().debug("List is empty") + return None return self._data[index.row()] def data(self, index, role=None): @@ -182,6 +186,8 @@ def data(self, index, role=None): def setData(self, index, value, role=None): if role == Qt.CheckStateRole: + # ToDo: This does not work under PySide6 + # since Qt.Checked returns no integer (only Qt.Checked.value) if value == Qt.Checked: if self.one_check: self._checked.clear() @@ -769,6 +775,8 @@ def setData(self, index, value, role=None): role == Qt.CheckStateRole and self._data.columns[index.column()] == "Empty-Room?" ): + # ToDo: This does not work under PySide6 + # since Qt.Checked returns no integer (only Qt.Checked.value) if value == Qt.Checked: self._data.iloc[index.row(), index.column()] = 1 else: diff --git a/mne_pipeline_hd/gui/parameter_widgets.py b/mne_pipeline_hd/gui/parameter_widgets.py index ef56954..ea519b6 100644 --- a/mne_pipeline_hd/gui/parameter_widgets.py +++ b/mne_pipeline_hd/gui/parameter_widgets.py @@ -4,7 +4,6 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ -import logging from ast import literal_eval from copy import copy from functools import partial @@ -60,7 +59,7 @@ ) from mne_pipeline_hd.pipeline.controller import Controller from mne_pipeline_hd.pipeline.loading import FSMRI -from mne_pipeline_hd.pipeline.pipeline_utils import QS, iswin +from mne_pipeline_hd.pipeline.pipeline_utils import QS, iswin, logger # ToDo: Unify None-select and more @@ -237,31 +236,36 @@ def set_param(self, value): def _read_data(self, name): # get data from dictionary - if isinstance(self.data, dict): - if name in self.data: - value = self.data[name] - else: - value = self.default + if isinstance(self.data, dict) and name in self.data: + value = self.data[name] # get data from Parameters in Project in MainWindow # (depending on selected parameter-preset and selected Project) - elif isinstance(self.data, Controller): - if name in self.data.pr.parameters[self.data.pr.p_preset]: - value = self.data.pr.parameters[self.data.pr.p_preset][name] - else: - value = self.default + elif ( + isinstance(self.data, Controller) + and name in self.data.pr.parameters[self.data.pr.p_preset] + ): + value = self.data.pr.parameters[self.data.pr.p_preset][name] # get data from QSettings - elif isinstance(self.data, QS): - if name in self.data.childKeys(): - value = self.data.value(name) - else: - value = self.default + elif isinstance(self.data, QS) and name in self.data.childKeys(): + value = self.data.value(name) + else: + value = self.default return value def read_param(self): - self.param_value = self._read_data(self.name) + data = self._read_data(self.name) + if not self.none_select: + if self.data_type != "multiple": + if not isinstance(data, self.data_type): + logger().warning( + f"Data for {self.name} has to be of type {self.data_type}, " + f"but is of type {type(data)} instead!" + ) + data = self.data_type() + self.param_value = data def _save_data(self, name, value): if isinstance(self.data, dict): @@ -278,6 +282,8 @@ def save_param(self): class IntGui(Param): """A GUI for Integer-Parameters""" + data_type = int + def __init__(self, min_val=0, max_val=1000, special_value_text=None, **kwargs): """ Parameters @@ -324,6 +330,8 @@ def get_value(self): class FloatGui(Param): """A GUI for Float-Parameters""" + data_type = float + def __init__(self, min_val=-1000.0, max_val=1000.0, step=0.1, decimals=2, **kwargs): """ Parameters @@ -373,6 +381,8 @@ class StringGui(Param): A GUI for String-Parameters """ + data_type = str + def __init__(self, **kwargs): """ @@ -415,6 +425,8 @@ def _eval_param(param_exp): class FuncGui(Param): """A GUI for Parameters defined by small functions, e.g from numpy""" + data_type = "multiple" + def __init__(self, **kwargs): """ Parameters @@ -502,6 +514,8 @@ def save_param(self): class BoolGui(Param): """A GUI for Boolean-Parameters""" + data_type = bool + def __init__(self, return_integer=False, **kwargs): """ Parameters @@ -541,6 +555,8 @@ def get_value(self): class TupleGui(Param): """A GUI for Tuple-Parameters""" + data_type = tuple + def __init__(self, min_val=-1000.0, max_val=1000.0, step=0.1, **kwargs): """ Parameters @@ -601,10 +617,11 @@ def set_value(self, value): # Signal valueChanged is already emitted after first setValue, # which leads to second param_value being 0 without being # preserved in self.loaded_value - self._external_set = True - self.param_widget1.setValue(value[0]) - self.param_widget2.setValue(value[1]) - self._external_set = False + if len(value) == 2: + self._external_set = True + self.param_widget1.setValue(value[0]) + self.param_widget2.setValue(value[1]) + self._external_set = False def _get_param(self): if not self._external_set: @@ -614,9 +631,12 @@ def get_value(self): return self.param_widget1.value(), self.param_widget2.value() +# ToDo: make options replacable class ComboGui(Param): """A GUI for a Parameter with limited options""" + data_type = "multiple" + def __init__(self, options, **kwargs): """ Parameters @@ -698,6 +718,8 @@ def closeEvent(self, event): class ListGui(Param): """A GUI for as list""" + data_type = list + def __init__(self, value_string_length=30, **kwargs): """ Parameters @@ -791,9 +813,12 @@ def closeEvent(self, event): event.accept() +# ToDo: make options replacable class CheckListGui(Param): """A GUI to select items from a list of options""" + data_type = list + def __init__(self, options, value_string_length=30, one_check=False, **kwargs): """ Parameters @@ -892,6 +917,8 @@ def closeEvent(self, event): class DictGui(Param): """A GUI for a dictionary""" + data_type = dict + def __init__(self, value_string_length=30, **kwargs): """ @@ -967,6 +994,8 @@ def get_value(self): class SliderGui(Param): """A GUI to show a slider for Int/Float-Parameters""" + data_type = "multiple" + def __init__(self, min_val=0, max_val=100, step=1, tracking=True, **kwargs): """ Parameters @@ -1059,6 +1088,8 @@ def get_value(self): class MultiTypeGui(Param): """A GUI which accepts multiple types of values in a single LineEdit""" + data_type = "multiple" + def __init__(self, type_selection=False, types=None, type_kwargs=None, **kwargs): """ Parameters @@ -1151,7 +1182,7 @@ def change_type(self, type_idx): try: old_widget.widget().deleteLater() except RuntimeError: - logging.debug("Old widget already deleted") + logger().debug("Old widget already deleted") del old_widget, self.param_widget self.param_type = self.types[type_idx] @@ -1303,16 +1334,23 @@ def _label_picked(self, vtk_picker, _): self._add_label_name(label.name, hemi, label) self.selected.append(label.name) self.list_changed_slot() + self.paramdlg.update_selected_display() # Update label text if "label" in self._actors["text"]: self.remove_text("label") + if label.color is not None: + color = label.color[:3] + opacity = label.color[-1] + else: + color = "w" + opacity = 1 self.add_text( 0, 0.05, label.name, - color=label.color[:3], - opacity=label.color[-1], + color=color, + opacity=opacity, font_size=12, name="label", ) @@ -1400,6 +1438,12 @@ def _init_layout(self): self.surface_cmbx.activated.connect(self._surface_changed) layout.addWidget(self.surface_cmbx) + self.selected_display = SimpleList( + data=self._selected_parc_labels + self._selected_extra_labels, + title="Selected Labels", + ) + layout.addWidget(self.selected_display) + self.parc_label_list = CheckList( data=self._parc_labels, checked=self._selected_parc_labels, @@ -1438,27 +1482,30 @@ def _subject_changed(self): self.parcellation_cmbx.clear() self.parcellation_cmbx.addItems(self._fsmri.parcellations) - # Get currently set parcellation - if ( - self.ct.pr.parameters[self.ct.pr.p_preset]["target_parcellation"] - in self._fsmri.parcellations - ): - self.parcellation_cmbx.setCurrentText( - self.ct.pr.parameters[self.ct.pr.p_preset]["target_parcellation"] - ) - - # Add extra labels + # Update extra labels self._extra_labels.clear() self._extra_labels += [lb.name for lb in self._fsmri.labels["Other"]] self.extra_label_list.content_changed() - old_selected = self._selected_extra_labels.copy() + old_selected_extra = self._selected_extra_labels.copy() self._selected_extra_labels.clear() self._selected_extra_labels += [ - lb for lb in old_selected if lb in self._extra_labels + lb for lb in old_selected_extra if lb in self._extra_labels ] self.extra_label_list.content_changed() + # Update selected parcellation labels + all_labels_exept_other = list() + for parc_name, labels in self._fsmri.labels.items(): + if parc_name != "Other": + all_labels_exept_other += [lb.name for lb in labels] + old_selected_parc = self._selected_parc_labels.copy() + self._selected_parc_labels.clear() + self._selected_parc_labels += [ + lb for lb in old_selected_parc if lb in all_labels_exept_other + ] + self.parc_label_list.content_changed() + # Update pickers if open if self._parc_picker is not None and not self._parc_picker.isclosed(): self._parc_picker.close() @@ -1481,16 +1528,13 @@ def _parc_changed(self): lb.name for lb in self._fsmri.labels[self._parcellation] ] - # get former selected - old_selected = self._selected_parc_labels.copy() - self._selected_parc_labels.clear() - self._selected_parc_labels += [ - lb for lb in old_selected if lb in self._parc_labels - ] - self.parc_label_list.content_changed() - if self._parc_picker is not None and not self._parc_picker.isclosed(): self._parc_picker._set_annotations(self._parcellation) + for label_name in [ + lb for lb in self._selected_parc_labels if lb in self._parc_labels + ]: + hemi = label_name[-2:] + self._parc_picker._add_label_name(label_name, hemi) def _surface_changed(self): self._surface = self.surface_cmbx.currentText() @@ -1498,6 +1542,11 @@ def _surface_changed(self): self._parc_picker.close() self._open_parc_picker() + def update_selected_display(self): + self.selected_display.replace_data( + self._selected_parc_labels + self._selected_extra_labels + ) + def _labels_changed(self, labels, picker_name): picker = ( self._parc_picker if picker_name == "parcellation" else self._extra_picker @@ -1510,6 +1559,8 @@ def _labels_changed(self, labels, picker_name): for remove_name in [lb for lb in shown_labels if lb not in labels]: hemi = remove_name[-2:] picker._remove_label_name(remove_name, hemi) + # Update display + self.update_selected_display() # Keep pickers on top def _open_parc_picker(self): @@ -1543,6 +1594,8 @@ def closeEvent(self, event): class LabelGui(Param): """This GUI lets the user pick labels from a brain.""" + data_type = list + def __init__(self, value_string_length=30, **kwargs): """ Parameters @@ -1620,6 +1673,8 @@ def get_value(self): class ColorGui(Param): """A GUI to pick a color and returns a dictionary with HexRGBA-Strings.""" + data_type = dict + def __init__(self, keys, **kwargs): """ Parameters @@ -1703,6 +1758,8 @@ def _pick_color(self): class PathGui(Param): """A GUI to pick a path.""" + data_type = str + def __init__(self, pick_mode="file", **kwargs): """ Parameters diff --git a/mne_pipeline_hd/gui/plot_widgets.py b/mne_pipeline_hd/gui/plot_widgets.py index 7cfd3b0..7524962 100644 --- a/mne_pipeline_hd/gui/plot_widgets.py +++ b/mne_pipeline_hd/gui/plot_widgets.py @@ -4,11 +4,15 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ -import logging from functools import partial from importlib import import_module from os.path import join, isfile +from matplotlib import pyplot as plt +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg +from matplotlib.figure import Figure +from mne.viz import Brain +from mne_qt_browser._pg_figure import MNEQtBrowser from qtpy.QtCore import Qt, QThreadPool from qtpy.QtGui import QPixmap, QFont from qtpy.QtWidgets import ( @@ -29,11 +33,8 @@ QToolBar, QSpinBox, ) -from matplotlib import pyplot as plt -from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg -from matplotlib.figure import Figure -from mne.viz import Brain -from mne_qt_browser._pg_figure import MNEQtBrowser + +from mne_pipeline_hd.pipeline.pipeline_utils import logger try: from mne.viz import Figure3D @@ -110,7 +111,7 @@ def add_plot(self, plot, name, func_name): elif isinstance(subplot, Figure3D): plot_widget = subplot else: - logging.error( + logger().error( f'Unrecognized type "{type(subplot)}" ' f'for "{func_name}"' ) plot_widget = QWidget() @@ -318,7 +319,7 @@ def error_func(err_tuple, o_name=obj_name, ppreset=p_preset): self.thread_error(err_tuple, o_name, ppreset, "plot") worker.signals.error.connect(error_func) - logging.info( + logger().info( f"Starting Thread for Object= {obj_name} " f"and Parameter-Preset= {p_preset}" ) diff --git a/mne_pipeline_hd/pipeline/controller.py b/mne_pipeline_hd/pipeline/controller.py index 7c98e2e..dbd0997 100644 --- a/mne_pipeline_hd/pipeline/controller.py +++ b/mne_pipeline_hd/pipeline/controller.py @@ -23,7 +23,7 @@ from mne_pipeline_hd import functions, extra from mne_pipeline_hd.gui.gui_utils import get_user_input_string from mne_pipeline_hd.pipeline.legacy import transfer_file_params_to_single_subject -from mne_pipeline_hd.pipeline.pipeline_utils import QS +from mne_pipeline_hd.pipeline.pipeline_utils import QS, logger from mne_pipeline_hd.pipeline.project import Project home_dirs = ["custom_packages", "freesurfer", "projects"] @@ -50,10 +50,14 @@ def __init__(self, home_path=None, selected_project=None, edu_program_name=None) # Initialize log-file self.logging_path = join(self.home_path, "_pipeline.log") + file_handlers = [h for h in logger().handlers if h.name == "file"] + if len(file_handlers) > 0: + logger().removeHandler(file_handlers[0]) file_handler = logging.FileHandler(self.logging_path, "w") - logging.getLogger().addHandler(file_handler) + file_handler.set_name("file") + logger().addHandler(file_handler) - logging.info(f"Home-Path: {self.home_path}") + logger().info(f"Home-Path: {self.home_path}") QS().setValue("home_path", self.home_path) # Create subdirectories if not existing for a valid home_path for subdir in [d for d in home_dirs if not isdir(join(self.home_path, d))]: @@ -169,7 +173,7 @@ def save_settings(self): ) as file: json.dump(self.settings, file, indent=4) except FileNotFoundError: - logging.warning("Settings could not be saved!") + logger().warning("Settings could not be saved!") # Sync QSettings with other instances QS().sync() @@ -187,7 +191,7 @@ def change_project(self, new_project): self.settings["selected_project"] = new_project if new_project not in self.projects: self.projects.append(new_project) - logging.info(f"Selected-Project: {self.pr.name}") + logger().info(f"Selected-Project: {self.pr.name}") # Legacy transfer_file_params_to_single_subject(self) @@ -209,20 +213,58 @@ def remove_project(self, project): shutil.rmtree(join(self.projects_path, project)) except OSError as error: print(error) - logging.warning( + logger().warning( f"The folder of {project} can't be deleted " f"and has to be deleted manually!" ) + def rename_project(self): + check_writable = os.access(self.pr.project_path, os.W_OK) + if check_writable: + new_project_name = get_user_input_string( + f'Change the name of project "{self.pr.name}" to:', + "Rename Project", + force=False, + ) + if new_project_name is not None: + try: + old_name = self.pr.name + self.pr.rename(new_project_name) + except PermissionError: + # ToDo: Warning-Function for GUI with dialog and non-GUI + logger().critical( + f"Can't rename {old_name} to {new_project_name}. " + f"Probably a file from inside the project is still opened. " + f"Please close all files and try again." + ) + else: + self.projects.remove(old_name) + self.projects.append(new_project_name) + else: + logger().warning( + "The project-folder seems to be not writable at the moment, " + "maybe some files inside are still in use?" + ) + def copy_parameters_between_projects( - self, from_name, from_p_preset, to_name, to_p_preset + self, + from_name, + from_p_preset, + to_name, + to_p_preset, + parameter=None, ): from_project = Project(self, from_name) if to_name == self.pr.name: to_project = self.pr else: to_project = Project(self, to_name) - to_project.parameters[to_p_preset] = from_project.parameters[from_p_preset] + if parameter is not None: + from_param = from_project.parameters[from_p_preset][parameter] + to_project.parameters[to_p_preset][parameter] = from_param + else: + from_param = from_project.parameters[from_p_preset] + to_project.parameters[to_p_preset] = from_param to_project.save() def save(self, worker_signals=None): @@ -362,7 +404,7 @@ def import_custom_modules(self): else: missing_files = [key for key in file_dict if file_dict[key] is None] - logging.warning( + logger().warning( f"Files for import of {pkg_name} " f"are missing: {missing_files}" ) diff --git a/mne_pipeline_hd/pipeline/function_utils.py b/mne_pipeline_hd/pipeline/function_utils.py index 515f6be..e84e8fb 100644 --- a/mne_pipeline_hd/pipeline/function_utils.py +++ b/mne_pipeline_hd/pipeline/function_utils.py @@ -4,22 +4,24 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ +from __future__ import print_function +import gc import inspect import io -import logging import sys from collections import OrderedDict from importlib import import_module from multiprocessing import Pipe +from matplotlib import pyplot as plt from qtpy.QtCore import QThreadPool, QRunnable, Slot, QObject, Signal from qtpy.QtWidgets import QAbstractItemView from mne_pipeline_hd.gui.base_widgets import TimedMessageBox from mne_pipeline_hd.gui.gui_utils import get_exception_tuple, ExceptionTuple, Worker from mne_pipeline_hd.pipeline.loading import BaseLoading, FSMRI, Group, MEEG -from mne_pipeline_hd.pipeline.pipeline_utils import shutdown, ismac, QS +from mne_pipeline_hd.pipeline.pipeline_utils import shutdown, ismac, QS, logger def get_func(func_name, obj): @@ -98,10 +100,7 @@ def write(self, text): while self.manager.pipe_busy: pass self.manager.pipe_busy = True - if text[:1] == "\r": - kind = "progress" - else: - kind = self.kind + kind = self.kind self.pipe.send((text, kind)) self.manager.pipe_busy = False @@ -109,7 +108,6 @@ def write(self, text): class StreamRcvSignals(QObject): stdout_received = Signal(str) stderr_received = Signal(str) - progress_received = Signal(str) class StreamReceiver(QRunnable): @@ -126,12 +124,10 @@ def run(self): except EOFError: break else: - if kind == "stdout": - self.signals.stdout_received.emit(text) - elif kind == "stderr": + if kind == "stderr": self.signals.stderr_received.emit(text) else: - self.signals.progress_received.emit(text) + self.signals.stdout_received.emit(text) def run_func(func, keywargs, pipe=None): @@ -277,14 +273,14 @@ def process_finished(self, result): def finished(self): for name, func, error in self.errors: - logging.critical(f"Error in {name} <- {func}: {error}") + logger().critical(f"Error in {name} <- {func}: {error}") def prepare_start(self): # Take first step of all_steps until there are no steps left. if len(self.all_steps) > 0: # Getting information as encoded in init_lists self.current_obj_name, self.current_func = self.all_steps.pop(0) - logging.debug( + logger().debug( f"Running {self.current_func} for " f"{self.current_obj_name}" ) # Get current object @@ -312,7 +308,7 @@ def start(self): kwds = dict() kwds["func"] = get_func(self.current_func, self.current_object) kwds["keywargs"] = get_arguments(kwds["func"], self.current_object) - logging.info( + logger().info( f"########################################\n" f"Running {self.current_func} for {self.current_obj_name}\n" f"########################################\n" @@ -418,6 +414,9 @@ def finished(self): self.rd.restart_bt.setEnabled(True) self.rd.close_bt.setEnabled(True) + if not self.ct.get_setting("show_plots"): + close_all() + if self.ct.get_setting("shutdown"): self.ct.save() ans = TimedMessageBox.information( @@ -445,19 +444,19 @@ def start(self): or (ismpl and show_plots and use_qthread) or (ismpl and not show_plots and use_qthread and ismac) ): - logging.info("Starting in Main-Thread.") + logger().info("Starting in Main-Thread.") result = run_func(**kwds) self.process_finished(result) elif QS().value("use_qthread"): - logging.info("Starting in separate Thread.") + logger().info("Starting in separate Thread.") worker = Worker(function=run_func, **kwds) worker.signals.error.connect(self.process_finished) worker.signals.finished.connect(self.process_finished) QThreadPool.globalInstance().start(worker) else: - logging.info("Starting in process from multiprocessing.") + logger().info("Starting in process from multiprocessing.") recv_pipe, send_pipe = Pipe(False) kwds["pipe"] = send_pipe stream_rcv = StreamReceiver(recv_pipe) @@ -467,11 +466,13 @@ def start(self): stream_rcv.signals.stderr_received.connect( self.rd.console_widget.write_stderr ) - stream_rcv.signals.progress_received.connect( - self.rd.console_widget.write_progress - ) QThreadPool.globalInstance().start(stream_rcv) # ToDO: MP self.pool.apply_async( func=run_func, kwds=kwds, callback=self.process_finished ) + + +def close_all(): + plt.close("all") + gc.collect() diff --git a/mne_pipeline_hd/pipeline/legacy.py b/mne_pipeline_hd/pipeline/legacy.py index 7c2cafb..fcaf6c1 100644 --- a/mne_pipeline_hd/pipeline/legacy.py +++ b/mne_pipeline_hd/pipeline/legacy.py @@ -5,14 +5,13 @@ Github: https://github.com/marsipu/mne-pipeline-hd """ import json -import logging import os import subprocess import sys from os.path import isdir, join, isfile from mne_pipeline_hd.pipeline.loading import MEEG, FSMRI, Group -from mne_pipeline_hd.pipeline.pipeline_utils import type_json_hook +from mne_pipeline_hd.pipeline.pipeline_utils import type_json_hook, logger renamed_parameters = { "filter_target": {"Raw": "raw", "Epochs": "epochs", "Evoked": "evoked"}, @@ -51,7 +50,7 @@ def install_package(package_name): - logging.info(f"Installing {package_name}...") + logger().info(f"Installing {package_name}...") print( subprocess.check_output( [sys.executable, "-m", "pip", "install", package_name], text=True @@ -60,7 +59,7 @@ def install_package(package_name): def uninstall_package(package_name): - logging.info(f"Uninstalling {package_name}...") + logger().info(f"Uninstalling {package_name}...") print( subprocess.check_output( [sys.executable, "-m", "pip", "uninstall", "-y", package_name], text=True @@ -81,7 +80,7 @@ def legacy_import_check(test_package=None): try: __import__(import_name) except ImportError: - logging.info( + logger().info( f"The package {import_name} " f"is required for this application.\n" ) ans = input("Do you want to install the " "new package now? [y/n]").lower() @@ -89,10 +88,10 @@ def legacy_import_check(test_package=None): try: install_package(install_name) except subprocess.CalledProcessError: - logging.critical("Installation failed!") + logger().critical("Installation failed!") else: return - logging.info( + logger().info( f"Please install the new package {import_name} " f"manually with:\n\n" f"> pip install {install_name}" @@ -103,7 +102,7 @@ def legacy_import_check(test_package=None): def transfer_file_params_to_single_subject(ct): old_fp_path = join(ct.pr.pscripts_path, f"file_parameters_{ct.pr.name}.json") if isfile(old_fp_path): - logging.info("Transfering File-Parameters to single files...") + logger().info("Transfering File-Parameters to single files...") with open(old_fp_path, "r") as file: file_parameters = json.load(file, object_hook=type_json_hook) for obj_name in file_parameters: @@ -122,4 +121,4 @@ def transfer_file_params_to_single_subject(ct): obj.save_file_parameter_file() obj.clean_file_parameters() os.remove(old_fp_path) - logging.info("Done!") + logger().info("Done!") diff --git a/mne_pipeline_hd/pipeline/loading.py b/mne_pipeline_hd/pipeline/loading.py index cc8808c..17e91d7 100644 --- a/mne_pipeline_hd/pipeline/loading.py +++ b/mne_pipeline_hd/pipeline/loading.py @@ -11,7 +11,6 @@ import inspect import itertools import json -import logging import os import pickle import shutil @@ -22,6 +21,7 @@ import matplotlib.pyplot as plt import mne +import mne_connectivity import numpy as np from tqdm import tqdm @@ -30,6 +30,7 @@ type_json_hook, QS, _test_run, + logger, ) @@ -53,7 +54,7 @@ def load_decorator(load_func): def load_wrapper(self, *args, **kwargs): # Get matching data-type from IO-Dict data_type = _get_data_type_from_func(self, load_func, "load") - logging.info(f"Loading {data_type} for {self.name}") + logger().info(f"Loading {data_type} for {self.name}") if data_type in self.data_dict: data = self.data_dict[data_type] @@ -73,7 +74,7 @@ def load_wrapper(self, *args, **kwargs): self.io_dict[data_type]["path"] = dp data = load_func(self, *args, **kwargs) self.io_dict[data_type]["path"] = new_path - logging.info( + logger().info( f"Deprecated path: Saving file for " f"{data_type} in updated path..." ) @@ -85,7 +86,7 @@ def load_wrapper(self, *args, **kwargs): os.remove(dp) elif self.p_preset != "Default": - logging.info( + logger().info( f"No File for {data_type} from {self.name}" f" with Parameter-Preset={self.p_preset} found," f" trying Default" @@ -130,7 +131,7 @@ def save_wrapper(self, *args, **kwargs): for path in [p for p in paths if not isdir(Path(p).parent)]: makedirs(Path(path).parent, exist_ok=True) - logging.info(f"Saving {data_type} for {self.name}") + logger().info(f"Saving {data_type} for {self.name}") save_func(self, *args, **kwargs) # Save data in data-dict for machines with big RAM @@ -320,7 +321,7 @@ def clean_file_parameters(self): for file_name in remove_files: self.file_parameters.pop(file_name) - logging.info( + logger().info( f"Removed {len(remove_files)} Files " f"and {n_remove_params} Parameters." ) self.save_file_parameter_file() @@ -423,7 +424,7 @@ def plot_save( ) idx_file_path = join(dir_path, idx_file_name) figure.savefig(idx_file_path) - logging.info(f"figure: {idx_file_path} has been saved") + logger().info(f"figure: {idx_file_path} has been saved") # Only store relative path to be compatible across OS plot_files_save_path = os.path.relpath( idx_file_path, self.figures_path @@ -438,7 +439,7 @@ def plot_save( if self.img_format != ".svg": file_name = file_name.strip(self.img_format) + ".svg" save_path = join(dir_path, file_name) - logging.info("Pyvista-Plots are saved as .svg") + logger().info("Pyvista-Plots are saved as .svg") pyvista_figure.plotter.save_graphics(save_path, title=file_name) elif brain: if brain_movie_kwargs is not None: @@ -451,7 +452,7 @@ def plot_save( brain.save_image(save_path) else: plt.savefig(save_path, dpi=dpi) - logging.info(f"figure: {save_path} has been saved") + logger().info(f"figure: {save_path} has been saved") if not isinstance(matplotlib_figure, list): # Only store relative path to be compatible across OS @@ -460,7 +461,7 @@ def plot_save( if plot_files_save_path not in self.plot_files[calling_func]: self.plot_files[calling_func].append(plot_files_save_path) else: - logging.info('Not saving plots; set "save_plots" to "True" to save') + logger().info('Not saving plots; set "save_plots" to "True" to save') # ToDo: Should have load-decorator! def load(self, data_type, **kwargs): @@ -482,10 +483,10 @@ def load_json(self, file_name, default=None): with open(file_path, "r") as file: data = json.load(file, object_hook=type_json_hook) except json.JSONDecodeError: - logging.warning(f"{file_path} could not be loaded") + logger().warning(f"{file_path} could not be loaded") data = default except FileNotFoundError: - logging.warning(f"{file_path} could not be found") + logger().warning(f"{file_path} could not be found") data = default return data @@ -499,7 +500,7 @@ def save_json(self, file_name, data): with open(file_path, "w") as file: json.dump(data, file, cls=TypedJSONEncoder, indent=4) except json.JSONDecodeError: - logging.warning(f"{file_path} could not be saved") + logger().warning(f"{file_path} could not be saved") self.save_file_params(file_path) @@ -508,11 +509,11 @@ def remove_json(self, file_name): try: os.remove(file_path) except FileNotFoundError: - logging.warning(f"{file_path} was not found") + logger().warning(f"{file_path} was not found") except OSError as err: - logging.warning(f"{file_path} could not be removed due to {err}") + logger().warning(f"{file_path} could not be removed due to {err}") else: - logging.warning(f"{file_path} was removed") + logger().warning(f"{file_path} was removed") def get_existing_paths(self): """Get existing paths and add the mapped File-Type @@ -550,7 +551,7 @@ def remove_path(self, data_type): for pn in [p_name_lh, p_name_rh]: self.file_parameters.pop(pn) except KeyError: - logging.warning(f"{Path(p).name} not in file-parameters") + logger().warning(f"{Path(p).name} not in file-parameters") try: os.remove(p) except FileNotFoundError: @@ -561,16 +562,16 @@ def remove_path(self, data_type): for ps in [p_lh, p_rh]: os.remove(ps) except FileNotFoundError: - logging.warning(f"{p} was not found") + logger().warning(f"{p} was not found") except IsADirectoryError: try: shutil.rmtree(p) except OSError as err: - logging.warning(f"{p} could not be removed due to {err}") + logger().warning(f"{p} could not be removed due to {err}") except OSError as err: - logging.warning(f"{p} could not be removed due to {err}") + logger().warning(f"{p} could not be removed due to {err}") else: - logging.warning(f"{p} was removed") + logger().warning(f"{p} was removed") sample_paths = { @@ -605,7 +606,7 @@ def init_attributes(self): if self.name not in self.pr.meeg_to_erm: self.erm = None if not self.suppress_warnings: - logging.warning( + logger().warning( f"No Empty-Room-Measurement assigned for {self.name}," f' defaulting to "None"' ) @@ -625,7 +626,7 @@ def init_attributes(self): else: self.fsmri = FSMRI(None, self.ct) if not self.suppress_warnings: - logging.warning( + logger().warning( f"No Freesurfer-MRI-Subject assigned for {self.name}," f' defaulting to "None"' ) @@ -634,7 +635,7 @@ def init_attributes(self): if self.name not in self.pr.meeg_bad_channels: self.bad_channels = list() if not self.suppress_warnings: - logging.warning( + logger().warning( f"No bad channels assigned for {self.name}," f" defaulting to empty list" ) @@ -643,19 +644,23 @@ def init_attributes(self): # The selected trials from the event-id if self.name not in self.pr.sel_event_id: - self.sel_trials = list() + self.sel_trials = dict() if not self.suppress_warnings: - logging.warning( + logger().warning( f"No Trials selected for {self.name}," f" defaulting to empty list" ) else: self.sel_trials = self.pr.sel_event_id[self.name] + # Legacy for before when sel_event_id was a list + if isinstance(self.sel_trials, list): + self.sel_trials = {k: None for k in self.sel_trials} + self.pr.sel_event_id[self.name] = self.sel_trials # The assigned event-id if self.name not in self.pr.meeg_event_id: self.event_id = dict() if not self.suppress_warnings: - logging.warning( + logger().warning( f"No EventID assigned for {self.name}," f" defaulting to empty dictionary" ) @@ -787,7 +792,7 @@ def init_paths(self): con_method: join( self.save_dir, "connectivity", - f"{self.name}_{trial}_{self.p_preset}_{con_method}-con.npy", + f"{self.name}_{trial}_{self.p_preset}_{con_method}-con.nc", ) for con_method in self.pa["con_methods"] } @@ -949,19 +954,26 @@ def init_sample(self): self.fsmri = FSMRI("fsaverage", self.ct) # Add event_id - self.event_id = { - "auditory/left": 1, - "auditory/right": 2, - "visual/left": 3, - "visual/right": 4, - "face": 5, - "buttonpress": 32, - } - self.pr.meeg_event_id[self.name] = self.event_id + if self.name not in self.pr.meeg_event_id: + self.event_id = { + "auditory/left": 1, + "auditory/right": 2, + "visual/left": 3, + "visual/right": 4, + "face": 5, + "buttonpress": 32, + } + self.pr.meeg_event_id[self.name] = self.event_id + else: + self.event_id = self.pr.meeg_event_id[self.name] + # ToDo: Here is problem, since there is no way # to select "auditory/left" from the gui. - self.sel_trials = ["auditory"] - self.pr.sel_event_id[self.name] = self.sel_trials + if self.name not in self.pr.sel_event_id: + self.sel_trials = {"auditory": None} + self.pr.sel_event_id[self.name] = self.sel_trials + else: + self.sel_trials = self.pr.sel_event_id[self.name] # init paths again self.init_paths() @@ -974,18 +986,18 @@ def init_sample(self): test_file_path = join(test_data_folder, test_file_name) file_path = self.io_dict[data_type]["path"] if data_type == "stcs": - file_path = file_path[self.sel_trials[0]] + file_path = file_path["auditory"] if not isfile(file_path + "-lh.stc"): - logging.debug(f"Copying {data_type} from sample-dataset...") + logger().debug(f"Copying {data_type} from sample-dataset...") stcs = mne.source_estimate.read_source_estimate(test_file_path) stcs.save(file_path) elif isfile(test_file_path) and not isfile(file_path): - logging.debug(f"Copying {data_type} from sample-dataset...") + logger().debug(f"Copying {data_type} from sample-dataset...") folder = Path(file_path).parent if not isdir(folder): os.mkdir(folder) shutil.copy2(test_file_path, file_path) - logging.debug("Done!") + logger().debug("Done!") # Add bad_channels self.bad_channels = self.load_info()["bads"] @@ -1054,7 +1066,6 @@ def save_raw(self, raw): @load_decorator def load_filtered(self): raw = mne.io.read_raw_fif(self.raw_filtered_path, preload=True) - raw.info["bads"] = [bc for bc in self.bad_channels if bc in raw.ch_names] return raw @save_decorator @@ -1098,6 +1109,19 @@ def load_epochs(self): def save_epochs(self, epochs): epochs.save(self.epochs_path, overwrite=True) + def get_trial_epochs(self): + """Return epochs for each trial in self.sel_trials""" + epochs = self.load_epochs() + for trial, meta_query in self.sel_trials.items(): + epoch_trial = meta_query or trial + # ToDo: Make this optional (at own risk) and not for normal trials + try: + epoch_trial = eval(epoch_trial) + except (NameError, SyntaxError, ValueError, TypeError): + pass + + yield trial, epochs[epoch_trial] + @load_decorator def load_reject_log(self): with open(self.reject_log_path, "rb") as file: @@ -1296,7 +1320,7 @@ def load_mixn_dipoles(self): dip_list.append(mne.read_dipole(mixn_dip_path)) idx += 1 mixn_dips[trial] = dip_list - logging.info(f"{idx + 1} dipoles read for {self.name}-{trial}") + logger().info(f"{idx + 1} dipoles read for {self.name}-{trial}") return mixn_dips @@ -1381,22 +1405,22 @@ def save_ltc(self, ltcs): @load_decorator def load_connectivity(self): - con_dict = {"__info__": self.load_json("con_labels")} + con_dict = dict() for trial in self.con_paths: con_dict[trial] = dict() - for con_method in self.con_paths[trial]: - con_dict[trial][con_method] = np.load(self.con_paths[trial][con_method]) + for con_method, con_path in self.con_paths[trial].items(): + con_dict[trial][con_method] = mne_connectivity.read_connectivity( + con_path + ) return con_dict @save_decorator def save_connectivity(self, con_dict): # Write info about label and parcellation into json - con_info = con_dict.pop("__info__") - self.save_json("con_labels", con_info) for trial in con_dict: - for con_method in con_dict[trial]: - np.save(self.con_paths[trial][con_method], con_dict[trial][con_method]) + for con_method, con in con_dict[trial].items(): + con.save(self.con_paths[trial][con_method]) fsaverage_paths = { @@ -1481,7 +1505,7 @@ def init_fsaverage(self): # so fsaverage will be downloaded to "~/mne_data/MNE-fsaverage-data" if _test_run(): mne.set_config("SUBJECTS_DIR", None) - logging.info("Downloading fsaverage...") + logger().info("Downloading fsaverage...") fsaverage_dir = mne.datasets.fetch_fsaverage(subjects_dir=None) if _test_run(): mne.set_config("SUBJECTS_DIR", self.ct.subjects_dir) @@ -1495,15 +1519,15 @@ def init_fsaverage(self): to_path = self.io_dict[data_type]["path"] if not isfile(to_path): os.rename(from_path, to_path) - logging.info(f"Renamed {from_path} to {to_path}") + logger().info(f"Renamed {from_path} to {to_path}") def _get_available_parc(self): annot_dir = join(self.subjects_dir, self.name, "label") try: files = os.listdir(annot_dir) - annotations = [file[3:-6] for file in files if file[-6:] == ".annot"] + annotations = set([file[3:-6] for file in files if file[-6:] == ".annot"]) except FileNotFoundError: - annotations = list() + annotations = set() return annotations @@ -1521,10 +1545,10 @@ def _get_available_labels(self): try: label = mne.read_label(join(label_dir, label_path), self.name) except ValueError: - logging.warning(f"Label {label_path} could not be loaded!") + logger().warning(f"Label {label_path} could not be loaded!") labels["Other"].append(label) except FileNotFoundError: - logging.warning(f"No label directory found for {self.name}!") + logger().warning(f"No label directory found for {self.name}!") if self.parcellations is None: self.parcellations = self._get_available_parc() @@ -1540,24 +1564,36 @@ def _get_available_labels(self): verbose="warning", ) except (RuntimeError, OSError): - logging.warning(f"Parcellation {parcellation} could not be loaded!") + logger().warning(f"Parcellation {parcellation} could not be loaded!") return labels def get_labels(self, target_labels=None, parcellation=None): labels = list() if self.name is None: - logging.warning("FSMRI-Object has no name and is empty!") + logger().warning("FSMRI-Object has no name and is empty!") else: - parcellation = parcellation or "Other" + # Get available parcellations if self.labels is None: self.labels = self._get_available_labels() + + # Subselect labels with parcellation + if parcellation is None: + search_labels = list() + for parcellation in self.labels: + search_labels += self.labels[parcellation] + else: + if parcellation in self.labels: + search_labels = self.labels[parcellation] + else: + raise RuntimeError( + f"Parcellation '{parcellation}' not found for {self.name}!" + ) + if target_labels is not None: - labels += [ - lb for lb in self.labels[parcellation] if lb.name in target_labels - ] + labels += [lb for lb in search_labels if lb.name in target_labels] else: - labels = self.labels[parcellation] + labels = search_labels return labels @@ -1611,7 +1647,7 @@ def init_attributes(self): if self.name not in self.pr.all_groups: self.group_list = [] if not self.suppress_warnings: - logging.warning( + logger().warning( f"No objects assigned for {self.name}," f" defaulting to empty list" ) else: @@ -1624,12 +1660,11 @@ def init_attributes(self): ]: self.event_id = {**self.event_id, **self.ct.pr.meeg_event_id[group_item]} - # The selected trials from the event-id - self.sel_trials = set() - for group_item in [ - gi for gi in self.group_list if gi in self.ct.pr.sel_event_id - ]: - self.sel_trials = self.sel_trials | set(self.ct.pr.sel_event_id[group_item]) + # The selected trials from the event-id (assume first to allow meta-queries) + if len(self.group_list) > 0: + self.sel_trials = MEEG(self.group_list[0], self.ct).sel_trials + else: + self.sel_trials = dict() # The fsmri where all group members are morphed to self.fsmri = FSMRI(self.pa["morph_to"], self.ct) @@ -1681,7 +1716,7 @@ def init_paths(self): con_method: join( self.save_dir, "connectivity", - f"{self.name}_{trial}_" f"{self.p_preset}_{con_method}.npy", + f"{self.name}_{trial}_" f"{self.p_preset}_{con_method}.nc", ) for con_method in self.pa["con_methods"] } @@ -1731,7 +1766,7 @@ def load_items(self, obj_type="MEEG", data_type=None): elif obj_type == "FSMRI": obj = FSMRI(obj_name, self.ct) else: - logging.error(f"The object-type {obj_type} is not valid!") + logger().error(f"The object-type {obj_type} is not valid!") continue if data_type is None: yield obj @@ -1739,7 +1774,7 @@ def load_items(self, obj_type="MEEG", data_type=None): data = obj.io_dict[data_type]["load"]() yield data, obj else: - logging.error(f"{data_type} is not valid for {obj_type}") + logger().error(f"{data_type} is not valid for {obj_type}") @load_decorator def load_ga_evokeds(self): @@ -1800,20 +1835,18 @@ def save_ga_ltc(self, ga_ltc): @load_decorator def load_ga_con(self): - ga_connect = {"__info__": self.load_json("con_labels")} + ga_connect = dict() for trial in self.ga_con_paths: ga_connect[trial] = {} - for con_method in self.ga_con_paths[trial]: - ga_connect[trial][con_method] = np.load( - self.ga_con_paths[trial][con_method] + for con_method, con_path in self.ga_con_paths[trial].items(): + ga_connect[trial][con_method] = mne_connectivity.read_connectivity( + con_path ) return ga_connect @save_decorator - def save_ga_con(self, ga_con): - con_info = ga_con.pop("__info__") - self.save_json("con_labels", con_info) - for trial in ga_con: - for con_method in ga_con[trial]: - np.save(self.ga_con_paths[trial][con_method], ga_con[trial][con_method]) + def save_ga_con(self, ga_con_dict): + for trial in ga_con_dict: + for con_method, ga_con in ga_con_dict[trial].items(): + ga_con.save(self.ga_con_paths[trial][con_method]) diff --git a/mne_pipeline_hd/pipeline/pipeline_utils.py b/mne_pipeline_hd/pipeline/pipeline_utils.py index b9c20d8..8b3ca3b 100644 --- a/mne_pipeline_hd/pipeline/pipeline_utils.py +++ b/mne_pipeline_hd/pipeline/pipeline_utils.py @@ -21,7 +21,7 @@ import numpy as np import psutil -import mne_pipeline_hd +from mne_pipeline_hd import extra datetime_format = "%d.%m.%Y %H:%M:%S" @@ -29,6 +29,34 @@ iswin = sys.platform.startswith("win32") islin = not ismac and not iswin +_logger = None + + +def init_logging(debug_mode=False): + global _logger + # Initialize Logger + _logger = logging.getLogger("mne_pipeline_hd") + if debug_mode: + _logger.setLevel(logging.DEBUG) + else: + _logger.setLevel(QS().value("log_level", defaultValue=logging.INFO)) + if debug_mode: + fmt = "[%(levelname)s] %(module)s.%(funcName)s(): %(message)s" + else: + fmt = "[%(levelname)s] %(message)s" + formatter = logging.Formatter(fmt) + console_handler = logging.StreamHandler() + console_handler.set_name("console") + console_handler.setFormatter(formatter) + _logger.addHandler(console_handler) + + +def logger(): + global _logger + if _logger is None: + _logger = logging.getLogger("mne_pipeline_hd") + return _logger + def get_n_jobs(n_jobs): """Get the number of jobs to use for parallel processing""" @@ -137,12 +165,12 @@ def compare_filep(obj, path, target_parameters=None, verbose=True): if str(previous_value) == str(current_value): result_dict[param] = "equal" if verbose: - logging.debug(f"{param} equal for {file_name}") + logger().debug(f"{param} equal for {file_name}") else: if param in critical_params: result_dict[param] = (previous_value, current_value, True) if verbose: - logging.debug( + logger().debug( f"{param} changed from {previous_value} to " f"{current_value} for {file_name} " f"and is probably crucial for {function}" @@ -150,19 +178,19 @@ def compare_filep(obj, path, target_parameters=None, verbose=True): else: result_dict[param] = (previous_value, current_value, False) if verbose: - logging.debug( + logger().debug( f"{param} changed from {previous_value} to " f"{current_value} for {file_name}" ) except KeyError: result_dict[param] = "missing" if verbose: - logging.warning(f"{param} is missing in records for {file_name}") + logger().warning(f"{param} is missing in records for {file_name}") if obj.ct.settings["overwrite"]: result_dict[param] = "overwrite" if verbose: - logging.info( + logger().info( f"{file_name} will be overwritten anyway" f" because Overwrite=True (Settings)" ) @@ -210,13 +238,13 @@ def shutdown(): def restart_program(): """Restarts the current program, with file objects and descriptors cleanup.""" - logging.info("Restarting") + logger().info("Restarting") try: p = psutil.Process(os.getpid()) for handler in p.open_files() + p.connections(): os.close(handler.fd) except Exception as e: - logging.error(e) + logger().error(e) python = sys.executable os.execl(python, python, *sys.argv) @@ -234,9 +262,7 @@ def _get_func_param_kwargs(func, params): class BaseSettings: def __init__(self): # Load default settings - default_settings_path = join( - resources.files(mne_pipeline_hd.extra), "default_settings.json" - ) + default_settings_path = join(resources.files(extra), "default_settings.json") with open(default_settings_path, "r") as file: self.default_qsettings = json.load(file)["qsettings"] diff --git a/mne_pipeline_hd/pipeline/project.py b/mne_pipeline_hd/pipeline/project.py index 69ff8e6..8f4a9cd 100644 --- a/mne_pipeline_hd/pipeline/project.py +++ b/mne_pipeline_hd/pipeline/project.py @@ -6,7 +6,6 @@ """ import json -import logging import os import shutil from ast import literal_eval @@ -25,6 +24,7 @@ count_dict_keys, encode_tuples, type_json_hook, + logger, ) @@ -72,7 +72,7 @@ def init_main_paths(self): for path in self.main_paths: if not exists(path): makedirs(path) - logging.debug(f"{path} created") + logger().debug(f"{path} created") def init_attributes(self): # Stores the names of all MEG/EEG-Files @@ -181,6 +181,23 @@ def init_pipeline_scripts(self): self.sel_p_preset_path: "p_preset", } + def rename(self, new_name): + # Rename folder + old_name = self.name + os.rename(self.project_path, join(self.ct.projects_path, new_name)) + self.name = new_name + self.init_main_paths() + # Rename project-files + old_paths = [Path(p).name for p in self.path_to_attribute] + self.init_pipeline_scripts() + new_paths = [Path(p).name for p in self.path_to_attribute] + for old_path, new_path in zip(old_paths, new_paths): + os.rename( + join(self.pscripts_path, old_path), join(self.pscripts_path, new_path) + ) + logger().info(f"Renamed project-script {old_path} to {new_path}") + logger().info(f'Finished renaming project "{old_name}" to "{new_name}"') + def load_lists(self): # Old Paths to allow transition (22.11.2020) self.old_all_meeg_path = join(self.pscripts_path, "file_list.json") @@ -364,7 +381,7 @@ def save(self, worker_signals=None): json.dump(attribute, file, cls=TypedJSONEncoder, indent=4) except json.JSONDecodeError as err: - logging.warning(f"There is a problem with path:\n" f"{err}") + logger().warning(f"There is a problem with path:\n" f"{err}") def add_meeg(self, name, file_path=None, is_erm=False): if is_erm: @@ -395,7 +412,7 @@ def remove_meeg(self, remove_files): # Remove MEEG from Lists/Dictionaries self.all_meeg.remove(meeg) except ValueError: - logging.warning(f"{meeg} already removed!") + logger().warning(f"{meeg} already removed!") self.meeg_to_erm.pop(meeg, None) self.meeg_to_fsmri.pop(meeg, None) self.meeg_bad_channels.pop(meeg, None) @@ -404,9 +421,9 @@ def remove_meeg(self, remove_files): try: remove_path = join(self.data_path, meeg) shutil.rmtree(remove_path) - logging.info(f"Succesful removed {remove_path}") + logger().info(f"Succesful removed {remove_path}") except FileNotFoundError: - logging.critical(join(self.data_path, meeg) + " not found!") + logger().critical(join(self.data_path, meeg) + " not found!") self.sel_meeg.clear() def add_fsmri(self, name, src_dir=None): @@ -416,15 +433,15 @@ def add_fsmri(self, name, src_dir=None): if src_dir is not None: dst_dir = join(self.ct.subjects_dir, name) if not isdir(dst_dir): - logging.debug(f"Copying Folder from {src_dir}...") + logger().debug(f"Copying Folder from {src_dir}...") try: shutil.copytree(src_dir, dst_dir) # surfaces with .H and .K at the end can't be copied except shutil.Error: pass - logging.debug(f"Finished Copying to {dst_dir}") + logger().debug(f"Finished Copying to {dst_dir}") else: - logging.info(f"{dst_dir} already exists") + logger().info(f"{dst_dir} already exists") return fsmri @@ -433,12 +450,12 @@ def remove_fsmri(self, remove_files): try: self.all_fsmri.remove(fsmri) except ValueError: - logging.warning(f"{fsmri} already deleted!") + logger().warning(f"{fsmri} already deleted!") if remove_files: try: shutil.rmtree(join(self.ct.subjects_dir, fsmri)) except FileNotFoundError: - logging.info(join(self.ct.subjects_dir, fsmri) + " not found!") + logger().info(join(self.ct.subjects_dir, fsmri) + " not found!") self.sel_fsmri.clear() def add_group(self): @@ -488,7 +505,7 @@ def clean_file_parameters(self, worker_signals=None): if worker_signals is not None: worker_signals.pgbar_n.emit(count) if worker_signals.was_canceled: - logging.info("Cleaning was canceled by the user!") + logger().info("Cleaning was canceled by the user!") return for fsmri in self.all_fsmri: @@ -499,7 +516,7 @@ def clean_file_parameters(self, worker_signals=None): if worker_signals is not None: worker_signals.pgbar_n.emit(count) if worker_signals.was_canceled: - logging.info("Cleaning was canceled by the user!") + logger().info("Cleaning was canceled by the user!") return for group in self.all_groups: @@ -510,7 +527,7 @@ def clean_file_parameters(self, worker_signals=None): if worker_signals is not None: worker_signals.pgbar_n.emit(count) if worker_signals.was_canceled: - logging.info("Cleaning was canceled by the user!") + logger().info("Cleaning was canceled by the user!") return def clean_plot_files(self, worker_signals=None): @@ -561,7 +578,7 @@ def clean_plot_files(self, worker_signals=None): worker_signals is not None and worker_signals.was_canceled ): - logging.info("Cleaning was canceled by user") + logger().info("Cleaning was canceled by user") return if func not in self.ct.pd_funcs.index: @@ -596,18 +613,18 @@ def clean_plot_files(self, worker_signals=None): self.plot_files[obj_key].pop(remove_preset_key) n_remove_ppreset += len(remove_p_preset) - logging.info( + logger().info( f"Removed {n_remove_ppreset} Parameter-Presets and " f"{n_remove_funcs} from {obj_key}" ) for remove_key in remove_obj: self.plot_files.pop(remove_key) - logging.info(f"Removed {len(remove_obj)} Objects from Plot-Files") + logger().info(f"Removed {len(remove_obj)} Objects from Plot-Files") # Remove image-files, which aren't listed in plot_files. free_space = 0 - logging.info("Removing unregistered images...") + logger().info("Removing unregistered images...") n_removed_images = 0 for root, _, files in os.walk(self.figures_path): files = [join(root, f) for f in files] @@ -617,10 +634,10 @@ def clean_plot_files(self, worker_signals=None): free_space += getsize(file_path) n_removed_images += 1 os.remove(file_path) - logging.info(f"Removed {n_removed_images} images") + logger().info(f"Removed {n_removed_images} images") # Remove empty folders (loop until all empty folders are removed) - logging.info("Removing empty folders...") + logger().info("Removing empty folders...") n_removed_folders = 0 folder_loop = True # Redo the file-walk because folders can get empty @@ -633,6 +650,6 @@ def clean_plot_files(self, worker_signals=None): os.rmdir(folder) n_removed_folders += 1 folder_loop = True - logging.info(f"Removed {n_removed_folders} folders") + logger().info(f"Removed {n_removed_folders} folders") - logging.info(f"{round(free_space / (1024 ** 2), 2)} MB of space was freed!") + logger().info(f"{round(free_space / (1024 ** 2), 2)} MB of space was freed!") diff --git a/mne_pipeline_hd/tests/test_concurrent.py b/mne_pipeline_hd/tests/test_concurrent.py index 5a5e76f..84099b3 100644 --- a/mne_pipeline_hd/tests/test_concurrent.py +++ b/mne_pipeline_hd/tests/test_concurrent.py @@ -8,14 +8,14 @@ # def test_blocking_worker_dialog(qtbot): # def _test_func(): # time.sleep(2) -# logging.info('Finished Test-Func') +# logger().info('Finished Test-Func') # # time1 = time.time() # dlg = WorkerDialog(None, _test_func, blocking=True) # qtbot.addWidget(dlg) # time2 = time.time() # -# logging.info(f'Worker-Dialog took {round(time2 - time1, 2)} s') +# logger().info(f'Worker-Dialog took {round(time2 - time1, 2)} s') # assert time2 - time1 >= 2 # def test_qprocess_worker(qtbot): diff --git a/mne_pipeline_hd/tests/test_console.py b/mne_pipeline_hd/tests/test_console.py new file mode 100644 index 0000000..e7f1329 --- /dev/null +++ b/mne_pipeline_hd/tests/test_console.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" +Authors: Martin Schulz +License: BSD 3-Clause +Github: https://github.com/marsipu/mne-pipeline-hd +""" +import os + +import pytest + +from mne_pipeline_hd.__main__ import init_streams +from mne_pipeline_hd.gui.gui_utils import MainConsoleWidget +from mne_pipeline_hd.pipeline.pipeline_utils import logger, init_logging + + +def test_logging(qtbot): + """Test streaming and logging to GUI-Console.""" + # Enable debugging + os.environ["MNEPHD_DEBUG"] = "true" + + init_streams() + init_logging() + + console = MainConsoleWidget() + qtbot.addWidget(console) + + wait_time = console.buffer_time * 2 + + print("Print-Test") + qtbot.wait(wait_time) + assert "Print-Test" in console.toPlainText() + + with pytest.raises(RuntimeError): + raise RuntimeError("Test-Error") + qtbot.wait(wait_time) + assert "Test-Error" in console.toPlainText() + + logger().info("Logging-Test") + qtbot.wait(wait_time) + assert "[INFO] Logging-Test" in console.toPlainText() diff --git a/mne_pipeline_hd/tests/test_loading.py b/mne_pipeline_hd/tests/test_loading.py index d690804..055af59 100644 --- a/mne_pipeline_hd/tests/test_loading.py +++ b/mne_pipeline_hd/tests/test_loading.py @@ -4,14 +4,15 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ -import logging import pytest +from mne_pipeline_hd.pipeline.pipeline_utils import logger + def _test_load_save(obj, available_test_paths, excepted_data_types=[]): for data_type in [d for d in obj.io_dict if d not in excepted_data_types]: - logging.info(f"Testing {data_type}") + logger().info(f"Testing {data_type}") if data_type not in available_test_paths: with pytest.raises((OSError, FileNotFoundError)): obj.load(data_type) diff --git a/mne_pipeline_hd/tests/test_parameter_widgets.py b/mne_pipeline_hd/tests/test_parameter_widgets.py index aeac262..322c097 100644 --- a/mne_pipeline_hd/tests/test_parameter_widgets.py +++ b/mne_pipeline_hd/tests/test_parameter_widgets.py @@ -138,27 +138,26 @@ def test_basic_param_guis(qtbot, gui_name): # Test MultiTypeGui if gui_name == "MultiTypeGui": - for gui_type, gui_name in gui.gui_types.items(): - gui.set_param(parameters[gui_name]) - assert gui.get_value() == parameters[gui_name] + for gui_type, type_gui_name in gui.gui_types.items(): + gui.set_param(parameters[type_gui_name]) + assert gui.get_value() == parameters[type_gui_name] assert type(gui.get_value()).__name__ == gui_type kwargs["type_selection"] = True kwargs["type_kwargs"] = dict() - for gui_name in gui.gui_types.values(): - type_class = getattr(parameter_widgets, gui_name) + for type_gui_name in gui.gui_types.values(): + type_class = getattr(parameter_widgets, type_gui_name) gui_parameters = list(inspect.signature(type_class).parameters) + list( inspect.signature(Param).parameters ) t_kwargs = { key: value for key, value in gui_kwargs.items() if key in gui_parameters } - kwargs["type_kwargs"][gui_name] = t_kwargs + kwargs["type_kwargs"][type_gui_name] = t_kwargs gui = gui_class(data=parameters, name=gui_name, **kwargs) - for gui_type, gui_name in gui.gui_types.items(): - type_idx = gui.types.index(gui_type) + for type_idx, (gui_type, type_gui_name) in enumerate(gui.gui_types.items()): gui.change_type(type_idx) - gui.set_param(parameters[gui_name]) - assert gui.get_value() == parameters[gui_name] + gui.set_param(parameters[type_gui_name]) + assert gui.get_value() == parameters[type_gui_name] assert type(gui.get_value()).__name__ == gui_type @@ -184,13 +183,16 @@ def test_label_gui(qtbot, controller): label_gui.param_widget.click() dlg = label_gui._dialog - # Test start labels checked + # Test start labels in checked assert ["insula-lh", "postcentral-lh"] == dlg._selected_parc_labels assert "lh.BA1-lh" in dlg._selected_extra_labels # Open Parc-Picker dlg.choose_parc_bt.click() parc_plot = dlg._parc_picker._renderer.plotter + # Select "aparc" parcellation + dlg.parcellation_cmbx.setCurrentText("aparc") + dlg._parc_changed() # Only triggered by mouse click with .activated # Check if start labels are shown assert "insula-lh" in dlg._parc_picker._shown_labels assert "postcentral-lh" in dlg._parc_picker._shown_labels @@ -220,17 +222,23 @@ def test_label_gui(qtbot, controller): # Change parcellation dlg.parcellation_cmbx.setCurrentText("aparc.a2009s") dlg._parc_changed() # Only triggered by mouse click with .activated - # Check if labels where removed - assert dlg._selected_parc_labels == [] # Add label by clicking on plot qtbot.mouseClick(parc_plot, Qt.LeftButton, pos=parc_plot.rect().center(), delay=100) assert "G_front_sup-rh" in dlg._selected_parc_labels - - # Add all labels + # Add label by selecting from list toggle_checked_list_model(dlg.parc_label_list.model, value=1, row=0) - dlg.close() - assert label_gui.param_value == [ + assert "G_Ins_lg_and_S_cent_ins-lh" in dlg._selected_parc_labels + + final_selection = [ + "insula-lh", + "postcentral-lh", "G_front_sup-rh", "G_Ins_lg_and_S_cent_ins-lh", "lh.BA1-lh", ] + # Check display widget + assert dlg.selected_display.model._data == final_selection + + # Add all labels + dlg.close() + assert label_gui.param_value == final_selection diff --git a/requirements.txt b/requirements.txt index d3d1151..b0c868c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ # Pipeline-related autoreject h5io +h5netcdf # only until new version of mne-connectivity is released # MNE-related mne