Skip to content

Commit

Permalink
construct a test case that old implementation will fail
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Apr 7, 2024
1 parent 41a4b89 commit 51c3054
Showing 1 changed file with 8 additions and 22 deletions.
30 changes: 8 additions & 22 deletions source/tests/tf/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self)
jdata["model"]["descriptor"]["attn_layer"] = 1
jdata["model"]["descriptor"]["rcut"] = 6.0
jdata["model"]["descriptor"]["rcut_smth"] = 4.0
jdata["model"]["descriptor"]["exclude_types"] = [[0, 1], [1, 1]]
jdata["model"]["descriptor"]["exclude_types"] = [[0, 0], [0, 1]]
jdata["model"]["descriptor"]["set_davg_zero"] = False
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
Expand Down Expand Up @@ -944,6 +944,8 @@ def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self)
}
model._compute_input_stat(input_data)
model.descrpt.bias_atom_e = data.compute_energy_shift()
# make the original implementation failed
model.descrpt.davg[:] += 1e-1

t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c")
t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy")
Expand Down Expand Up @@ -996,16 +998,17 @@ def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self)
pf, pv = pf.reshape(-1), pv.reshape(-1)

eps = 1e-4
delta = 1e-6
fdf, fdv = finite_difference_fv(
sess, energy, feed_dict_test, t_coord, t_box, delta=eps
)
np.testing.assert_allclose(pf, fdf, atol=1e-10)
np.testing.assert_allclose(pv, fdv, atol=1e-8)
np.testing.assert_allclose(pf, fdf, delta)
np.testing.assert_allclose(pv, fdv, delta)

tested_eps = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
for eps in tested_eps:
deltae = eps
deltad = eps
deltae = 1e-15
deltad = 1e-15
de, df, dv = check_smooth_efv(
sess,
energy,
Expand All @@ -1019,20 +1022,3 @@ def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

for eps in tested_eps:
deltae = 5.0 * eps
deltad = 5.0 * eps
de, df, dv = check_smooth_efv(
sess,
energy,
force,
virial,
feed_dict_test,
t_coord,
jdata["model"]["descriptor"]["rcut_smth"],
delta=eps,
)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

0 comments on commit 51c3054

Please sign in to comment.