diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 3f4483dd74..5e0487f500 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -683,7 +683,7 @@ def _pass_filter( ) if self.smooth: inputs_i = tf.where( - tf.cast(mask, tf.bool), inputs_i, self.avg_looked_up + tf.cast(mask, tf.bool), inputs_i, tf.reshape(tf.tile(self.avg_looked_up, [1, 1, 4]), tf.shape(inputs_i)) ) else: inputs_i *= mask