Skip to content

Commit

Permalink
add epoch-level macro-f1
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Nov 24, 2023
1 parent 2aae072 commit 684bdf7
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 8 deletions.
Empty file added chebai/callbacks/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any

from lightning.pytorch.callbacks import Callback
import numpy as np
from lightning.pytorch.utilities.types import STEP_OUTPUT
import lightning as pl
from torchmetrics.classification import MultilabelF1Score


class _EpochLevelMetric(Callback):
"""Applies a metric to data from a whole training epoch, instead of batch-wise (the default in Lightning)"""

def __init__(self, num_labels):
self.train_labels, self.val_labels = None, None
self.train_preds, self.val_preds = None, None
self.num_labels = num_labels

@property
def metric_name(self):
raise NotImplementedError

def apply_metric(self, target, pred):
raise NotImplementedError

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.train_labels = np.empty(shape=(0,), dtype=int)
self.train_preds = np.empty(shape=(0,), dtype=int)

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT,
batch: Any, batch_idx: int) -> None:
self.train_labels = np.concatenate((self.train_labels, outputs['labels'].int(),))
self.train_preds = np.concatenate((self.train_preds, outputs['preds'],))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pl_module.log(f'train_{self.metric_name}', self.apply_metric(self.train_labels, self.train_preds))

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.val_labels = np.empty(shape=(0,), dtype=int)
self.val_preds = np.empty(shape=(0,), dtype=int)

def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT,
batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
self.val_labels = np.concatenate((self.val_labels, outputs['labels'].int(),))
self.val_preds = np.concatenate((self.val_preds, outputs['preds'],))

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pl_module.log(f'val_{self.metric_name}', self.apply_metric(self.val_labels, self.val_preds))


class EpochLevelMacroF1(_EpochLevelMetric):

@property
def metric_name(self):
return 'ep_macro-f1'

def apply_metric(self, target, pred):
f1 = MultilabelF1Score(num_labels=self.num_labels, average='macro')
return f1(target, pred)
1 change: 1 addition & 0 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def add_arguments_to_parser(self, parser):
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
)
parser.link_arguments("model.init_args.out_dim", "trainer.callbacks.num_labels")
# parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
data = self._process_batch(batch, batch_idx)
labels = data["labels"]
model_output = self(data, **data.get("model_kwargs", dict()))
d = dict(data=data, labels=labels, output=model_output)
pr, tar = self._get_prediction_and_labels(data, labels, model_output)
d = dict(data=data, labels=labels, output=model_output, preds=pr)
if log:
if self.criterion is not None:
loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss(
Expand All @@ -100,7 +101,6 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
sync_dist=sync_dist,
)
if metrics and labels is not None:
pr, tar = self._get_prediction_and_labels(data, labels, model_output)
for metric_name, metric in metrics.items():
m = metric(pr, tar)
if isinstance(m, dict):
Expand Down
8 changes: 7 additions & 1 deletion chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,13 @@ def _get_prediction_and_labels(self, data, labels, model_output):

def forward(self, data, **kwargs):
self.batch_size = data["features"].shape[0]
inp = self.electra.embeddings.forward(data["features"])
try:
inp = self.electra.embeddings.forward(data["features"])
except RuntimeError as e:
print(f'RuntimeError at forward')
print(f'data: {data}')
print(f'data[features]: {data["features"]}')
print(e)
inp = self.word_dropout(inp)
electra = self.electra(inputs_embeds=inp, **kwargs)
d = electra.last_hidden_state[:, 0, :]
Expand Down
9 changes: 5 additions & 4 deletions configs/metrics/micro-macro-f1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ init_args:
class_path: torchmetrics.classification.MultilabelF1Score
init_args:
average: micro
macro-f1:
class_path: torchmetrics.classification.MultilabelF1Score
init_args:
average: macro
# not functioning, results are calculated batch-wise instead of epoch-wise
#macro-f1:
# class_path: torchmetrics.classification.MultilabelF1Score
# init_args:
# average: macro
3 changes: 2 additions & 1 deletion configs/training/default_callbacks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
init_args:
filename: 'per_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}'
every_n_epochs: 5
save_top_k: -1
save_top_k: -1
- class_path: chebai.callbacks.epoch_metrics.EpochLevelMacroF1

0 comments on commit 684bdf7

Please sign in to comment.