diff --git a/deepmd/main.py b/deepmd/main.py index be489f8eda..e8b93320c6 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -745,6 +745,29 @@ def main_parser() -> argparse.ArgumentParser: ) parser_convert_backend.add_argument("INPUT", help="The input model file.") parser_convert_backend.add_argument("OUTPUT", help="The output model file.") + + # * show model ****************************************************************** + parser_show = subparsers.add_parser( + "show", + parents=[parser_log], + help="(Supported backend: PyTorch) Show the information of a model", + formatter_class=RawTextArgumentDefaultsHelpFormatter, + epilog=textwrap.dedent( + """\ + examples: + dp --pt show model.pt model-branch type-map descriptor fitting-net + dp --pt show frozen_model.pth type-map descriptor fitting-net + """ + ), + ) + parser_show.add_argument( + "INPUT", help="The input checkpoint file or frozen model file" + ) + parser_show.add_argument( + "ATTRIBUTES", + choices=["model-branch", "type-map", "descriptor", "fitting-net"], + nargs="+", + ) return parser @@ -802,6 +825,7 @@ def main(): "compress", "convert-from", "train-nvnmd", + "show", ): deepmd_main = BACKENDS[args.backend]().entry_point_hook elif args.command is None: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index eafce67e84..fba22e6d24 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -44,6 +44,9 @@ from deepmd.pt.train import ( training, ) +from deepmd.pt.utils import ( + env, +) from deepmd.pt.utils.dataloader import ( DpLoaderSet, ) @@ -297,6 +300,67 @@ def freeze(FLAGS): ) +def show(FLAGS): + if FLAGS.INPUT.split(".")[-1] == "pt": + state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) + if "model" in state_dict: + state_dict = state_dict["model"] + model_params = state_dict["_extra_state"]["model_params"] + elif FLAGS.INPUT.split(".")[-1] == "pth": + model_params_string = torch.jit.load( + FLAGS.INPUT, map_location=env.DEVICE + ).model_def_script + model_params = json.loads(model_params_string) + else: + raise RuntimeError( + "The model provided must be a checkpoint file with a .pt extension " + "or a frozen model with a .pth extension" + ) + model_is_multi_task = "model_dict" in model_params + log.info("This is a multitask model") if model_is_multi_task else log.info( + "This is a singletask model" + ) + + if "model-branch" in FLAGS.ATTRIBUTES: + # The model must be multitask mode + if not model_is_multi_task: + raise RuntimeError( + "The 'model-branch' option requires a multitask model." + " The provided model does not meet this criterion." + ) + model_branches = list(model_params["model_dict"].keys()) + log.info(f"Available model branches are {model_branches}") + if "type-map" in FLAGS.ATTRIBUTES: + if model_is_multi_task: + model_branches = list(model_params["model_dict"].keys()) + for branch in model_branches: + type_map = model_params["model_dict"][branch]["type_map"] + log.info(f"The type_map of branch {branch} is {type_map}") + else: + type_map = model_params["type_map"] + log.info(f"The type_map is {type_map}") + if "descriptor" in FLAGS.ATTRIBUTES: + if model_is_multi_task: + model_branches = list(model_params["model_dict"].keys()) + for branch in model_branches: + descriptor = model_params["model_dict"][branch]["descriptor"] + log.info(f"The descriptor parameter of branch {branch} is {descriptor}") + else: + descriptor = model_params["descriptor"] + log.info(f"The descriptor parameter is {descriptor}") + if "fitting-net" in FLAGS.ATTRIBUTES: + if model_is_multi_task: + model_branches = list(model_params["model_dict"].keys()) + for branch in model_branches: + fitting_net = model_params["model_dict"][branch]["fitting_net"] + log.info( + f"The fitting_net parameter of branch {branch} is {fitting_net}" + ) + else: + fitting_net = model_params["fitting_net"] + log.info(f"The fitting_net parameter is {fitting_net}") + + @record def main(args: Optional[Union[List[str], argparse.Namespace]] = None): if not isinstance(args, argparse.Namespace): @@ -319,6 +383,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) freeze(FLAGS) + elif FLAGS.command == "show": + show(FLAGS) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/doc/train/finetuning.md b/doc/train/finetuning.md index 79ca8cdea4..77630720c7 100644 --- a/doc/train/finetuning.md +++ b/doc/train/finetuning.md @@ -99,11 +99,12 @@ $ dp --pt train input.json --finetune multitask_pretrained.pt --model-branch CHO ``` :::{note} -To check the available model branches, you can typically refer to the documentation of the pre-trained model. -If you're still unsure about the available branches, you can try inputting an arbitrary branch name. -This will prompt an error message that displays a list of all the available model branches. +One can check the available model branches in multi-task pre-trained model by refering to the documentation of the pre-trained model or by using the following command: + +```bash +$ dp --pt show multitask_pretrained.pt model-branch +``` -Please note that this feature will be improved in the upcoming version to provide a more user-friendly experience. ::: This command will start fine-tuning based on the pre-trained model's descriptor and the selected branch's fitting net. diff --git a/source/tests/pt/common.py b/source/tests/pt/common.py new file mode 100644 index 0000000000..445a417592 --- /dev/null +++ b/source/tests/pt/common.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.pt.entrypoints.main import ( + main, +) + + +def run_dp(cmd: str) -> int: + """Run DP directly from the entry point instead of the subprocess. + + It is quite slow to start DeePMD-kit with subprocess. + + Parameters + ---------- + cmd : str + The command to run. + + Returns + ------- + int + Always returns 0. + """ + cmds = cmd.split() + if cmds[0] == "dp": + cmds = cmds[1:] + else: + raise RuntimeError("The command is not dp") + + main(cmds) + return 0 diff --git a/source/tests/pt/test_dp_show.py b/source/tests/pt/test_dp_show.py new file mode 100644 index 0000000000..da5137d7ae --- /dev/null +++ b/source/tests/pt/test_dp_show.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import io +import json +import os +import shutil +import unittest +from contextlib import ( + redirect_stderr, +) +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) + +from .common import ( + run_dp, +) +from .model.test_permutation import ( + model_se_e2_a, +) + + +class TestSingleTaskModel(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + data_file = [str(Path(__file__).parent / "water/data/single")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["model"]["type_map"] = ["O", "H", "Au"] + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + run_dp("dp --pt freeze") + + def test_checkpoint(self): + INPUT = "model.pt" + ATTRIBUTES = "type-map descriptor fitting-net" + with redirect_stderr(io.StringIO()) as f: + run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}") + results = f.getvalue().split("\n")[:-1] + assert "This is a singletask model" in results[-4] + assert "The type_map is ['O', 'H', 'Au']" in results[-3] + assert ( + "{'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut': 4.0" + ) in results[-2] + assert ( + "The fitting_net parameter is {'neuron': [24, 24, 24], 'resnet_dt': True, 'seed': 1}" + in results[-1] + ) + + def test_frozen_model(self): + INPUT = "frozen_model.pth" + ATTRIBUTES = "type-map descriptor fitting-net" + with redirect_stderr(io.StringIO()) as f: + run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}") + results = f.getvalue().split("\n")[:-1] + assert "This is a singletask model" in results[-4] + assert "The type_map is ['O', 'H', 'Au']" in results[-3] + assert ( + "{'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut': 4.0" + ) in results[-2] + assert ( + "The fitting_net parameter is {'neuron': [24, 24, 24], 'resnet_dt': True, 'seed': 1}" + in results[-1] + ) + + def test_checkpoint_error(self): + INPUT = "model.pt" + ATTRIBUTES = "model-branch type-map descriptor fitting-net" + with self.assertRaisesRegex( + RuntimeError, "The 'model-branch' option requires a multitask model" + ): + run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}") + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith("pt"): + os.remove(f) + if f in ["lcurve.out", "frozen_model.pth", "output.txt", "checkpoint"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + +class TestMultiTaskModel(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/multitask.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "se_e2_a" + os.makedirs(self.stat_files, exist_ok=True) + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["model"]["model_dict"]["model_1"]["fitting_net"] = { + "neuron": [1, 2, 3], + "seed": 678, + } + self.config["model"]["model_dict"]["model_2"]["fitting_net"] = { + "neuron": [9, 8, 7], + "seed": 1111, + } + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) + trainer.run() + run_dp("dp --pt freeze --head model_1") + + def test_checkpoint(self): + INPUT = "model.ckpt.pt" + ATTRIBUTES = "model-branch type-map descriptor fitting-net" + with redirect_stderr(io.StringIO()) as f: + run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}") + results = f.getvalue().split("\n")[:-1] + assert "This is a multitask model" in results[-8] + assert "Available model branches are ['model_1', 'model_2']" in results[-7] + assert "The type_map of branch model_1 is ['O', 'H', 'B']" in results[-6] + assert "The type_map of branch model_2 is ['O', 'H', 'B']" in results[-5] + assert ( + "model_1" + and "'type': 'se_e2_a'" + and "'sel': [46, 92, 4]" + and "'rcut_smth': 0.5" + ) in results[-4] + assert ( + "model_2" + and "'type': 'se_e2_a'" + and "'sel': [46, 92, 4]" + and "'rcut_smth': 0.5" + ) in results[-3] + assert ( + "The fitting_net parameter of branch model_1 is {'neuron': [1, 2, 3], 'seed': 678}" + in results[-2] + ) + assert ( + "The fitting_net parameter of branch model_2 is {'neuron': [9, 8, 7], 'seed': 1111}" + in results[-1] + ) + + def test_frozen_model(self): + INPUT = "frozen_model.pth" + ATTRIBUTES = "type-map descriptor fitting-net" + with redirect_stderr(io.StringIO()) as f: + run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}") + results = f.getvalue().split("\n")[:-1] + assert "This is a singletask model" in results[-4] + assert "The type_map is ['O', 'H', 'B']" in results[-3] + assert ( + "'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut_smth': 0.5" + ) in results[-2] + assert ( + "The fitting_net parameter is {'neuron': [1, 2, 3], 'seed': 678}" + in results[-1] + ) + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith("pt"): + os.remove(f) + if f in ["lcurve.out", "frozen_model.pth", "checkpoint", "output.txt"]: + os.remove(f) + if f in ["stat_files", self.stat_files]: + shutil.rmtree(f)