From 1c94003376f3a947ec7ba4e3cde649dc51f6a0af Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 10 Mar 2024 00:50:20 +0800 Subject: [PATCH] Add `max_ckpt_keep` for trainer (#3441) Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Jinzhe Zeng --- deepmd/pt/train/training.py | 10 ++++++++++ deepmd/tf/train/trainer.py | 5 ++++- deepmd/utils/argcheck.py | 6 ++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 62bc5a4c97..fb28f0c4f2 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -132,6 +132,7 @@ def __init__( self.disp_freq = training_params.get("disp_freq", 1000) self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") self.save_freq = training_params.get("save_freq", 1000) + self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) self.lcurve_should_print_header = True def get_opt_param(params): @@ -924,6 +925,15 @@ def save_model(self, save_path, lr=0.0, step=0): {"model": module.state_dict(), "optimizer": self.optimizer.state_dict()}, save_path, ) + checkpoint_dir = save_path.parent + checkpoint_files = [ + f + for f in checkpoint_dir.glob("*.pt") + if not f.is_symlink() and f.name.startswith(self.save_ckpt) + ] + if len(checkpoint_files) > self.max_ckpt_keep: + checkpoint_files.sort(key=lambda x: x.stat().st_mtime) + checkpoint_files[0].unlink() def get_data(self, is_train=True, task_key="Default"): if not self.multi_task: diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 1dd31fd0bb..27478abaa1 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -164,6 +164,7 @@ def get_lr_and_coef(lr_param): self.disp_freq = tr_data.get("disp_freq", 1000) self.save_freq = tr_data.get("save_freq", 1000) self.save_ckpt = tr_data.get("save_ckpt", "model.ckpt") + self.max_ckpt_keep = tr_data.get("max_ckpt_keep", 5) self.display_in_training = tr_data.get("disp_training", True) self.timing_in_training = tr_data.get("time_training", True) self.profiling = self.run_opt.is_chief and tr_data.get("profiling", False) @@ -498,7 +499,9 @@ def _init_session(self): # Initializes or restore global variables init_op = tf.global_variables_initializer() if self.run_opt.is_chief: - self.saver = tf.train.Saver(save_relative_paths=True) + self.saver = tf.train.Saver( + save_relative_paths=True, max_to_keep=self.max_ckpt_keep + ) if self.run_opt.init_mode == "init_from_scratch": log.info("initialize model from scratch") run_sess(self.sess, init_op) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5e8db431f8..e822e18d50 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2134,6 +2134,11 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. doc_disp_freq = "The frequency of printing learning curve." doc_save_freq = "The frequency of saving check point." doc_save_ckpt = "The path prefix of saving check point files." + doc_max_ckpt_keep = ( + "The maximum number of checkpoints to keep. " + "The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. " + "Defaults to 5." + ) doc_disp_training = "Displaying verbose information during training." doc_time_training = "Timing durining training." doc_profiling = "Profiling during training." @@ -2192,6 +2197,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. Argument( "save_ckpt", str, optional=True, default="model.ckpt", doc=doc_save_ckpt ), + Argument("max_ckpt_keep", int, optional=True, default=5, doc=doc_max_ckpt_keep), Argument( "disp_training", bool, optional=True, default=True, doc=doc_disp_training ),