Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 8, 2024
1 parent a84f161 commit 33d5fd9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
8 changes: 7 additions & 1 deletion source/tests/pt/model/test_descriptor_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,11 @@ def test_descriptor_block(self):
des = DescrptBlockSeAtten(
**dparams,
).to(env.DEVICE)
des.load_state_dict(torch.load(self.file_model_param, weights_only=True))
state_dict = torch.load(self.file_model_param, weights_only=True)
# this is an old state dict, modify manually
state_dict["compress_info.0"] = des.compress_info[0]
state_dict["compress_data.0"] = des.compress_data[0]
des.load_state_dict(state_dict)
coord = self.coord
atype = self.atype
box = self.cell
Expand Down Expand Up @@ -371,5 +375,7 @@ def translate_se_atten_and_type_embd_dicts_to_dpa1(
tk = "type_embedding." + kk
record[all_keys.index(tk)] = True
target_dict[tk] = type_embd_dict[kk]
record[all_keys.index("se_atten.compress_data.0")] = True
record[all_keys.index("se_atten.compress_info.0")] = True
assert all(record)
return target_dict
2 changes: 2 additions & 0 deletions source/tests/pt/model/test_descriptor_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,7 @@ def translate_type_embd_dicts_to_dpa2(
tk = "type_embedding." + kk
record[all_keys.index(tk)] = True
target_dict[tk] = type_embd_dict[kk]
record[all_keys.index("repinit.compress_data.0")] = True
record[all_keys.index("repinit.compress_info.0")] = True
assert all(record)
return target_dict
3 changes: 2 additions & 1 deletion source/tests/pt/model/test_unused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def get_contributing_params(y, top_level=True):
contributing_parameters = set(get_contributing_params(ret0["energy"]))
all_parameters = set(self.model.parameters())
non_contributing = all_parameters - contributing_parameters
self.assertEqual(len(non_contributing), 0)
# 2 for compression
self.assertEqual(len(non_contributing), 2)


if __name__ == "__main__":
Expand Down

0 comments on commit 33d5fd9

Please sign in to comment.