Skip to content

Commit

Permalink
fix(tf): apply exclude types to se_atten_v2 switch (#3651)
Browse files Browse the repository at this point in the history
I construct a test case in which all old implementations fail.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Apr 8, 2024
1 parent ff20efc commit 09fd3bb
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
4 changes: 4 additions & 0 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,10 @@ def _pass_filter(
tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt]
),
)
self.recovered_switch *= tf.reshape(
tf.slice(tf.reshape(mask, [-1, 4]), [0, 0], [-1, 1]),
[-1, natoms[0], self.sel_all_a[0]],
)
else:
inputs_i *= mask
if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor:
Expand Down
130 changes: 130 additions & 0 deletions source/tests/tf/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,3 +890,133 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self):
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self):
"""test: auto-diff, continuity of e,f,v."""
jfile = "water_se_atten.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
set_pfx = j_must_have(jdata, "set_prefix")
batch_size = 1
test_size = 1
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")

data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)

test_data = data.get_test()
numb_test = 1

jdata["model"]["descriptor"].pop("type", None)
jdata["model"]["descriptor"]["ntypes"] = 2
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["smooth_type_embdding"] = True
jdata["model"]["descriptor"]["attn_layer"] = 1
jdata["model"]["descriptor"]["rcut"] = 6.0
jdata["model"]["descriptor"]["rcut_smth"] = 4.0
jdata["model"]["descriptor"]["exclude_types"] = [[0, 0], [0, 1]]
jdata["model"]["descriptor"]["set_davg_zero"] = False
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
ntypes=descrpt.get_ntypes(),
neuron=typeebd_param["neuron"],
activation_function=None,
resnet_dt=typeebd_param["resnet_dt"],
seed=typeebd_param["seed"],
uniform_seed=True,
padding=True,
)
model = EnerModel(descrpt, fitting, typeebd)

input_data = {
"coord": [test_data["coord"]],
"box": [test_data["box"]],
"type": [test_data["type"]],
"natoms_vec": [test_data["natoms_vec"]],
"default_mesh": [test_data["default_mesh"]],
}
model._compute_input_stat(input_data)
model.descrpt.bias_atom_e = data.compute_energy_shift()
# make the original implementation failed
model.descrpt.davg[:] += 1e-1

t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c")
t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy")
t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force")
t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial")
t_atom_ener = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener"
)
t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord")
t_type = tf.placeholder(tf.int32, [None], name="i_type")
t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms")
t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box")
t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh")
is_training = tf.placeholder(tf.bool)
inputs_dict = {}

model_pred = model.build(
t_coord,
t_type,
t_natoms,
t_box,
t_mesh,
inputs_dict,
suffix=self.filename
+ "-"
+ inspect.stack()[0][3]
+ "test_model_se_atten_model_compressible_excluded_types",
reuse=False,
)
energy = model_pred["energy"]
force = model_pred["force"]
virial = model_pred["virial"]

feed_dict_test = {
t_prop_c: test_data["prop_c"],
t_energy: test_data["energy"][:numb_test],
t_force: np.reshape(test_data["force"][:numb_test, :], [-1]),
t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]),
t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]),
t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]),
t_box: test_data["box"][:numb_test, :],
t_type: np.reshape(test_data["type"][:numb_test, :], [-1]),
t_natoms: test_data["natoms_vec"],
t_mesh: test_data["default_mesh"],
is_training: False,
}
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[pe, pf, pv] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
pf, pv = pf.reshape(-1), pv.reshape(-1)

eps = 1e-4
delta = 1e-6
fdf, fdv = finite_difference_fv(
sess, energy, feed_dict_test, t_coord, t_box, delta=eps
)
np.testing.assert_allclose(pf, fdf, delta)
np.testing.assert_allclose(pv, fdv, delta)

tested_eps = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
for eps in tested_eps:
deltae = 1e-15
deltad = 1e-15
de, df, dv = check_smooth_efv(
sess,
energy,
force,
virial,
feed_dict_test,
t_coord,
jdata["model"]["descriptor"]["rcut"],
delta=eps,
)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

0 comments on commit 09fd3bb

Please sign in to comment.