diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index f42b08c3..48188cd7 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -1,53 +1,29 @@ import os from collections import OrderedDict -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Tuple import pandas as pd +from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import ( - GOUniProtOver50, - GOUniProtOver250, - _GOUniProtDataExtractor, -) +from chebai.preprocessing.datasets.go_uniprot import DeepGO1MigratedData class DeepGo1DataMigration: """ A class to handle data migration and processing for the DeepGO project. - It migrates the deepGO data to our data structure followed for GO-UniProt data. + It migrates the DeepGO data to our data structure followed for GO-UniProt data. - It migrates the data of DeepGO model of the below research paper: + This class handles data from the DeepGO model as described in: Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 - (https://doi.org/10.1093/bioinformatics/btx624), - - Attributes: - _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. - _MAXLEN (int): Maximum sequence length for sequences. - _LABELS_START_IDX (int): Starting index for labels in the dataset. - - Methods: - __init__(data_dir, go_branch): Initializes the data directory and GO branch. - _load_data(): Loads train, validation, test, and terms data from the specified directory. - _record_splits(): Creates a DataFrame with IDs and their corresponding split. - migrate(): Executes the migration process including data loading, processing, and saving. - _extract_required_data_from_splits(): Extracts required columns from the splits data. - _get_labels_columns(data_df): Generates label columns for the data based on GO terms. - extract_go_id(go_list): Extracts GO IDs from a list. - save_migrated_data(data_df, splits_df): Saves the processed data and splits. + (https://doi.org/10.1093/bioinformatics/btx624). """ - # Number of annotations for each go_branch as per the research paper - _CORRESPONDING_GO_CLASSES = { - "cc": GOUniProtOver50, - "mf": GOUniProtOver50, - "bp": GOUniProtOver250, - } - + # Max sequence length as per DeepGO1 _MAXLEN = 1002 - _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + _LABELS_START_IDX = DeepGO1MigratedData._LABELS_START_IDX def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): """ @@ -55,9 +31,9 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): Args: data_dir (str): Directory containing the data files. - go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. """ - valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + valid_go_branches = list(DeepGO1MigratedData.GO_BRANCH_MAPPING.keys()) if go_branch not in valid_go_branches: raise ValueError(f"go_branch must be one of {valid_go_branches}") self._go_branch = go_branch @@ -69,34 +45,60 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): self._terms_df: Optional[pd.DataFrame] = None self._classes: Optional[List[str]] = None + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + data_with_labels_df = self._extract_required_data_from_splits() + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + def _load_data(self) -> None: """ Loads the test, train, validation, and terms data from the pickled files in the data directory. """ try: - print(f"Loading data from {self._data_dir}......") + print(f"Loading data files from directory: {self._data_dir}") self._test_df = pd.DataFrame( pd.read_pickle( os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") ) ) - self._train_df = pd.DataFrame( + + # DeepGO 1 lacks a validation split, so we will create one by further splitting the training set. + # Although this reduces the training data slightly compared to the original DeepGO setup, + # given the data size, the impact should be minimal. + train_df = pd.DataFrame( pd.read_pickle( os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") ) ) - # self._validation_df = pd.DataFrame( - # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl")) - # ) - - # DeepGO1 data does not include a separate validation split, but our data structure requires one. - # To accommodate this, we will create a placeholder validation split by duplicating a small subset of the - # training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set - # without creating an exclusive validation split from it. - # Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not - # reflect true validation performance. - self._validation_df = self._train_df[len(self._train_df) - 5 :] + + self._train_df, self._validation_df = self._get_train_val_split(train_df) + self._terms_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) ) @@ -104,6 +106,35 @@ def _load_data(self) -> None: except FileNotFoundError as e: print(f"Error loading data: {e}") + @staticmethod + def _get_train_val_split( + train_df: pd.DataFrame, + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Splits the training data into a smaller training set and a validation set. + + Args: + train_df (pd.DataFrame): Original training DataFrame. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames. + """ + labels_list_train = train_df["labels"].tolist() + train_split = 0.85 + test_size = ((1 - train_split) ** 2) / train_split + + splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=42 + ) + + train_indices, validation_indices = next( + splitter.split(labels_list_train, labels_list_train) + ) + + df_validation = train_df.iloc[validation_indices] + df_train = train_df.iloc[train_indices] + return df_train, df_validation + def _record_splits(self) -> pd.DataFrame: """ Creates a DataFrame that stores the IDs and their corresponding data splits. @@ -111,7 +142,7 @@ def _record_splits(self) -> pd.DataFrame: Returns: pd.DataFrame: A combined DataFrame containing split assignments. """ - print("Recording splits...") + print("Recording data splits for train, validation, and test sets.") split_assignment_list: List[pd.DataFrame] = [ pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), pd.DataFrame( @@ -123,37 +154,6 @@ def _record_splits(self) -> pd.DataFrame: combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) return combined_split_assignment - def migrate(self) -> None: - """ - Executes the data migration by loading, processing, and saving the data. - """ - print("Migration started......") - self._load_data() - if not all( - df is not None - for df in [ - self._train_df, - self._validation_df, - self._test_df, - self._terms_df, - ] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - splits_df = self._record_splits() - - data_with_labels_df = self._extract_required_data_from_splits() - - if not all( - var is not None for var in [data_with_labels_df, splits_df, self._classes] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - - self.save_migrated_data(data_with_labels_df, splits_df) - def _extract_required_data_from_splits(self) -> pd.DataFrame: """ Extracts required columns from the combined data splits. @@ -161,12 +161,11 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: Returns: pd.DataFrame: A DataFrame containing the essential columns for processing. """ - print("Combining the data splits with required data..... ") + print("Combining data splits into a single DataFrame with required columns.") required_columns = [ "proteins", "accessions", "sequences", - # Note: The GO classes here only directly related one, and not transitive GO classes "gos", "labels", ] @@ -183,7 +182,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: lambda row: self.extract_go_id(row["gos"]), axis=1 ) - labels_df = self._get_labels_colums(new_df) + labels_df = self._get_labels_columns(new_df) data_df = pd.DataFrame( OrderedDict( @@ -198,28 +197,32 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: return df - def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: """ - Generates a DataFrame with one-hot encoded columns for each GO term label, - based on the terms provided in `self._terms_df` and the existing labels in `data_df`. + Extracts and parses GO IDs from a list of GO annotations. - This method extracts GO IDs from the `functions` column of `self._terms_df`, - creating a list of all unique GO IDs. It then uses this list to create new - columns in the returned DataFrame, where each row has binary values - (0 or 1) indicating the presence of each GO ID in the corresponding entry of - `data_df['labels']`. + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[int]: List of parsed GO IDs. + """ + return [DeepGO1MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _get_labels_columns(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates columns for labels based on provided selected terms. Args: - data_df (pd.DataFrame): DataFrame containing data with a 'labels' column, - which holds lists of GO ID labels for each row. + data_df (pd.DataFrame): DataFrame with GO annotations and labels. Returns: - pd.DataFrame: A DataFrame with the same index as `data_df` and one column - per GO ID, containing binary values indicating label presence. + pd.DataFrame: DataFrame with label columns. """ - print("Generating labels based on terms.pkl file.......") + print("Generating label columns from provided selected terms.") parsed_go_ids: pd.Series = self._terms_df["functions"].apply( - lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) + lambda gos: DeepGO1MigratedData._parse_go_id(gos) ) all_go_ids_list = parsed_go_ids.values.tolist() self._classes = all_go_ids_list @@ -230,21 +233,6 @@ def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: return new_label_columns - @staticmethod - def extract_go_id(go_list: List[str]) -> List[int]: - """ - Extracts and parses GO IDs from a list of GO annotations. - - Args: - go_list (List[str]): List of GO annotation strings. - - Returns: - List[str]: List of parsed GO IDs. - """ - return [ - _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list - ] - def save_migrated_data( self, data_df: pd.DataFrame, splits_df: pd.DataFrame ) -> None: @@ -255,31 +243,38 @@ def save_migrated_data( data_df (pd.DataFrame): Data with GO labels. splits_df (pd.DataFrame): Split assignment DataFrame. """ - print("Saving transformed data......") - go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ - self._go_branch - ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) + print("Saving transformed data files.") - go_class_instance.save_processed( - data_df, go_class_instance.processed_main_file_names_dict["data"] + deepgo_migr_inst: DeepGO1MigratedData = DeepGO1MigratedData( + go_branch=DeepGO1MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._MAXLEN, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] ) print( - f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" ) + # Save splits file splits_df.to_csv( - os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go1.csv"), index=False, ) - print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + print(f"splits_deep_go1.csv saved to {deepgo_migr_inst.processed_dir_main}") + # Save classes file classes = sorted(self._classes) with open( - os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go1.txt"), + "wt", ) as fout: fout.writelines(str(node) + "\n" for node in classes) - print(f"classes.txt saved to {go_class_instance.processed_dir_main}") - print("Migration completed!") + print(f"classes_deep_go1.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration process completed!") class Main: