From f5b7aa37710f0c21269a5ce0af69251026471e9e Mon Sep 17 00:00:00 2001 From: Lysithea <52808607+CaRoLZhangxy@users.noreply.github.com> Date: Wed, 27 Mar 2024 15:19:48 +0800 Subject: [PATCH] fix pt bug: missing get_ntype method (#3612) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pt/model/model/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index a62050b2d1..bf97472e33 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -45,3 +45,8 @@ def compute_or_load_stat( def get_model_def_script(self) -> str: """Get the model definition script.""" return self.model_def_script + + @torch.jit.export + def get_ntypes(self): + """Returns the number of element types.""" + return len(self.get_type_map())