From ba50b820bac25cbae2910c49043491065c0de879 Mon Sep 17 00:00:00 2001 From: Andres Espinosa Date: Tue, 14 Jan 2020 08:57:51 -0300 Subject: [PATCH] Allow save checkpoints without validation data --- machine/trainer/supervised_trainer.py | 2 +- machine/util/callbacks/model_checkpoint.py | 25 +++++++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/machine/trainer/supervised_trainer.py b/machine/trainer/supervised_trainer.py index 2734f055..bac30e24 100644 --- a/machine/trainer/supervised_trainer.py +++ b/machine/trainer/supervised_trainer.py @@ -221,7 +221,7 @@ def get_optim(optim_name): callbacks = CallbackContainer(self, [Logger(), - ModelCheckpoint(top_k=top_k), + ModelCheckpoint(top_k=top_k, save_last=dev_data is None), History()] + custom_callbacks) logs = self._train_epoches(data, num_epochs, diff --git a/machine/util/callbacks/model_checkpoint.py b/machine/util/callbacks/model_checkpoint.py index e451a0fa..95b4669c 100644 --- a/machine/util/callbacks/model_checkpoint.py +++ b/machine/util/callbacks/model_checkpoint.py @@ -11,14 +11,19 @@ class ModelCheckpoint(Callback): Model checkpoint to save weights during training. This callback is automatically applied for every model that is trained with the SupervisedTrainer. + + Args: + save_last (optional, bool): if True, save last top_k models + instead of the best top_k models """ def __init__(self, top_k=5, monitor='val', - save_best_only=True): + save_last=False): super(ModelCheckpoint, self).__init__() self.top_k = top_k self.monitor = monitor - self.save_best_only = save_best_only + self.save_last = save_last + self.next_index = 1 def set_trainer(self, trainer): self.trainer = trainer @@ -43,14 +48,18 @@ def on_batch_end(self, batch, info=None): max_eval_loss = max(self.loss_best) - if total_loss < max_eval_loss: - index_max = self.loss_best.index(max_eval_loss) + if total_loss < max_eval_loss or self.save_last: + if self.save_last: + index_to_overwrite = self.next_index + self.next_index = (self.next_index + 1) % self.top_k + else: + index_to_overwrite = self.loss_best.index(max_eval_loss) # rm prev model - if self.best_checkpoints[index_max] is not None: + if self.best_checkpoints[index_to_overwrite] is not None: shutil.rmtree(os.path.join( - self.expt_dir, self.best_checkpoints[index_max])) - self.best_checkpoints[index_max] = model_name - self.loss_best[index_max] = total_loss + self.expt_dir, self.best_checkpoints[index_to_overwrite])) + self.best_checkpoints[index_to_overwrite] = model_name + self.loss_best[index_to_overwrite] = total_loss # save model Checkpoint(model=self.trainer.model,