diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index a62050b2d1..bf97472e33 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -45,3 +45,8 @@ def compute_or_load_stat( def get_model_def_script(self) -> str: """Get the model definition script.""" return self.model_def_script + + @torch.jit.export + def get_ntypes(self): + """Returns the number of element types.""" + return len(self.get_type_map())