Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(tf): make se_atten_v2 masking smooth when davg is not zero #3632

Merged
merged 9 commits into from
Apr 2, 2024
12 changes: 11 additions & 1 deletion deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,17 @@
tf.shape(inputs_i)[0],
self.nei_type_vec, # extra input for atten
)
inputs_i *= mask
if self.smooth:
inputs_i = tf.where(

Check warning on line 685 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L684-L685

Added lines #L684 - L685 were not covered by tests
tf.cast(mask, tf.bool),
inputs_i,
# (nframes * nloc, 1) -> (nframes * nloc, ndescrpt)
tf.tile(
tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt]
),
)
else:
inputs_i *= mask

Check warning on line 694 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L694

Added line #L694 was not covered by tests
if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor:
inputs_i = descrpt2r4(inputs_i, atype)
layer, qmat = self._filter(
Expand Down