From 34b41b9b998b26316cc158b8706655d92cb38f89 Mon Sep 17 00:00:00 2001 From: GeorgWa Date: Tue, 17 Oct 2023 00:43:22 +0200 Subject: [PATCH] CHORE major refactoring --- alphadia/__init__.py | 2 +- alphadia/analysis/__init__.py | 1 - alphadia/analysis/actions.py | 144 - alphadia/annotation/__init__.py | 57 - alphadia/annotation/identification.py | 288 -- alphadia/annotation/library.py | 32 - alphadia/annotation/percolation.py | 433 --- alphadia/annotation/psm_stats.py | 266 -- alphadia/calibration.py | 943 +++++- alphadia/cli.py | 202 +- alphadia/{extraction => }/data/bruker.py | 27 +- alphadia/{extraction => }/data/thermo.py | 14 +- alphadia/dia.py | 82 - alphadia/extraction/__init__.py | 0 alphadia/extraction/calibration.py | 811 ----- alphadia/{extraction => }/fdr.py | 17 +- alphadia/{extraction => }/fdrexperimental.py | 73 +- alphadia/{extraction => }/features.py | 10 +- alphadia/gui.py | 1 - alphadia/{extraction => }/hybridselection.py | 38 +- alphadia/library.py | 1851 ----------- alphadia/{extraction => }/libtransform.py | 94 +- alphadia/{extraction => }/numba/config.py | 12 +- alphadia/{extraction => }/numba/fragments.py | 9 +- alphadia/{extraction => }/numba/numeric.py | 7 + alphadia/{extraction => }/planning.py | 35 +- alphadia/{extraction => }/plexscoring.py | 34 +- alphadia/{extraction => }/plotting/cycle.py | 26 +- alphadia/{extraction => }/plotting/debug.py | 12 +- alphadia/{extraction => }/plotting/utils.py | 8 +- alphadia/prefilter.py | 1265 -------- alphadia/preprocessing/__init__.py | 186 -- alphadia/preprocessing/calibration.py | 191 -- alphadia/preprocessing/connecting.py | 142 - alphadia/preprocessing/deisotoping.py | 212 -- alphadia/preprocessing/msmsgeneration.py | 394 --- alphadia/preprocessing/peakfinding.py | 363 --- alphadia/preprocessing/peakstats.py | 636 ---- alphadia/preprocessing/smoothing.py | 210 -- alphadia/{extraction => }/quadrupole.py | 8 +- alphadia/smoothing.py | 2858 ----------------- alphadia/{extraction => }/testing.py | 25 +- alphadia/thermo.py | 215 -- alphadia/{extraction => }/utils.py | 9 - alphadia/{extraction => }/validate.py | 13 +- alphadia/venn.py | 473 --- alphadia/{extraction => }/workflow/base.py | 14 +- alphadia/{extraction => }/workflow/manager.py | 25 +- .../workflow/peptidecentric.py | 23 +- .../{extraction => }/workflow/reporting.py | 68 +- nbs/search/library_search.ipynb | 41 +- 51 files changed, 1129 insertions(+), 11771 deletions(-) delete mode 100644 alphadia/analysis/__init__.py delete mode 100644 alphadia/analysis/actions.py delete mode 100644 alphadia/annotation/__init__.py delete mode 100644 alphadia/annotation/identification.py delete mode 100644 alphadia/annotation/library.py delete mode 100644 alphadia/annotation/percolation.py delete mode 100644 alphadia/annotation/psm_stats.py rename alphadia/{extraction => }/data/bruker.py (98%) rename alphadia/{extraction => }/data/thermo.py (99%) delete mode 100644 alphadia/dia.py delete mode 100644 alphadia/extraction/__init__.py delete mode 100644 alphadia/extraction/calibration.py rename alphadia/{extraction => }/fdr.py (98%) rename alphadia/{extraction => }/fdrexperimental.py (87%) rename alphadia/{extraction => }/features.py (99%) rename alphadia/{extraction => }/hybridselection.py (99%) delete mode 100644 alphadia/library.py rename alphadia/{extraction => }/libtransform.py (94%) rename alphadia/{extraction => }/numba/config.py (97%) rename alphadia/{extraction => }/numba/fragments.py (98%) rename alphadia/{extraction => }/numba/numeric.py (99%) rename alphadia/{extraction => }/planning.py (91%) rename alphadia/{extraction => }/plexscoring.py (98%) rename alphadia/{extraction => }/plotting/cycle.py (94%) rename alphadia/{extraction => }/plotting/debug.py (98%) rename alphadia/{extraction => }/plotting/utils.py (97%) delete mode 100644 alphadia/prefilter.py delete mode 100644 alphadia/preprocessing/__init__.py delete mode 100644 alphadia/preprocessing/calibration.py delete mode 100644 alphadia/preprocessing/connecting.py delete mode 100644 alphadia/preprocessing/deisotoping.py delete mode 100644 alphadia/preprocessing/msmsgeneration.py delete mode 100644 alphadia/preprocessing/peakfinding.py delete mode 100644 alphadia/preprocessing/peakstats.py delete mode 100644 alphadia/preprocessing/smoothing.py rename alphadia/{extraction => }/quadrupole.py (98%) delete mode 100644 alphadia/smoothing.py rename alphadia/{extraction => }/testing.py (96%) delete mode 100644 alphadia/thermo.py rename alphadia/{extraction => }/utils.py (98%) rename alphadia/{extraction => }/validate.py (99%) delete mode 100644 alphadia/venn.py rename alphadia/{extraction => }/workflow/base.py (96%) rename alphadia/{extraction => }/workflow/manager.py (98%) rename alphadia/{extraction => }/workflow/peptidecentric.py (98%) rename alphadia/{extraction => }/workflow/reporting.py (93%) diff --git a/alphadia/__init__.py b/alphadia/__init__.py index 17161df3..fcf5ab7d 100644 --- a/alphadia/__init__.py +++ b/alphadia/__init__.py @@ -40,4 +40,4 @@ } __extra_requirements__ = { "development": "requirements_development.txt", -} +} \ No newline at end of file diff --git a/alphadia/analysis/__init__.py b/alphadia/analysis/__init__.py deleted file mode 100644 index 8dbe83e5..00000000 --- a/alphadia/analysis/__init__.py +++ /dev/null @@ -1 +0,0 @@ -import alphadia.analysis.actions as actions diff --git a/alphadia/analysis/actions.py b/alphadia/analysis/actions.py deleted file mode 100644 index ff2cd179..00000000 --- a/alphadia/analysis/actions.py +++ /dev/null @@ -1,144 +0,0 @@ -"""A module to analyse timsTOF DIA data.""" - -import logging -import collections -import abc - -import alphatims.bruker -import alphadia.smoothing - -class ActionDeque(collections.deque): - - def run_all_consecutive_actions(self) -> None: - if len(self) == 0: - logging.info("No actions in ActionDeque") - else: - if len(self) == 1: - logging.info("Running 1 action in ActionDeque") - else: - logging.info(f"Running {len(self)} actions in ActionDeque") - for action_to_take in self: - action_to_take.run() - - -class Action(abc.ABC): - - def __init__(self, **parameters): - self.update_parameters(**parameters) - - @property - @abc.abstractmethod - def default_parameters(self) -> dict: - pass - - @property - def parameters(self) -> dict: - if not hasattr(self, "_parameters"): - self._parameters = self.default_parameters - return self._parameters - - def update_parameters(self, **parameters) -> None: - self._parameters = self.parse_valid_parameters(**parameters) - - def parse_valid_parameters(self, **parameters) -> None: - current_parameters = self.parameters - for parameter_key, parameter_value in parameters.items(): - current_parameters[parameter_key] = parameter_value - return current_parameters - - def set_output(self, output: type) -> type: - self._output = output - - @property - def output(self) -> type: - if not hasattr(self, "_output"): - raise ValueError("No output has been defined for this action") - return self._output - - @property - def is_completed(self) -> bool: - return hasattr(self, "_output") - - def run(self, redo_completed: bool = False, **parameters) -> None: - if redo_completed or not self.is_completed: - if len(parameters) > 0: - self.update_parameters(**parameters) - logging.info(f"Running '{self.__class__.__name__}'") - try: - output = self._run() - self.set_output(output) - except Exception as raised_exception: - if hasattr(self, "_output"): - del self._output - raise raised_exception - else: - logging.info( - f"'{self.__class__.__name__}' is already completed" - ) - return self.output - - @property - @abc.abstractmethod - def runnable_function(self) -> callable: - pass - - def _run(self) -> type: - return self.runnable_function(**self.parameters) - - # @staticmethod - # def create(name): - # if name == "import": - # return ImportAction() - -class ImportAction(Action): - - @property - def default_parameters(self) -> dict: - return { - "bruker_d_folder_name": None, - } - - @property - def runnable_function(self) -> callable: - return alphatims.bruker.TimsTOF - -class ConnectAction(Action): - - @property - def default_parameters(self) -> dict: - return { - "scan_tolerance": 6, - "dia_data": None, - "multiple_frames": False, - "ms1": True, - "ms2": True, - } - - @property - def runnable_function(self) -> callable: - # import functools - # _func = functools.partial( - # alphadia.smoothing.get_connections_within_cycle, - # scan_max_index=self.parameters["dia_data"].scan_max_index, - # dia_mz_cycle=self.parameters["dia_data"].dia_mz_cycle - # ) - def _func2(**kwargs): - parameters = self.parameters.copy() - dia_data = parameters.pop("dia_data") - return alphadia.smoothing.get_connections_within_cycle( - scan_max_index=dia_data.scan_max_index, - dia_mz_cycle=dia_data.dia_mz_cycle, - **parameters, - ) - # parameters = self.parameters.copy() - # dia_data = parameters.pop("dia_data") - # _func = functools.partial( - # alphadia.smoothing.get_connections_within_cycle, - # scan_max_index=dia_data.scan_max_index, - # dia_mz_cycle=dia_data.dia_mz_cycle, - # **parameters, - # ) - # result = _func() - # def _func2(**kwargs): - # return result - return _func2 diff --git a/alphadia/annotation/__init__.py b/alphadia/annotation/__init__.py deleted file mode 100644 index 34bd5279..00000000 --- a/alphadia/annotation/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Identify pseudo MSMS data data.""" - -from . import identification -from . import psm_stats -from . import library -from . import percolation - - -class Annotator: - - def set_ions(self, precursor_df, fragment_df): - # self.preprocessing_workflow = preprocessing_workflow - self.precursor_df = precursor_df - self.fragment_df = fragment_df - - def set_library(self, library): - self.library = library - - def set_msms_identifier(self): - self.msms_identifier = identification.MSMSIdentifier() - # self.msms_identifier.set_preprocessor(self.preprocessing_workflow) - self.msms_identifier.set_ions( - self.precursor_df, - self.fragment_df, - ) - self.msms_identifier.set_library(self.library) - self.msms_identifier.identify() - - def set_psm_stats_calculator(self): - self.psm_stats_calculator = psm_stats.PSMStatsCalculator() - # self.psm_stats_calculator.set_preprocessor(self.preprocessing_workflow) - self.psm_stats_calculator.set_ions(self.precursor_df, self.fragment_df) - self.psm_stats_calculator.set_library(self.library) - self.psm_stats_calculator.set_annotation( - self.msms_identifier.annotation - ) - self.psm_stats_calculator.estimate_mz_tolerance() - - def set_percolator(self): - self.percolator = percolation.Percolator() - self.percolator.set_annotation( - self.psm_stats_calculator.annotation - ) - self.percolator.percolate() - - def run_default(self): - self.set_msms_identifier() - self.set_psm_stats_calculator() - self.msms_identifier.update_ppm_values_from_stats_calculator( - self.psm_stats_calculator - ) - self.msms_identifier.identify() - self.psm_stats_calculator.set_annotation( - self.msms_identifier.annotation - ) - self.psm_stats_calculator.update_annotation_stats() - self.set_percolator() diff --git a/alphadia/annotation/identification.py b/alphadia/annotation/identification.py deleted file mode 100644 index 2f6742cb..00000000 --- a/alphadia/annotation/identification.py +++ /dev/null @@ -1,288 +0,0 @@ -"""Annotate pseudo MSMS spectra.""" - -import logging - -import numpy as np - -import alphatims.utils - - -class MSMSIdentifier: - - def __init__( - self, - precursor_ppm=50, - fragment_ppm=50, - min_size=10, - ppm_mean=0, - min_hit_count=1, - append_stats=True, - top_n_hits=1, - ): - self.precursor_ppm = precursor_ppm - self.fragment_ppm = fragment_ppm - self.min_size = min_size - self.ppm_mean = ppm_mean - self.min_hit_count = min_hit_count - self.append_stats = append_stats - self.top_n_hits = top_n_hits - - def set_ions(self, precursor_df, fragment_df): - self.precursor_df = precursor_df - self.fragment_df = fragment_df - - def set_library(self, library): - self.library = library - - def update_ppm_values_from_stats_calculator( - self, - psm_stats_calculator - ): - self.ppm_mean = psm_stats_calculator.ppm_mean - self.fragment_ppm = psm_stats_calculator.ppm_width - self.precursor_ppm = psm_stats_calculator.ppm_width - - def identify( - self, - ): - logging.info( - f"Quick library annotation of mono isotopes with {self.ppm_mean=} and {self.precursor_ppm=}" - ) - spectrum_sizes = (self.precursor_df.fragment_end - self.precursor_df.fragment_start).values - o = np.argsort(self.precursor_df.tof_indices.values) - p_mzs = self.precursor_df.mz_average.values[o] - lower = np.empty( - len(self.precursor_df), - dtype=np.int64 - ) - upper = np.empty( - len(self.precursor_df), - dtype=np.int64 - ) - lower[o] = np.searchsorted( - self.library.predicted_library_df.precursor_mz.values, - p_mzs / (1 + self.precursor_ppm * 10**-6) - ) - upper[o] = np.searchsorted( - self.library.predicted_library_df.precursor_mz.values, - p_mzs * (1 + self.precursor_ppm * 10**-6) - ) - logging.info( - f"PSMs to test: {np.sum(((upper - lower) * (spectrum_sizes >= self.min_size)))}" - ) - ( - precursor_indices, - precursor_indptr, - hit_counts, - frequency_counts, - db_indices, - ) = annotate( - range(len(lower)), - # range(100), - self.library.predicted_library_df.frag_start_idx.values, - self.library.predicted_library_df.frag_end_idx.values, - self.precursor_df.fragment_start.values, - self.precursor_df.fragment_end.values, - self.fragment_df.mz_average.values * (1 + self.ppm_mean * 10**-6), - self.fragment_df[ - [i for i in self.fragment_df.columns if "correlation" in i] - ].prod(axis=1).values, # TODO - self.fragment_ppm, - lower, - upper, - self.library.y_mzs, - self.library.b_mzs, - self.min_size, - self.min_hit_count, - self.top_n_hits, - ) - - precursor_selection = np.repeat(precursor_indices, precursor_indptr) - hits = self.precursor_df.iloc[precursor_selection].reset_index() - hits["inet_index"] = precursor_selection - hits["candidates"] = (upper - lower)[precursor_selection] - hits["total_peaks"] = spectrum_sizes[precursor_selection] - hits["db_index"] = db_indices.astype(np.int64) - # hits["counts"] = np.repeat(hit_counts, precursor_indptr) - hits["counts"] = hit_counts - hits["frequency_counts"] = frequency_counts - self.annotation = hits.rename(columns={"charge": "precursor_charge"}) - self.annotation = self.annotation.join(self.library.predicted_library_df, on="db_index") - self.annotation["im_diff"] = self.annotation.mobility_pred - self.annotation.mobility_values - self.annotation["mz_diff"] = self.annotation.precursor_mz - self.annotation.mz_values - self.annotation["ppm_diff"] = self.annotation.mz_diff / self.annotation.precursor_mz * 10**6 - self.annotation["target"] = ~self.annotation.decoy - self.annotation.reset_index(drop=True, inplace=True) - - -def annotate( - iterable, - frag_start_idx, - frag_end_idx, - frag_start, - frag_end, - frag_mzs, - frag_weights, - fragment_ppm, - lower, - upper, - y_mzs, - b_mzs, - min_size, - min_hit_count, - top_n_hits, -): - import multiprocessing - - def starfunc(index): - # return alphadia.prefilter.annotate_pool( - return annotate_pool2( - index, - frag_start_idx, - frag_end_idx, - frag_start, - frag_end, - frag_mzs, - frag_weights, - fragment_ppm, - lower, - upper, - y_mzs, - b_mzs, - min_size, - min_hit_count, - top_n_hits, - ) - precursor_indices = [] - max_hit_counts = [] - max_frequency_counts = [] - db_indices = [] - precursor_indptr = [] - with multiprocessing.pool.ThreadPool(alphatims.utils.MAX_THREADS) as pool: - for ( - precursor_index, - hit_count, - frequency_count, - db_indices_, - ) in alphatims.utils.progress_callback( - pool.imap(starfunc, iterable), - total=len(iterable), - include_progress_callback=True - ): - # if hit_count >= min_hit_count: - if True: - precursor_indices.append(precursor_index) - precursor_indptr.append(len(db_indices_)) - max_hit_counts.append(hit_count) - max_frequency_counts.append(frequency_count) - db_indices.append(db_indices_) - return ( - np.array(precursor_indices), - np.array(precursor_indptr), - # np.array(max_hit_counts), - np.concatenate(max_hit_counts), - np.concatenate(max_frequency_counts), - np.concatenate(db_indices), - ) - - -@alphatims.utils.njit(nogil=True) -def annotate_pool2( - index, - frag_start_idx, - frag_end_idx, - frag_start, - frag_end, - frag_mzs, - frag_weights, - fragment_ppm, - lower, - upper, - y_mzs, - b_mzs, - min_size, - min_hit_count, - top_n_hits, -): - start = frag_start[index] - end = frag_end[index] - results = [0][1:] # this defines the type - hit_counts = [0][1:] # this defines the type - frequency_counts = [0.0][1:] # this defines the type - if (end - start) < min_size: - return index, hit_counts, frequency_counts, results - if (end - start) < min_hit_count: - return index, hit_counts, frequency_counts, results - frequencies = frag_weights[start: end] - fragment_mzs = frag_mzs[start: end] - max_hit_count = min_hit_count - for db_index in range(lower[index], upper[index]): - frag_start = frag_start_idx[db_index] - frag_end = frag_end_idx[db_index] - y_hits, y_frequency = hit_and_frequency_count( - fragment_mzs, - frequencies, - y_mzs[frag_start: frag_end][::-1], - fragment_ppm, - ) - b_hits, b_frequency = hit_and_frequency_count( - fragment_mzs, - frequencies, - b_mzs[frag_start: frag_end], - fragment_ppm, - ) - hit_count = b_hits + y_hits - frequency_count = b_frequency + y_frequency - if top_n_hits == 1: - if frequency_count == max_hit_count: - results.append(db_index) - hit_counts.append(hit_count) - frequency_counts.append(frequency_count) - elif frequency_count > max_hit_count: - results = [db_index] - hit_counts = [hit_count] - frequency_counts = [frequency_count] - max_hit_count = hit_count - elif frequency_count >= min_hit_count: - if len(results) >= top_n_hits: - for min_index, freq_count in enumerate(frequency_counts): - if freq_count == min_hit_count: - results[min_index] = db_index - hit_counts[min_index] = hit_count - frequency_counts[min_index] = frequency_count - break - min_hit_count = min(frequency_counts) - else: - results.append(db_index) - hit_counts.append(hit_count) - frequency_counts.append(frequency_count) - # return index, max_hit_count, results - return index, hit_counts, frequency_counts, results - - - -@alphatims.utils.njit(nogil=True) -def hit_and_frequency_count( - fragment_mzs, - frequencies, - database_mzs, - fragment_ppm, -): - fragment_index = 0 - database_index = 0 - hits = 0 - summed_frequency = 0 - while (fragment_index < len(fragment_mzs)) and (database_index < len(database_mzs)): - fragment_mz = fragment_mzs[fragment_index] - database_mz = database_mzs[database_index] - frequency = frequencies[fragment_index] - if fragment_mz < (database_mz / (1 + 10**-6 * fragment_ppm)): - fragment_index += 1 - elif database_mz < (fragment_mz / (1 + 10**-6 * fragment_ppm)): - database_index += 1 - else: - hits += 1 - summed_frequency += frequency - fragment_index += 1 - database_index += 1 - return hits, summed_frequency diff --git a/alphadia/annotation/library.py b/alphadia/annotation/library.py deleted file mode 100644 index 5548a77d..00000000 --- a/alphadia/annotation/library.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Import library.""" - -import logging - -import numpy as np - -import alphabase.io.hdf - - -class Library: - - def import_from_file(self, library_file_name, is_already_mmapped=True): - logging.info("Loading library") - self.library_file_name = library_file_name - self.lib = alphabase.io.hdf.HDF_File( - self.library_file_name, - read_only=is_already_mmapped, - ) - - predicted_library_df = self.lib.library.precursor_df[...] - # predicted_library_df.sort_values(by=["rt_pred", "mobility_pred"], inplace=True) - predicted_library_df.sort_values(by="precursor_mz", inplace=True) - predicted_library_df.reset_index(level=0, inplace=True) - predicted_library_df.rename(columns={"index": "original_index"}, inplace=True) - predicted_library_df.decoy = predicted_library_df.decoy.astype(np.bool_) - - self.y_mzs = self.lib.library.fragment_mz_df.y_z1.mmap - self.b_mzs = self.lib.library.fragment_mz_df.b_z1.mmap - self.y_ions_intensities = self.lib.library.fragment_intensity_df.y_z1.mmap - self.b_ions_intensities = self.lib.library.fragment_intensity_df.b_z1.mmap - - self.predicted_library_df = predicted_library_df diff --git a/alphadia/annotation/percolation.py b/alphadia/annotation/percolation.py deleted file mode 100644 index 823fcd41..00000000 --- a/alphadia/annotation/percolation.py +++ /dev/null @@ -1,433 +0,0 @@ -"""Percolate results.""" - -import logging - -import numpy as np -import pandas as pd -import sklearn -import sklearn.model_selection -import sklearn.decomposition -import sklearn.neighbors -import sklearn.preprocessing -import sklearn.ensemble -import sklearn.pipeline - - -import alphatims.utils - - -class Percolator: - - def __init__( - self, - fdr=0.01, - train_fdr_level_pre_calibration=0.1, - train_fdr_level_post_calibration=0.33, - n_neighbors=4, - test_size=0.5, - random_state=0, - ): - self.fdr = fdr - self.train_fdr_level_pre_calibration = train_fdr_level_pre_calibration - self.train_fdr_level_post_calibration = train_fdr_level_post_calibration - self.n_neighbors = n_neighbors - self.test_size = test_size - self.random_state = random_state - - def set_annotation(self, annotation): - self.annotation = annotation - - def percolate(self): - logging.info("Percolating PSMs") - val_names = [ - "counts", - "frequency_counts", - "ppm_diff", - "im_diff", - "charge", - "total_peaks", - "nAA", - "b_hit_counts", - "y_hit_counts", - "b_mean_ppm", - "y_mean_ppm", - "relative_found_b_int", - "relative_missed_b_int", - "relative_found_y_int", - "relative_missed_y_int", - "relative_found_int", - "relative_missed_int", - "pearsons", - "pearsons_log", - "candidates", - ] - logging.info("Calculating quick log odds") - score_df = self.annotation.copy() - log_odds = calculate_log_odds_product( - score_df, - val_names, - ) - # log_odds = score_df["frequency_counts"].values - score_df["log_odds"] = log_odds - # score_df = alphadia.prefilter.train_and_score( - # score_df, - # val_names, - # ini_score="log_odds", - # train_fdr_level=train_fdr_level_pre_calibration, - # ).reset_index(drop=True) - score_df = get_q_values(score_df, "log_odds", 'decoy', drop=True) - score_df_above_fdr = score_df[ - (score_df.q_value < self.fdr) & (score_df.target) - ].reset_index(drop=True) - logging.info( - f"Found {len(score_df_above_fdr)} targets for calibration" - ) - score_df_above_fdr["im_pred"] = score_df_above_fdr.mobility_pred - score_df_above_fdr["im_values"] = score_df_above_fdr.mobility_values - self.predictors = {} - for dimension in ["rt", "im"]: - X = score_df_above_fdr[f"{dimension}_pred"].values.reshape(-1, 1) - y = score_df_above_fdr[f"{dimension}_values"].values - ( - X_train, - X_test, - y_train, - y_test - ) = sklearn.model_selection.train_test_split( - X, - y, - test_size=self.test_size, - random_state=self.random_state, - ) - self.predictors[dimension] = sklearn.neighbors.KNeighborsRegressor( - n_neighbors=self.n_neighbors, - # weights="distance", - n_jobs=alphatims.utils.set_threads(alphatims.utils.MAX_THREADS) - ) - self.predictors[dimension].fit(X_train, y_train) - score_df_above_fdr[f"{dimension}_calibrated"] = self.predictors[dimension].predict( - score_df_above_fdr[f"{dimension}_pred"].values.reshape(-1, 1) - ) - score_df_above_fdr[f"{dimension}_diff"] = score_df_above_fdr[f"{dimension}_values"] - score_df_above_fdr[f"{dimension}_calibrated"] - score_df["rt_calibrated"] = self.predictors["rt"].predict( - score_df.rt_pred.values.reshape(-1, 1) - ) - score_df["im_calibrated"] = self.predictors["im"].predict( - score_df.mobility_pred.values.reshape(-1, 1) - ) - ppm_mean = np.mean(score_df_above_fdr.ppm_diff.values) - score_df["mz_calibrated"] = score_df.precursor_mz * ( - 1 - ppm_mean * 10**-6 - ) - - score_df["ppm_diff_calibrated"] = (score_df.mz_calibrated - score_df.mz_values) / score_df.mz_calibrated * 10**6 - score_df["rt_diff_calibrated"] = score_df.rt_calibrated - score_df.rt_values - score_df["im_diff_calibrated"] = score_df.im_calibrated - score_df.mobility_values - # self.score_df = score_df.reset_index(drop=True) - self.score_df = train_and_score( - # score_df[np.abs(score_df.rt_diff_calibrated) < 250].reset_index(drop=True), - score_df, - [ - "counts", - "frequency_counts", - "ppm_diff_calibrated", - "im_diff_calibrated", - "rt_diff_calibrated", - "charge", - "total_peaks", - "nAA", - "b_hit_counts", - "y_hit_counts", - "b_mean_ppm", - "y_mean_ppm", - "relative_found_b_int", - "relative_missed_b_int", - "relative_found_y_int", - "relative_missed_y_int", - "relative_found_int", - "relative_missed_int", - "pearsons", - "pearsons_log", - "candidates", - # "log_odds", - ], - ini_score="log_odds", - train_fdr_level=self.train_fdr_level_post_calibration, - ).reset_index(drop=True) - - self.score_df["target_type"] = np.array([-1, 0])[ - self.score_df.target.astype(np.int) - ] - self.score_df["target_type"][ - (self.score_df.q_value < self.fdr) & (self.score_df.target) - ] = 1 - - -@alphatims.utils.njit(nogil=True) -def fdr_to_q_values(fdr_values): - q_values = np.zeros_like(fdr_values) - min_q_value = np.max(fdr_values) - for i in range(len(fdr_values) - 1, -1, -1): - fdr = fdr_values[i] - if fdr < min_q_value: - min_q_value = fdr - q_values[i] = min_q_value - return q_values - - -def get_q_values(_df, score_column, decoy_column, drop=False): - _df = _df.reset_index(drop=drop) - _df = _df.sort_values([score_column, score_column], ascending=False) - target_values = 1-_df['decoy'].values - decoy_cumsum = np.cumsum(_df['decoy'].values) - target_cumsum = np.cumsum(target_values) - fdr_values = decoy_cumsum/target_cumsum - _df['q_value'] = fdr_to_q_values(fdr_values) - return _df - - -def calculate_odds( - df, - column_name, - *, - target_name="target", - smooth=1, - plot=False -): - negatives, positives = np.bincount(df.target.values) - if negatives > positives: - raise ValueError( - f"Found more decoys ({negatives}) than targets ({positives})" - ) - tp_count = 1000 - else: - tp_count = positives - negatives - n = int(tp_count * smooth) - order = np.argsort(df[column_name].values) - forward = np.cumsum(df[target_name].values[order]) - odds = np.zeros_like(forward, dtype=np.float) - odds[n:-n] = forward[2*n:] - forward[:-2*n] - odds[:n] = forward[n:2*n] - odds[-n:] = forward[-1] - forward[-2*n:-n] - odds[n:-n] /= 2*n - odds[:n] /= np.arange(n, 2*n) - odds[-n:] /= np.arange(n, 2*n)[::-1] - odds /= (1 - odds) - odds = odds[np.argsort(order)] - if plot: - import matplotlib.pyplot as plt - plt.scatter(df[column_name], odds, marker=".") - return odds - - -def calculate_log_odds_product( - df_, - val_names -): - df = df_[val_names] - df = sklearn.preprocessing.StandardScaler().fit_transform(df) - pca = sklearn.decomposition.PCA(n_components=df.shape[1]) - pca.fit(df) - df = pd.DataFrame(pca.transform(df)) - df["target"] = df_.target - negative, positive = np.bincount(df.target) - log_odds = np.zeros(len(df)) - for val_name in range(df.shape[1] - 1): - odds = calculate_odds(df, val_name, smooth=1) - log_odds += np.log2(odds) * pca.explained_variance_[val_name] - return log_odds - # new_df = analysis1.score_df[["decoy", "target"]] - # new_df['odds'] = log_odds - # new_df = alphadia.library.get_q_values(new_df, "odds", 'decoy', drop=True) - # new_df.reset_index(drop=True, inplace=True) - - -def train_and_score( - scores_df, - features, - train_fdr_level: float = 0.1, - ini_score: str = "count", - min_train: int = 1000, - test_size: float = 0.8, - max_depth: list = [5, 25, 50], - max_leaf_nodes: list = [150, 200, 250], - n_jobs: int = -1, - scoring: str = 'accuracy', - plot: bool = False, - random_state: int = 42, -): - df = scores_df.copy() - cv = train_RF( - df, - features, - train_fdr_level=train_fdr_level, - ini_score=ini_score, - min_train=min_train, - test_size=test_size, - max_depth=max_depth, - max_leaf_nodes=max_leaf_nodes, - n_jobs=n_jobs, - scoring=scoring, - plot=plot, - random_state=random_state, - ) - df['score'] = cv.predict_proba(df[features])[:, 1] - return get_q_values(df, "score", 'decoy', drop=True) - - -def train_RF( - df: pd.DataFrame, - features: list, - train_fdr_level: float = 0.1, - ini_score: str = None, - min_train: int = 1000, - test_size: float = 0.8, - max_depth: list = [5, 25, 50], - max_leaf_nodes: list = [150, 200, 250], - n_jobs: int = -1, - scoring: str = 'accuracy', - plot: bool = False, - random_state: int = 42, -): - # Setup ML pipeline - scaler = sklearn.preprocessing.StandardScaler() - rfc = sklearn.ensemble.RandomForestClassifier(random_state=random_state) - ## Initiate scaling + classification pipeline - pipeline = sklearn.pipeline.Pipeline([('scaler', scaler), ('clf', rfc)]) - parameters = { - 'clf__max_depth': (max_depth), - 'clf__max_leaf_nodes': (max_leaf_nodes) - } - ## Setup grid search framework for parameter selection and internal cross validation - cv = sklearn.model_selection.GridSearchCV( - pipeline, - param_grid=parameters, - cv=5, - scoring=scoring, - verbose=0, - return_train_score=True, - n_jobs=n_jobs - ) - # Prepare target and decoy df - dfD = df[df.decoy.values] - # Select high scoring targets (<= train_fdr_level) - # df_prescore = filter_score(df) - # df_prescore = filter_precursor(df_prescore) - # scored = cut_fdr(df_prescore, fdr_level = train_fdr_level, plot=False)[1] - # highT = scored[scored.decoy==False] - # dfT_high = dfT[dfT['query_idx'].isin(highT.query_idx)] - # dfT_high = dfT_high[dfT_high['db_idx'].isin(highT.db_idx)] - if ini_score is None: - selection = None - best_hit_count = 0 - best_feature = "" - for feature in features: - new_df = get_q_values(df, feature, 'decoy') - hits = ( - new_df['q_value'] <= train_fdr_level - ) & ( - new_df['decoy'] == 0 - ) - hit_count = np.sum(hits) - if hit_count > best_hit_count: - best_hit_count = hit_count - selection = hits - best_feature = feature - logging.info(f'Using optimal "{best_feature}" as initial_feature') - dfT_high = df[selection] - else: - logging.info(f'Using selected "{ini_score}" as initial_feature') - new_df = get_q_values(df, ini_score, 'decoy') - dfT_high = df[ - (new_df['q_value'] <= train_fdr_level) & (new_df['decoy'] == 0) - ] - - # Determine the number of psms for semi-supervised learning - n_train = int(dfT_high.shape[0]) - if dfD.shape[0] < n_train: - n_train = int(dfD.shape[0]) - logging.info( - "The total number of available decoys is lower than " - "the initial set of high scoring targets." - ) - if n_train < min_train: - raise ValueError( - "There are fewer high scoring targets or decoys than " - "required by 'min_train'." - ) - - # Subset the targets and decoys datasets to result in a balanced dataset - df_training = dfT_high.append( - dfD.sample(n=n_train, random_state=random_state) - ) - # df_training = dfT_high.append(dfD) - - # Select training and test sets - X = df_training[features] - y = df_training['target'].astype(int) - ( - X_train, - X_test, - y_train, - y_test - ) = sklearn.model_selection.train_test_split( - X.values, - y.values, - test_size=test_size, - random_state=random_state, - stratify=y.values - ) - - # Train the classifier on the training set via 5-fold cross-validation and subsequently test on the test set - logging.info( - 'Training & cross-validation on {} targets and {} decoys'.format( - # np.sum(y_train), X_train.shape[0] - np.sum(y_train) - *np.bincount(y_train)[::-1] - ) - ) - cv.fit(X_train, y_train) - - logging.info( - 'The best parameters selected by 5-fold cross-validation were {}'.format( - cv.best_params_ - ) - ) - logging.info( - 'The train {} was {}'.format(scoring, cv.score(X_train, y_train)) - ) - logging.info( - 'Testing on {} targets and {} decoys'.format( - np.sum(y_test), - X_test.shape[0] - np.sum(y_test) - ) - ) - logging.info( - 'The test {} was {}'.format(scoring, cv.score(X_test, y_test)) - ) - - feature_importances = cv.best_estimator_.named_steps['clf'].feature_importances_ - indices = np.argsort(feature_importances)[::-1][:40] - - top_features = X.columns[indices][:40] - top_score = feature_importances[indices][:40] - - feature_dict = dict(zip(top_features, top_score)) - logging.info(f"Top features {feature_dict}") - - # Inspect feature importances - if plot: - import matplotlib.pyplot as plt - import seaborn as sns - g = sns.barplot( - y=X.columns[indices][:40], - x=feature_importances[indices][:40], - orient='h', - palette='RdBu' - ) - g.set_xlabel("Relative importance", fontsize=12) - g.set_ylabel("Features", fontsize=12) - g.tick_params(labelsize=9) - g.set_title("Feature importance") - plt.show() - - return cv diff --git a/alphadia/annotation/psm_stats.py b/alphadia/annotation/psm_stats.py deleted file mode 100644 index cd0fde33..00000000 --- a/alphadia/annotation/psm_stats.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Calculate PSM stats.""" - -import logging - -import numpy as np - -import alphatims.utils - - -class PSMStatsCalculator: - - def __init__( - self, - pseudo_int=10**-6, - ): - self.pseudo_int = pseudo_int - - def set_ions(self, precursor_df, fragment_df): - self.precursor_df = precursor_df - self.fragment_df = fragment_df - - def set_library(self, library): - self.library = library - - def set_annotation(self, annotation): - self.annotation = annotation - - def estimate_mz_tolerance(self): - logging.info("Estimating ppm values") - ppm_diffs = self.annotation.ppm_diff - order = np.argsort(ppm_diffs.values) - - decoys, targets = np.bincount(self.annotation.decoy.values) - distribution = np.cumsum( - [ - 1 / targets if i else -1 / decoys for i in self.annotation.decoy.values[order] - ] - ) - low = ppm_diffs[order[np.argmin(distribution)]] - high = ppm_diffs[order[np.argmax(distribution)]] - self.ppm_mean = (low + high) / 2 - self.ppm_width = abs(high - low) - # plt.plot( - # ppm_diffs[order], - # distribution, - # ) - # sns.histplot( - # data=self.annotation, - # x="ppm_diff", - # hue="decoy", - # ) - - def update_annotation_stats(self): - logging.info("Appending stats to quick annotation") - b_hit_counts = np.zeros(len(self.annotation)) - y_hit_counts = np.zeros(len(self.annotation)) - b_mean_ppm = np.zeros(len(self.annotation)) - y_mean_ppm = np.zeros(len(self.annotation)) - relative_found_b_int = np.zeros(len(self.annotation)) - relative_missed_b_int = np.zeros(len(self.annotation)) - relative_found_y_int = np.zeros(len(self.annotation)) - relative_missed_y_int = np.zeros(len(self.annotation)) - relative_found_int = np.zeros(len(self.annotation)) - relative_missed_int = np.zeros(len(self.annotation)) - pearsons = np.zeros(len(self.annotation)) - pearsons_log = np.zeros(len(self.annotation)) - update_annotation( - range(len(self.annotation)), - # 1000, - self.annotation.db_index.values, - self.library.predicted_library_df.frag_start_idx.values, - self.library.predicted_library_df.frag_end_idx.values, - self.library.y_mzs, - self.library.b_mzs, - self.library.y_ions_intensities, - self.library.b_ions_intensities, - self.annotation.inet_index.values, - - self.precursor_df.fragment_start.values, - self.precursor_df.fragment_end.values, - self.fragment_df.summed_intensity_values.values, - self.fragment_df.mz_average.values * (1 + self.ppm_mean * 10**-6), - # self.precursor_indptr, - # self.fragment_indices, - # self.tof_indices, - # self.smooth_intensity_values, #.astype(np.float64), - # self.mz_values * (1 + self.ppm_mean * 10**-6), - - - self.ppm_width, - b_hit_counts, - y_hit_counts, - b_mean_ppm, - y_mean_ppm, - relative_found_b_int, - relative_missed_b_int, - relative_found_y_int, - relative_missed_y_int, - relative_found_int, - relative_missed_int, - pearsons, - pearsons_log, - np.float32(self.pseudo_int), - ) - self.annotation["b_hit_counts"] = b_hit_counts - self.annotation["y_hit_counts"] = y_hit_counts - self.annotation["b_mean_ppm"] = b_mean_ppm - self.annotation["y_mean_ppm"] = y_mean_ppm - self.annotation["relative_found_b_int"] = relative_found_b_int - self.annotation["relative_missed_b_int"] = relative_missed_b_int - self.annotation["relative_found_y_int"] = relative_found_y_int - self.annotation["relative_missed_y_int"] = relative_missed_y_int - self.annotation["relative_found_int"] = relative_found_int - self.annotation["relative_missed_int"] = relative_missed_int - pearsons[~np.isfinite(pearsons)] = 0 - self.annotation["pearsons"] = pearsons - pearsons_log[~np.isfinite(pearsons_log)] = 0 - self.annotation["pearsons_log"] = pearsons_log - - -@alphatims.utils.pjit -# @alphatims.utils.njit(nogil=True) -def update_annotation( - index, - database_indices, - database_frag_starts, - database_frag_ends, - database_y_mzs, - database_b_mzs, - database_y_ints, - database_b_ints, - inet_indices, - fragment_start, - fragment_end, - fragment_intensities, - fragment_mzs, - # precursor_indptr, - # fragment_indices, - # tof_indices, - # intensity_values, - # mz_values, - fragment_ppm, - b_hit_counts, - y_hit_counts, - b_mean_ppm, - y_mean_ppm, - relative_found_b_int, - relative_missed_b_int, - relative_found_y_int, - relative_missed_y_int, - relative_found_int, - relative_missed_int, - pearsons, - pearsons_log, - pseudo_int, -): - if index >= len(database_indices): - return - database_index = database_indices[index] - db_frag_start_idx = database_frag_starts[database_index] - db_frag_end_idx = database_frag_ends[database_index] - db_y_mzs = database_y_mzs[db_frag_start_idx: db_frag_end_idx][::-1] - db_b_mzs = database_b_mzs[db_frag_start_idx: db_frag_end_idx] - db_y_ints = database_y_ints[db_frag_start_idx: db_frag_end_idx][::-1] - db_b_ints = database_b_ints[db_frag_start_idx: db_frag_end_idx] - if pseudo_int > 0: - db_y_ints = db_y_ints + pseudo_int - db_b_ints = db_b_ints + pseudo_int - precursor_index = inet_indices[index] - frag_start_idx = fragment_start[precursor_index] - frag_end_idx = fragment_end[precursor_index] - fragment_mzs = fragment_mzs[frag_start_idx: frag_end_idx] - fragment_ints = fragment_intensities[frag_start_idx: frag_end_idx] - fragment_b_hits, db_b_hits = find_hits( - fragment_mzs, - db_b_mzs, - fragment_ppm, - ) - total_b_int = np.sum(db_b_ints) - if total_b_int == 0: - total_b_int = 1 - if len(db_b_hits) > 0: - b_ppm = np.mean( - (db_b_mzs[db_b_hits] - fragment_mzs[fragment_b_hits]) / db_b_mzs[db_b_hits] * 10**6 - ) - found_b_int = np.sum(db_b_ints[db_b_hits]) - min_b_int = np.min(db_b_ints[db_b_hits]) - else: # TODO defaults are not reflective of good/bad scores - b_ppm = fragment_ppm - found_b_int = 0 - min_b_int = -1 - fragment_y_hits, db_y_hits = find_hits( - fragment_mzs, - db_y_mzs, - fragment_ppm, - ) - total_y_int = np.sum(db_y_ints) - if total_y_int == 0: - total_y_int = 1 - if len(db_y_hits) > 0: - y_ppm = np.mean( - (db_y_mzs[db_y_hits] - fragment_mzs[fragment_y_hits]) / db_y_mzs[db_y_hits] * 10**6 - ) - found_y_int = np.sum(db_y_ints[db_y_hits]) - min_y_int = np.min(db_y_ints[db_y_hits]) - else: # TODO defaults are not reflective of good/bad scores - y_ppm = fragment_ppm - found_y_int = 0 - min_y_int = -1 - missed_b_int = np.sum( - np.array([intsy for i, intsy in enumerate(db_b_ints) if (i not in db_b_hits) and (intsy > min_b_int)]) - ) - missed_y_int = np.sum( - np.array([intsy for i, intsy in enumerate(db_y_ints) if (i not in db_y_hits) and (intsy > min_y_int)]) - ) - # all_frags = fragment_ints - b_hit_counts[index] = len(db_b_hits) - y_hit_counts[index] = len(db_y_hits) - b_mean_ppm[index] = b_ppm - y_mean_ppm[index] = y_ppm - relative_found_b_int[index] = found_b_int / total_b_int - relative_missed_b_int[index] = missed_b_int / total_b_int - relative_found_y_int[index] = found_y_int / total_y_int - relative_missed_y_int[index] = missed_y_int / total_y_int - relative_found_int[index] = (found_b_int + found_y_int) / (total_b_int + total_y_int) - relative_missed_int[index] = (missed_b_int + missed_y_int) / (total_b_int + total_y_int) - all_db_ints = [] - all_frag_ints = [] - for b_int in db_b_ints[db_b_hits]: - all_db_ints.append(b_int) - for y_int in db_y_ints[db_y_hits]: - all_db_ints.append(y_int) - for frag_int in fragment_ints[fragment_b_hits]: - all_frag_ints.append(frag_int) - for frag_int in fragment_ints[fragment_y_hits]: - all_frag_ints.append(frag_int) - pearsons[index] = np.corrcoef(all_db_ints, all_frag_ints)[0, 1] - pearsons_log[index] = np.corrcoef( - np.log(np.array(all_db_ints)), - np.log(np.array(all_frag_ints)), - )[0, 1] - - -@alphatims.utils.njit(nogil=True) -def find_hits( - fragment_mzs, - database_mzs, - fragment_ppm, -): - fragment_index = 0 - database_index = 0 - fragment_hits = [] - db_hits = [] - while (fragment_index < len(fragment_mzs)) and (database_index < len(database_mzs)): - fragment_mz = fragment_mzs[fragment_index] - database_mz = database_mzs[database_index] - if fragment_mz < (database_mz / (1 + 10**-6 * fragment_ppm)): - fragment_index += 1 - elif database_mz < (fragment_mz / (1 + 10**-6 * fragment_ppm)): - database_index += 1 - else: - fragment_hits.append(fragment_index) - db_hits.append(database_index) - fragment_index += 1 - database_index += 1 - return np.array(fragment_hits), np.array(db_hits) diff --git a/alphadia/calibration.py b/alphadia/calibration.py index c77050a6..5297a896 100644 --- a/alphadia/calibration.py +++ b/alphadia/calibration.py @@ -1,143 +1,808 @@ -"""Calibrate quad""" +# native imports +import os +import logging +import typing +import pickle -import alphatims.bruker +# alphadia imports +from alphadia.plotting.utils import density_scatter + +# alpha family imports import alphatims.utils -import numpy as np +from alphabase.statistics.regression import LOESSRegression + +# third party imports import pandas as pd -import alphatims.plotting - - -@alphatims.utils.njit(nogil=True, cache=False) -def merge_cyclic_pushes( - cyclic_push_index, - intensity_values, - tof_indices, - push_indptr, - zeroth_frame, - cycle_length, - tof_max_index, - scan_max_index, - return_sparse=False, -): - offset = scan_max_index * zeroth_frame + cyclic_push_index - intensity_buffer = np.zeros(tof_max_index) - tofs = [] - for push_index in range(offset, len(push_indptr) - 1, cycle_length): - start = push_indptr[push_index] - end = push_indptr[push_index + 1] - for index in range(start, end): - tof = tof_indices[index] - intensity = intensity_values[index] - if intensity_buffer[tof] == 0: - tofs.append(tof) - intensity_buffer[tof] += intensity - tofs = np.array(tofs, dtype=tof_indices.dtype) - if return_sparse: - tofs = np.sort(tofs) - intensity_buffer = intensity_buffer[tofs] - return tofs, intensity_buffer - - -def guesstimate_quad_settings( - dia_data, - smooth_window=100, - gaussian_blur=5, - percentile=50, - regresion_mz_lower_cutoff=400, - regresion_mz_upper_cutoff=1000, -): - dia_mz_cycle = np.empty_like(dia_data.dia_mz_cycle) - weights = np.zeros(len(dia_mz_cycle)) - for cyclic_push_index, (low_quad, high_quad) in alphatims.utils.progress_callback( - enumerate(dia_data.dia_mz_cycle), - total=len(dia_data.dia_mz_cycle) - ): - if (low_quad == -1) and (high_quad == -1): - dia_mz_cycle[cyclic_push_index] = (low_quad, high_quad) - continue - tofs, intensity_buffer = merge_cyclic_pushes( - cyclic_push_index=cyclic_push_index, - intensity_values=dia_data.intensity_values, - tof_indices=dia_data.tof_indices, - push_indptr=dia_data.push_indptr, - zeroth_frame=dia_data.zeroth_frame, - cycle_length=len(dia_data.dia_mz_cycle), - tof_max_index=dia_data.tof_max_index, - scan_max_index=dia_data.scan_max_index, - return_sparse=True, - ) - if len(tofs) > 0: - cum_int = np.cumsum(intensity_buffer) - low_threshold = cum_int[-1] * percentile / 100 / 2 - high_threshold = cum_int[-1] * (1 - (percentile / 100 / 2)) - low_index = np.searchsorted(cum_int, low_threshold) - high_index = np.searchsorted(cum_int, high_threshold, "right") - low_quad_estimate = dia_data.mz_values[tofs[low_index]] - high_quad_estimate = dia_data.mz_values[tofs[high_index]] +import numpy as np +from matplotlib import pyplot as plt + +import sklearn.base +from sklearn.linear_model import LinearRegression +from sklearn.preprocessing import PolynomialFeatures +from sklearn.pipeline import Pipeline + +class Calibration(): + def __init__(self, + name : str = '', + function : object = None, + input_columns : typing.List[str] = [], + target_columns : typing.List[str] = [], + output_columns : typing.List[str] = [], + transform_deviation : typing.Union[None, float] = None, + **kwargs): + """A single estimator for a property (mz, rt, etc.). + + Calibration is performed by modeling the deviation of an input values (e.g. mz_library) from an observed property (e.g. mz_observed) using a function (e.g. LinearRegression). Once calibrated, calibrated values (e.g. mz_calibrated) can be predicted from input values (e.g. mz_library). Additional explaining variabels can be added to the input values (e.g. rt_library) to improve the calibration. + + Parameters + ---------- + + name : str + Name of the estimator for logging and plotting e.g. 'mz' + + function : object + The estimator object instance which must have a fit and predict method. + This will usually be a sklearn estimator or a custom estimator. + + input_columns : list of str + The columns of the dataframe that are used as input for the estimator e.g. ['mz_library']. + The first column is the property which should be calibrated, additional columns can be used as explaining variables e.g. ['mz_library', 'rt_library']. + + target_columns : list of str + The columns of the dataframe that are used as target for the estimator e.g. ['mz_observed']. + At the moment only one target column is supported. + + output_columns : list of str + The columns of the dataframe that are used as output for the estimator e.g. ['mz_calibrated']. + At the moment only one output column is supported. + + transform_deviation : typing.List[Union[None, float]] + If set to a valid float, the deviation is expressed as a fraction of the input value e.g. 1e6 for ppm. + If set to None, the deviation is expressed in absolute units. + + """ + + self.name = name + self.function = function + self.input_columns = input_columns + self.target_columns = target_columns + self.output_columns = output_columns + self.transform_deviation = float(transform_deviation) if transform_deviation is not None else None + self.is_fitted = False + + def __repr__(self) -> str: + return f'' + + def save(self, file_name: str): + """Save the estimator to pickle file. + + Parameters + ---------- + + file_name : str + Path to the pickle file + + """ + + with open(file_name, 'wb') as f: + pickle.dump(self, f) + + def load(self, file_name: str): + """Load the estimator from pickle file. + + Parameters + ---------- + + file_name : str + Path to the pickle file + + """ + + with open(file_name, 'rb') as f: + loaded_calibration = pickle.load(f) + self.__dict__.update(loaded_calibration.__dict__) + + def validate_columns( + self, + dataframe : pd.DataFrame + ): + """Validate that the input and target columns are present in the dataframe. + + Parameters + ---------- + dataframe : pandas.DataFrame + Dataframe containing the input and target columns + + Returns + ------- + bool + True if all columns are present, False otherwise + + """ + + valid = True + + if len(self.target_columns) > 1 : + logging.warning('Only one target column supported') + valid = False + + required_columns = set(self.input_columns + self.target_columns) + if not required_columns.issubset(dataframe.columns): + logging.warning(f'{self.name}, at least one column {required_columns} not found in dataframe') + valid = False + + return valid + + def fit( + self, + dataframe : pd.DataFrame, + plot : bool = False, + **kwargs + ): + """Fit the estimator based on the input and target columns of the dataframe. + + Parameters + ---------- + + dataframe : pandas.DataFrame + Dataframe containing the input and target columns + + plot : bool, default=False + If True, a plot of the calibration is generated. + + Returns + ------- + + np.ndarray + Array of shape (n_input_columns, ) containing the mean absolute deviation of the residual deviation at the given confidence interval + + """ + + if not self.validate_columns(dataframe): + logging.warning(f'{self.name} calibration was skipped') + return + + if self.function is None: + raise ValueError('No estimator function provided') + + input_values = dataframe[self.input_columns].values + target_value = dataframe[self.target_columns].values + + try: + self.function.fit(input_values, target_value) + self.is_fitted = True + except Exception as e: + logging.error(f'Could not fit estimator {self.name}: {e}') + return + + if plot == True: + self.plot(dataframe, **kwargs) + + + def predict(self, dataframe, inplace=True): + """Perform a prediction based on the input columns of the dataframe. + + Parameters + ---------- + dataframe : pandas.DataFrame + Dataframe containing the input and target columns + + inplace : bool, default=True + If True, the prediction is added as a new column to the dataframe. If False, the prediction is returned as a numpy array. + + Returns + ------- + np.ndarray + Array of shape (n_samples, ) containing the prediction + + """ + + if self.is_fitted == False: + logging.warning(f'{self.name} prediction was skipped as it has not been fitted yet') + return + + if not set(self.input_columns).issubset(dataframe.columns): + logging.warning(f'{self.name} calibration was skipped as input column {self.input_columns} not found in dataframe') + return + + input_values = dataframe[self.input_columns].values + + if inplace: + dataframe[self.output_columns[0]] = self.function.predict(input_values) else: - low_quad_estimate, high_quad_estimate = -1, -1 - dia_mz_cycle[cyclic_push_index] = ( - low_quad_estimate, - high_quad_estimate - ) - weights[cyclic_push_index] = np.sum(intensity_buffer) - predicted_dia_mz_cycle = predict_dia_mz_cycle( - dia_mz_cycle, - dia_data, - weights, - ) - return dia_mz_cycle, predicted_dia_mz_cycle - - - -def predict_dia_mz_cycle( - dia_mz_cycle, - dia_data, - weights, -): - import sklearn.linear_model - df = pd.DataFrame( - { - "detected_lower": dia_mz_cycle[:, 0], - "detected_upper": dia_mz_cycle[:, 1], - "frame": np.arange(len(dia_mz_cycle)) // dia_data.scan_max_index, - "scan": np.arange(len(dia_mz_cycle)) % dia_data.scan_max_index, - "weights": weights, - } - ) - frame_reg_lower = {} - frame_reg_upper = {} - model = sklearn.linear_model.HuberRegressor - for frame in np.unique(df.frame): - if np.all(dia_data.dia_mz_cycle[df.frame == frame] == -1): - continue - selection = df[df.frame == frame] - frame_reg_lower[frame] = model().fit( - selection.scan.values.reshape(-1, 1), - selection.detected_lower.values.reshape(-1, 1), - selection.weights.values, - ) - frame_reg_upper[frame] = model().fit( - selection.scan.values.reshape(-1, 1), - selection.detected_upper.values.reshape(-1, 1), - selection.weights.values, - ) - predicted_upper = [] - predicted_lower = [] - for index, frame in enumerate(df.frame.values): - if frame not in frame_reg_upper: - predicted_upper.append(-1) - predicted_lower.append(-1) - continue - predicted_lower_ = frame_reg_lower[frame].predict( - df.scan.values[index: index + 1].reshape(-1, 1) - ) - predicted_upper_ = frame_reg_upper[frame].predict( - df.scan.values[index: index + 1].reshape(-1, 1) - ) - predicted_lower.append(predicted_lower_[0]) - predicted_upper.append(predicted_upper_[0]) - predicted_dia_mz_cycle = np.vstack( - [predicted_lower, predicted_upper] - ).T - return predicted_dia_mz_cycle + return self.function.predict(input_values) + + def fit_predict( + self, + dataframe : pd.DataFrame, + plot : bool = False, + inplace : bool = True + ): + """Fit the estimator and perform a prediction based on the input columns of the dataframe. + + Parameters + ---------- + + dataframe : pandas.DataFrame + Dataframe containing the input and target columns + + plot : bool, default=False + If True, a plot of the calibration is generated. + + inplace : bool, default=True + If True, the prediction is added as a new column to the dataframe. If False, the prediction is returned as a numpy array. + + """ + self.fit(dataframe, plot=plot) + return self.predict(dataframe, inplace=inplace) + + def deviation(self, dataframe : pd.DataFrame): + """ Calculate the deviations between the input, target and calibrated values. + + Parameters + ---------- + dataframe : pandas.DataFrame + Dataframe containing the input and target columns + + Returns + ------- + np.ndarray + Array of shape (n_samples, 3 + n_input_columns). + The second dimension contains the observed deviation, calibrated deviation, residual deviation and the input values. + + """ + + # the first column is the unclaibrated input property + # all other columns are explaining variables + input_values = dataframe[self.input_columns].values + + # the first column is the unclaibrated input property + uncalibrated_values = input_values[:, [0]] + + # only one target column is supported + target_values = dataframe[self.target_columns].values[:, [0]] + input_transform = self.transform_deviation + + calibrated_values = self.predict(dataframe, inplace=False) + if calibrated_values.ndim == 1: + calibrated_values = calibrated_values[:, np.newaxis] + + # only one output column is supported + calibrated_dim = calibrated_values[:, [0]] + + # deviation is the difference between the (observed) target value and the uncalibrated input value + observed_deviation = target_values - uncalibrated_values + if input_transform is not None: + observed_deviation = observed_deviation/uncalibrated_values * float(input_transform) + + # calibrated deviation is the explained difference between the (calibrated) target value and the uncalibrated input value + calibrated_deviation = calibrated_dim - uncalibrated_values + if input_transform is not None: + calibrated_deviation = calibrated_deviation/uncalibrated_values * float(input_transform) + + # residual deviation is the unexplained difference between the (observed) target value and the (calibrated) target value + residual_deviation = observed_deviation - calibrated_deviation + + return np.concatenate([observed_deviation, calibrated_deviation, residual_deviation, input_values], axis=1) + + def ci(self, dataframe, ci : float = 0.95): + """Calculate the residual deviation at the given confidence interval. + + Parameters + ---------- + + dataframe : pandas.DataFrame + Dataframe containing the input and target columns + + ci : float, default=0.95 + confidence interval + + Returns + ------- + + float + the confidence interval of the residual deviation after calibration + """ + + if not 0 < ci < 1: + raise ValueError('Confidence interval must be between 0 and 1') + + if not self.is_fitted: + return 0 + + ci_percentile = [100*(1-ci)/2, 100*(1+ci)/2] + + deviation = self.deviation(dataframe) + residual_deviation = deviation[:, 2] + return np.mean(np.abs(np.percentile(residual_deviation, ci_percentile))) + + def get_transform_unit( + self, + transform_deviation : typing.Union[None, float] + ): + + """Get the unit of the deviation based on the transform deviation. + + Parameters + ---------- + + transform_deviation : typing.Union[None, float] + If set to a valid float, the deviation is expressed as a fraction of the input value e.g. 1e6 for ppm. + + Returns + ------- + str + The unit of the deviation + + """ + if transform_deviation is not None: + if np.isclose(transform_deviation,1e6): + return '(ppm)' + elif np.isclose(transform_deviation,1e2): + return '(%)' + else: + return f'({transform_deviation})' + else: + return '(absolute)' + + + def plot( + self, + dataframe : pd.DataFrame, + figure_path : str = None, + #neptune_run : str = None, + #neptune_key :str = None, + **kwargs + ): + + """Plot the data and calibration model. + + Parameters + ---------- + + dataframe : pandas.DataFrame + Dataframe containing the input and target columns + + figure_path : str, default=None + If set, the figure is saved to the given path. + + neptune_run : str, default=None + If set, the figure is logged to the given neptune run. + + neptune_key : str, default=None + key under which the figure is logged to the neptune run. + + """ + + deviation = self.deviation(dataframe) + + n_input_properties = deviation.shape[1] - 3 + + transform_unit = self.get_transform_unit(self.transform_deviation) + + fig, axs = plt.subplots(n_input_properties, 2, figsize=(6.5, 3.5*n_input_properties), squeeze=False) + + for input_property in range(n_input_properties): + + # plot the relative observed deviation + density_scatter( + deviation[:, 3+input_property], + deviation[:, 0], + axis=axs[input_property, 0], + s=1 + ) + + # plot the calibration model + x_values = deviation[:, 3+input_property] + y_values = deviation[:, 1] + order = np.argsort(x_values) + x_values = x_values[order] + y_values = y_values[order] + + axs[input_property, 0].plot(x_values, y_values, color='red') + + # plot the calibrated deviation + + density_scatter( + deviation[:, 3+input_property], + deviation[:, 2], + axis=axs[input_property, 1], + s=1 + ) + + for ax, dim in zip(axs[input_property, :],[0,2]): + ax.set_xlabel(self.input_columns[input_property]) + ax.set_ylabel(f'observed deviation {transform_unit}') + + # get absolute y value and set limites to plus minus absolute y + y = deviation[:, dim] + y_abs = np.abs(y) + ax.set_ylim(-y_abs.max()*1.05, y_abs.max()*1.05) + + fig.tight_layout() + + # log figure to neptune ai + #if neptune_run is not None and neptune_key is not None: + # neptune_run[f'calibration/{neptune_key}'].log(fig) + + #if figure_path is not None: + + # i = 0 + # file_name = os.path.join(figure_path, f'calibration_{neptune_key}_{i}.png') + # while os.path.exists(file_name): + # file_name = os.path.join(figure_path, f'calibration_{neptune_key}_{i}.png') + # i += 1 + + # fig.savefig(file_name) + + plt.show() + + plt.close() + +class CalibrationManager(): + + def __init__( + self, + config : typing.Union[None, dict] = None, + path : typing.Union[None, str] = None, + load_calibration : bool = True): + + """Contains, updates and applies all calibrations for a single run. + + Calibrations are grouped into calibration groups. Each calibration group is applied to a single data structure (precursor dataframe, fragment fataframe, etc.). Each calibration group contains multiple estimators which each calibrate a single property (mz, rt, etc.). Each estimator is a `Calibration` object which contains the estimator function. + + Parameters + ---------- + + config : typing.Union[None, dict], default=None + Calibration config dict. If None, the default config is used. + + path : str, default=None + Path where the current parameter set is saved to and loaded from. + + load_calibration : bool, default=True + If True, the calibration manager is loaded from the given path. + + """ + self._is_loaded_from_file = False + self.estimator_groups = [] + self.path = path + + logging.info('========= Initializing Calibration Manager =========') + + self.load_config(config) + if load_calibration: + self.load() + + logging.info('====================================================') + + @property + def is_loaded_from_file(self): + """Check if the calibration manager was loaded from file. + """ + return self._is_loaded_from_file + + @property + def is_fitted(self): + """Check if all estimators in all calibration groups are fitted. + """ + + is_fitted = True + for group in self.estimator_groups: + for estimator in group['estimators']: + if not estimator.is_fitted: + is_fitted = False + break + + return is_fitted and len(self.estimator_groups) > 0 + + def load_config(self, config : dict): + """Load calibration config from config Dict. + + each calibration config is a list of calibration groups which consist of multiple estimators. + For each estimator the `model` and `model_args` are used to request a model from the calibration_model_provider and to initialize it. + The estimator is then initialized with the `Calibration` class and added to the group. + + Parameters + ---------- + + config : dict + Calibration config dict + + Example + ------- + + Create a calibration manager with a single group and a single estimator: + + .. code-block:: python + + calibration_manager = calibration.CalibrationManager() + calibration_manager.load_config([{ + 'name': 'mz_calibration', + 'estimators': [ + { + 'name': 'mz', + 'model': 'LOESSRegression', + 'model_args': { + 'n_kernels': 2 + }, + 'input_columns': ['mz_library'], + 'target_columns': ['mz_observed'], + 'output_columns': ['mz_calibrated'], + 'transform_deviation': 1e6 + }, + + ] + }]) + + """ + + logging.info('loading calibration config') + logging.info(f'found {len(config)} calibration groups') + for group in config: + logging.info(f'Calibration group :{group["name"]}, found {len(group["estimators"])} estimator(s)') + for estimator in group['estimators']: + try: + template = calibration_model_provider.get_model(estimator['model']) + model_args = estimator['model_args'] if 'model_args' in estimator else {} + estimator['function'] = template(**model_args) + except Exception as e: + logging.error(f'Could not load estimator {estimator["name"]}: {e}') + + group_copy = {'name': group['name']} + group_copy['estimators'] = [Calibration(**x) for x in group['estimators']] + self.estimator_groups.append(group_copy) + + def save(self): + """Save the calibration manager state to pickle file. + """ + if self.path is not None: + with open(self.path, 'wb') as f: + pickle.dump(self, f) + + def load(self): + """Load the calibration manager from pickle file. + """ + if self.path is not None and os.path.exists(self.path): + try: + with open(self.path, 'rb') as f: + loaded_state = pickle.load(f) + self.__dict__.update(loaded_state.__dict__) + self._is_loaded_from_file = True + except: + logging.warning(f'Could not load calibration manager from {self.path}') + else: + logging.info(f'Loaded calibration manager from {self.path}') + else: + logging.warning(f'Calibration manager path {self.path} does not exist') + + def get_group_names(self): + """Get the names of all calibration groups. + + Returns + ------- + list of str + List of calibration group names + """ + + return [x['name'] for x in self.estimator_groups] + + def get_group(self, group_name : str): + """Get the calibration group by name. + + Parameters + ---------- + + group_name : str + Name of the calibration group + + Returns + ------- + dict + Calibration group dict with `name` and `estimators` keys\ + + """ + for group in self.estimator_groups: + if group['name'] == group_name: + return group + + logging.error(f'could not get_group: {group_name}') + return None + + def get_estimator_names(self, group_name : str): + """Get the names of all estimators in a calibration group. + + Parameters + ---------- + + group_name : str + Name of the calibration group + + Returns + ------- + list of str + List of estimator names + """ + + group = self.get_group(group_name) + if group is not None: + return [x.name for x in group['estimators']] + logging.error(f'could not get_estimator_names: {group_name}') + return None + + def get_estimator(self, group_name : str, estimator_name : str): + + """Get an estimator from a calibration group. + + Parameters + ---------- + + group_name : str + Name of the calibration group + + estimator_name : str + Name of the estimator + + Returns + ------- + Calibration + The estimator object + + """ + group = self.get_group(group_name) + if group is not None: + for estimator in group['estimators']: + if estimator.name == estimator_name: + return estimator + logging.error(f'could not get_estimator: {group_name}, {estimator_name}') + return None + + def fit( + self, + df : pd.DataFrame, + group_name : str, + *args, + **kwargs + ): + """Fit all estimators in a calibration group. + + Parameters + ---------- + + df : pandas.DataFrame + Dataframe containing the input and target columns + + group_name : str + Name of the calibration group + + """ + + if len(self.estimator_groups) == 0: + raise ValueError('No estimators defined') + + group_idx = [i for i, x in enumerate(self.estimator_groups) if x['name'] == group_name] + if len(group_idx) == 0: + raise ValueError(f'No group named {group_name} found') + for group in group_idx: + for estimator in self.estimator_groups[group]['estimators']: + logging.info(f'calibration group: {group_name}, fitting {estimator.name} estimator ') + estimator.fit(df, *args, neptune_key=f'{group_name}_{estimator.name}', **kwargs) + + def predict( + self, + df : pd.DataFrame, + group_name : str, + *args, + **kwargs): + + """Predict all estimators in a calibration group. + + Parameters + ---------- + + df : pandas.DataFrame + Dataframe containing the input and target columns + + group_name : str + Name of the calibration group + + """ + + if len(self.estimator_groups) == 0: + raise ValueError('No estimators defined') + + group_idx = [i for i, x in enumerate(self.estimator_groups) if x['name'] == group_name] + if len(group_idx) == 0: + raise ValueError(f'No group named {group_name} found') + for group in group_idx: + for estimator in self.estimator_groups[group]['estimators']: + logging.info(f'calibration group: {group_name}, predicting {estimator.name}') + estimator.predict(df, inplace=True, *args, **kwargs) + + def fit_predict( + self, + df : pd.DataFrame, + group_name : str, + plot : bool = True, + ): + """Fit and predict all estimators in a calibration group. + + Parameters + ---------- + + df : pandas.DataFrame + Dataframe containing the input and target columns + + group_name : str + Name of the calibration group + + plot : bool, default=True + If True, a plot of the calibration is generated. + + """ + self.fit(df, group_name, plot=plot) + self.predict(df, group_name) + +class CalibrationModelProvider: + def __init__(self): + + """Provides a collection of scikit-learn compatible models for calibration. + """ + self.model_dict = {} + + def __repr__(self) -> str: + string = ' str: - return f'' - - def save(self, file_name: str): - """Save the estimator to pickle file. - - Parameters - ---------- - - file_name : str - Path to the pickle file - - """ - - with open(file_name, 'wb') as f: - pickle.dump(self, f) - - def load(self, file_name: str): - """Load the estimator from pickle file. - - Parameters - ---------- - - file_name : str - Path to the pickle file - - """ - - with open(file_name, 'rb') as f: - loaded_calibration = pickle.load(f) - self.__dict__.update(loaded_calibration.__dict__) - - def validate_columns( - self, - dataframe : pd.DataFrame - ): - """Validate that the input and target columns are present in the dataframe. - - Parameters - ---------- - dataframe : pandas.DataFrame - Dataframe containing the input and target columns - - Returns - ------- - bool - True if all columns are present, False otherwise - - """ - - valid = True - - if len(self.target_columns) > 1 : - logging.warning('Only one target column supported') - valid = False - - required_columns = set(self.input_columns + self.target_columns) - if not required_columns.issubset(dataframe.columns): - logging.warning(f'{self.name}, at least one column {required_columns} not found in dataframe') - valid = False - - return valid - - def fit( - self, - dataframe : pd.DataFrame, - plot : bool = False, - **kwargs - ): - """Fit the estimator based on the input and target columns of the dataframe. - - Parameters - ---------- - - dataframe : pandas.DataFrame - Dataframe containing the input and target columns - - plot : bool, default=False - If True, a plot of the calibration is generated. - - Returns - ------- - - np.ndarray - Array of shape (n_input_columns, ) containing the mean absolute deviation of the residual deviation at the given confidence interval - - """ - - if not self.validate_columns(dataframe): - logging.warning(f'{self.name} calibration was skipped') - return - - if self.function is None: - raise ValueError('No estimator function provided') - - input_values = dataframe[self.input_columns].values - target_value = dataframe[self.target_columns].values - - try: - self.function.fit(input_values, target_value) - self.is_fitted = True - except Exception as e: - logging.error(f'Could not fit estimator {self.name}: {e}') - return - - if plot == True: - self.plot(dataframe, **kwargs) - - - def predict(self, dataframe, inplace=True): - """Perform a prediction based on the input columns of the dataframe. - - Parameters - ---------- - dataframe : pandas.DataFrame - Dataframe containing the input and target columns - - inplace : bool, default=True - If True, the prediction is added as a new column to the dataframe. If False, the prediction is returned as a numpy array. - - Returns - ------- - np.ndarray - Array of shape (n_samples, ) containing the prediction - - """ - - if self.is_fitted == False: - logging.warning(f'{self.name} prediction was skipped as it has not been fitted yet') - return - - if not set(self.input_columns).issubset(dataframe.columns): - logging.warning(f'{self.name} calibration was skipped as input column {self.input_columns} not found in dataframe') - return - - input_values = dataframe[self.input_columns].values - - if inplace: - dataframe[self.output_columns[0]] = self.function.predict(input_values) - else: - return self.function.predict(input_values) - - def fit_predict( - self, - dataframe : pd.DataFrame, - plot : bool = False, - inplace : bool = True - ): - """Fit the estimator and perform a prediction based on the input columns of the dataframe. - - Parameters - ---------- - - dataframe : pandas.DataFrame - Dataframe containing the input and target columns - - plot : bool, default=False - If True, a plot of the calibration is generated. - - inplace : bool, default=True - If True, the prediction is added as a new column to the dataframe. If False, the prediction is returned as a numpy array. - - """ - self.fit(dataframe, plot=plot) - return self.predict(dataframe, inplace=inplace) - - def deviation(self, dataframe : pd.DataFrame): - """ Calculate the deviations between the input, target and calibrated values. - - Parameters - ---------- - dataframe : pandas.DataFrame - Dataframe containing the input and target columns - - Returns - ------- - np.ndarray - Array of shape (n_samples, 3 + n_input_columns). - The second dimension contains the observed deviation, calibrated deviation, residual deviation and the input values. - - """ - - # the first column is the unclaibrated input property - # all other columns are explaining variables - input_values = dataframe[self.input_columns].values - - # the first column is the unclaibrated input property - uncalibrated_values = input_values[:, [0]] - - # only one target column is supported - target_values = dataframe[self.target_columns].values[:, [0]] - input_transform = self.transform_deviation - - calibrated_values = self.predict(dataframe, inplace=False) - if calibrated_values.ndim == 1: - calibrated_values = calibrated_values[:, np.newaxis] - - # only one output column is supported - calibrated_dim = calibrated_values[:, [0]] - - # deviation is the difference between the (observed) target value and the uncalibrated input value - observed_deviation = target_values - uncalibrated_values - if input_transform is not None: - observed_deviation = observed_deviation/uncalibrated_values * float(input_transform) - - # calibrated deviation is the explained difference between the (calibrated) target value and the uncalibrated input value - calibrated_deviation = calibrated_dim - uncalibrated_values - if input_transform is not None: - calibrated_deviation = calibrated_deviation/uncalibrated_values * float(input_transform) - - # residual deviation is the unexplained difference between the (observed) target value and the (calibrated) target value - residual_deviation = observed_deviation - calibrated_deviation - - return np.concatenate([observed_deviation, calibrated_deviation, residual_deviation, input_values], axis=1) - - def ci(self, dataframe, ci : float = 0.95): - """Calculate the residual deviation at the given confidence interval. - - Parameters - ---------- - - dataframe : pandas.DataFrame - Dataframe containing the input and target columns - - ci : float, default=0.95 - confidence interval - - Returns - ------- - - float - the confidence interval of the residual deviation after calibration - """ - - if not 0 < ci < 1: - raise ValueError('Confidence interval must be between 0 and 1') - - if not self.is_fitted: - return 0 - - ci_percentile = [100*(1-ci)/2, 100*(1+ci)/2] - - deviation = self.deviation(dataframe) - residual_deviation = deviation[:, 2] - return np.mean(np.abs(np.percentile(residual_deviation, ci_percentile))) - - def get_transform_unit( - self, - transform_deviation : typing.Union[None, float] - ): - - """Get the unit of the deviation based on the transform deviation. - - Parameters - ---------- - - transform_deviation : typing.Union[None, float] - If set to a valid float, the deviation is expressed as a fraction of the input value e.g. 1e6 for ppm. - - Returns - ------- - str - The unit of the deviation - - """ - if transform_deviation is not None: - if np.isclose(transform_deviation,1e6): - return '(ppm)' - elif np.isclose(transform_deviation,1e2): - return '(%)' - else: - return f'({transform_deviation})' - else: - return '(absolute)' - - - def plot( - self, - dataframe : pd.DataFrame, - figure_path : str = None, - #neptune_run : str = None, - #neptune_key :str = None, - **kwargs - ): - - """Plot the data and calibration model. - - Parameters - ---------- - - dataframe : pandas.DataFrame - Dataframe containing the input and target columns - - figure_path : str, default=None - If set, the figure is saved to the given path. - - neptune_run : str, default=None - If set, the figure is logged to the given neptune run. - - neptune_key : str, default=None - key under which the figure is logged to the neptune run. - - """ - - deviation = self.deviation(dataframe) - - n_input_properties = deviation.shape[1] - 3 - - transform_unit = self.get_transform_unit(self.transform_deviation) - - fig, axs = plt.subplots(n_input_properties, 2, figsize=(6.5, 3.5*n_input_properties), squeeze=False) - - for input_property in range(n_input_properties): - - # plot the relative observed deviation - density_scatter( - deviation[:, 3+input_property], - deviation[:, 0], - axis=axs[input_property, 0], - s=1 - ) - - # plot the calibration model - x_values = deviation[:, 3+input_property] - y_values = deviation[:, 1] - order = np.argsort(x_values) - x_values = x_values[order] - y_values = y_values[order] - - axs[input_property, 0].plot(x_values, y_values, color='red') - - # plot the calibrated deviation - - density_scatter( - deviation[:, 3+input_property], - deviation[:, 2], - axis=axs[input_property, 1], - s=1 - ) - - for ax, dim in zip(axs[input_property, :],[0,2]): - ax.set_xlabel(self.input_columns[input_property]) - ax.set_ylabel(f'observed deviation {transform_unit}') - - # get absolute y value and set limites to plus minus absolute y - y = deviation[:, dim] - y_abs = np.abs(y) - ax.set_ylim(-y_abs.max()*1.05, y_abs.max()*1.05) - - fig.tight_layout() - - # log figure to neptune ai - #if neptune_run is not None and neptune_key is not None: - # neptune_run[f'calibration/{neptune_key}'].log(fig) - - #if figure_path is not None: - - # i = 0 - # file_name = os.path.join(figure_path, f'calibration_{neptune_key}_{i}.png') - # while os.path.exists(file_name): - # file_name = os.path.join(figure_path, f'calibration_{neptune_key}_{i}.png') - # i += 1 - - # fig.savefig(file_name) - - plt.show() - - plt.close() - -class CalibrationManager(): - - def __init__( - self, - config : typing.Union[None, dict] = None, - path : typing.Union[None, str] = None, - load_calibration : bool = True): - - """Contains, updates and applies all calibrations for a single run. - - Calibrations are grouped into calibration groups. Each calibration group is applied to a single data structure (precursor dataframe, fragment fataframe, etc.). Each calibration group contains multiple estimators which each calibrate a single property (mz, rt, etc.). Each estimator is a `Calibration` object which contains the estimator function. - - Parameters - ---------- - - config : typing.Union[None, dict], default=None - Calibration config dict. If None, the default config is used. - - path : str, default=None - Path where the current parameter set is saved to and loaded from. - - load_calibration : bool, default=True - If True, the calibration manager is loaded from the given path. - - """ - self._is_loaded_from_file = False - self.estimator_groups = [] - self.path = path - - logging.info('========= Initializing Calibration Manager =========') - - self.load_config(config) - if load_calibration: - self.load() - - logging.info('====================================================') - - @property - def is_loaded_from_file(self): - """Check if the calibration manager was loaded from file. - """ - return self._is_loaded_from_file - - @property - def is_fitted(self): - """Check if all estimators in all calibration groups are fitted. - """ - - is_fitted = True - for group in self.estimator_groups: - for estimator in group['estimators']: - if not estimator.is_fitted: - is_fitted = False - break - - return is_fitted and len(self.estimator_groups) > 0 - - def load_config(self, config : dict): - """Load calibration config from config Dict. - - each calibration config is a list of calibration groups which consist of multiple estimators. - For each estimator the `model` and `model_args` are used to request a model from the calibration_model_provider and to initialize it. - The estimator is then initialized with the `Calibration` class and added to the group. - - Parameters - ---------- - - config : dict - Calibration config dict - - Example - ------- - - Create a calibration manager with a single group and a single estimator: - - .. code-block:: python - - calibration_manager = calibration.CalibrationManager() - calibration_manager.load_config([{ - 'name': 'mz_calibration', - 'estimators': [ - { - 'name': 'mz', - 'model': 'LOESSRegression', - 'model_args': { - 'n_kernels': 2 - }, - 'input_columns': ['mz_library'], - 'target_columns': ['mz_observed'], - 'output_columns': ['mz_calibrated'], - 'transform_deviation': 1e6 - }, - - ] - }]) - - """ - - logging.info('loading calibration config') - logging.info(f'found {len(config)} calibration groups') - for group in config: - logging.info(f'Calibration group :{group["name"]}, found {len(group["estimators"])} estimator(s)') - for estimator in group['estimators']: - try: - template = calibration_model_provider.get_model(estimator['model']) - model_args = estimator['model_args'] if 'model_args' in estimator else {} - estimator['function'] = template(**model_args) - except Exception as e: - logging.error(f'Could not load estimator {estimator["name"]}: {e}') - - group_copy = {'name': group['name']} - group_copy['estimators'] = [Calibration(**x) for x in group['estimators']] - self.estimator_groups.append(group_copy) - - def save(self): - """Save the calibration manager state to pickle file. - """ - if self.path is not None: - with open(self.path, 'wb') as f: - pickle.dump(self, f) - - def load(self): - """Load the calibration manager from pickle file. - """ - if self.path is not None and os.path.exists(self.path): - try: - with open(self.path, 'rb') as f: - loaded_state = pickle.load(f) - self.__dict__.update(loaded_state.__dict__) - self._is_loaded_from_file = True - except: - logging.warning(f'Could not load calibration manager from {self.path}') - else: - logging.info(f'Loaded calibration manager from {self.path}') - else: - logging.warning(f'Calibration manager path {self.path} does not exist') - - def get_group_names(self): - """Get the names of all calibration groups. - - Returns - ------- - list of str - List of calibration group names - """ - - return [x['name'] for x in self.estimator_groups] - - def get_group(self, group_name : str): - """Get the calibration group by name. - - Parameters - ---------- - - group_name : str - Name of the calibration group - - Returns - ------- - dict - Calibration group dict with `name` and `estimators` keys\ - - """ - for group in self.estimator_groups: - if group['name'] == group_name: - return group - - logging.error(f'could not get_group: {group_name}') - return None - - def get_estimator_names(self, group_name : str): - """Get the names of all estimators in a calibration group. - - Parameters - ---------- - - group_name : str - Name of the calibration group - - Returns - ------- - list of str - List of estimator names - """ - - group = self.get_group(group_name) - if group is not None: - return [x.name for x in group['estimators']] - logging.error(f'could not get_estimator_names: {group_name}') - return None - - def get_estimator(self, group_name : str, estimator_name : str): - - """Get an estimator from a calibration group. - - Parameters - ---------- - - group_name : str - Name of the calibration group - - estimator_name : str - Name of the estimator - - Returns - ------- - Calibration - The estimator object - - """ - group = self.get_group(group_name) - if group is not None: - for estimator in group['estimators']: - if estimator.name == estimator_name: - return estimator - logging.error(f'could not get_estimator: {group_name}, {estimator_name}') - return None - - def fit( - self, - df : pd.DataFrame, - group_name : str, - *args, - **kwargs - ): - """Fit all estimators in a calibration group. - - Parameters - ---------- - - df : pandas.DataFrame - Dataframe containing the input and target columns - - group_name : str - Name of the calibration group - - """ - - if len(self.estimator_groups) == 0: - raise ValueError('No estimators defined') - - group_idx = [i for i, x in enumerate(self.estimator_groups) if x['name'] == group_name] - if len(group_idx) == 0: - raise ValueError(f'No group named {group_name} found') - for group in group_idx: - for estimator in self.estimator_groups[group]['estimators']: - logging.info(f'calibration group: {group_name}, fitting {estimator.name} estimator ') - estimator.fit(df, *args, neptune_key=f'{group_name}_{estimator.name}', **kwargs) - - def predict( - self, - df : pd.DataFrame, - group_name : str, - *args, - **kwargs): - - """Predict all estimators in a calibration group. - - Parameters - ---------- - - df : pandas.DataFrame - Dataframe containing the input and target columns - - group_name : str - Name of the calibration group - - """ - - if len(self.estimator_groups) == 0: - raise ValueError('No estimators defined') - - group_idx = [i for i, x in enumerate(self.estimator_groups) if x['name'] == group_name] - if len(group_idx) == 0: - raise ValueError(f'No group named {group_name} found') - for group in group_idx: - for estimator in self.estimator_groups[group]['estimators']: - logging.info(f'calibration group: {group_name}, predicting {estimator.name}') - estimator.predict(df, inplace=True, *args, **kwargs) - - def fit_predict( - self, - df : pd.DataFrame, - group_name : str, - plot : bool = True, - ): - """Fit and predict all estimators in a calibration group. - - Parameters - ---------- - - df : pandas.DataFrame - Dataframe containing the input and target columns - - group_name : str - Name of the calibration group - - plot : bool, default=True - If True, a plot of the calibration is generated. - - """ - self.fit(df, group_name, plot=plot) - self.predict(df, group_name) - -class CalibrationModelProvider: - def __init__(self): - - """Provides a collection of scikit-learn compatible models for calibration. - """ - self.model_dict = {} - - def __repr__(self) -> str: - string = '