Skip to content

Commit

Permalink
add subcommand predict_from_file
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 4, 2023
1 parent 0ef01cd commit afb4fbb
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 13 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,23 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co
python -m chebai train --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --ckpt_path=[path-to-model-with-ontology-pretraining]
```

## Features on branch `features-sfluegel`
### Cross-validation
## Predicting classes given SMILES strings

```
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
```
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
one row for each SMILES string and one column for each class.


## Cross-validation
Use inner cross-validation by not splitting between test and validation sets at dataset creation,
but using k-fold cross-validation at runtime. This creates k models with separate metrics and checkpoints.
For training with `k`-fold cross-validation, use the `cv-fit` subcommand and the options
```
--data.init_args.inner_k_folds=k --n_splits=k
```
### Chebi versions
## Chebi versions
Change the chebi version used for all sets (default: 200):
```
--data.init_args.chebi_version=VERSION
Expand Down
1 change: 1 addition & 0 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def subcommands() -> Dict[str, Set[str]]:
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
"predict_from_file": {"model"}
}


Expand Down
13 changes: 8 additions & 5 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ class ChemDataReader(DataReader):
def name(cls):
return "smiles_token"

def __init__(self, *args, **kwargs):
def __init__(self, token_path = None, *args, **kwargs):
super().__init__(*args, **kwargs)
dirname = os.path.dirname(__file__)
with open(os.path.join(dirname, "bin", "tokens.txt"), "r") as pk:
if token_path is None:
dirname = os.path.dirname(__file__)
token_path = os.path.join(dirname, "bin", "tokens.txt")
self.token_path = token_path
with open(self.token_path, "r") as pk:
self.cache = [x.strip() for x in pk]

def _get_token_index(self, token):
Expand All @@ -104,8 +107,8 @@ def _read_data(self, raw_data):
def save_token_cache(self):
"""write contents of self.cache into tokens.txt"""
dirname = os.path.dirname(__file__)
with open(os.path.join(dirname, "bin", "tokens.txt"), "w") as pk:
print(f'saving tokens to {os.path.join(dirname, "bin", "tokens.txt")}...')
with open(self.token_path, "w") as pk:
print(f'saving tokens to {self.token_path}...')
print(f'first 10 tokens: {self.cache[:10]}')
pk.writelines([f'{c}\n' for c in self.cache])

Expand Down
39 changes: 34 additions & 5 deletions chebai/trainer/InnerCVTrainer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import logging
import os
from typing import Optional, Union, Iterable
from typing import Optional, Union

import pandas as pd
from lightning import Trainer, LightningModule
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning.pytorch.loggers import CSVLogger, Logger, TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning_utilities.core.rank_zero import WarningCache

from lightning.pytorch.loggers import CSVLogger
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn

from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.collate import RaggedCollater
from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader
from torch.nn.utils.rnn import pad_sequence
import torch
import pandas as pd

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,6 +52,32 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar
print(f'Logging this fold at {new_trainer.logger.log_dir}')
new_trainer.fit(train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, *args, **kwargs)

def predict_from_file(self, model: LightningModule, checkpoint_path: _PATH, input_path: _PATH,
save_to: _PATH='predictions.csv', classes_path: Optional[_PATH] = None):
loaded_model= model.__class__.load_from_checkpoint(checkpoint_path)
with open(input_path, 'r') as input:
smiles_strings = [inp.strip() for inp in input.readlines()]
predictions = self._predict_smiles(loaded_model, smiles_strings)
predictions_df = pd.DataFrame(predictions.detach().numpy())
if classes_path is not None:
with open(classes_path, 'r') as f:
predictions_df.columns = [cls.strip() for cls in f.readlines()]
predictions_df.index = smiles_strings
predictions_df.to_csv(save_to)


def _predict_smiles(self, model: LightningModule, smiles: list[str]):
reader = ChemDataReader()
parsed_smiles = [reader._read_data(s) for s in smiles]
x = pad_sequence([torch.tensor(a) for a in parsed_smiles], batch_first=True)
cls_tokens = (torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) * CLS_TOKEN)
features = torch.cat((cls_tokens, x), dim=1)
model_output = model({'features': features})
preds = torch.sigmoid(model_output['logits'])

print(preds.shape)
return preds


# extend CSVLogger to include fold number in log path
class CSVLoggerCVSupport(CSVLogger):
Expand Down

0 comments on commit afb4fbb

Please sign in to comment.