diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 3ca763870b..bde8365775 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -705,7 +705,7 @@ def _pass_filter( ), ) self.recovered_switch *= tf.reshape( - tf.slice(tf.reshape(mask, [-1, 4]), [0, 0], [-1, 1]), + tf.slice(tf.reshape(tf.cast(mask, self.filter_precision), [-1, 4]), [0, 0], [-1, 1]), [-1, natoms[0], self.sel_all_a[0]], ) else: