diff --git a/README.md b/README.md index 68015f4d..c824a171 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,18 @@ one row for each SMILES string and one column for each class. ## Cross-validation -Use inner cross-validation by not splitting between test and validation sets at dataset creation, -but using k-fold cross-validation at runtime. This creates k models with separate metrics and checkpoints. -For training with `k`-fold cross-validation, use the `cv-fit` subcommand and the options +You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test +set. For that, you need to specify the total_number of folds as ``` ---data.init_args.inner_k_folds=k --n_splits=k +--data.init_args.inner_k_folds=K ``` +and the fold to be used in the current optimisation run as +``` +--data.init_args.fold_index=I +``` +To train K models, you need to do K such calls, each with a different `fold_index`. On the first call with a given +`inner_k_folds`, all folds will be created and stored in the data directory + ## Chebi versions Change the chebi version used for all sets (default: 200): ``` @@ -51,3 +57,20 @@ To change only the version of the train and validation sets independently of the ``` --data.init_args.chebi_version_train=VERSION ``` + +## Data folder structure +Data is stored in and retrieved from the raw and processed folders +``` +data/${dataset_name}/${chebi_version}/raw/ +``` +and +``` +data/${dataset_name}/${chebi_version}/processed/${reader_name}/ +``` +where `${dataset_name}` is the `_name`-attribute of the `DataModule` used, +`${chebi_version}` refers to the ChEBI version used (only for ChEBI-datasets) and +`${reader_name}` is the `name`-attribute of the `Reader` class associated with the dataset. + +For cross-validation, the folds are stored as `cv_${n_folds}_fold/fold_{fold_index}_train.pkl` +and `cv_${n_folds}_fold/fold_{fold_index}_validation.pkl` in the raw directory. +In the processed directory, `.pt` is used instead of `.pkl`. \ No newline at end of file diff --git a/chebai/cli.py b/chebai/cli.py index 6d53b9fe..9d7c2f8a 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -29,7 +29,6 @@ def subcommands() -> Dict[str, Set[str]]: "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, - "cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, "predict_from_file": {"model"}, } diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index b76de163..622551bb 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -28,6 +28,7 @@ def __init__( num_workers: int = 1, chebi_version: int = 200, inner_k_folds: int = -1, # use inner cross-validation if > 1 + fold_index: typing.Optional[int] = None, base_dir=None, **kwargs, ): @@ -51,9 +52,20 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds + assert ( + fold_index is None or self.use_inner_cross_validation is not None + ), "fold_index can only be set if cross validation is used" + if fold_index is not None and self.inner_k_folds is not None: + assert ( + fold_index < self.inner_k_folds + ), "fold_index can't be larger than the total number of folds" + self.fold_index = fold_index self._base_dir = base_dir os.makedirs(self.raw_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True) + if self.use_inner_cross_validation: + os.makedirs(os.path.join(self.raw_dir, self.fold_dir), exist_ok=True) + os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True) @property def identifier(self): @@ -78,6 +90,11 @@ def processed_dir(self): def raw_dir(self): return os.path.join(self.base_dir, "raw") + @property + def fold_dir(self): + """name of dir where the folds from inner cross-validation (i.e., the train and val sets) are stored""" + return f"cv_{self.inner_k_folds}_fold" + @property def _name(self): raise NotImplementedError @@ -95,7 +112,12 @@ def load_processed_data(self, kind: str = None, filename: str = None) -> List: if kind is not None and filename is None: try: # processed_file_names_dict is only implemented for _ChEBIDataExtractor - filename = self.processed_file_names_dict[kind] + if self.use_inner_cross_validation and kind != "test": + filename = self.processed_file_names_dict[ + f"fold_{self.fold_index}_{kind}" + ] + else: + filename = self.processed_file_names_dict[kind] except NotImplementedError: filename = f"{kind}.pt" return torch.load(os.path.join(self.processed_dir, filename)) @@ -158,7 +180,7 @@ def _load_data_from_file(self, path): def train_dataloader(self, *args, **kwargs) -> DataLoader: return self.dataloader( - "train" if not self.use_inner_cross_validation else "train_val", + "train", shuffle=True, num_workers=self.num_workers, persistent_workers=True, @@ -167,7 +189,7 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader: def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: return self.dataloader( - "validation" if not self.use_inner_cross_validation else "train_val", + "validation", shuffle=False, num_workers=self.num_workers, persistent_workers=True, @@ -191,13 +213,6 @@ def setup(self, **kwargs): ): self.setup_processed() - if self.use_inner_cross_validation: - self.train_val_data = torch.load( - os.path.join( - self.processed_dir, self.processed_file_names_dict["train_val"] - ) - ) - if not ("keep_reader" in kwargs and kwargs["keep_reader"]): self.reader.on_finish() diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 8bf1d12d..0b673d29 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -15,7 +15,10 @@ import pickle import random -from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit +from iterstrat.ml_stratifiers import ( + MultilabelStratifiedShuffleSplit, + MultilabelStratifiedKFold, +) import fastobo import networkx as nx @@ -116,7 +119,7 @@ def __init__( self.single_class = single_class super(_ChEBIDataExtractor, self).__init__(**kwargs) # use different version of chebi for training and validation (if not None) - # (still use self.chebi_version for test set) + # (still uses self.chebi_version for test set) self.chebi_version_train = chebi_version_train def select_classes(self, g, split_name, *args, **kwargs): @@ -150,8 +153,8 @@ def graph_to_raw_dataset(self, g, split_name=None): data = data[data.iloc[:, 3:].any(axis=1)] return data - def save(self, data: pd.DataFrame, split_name: str): - pickle.dump(data, open(os.path.join(self.raw_dir, split_name), "wb")) + def save_raw(self, data: pd.DataFrame, filename: str): + pickle.dump(data, open(os.path.join(self.raw_dir, filename), "wb")) def _load_dict(self, input_file_path): with open(input_file_path, "rb") as input_file: @@ -249,18 +252,32 @@ def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFram test_smiles = test_df["SMILES"].tolist() mask = [smiles not in test_smiles for smiles in df_trainval["SMILES"]] df_trainval = df_trainval[mask] + df_trainval_list = df_trainval.values.tolist() + df_trainval_list = [row[3:] for row in df_trainval_list] if self.use_inner_cross_validation: - return df_trainval + folds = {} + kfold = MultilabelStratifiedKFold(n_splits=self.inner_k_folds) + for fold, (train_ids, val_ids) in enumerate( + kfold.split( + df_trainval_list, + df_trainval_list, + ) + ): + df_validation = df_trainval.iloc[val_ids] + df_train = df_trainval.iloc[train_ids] + folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train + folds[ + self.raw_file_names_dict[f"fold_{fold}_validation"] + ] = df_validation + + return folds # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) test_size = ((1 - self.train_split) ** 2) / self.train_split msss = MultilabelStratifiedShuffleSplit( n_splits=1, test_size=test_size, random_state=0 ) - - df_trainval_list = df_trainval.values.tolist() - df_trainval_list = [row[3:] for row in df_trainval_list] train_split = [] validation_split = [] for train_split, validation_split in msss.split( @@ -271,7 +288,10 @@ def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFram df_validation = df_trainval.iloc[validation_split] df_train = df_trainval.iloc[train_split] - return df_train, df_validation + return { + self.raw_file_names_dict["train"]: df_train, + self.raw_file_names_dict["validation"]: df_validation, + } @property def processed_dir(self): @@ -295,13 +315,14 @@ def processed_file_names_dict(self) -> dict: f"_v{self.chebi_version_train}" if self.chebi_version_train else "" ) res = {"test": f"test{train_v_str}.pt"} - if self.use_inner_cross_validation: - res[ - "train_val" - ] = f"trainval{train_v_str}.pt" # for cv, split train/val on runtime - else: - res["train"] = f"train{train_v_str}.pt" - res["validation"] = f"validation{train_v_str}.pt" + for set in ["train", "validation"]: + if self.use_inner_cross_validation: + for i in range(self.inner_k_folds): + res[f"fold_{i}_{set}"] = os.path.join( + self.fold_dir, f"fold_{i}_{set}{train_v_str}.pt" + ) + else: + res[set] = f"{set}{train_v_str}.pt" return res @property @@ -313,14 +334,14 @@ def raw_file_names_dict(self) -> dict: "test": f"test.pkl" } # no extra raw test version for chebi_version_train - use default test set and only # adapt processed file - if self.use_inner_cross_validation: - res[ - "train_val" - ] = f"trainval{train_v_str}.pkl" # for cv, split train/val on runtime - else: - res["train"] = f"train{train_v_str}.pkl" - res["validation"] = f"validation{train_v_str}.pkl" - + for set in ["train", "validation"]: + if self.use_inner_cross_validation: + for i in range(self.inner_k_folds): + res[f"fold_{i}_{set}"] = os.path.join( + self.fold_dir, f"fold_{i}_{set}{train_v_str}.pkl" + ) + else: + res[set] = f"{set}{train_v_str}.pkl" return res @property @@ -359,7 +380,7 @@ def prepare_data(self, *args, **kwargs): g = extract_class_hierarchy(chebi_path) df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["test"]) _, test_df = self.get_test_split(df) - self.save(test_df, self.raw_file_names_dict["test"]) + self.save_raw(test_df, self.raw_file_names_dict["test"]) # load test_split from file else: with open( @@ -374,16 +395,14 @@ def prepare_data(self, *args, **kwargs): ) g = extract_class_hierarchy(chebi_path) if self.use_inner_cross_validation: - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["train_val"]) - train_val_df = self.get_train_val_splits_given_test(df, test_df) - self.save(train_val_df, self.raw_file_names_dict["train_val"]) + df = self.graph_to_raw_dataset( + g, self.raw_file_names_dict[f"fold_0_train"] + ) else: df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["train"]) - train_split, val_split = self.get_train_val_splits_given_test( - df, test_df - ) - self.save(train_split, self.raw_file_names_dict["train"]) - self.save(val_split, self.raw_file_names_dict["validation"]) + train_val_dict = self.get_train_val_splits_given_test(df, test_df) + for name, df in train_val_dict.items(): + self.save_raw(df, name) class JCIExtendedBase(_ChEBIDataExtractor): diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 431be8b5..e9cafa2b 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -37,41 +37,6 @@ def __init__(self, *args, **kwargs): # instantiation custom logger connector self._logger_connector.on_trainer_init(self.logger, 1) - def cv_fit(self, datamodule: XYBaseDataModule, *args, **kwargs): - n_splits = datamodule.inner_k_folds - if n_splits < 2: - self.fit(datamodule=datamodule, *args, **kwargs) - else: - datamodule.prepare_data() - datamodule.setup() - - kfold = MultilabelStratifiedKFold(n_splits=n_splits) - - for fold, (train_ids, val_ids) in enumerate( - kfold.split( - datamodule.train_val_data, - [data["labels"] for data in datamodule.train_val_data], - ) - ): - train_dataloader = datamodule.train_dataloader(ids=train_ids) - val_dataloader = datamodule.val_dataloader(ids=val_ids) - init_kwargs = self.init_kwargs - new_trainer = CustomTrainer(*self.init_args, **init_kwargs) - logger = new_trainer.logger - if isinstance(logger, CustomLogger): - logger.set_fold(fold) - rank_zero_info(f"Logging this fold at {logger.experiment.dir}") - else: - rank_zero_warn( - f"Using k-fold cross-validation without an adapted logger class" - ) - new_trainer.fit( - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloader, - *args, - **kwargs, - ) - def predict_from_file( self, model: LightningModule, diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml index 147c3500..0c6b860f 100644 --- a/configs/training/default_trainer.yml +++ b/configs/training/default_trainer.yml @@ -1,5 +1,5 @@ min_epochs: 100 max_epochs: 100 default_root_dir: &default_root_dir logs -logger: csv_logger.yml +logger: wandb_logger.yml callbacks: default_callbacks.yml \ No newline at end of file