From 5b24626256fb0a7b549161dd208e396e21afd839 Mon Sep 17 00:00:00 2001 From: Yuru Jia <91590963+yurujaja@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:05:32 +0100 Subject: [PATCH] enable selection between the final and best checkpoint (#117) --- configs/test.yaml | 1 + configs/train.yaml | 1 + pangaea/run.py | 9 +++++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/configs/test.yaml b/configs/test.yaml index 1b4e5f53..f3965f28 100644 --- a/configs/test.yaml +++ b/configs/test.yaml @@ -9,6 +9,7 @@ batch_size: 8 # EXPERIMENT +use_final_ckpt: false finetune: false ckpt_dir: ??? diff --git a/configs/train.yaml b/configs/train.yaml index a1ea932c..28b7fa7d 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -19,6 +19,7 @@ limited_label_val: 1 limited_label_strategy: stratified # Options: stratified, oversampled, random stratification_bins: 3 # number of bins for stratified sampling, only for stratified data_replicate: 1 +use_final_ckpt: false defaults: diff --git a/pangaea/run.py b/pangaea/run.py index 82f2c9af..d6f4ec19 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -24,6 +24,7 @@ from pangaea.utils.utils import ( fix_seed, get_best_model_ckpt_path, + get_final_model_ckpt_path, get_generator, seed_worker, ) @@ -278,8 +279,12 @@ def main(cfg: DictConfig) -> None: test_evaluator: Evaluator = instantiate( cfg.task.evaluator, val_loader=test_loader, exp_dir=exp_dir, device=device ) - best_model_ckpt_path = get_best_model_ckpt_path(exp_dir) - test_evaluator.evaluate(decoder, "best_model", best_model_ckpt_path) + + if cfg.use_final_ckpt: + model_ckpt_path = get_final_model_ckpt_path(exp_dir) + else: + model_ckpt_path = get_best_model_ckpt_path(exp_dir) + test_evaluator.evaluate(decoder, "test_model", model_ckpt_path) if cfg.use_wandb and rank == 0: wandb.finish()