Skip to content

Commit

Permalink
feat: DeepEval.get_model_def_script and common dp show
Browse files Browse the repository at this point in the history
Fix deepmodeling#4019.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 16, 2024
1 parent 96ed5df commit bd74bff
Show file tree
Hide file tree
Showing 13 changed files with 200 additions and 72 deletions.
5 changes: 5 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -372,3 +373,7 @@ def _get_output_shape(self, odef, nframes, natoms):
return [nframes, natoms, *odef.shape, 1]
else:
raise RuntimeError("unknown category")

def get_model_def_script(self) -> dict:
"""Get model defination script."""
return json.loads(self.model.get_model_def_script())
5 changes: 5 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from deepmd.entrypoints.neighbor_stat import (
neighbor_stat,
)
from deepmd.entrypoints.show import (
show,
)
from deepmd.entrypoints.test import (
test,
)
Expand Down Expand Up @@ -81,5 +84,7 @@ def main(args: argparse.Namespace):
start_dpgui(**dict_args)
elif args.command == "convert-backend":
convert_backend(**dict_args)
elif args.command == "show":
show(**dict_args)
else:
raise ValueError(f"Unknown command: {args.command}")
68 changes: 68 additions & 0 deletions deepmd/entrypoints/show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
)

from deepmd.infer.deep_eval import (
DeepEval,
)

log = logging.getLogger(__name__)


def show(
*,
INPUT: str,
ATTRIBUTES: List[str],
**kwargs,
):
model = DeepEval(INPUT, head=0)
model_params = model.get_model_def_script()
model_is_multi_task = "model_dict" in model_params
log.info("This is a multitask model") if model_is_multi_task else log.info(
"This is a singletask model"
)

if "model-branch" in ATTRIBUTES:
# The model must be multitask mode
if not model_is_multi_task:
raise RuntimeError(
"The 'model-branch' option requires a multitask model."
" The provided model does not meet this criterion."
)
model_branches = list(model_params["model_dict"].keys())
model_branches += ["RANDOM"]
log.info(
f"Available model branches are {model_branches}, "
f"where 'RANDOM' means using a randomly initialized fitting net."
)
if "type-map" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
type_map = model_params["model_dict"][branch]["type_map"]
log.info(f"The type_map of branch {branch} is {type_map}")
else:
type_map = model_params["type_map"]
log.info(f"The type_map is {type_map}")
if "descriptor" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
descriptor = model_params["model_dict"][branch]["descriptor"]
log.info(f"The descriptor parameter of branch {branch} is {descriptor}")
else:
descriptor = model_params["descriptor"]
log.info(f"The descriptor parameter is {descriptor}")
if "fitting-net" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
fitting_net = model_params["model_dict"][branch]["fitting_net"]
log.info(
f"The fitting_net parameter of branch {branch} is {fitting_net}"
)
else:
fitting_net = model_params["fitting_net"]
log.info(f"The fitting_net parameter is {fitting_net}")
8 changes: 8 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def get_has_spin(self):
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""

def get_model_def_script(self) -> dict:
"""Get model defination script."""
raise NotImplementedError("Not implemented in this backend.")


class DeepEval(ABC):
"""High-level Deep Evaluator interface.
Expand Down Expand Up @@ -546,3 +550,7 @@ def has_spin(self) -> bool:
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""
return self.deep_eval.get_ntypes_spin()

def get_model_def_script(self) -> dict:
"""Get model defination script."""
return self.deep_eval.get_model_def_script()
12 changes: 9 additions & 3 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,15 +880,21 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
return parsed_args


def main():
def main(args: Optional[List[str]] = None):
"""DeePMD-kit new entry point.
Parameters
----------
args : List[str]
list of command line arguments, main purpose is testing default option None
takes arguments from sys.argv
Raises
------
RuntimeError
if no command was input
"""
args = parse_args()
args = parse_args(args=args)

if args.backend not in BACKEND_TABLE:
raise ValueError(f"Unknown backend {args.backend}")
Expand All @@ -900,6 +906,7 @@ def main():
"neighbor-stat",
"gui",
"convert-backend",
"show",
):
# common entrypoints
from deepmd.entrypoints.main import main as deepmd_main
Expand All @@ -910,7 +917,6 @@ def main():
"compress",
"convert-from",
"train-nvnmd",
"show",
"change-bias",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
Expand Down
67 changes: 0 additions & 67 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,71 +356,6 @@ def freeze(FLAGS):
)


def show(FLAGS):
if FLAGS.INPUT.split(".")[-1] == "pt":
state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
if "model" in state_dict:
state_dict = state_dict["model"]
model_params = state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.split(".")[-1] == "pth":
model_params_string = torch.jit.load(
FLAGS.INPUT, map_location=env.DEVICE
).model_def_script
model_params = json.loads(model_params_string)
else:
raise RuntimeError(
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
model_is_multi_task = "model_dict" in model_params
log.info("This is a multitask model") if model_is_multi_task else log.info(
"This is a singletask model"
)

if "model-branch" in FLAGS.ATTRIBUTES:
# The model must be multitask mode
if not model_is_multi_task:
raise RuntimeError(
"The 'model-branch' option requires a multitask model."
" The provided model does not meet this criterion."
)
model_branches = list(model_params["model_dict"].keys())
model_branches += ["RANDOM"]
log.info(
f"Available model branches are {model_branches}, "
f"where 'RANDOM' means using a randomly initialized fitting net."
)
if "type-map" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
type_map = model_params["model_dict"][branch]["type_map"]
log.info(f"The type_map of branch {branch} is {type_map}")
else:
type_map = model_params["type_map"]
log.info(f"The type_map is {type_map}")
if "descriptor" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
descriptor = model_params["model_dict"][branch]["descriptor"]
log.info(f"The descriptor parameter of branch {branch} is {descriptor}")
else:
descriptor = model_params["descriptor"]
log.info(f"The descriptor parameter is {descriptor}")
if "fitting-net" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
fitting_net = model_params["model_dict"][branch]["fitting_net"]
log.info(
f"The fitting_net parameter of branch {branch} is {fitting_net}"
)
else:
fitting_net = model_params["fitting_net"]
log.info(f"The fitting_net parameter is {fitting_net}")


def change_bias(FLAGS):
if FLAGS.INPUT.endswith(".pt"):
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
Expand Down Expand Up @@ -574,8 +509,6 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
elif FLAGS.command == "show":
show(FLAGS)
elif FLAGS.command == "change-bias":
change_bias(FLAGS)
else:
Expand Down
13 changes: 12 additions & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
head: Optional[str] = None,
head: Optional[Union[str, int]] = None,
**kwargs: Any,
):
self.output_def = output_def
Expand All @@ -110,9 +111,12 @@ def __init__(
if "model" in state_dict:
state_dict = state_dict["model"]
self.input_param = state_dict["_extra_state"]["model_params"]
self.model_def_script = self.input_param
self.multi_task = "model_dict" in self.input_param
if self.multi_task:
model_keys = list(self.input_param["model_dict"].keys())
if isinstance(head, int):
head = model_keys[0]
assert (
head is not None
), f"Head must be set for multitask model! Available heads are: {model_keys}"
Expand All @@ -134,6 +138,9 @@ def __init__(
elif str(self.model_path).endswith(".pth"):
model = torch.jit.load(model_file, map_location=env.DEVICE)
self.dp = ModelWrapper(model)
self.model_def_script = json.loads(
self.dp.model["Default"].get_model_def_script()
)
else:
raise ValueError("Unknown model file format!")
self.rcut = self.dp.model["Default"].get_rcut()
Expand Down Expand Up @@ -590,6 +597,10 @@ def eval_typeebd(self) -> np.ndarray:
typeebd = torch.cat(out, dim=1)
return to_numpy_array(typeebd)

def get_model_def_script(self) -> str:
"""Get model defination script."""
return self.model_def_script


# For tests only
def eval_model(
Expand Down
8 changes: 8 additions & 0 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from functools import (
lru_cache,
)
Expand Down Expand Up @@ -1123,6 +1124,13 @@ def get_numb_dos(self) -> int:
def get_has_efield(self) -> bool:
return self.has_efield

def get_model_def_script(self) -> dict:
"""Get model defination script."""
t_script = self._get_tensor("train_attr/training_script:0")
[script] = run_sess(self.sess, [t_script], feed_dict={})
model_def_script = script.decode("utf-8")
return json.loads(model_def_script)["model"]


class DeepEvalOld:
# old class for DipoleChargeModifier only
Expand Down
1 change: 1 addition & 0 deletions source/tests/infer/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, filename: str):
self.type_map = config["type_map"]
self.dim_fparam = config["dim_fparam"]
self.dim_aparam = config["dim_aparam"]
self.model_def_script = config.get("model_def_script")

@lru_cache
def get_model(self, suffix: str, out_file: Optional[str] = None) -> str:
Expand Down
38 changes: 38 additions & 0 deletions source/tests/infer/deeppot-testcase.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,44 @@ rcut: 6.0
type_map: ["O", "H"]
dim_fparam: 0
dim_aparam: 0
model_def_script:
{
"data_bias_nsample": 10,
"data_stat_nbatch": 10,
"data_stat_protect": 0.01,
"descriptor":
{
"activation_function": "tanh",
"axis_neuron": 4,
"exclude_types": [],
"neuron": [2, 4, 8],
"precision": "default",
"rcut": 6.0,
"rcut_smth": 0.5,
"resnet_dt": False,
"seed": 1,
"sel": [46, 92],
"set_davg_zero": False,
"trainable": True,
"type": "se_e2_a",
"type_one_side": False,
},
"fitting_net":
{
"activation_function": "tanh",
"atom_ener": [],
"neuron": [6, 6, 6],
"numb_aparam": 0,
"numb_fparam": 0,
"precision": "default",
"rcond": 0.001,
"resnet_dt": True,
"seed": 1,
"trainable": True,
"type": "ener",
},
"type_map": ["O", "H"],
}
results:
- coord:
[
Expand Down
Loading

0 comments on commit bd74bff

Please sign in to comment.