Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PT: keep the same checkpoint behavior as TF #3191

Merged
merged 3 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import glob
import json
import os
import platform
import shutil
import warnings
from pathlib import (
Path,
Expand Down Expand Up @@ -268,3 +272,30 @@
return np.float64
else:
raise RuntimeError(f"{precision} is not a valid precision")


def symlink_prefix_files(old_prefix: str, new_prefix: str):
"""Create symlinks from old checkpoint prefix to new one.

On Windows this function will copy files instead of creating symlinks.

Parameters
----------
old_prefix : str
old checkpoint prefix, all files with this prefix will be symlinked
new_prefix : str
new checkpoint prefix
"""
original_files = glob.glob(old_prefix + ".*")
for ori_ff in original_files:
new_ff = new_prefix + ori_ff[len(old_prefix) :]
try:

Check warning on line 292 in deepmd/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/common.py#L289-L292

Added lines #L289 - L292 were not covered by tests
# remove old one
os.remove(new_ff)
except OSError:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass
if platform.system() != "Windows":

Check warning on line 297 in deepmd/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/common.py#L294-L297

Added lines #L294 - L297 were not covered by tests
# by default one does not have access to create symlink on Windows
os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff)

Check warning on line 299 in deepmd/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/common.py#L299

Added line #L299 was not covered by tests
else:
shutil.copyfile(ori_ff, new_ff)

Check warning on line 301 in deepmd/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/common.py#L301

Added line #L301 was not covered by tests
6 changes: 3 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@
test(FLAGS)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
# TODO: automatically generate model.pt during training
# FLAGS.model = str(Path(FLAGS.checkpoint).joinpath("model.pt"))
raise NotImplementedError("Checkpoint should give a file")
checkpoint_path = Path(FLAGS.checkpoint_folder)
latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))

Check warning on line 313 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L311-L313

Added lines #L311 - L313 were not covered by tests
else:
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
Expand Down
19 changes: 9 additions & 10 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import time
from copy import (
deepcopy,
Expand All @@ -22,6 +21,9 @@
logging_redirect_tqdm,
)

from deepmd.common import (

Check warning on line 24 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L24

Added line #L24 was not covered by tests
symlink_prefix_files,
)
from deepmd.pt.loss import (
DenoiseLoss,
EnergyStdLoss,
Expand Down Expand Up @@ -102,7 +104,7 @@
self.num_steps = training_params["numb_steps"]
self.disp_file = training_params.get("disp_file", "lcurve.out")
self.disp_freq = training_params.get("disp_freq", 1000)
self.save_ckpt = training_params.get("save_ckpt", "model.pt")
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")

Check warning on line 107 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L107

Added line #L107 was not covered by tests
self.save_freq = training_params.get("save_freq", 1000)
self.lcurve_should_print_header = True

Expand Down Expand Up @@ -650,13 +652,14 @@
or (_step_id + 1) == self.num_steps
) and (self.rank == 0 or dist.get_rank() == 0):
# Handle the case if rank 0 aborted and re-assigned
self.latest_model = Path(self.save_ckpt)
self.latest_model = self.latest_model.with_name(
f"{self.latest_model.stem}_{_step_id + 1}{self.latest_model.suffix}"
)
self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pt")

Check warning on line 655 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L655

Added line #L655 was not covered by tests

module = self.wrapper.module if dist.is_initialized() else self.wrapper
self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
logging.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))

Check warning on line 662 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L660-L662

Added lines #L660 - L662 were not covered by tests

self.t0 = time.time()
with logging_redirect_tqdm():
Expand Down Expand Up @@ -694,10 +697,6 @@
logging.info(
f"Frozen model for inferencing has been saved to {pth_model_path}"
)
try:
os.symlink(self.latest_model, self.save_ckpt)
except OSError:
self.save_model(self.save_ckpt, lr=0, step=self.num_steps)
logging.info(f"Trained model has been saved to: {self.save_ckpt}")

if fout:
Expand Down
19 changes: 4 additions & 15 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: LGPL-3.0-or-later
import glob
import logging
import os
import platform
import shutil
import time
from typing import (
Expand All @@ -22,6 +20,9 @@

# load grad of force module
import deepmd.tf.op # noqa: F401
from deepmd.common import (

Check warning on line 23 in deepmd/tf/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/train/trainer.py#L23

Added line #L23 was not covered by tests
symlink_prefix_files,
)
from deepmd.tf.common import (
data_requirement,
get_precision,
Expand Down Expand Up @@ -830,19 +831,7 @@
) from e
# make symlinks from prefix with step to that without step to break nothing
# get all checkpoint files
original_files = glob.glob(ckpt_prefix + ".*")
for ori_ff in original_files:
new_ff = self.save_ckpt + ori_ff[len(ckpt_prefix) :]
try:
# remove old one
os.remove(new_ff)
except OSError:
pass
if platform.system() != "Windows":
# by default one does not have access to create symlink on Windows
os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff)
else:
shutil.copyfile(ori_ff, new_ff)
symlink_prefix_files(ckpt_prefix, self.save_ckpt)

Check warning on line 834 in deepmd/tf/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/train/trainer.py#L834

Added line #L834 was not covered by tests
log.info("saved checkpoint %s" % self.save_ckpt)

def get_feed_dict(self, batch, is_training):
Expand Down
Loading