diff --git a/chebai/result/classification.py b/chebai/result/classification.py index 914d1c3e..4c0ed0cf 100644 --- a/chebai/result/classification.py +++ b/chebai/result/classification.py @@ -9,6 +9,11 @@ import torch import tqdm +from chebai.models import ChebaiBaseNet +from chebai.preprocessing.datasets import XYBaseDataModule +from Typing import Optional +from lightning.fabric.utilities.types import _PATH + def visualise_f1(logs_path): df = pd.read_csv(os.path.join(logs_path, "metrics.csv")) @@ -27,42 +32,139 @@ def visualise_f1(logs_path): # get predictions from model -def evaluate_model(logs_base_path, model_filename, data_module): - model = electra.Electra.load_from_checkpoint( - os.path.join( - logs_base_path, - "best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt", - model_filename, - ) - ) - assert isinstance(model, electra.Electra) +def evaluate_model( + model: ChebaiBaseNet, + data_module: XYBaseDataModule, + data_path: Optional[_PATH] = None, + buffer_dir=None, +): + """Runs model on test set of data_module (or, if data_path is not None, on data set found at data_path). + If buffer_dir is set, results will be saved in buffer_dir. Returns tensors with predictions and labels.""" + model.eval() collate = data_module.reader.COLLATER() - test_file = "test.pt" + if test_file is None: + test_file = data_module.processed_file_names_dict["test"] data_path = os.path.join(data_module.processed_dir, test_file) data_list = torch.load(data_path) preds_list = [] labels_list = [] + if buffer_dir is not None: + os.makedirs(buffer_dir, exist_ok=True) + save_ind = 0 + save_batch_size = 128 + n_saved = 0 for row in tqdm.tqdm(data_list): - processable_data = model._process_batch(collate([row]), 0) - model_output = model(processable_data, **processable_data["model_kwargs"]) + collated = collate([row]) + collated.x = collated.to_x(model.device) + collated.y = collated.to_y(model.device) + processable_data = model._process_batch(collated, 0) + model_output = model(processable_data) preds, labels = model._get_prediction_and_labels( processable_data, processable_data["labels"], model_output ) preds_list.append(preds) labels_list.append(labels) + if buffer_dir is not None: + n_saved += 1 + if n_saved >= save_batch_size: + torch.save( + torch.cat(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + torch.save( + torch.cat(labels_list), + os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), + ) + preds_list = [] + labels_list = [] + save_ind += 1 + n_saved = 0 test_preds = torch.cat(preds_list) test_labels = torch.cat(labels_list) - print(test_preds.shape) - print(test_labels.shape) - test_loss = ElectraPreLoss() - print(f"Loss on test set: {test_loss(test_preds, test_labels)}") - f1_macro = MultilabelF1Score(test_preds.shape[1], average="macro") - f1_micro = MultilabelF1Score(test_preds.shape[1], average="micro") + + return test_preds, test_labels + + +def load_results_from_buffer(buffer_dir): + """Load results stored in evaluate_model()""" + preds_list = [] + labels_list = [] + + i = 0 + filename = f"preds{i:03d}.pt" + while os.path.isfile(os.path.join(buffer_dir, filename)): + preds_list.append( + torch.load( + os.path.join(buffer_dir, filename), + map_location=torch.device(DEVICE), + ) + ) + i += 1 + filename = f"preds{i:03d}.pt" + + i = 0 + filename = f"labels{i:03d}.pt" + while os.path.isfile(os.path.join(buffer_dir, filename)): + labels_list.append( + torch.load( + os.path.join(buffer_dir, filename), + map_location=torch.device(DEVICE), + ) + ) + i += 1 + filename = f"labels{i:03d}.pt" + + test_preds = torch.cat(preds_list) + test_labels = torch.cat(labels_list) + + return test_preds, test_labels + + +def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output=False): + """Prints relevant metrics, including micro and macro F1, recall and precision, best k classes and worst classes.""" + f1_macro = MultilabelF1Score(preds.shape[1], average="macro").to(device=device) + f1_micro = MultilabelF1Score(preds.shape[1], average="micro").to(device=device) print( - f"Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}" + f"Macro-F1 on test set with {preds.shape[1]} classes: {f1_macro(preds, labels):3f}" ) print( - f"Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}" + f"Micro-F1 on test set with {preds.shape[1]} classes: {f1_micro(preds, labels):3f}" + ) + precision_macro = MultilabelPrecision(preds.shape[1], average="macro").to( + device=device ) + precision_micro = MultilabelPrecision(preds.shape[1], average="micro").to( + device=device + ) + recall_macro = MultilabelRecall(preds.shape[1], average="macro").to(device=device) + recall_micro = MultilabelRecall(preds.shape[1], average="micro").to(device=device) + print(f"Macro-Precision: {precision_macro(preds, labels):3f}") + print(f"Micro-Precision: {precision_micro(preds, labels):3f}") + print(f"Macro-Recall: {recall_macro(preds, labels):3f}") + print(f"Micro-Recall: {recall_micro(preds, labels):3f}") + if markdown_output: + print( + f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall |" + ) + print(f"| --- | --- | --- | --- | --- | --- | --- |") + print( + f"| | {f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | {precision_micro(preds, labels):3f} | {recall_macro(preds, labels):3f} | {recall_micro(preds, labels):3f} |" + ) + + classwise_f1_fn = MultilabelF1Score(preds.shape[1], average=None).to(device=device) + classwise_f1 = classwise_f1_fn(preds, labels) + best_classwise_f1 = torch.topk(classwise_f1, top_k).indices + worst_classwise_f1 = torch.topk(classwise_f1, top_k, largest=False).indices + print(f"Top {top_k} classes (F1-score):") + for i, best in enumerate(best_classwise_f1): + print( + f"{i + 1}. {classes[best] if classes is not None else best} - F1: {classwise_f1[best]:3f}" + ) + + zeros = [] + for i, f1 in enumerate(classwise_f1): + if f1 == 0.0: + zeros.append(f"{classes[i] if classes is not None else i}") + print(f'Classes with F1-score = 0: {", ".join(zeros)}') diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml index bcdebf66..63bf0724 100644 --- a/configs/training/default_trainer.yml +++ b/configs/training/default_trainer.yml @@ -2,9 +2,9 @@ min_epochs: 100 max_epochs: 100 default_root_dir: &default_root_dir logs logger: -# class_path: chebai.loggers.custom.CSVLogger -# init_args: -# save_dir: *default_root_dir +# class_path: lightning.pytorch.loggers.CSVLogger +# init_args: +# save_dir: *default_root_dir class_path: chebai.loggers.custom.CustomLogger # Extension of Wandb logger init_args: save_dir: *default_root_dir