Skip to content

Commit

Permalink
migration : add esm2 embeddings
Browse files Browse the repository at this point in the history
- modify deepgo2 migration script to migrate the esm2 embeddings too
- modify migration class to use esm2 embeddings or reader features, based on input
  • Loading branch information
aditya0by0 committed Dec 9, 2024
1 parent 66732a7 commit e7b3d80
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
95 changes: 93 additions & 2 deletions chebai/preprocessing/datasets/deepGO/go_uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import pandas as pd
import requests
import torch
import tqdm
from Bio import SwissProt

from chebai.preprocessing import reader as dr
Expand Down Expand Up @@ -892,12 +893,95 @@ class DeepGO2MigratedData(_DeepGOMigratedData):
dict: Dictionary with file names specific to DeepGO2.
"""

def __init__(self, **kwargs):
_LABELS_START_IDX: int = 5 # additional esm2_embeddings column in the dataframe
_ESM_EMBEDDINGS_COL_IDX: int = 4

def __init__(self, use_esm2_embeddings=False, **kwargs):
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
assert int(kwargs.get("max_sequence_length")) == 1000

self.use_esm2_embeddings: bool = use_esm2_embeddings
super(_DeepGOMigratedData, self).__init__(**kwargs)

# ------------------------------ Phase: Setup data -----------------------------------
def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
"""
Load and process data from a file into a list of dictionaries containing features and labels.
This method processes data differently based on the `use_esm2_embeddings` flag:
- If `use_esm2_embeddings` is True, raw dictionaries from `_load_dict` are returned, _load_dict already returns
the numerical features (esm2 embeddings) from the data file, hence no reader is required.
- Otherwise, a reader is used to process the data (generate numerical features).
Args:
path (str): The path to the input file.
Returns:
List[Dict[str, Any]]: A list of dictionaries with the following keys:
- `features`: Sequence or embedding data, depending on the context.
- `labels`: A boolean array of labels.
- `ident`: The identifier for the sequence.
"""
lines = self._get_data_size(path)
print(f"Processing {lines} lines...")

if self.use_esm2_embeddings:
data = [
d
for d in tqdm.tqdm(self._load_dict(path), total=lines)
if d["features"] is not None
]
else:
data = [
self.reader.to_data(d)
for d in tqdm.tqdm(self._load_dict(path), total=lines)
if d["features"] is not None
]

# filter for missing features in resulting data
data = [val for val in data if val["features"] is not None]

return data

def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
"""
Loads data from a pickled file and yields individual dictionaries for each row.
The pickled file is expected to contain rows with the following structure:
- Data at row index `self._ID_IDX`: ID of go data instance
- Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein
- Data at row index `self._ESM2_EMBEDDINGS_COL_IDX`: ESM2 embeddings of the protein
- Data from row index `self._LABELS_START_IDX` onwards: Labels
The method adapts based on the `use_esm2_embeddings` flag:
- If `use_esm2_embeddings` is True, features are loaded from the column specified by `self._ESM_EMBEDDINGS_COL_IDX`.
- Otherwise, features are loaded from the column specified by `self._DATA_REPRESENTATION_IDX`.
Args:
input_file_path (str): The path to the pickled input file.
Yields:
Dict[str, Any]: A dictionary containing:
- `features` (Any): Sequence or embedding data for the instance.
- `labels` (np.ndarray): A boolean array of labels starting from row index 4.
- `ident` (Any): The identifier from row index 0.
"""
with open(input_file_path, "rb") as input_file:
df = pd.read_pickle(input_file)

if self.use_esm2_embeddings:
features_idx = self._ESM_EMBEDDINGS_COL_IDX
else:
features_idx = self._DATA_REPRESENTATION_IDX

for row in df.values:
labels = row[self._LABELS_START_IDX :].astype(bool)
yield dict(
features=row[features_idx],
labels=labels,
ident=row[self._ID_IDX],
)

# ------------------------------ Phase: Raw Properties -----------------------------------
@property
def processed_main_file_names_dict(self) -> Dict[str, str]:
"""
Expand All @@ -917,3 +1001,10 @@ def processed_file_names_dict(self) -> Dict[str, str]:
dict: Dictionary with data file name for DeepGO2.
"""
return {"data": "data_deep_go2.pt"}

@property
def identifier(self) -> tuple:
"""Identifier for the dataset."""
if self.use_esm2_embeddings:
return (dr.ESM2EmbeddingReader.name(),)
return (self.reader.name(),)
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
"exp_annotations", # Directly associated GO ids
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69
"prop_annotations", # Transitively associated GO ids
"esm2",
]

new_df = pd.concat(
Expand All @@ -239,6 +240,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
accession=new_df["accessions"],
go_ids=new_df["go_ids"],
sequence=new_df["sequences"],
esm2_embeddings=new_df["esm2"],
)
)
return data_df
Expand Down
3 changes: 2 additions & 1 deletion chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,8 @@ def load_hub_workaround(self, url) -> torch.Tensor:
)
return data

def name(self) -> None:
@staticmethod
def name() -> None:
"""
Returns the name of the data reader. This method identifies the specific type of data reader.
Expand Down

0 comments on commit e7b3d80

Please sign in to comment.