Skip to content

Commit

Permalink
enable choosing from final or best ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
yurujaja committed Nov 8, 2024
1 parent ecfefa1 commit dc90879
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
1 change: 1 addition & 0 deletions configs/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ batch_size: 8


# EXPERIMENT
use_final_ckpt: false
finetune: false
ckpt_dir: ???

Expand Down
1 change: 1 addition & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limited_label_val: 1
limited_label_strategy: stratified # Options: stratified, oversampled, random
stratification_bins: 3 # number of bins for stratified sampling, only for stratified
data_replicate: 1
use_final_ckpt: false


defaults:
Expand Down
9 changes: 7 additions & 2 deletions pangaea/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pangaea.utils.utils import (
fix_seed,
get_best_model_ckpt_path,
get_final_model_ckpt_path,
get_generator,
seed_worker,
)
Expand Down Expand Up @@ -278,8 +279,12 @@ def main(cfg: DictConfig) -> None:
test_evaluator: Evaluator = instantiate(
cfg.task.evaluator, val_loader=test_loader, exp_dir=exp_dir, device=device
)
best_model_ckpt_path = get_best_model_ckpt_path(exp_dir)
test_evaluator.evaluate(decoder, "best_model", best_model_ckpt_path)

if cfg.use_final_ckpt:
model_ckpt_path = get_final_model_ckpt_path(exp_dir)
else:
model_ckpt_path = get_best_model_ckpt_path(exp_dir)
test_evaluator.evaluate(decoder, "test_model", model_ckpt_path)

if cfg.use_wandb and rank == 0:
wandb.finish()
Expand Down

0 comments on commit dc90879

Please sign in to comment.