Skip to content

Commit

Permalink
fix the smoothness issue of se_attn_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Oct 9, 2023
1 parent d8ee74b commit d567995
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 4 deletions.
23 changes: 19 additions & 4 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ def build(
self.filter_precision,
)
self.negative_mask = -(2 << 32) * (1.0 - self.nmask)
# hard coding the magnitude of attention weight shift
self.smth_attn_w_shift = 20.0
# only used when tensorboard was set as true
tf.summary.histogram("descrpt", self.descrpt)
tf.summary.histogram("rij", self.rij)
Expand Down Expand Up @@ -599,7 +601,7 @@ def build(
)
self.recovered_r = (
tf.reshape(
tf.slice(tf.reshape(self.descrpt, [-1, 4]), [0, 0], [-1, 1]),
tf.slice(tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1]),
[-1, natoms[0], self.sel_all_a[0]],
)
* self.std_looked_up
Expand Down Expand Up @@ -865,10 +867,23 @@ def _scaled_dot_attn(
save_weights=True,
):
attn = tf.matmul(Q / temperature, K, transpose_b=True)
attn *= self.nmask
attn += self.negative_mask
if self.smooth:
# (nb x nloc) x nsel
nsel = self.sel_all_a[0]
attn = ((attn + self.smth_attn_w_shift) *
tf.reshape(self.recovered_switch, [-1,1,nsel]) *
tf.reshape(self.recovered_switch, [-1,nsel,1]) -
self.smth_attn_w_shift)
else:
attn *= self.nmask
attn += self.negative_mask
attn = tf.nn.softmax(attn, axis=-1)
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
if self.smooth:
attn = (attn *
tf.reshape(self.recovered_switch, [-1,1,nsel]) *
tf.reshape(self.recovered_switch, [-1,nsel,1]))
else:
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
if save_weights:
self.attn_weight[layer] = attn[0] # atom 0
if dotr:
Expand Down
72 changes: 72 additions & 0 deletions source/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,78 @@ def strerch_box(old_coord, old_box, new_box):
return ncoord.reshape(old_coord.shape)


def finite_difference_fv(sess, energy, feed_dict, t_coord, t_box, delta=1e-6):
"""for energy models, compute f, v by finite difference
"""
base_dict = feed_dict.copy()
coord0 = base_dict.pop(t_coord)
box0 = base_dict.pop(t_box)
fdf = -finite_difference(
lambda coord: sess.run(
energy, feed_dict={**base_dict, t_coord: coord, t_box: box0}
).reshape(-1),
coord0,
delta=delta,
).reshape(-1)
fdv = -(
finite_difference(
lambda box: sess.run(
energy,
feed_dict={
**base_dict,
t_coord: strerch_box(coord0, box0, box),
t_box: box,
},
).reshape(-1),
box0,
delta=delta,
)
.reshape([-1, 3, 3])
.transpose(0, 2, 1)
@ box0.reshape(3, 3)
).reshape(-1)
return fdf, fdv


def check_continuity(f, cc, rcut, delta):
"""coord[0:2] to [[0, 0, 0], [rcut+-.5*delta, 0, 0]]
"""
cc = cc.reshape([-1,3])
cc0 = np.copy(cc)
cc1 = np.copy(cc)
cc0[:2,:] = np.array([
0.0, 0.0, 0.0,
rcut-0.5*delta, 0.0, 0.0,
]).reshape([-1,3])
cc1[:2,:] = np.array([
0.0, 0.0, 0.0,
rcut+0.5*delta, 0.0, 0.0,
]).reshape([-1,3])
return f(cc0.reshape(-1)), f(cc1.reshape(-1))


def check_smooth_efv(sess, energy, force, virial, feed_dict, t_coord, rcut, delta=1e-5):
"""check the smoothness of e, f and v
the returned values are de, df, dv
de[0] are supposed to be closed to de[1]
df[0] are supposed to be closed to df[1]
dv[0] are supposed to be closed to dv[1]
"""
base_dict = feed_dict.copy()
coord0 = base_dict.pop(t_coord)
[fe, ff, fv] = [
lambda coord: sess.run(
ii, feed_dict={**base_dict, t_coord: coord}
).reshape(-1)
for ii in [energy, force, virial]
]
[de, df, dv] = [
check_continuity(ii, coord0, rcut, delta=delta)
for ii in [fe, ff, fv]
]
return de, df, dv


def run_dp(cmd: str) -> int:
"""Run DP directly from the entry point instead of the subprocess.
Expand Down
139 changes: 139 additions & 0 deletions source/tests/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from deepmd.utils.type_embed import (
TypeEmbedNet,
)
from common import finite_difference_fv, check_smooth_efv

GLOBAL_ENER_FLOAT_PRECISION = tf.float64
GLOBAL_TF_FLOAT_PRECISION = tf.float64
Expand Down Expand Up @@ -726,3 +727,141 @@ def test_stripped_type_embedding_exclude_types(self):
np.testing.assert_almost_equal(des[:, 0:2], 0.0, 10)
with self.assertRaises(AssertionError):
np.testing.assert_almost_equal(des[:, 2:6], 0.0, 10)


def test_smoothness_of_stripped_type_embedding_smooth_model(self):
"""test: auto-diff, continuity of e,f,v
"""
jfile = "water_se_atten.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
set_pfx = j_must_have(jdata, "set_prefix")
batch_size = j_must_have(jdata, "batch_size")
test_size = j_must_have(jdata, "numb_test")
batch_size = 1
test_size = 1
stop_batch = j_must_have(jdata, "stop_batch")
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")

data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)

test_data = data.get_test()
numb_test = 1

jdata["model"]["descriptor"].pop("type", None)
jdata["model"]["descriptor"]["ntypes"] = 2
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["smooth_type_embdding"] = True
jdata["model"]["descriptor"]["attn_layer"] = 1
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["descrpt"] = descrpt
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
neuron=typeebd_param["neuron"],
activation_function=None,
resnet_dt=typeebd_param["resnet_dt"],
seed=typeebd_param["seed"],
uniform_seed=True,
padding=True,
)
model = EnerModel(descrpt, fitting, typeebd)

# model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
input_data = {
"coord": [test_data["coord"]],
"box": [test_data["box"]],
"type": [test_data["type"]],
"natoms_vec": [test_data["natoms_vec"]],
"default_mesh": [test_data["default_mesh"]],
}
model._compute_input_stat(input_data)
model.descrpt.bias_atom_e = data.compute_energy_shift()

t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c")
t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy")
t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force")
t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial")
t_atom_ener = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener"
)
t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord")
t_type = tf.placeholder(tf.int32, [None], name="i_type")
t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms")
t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box")
t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh")
is_training = tf.placeholder(tf.bool)
t_fparam = None
inputs_dict = {}

model_pred = model.build(
t_coord,
t_type,
t_natoms,
t_box,
t_mesh,
inputs_dict,
suffix=self.filename
+ "-"
+ inspect.stack()[0][3]
+ "test_model_se_atten_model_compressible",
reuse=False,
)
energy = model_pred["energy"]
force = model_pred["force"]
virial = model_pred["virial"]
atom_ener = model_pred["atom_ener"]

feed_dict_test = {
t_prop_c: test_data["prop_c"],
t_energy: test_data["energy"][:numb_test],
t_force: np.reshape(test_data["force"][:numb_test, :], [-1]),
t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]),
t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]),
t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]),
t_box: test_data["box"][:numb_test, :],
t_type: np.reshape(test_data["type"][:numb_test, :], [-1]),
t_natoms: test_data["natoms_vec"],
t_mesh: test_data["default_mesh"],
is_training: False,
}
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[pe, pf, pv] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
pf, pv = pf.reshape(-1), pv.reshape(-1)

eps = 1e-4
delta = 1e-5
fdf, fdv = finite_difference_fv(
sess, energy, feed_dict_test, t_coord, t_box, delta=eps)
np.testing.assert_allclose(pf, fdf, delta)
np.testing.assert_allclose(pv, fdv, delta)

tested_eps = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
for eps in tested_eps:
deltae = eps
deltad = eps
de, df, dv = check_smooth_efv(
sess, energy, force, virial,
feed_dict_test, t_coord,
jdata["model"]["descriptor"]["rcut"],
delta=eps,
)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

for eps in tested_eps:
deltae = 5.*eps
deltad = 5.*eps
de, df, dv = check_smooth_efv(
sess, energy, force, virial,
feed_dict_test, t_coord,
jdata["model"]["descriptor"]["rcut_smth"],
delta=eps,
)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

0 comments on commit d567995

Please sign in to comment.