From d302a5af02c727dc4553f4a50839d3f4382af03e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 8 Feb 2024 19:16:13 -0500 Subject: [PATCH] pt: infer model type from ModelOutputDef Signed-off-by: Jinzhe Zeng --- deepmd/pt/infer/deep_eval.py | 33 ++++++++++++++++++++++++++--- deepmd/pt/model/model/make_model.py | 1 + 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index b42bee1dbe..cb602dd172 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -18,12 +18,26 @@ OutputVariableCategory, OutputVariableDef, ) +from deepmd.infer.deep_dipole import ( + DeepDipole, +) +from deepmd.infer.deep_dos import ( + DeepDOS, +) +from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper from deepmd.infer.deep_eval import ( DeepEvalBackend, ) +from deepmd.infer.deep_polar import ( + DeepGlobalPolar, + DeepPolar, +) from deepmd.infer.deep_pot import ( DeepPot, ) +from deepmd.infer.deep_wfc import ( + DeepWFC, +) from deepmd.pt.model.model import ( get_model, ) @@ -44,8 +58,6 @@ if TYPE_CHECKING: import ase.neighborlist - from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper - class DeepEval(DeepEvalBackend): """PyTorch backend implementaion of DeepEval. @@ -127,7 +139,22 @@ def get_dim_aparam(self) -> int: @property def model_type(self) -> "DeepEvalWrapper": """The the evaluator of the model type.""" - return DeepPot + output_def = self.dp.model["Default"].model_output_def() + var_defs = output_def.var_defs + if "energy" in var_defs: + return DeepPot + elif "dos" in var_defs: + return DeepDOS + elif "dipole" in var_defs: + return DeepDipole + elif "polar" in var_defs: + return DeepPolar + elif "global_polar" in var_defs: + return DeepGlobalPolar + elif "wfc" in var_defs: + return DeepWFC + else: + raise RuntimeError("Unknown model type") def get_sel_type(self) -> List[int]: """Get the selected atom types of this model. diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 1e76c6a468..8a863b8cdc 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -57,6 +57,7 @@ def __init__( **kwargs, ) + @torch.jit.export def model_output_def(self): """Get the output def for the model.""" return ModelOutputDef(self.fitting_output_def())