Skip to content

Commit

Permalink
add / improve classfication task evaluation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 13, 2023
1 parent 643ffe8 commit d01a67a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 23 deletions.
142 changes: 122 additions & 20 deletions chebai/result/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import torch
import tqdm

from chebai.models import ChebaiBaseNet
from chebai.preprocessing.datasets import XYBaseDataModule
from Typing import Optional
from lightning.fabric.utilities.types import _PATH


def visualise_f1(logs_path):
df = pd.read_csv(os.path.join(logs_path, "metrics.csv"))
Expand All @@ -27,42 +32,139 @@ def visualise_f1(logs_path):


# get predictions from model
def evaluate_model(logs_base_path, model_filename, data_module):
model = electra.Electra.load_from_checkpoint(
os.path.join(
logs_base_path,
"best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt",
model_filename,
)
)
assert isinstance(model, electra.Electra)
def evaluate_model(
model: ChebaiBaseNet,
data_module: XYBaseDataModule,
data_path: Optional[_PATH] = None,
buffer_dir=None,
):
"""Runs model on test set of data_module (or, if data_path is not None, on data set found at data_path).
If buffer_dir is set, results will be saved in buffer_dir. Returns tensors with predictions and labels."""
model.eval()
collate = data_module.reader.COLLATER()
test_file = "test.pt"
if test_file is None:
test_file = data_module.processed_file_names_dict["test"]
data_path = os.path.join(data_module.processed_dir, test_file)
data_list = torch.load(data_path)
preds_list = []
labels_list = []
if buffer_dir is not None:
os.makedirs(buffer_dir, exist_ok=True)
save_ind = 0
save_batch_size = 128
n_saved = 0

for row in tqdm.tqdm(data_list):
processable_data = model._process_batch(collate([row]), 0)
model_output = model(processable_data, **processable_data["model_kwargs"])
collated = collate([row])
collated.x = collated.to_x(model.device)
collated.y = collated.to_y(model.device)
processable_data = model._process_batch(collated, 0)
model_output = model(processable_data)
preds, labels = model._get_prediction_and_labels(
processable_data, processable_data["labels"], model_output
)
preds_list.append(preds)
labels_list.append(labels)
if buffer_dir is not None:
n_saved += 1
if n_saved >= save_batch_size:
torch.save(
torch.cat(preds_list),
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
)
torch.save(
torch.cat(labels_list),
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
)
preds_list = []
labels_list = []
save_ind += 1
n_saved = 0

test_preds = torch.cat(preds_list)
test_labels = torch.cat(labels_list)
print(test_preds.shape)
print(test_labels.shape)
test_loss = ElectraPreLoss()
print(f"Loss on test set: {test_loss(test_preds, test_labels)}")
f1_macro = MultilabelF1Score(test_preds.shape[1], average="macro")
f1_micro = MultilabelF1Score(test_preds.shape[1], average="micro")

return test_preds, test_labels


def load_results_from_buffer(buffer_dir):
"""Load results stored in evaluate_model()"""
preds_list = []
labels_list = []

i = 0
filename = f"preds{i:03d}.pt"
while os.path.isfile(os.path.join(buffer_dir, filename)):
preds_list.append(
torch.load(
os.path.join(buffer_dir, filename),
map_location=torch.device(DEVICE),
)
)
i += 1
filename = f"preds{i:03d}.pt"

i = 0
filename = f"labels{i:03d}.pt"
while os.path.isfile(os.path.join(buffer_dir, filename)):
labels_list.append(
torch.load(
os.path.join(buffer_dir, filename),
map_location=torch.device(DEVICE),
)
)
i += 1
filename = f"labels{i:03d}.pt"

test_preds = torch.cat(preds_list)
test_labels = torch.cat(labels_list)

return test_preds, test_labels


def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output=False):
"""Prints relevant metrics, including micro and macro F1, recall and precision, best k classes and worst classes."""
f1_macro = MultilabelF1Score(preds.shape[1], average="macro").to(device=device)
f1_micro = MultilabelF1Score(preds.shape[1], average="micro").to(device=device)
print(
f"Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}"
f"Macro-F1 on test set with {preds.shape[1]} classes: {f1_macro(preds, labels):3f}"
)
print(
f"Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}"
f"Micro-F1 on test set with {preds.shape[1]} classes: {f1_micro(preds, labels):3f}"
)
precision_macro = MultilabelPrecision(preds.shape[1], average="macro").to(
device=device
)
precision_micro = MultilabelPrecision(preds.shape[1], average="micro").to(
device=device
)
recall_macro = MultilabelRecall(preds.shape[1], average="macro").to(device=device)
recall_micro = MultilabelRecall(preds.shape[1], average="micro").to(device=device)
print(f"Macro-Precision: {precision_macro(preds, labels):3f}")
print(f"Micro-Precision: {precision_micro(preds, labels):3f}")
print(f"Macro-Recall: {recall_macro(preds, labels):3f}")
print(f"Micro-Recall: {recall_micro(preds, labels):3f}")
if markdown_output:
print(
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall |"
)
print(f"| --- | --- | --- | --- | --- | --- | --- |")
print(
f"| | {f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | {precision_micro(preds, labels):3f} | {recall_macro(preds, labels):3f} | {recall_micro(preds, labels):3f} |"
)

classwise_f1_fn = MultilabelF1Score(preds.shape[1], average=None).to(device=device)
classwise_f1 = classwise_f1_fn(preds, labels)
best_classwise_f1 = torch.topk(classwise_f1, top_k).indices
worst_classwise_f1 = torch.topk(classwise_f1, top_k, largest=False).indices
print(f"Top {top_k} classes (F1-score):")
for i, best in enumerate(best_classwise_f1):
print(
f"{i + 1}. {classes[best] if classes is not None else best} - F1: {classwise_f1[best]:3f}"
)

zeros = []
for i, f1 in enumerate(classwise_f1):
if f1 == 0.0:
zeros.append(f"{classes[i] if classes is not None else i}")
print(f'Classes with F1-score = 0: {", ".join(zeros)}')
6 changes: 3 additions & 3 deletions configs/training/default_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ min_epochs: 100
max_epochs: 100
default_root_dir: &default_root_dir logs
logger:
# class_path: chebai.loggers.custom.CSVLogger
# init_args:
# save_dir: *default_root_dir
# class_path: lightning.pytorch.loggers.CSVLogger
# init_args:
# save_dir: *default_root_dir
class_path: chebai.loggers.custom.CustomLogger # Extension of Wandb logger
init_args:
save_dir: *default_root_dir
Expand Down

0 comments on commit d01a67a

Please sign in to comment.