From c6a03fc46afb138480231520e0b2ab63703378bf Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jan 2024 17:46:31 -0500 Subject: [PATCH 1/3] PT: keep the same checkpoint behavior as TF Set the default save_ckpt to `model.ckpt` as the prefix. When saving checkpoints, `model.ckpt-100.pt` will be saved, and `model.ckpt.pt` will be symlinked to `model.ckpt-100.pt`. A `checkpoint` file will be saved to record `model.ckpt-100.pt`. This keeps the same behavior as the TF backend. Signed-off-by: Jinzhe Zeng --- deepmd/common.py | 31 +++++++++++++++++++++++++++++++ deepmd/pt/entrypoints/main.py | 6 +++--- deepmd/pt/train/training.py | 19 +++++++++---------- deepmd/tf/train/trainer.py | 19 ++++--------------- 4 files changed, 47 insertions(+), 28 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index f950b50919..05d02234b4 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import glob import json +import os +import platform +import shutil import warnings from pathlib import ( Path, @@ -268,3 +272,30 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype: return np.float64 else: raise RuntimeError(f"{precision} is not a valid precision") + + +def symlink_prefix_files(old_prefix: str, new_prefix: str): + """Create symlinks from old checkpoint prefix to new one. + + On Windows this function will copy files instead of creating symlinks. + + Parameters + ---------- + old_prefix : str + old checkpoint prefix, all files with this prefix will be symlinked + new_prefix : str + new checkpoint prefix + """ + original_files = glob.glob(old_prefix + ".*") + for ori_ff in original_files: + new_ff = new_prefix + ori_ff[len(old_prefix) :] + try: + # remove old one + os.remove(new_ff) + except OSError: + pass + if platform.system() != "Windows": + # by default one does not have access to create symlink on Windows + os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff) + else: + shutil.copyfile(ori_ff, new_ff) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index c5e551ebd8..ad5e92d495 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -308,9 +308,9 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): test(FLAGS) elif FLAGS.command == "freeze": if Path(FLAGS.checkpoint_folder).is_dir(): - # TODO: automatically generate model.pt during training - # FLAGS.model = str(Path(FLAGS.checkpoint).joinpath("model.pt")) - raise NotImplementedError("Checkpoint should give a file") + checkpoint_path = Path(FLAGS.checkpoint_folder) + latest_ckpt_file = (checkpoint_path / "checkpoint").read_text() + FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file)) else: FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 049685a6e3..8ea69c8489 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import os import time from copy import ( deepcopy, @@ -22,6 +21,9 @@ logging_redirect_tqdm, ) +from deepmd.common import ( + symlink_prefix_files, +) from deepmd.pt.loss import ( DenoiseLoss, EnergyStdLoss, @@ -102,7 +104,7 @@ def __init__( self.num_steps = training_params["numb_steps"] self.disp_file = training_params.get("disp_file", "lcurve.out") self.disp_freq = training_params.get("disp_freq", 1000) - self.save_ckpt = training_params.get("save_ckpt", "model.pt") + self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") self.save_freq = training_params.get("save_freq", 1000) self.lcurve_should_print_header = True @@ -650,13 +652,14 @@ def log_loss_valid(_task_key="Default"): or (_step_id + 1) == self.num_steps ) and (self.rank == 0 or dist.get_rank() == 0): # Handle the case if rank 0 aborted and re-assigned - self.latest_model = Path(self.save_ckpt) - self.latest_model = self.latest_model.with_name( - f"{self.latest_model.stem}_{_step_id + 1}{self.latest_model.suffix}" - ) + self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pt") + module = self.wrapper.module if dist.is_initialized() else self.wrapper self.save_model(self.latest_model, lr=cur_lr, step=_step_id) logging.info(f"Saved model to {self.latest_model}") + symlink_prefix_files(self.latest_model.stem, self.save_ckpt) + with open("checkpoint", "w") as f: + f.write(str(self.latest_model)) self.t0 = time.time() with logging_redirect_tqdm(): @@ -694,10 +697,6 @@ def log_loss_valid(_task_key="Default"): logging.info( f"Frozen model for inferencing has been saved to {pth_model_path}" ) - try: - os.symlink(self.latest_model, self.save_ckpt) - except OSError: - self.save_model(self.save_ckpt, lr=0, step=self.num_steps) logging.info(f"Trained model has been saved to: {self.save_ckpt}") if fout: diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 19b81d7a13..2d29a1a1c1 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -1,9 +1,7 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: LGPL-3.0-or-later -import glob import logging import os -import platform import shutil import time from typing import ( @@ -22,6 +20,9 @@ # load grad of force module import deepmd.tf.op # noqa: F401 +from deepmd.common import ( + symlink_prefix_files, +) from deepmd.tf.common import ( data_requirement, get_precision, @@ -830,19 +831,7 @@ def save_checkpoint(self, cur_batch: int): ) from e # make symlinks from prefix with step to that without step to break nothing # get all checkpoint files - original_files = glob.glob(ckpt_prefix + ".*") - for ori_ff in original_files: - new_ff = self.save_ckpt + ori_ff[len(ckpt_prefix) :] - try: - # remove old one - os.remove(new_ff) - except OSError: - pass - if platform.system() != "Windows": - # by default one does not have access to create symlink on Windows - os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff) - else: - shutil.copyfile(ori_ff, new_ff) + symlink_prefix_files(ckpt_prefix, self.save_ckpt) log.info("saved checkpoint %s" % self.save_ckpt) def get_feed_dict(self, batch, is_training): From 43d271d832d462335da9067330083cb910fa6bdc Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jan 2024 17:56:06 -0500 Subject: [PATCH 2/3] update docs Signed-off-by: Jinzhe Zeng --- deepmd/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/main.py b/deepmd/main.py index 30d2b293c0..ff7120c8e7 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -275,7 +275,7 @@ def main_parser() -> argparse.ArgumentParser: "--checkpoint", type=str, default=".", - help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing model.pt, or a pt file", + help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing checkpoint, or a pt file", ) parser_frz.add_argument( "-o", From 968ae482b55754cc6e2f5e7e5fa6304f747725c9 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jan 2024 18:04:25 -0500 Subject: [PATCH 3/3] forgot to push Signed-off-by: Jinzhe Zeng --- source/tests/pt/water/se_atten.json | 1 + 1 file changed, 1 insertion(+) diff --git a/source/tests/pt/water/se_atten.json b/source/tests/pt/water/se_atten.json index 8867e0db41..3ed80ae892 100644 --- a/source/tests/pt/water/se_atten.json +++ b/source/tests/pt/water/se_atten.json @@ -79,6 +79,7 @@ "disp_file": "lcurve.out", "disp_freq": 100, "save_freq": 1000, + "save_ckpt": "model", "_comment": "that's all" } }