Skip to content

Commit

Permalink
Merge pull request #87 from yurujaja/best_checkpoint_metric
Browse files Browse the repository at this point in the history
Best checkpoint metric
  • Loading branch information
SebastianHafner authored Oct 8, 2024
2 parents 60b0db5 + 978f016 commit a04f8bc
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
1 change: 1 addition & 0 deletions configs/task/change_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ trainer:
ckpt_interval: 20
eval_interval: 5
log_interval: 5
best_metric_key: IoU
use_wandb: ${use_wandb}

evaluator:
Expand Down
1 change: 1 addition & 0 deletions configs/task/regression.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ trainer:
ckpt_interval: 20
eval_interval: 5
log_interval: 5
best_metric_key: MSE
use_wandb: ${use_wandb}

evaluator:
Expand Down
1 change: 1 addition & 0 deletions configs/task/segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ trainer:
ckpt_interval: 20
eval_interval: 5
log_interval: 5
best_metric_key: mIoU
use_wandb: ${use_wandb}

evaluator:
Expand Down
26 changes: 20 additions & 6 deletions pangaea/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import os
import pathlib
import time
import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Subset
from pangaea.utils.logger import RunningAverageMeter, sec_to_hm


Expand All @@ -30,6 +31,7 @@ def __init__(
ckpt_interval: int,
eval_interval: int,
log_interval: int,
best_metric_key: str,
):
"""Initialize the Trainer.
Expand All @@ -48,6 +50,7 @@ def __init__(
ckpt_interval (int): interval to save the checkpoint.
eval_interval (int): interval to evaluate the model.
log_interval (int): interval to log the training information.
best_metric_key (str): metric that determines best checkpoints.
"""
self.rank = int(os.environ["RANK"])
self.criterion = criterion
Expand All @@ -65,15 +68,19 @@ def __init__(
self.ckpt_interval = ckpt_interval
self.eval_interval = eval_interval
self.log_interval = log_interval
self.best_metric_key = best_metric_key

self.training_stats = {
name: RunningAverageMeter(length=self.batch_per_epoch)
for name in ["loss", "data_time", "batch_time", "eval_time"]
}
self.training_metrics = {}
self.best_ckpt = None
self.best_metric_key = None
self.best_metric_comp = operator.gt
if isinstance(self.train_loader.dataset, Subset):
self.num_classes = self.train_loader.dataset.dataset.num_classes
else:
self.num_classes = self.train_loader.dataset.num_classes

assert precision in [
"fp32",
Expand Down Expand Up @@ -265,8 +272,11 @@ def set_best_checkpoint(
eval_metrics (dict[float, list[float]]): metrics computed by the evaluator on the validation set.
epoch (int): number of the epoch.
"""
if self.best_metric_comp(eval_metrics[self.best_metric_key], self.best_metric):
self.best_metric = eval_metrics[self.best_metric_key]
curr_metric = eval_metrics[self.best_metric_key]
if isinstance(curr_metric, list):
curr_metric = curr_metric[0] if self.num_classes == 1 else np.mean(curr_metric)
if self.best_metric_comp(curr_metric, self.best_metric):
self.best_metric = curr_metric
self.best_ckpt = self.get_checkpoint(epoch)

@torch.no_grad()
Expand Down Expand Up @@ -362,6 +372,7 @@ def __init__(
ckpt_interval: int,
eval_interval: int,
log_interval: int,
best_metric_key: str,
):
"""Initialize the Trainer for segmentation task.
Args:
Expand All @@ -379,6 +390,7 @@ def __init__(
ckpt_interval (int): interval to save the checkpoint.
eval_interval (int): interval to evaluate the model.
log_interval (int): interval to log the training information.
best_metric_key (str): metric that determines best checkpoints.
"""
super().__init__(
model=model,
Expand All @@ -395,13 +407,13 @@ def __init__(
ckpt_interval=ckpt_interval,
eval_interval=eval_interval,
log_interval=log_interval,
best_metric_key=best_metric_key,
)

self.training_metrics = {
name: RunningAverageMeter(length=100) for name in ["Acc", "mAcc", "mIoU"]
}
self.best_metric = float("-inf")
self.best_metric_key = "mIoU"
self.best_metric_comp = operator.gt

def compute_loss(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -486,6 +498,7 @@ def __init__(
ckpt_interval: int,
eval_interval: int,
log_interval: int,
best_metric_key: str,
):
"""Initialize the Trainer for regression task.
Args:
Expand All @@ -503,6 +516,7 @@ def __init__(
ckpt_interval (int): interval to save the checkpoint.
eval_interval (int): interval to evaluate the model.
log_interval (int): interval to log the training information.
best_metric_key (str): metric that determines best checkpoints.
"""
super().__init__(
model=model,
Expand All @@ -519,13 +533,13 @@ def __init__(
ckpt_interval=ckpt_interval,
eval_interval=eval_interval,
log_interval=log_interval,
best_metric_key=best_metric_key,
)

self.training_metrics = {
name: RunningAverageMeter(length=100) for name in ["MSE"]
}
self.best_metric = float("inf")
self.best_metric_key = "MSE"
self.best_metric_comp = operator.lt

def compute_loss(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit a04f8bc

Please sign in to comment.