Skip to content

Commit

Permalink
reimplement cross-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Jan 3, 2024
1 parent 8e55600 commit c108686
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 84 deletions.
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ one row for each SMILES string and one column for each class.


## Cross-validation
Use inner cross-validation by not splitting between test and validation sets at dataset creation,
but using k-fold cross-validation at runtime. This creates k models with separate metrics and checkpoints.
For training with `k`-fold cross-validation, use the `cv-fit` subcommand and the options
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test
set. For that, you need to specify the total_number of folds as
```
--data.init_args.inner_k_folds=k --n_splits=k
--data.init_args.inner_k_folds=K
```
and the fold to be used in the current optimisation run as
```
--data.init_args.fold_index=I
```
To train K models, you need to do K such calls, each with a different `fold_index`. On the first call with a given
`inner_k_folds`, all folds will be created and stored in the data directory

## Chebi versions
Change the chebi version used for all sets (default: 200):
```
Expand All @@ -51,3 +57,20 @@ To change only the version of the train and validation sets independently of the
```
--data.init_args.chebi_version_train=VERSION
```

## Data folder structure
Data is stored in and retrieved from the raw and processed folders
```
data/${dataset_name}/${chebi_version}/raw/
```
and
```
data/${dataset_name}/${chebi_version}/processed/${reader_name}/
```
where `${dataset_name}` is the `_name`-attribute of the `DataModule` used,
`${chebi_version}` refers to the ChEBI version used (only for ChEBI-datasets) and
`${reader_name}` is the `name`-attribute of the `Reader` class associated with the dataset.

For cross-validation, the folds are stored as `cv_${n_folds}_fold/fold_{fold_index}_train.pkl`
and `cv_${n_folds}_fold/fold_{fold_index}_validation.pkl` in the raw directory.
In the processed directory, `.pt` is used instead of `.pkl`.
1 change: 0 additions & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def subcommands() -> Dict[str, Set[str]]:
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
"predict_from_file": {"model"},
}

Expand Down
35 changes: 25 additions & 10 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
num_workers: int = 1,
chebi_version: int = 200,
inner_k_folds: int = -1, # use inner cross-validation if > 1
fold_index: typing.Optional[int] = None,
base_dir=None,
**kwargs,
):
Expand All @@ -51,9 +52,20 @@ def __init__(
self.use_inner_cross_validation = (
inner_k_folds > 1
) # only use cv if there are at least 2 folds
assert (
fold_index is None or self.use_inner_cross_validation is not None
), "fold_index can only be set if cross validation is used"
if fold_index is not None and self.inner_k_folds is not None:
assert (
fold_index < self.inner_k_folds
), "fold_index can't be larger than the total number of folds"
self.fold_index = fold_index
self._base_dir = base_dir
os.makedirs(self.raw_dir, exist_ok=True)
os.makedirs(self.processed_dir, exist_ok=True)
if self.use_inner_cross_validation:
os.makedirs(os.path.join(self.raw_dir, self.fold_dir), exist_ok=True)
os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True)

@property
def identifier(self):
Expand All @@ -78,6 +90,11 @@ def processed_dir(self):
def raw_dir(self):
return os.path.join(self.base_dir, "raw")

@property
def fold_dir(self):
"""name of dir where the folds from inner cross-validation (i.e., the train and val sets) are stored"""
return f"cv_{self.inner_k_folds}_fold"

@property
def _name(self):
raise NotImplementedError
Expand All @@ -95,7 +112,12 @@ def load_processed_data(self, kind: str = None, filename: str = None) -> List:
if kind is not None and filename is None:
try:
# processed_file_names_dict is only implemented for _ChEBIDataExtractor
filename = self.processed_file_names_dict[kind]
if self.use_inner_cross_validation and kind != "test":
filename = self.processed_file_names_dict[
f"fold_{self.fold_index}_{kind}"
]
else:
filename = self.processed_file_names_dict[kind]
except NotImplementedError:
filename = f"{kind}.pt"
return torch.load(os.path.join(self.processed_dir, filename))
Expand Down Expand Up @@ -158,7 +180,7 @@ def _load_data_from_file(self, path):

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return self.dataloader(
"train" if not self.use_inner_cross_validation else "train_val",
"train",
shuffle=True,
num_workers=self.num_workers,
persistent_workers=True,
Expand All @@ -167,7 +189,7 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader:

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return self.dataloader(
"validation" if not self.use_inner_cross_validation else "train_val",
"validation",
shuffle=False,
num_workers=self.num_workers,
persistent_workers=True,
Expand All @@ -191,13 +213,6 @@ def setup(self, **kwargs):
):
self.setup_processed()

if self.use_inner_cross_validation:
self.train_val_data = torch.load(
os.path.join(
self.processed_dir, self.processed_file_names_dict["train_val"]
)
)

if not ("keep_reader" in kwargs and kwargs["keep_reader"]):
self.reader.on_finish()

Expand Down
85 changes: 52 additions & 33 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import pickle
import random

from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from iterstrat.ml_stratifiers import (
MultilabelStratifiedShuffleSplit,
MultilabelStratifiedKFold,
)

import fastobo
import networkx as nx
Expand Down Expand Up @@ -116,7 +119,7 @@ def __init__(
self.single_class = single_class
super(_ChEBIDataExtractor, self).__init__(**kwargs)
# use different version of chebi for training and validation (if not None)
# (still use self.chebi_version for test set)
# (still uses self.chebi_version for test set)
self.chebi_version_train = chebi_version_train

def select_classes(self, g, split_name, *args, **kwargs):
Expand Down Expand Up @@ -150,8 +153,8 @@ def graph_to_raw_dataset(self, g, split_name=None):
data = data[data.iloc[:, 3:].any(axis=1)]
return data

def save(self, data: pd.DataFrame, split_name: str):
pickle.dump(data, open(os.path.join(self.raw_dir, split_name), "wb"))
def save_raw(self, data: pd.DataFrame, filename: str):
pickle.dump(data, open(os.path.join(self.raw_dir, filename), "wb"))

def _load_dict(self, input_file_path):
with open(input_file_path, "rb") as input_file:
Expand Down Expand Up @@ -249,18 +252,32 @@ def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFram
test_smiles = test_df["SMILES"].tolist()
mask = [smiles not in test_smiles for smiles in df_trainval["SMILES"]]
df_trainval = df_trainval[mask]
df_trainval_list = df_trainval.values.tolist()
df_trainval_list = [row[3:] for row in df_trainval_list]

if self.use_inner_cross_validation:
return df_trainval
folds = {}
kfold = MultilabelStratifiedKFold(n_splits=self.inner_k_folds)
for fold, (train_ids, val_ids) in enumerate(
kfold.split(
df_trainval_list,
df_trainval_list,
)
):
df_validation = df_trainval.iloc[val_ids]
df_train = df_trainval.iloc[train_ids]
folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train
folds[
self.raw_file_names_dict[f"fold_{fold}_validation"]
] = df_validation

return folds

# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
test_size = ((1 - self.train_split) ** 2) / self.train_split
msss = MultilabelStratifiedShuffleSplit(
n_splits=1, test_size=test_size, random_state=0
)

df_trainval_list = df_trainval.values.tolist()
df_trainval_list = [row[3:] for row in df_trainval_list]
train_split = []
validation_split = []
for train_split, validation_split in msss.split(
Expand All @@ -271,7 +288,10 @@ def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFram

df_validation = df_trainval.iloc[validation_split]
df_train = df_trainval.iloc[train_split]
return df_train, df_validation
return {
self.raw_file_names_dict["train"]: df_train,
self.raw_file_names_dict["validation"]: df_validation,
}

@property
def processed_dir(self):
Expand All @@ -295,13 +315,14 @@ def processed_file_names_dict(self) -> dict:
f"_v{self.chebi_version_train}" if self.chebi_version_train else ""
)
res = {"test": f"test{train_v_str}.pt"}
if self.use_inner_cross_validation:
res[
"train_val"
] = f"trainval{train_v_str}.pt" # for cv, split train/val on runtime
else:
res["train"] = f"train{train_v_str}.pt"
res["validation"] = f"validation{train_v_str}.pt"
for set in ["train", "validation"]:
if self.use_inner_cross_validation:
for i in range(self.inner_k_folds):
res[f"fold_{i}_{set}"] = os.path.join(
self.fold_dir, f"fold_{i}_{set}{train_v_str}.pt"
)
else:
res[set] = f"{set}{train_v_str}.pt"
return res

@property
Expand All @@ -313,14 +334,14 @@ def raw_file_names_dict(self) -> dict:
"test": f"test.pkl"
} # no extra raw test version for chebi_version_train - use default test set and only
# adapt processed file
if self.use_inner_cross_validation:
res[
"train_val"
] = f"trainval{train_v_str}.pkl" # for cv, split train/val on runtime
else:
res["train"] = f"train{train_v_str}.pkl"
res["validation"] = f"validation{train_v_str}.pkl"

for set in ["train", "validation"]:
if self.use_inner_cross_validation:
for i in range(self.inner_k_folds):
res[f"fold_{i}_{set}"] = os.path.join(
self.fold_dir, f"fold_{i}_{set}{train_v_str}.pkl"
)
else:
res[set] = f"{set}{train_v_str}.pkl"
return res

@property
Expand Down Expand Up @@ -359,7 +380,7 @@ def prepare_data(self, *args, **kwargs):
g = extract_class_hierarchy(chebi_path)
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["test"])
_, test_df = self.get_test_split(df)
self.save(test_df, self.raw_file_names_dict["test"])
self.save_raw(test_df, self.raw_file_names_dict["test"])
# load test_split from file
else:
with open(
Expand All @@ -374,16 +395,14 @@ def prepare_data(self, *args, **kwargs):
)
g = extract_class_hierarchy(chebi_path)
if self.use_inner_cross_validation:
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["train_val"])
train_val_df = self.get_train_val_splits_given_test(df, test_df)
self.save(train_val_df, self.raw_file_names_dict["train_val"])
df = self.graph_to_raw_dataset(
g, self.raw_file_names_dict[f"fold_0_train"]
)
else:
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["train"])
train_split, val_split = self.get_train_val_splits_given_test(
df, test_df
)
self.save(train_split, self.raw_file_names_dict["train"])
self.save(val_split, self.raw_file_names_dict["validation"])
train_val_dict = self.get_train_val_splits_given_test(df, test_df)
for name, df in train_val_dict.items():
self.save_raw(df, name)


class JCIExtendedBase(_ChEBIDataExtractor):
Expand Down
35 changes: 0 additions & 35 deletions chebai/trainer/CustomTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,41 +37,6 @@ def __init__(self, *args, **kwargs):
# instantiation custom logger connector
self._logger_connector.on_trainer_init(self.logger, 1)

def cv_fit(self, datamodule: XYBaseDataModule, *args, **kwargs):
n_splits = datamodule.inner_k_folds
if n_splits < 2:
self.fit(datamodule=datamodule, *args, **kwargs)
else:
datamodule.prepare_data()
datamodule.setup()

kfold = MultilabelStratifiedKFold(n_splits=n_splits)

for fold, (train_ids, val_ids) in enumerate(
kfold.split(
datamodule.train_val_data,
[data["labels"] for data in datamodule.train_val_data],
)
):
train_dataloader = datamodule.train_dataloader(ids=train_ids)
val_dataloader = datamodule.val_dataloader(ids=val_ids)
init_kwargs = self.init_kwargs
new_trainer = CustomTrainer(*self.init_args, **init_kwargs)
logger = new_trainer.logger
if isinstance(logger, CustomLogger):
logger.set_fold(fold)
rank_zero_info(f"Logging this fold at {logger.experiment.dir}")
else:
rank_zero_warn(
f"Using k-fold cross-validation without an adapted logger class"
)
new_trainer.fit(
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
*args,
**kwargs,
)

def predict_from_file(
self,
model: LightningModule,
Expand Down
2 changes: 1 addition & 1 deletion configs/training/default_trainer.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
min_epochs: 100
max_epochs: 100
default_root_dir: &default_root_dir logs
logger: csv_logger.yml
logger: wandb_logger.yml
callbacks: default_callbacks.yml

0 comments on commit c108686

Please sign in to comment.