Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PT: keep the same checkpoint behavior as TF #3191

Merged
merged 3 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import glob
import json
import os
import platform
import shutil
import warnings
from pathlib import (
Path,
Expand Down Expand Up @@ -268,3 +272,30 @@
return np.float64
else:
raise RuntimeError(f"{precision} is not a valid precision")


def symlink_prefix_files(old_prefix: str, new_prefix: str):
"""Create symlinks from old checkpoint prefix to new one.

On Windows this function will copy files instead of creating symlinks.

Parameters
----------
old_prefix : str
old checkpoint prefix, all files with this prefix will be symlinked
new_prefix : str
new checkpoint prefix
"""
original_files = glob.glob(old_prefix + ".*")
for ori_ff in original_files:
new_ff = new_prefix + ori_ff[len(old_prefix) :]
try:
# remove old one
os.remove(new_ff)
except OSError:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass
if platform.system() != "Windows":
# by default one does not have access to create symlink on Windows
os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff)
else:
shutil.copyfile(ori_ff, new_ff)

Check warning on line 301 in deepmd/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/common.py#L301

Added line #L301 was not covered by tests
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def main_parser() -> argparse.ArgumentParser:
"--checkpoint",
type=str,
default=".",
help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing model.pt, or a pt file",
help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing checkpoint, or a pt file",
)
parser_frz.add_argument(
"-o",
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@
test(FLAGS)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
# TODO: automatically generate model.pt during training
# FLAGS.model = str(Path(FLAGS.checkpoint).joinpath("model.pt"))
raise NotImplementedError("Checkpoint should give a file")
checkpoint_path = Path(FLAGS.checkpoint_folder)
latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))

Check warning on line 313 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L311-L313

Added lines #L311 - L313 were not covered by tests
else:
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
Expand Down
19 changes: 9 additions & 10 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import time
from copy import (
deepcopy,
Expand All @@ -22,6 +21,9 @@
logging_redirect_tqdm,
)

from deepmd.common import (
symlink_prefix_files,
)
from deepmd.pt.loss import (
DenoiseLoss,
EnergyStdLoss,
Expand Down Expand Up @@ -102,7 +104,7 @@ def __init__(
self.num_steps = training_params["numb_steps"]
self.disp_file = training_params.get("disp_file", "lcurve.out")
self.disp_freq = training_params.get("disp_freq", 1000)
self.save_ckpt = training_params.get("save_ckpt", "model.pt")
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
self.save_freq = training_params.get("save_freq", 1000)
self.lcurve_should_print_header = True

Expand Down Expand Up @@ -650,13 +652,14 @@ def log_loss_valid(_task_key="Default"):
or (_step_id + 1) == self.num_steps
) and (self.rank == 0 or dist.get_rank() == 0):
# Handle the case if rank 0 aborted and re-assigned
self.latest_model = Path(self.save_ckpt)
self.latest_model = self.latest_model.with_name(
f"{self.latest_model.stem}_{_step_id + 1}{self.latest_model.suffix}"
)
self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pt")

module = self.wrapper.module if dist.is_initialized() else self.wrapper
self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
logging.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))

self.t0 = time.time()
with logging_redirect_tqdm():
Expand Down Expand Up @@ -694,10 +697,6 @@ def log_loss_valid(_task_key="Default"):
logging.info(
f"Frozen model for inferencing has been saved to {pth_model_path}"
)
try:
os.symlink(self.latest_model, self.save_ckpt)
except OSError:
self.save_model(self.save_ckpt, lr=0, step=self.num_steps)
logging.info(f"Trained model has been saved to: {self.save_ckpt}")

if fout:
Expand Down
19 changes: 4 additions & 15 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: LGPL-3.0-or-later
import glob
import logging
import os
import platform
import shutil
import time
from typing import (
Expand All @@ -22,6 +20,9 @@

# load grad of force module
import deepmd.tf.op # noqa: F401
from deepmd.common import (
symlink_prefix_files,
)
from deepmd.tf.common import (
data_requirement,
get_precision,
Expand Down Expand Up @@ -830,19 +831,7 @@ def save_checkpoint(self, cur_batch: int):
) from e
# make symlinks from prefix with step to that without step to break nothing
# get all checkpoint files
original_files = glob.glob(ckpt_prefix + ".*")
for ori_ff in original_files:
new_ff = self.save_ckpt + ori_ff[len(ckpt_prefix) :]
try:
# remove old one
os.remove(new_ff)
except OSError:
pass
if platform.system() != "Windows":
# by default one does not have access to create symlink on Windows
os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff)
else:
shutil.copyfile(ori_ff, new_ff)
symlink_prefix_files(ckpt_prefix, self.save_ckpt)
log.info("saved checkpoint %s" % self.save_ckpt)

def get_feed_dict(self, batch, is_training):
Expand Down
1 change: 1 addition & 0 deletions source/tests/pt/water/se_atten.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"disp_file": "lcurve.out",
"disp_freq": 100,
"save_freq": 1000,
"save_ckpt": "model",
"_comment": "that's all"
}
}
Loading