From 419d603d43c88c9df95a2fe6c8e68f35ebebf3a0 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 5 Dec 2024 13:25:02 +0100 Subject: [PATCH] Added new file for Evaluation --- chebai/result/utils.py | 4 +++ eval.py | 73 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 eval.py diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 31063747..b6d56a5e 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -94,6 +94,7 @@ def evaluate_model( Returns: Tensors with predictions and labels. """ + print("Start of evaluate_model") model.eval() collate = data_module.reader.COLLATOR() @@ -157,6 +158,7 @@ def evaluate_model( torch.cat(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), ) + print("End of evaluate_model") def load_results_from_buffer( @@ -172,6 +174,7 @@ def load_results_from_buffer( Returns: Tensors with predictions and labels. """ + print("Start of load_results_from_buffer") preds_list = [] labels_list = [] @@ -208,6 +211,7 @@ def load_results_from_buffer( else: test_labels = None + print("End of load_results_from_buffer") return test_preds, test_labels diff --git a/eval.py b/eval.py new file mode 100644 index 00000000..28157903 --- /dev/null +++ b/eval.py @@ -0,0 +1,73 @@ +import pandas as pd + +from chebai.result.utils import ( + evaluate_model, + load_results_from_buffer, +) +from chebai.result.classification import print_metrics +from chebai.models.electra import Electra +from chebai.preprocessing.datasets.chebi import ChEBIOver50, ChEBIOver100 +import os +import tqdm +import torch +import pickle + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print(DEVICE) + + +# Specify paths and parameters +checkpoint_name = "best_epoch=14_val_loss=0.0017_val_macro-f1=0.9226_val_micro-f1=0.9847.ckpt" +print("checkpoint_name",checkpoint_name) +checkpoint_path = os.path.join("logs/wandb/run-20241128_214007-ukcabied/files/checkpoints", f"{checkpoint_name}.ckpt") +print("checkpoint_path",checkpoint_path) +kind = "test" # Change to "train" or "validation" as needed +buffer_dir = os.path.join("results_buffer", checkpoint_name, kind) +print("buffer_dir",buffer_dir) +batch_size = 10 # Set batch size + +# Load data module +data_module = ChEBIOver100(chebi_version=231) + +model_class = Electra + +# evaluates model, stores results in buffer_dir +model = model_class.load_from_checkpoint(checkpoint_path) +if buffer_dir is None: + preds, labels = evaluate_model( + model, + data_module, + buffer_dir=buffer_dir, + # No need to provide this parameter for Chebi dataset, "kind" parameter should be provided + # filename=data_module.processed_file_names_dict[kind], + batch_size=10, + kind=kind, + ) +else: + evaluate_model( + model, + data_module, + buffer_dir=buffer_dir, + # No need to provide this parameter for Chebi dataset, "kind" parameter should be provided + # filename=data_module.processed_file_names_dict[kind], + batch_size=10, + kind=kind, + ) + # load data from buffer_dir + preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE) + + +# Load classes from the classes.txt +with open(os.path.join(data_module.processed_dir_main, "classes.txt"), "r") as f: + classes = [line.strip() for line in f.readlines()] + + +# output relevant metrics +print_metrics( + preds, + labels.to(torch.int), + DEVICE, + classes=classes, + markdown_output=False, + top_k=10, +)