diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index 4747e05874..02625f5331 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import json from typing import ( TYPE_CHECKING, Any, @@ -372,3 +373,7 @@ def _get_output_shape(self, odef, nframes, natoms): return [nframes, natoms, *odef.shape, 1] else: raise RuntimeError("unknown category") + + def get_model_def_script(self) -> dict: + """Get model defination script.""" + return json.loads(self.model.get_model_def_script()) diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index 9f05b9a530..ba2eb90247 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -24,6 +24,9 @@ from deepmd.entrypoints.neighbor_stat import ( neighbor_stat, ) +from deepmd.entrypoints.show import ( + show, +) from deepmd.entrypoints.test import ( test, ) @@ -81,5 +84,7 @@ def main(args: argparse.Namespace): start_dpgui(**dict_args) elif args.command == "convert-backend": convert_backend(**dict_args) + elif args.command == "show": + show(**dict_args) else: raise ValueError(f"Unknown command: {args.command}") diff --git a/deepmd/entrypoints/show.py b/deepmd/entrypoints/show.py new file mode 100644 index 0000000000..6f72c4614d --- /dev/null +++ b/deepmd/entrypoints/show.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + List, +) + +from deepmd.infer.deep_eval import ( + DeepEval, +) + +log = logging.getLogger(__name__) + + +def show( + *, + INPUT: str, + ATTRIBUTES: List[str], + **kwargs, +): + model = DeepEval(INPUT, head=0) + model_params = model.get_model_def_script() + model_is_multi_task = "model_dict" in model_params + log.info("This is a multitask model") if model_is_multi_task else log.info( + "This is a singletask model" + ) + + if "model-branch" in ATTRIBUTES: + # The model must be multitask mode + if not model_is_multi_task: + raise RuntimeError( + "The 'model-branch' option requires a multitask model." + " The provided model does not meet this criterion." + ) + model_branches = list(model_params["model_dict"].keys()) + model_branches += ["RANDOM"] + log.info( + f"Available model branches are {model_branches}, " + f"where 'RANDOM' means using a randomly initialized fitting net." + ) + if "type-map" in ATTRIBUTES: + if model_is_multi_task: + model_branches = list(model_params["model_dict"].keys()) + for branch in model_branches: + type_map = model_params["model_dict"][branch]["type_map"] + log.info(f"The type_map of branch {branch} is {type_map}") + else: + type_map = model_params["type_map"] + log.info(f"The type_map is {type_map}") + if "descriptor" in ATTRIBUTES: + if model_is_multi_task: + model_branches = list(model_params["model_dict"].keys()) + for branch in model_branches: + descriptor = model_params["model_dict"][branch]["descriptor"] + log.info(f"The descriptor parameter of branch {branch} is {descriptor}") + else: + descriptor = model_params["descriptor"] + log.info(f"The descriptor parameter is {descriptor}") + if "fitting-net" in ATTRIBUTES: + if model_is_multi_task: + model_branches = list(model_params["model_dict"].keys()) + for branch in model_branches: + fitting_net = model_params["model_dict"][branch]["fitting_net"] + log.info( + f"The fitting_net parameter of branch {branch} is {fitting_net}" + ) + else: + fitting_net = model_params["fitting_net"] + log.info(f"The fitting_net parameter is {fitting_net}") diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 209c679e1b..f35094df3d 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -284,6 +284,10 @@ def get_has_spin(self): def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model. Only used in old implement.""" + def get_model_def_script(self) -> dict: + """Get model defination script.""" + raise NotImplementedError("Not implemented in this backend.") + class DeepEval(ABC): """High-level Deep Evaluator interface. @@ -546,3 +550,7 @@ def has_spin(self) -> bool: def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model. Only used in old implement.""" return self.deep_eval.get_ntypes_spin() + + def get_model_def_script(self) -> dict: + """Get model defination script.""" + return self.deep_eval.get_model_def_script() diff --git a/deepmd/main.py b/deepmd/main.py index 777bfd3aa3..6869ca7e88 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -880,15 +880,21 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: return parsed_args -def main(): +def main(args: Optional[List[str]] = None): """DeePMD-kit new entry point. + Parameters + ---------- + args : List[str] + list of command line arguments, main purpose is testing default option None + takes arguments from sys.argv + Raises ------ RuntimeError if no command was input """ - args = parse_args() + args = parse_args(args=args) if args.backend not in BACKEND_TABLE: raise ValueError(f"Unknown backend {args.backend}") @@ -900,6 +906,7 @@ def main(): "neighbor-stat", "gui", "convert-backend", + "show", ): # common entrypoints from deepmd.entrypoints.main import main as deepmd_main @@ -910,7 +917,6 @@ def main(): "compress", "convert-from", "train-nvnmd", - "show", "change-bias", ): deepmd_main = BACKENDS[args.backend]().entry_point_hook diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 03b9f8c3a1..9133575ec8 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -356,71 +356,6 @@ def freeze(FLAGS): ) -def show(FLAGS): - if FLAGS.INPUT.split(".")[-1] == "pt": - state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) - if "model" in state_dict: - state_dict = state_dict["model"] - model_params = state_dict["_extra_state"]["model_params"] - elif FLAGS.INPUT.split(".")[-1] == "pth": - model_params_string = torch.jit.load( - FLAGS.INPUT, map_location=env.DEVICE - ).model_def_script - model_params = json.loads(model_params_string) - else: - raise RuntimeError( - "The model provided must be a checkpoint file with a .pt extension " - "or a frozen model with a .pth extension" - ) - model_is_multi_task = "model_dict" in model_params - log.info("This is a multitask model") if model_is_multi_task else log.info( - "This is a singletask model" - ) - - if "model-branch" in FLAGS.ATTRIBUTES: - # The model must be multitask mode - if not model_is_multi_task: - raise RuntimeError( - "The 'model-branch' option requires a multitask model." - " The provided model does not meet this criterion." - ) - model_branches = list(model_params["model_dict"].keys()) - model_branches += ["RANDOM"] - log.info( - f"Available model branches are {model_branches}, " - f"where 'RANDOM' means using a randomly initialized fitting net." - ) - if "type-map" in FLAGS.ATTRIBUTES: - if model_is_multi_task: - model_branches = list(model_params["model_dict"].keys()) - for branch in model_branches: - type_map = model_params["model_dict"][branch]["type_map"] - log.info(f"The type_map of branch {branch} is {type_map}") - else: - type_map = model_params["type_map"] - log.info(f"The type_map is {type_map}") - if "descriptor" in FLAGS.ATTRIBUTES: - if model_is_multi_task: - model_branches = list(model_params["model_dict"].keys()) - for branch in model_branches: - descriptor = model_params["model_dict"][branch]["descriptor"] - log.info(f"The descriptor parameter of branch {branch} is {descriptor}") - else: - descriptor = model_params["descriptor"] - log.info(f"The descriptor parameter is {descriptor}") - if "fitting-net" in FLAGS.ATTRIBUTES: - if model_is_multi_task: - model_branches = list(model_params["model_dict"].keys()) - for branch in model_branches: - fitting_net = model_params["model_dict"][branch]["fitting_net"] - log.info( - f"The fitting_net parameter of branch {branch} is {fitting_net}" - ) - else: - fitting_net = model_params["fitting_net"] - log.info(f"The fitting_net parameter is {fitting_net}") - - def change_bias(FLAGS): if FLAGS.INPUT.endswith(".pt"): old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) @@ -574,8 +509,6 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) freeze(FLAGS) - elif FLAGS.command == "show": - show(FLAGS) elif FLAGS.command == "change-bias": change_bias(FLAGS) else: diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 48630007d0..4ff8e3b345 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import json from typing import ( TYPE_CHECKING, Any, @@ -100,7 +101,7 @@ def __init__( *args: Any, auto_batch_size: Union[bool, int, AutoBatchSize] = True, neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, - head: Optional[str] = None, + head: Optional[Union[str, int]] = None, **kwargs: Any, ): self.output_def = output_def @@ -110,9 +111,12 @@ def __init__( if "model" in state_dict: state_dict = state_dict["model"] self.input_param = state_dict["_extra_state"]["model_params"] + self.model_def_script = self.input_param self.multi_task = "model_dict" in self.input_param if self.multi_task: model_keys = list(self.input_param["model_dict"].keys()) + if isinstance(head, int): + head = model_keys[0] assert ( head is not None ), f"Head must be set for multitask model! Available heads are: {model_keys}" @@ -134,6 +138,9 @@ def __init__( elif str(self.model_path).endswith(".pth"): model = torch.jit.load(model_file, map_location=env.DEVICE) self.dp = ModelWrapper(model) + self.model_def_script = json.loads( + self.dp.model["Default"].get_model_def_script() + ) else: raise ValueError("Unknown model file format!") self.rcut = self.dp.model["Default"].get_rcut() @@ -590,6 +597,10 @@ def eval_typeebd(self) -> np.ndarray: typeebd = torch.cat(out, dim=1) return to_numpy_array(typeebd) + def get_model_def_script(self) -> str: + """Get model defination script.""" + return self.model_def_script + # For tests only def eval_model( diff --git a/deepmd/tf/infer/deep_eval.py b/deepmd/tf/infer/deep_eval.py index c3ed3a3688..0f317bd21f 100644 --- a/deepmd/tf/infer/deep_eval.py +++ b/deepmd/tf/infer/deep_eval.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import json from functools import ( lru_cache, ) @@ -1123,6 +1124,13 @@ def get_numb_dos(self) -> int: def get_has_efield(self) -> bool: return self.has_efield + def get_model_def_script(self) -> dict: + """Get model defination script.""" + t_script = self._get_tensor("train_attr/training_script:0") + [script] = run_sess(self.sess, [t_script], feed_dict={}) + model_def_script = script.decode("utf-8") + return json.loads(model_def_script)["model"] + class DeepEvalOld: # old class for DipoleChargeModifier only diff --git a/source/tests/infer/case.py b/source/tests/infer/case.py index 662d84e7e0..c1bce424c4 100644 --- a/source/tests/infer/case.py +++ b/source/tests/infer/case.py @@ -147,6 +147,7 @@ def __init__(self, filename: str): self.type_map = config["type_map"] self.dim_fparam = config["dim_fparam"] self.dim_aparam = config["dim_aparam"] + self.model_def_script = config.get("model_def_script") @lru_cache def get_model(self, suffix: str, out_file: Optional[str] = None) -> str: diff --git a/source/tests/infer/deeppot-testcase.yaml b/source/tests/infer/deeppot-testcase.yaml index 8e031a1638..9523b8d1ea 100644 --- a/source/tests/infer/deeppot-testcase.yaml +++ b/source/tests/infer/deeppot-testcase.yaml @@ -6,6 +6,44 @@ rcut: 6.0 type_map: ["O", "H"] dim_fparam: 0 dim_aparam: 0 +model_def_script: + { + "data_bias_nsample": 10, + "data_stat_nbatch": 10, + "data_stat_protect": 0.01, + "descriptor": + { + "activation_function": "tanh", + "axis_neuron": 4, + "exclude_types": [], + "neuron": [2, 4, 8], + "precision": "default", + "rcut": 6.0, + "rcut_smth": 0.5, + "resnet_dt": False, + "seed": 1, + "sel": [46, 92], + "set_davg_zero": False, + "trainable": True, + "type": "se_e2_a", + "type_one_side": False, + }, + "fitting_net": + { + "activation_function": "tanh", + "atom_ener": [], + "neuron": [6, 6, 6], + "numb_aparam": 0, + "numb_fparam": 0, + "precision": "default", + "rcond": 0.001, + "resnet_dt": True, + "seed": 1, + "trainable": True, + "type": "ener", + }, + "type_map": ["O", "H"], + } results: - coord: [ diff --git a/source/tests/infer/fparam_aparam-testcase.yaml b/source/tests/infer/fparam_aparam-testcase.yaml index 2ebab64192..220b2df209 100644 --- a/source/tests/infer/fparam_aparam-testcase.yaml +++ b/source/tests/infer/fparam_aparam-testcase.yaml @@ -6,6 +6,45 @@ rcut: 6.0 type_map: ["O"] dim_fparam: 1 dim_aparam: 1 +model_def_script: + { + "data_bias_nsample": 10, + "data_stat_nbatch": 1, + "data_stat_protect": 0.01, + "descriptor": + { + "activation_function": "tanh", + "axis_neuron": 8, + "exclude_types": [], + "neuron": [5, 10, 20], + "precision": "default", + "rcut": 6.0, + "rcut_smth": 1.8, + "resnet_dt": False, + "seed": 1, + "sel": [60], + "set_davg_zero": False, + "trainable": True, + "type": "se_e2_a", + "type_one_side": False, + }, + "fitting_net": + { + "activation_function": "tanh", + "atom_ener": [], + "neuron": [5, 5, 5], + "numb_aparam": 1, + "numb_fparam": 1, + "precision": "default", + "rcond": 0.001, + "resnet_dt": True, + "seed": 1, + "trainable": True, + "type": "ener", + "use_aparam_as_mask": False, + }, + "type_map": ["O"], + } results: - coord: [ diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index f193c09616..6b62e994aa 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -325,6 +325,12 @@ def test_dpdata_driver(self): err_msg=f"Result {ii} virial", ) + def test_model_script_def(self): + if self.case.model_def_script is not None: + self.assertDictEqual( + self.case.model_def_script, self.dp.get_model_def_script() + ) + @parameterized( ("se_e2_a",), # key diff --git a/source/tests/pt/common.py b/source/tests/pt/common.py index 445a417592..8886522360 100644 --- a/source/tests/pt/common.py +++ b/source/tests/pt/common.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.pt.entrypoints.main import ( +from deepmd.main import ( main, )