Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Mar 9, 2024
1 parent db9e881 commit 8f6614c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 19 deletions.
9 changes: 6 additions & 3 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,11 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
model_params["type_map"],
model_params["new_type_map"],
)
if hasattr(self.model, "fitting_net"):
self.model.fitting_net.change_energy_bias(
# TODO: need an interface instead of fetching fitting_net!!!!!!!!!
if hasattr(self.model, "atomic_model") and hasattr(
self.model.atomic_model, "fitting_net"
):
self.model.atomic_model.fitting_net.change_energy_bias(
config,
self.model,
old_type_map,
Expand All @@ -530,7 +533,7 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
)
elif isinstance(self.model, DPZBLModel):
# need to updated
self.model.change_energy_bias()
self.model.atomic_model.change_energy_bias()
else:
raise NotImplementedError
if init_frz_model is not None:
Expand Down
10 changes: 4 additions & 6 deletions deepmd/pt/train/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ def share_params(self, shared_links, resume=False):
shared_level_base = shared_base["shared_level"]
if "descriptor" in class_type_base:
if class_type_base == "descriptor":
base_class = self.model[model_key_base].__getattr__("descriptor")
base_class = self.model[model_key_base].get_descriptor()
elif "hybrid" in class_type_base:
hybrid_index = int(class_type_base.split("_")[-1])
base_class = (
self.model[model_key_base]
.__getattr__("descriptor")
.get_descriptor()
.descriptor_list[hybrid_index]
)
else:
Expand All @@ -96,14 +96,12 @@ def share_params(self, shared_links, resume=False):
"descriptor" in class_type_link
), f"Class type mismatched: {class_type_base} vs {class_type_link}!"
if class_type_link == "descriptor":
link_class = self.model[model_key_link].__getattr__(
"descriptor"
)
link_class = self.model[model_key_link].get_descriptor()
elif "hybrid" in class_type_link:
hybrid_index = int(class_type_link.split("_")[-1])
link_class = (
self.model[model_key_link]
.__getattr__("descriptor")
.get_descriptor()
.descriptor_list[hybrid_index]
)
else:
Expand Down
18 changes: 10 additions & 8 deletions source/tests/pt/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,29 @@ def test_finetune_change_energy_bias(self):
else:
model = get_model(self.model_config)
if isinstance(model, EnergyModel):
model.fitting_net.bias_atom_e = torch.rand_like(
model.fitting_net.bias_atom_e
model.get_fitting_net().bias_atom_e = torch.rand_like(
model.get_fitting_net().bias_atom_e
)
energy_bias_before = deepcopy(
model.fitting_net.bias_atom_e.detach().cpu().numpy().reshape(-1)
model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1)
)
bias_atom_e_input = deepcopy(
model.fitting_net.bias_atom_e.detach().cpu().numpy().reshape(-1)
model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1)
)
elif isinstance(model, DPZBLModel):
model.dp_model.fitting_net.bias_atom_e = torch.rand_like(
model.dp_model.fitting_net.bias_atom_e
model.dp_model.get_fitting_net().bias_atom_e = torch.rand_like(
model.dp_model.get_fitting_net().bias_atom_e
)
energy_bias_before = deepcopy(
model.dp_model.fitting_net.bias_atom_e.detach()
model.dp_model.get_fitting_net()
.bias_atom_e.detach()
.cpu()
.numpy()
.reshape(-1)
)
bias_atom_e_input = deepcopy(
model.dp_model.fitting_net.bias_atom_e.detach()
model.dp_model.get_fitting_net()
.bias_atom_e.detach()
.cpu()
.numpy()
.reshape(-1)
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def test_trainable(self):
fix_params["model"]["descriptor"]["trainable"] = True
trainer_fix = get_trainer(fix_params)
model_dict_before_training = deepcopy(
trainer_fix.model.fitting_net.state_dict()
trainer_fix.model.get_fitting_net().state_dict()
)
trainer_fix.run()
model_dict_after_training = deepcopy(
trainer_fix.model.fitting_net.state_dict()
trainer_fix.model.get_fitting_net().state_dict()
)
else:
trainer_fix = get_trainer(fix_params)
Expand Down

0 comments on commit 8f6614c

Please sign in to comment.