Skip to content

Commit

Permalink
deepgo1: further split train set into train and val for
Browse files Browse the repository at this point in the history
- +migration structure changes
  • Loading branch information
aditya0by0 committed Nov 13, 2024
1 parent a15d492 commit e0a8524
Showing 1 changed file with 118 additions and 123 deletions.
241 changes: 118 additions & 123 deletions chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,39 @@
import os
from collections import OrderedDict
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Tuple

import pandas as pd
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from jsonargparse import CLI

from chebai.preprocessing.datasets.go_uniprot import (
GOUniProtOver50,
GOUniProtOver250,
_GOUniProtDataExtractor,
)
from chebai.preprocessing.datasets.go_uniprot import DeepGO1MigratedData


class DeepGo1DataMigration:
"""
A class to handle data migration and processing for the DeepGO project.
It migrates the deepGO data to our data structure followed for GO-UniProt data.
It migrates the DeepGO data to our data structure followed for GO-UniProt data.
It migrates the data of DeepGO model of the below research paper:
This class handles data from the DeepGO model as described in:
Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf,
DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier,
Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668
(https://doi.org/10.1093/bioinformatics/btx624),
Attributes:
_CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes.
_MAXLEN (int): Maximum sequence length for sequences.
_LABELS_START_IDX (int): Starting index for labels in the dataset.
Methods:
__init__(data_dir, go_branch): Initializes the data directory and GO branch.
_load_data(): Loads train, validation, test, and terms data from the specified directory.
_record_splits(): Creates a DataFrame with IDs and their corresponding split.
migrate(): Executes the migration process including data loading, processing, and saving.
_extract_required_data_from_splits(): Extracts required columns from the splits data.
_get_labels_columns(data_df): Generates label columns for the data based on GO terms.
extract_go_id(go_list): Extracts GO IDs from a list.
save_migrated_data(data_df, splits_df): Saves the processed data and splits.
(https://doi.org/10.1093/bioinformatics/btx624).
"""

# Number of annotations for each go_branch as per the research paper
_CORRESPONDING_GO_CLASSES = {
"cc": GOUniProtOver50,
"mf": GOUniProtOver50,
"bp": GOUniProtOver250,
}

# Max sequence length as per DeepGO1
_MAXLEN = 1002
_LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX
_LABELS_START_IDX = DeepGO1MigratedData._LABELS_START_IDX

def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
"""
Initializes the data migration object with a data directory and GO branch.
Args:
data_dir (str): Directory containing the data files.
go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process).
go_branch (Literal["cc", "mf", "bp"]): GO branch to use.
"""
valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys())
valid_go_branches = list(DeepGO1MigratedData.GO_BRANCH_MAPPING.keys())
if go_branch not in valid_go_branches:
raise ValueError(f"go_branch must be one of {valid_go_branches}")
self._go_branch = go_branch
Expand All @@ -69,49 +45,104 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
self._terms_df: Optional[pd.DataFrame] = None
self._classes: Optional[List[str]] = None

def migrate(self) -> None:
"""
Executes the data migration by loading, processing, and saving the data.
"""
print("Starting the migration process...")
self._load_data()
if not all(
df is not None
for df in [
self._train_df,
self._validation_df,
self._test_df,
self._terms_df,
]
):
raise Exception(
"Data splits or terms data is not available in instance variables."
)
splits_df = self._record_splits()
data_with_labels_df = self._extract_required_data_from_splits()

if not all(
var is not None for var in [data_with_labels_df, splits_df, self._classes]
):
raise Exception(
"Data splits or terms data is not available in instance variables."
)

self.save_migrated_data(data_with_labels_df, splits_df)

def _load_data(self) -> None:
"""
Loads the test, train, validation, and terms data from the pickled files
in the data directory.
"""
try:
print(f"Loading data from {self._data_dir}......")
print(f"Loading data files from directory: {self._data_dir}")
self._test_df = pd.DataFrame(
pd.read_pickle(
os.path.join(self._data_dir, f"test-{self._go_branch}.pkl")
)
)
self._train_df = pd.DataFrame(

# DeepGO 1 lacks a validation split, so we will create one by further splitting the training set.
# Although this reduces the training data slightly compared to the original DeepGO setup,
# given the data size, the impact should be minimal.
train_df = pd.DataFrame(
pd.read_pickle(
os.path.join(self._data_dir, f"train-{self._go_branch}.pkl")
)
)
# self._validation_df = pd.DataFrame(
# pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl"))
# )

# DeepGO1 data does not include a separate validation split, but our data structure requires one.
# To accommodate this, we will create a placeholder validation split by duplicating a small subset of the
# training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set
# without creating an exclusive validation split from it.
# Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not
# reflect true validation performance.
self._validation_df = self._train_df[len(self._train_df) - 5 :]

self._train_df, self._validation_df = self._get_train_val_split(train_df)

self._terms_df = pd.DataFrame(
pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl"))
)

except FileNotFoundError as e:
print(f"Error loading data: {e}")

@staticmethod
def _get_train_val_split(
train_df: pd.DataFrame,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Splits the training data into a smaller training set and a validation set.
Args:
train_df (pd.DataFrame): Original training DataFrame.
Returns:
Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames.
"""
labels_list_train = train_df["labels"].tolist()
train_split = 0.85
test_size = ((1 - train_split) ** 2) / train_split

splitter = MultilabelStratifiedShuffleSplit(
n_splits=1, test_size=test_size, random_state=42
)

train_indices, validation_indices = next(
splitter.split(labels_list_train, labels_list_train)
)

df_validation = train_df.iloc[validation_indices]
df_train = train_df.iloc[train_indices]
return df_train, df_validation

def _record_splits(self) -> pd.DataFrame:
"""
Creates a DataFrame that stores the IDs and their corresponding data splits.
Returns:
pd.DataFrame: A combined DataFrame containing split assignments.
"""
print("Recording splits...")
print("Recording data splits for train, validation, and test sets.")
split_assignment_list: List[pd.DataFrame] = [
pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}),
pd.DataFrame(
Expand All @@ -123,50 +154,18 @@ def _record_splits(self) -> pd.DataFrame:
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True)
return combined_split_assignment

def migrate(self) -> None:
"""
Executes the data migration by loading, processing, and saving the data.
"""
print("Migration started......")
self._load_data()
if not all(
df is not None
for df in [
self._train_df,
self._validation_df,
self._test_df,
self._terms_df,
]
):
raise Exception(
"Data splits or terms data is not available in instance variables."
)
splits_df = self._record_splits()

data_with_labels_df = self._extract_required_data_from_splits()

if not all(
var is not None for var in [data_with_labels_df, splits_df, self._classes]
):
raise Exception(
"Data splits or terms data is not available in instance variables."
)

self.save_migrated_data(data_with_labels_df, splits_df)

def _extract_required_data_from_splits(self) -> pd.DataFrame:
"""
Extracts required columns from the combined data splits.
Returns:
pd.DataFrame: A DataFrame containing the essential columns for processing.
"""
print("Combining the data splits with required data..... ")
print("Combining data splits into a single DataFrame with required columns.")
required_columns = [
"proteins",
"accessions",
"sequences",
# Note: The GO classes here only directly related one, and not transitive GO classes
"gos",
"labels",
]
Expand All @@ -183,7 +182,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
lambda row: self.extract_go_id(row["gos"]), axis=1
)

labels_df = self._get_labels_colums(new_df)
labels_df = self._get_labels_columns(new_df)

data_df = pd.DataFrame(
OrderedDict(
Expand All @@ -198,28 +197,32 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:

return df

def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame:
@staticmethod
def extract_go_id(go_list: List[str]) -> List[int]:
"""
Generates a DataFrame with one-hot encoded columns for each GO term label,
based on the terms provided in `self._terms_df` and the existing labels in `data_df`.
Extracts and parses GO IDs from a list of GO annotations.
This method extracts GO IDs from the `functions` column of `self._terms_df`,
creating a list of all unique GO IDs. It then uses this list to create new
columns in the returned DataFrame, where each row has binary values
(0 or 1) indicating the presence of each GO ID in the corresponding entry of
`data_df['labels']`.
Args:
go_list (List[str]): List of GO annotation strings.
Returns:
List[int]: List of parsed GO IDs.
"""
return [DeepGO1MigratedData._parse_go_id(go_id_str) for go_id_str in go_list]

def _get_labels_columns(self, data_df: pd.DataFrame) -> pd.DataFrame:
"""
Generates columns for labels based on provided selected terms.
Args:
data_df (pd.DataFrame): DataFrame containing data with a 'labels' column,
which holds lists of GO ID labels for each row.
data_df (pd.DataFrame): DataFrame with GO annotations and labels.
Returns:
pd.DataFrame: A DataFrame with the same index as `data_df` and one column
per GO ID, containing binary values indicating label presence.
pd.DataFrame: DataFrame with label columns.
"""
print("Generating labels based on terms.pkl file.......")
print("Generating label columns from provided selected terms.")
parsed_go_ids: pd.Series = self._terms_df["functions"].apply(
lambda gos: _GOUniProtDataExtractor._parse_go_id(gos)
lambda gos: DeepGO1MigratedData._parse_go_id(gos)
)
all_go_ids_list = parsed_go_ids.values.tolist()
self._classes = all_go_ids_list
Expand All @@ -230,21 +233,6 @@ def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame:

return new_label_columns

@staticmethod
def extract_go_id(go_list: List[str]) -> List[int]:
"""
Extracts and parses GO IDs from a list of GO annotations.
Args:
go_list (List[str]): List of GO annotation strings.
Returns:
List[str]: List of parsed GO IDs.
"""
return [
_GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list
]

def save_migrated_data(
self, data_df: pd.DataFrame, splits_df: pd.DataFrame
) -> None:
Expand All @@ -255,31 +243,38 @@ def save_migrated_data(
data_df (pd.DataFrame): Data with GO labels.
splits_df (pd.DataFrame): Split assignment DataFrame.
"""
print("Saving transformed data......")
go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[
self._go_branch
](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN)
print("Saving transformed data files.")

go_class_instance.save_processed(
data_df, go_class_instance.processed_main_file_names_dict["data"]
deepgo_migr_inst: DeepGO1MigratedData = DeepGO1MigratedData(
go_branch=DeepGO1MigratedData.GO_BRANCH_MAPPING[self._go_branch],
max_sequence_length=self._MAXLEN,
)

# Save data file
deepgo_migr_inst.save_processed(
data_df, deepgo_migr_inst.processed_main_file_names_dict["data"]
)
print(
f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}"
f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}"
)

# Save splits file
splits_df.to_csv(
os.path.join(go_class_instance.processed_dir_main, "splits.csv"),
os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go1.csv"),
index=False,
)
print(f"splits.csv saved to {go_class_instance.processed_dir_main}")
print(f"splits_deep_go1.csv saved to {deepgo_migr_inst.processed_dir_main}")

# Save classes file
classes = sorted(self._classes)
with open(
os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt"
os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go1.txt"),
"wt",
) as fout:
fout.writelines(str(node) + "\n" for node in classes)
print(f"classes.txt saved to {go_class_instance.processed_dir_main}")
print("Migration completed!")
print(f"classes_deep_go1.txt saved to {deepgo_migr_inst.processed_dir_main}")

print("Migration process completed!")


class Main:
Expand Down

0 comments on commit e0a8524

Please sign in to comment.