Skip to content

Commit

Permalink
migration fix : truncate seq and save data with labels
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya0by0 committed Dec 4, 2024
1 parent f75e30b commit 1b8b270
Showing 1 changed file with 49 additions and 13 deletions.
62 changes: 49 additions & 13 deletions chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,28 @@ class DeepGo2DataMigration:
(https://doi.org/10.1093/bioinformatics/btx624)
"""

# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
_MAXLEN = 1000
_LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX

def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
def __init__(
self, data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000
):
"""
Initializes the data migration object with a data directory and GO branch.
Args:
data_dir (str): Directory containing the data files.
go_branch (Literal["cc", "mf", "bp"]): GO branch to use.
max_len (int): Used to truncate the sequence to this length. Default is 1000.
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
"""
valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys())
if go_branch not in valid_go_branches:
raise ValueError(f"go_branch must be one of {valid_go_branches}")
self._go_branch = go_branch

self._data_dir: str = os.path.join(rf"{data_dir}", go_branch)
self._max_len: int = max_len

self._train_df: Optional[pd.DataFrame] = None
self._test_df: Optional[pd.DataFrame] = None
self._validation_df: Optional[pd.DataFrame] = None
Expand Down Expand Up @@ -74,33 +78,61 @@ def migrate(self) -> None:
"Data splits or terms data is not available in instance variables."
)

self.save_migrated_data(data_df, splits_df)
self.save_migrated_data(data_with_labels_df, splits_df)

def _load_data(self) -> None:
"""
Loads the test, train, validation, and terms data from the pickled files
in the data directory.
"""

try:
print(f"Loading data from directory: {self._data_dir}......")
self._test_df = pd.DataFrame(
pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl"))
self._test_df = self._truncate_sequences(
pd.DataFrame(
pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl"))
)
)
self._train_df = pd.DataFrame(
pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl"))
self._train_df = self._truncate_sequences(
pd.DataFrame(
pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl"))
)
)
self._validation_df = pd.DataFrame(
pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl"))
self._validation_df = self._truncate_sequences(
pd.DataFrame(
pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl"))
)
)

self._terms_df = pd.DataFrame(
pd.read_pickle(os.path.join(self._data_dir, "terms.pkl"))
)

except FileNotFoundError as e:
raise FileNotFoundError(
f"Data file not found in directory: {e}. "
"Please ensure all required files are available in the specified directory."
)

def _truncate_sequences(
self, df: pd.DataFrame, column: str = "sequences"
) -> pd.DataFrame:
"""
Truncate sequences in a specified column of a dataframe to the maximum length.
https://github.com/bio-ontology-research-group/deepgo2/blob/main/train_cnn.py#L206-L217
Args:
df (pd.DataFrame): The input dataframe containing the data to be processed.
column (str, optional): The column containing sequences to truncate.
Defaults to "sequences".
Returns:
pd.DataFrame: The dataframe with sequences truncated to `self._max_len`.
"""
df[column] = df[column].apply(lambda x: x[: self._max_len])
return df

def _record_splits(self) -> pd.DataFrame:
"""
Creates a DataFrame that stores the IDs and their corresponding data splits.
Expand Down Expand Up @@ -217,7 +249,7 @@ def save_migrated_data(
print("Saving transformed data......")
deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData(
go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch],
max_sequence_length=self._MAXLEN,
max_sequence_length=self._max_len,
)

# Save data file
Expand Down Expand Up @@ -257,7 +289,9 @@ class Main:
"""

@staticmethod
def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None:
def migrate(
data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000
) -> None:
"""
Initiates the migration process by creating a DeepGoDataMigration instance
and invoking its migrate method.
Expand All @@ -268,8 +302,10 @@ def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None:
("cc" for cellular_component,
"mf" for molecular_function,
or "bp" for biological_process).
max_len (int): Used to truncate the sequence to this length. Default is 1000.
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
"""
DeepGo2DataMigration(data_dir, go_branch).migrate()
DeepGo2DataMigration(data_dir, go_branch, max_len).migrate()


if __name__ == "__main__":
Expand Down

0 comments on commit 1b8b270

Please sign in to comment.