diff --git a/deepmd/common.py b/deepmd/common.py index f950b50919..05d02234b4 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import glob import json +import os +import platform +import shutil import warnings from pathlib import ( Path, @@ -268,3 +272,30 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype: return np.float64 else: raise RuntimeError(f"{precision} is not a valid precision") + + +def symlink_prefix_files(old_prefix: str, new_prefix: str): + """Create symlinks from old checkpoint prefix to new one. + + On Windows this function will copy files instead of creating symlinks. + + Parameters + ---------- + old_prefix : str + old checkpoint prefix, all files with this prefix will be symlinked + new_prefix : str + new checkpoint prefix + """ + original_files = glob.glob(old_prefix + ".*") + for ori_ff in original_files: + new_ff = new_prefix + ori_ff[len(old_prefix) :] + try: + # remove old one + os.remove(new_ff) + except OSError: + pass + if platform.system() != "Windows": + # by default one does not have access to create symlink on Windows + os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff) + else: + shutil.copyfile(ori_ff, new_ff) diff --git a/deepmd/main.py b/deepmd/main.py index 30d2b293c0..ff7120c8e7 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -275,7 +275,7 @@ def main_parser() -> argparse.ArgumentParser: "--checkpoint", type=str, default=".", - help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing model.pt, or a pt file", + help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing checkpoint, or a pt file", ) parser_frz.add_argument( "-o", diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index c5e551ebd8..ad5e92d495 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -308,9 +308,9 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): test(FLAGS) elif FLAGS.command == "freeze": if Path(FLAGS.checkpoint_folder).is_dir(): - # TODO: automatically generate model.pt during training - # FLAGS.model = str(Path(FLAGS.checkpoint).joinpath("model.pt")) - raise NotImplementedError("Checkpoint should give a file") + checkpoint_path = Path(FLAGS.checkpoint_folder) + latest_ckpt_file = (checkpoint_path / "checkpoint").read_text() + FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file)) else: FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 673491d788..a14518e8a0 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -70,6 +70,8 @@ def task_deriv_one( if do_atomic_virial: extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy) extended_virial = extended_virial + extended_virial_corr + # to [...,3,3] -> [...,9] + extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005 return extended_force, extended_virial @@ -106,18 +108,18 @@ def take_deriv( split_svv1 = torch.split(svv1, [1] * size, dim=-1) split_ff, split_avir = [], [] for vvi, svvi in zip(split_vv1, split_svv1): - # nf x nloc x 3, nf x nloc x 3 x 3 + # nf x nloc x 3, nf x nloc x 9 ffi, aviri = task_deriv_one( vvi, svvi, coord_ext, do_atomic_virial=do_atomic_virial ) - # nf x nloc x 1 x 3, nf x nloc x 1 x 3 x 3 + # nf x nloc x 1 x 3, nf x nloc x 1 x 9 ffi = ffi.unsqueeze(-2) - aviri = aviri.unsqueeze(-3) + aviri = aviri.unsqueeze(-2) split_ff.append(ffi) split_avir.append(aviri) - # nf x nloc x v_dim x 3, nf x nloc x v_dim x 3 x 3 + # nf x nloc x v_dim x 3, nf x nloc x v_dim x 9 ff = torch.concat(split_ff, dim=-2) - avir = torch.concat(split_avir, dim=-3) + avir = torch.concat(split_avir, dim=-2) return ff, avir @@ -185,7 +187,7 @@ def communicate_extended_output( force = torch.zeros( vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device ) - # nf x nloc x 1 x 3 + # nf x nloc x nvar x 3 new_ret[kk_derv_r] = torch.scatter_reduce( force, 1, @@ -193,13 +195,15 @@ def communicate_extended_output( src=model_ret[kk_derv_r], reduce="sum", ) - mapping = mapping.unsqueeze(-1).expand( - [-1] * (len(mldims) + len(derv_r_ext_dims)) + [3] + derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005 + # nf x nloc x nvar x 3 -> nf x nloc x nvar x 9 + mapping = torch.tile( + mapping, [1] * (len(mldims) + len(vdef.shape)) + [3] ) virial = torch.zeros( - vldims + derv_r_ext_dims + [3], dtype=vv.dtype, device=vv.device + vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device ) - # nf x nloc x 1 x 3 + # nf x nloc x nvar x 9 new_ret[kk_derv_c] = torch.scatter_reduce( virial, 1, diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 049685a6e3..8ea69c8489 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import os import time from copy import ( deepcopy, @@ -22,6 +21,9 @@ logging_redirect_tqdm, ) +from deepmd.common import ( + symlink_prefix_files, +) from deepmd.pt.loss import ( DenoiseLoss, EnergyStdLoss, @@ -102,7 +104,7 @@ def __init__( self.num_steps = training_params["numb_steps"] self.disp_file = training_params.get("disp_file", "lcurve.out") self.disp_freq = training_params.get("disp_freq", 1000) - self.save_ckpt = training_params.get("save_ckpt", "model.pt") + self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") self.save_freq = training_params.get("save_freq", 1000) self.lcurve_should_print_header = True @@ -650,13 +652,14 @@ def log_loss_valid(_task_key="Default"): or (_step_id + 1) == self.num_steps ) and (self.rank == 0 or dist.get_rank() == 0): # Handle the case if rank 0 aborted and re-assigned - self.latest_model = Path(self.save_ckpt) - self.latest_model = self.latest_model.with_name( - f"{self.latest_model.stem}_{_step_id + 1}{self.latest_model.suffix}" - ) + self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pt") + module = self.wrapper.module if dist.is_initialized() else self.wrapper self.save_model(self.latest_model, lr=cur_lr, step=_step_id) logging.info(f"Saved model to {self.latest_model}") + symlink_prefix_files(self.latest_model.stem, self.save_ckpt) + with open("checkpoint", "w") as f: + f.write(str(self.latest_model)) self.t0 = time.time() with logging_redirect_tqdm(): @@ -694,10 +697,6 @@ def log_loss_valid(_task_key="Default"): logging.info( f"Frozen model for inferencing has been saved to {pth_model_path}" ) - try: - os.symlink(self.latest_model, self.save_ckpt) - except OSError: - self.save_model(self.save_ckpt, lr=0, step=self.num_steps) logging.info(f"Trained model has been saved to: {self.save_ckpt}") if fout: diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 19b81d7a13..2d29a1a1c1 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -1,9 +1,7 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: LGPL-3.0-or-later -import glob import logging import os -import platform import shutil import time from typing import ( @@ -22,6 +20,9 @@ # load grad of force module import deepmd.tf.op # noqa: F401 +from deepmd.common import ( + symlink_prefix_files, +) from deepmd.tf.common import ( data_requirement, get_precision, @@ -830,19 +831,7 @@ def save_checkpoint(self, cur_batch: int): ) from e # make symlinks from prefix with step to that without step to break nothing # get all checkpoint files - original_files = glob.glob(ckpt_prefix + ".*") - for ori_ff in original_files: - new_ff = self.save_ckpt + ori_ff[len(ckpt_prefix) :] - try: - # remove old one - os.remove(new_ff) - except OSError: - pass - if platform.system() != "Windows": - # by default one does not have access to create symlink on Windows - os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff) - else: - shutil.copyfile(ori_ff, new_ff) + symlink_prefix_files(ckpt_prefix, self.save_ckpt) log.info("saved checkpoint %s" % self.save_ckpt) def get_feed_dict(self, batch, is_training): diff --git a/source/tests/pt/test_autodiff.py b/source/tests/pt/test_autodiff.py index 4f303a8bb3..8840fbdd4c 100644 --- a/source/tests/pt/test_autodiff.py +++ b/source/tests/pt/test_autodiff.py @@ -121,9 +121,11 @@ def np_infer( def ff(bb): return np_infer(bb)["energy"] - fdv = -( - finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell - ).squeeze() + fdv = ( + -(finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell) + .squeeze() + .reshape(9) + ) rfv = np_infer(cell)["virial"] np.testing.assert_almost_equal(fdv, rfv, decimal=places) diff --git a/source/tests/pt/test_rot.py b/source/tests/pt/test_rot.py index b5d9d9b64b..7222fd6f69 100644 --- a/source/tests/pt/test_rot.py +++ b/source/tests/pt/test_rot.py @@ -65,8 +65,8 @@ def test( ) if not hasattr(self, "test_virial") or self.test_virial: torch.testing.assert_close( - torch.matmul(rmat.T, torch.matmul(ret0["virial"], rmat)), - ret1["virial"], + torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)), + ret1["virial"].view([3, 3]), rtol=prec, atol=prec, ) @@ -102,8 +102,8 @@ def test( ) if not hasattr(self, "test_virial") or self.test_virial: torch.testing.assert_close( - torch.matmul(rmat.T, torch.matmul(ret0["virial"], rmat)), - ret1["virial"], + torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)), + ret1["virial"].view([3, 3]), rtol=prec, atol=prec, ) diff --git a/source/tests/pt/test_rotation.py b/source/tests/pt/test_rotation.py index 4b49377a27..58ec80e0d6 100644 --- a/source/tests/pt/test_rotation.py +++ b/source/tests/pt/test_rotation.py @@ -121,9 +121,10 @@ def test_rotation(self): if "virial" in result1: self.assertTrue( torch.allclose( - result2["virial"][0], + result2["virial"][0].view([3, 3]), torch.matmul( - torch.matmul(rotation, result1["virial"][0].T), rotation.T + torch.matmul(rotation, result1["virial"][0].view([3, 3]).T), + rotation.T, ), ) ) diff --git a/source/tests/pt/water/se_atten.json b/source/tests/pt/water/se_atten.json index 8867e0db41..3ed80ae892 100644 --- a/source/tests/pt/water/se_atten.json +++ b/source/tests/pt/water/se_atten.json @@ -79,6 +79,7 @@ "disp_file": "lcurve.out", "disp_freq": 100, "save_freq": 1000, + "save_ckpt": "model", "_comment": "that's all" } }