Skip to content

Commit

Permalink
fix data preparation
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 5, 2023
1 parent a21e3f8 commit e0f72ae
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 89 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,3 @@ cython_debug/
#.idea/

configs/
# the notebook I put in the wrong folder
chebai/preprocessing/datasets/demo_old_chebi.ipynb
demo_examine_pretraining_data.ipynb
6 changes: 5 additions & 1 deletion chebai/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.fabric.utilities.types import _PATH

from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer, LightningModule
import os
from lightning.fabric.utilities.cloud_io import _is_dir
from lightning.pytorch.utilities.rank_zero import rank_zero_info

class CustomModelCheckpoint(ModelCheckpoint):
"""Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the
Expand Down
133 changes: 49 additions & 84 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,11 @@ def select_classes(self, g, split_name, *args, **kwargs):
raise NotImplementedError

def graph_to_raw_dataset(self, g, split_name=None):
"""Preparation step before creating splits, uses graph created by extract_class_hierarchy()
"""Preparation step before creating splits, uses graph created by extract_class_hierarchy(),
split_name is only relevant, if a separate train_version is set"""
smiles = nx.get_node_attributes(g, "smiles")
names = nx.get_node_attributes(g, "name")

print("build labels")
print(f"Process graph")

molecules, smiles_list = zip(
Expand Down Expand Up @@ -199,68 +198,50 @@ def setup_processed(self):
self._setup_pruned_test_set()
self.reader.save_token_cache()

def get_splits(self, df: pd.DataFrame):
print("Split dataset")
def get_test_split(self, df: pd.DataFrame):
print("Split dataset into train (including val) / test")

df_list = df.values.tolist()
df_list = [row[3:] for row in df_list]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0)
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=0)

train_split = []
test_split = []
for (train_split, test_split) in msss.split(
df_list, df_list,
df_list, df_list,
):
train_split = train_split
test_split = test_split
break
df_train = df.iloc[train_split]
df_test = df.iloc[test_split]
if self.use_inner_cross_validation:
return df_train, df_test
return df_train, df_test

df_test_list = df_test.values.tolist()
df_test_list = [row[3:] for row in df_test_list]
validation_split = []
test_split = []
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0)
for (test_split, validation_split) in msss.split(
df_test_list, df_test_list
):
test_split = test_split
validation_split = validation_split
break
def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame):
""" Use test set (e.g., loaded from another chebi version or generated in get_test_split), avoid overlap"""
print(f"Split dataset into train / val with given test set")

df_validation = df_test.iloc[validation_split]
df_test = df_test.iloc[test_split]
return df_train, df_test, df_validation

def get_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame):
""" Use test set from another chebi version the model does not train on, avoid overlap"""
print(f"Split dataset for chebi_v{self.chebi_version_train}")
df_trainval = df
test_smiles = test_df['SMILES'].tolist()
mask = []
for row in df_trainval:
if row['SMILES'] in test_smiles:
mask.append(False)
else:
mask.append(True)
mask = [smiles not in test_smiles for smiles in df_trainval['SMILES']]
df_trainval = df_trainval[mask]


if self.use_inner_cross_validation:
return df_trainval

# assume that size of validation split should relate to train split as in get_splits()
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=self.train_split ** 2, random_state=0)
# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
test_size = ((1 - self.train_split) ** 2) / self.train_split
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=0)

df_trainval_list = df_trainval.tolist()
df_trainval_list = df_trainval.values.tolist()
df_trainval_list = [row[3:] for row in df_trainval_list]
train_split = []
validation_split = []
for (train_split, validation_split) in msss.split(
df_trainval_list, df_trainval_list
df_trainval_list, df_trainval_list
):
train_split = train_split
validation_split = validation_split
Expand Down Expand Up @@ -309,6 +290,16 @@ def processed_file_names(self):
def raw_file_names(self):
return list(self.raw_file_names_dict.values())

def _load_chebi(self, version: int):
chebi_name = f'chebi.obo' if version == self.chebi_version else f'chebi_v{version}.obo'
chebi_path = os.path.join(self.raw_dir, chebi_name)
if not os.path.isfile(chebi_path):
print(f"Load ChEBI ontology (v_{version})")
url = f"http://purl.obolibrary.org/obo/chebi/{version}/chebi.obo"
r = requests.get(url, allow_redirects=True)
open(chebi_path, "wb").write(r.content)
return chebi_path

def prepare_data(self, *args, **kwargs):
print("Check for raw data in", self.raw_dir)
if any(
Expand All @@ -317,56 +308,30 @@ def prepare_data(self, *args, **kwargs):
):
os.makedirs(self.raw_dir, exist_ok=True)
print("Missing raw data. Go fetch...")
if self.chebi_version_train is None:
# load chebi_v{chebi_version}, create splits
chebi_path = os.path.join(self.raw_dir, f"chebi.obo")
if not os.path.isfile(chebi_path):
print("Load ChEBI ontology")
url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version}/chebi.obo"
r = requests.get(url, allow_redirects=True)
open(chebi_path, "wb").write(r.content)
# missing test set -> create
if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])):
chebi_path = self._load_chebi(self.chebi_version)
g = extract_class_hierarchy(chebi_path)
splits = {}
full_data = self.graph_to_raw_dataset(g)
if self.use_inner_cross_validation:
splits['train_val'], splits['test'] = self.get_splits(full_data)
else:
splits['train'], splits['test'], splits['validation'] = self.get_splits(full_data)
for label, split in splits.items():
self.save(split, self.raw_file_names_dict[label])
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test'])
_, test_df = self.get_test_split(df)
self.save(test_df, self.raw_file_names_dict['test'])
# load test_split from file
else:
# missing test set -> create
if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])):
chebi_path = os.path.join(self.raw_dir, f"chebi.obo")
if not os.path.isfile(chebi_path):
print("Load ChEBI ontology")
url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version}/chebi.obo"
r = requests.get(url, allow_redirects=True)
open(chebi_path, "wb").write(r.content)
g = extract_class_hierarchy(chebi_path)
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test'])
_, test_split, _ = self.get_splits(df)
self.save(df, self.raw_file_names_dict['test'])
else:
# load test_split from file
with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file:
test_split = [row[0] for row in pickle.load(input_file).values]
chebi_path = os.path.join(self.raw_dir, f"chebi_v{self.chebi_version_train}.obo")
if not os.path.isfile(chebi_path):
print(f"Load ChEBI ontology (v_{self.chebi_version_train})")
url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version_train}/chebi.obo"
r = requests.get(url, allow_redirects=True)
open(chebi_path, "wb").write(r.content)
g = extract_class_hierarchy(chebi_path)
if self.use_inner_cross_validation:
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val'])
train_val_df = self.get_splits_given_test(df, test_split)
self.save(train_val_df, self.raw_file_names_dict['train_val'])
else:
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train'])
train_split, val_split = self.get_splits_given_test(df, test_split)
self.save(train_split, self.raw_file_names_dict['train'])
self.save(val_split, self.raw_file_names_dict['validation'])
with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file:
test_df = pickle.load(input_file)
# create train/val split based on test set
chebi_path = self._load_chebi(
self.chebi_version_train if self.chebi_version_train is not None else self.chebi_version)
g = extract_class_hierarchy(chebi_path)
if self.use_inner_cross_validation:
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val'])
train_val_df = self.get_train_val_splits_given_test(df, test_df)
self.save(train_val_df, self.raw_file_names_dict['train_val'])
else:
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train'])
train_split, val_split = self.get_train_val_splits_given_test(df, test_df)
self.save(train_split, self.raw_file_names_dict['train'])
self.save(val_split, self.raw_file_names_dict['validation'])


class JCIExtendedBase(_ChEBIDataExtractor):
Expand Down
2 changes: 1 addition & 1 deletion chebai/trainer/InnerCVTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar
train_dataloader = datamodule.train_dataloader(ids=train_ids)
val_dataloader = datamodule.val_dataloader(ids=val_ids)
init_kwargs = self.init_kwargs
new_trainer = Trainer(*self.init_args, **init_kwargs)
new_trainer = InnerCVTrainer(*self.init_args, **init_kwargs)
logger = new_trainer.logger
if isinstance(logger, CustomLogger):
logger.set_fold(fold)
Expand Down

0 comments on commit e0f72ae

Please sign in to comment.