Skip to content

Commit

Permalink
Add max_ckpt_keep for trainer (#3441)
Browse files Browse the repository at this point in the history
Signed-off-by: Duo <[email protected]>
Co-authored-by: Jinzhe Zeng <[email protected]>
(cherry picked from commit fd82f04)

--------

Code for PyTorch is removed.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
iProzd authored and njzjz committed Apr 6, 2024
1 parent 99a2d44 commit a788daa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 4 additions & 1 deletion deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def get_lr_and_coef(lr_param):
self.disp_freq = tr_data.get("disp_freq", 1000)
self.save_freq = tr_data.get("save_freq", 1000)
self.save_ckpt = tr_data.get("save_ckpt", "model.ckpt")
self.max_ckpt_keep = tr_data.get("max_ckpt_keep", 5)
self.display_in_training = tr_data.get("disp_training", True)
self.timing_in_training = tr_data.get("time_training", True)
self.profiling = self.run_opt.is_chief and tr_data.get("profiling", False)
Expand Down Expand Up @@ -493,7 +494,9 @@ def _init_session(self):
# Initializes or restore global variables
init_op = tf.global_variables_initializer()
if self.run_opt.is_chief:
self.saver = tf.train.Saver(save_relative_paths=True)
self.saver = tf.train.Saver(
save_relative_paths=True, max_to_keep=self.max_ckpt_keep
)
if self.run_opt.init_mode == "init_from_scratch":
log.info("initialize model from scratch")
run_sess(self.sess, init_op)
Expand Down
6 changes: 6 additions & 0 deletions deepmd_utils/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,6 +1681,11 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
doc_disp_freq = "The frequency of printing learning curve."
doc_save_freq = "The frequency of saving check point."
doc_save_ckpt = "The path prefix of saving check point files."
doc_max_ckpt_keep = (
"The maximum number of checkpoints to keep. "
"The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. "
"Defaults to 5."
)
doc_disp_training = "Displaying verbose information during training."
doc_time_training = "Timing durining training."
doc_profiling = "Profiling during training."
Expand Down Expand Up @@ -1722,6 +1727,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
Argument(
"save_ckpt", str, optional=True, default="model.ckpt", doc=doc_save_ckpt
),
Argument("max_ckpt_keep", int, optional=True, default=5, doc=doc_max_ckpt_keep),
Argument(
"disp_training", bool, optional=True, default=True, doc=doc_disp_training
),
Expand Down

0 comments on commit a788daa

Please sign in to comment.