Skip to content

Commit

Permalink
Merge pull request #6 from ChEB-AI/features-sfluegel
Browse files Browse the repository at this point in the history
Various features
  • Loading branch information
sfluegel05 authored Jan 5, 2024
2 parents 6fa8da2 + 6c72805 commit 0ccd7fb
Show file tree
Hide file tree
Showing 50 changed files with 4,897 additions and 1,993 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
repos:
- repo: https://github.com/PyCQA/isort
rev: "5.12.0"
hooks:
- id: isort
#- repo: https://github.com/PyCQA/isort
# rev: "5.12.0"
# hooks:
# - id: isort
- repo: https://github.com/psf/black
rev: "22.10.0"
hooks:
Expand Down
51 changes: 50 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,53 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co

```
python -m chebai train --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --ckpt_path=[path-to-model-with-ontology-pretraining]
```
```

## Predicting classes given SMILES strings

```
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
```
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
one row for each SMILES string and one column for each class.


## Cross-validation
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test
set. For that, you need to specify the total_number of folds as
```
--data.init_args.inner_k_folds=K
```
and the fold to be used in the current optimisation run as
```
--data.init_args.fold_index=I
```
To train K models, you need to do K such calls, each with a different `fold_index`. On the first call with a given
`inner_k_folds`, all folds will be created and stored in the data directory

## Chebi versions
Change the chebi version used for all sets (default: 200):
```
--data.init_args.chebi_version=VERSION
```
To change only the version of the train and validation sets independently of the test set, use
```
--data.init_args.chebi_version_train=VERSION
```

## Data folder structure
Data is stored in and retrieved from the raw and processed folders
```
data/${dataset_name}/${chebi_version}/raw/
```
and
```
data/${dataset_name}/${chebi_version}/processed/${reader_name}/
```
where `${dataset_name}` is the `_name`-attribute of the `DataModule` used,
`${chebi_version}` refers to the ChEBI version used (only for ChEBI-datasets) and
`${reader_name}` is the `name`-attribute of the `Reader` class associated with the dataset.

For cross-validation, the folds are stored as `cv_${n_folds}_fold/fold_{fold_index}_train.pkl`
and `cv_${n_folds}_fold/fold_{fold_index}_validation.pkl` in the raw directory.
In the processed directory, `.pt` is used instead of `.pkl`.
1 change: 1 addition & 0 deletions chebai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import torch

MODULE_PATH = os.path.abspath(os.path.dirname(__file__))
Expand Down
5 changes: 3 additions & 2 deletions chebai/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import os

from lightning.pytorch.callbacks import BasePredictionWriter
import torch
import os
import json


class ChebaiPredictionWriter(BasePredictionWriter):
Expand Down
Empty file added chebai/callbacks/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import torchmetrics


def custom_reduce_fx(input):
print(f"called reduce (device: {input.device})")
return torch.sum(input, dim=0)


class MacroF1(torchmetrics.Metric):
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.add_state(
"true_positives",
default=torch.zeros(num_labels, dtype=torch.int),
dist_reduce_fx="sum",
)
self.add_state(
"positive_predictions",
default=torch.zeros(num_labels, dtype=torch.int),
dist_reduce_fx="sum",
)
self.add_state(
"positive_labels",
default=torch.zeros(num_labels, dtype=torch.int),
dist_reduce_fx="sum",
)
self.threshold = threshold

def update(self, preds: torch.Tensor, labels: torch.Tensor):
tps = torch.sum(
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
)
self.true_positives += tps
self.positive_predictions += torch.sum(preds > self.threshold, dim=0)
self.positive_labels += torch.sum(labels, dim=0)

def compute(self):
# ignore classes without positive labels
# classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0),
# which is propagated to the classwise_f1 and then turned into 0
mask = self.positive_labels != 0
precision = self.true_positives[mask] / self.positive_predictions[mask]
recall = self.true_positives[mask] / self.positive_labels[mask]
classwise_f1 = 2 * precision * recall / (precision + recall)
# if (precision and recall are 0) or (precision is nan), set f1 to 0
classwise_f1 = classwise_f1.nan_to_num()
return torch.mean(classwise_f1)
63 changes: 63 additions & 0 deletions chebai/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

from lightning.fabric.utilities.cloud_io import _is_dir
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities.rank_zero import rank_zero_info
from lightning_utilities.core.rank_zero import rank_zero_warn


class CustomModelCheckpoint(ModelCheckpoint):
"""Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the
same directory as the other logs"""

def setup(
self, trainer: "Trainer", pl_module: "LightningModule", stage: str
) -> None:
"""Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir"""
if self.dirpath is not None:
self.dirpath = None
dirpath = self.__resolve_ckpt_dir(trainer)
dirpath = trainer.strategy.broadcast(dirpath)
self.dirpath = dirpath
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)

def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
"""Same as in parent class, duplicated because method in parent class is not accessible"""
if (
self.save_top_k != 0
and _is_dir(self._fs, dirpath, strict=True)
and len(self._fs.ls(dirpath)) > 0
):
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
"""Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs"""
rank_zero_info(f"Resolving checkpoint dir (custom)")
if self.dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
return self.dirpath
if len(trainer.loggers) > 0:
if trainer.loggers[0].save_dir is not None:
save_dir = trainer.loggers[0].save_dir
else:
save_dir = trainer.default_root_dir
name = trainer.loggers[0].name
version = trainer.loggers[0].version
version = version if isinstance(version, str) else f"version_{version}"
logger = trainer.loggers[0]
if isinstance(logger, WandbLogger) and isinstance(
logger.experiment.dir, str
):
ckpt_path = os.path.join(logger.experiment.dir, "checkpoints")
else:
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
else:
# if no loggers, use default_root_dir
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

rank_zero_info(f"Now using checkpoint path {ckpt_path}")
return ckpt_path
30 changes: 27 additions & 3 deletions chebai/cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,39 @@
from lightning.pytorch.cli import LightningCLI
from typing import Dict, Set

from lightning.pytorch.cli import LightningArgumentParser, LightningCLI

from chebai.trainer.CustomTrainer import CustomTrainer


class ChebaiCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
def __init__(self, *args, **kwargs):
super().__init__(trainer_class=CustomTrainer, *args, **kwargs)

def add_arguments_to_parser(self, parser: LightningArgumentParser):
for kind in ("train", "val", "test"):
for average in ("micro", "macro"):
parser.link_arguments(
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
)
parser.link_arguments(
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
)

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
"""Defines the list of available subcommands and the arguments to skip."""
return {
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"predict_from_file": {"model"},
}


def cli():
r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"})
r = ChebaiCLI(
save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"},
)
Empty file added chebai/loggers/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions chebai/loggers/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from datetime import datetime
from typing import Literal, Optional, Union
import os

from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import wandb


class CustomLogger(WandbLogger):
"""Adds support for custom naming of runs and cross-validation"""

def __init__(
self,
save_dir: _PATH,
name: str = "logs",
version: Optional[Union[int, str]] = None,
prefix: str = "",
fold: Optional[int] = None,
project: Optional[str] = None,
entity: Optional[str] = None,
offline: bool = False,
log_model: Union[Literal["all"], bool] = False,
**kwargs,
):
if version is None:
version = f"{datetime.now():%y%m%d-%H%M}"
self._version = version
self._name = name
self._fold = fold
super().__init__(
name=self.name,
save_dir=save_dir,
version=None,
prefix=prefix,
log_model=log_model,
entity=entity,
project=project,
offline=offline,
**kwargs,
)

@property
def name(self) -> Optional[str]:
name = f"{self._name}_{self.version}"
if self._fold is not None:
name += f"_fold{self._fold}"
return name

@property
def version(self) -> Optional[str]:
return self._version

@property
def root_dir(self) -> Optional[str]:
return os.path.join(self.save_dir, self.name)

@property
def log_dir(self) -> str:
version = (
self.version if isinstance(self.version, str) else f"version_{self.version}"
)
if self._fold is None:
return os.path.join(self.root_dir, version)
return os.path.join(self.root_dir, version, f"fold_{self._fold}")

def set_fold(self, fold: int):
if fold != self._fold:
self._fold = fold
# start new experiment
wandb.finish()
self._wandb_init["name"] = self.name
self._experiment = None
_ = self.experiment

@property
def fold(self):
return self._fold

def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# don't save checkpoint as wandb artifact
pass
2 changes: 1 addition & 1 deletion chebai/loss/pretraining.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch


class ElectraPreLoss(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -14,4 +15,3 @@ def forward(self, input, target, **loss_kwargs):
target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred
)
return gen_loss + disc_loss

8 changes: 5 additions & 3 deletions chebai/loss/semantic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from chebai.models.electra import extract_class_hierarchy
import os
import csv
import os
import pickle

import torch

from chebai.models.electra import extract_class_hierarchy

IMPLICATION_CACHE_FILE = "chebi.cache"


Expand Down
Loading

0 comments on commit 0ccd7fb

Please sign in to comment.