diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index 559e51df..d1ccf8a6 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -57,6 +57,7 @@ def predict_from_file(self, model: LightningModule, checkpoint_path: _PATH, inpu loaded_model= model.__class__.load_from_checkpoint(checkpoint_path) with open(input_path, 'r') as input: smiles_strings = [inp.strip() for inp in input.readlines()] + loaded_model.eval() predictions = self._predict_smiles(loaded_model, smiles_strings) predictions_df = pd.DataFrame(predictions.detach().numpy()) if classes_path is not None: