diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index 5e6ec7e..0455c4b 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -290,7 +290,7 @@ def format_metric(name, values, mean_value): self.logger.info(recall_str) self.logger.info(macc_str) - if self.use_wandb: + if self.use_wandb and self.rank == 0: wandb.log( { f"{self.split}_mIoU": metrics["mIoU"], @@ -400,5 +400,5 @@ def log_metrics(self, metrics): rmse = "-------------------\n" + 'RMSE \t{:>7}'.format('%.3f' % metrics['RMSE']) self.logger.info(header + mse + rmse) - if self.use_wandb: + if self.use_wandb and self.rank == 0: wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]}) diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index 5aebcda..2291f0a 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -76,7 +76,6 @@ def __init__( for name in ["loss", "data_time", "batch_time", "eval_time"] } self.training_metrics = {} - self.best_ckpt = None self.best_metric_comp = operator.gt self.num_classes = self.train_loader.dataset.num_classes @@ -105,7 +104,7 @@ def train(self) -> None: if epoch % self.eval_interval == 0: metrics, used_time = self.evaluator(self.model, f"epoch {epoch}") self.training_stats["eval_time"].update(used_time) - self.set_best_checkpoint(metrics, epoch) + self.save_best_checkpoint(metrics, epoch) self.logger.info("============ Starting epoch %i ... ============" % epoch) # set sampler @@ -117,17 +116,11 @@ def train(self) -> None: metrics, used_time = self.evaluator(self.model, "final model") self.training_stats["eval_time"].update(used_time) - self.set_best_checkpoint(metrics, self.n_epochs) + self.save_best_checkpoint(metrics, self.n_epochs) # save last model self.save_model(self.n_epochs, is_final=True) - # save best model - if self.best_ckpt: - self.save_model( - self.best_ckpt["epoch"], is_best=True, checkpoint=self.best_ckpt - ) - def train_one_epoch(self, epoch: int) -> None: """Train model for one epoch. @@ -186,7 +179,7 @@ def train_one_epoch(self, epoch: int) -> None: end_time = time.time() def get_checkpoint(self, epoch: int) -> dict[str, dict | int]: - """Create a checkpoint dictionary. + """Create a checkpoint dictionary, containing references to the pytorch tensors. Args: epoch (int): number of the epoch. @@ -201,7 +194,7 @@ def get_checkpoint(self, epoch: int) -> dict[str, dict | int]: "scaler": self.scaler.state_dict(), "epoch": epoch, } - return copy.deepcopy(checkpoint) + return checkpoint def save_model( self, @@ -222,8 +215,8 @@ def save_model( torch.distributed.barrier() return checkpoint = self.get_checkpoint(epoch) if checkpoint is None else checkpoint - suffix = "_best" if is_best else "_final" if is_final else "" - checkpoint_path = os.path.join(self.exp_dir, f"checkpoint_{epoch}{suffix}.pth") + suffix = "_best" if is_best else f"{epoch}_final" if is_final else f"{epoch}" + checkpoint_path = os.path.join(self.exp_dir, f"checkpoint_{suffix}.pth") torch.save(checkpoint, checkpoint_path) self.logger.info( f"Epoch {epoch} | Training checkpoint saved at {checkpoint_path}" @@ -267,7 +260,7 @@ def compute_loss(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tens """ raise NotImplementedError - def set_best_checkpoint( + def save_best_checkpoint( self, eval_metrics: dict[float, list[float]], epoch: int ) -> None: """Update the best checkpoint according to the evaluation metrics. @@ -281,7 +274,10 @@ def set_best_checkpoint( curr_metric = curr_metric[0] if self.num_classes == 1 else np.mean(curr_metric) if self.best_metric_comp(curr_metric, self.best_metric): self.best_metric = curr_metric - self.best_ckpt = self.get_checkpoint(epoch) + best_ckpt = self.get_checkpoint(epoch) + self.save_model( + epoch, is_best=True, checkpoint=best_ckpt + ) @torch.no_grad() def compute_logging_metrics(