From 8bf8c1e349fa7a3de1c72b7f55c46c38fa0c73ee Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 1 Apr 2024 14:24:19 -0400 Subject: [PATCH 1/9] fix(tf): make se_atten_v2 smooth when exclude_types is given and set_davg_zero is False Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/se_atten.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 51e34e9b08..c278c4c69e 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -681,7 +681,10 @@ def _pass_filter( tf.shape(inputs_i)[0], self.nei_type_vec, # extra input for atten ) - inputs_i *= mask + if self.smooth: + inputs_i = tf.where(mask, inputs_i, self.avg_looked_up) + else: + inputs_i *= mask if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: inputs_i = descrpt2r4(inputs_i, atype) layer, qmat = self._filter( From 9dfa85a7e060003fb35f97eae99f3b037d2c7137 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 1 Apr 2024 15:06:17 -0400 Subject: [PATCH 2/9] cast mask to bool Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/se_atten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index c278c4c69e..d0990b0a30 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -682,7 +682,7 @@ def _pass_filter( self.nei_type_vec, # extra input for atten ) if self.smooth: - inputs_i = tf.where(mask, inputs_i, self.avg_looked_up) + inputs_i = tf.where(tf.cast(mask, tf.bool), inputs_i, self.avg_looked_up) else: inputs_i *= mask if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: From 56e6cac87e3de6f451369078cb3ae65f16687a06 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 19:06:42 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/tf/descriptor/se_atten.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index d0990b0a30..3f4483dd74 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -682,7 +682,9 @@ def _pass_filter( self.nei_type_vec, # extra input for atten ) if self.smooth: - inputs_i = tf.where(tf.cast(mask, tf.bool), inputs_i, self.avg_looked_up) + inputs_i = tf.where( + tf.cast(mask, tf.bool), inputs_i, self.avg_looked_up + ) else: inputs_i *= mask if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: From 3563de74469ce4216903febb80ab4bcef45c1c73 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 1 Apr 2024 15:43:26 -0400 Subject: [PATCH 4/9] tile + reshape Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/se_atten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 3f4483dd74..5e0487f500 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -683,7 +683,7 @@ def _pass_filter( ) if self.smooth: inputs_i = tf.where( - tf.cast(mask, tf.bool), inputs_i, self.avg_looked_up + tf.cast(mask, tf.bool), inputs_i, tf.reshape(tf.tile(self.avg_looked_up, [1, 1, 4]), tf.shape(inputs_i)) ) else: inputs_i *= mask From 97e1e0a7cfd335bba4f12e7e63c3acc3ca154c44 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 19:43:51 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/tf/descriptor/se_atten.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 5e0487f500..39faae8001 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -683,7 +683,11 @@ def _pass_filter( ) if self.smooth: inputs_i = tf.where( - tf.cast(mask, tf.bool), inputs_i, tf.reshape(tf.tile(self.avg_looked_up, [1, 1, 4]), tf.shape(inputs_i)) + tf.cast(mask, tf.bool), + inputs_i, + tf.reshape( + tf.tile(self.avg_looked_up, [1, 1, 4]), tf.shape(inputs_i) + ), ) else: inputs_i *= mask From 6f8eff487889e3fa21c933f101bfa3411475ec09 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 1 Apr 2024 16:21:21 -0400 Subject: [PATCH 6/9] fix shape Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/se_atten.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 39faae8001..e561e32095 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -685,8 +685,9 @@ def _pass_filter( inputs_i = tf.where( tf.cast(mask, tf.bool), inputs_i, + # broadcast is available: (nframes * nloc, 1) -> (nframes * nloc, ndescrpt) tf.reshape( - tf.tile(self.avg_looked_up, [1, 1, 4]), tf.shape(inputs_i) + self.avg_looked_up, [-1, 1] ), ) else: From 3314c6d6408c293bbff5ec5d6d0541e510c5fd85 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 20:22:26 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/tf/descriptor/se_atten.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index e561e32095..b2923537e9 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -686,9 +686,7 @@ def _pass_filter( tf.cast(mask, tf.bool), inputs_i, # broadcast is available: (nframes * nloc, 1) -> (nframes * nloc, ndescrpt) - tf.reshape( - self.avg_looked_up, [-1, 1] - ), + tf.reshape(self.avg_looked_up, [-1, 1]), ) else: inputs_i *= mask From db0373ef0aab48acea08500918f5a7814ad2941c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 1 Apr 2024 17:52:55 -0400 Subject: [PATCH 8/9] broadcast seems not working Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/se_atten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index b2923537e9..5553e2e3c4 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -685,8 +685,8 @@ def _pass_filter( inputs_i = tf.where( tf.cast(mask, tf.bool), inputs_i, - # broadcast is available: (nframes * nloc, 1) -> (nframes * nloc, ndescrpt) - tf.reshape(self.avg_looked_up, [-1, 1]), + # (nframes * nloc, 1) -> (nframes * nloc, ndescrpt) + tf.tile(tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt]), ) else: inputs_i *= mask From df1f4fbc63ed4104c1350b06e91a20e3283b6f67 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 21:54:45 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/tf/descriptor/se_atten.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 5553e2e3c4..82184dec02 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -686,7 +686,9 @@ def _pass_filter( tf.cast(mask, tf.bool), inputs_i, # (nframes * nloc, 1) -> (nframes * nloc, ndescrpt) - tf.tile(tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt]), + tf.tile( + tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt] + ), ) else: inputs_i *= mask