diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 62bc5a4c97..97d28cd4a1 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -519,8 +519,11 @@ def get_loss(loss_params, start_lr, _ntypes, _model): model_params["type_map"], model_params["new_type_map"], ) - if hasattr(self.model, "fitting_net"): - self.model.fitting_net.change_energy_bias( + # TODO: need an interface instead of fetching fitting_net!!!!!!!!! + if hasattr(self.model, "atomic_model") and hasattr( + self.model.atomic_model, "fitting_net" + ): + self.model.atomic_model.fitting_net.change_energy_bias( config, self.model, old_type_map, @@ -530,7 +533,7 @@ def get_loss(loss_params, start_lr, _ntypes, _model): ) elif isinstance(self.model, DPZBLModel): # need to updated - self.model.change_energy_bias() + self.model.atomic_model.change_energy_bias() else: raise NotImplementedError if init_frz_model is not None: diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index c1040fb9e3..061cd777db 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -75,12 +75,12 @@ def share_params(self, shared_links, resume=False): shared_level_base = shared_base["shared_level"] if "descriptor" in class_type_base: if class_type_base == "descriptor": - base_class = self.model[model_key_base].__getattr__("descriptor") + base_class = self.model[model_key_base].get_descriptor() elif "hybrid" in class_type_base: hybrid_index = int(class_type_base.split("_")[-1]) base_class = ( self.model[model_key_base] - .__getattr__("descriptor") + .get_descriptor() .descriptor_list[hybrid_index] ) else: @@ -96,14 +96,12 @@ def share_params(self, shared_links, resume=False): "descriptor" in class_type_link ), f"Class type mismatched: {class_type_base} vs {class_type_link}!" if class_type_link == "descriptor": - link_class = self.model[model_key_link].__getattr__( - "descriptor" - ) + link_class = self.model[model_key_link].get_descriptor() elif "hybrid" in class_type_link: hybrid_index = int(class_type_link.split("_")[-1]) link_class = ( self.model[model_key_link] - .__getattr__("descriptor") + .get_descriptor() .descriptor_list[hybrid_index] ) else: diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index d21a44acc7..dd72eb4718 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -44,27 +44,29 @@ def test_finetune_change_energy_bias(self): else: model = get_model(self.model_config) if isinstance(model, EnergyModel): - model.fitting_net.bias_atom_e = torch.rand_like( - model.fitting_net.bias_atom_e + model.get_fitting_net().bias_atom_e = torch.rand_like( + model.get_fitting_net().bias_atom_e ) energy_bias_before = deepcopy( - model.fitting_net.bias_atom_e.detach().cpu().numpy().reshape(-1) + model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1) ) bias_atom_e_input = deepcopy( - model.fitting_net.bias_atom_e.detach().cpu().numpy().reshape(-1) + model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1) ) elif isinstance(model, DPZBLModel): - model.dp_model.fitting_net.bias_atom_e = torch.rand_like( - model.dp_model.fitting_net.bias_atom_e + model.dp_model.get_fitting_net().bias_atom_e = torch.rand_like( + model.dp_model.get_fitting_net().bias_atom_e ) energy_bias_before = deepcopy( - model.dp_model.fitting_net.bias_atom_e.detach() + model.dp_model.get_fitting_net() + .bias_atom_e.detach() .cpu() .numpy() .reshape(-1) ) bias_atom_e_input = deepcopy( - model.dp_model.fitting_net.bias_atom_e.detach() + model.dp_model.get_fitting_net() + .bias_atom_e.detach() .cpu() .numpy() .reshape(-1) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index db69a1bcea..e4403d2251 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -52,11 +52,11 @@ def test_trainable(self): fix_params["model"]["descriptor"]["trainable"] = True trainer_fix = get_trainer(fix_params) model_dict_before_training = deepcopy( - trainer_fix.model.fitting_net.state_dict() + trainer_fix.model.get_fitting_net().state_dict() ) trainer_fix.run() model_dict_after_training = deepcopy( - trainer_fix.model.fitting_net.state_dict() + trainer_fix.model.get_fitting_net().state_dict() ) else: trainer_fix = get_trainer(fix_params)