diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py index 245c0f3d3f..a222c8e6f6 100644 --- a/deepmd/pt/model/model/dp_atomic_model.py +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -93,11 +93,11 @@ def __init__( ) fitting_net["type"] = fitting_net.get("type", "ener") - if self.descriptor_type not in ["se_e2_a"]: - fitting_net["ntypes"] = self.descriptor.get_ntype() + fitting_net["ntypes"] = self.descriptor.get_ntype() + if self.descriptor_type in ["se_e2_a"]: + fitting_net["distinguish_types"] = True else: - fitting_net["ntypes"] = self.descriptor.get_ntype() - fitting_net["use_tebd"] = False + fitting_net["distinguish_types"] = False fitting_net["embedding_width"] = self.descriptor.dim_out self.grad_force = "direct" not in fitting_net["type"]