Skip to content

Commit

Permalink
Merge pull request #113 from yurujaja/best_checkpoint_memory
Browse files Browse the repository at this point in the history
Save best checkpoint to disk instead of keeping it in memory
  • Loading branch information
yurujaja authored Oct 31, 2024
2 parents a4f848f + 8497d6e commit fcd7560
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
4 changes: 2 additions & 2 deletions pangaea/engine/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def format_metric(name, values, mean_value):
self.logger.info(recall_str)
self.logger.info(macc_str)

if self.use_wandb:
if self.use_wandb and self.rank == 0:
wandb.log(
{
f"{self.split}_mIoU": metrics["mIoU"],
Expand Down Expand Up @@ -400,5 +400,5 @@ def log_metrics(self, metrics):
rmse = "-------------------\n" + 'RMSE \t{:>7}'.format('%.3f' % metrics['RMSE'])
self.logger.info(header + mse + rmse)

if self.use_wandb:
if self.use_wandb and self.rank == 0:
wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]})
26 changes: 11 additions & 15 deletions pangaea/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def __init__(
for name in ["loss", "data_time", "batch_time", "eval_time"]
}
self.training_metrics = {}
self.best_ckpt = None
self.best_metric_comp = operator.gt
self.num_classes = self.train_loader.dataset.num_classes

Expand Down Expand Up @@ -105,7 +104,7 @@ def train(self) -> None:
if epoch % self.eval_interval == 0:
metrics, used_time = self.evaluator(self.model, f"epoch {epoch}")
self.training_stats["eval_time"].update(used_time)
self.set_best_checkpoint(metrics, epoch)
self.save_best_checkpoint(metrics, epoch)

self.logger.info("============ Starting epoch %i ... ============" % epoch)
# set sampler
Expand All @@ -117,17 +116,11 @@ def train(self) -> None:

metrics, used_time = self.evaluator(self.model, "final model")
self.training_stats["eval_time"].update(used_time)
self.set_best_checkpoint(metrics, self.n_epochs)
self.save_best_checkpoint(metrics, self.n_epochs)

# save last model
self.save_model(self.n_epochs, is_final=True)

# save best model
if self.best_ckpt:
self.save_model(
self.best_ckpt["epoch"], is_best=True, checkpoint=self.best_ckpt
)

def train_one_epoch(self, epoch: int) -> None:
"""Train model for one epoch.
Expand Down Expand Up @@ -186,7 +179,7 @@ def train_one_epoch(self, epoch: int) -> None:
end_time = time.time()

def get_checkpoint(self, epoch: int) -> dict[str, dict | int]:
"""Create a checkpoint dictionary.
"""Create a checkpoint dictionary, containing references to the pytorch tensors.
Args:
epoch (int): number of the epoch.
Expand All @@ -201,7 +194,7 @@ def get_checkpoint(self, epoch: int) -> dict[str, dict | int]:
"scaler": self.scaler.state_dict(),
"epoch": epoch,
}
return copy.deepcopy(checkpoint)
return checkpoint

def save_model(
self,
Expand All @@ -222,8 +215,8 @@ def save_model(
torch.distributed.barrier()
return
checkpoint = self.get_checkpoint(epoch) if checkpoint is None else checkpoint
suffix = "_best" if is_best else "_final" if is_final else ""
checkpoint_path = os.path.join(self.exp_dir, f"checkpoint_{epoch}{suffix}.pth")
suffix = "_best" if is_best else f"{epoch}_final" if is_final else f"{epoch}"
checkpoint_path = os.path.join(self.exp_dir, f"checkpoint_{suffix}.pth")
torch.save(checkpoint, checkpoint_path)
self.logger.info(
f"Epoch {epoch} | Training checkpoint saved at {checkpoint_path}"
Expand Down Expand Up @@ -267,7 +260,7 @@ def compute_loss(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tens
"""
raise NotImplementedError

def set_best_checkpoint(
def save_best_checkpoint(
self, eval_metrics: dict[float, list[float]], epoch: int
) -> None:
"""Update the best checkpoint according to the evaluation metrics.
Expand All @@ -281,7 +274,10 @@ def set_best_checkpoint(
curr_metric = curr_metric[0] if self.num_classes == 1 else np.mean(curr_metric)
if self.best_metric_comp(curr_metric, self.best_metric):
self.best_metric = curr_metric
self.best_ckpt = self.get_checkpoint(epoch)
best_ckpt = self.get_checkpoint(epoch)
self.save_model(
epoch, is_best=True, checkpoint=best_ckpt
)

@torch.no_grad()
def compute_logging_metrics(
Expand Down

0 comments on commit fcd7560

Please sign in to comment.