Skip to content

Commit

Permalink
add uts for the compression of smooth se_atten
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Oct 13, 2023
1 parent d567995 commit 7ceabd2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
18 changes: 11 additions & 7 deletions source/tests/test_model_compression_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ def _subprocess_run(command):
# - type embedding FP32, se_atten FP64
# - type embedding FP32, se_atten FP32
tests = [
{"se_atten precision": "float64", "type embedding precision": "float64"},
{"se_atten precision": "float64", "type embedding precision": "float32"},
{"se_atten precision": "float32", "type embedding precision": "float64"},
{"se_atten precision": "float32", "type embedding precision": "float32"},
{"se_atten precision": "float64", "type embedding precision": "float64", "smooth_type_embdding": True},
{"se_atten precision": "float64", "type embedding precision": "float64", "smooth_type_embdding": False},
{"se_atten precision": "float64", "type embedding precision": "float32", "smooth_type_embdding": True},
{"se_atten precision": "float32", "type embedding precision": "float64", "smooth_type_embdding": True},
{"se_atten precision": "float32", "type embedding precision": "float32", "smooth_type_embdding": True},
]


Expand All @@ -73,6 +74,9 @@ def _init_models():
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["sel"] = 120
jdata["model"]["descriptor"]["attn_layer"] = 0
jdata["model"]["descriptor"]["smooth_type_embdding"] = tests[i][
"smooth_type_embdding"
]
jdata["model"]["type_embedding"] = {}
jdata["model"]["type_embedding"]["precision"] = tests[i][
"type embedding precision"
Expand Down Expand Up @@ -479,9 +483,9 @@ def test_1frame(self):
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)
np.testing.assert_almost_equal(ff0, ff1, default_places, err_msg=str(tests[i]))
np.testing.assert_almost_equal(ee0, ee1, default_places, err_msg=str(tests[i]))
np.testing.assert_almost_equal(vv0, vv1, default_places, err_msg=str(tests[i]))

def test_1frame_atm(self):
for i in range(len(tests)):
Expand Down
2 changes: 2 additions & 0 deletions source/tests/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,8 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self):
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["smooth_type_embdding"] = True
jdata["model"]["descriptor"]["attn_layer"] = 1
jdata["model"]["descriptor"]["rcut"] = 6.0
jdata["model"]["descriptor"]["rcut_smth"] = 4.0
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["descrpt"] = descrpt
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
Expand Down

0 comments on commit 7ceabd2

Please sign in to comment.