From 33d5fd916dcd1778ddeba6e0ed3a9f5df88f2a00 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 12:51:40 -0500 Subject: [PATCH] fix tests Signed-off-by: Jinzhe Zeng --- source/tests/pt/model/test_descriptor_dpa1.py | 8 +++++++- source/tests/pt/model/test_descriptor_dpa2.py | 2 ++ source/tests/pt/model/test_unused_params.py | 3 ++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index ddd5dc6c3c..9652a63944 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -245,7 +245,11 @@ def test_descriptor_block(self): des = DescrptBlockSeAtten( **dparams, ).to(env.DEVICE) - des.load_state_dict(torch.load(self.file_model_param, weights_only=True)) + state_dict = torch.load(self.file_model_param, weights_only=True) + # this is an old state dict, modify manually + state_dict["compress_info.0"] = des.compress_info[0] + state_dict["compress_data.0"] = des.compress_data[0] + des.load_state_dict(state_dict) coord = self.coord atype = self.atype box = self.cell @@ -371,5 +375,7 @@ def translate_se_atten_and_type_embd_dicts_to_dpa1( tk = "type_embedding." + kk record[all_keys.index(tk)] = True target_dict[tk] = type_embd_dict[kk] + record[all_keys.index("se_atten.compress_data.0")] = True + record[all_keys.index("se_atten.compress_info.0")] = True assert all(record) return target_dict diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index 17d609a2f9..7efbe0a921 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -194,5 +194,7 @@ def translate_type_embd_dicts_to_dpa2( tk = "type_embedding." + kk record[all_keys.index(tk)] = True target_dict[tk] = type_embd_dict[kk] + record[all_keys.index("repinit.compress_data.0")] = True + record[all_keys.index("repinit.compress_info.0")] = True assert all(record) return target_dict diff --git a/source/tests/pt/model/test_unused_params.py b/source/tests/pt/model/test_unused_params.py index 98bbe7040e..8c223d7590 100644 --- a/source/tests/pt/model/test_unused_params.py +++ b/source/tests/pt/model/test_unused_params.py @@ -86,7 +86,8 @@ def get_contributing_params(y, top_level=True): contributing_parameters = set(get_contributing_params(ret0["energy"])) all_parameters = set(self.model.parameters()) non_contributing = all_parameters - contributing_parameters - self.assertEqual(len(non_contributing), 0) + # 2 for compression + self.assertEqual(len(non_contributing), 2) if __name__ == "__main__":