From 13a8adf234f09982d88aa25fbd0b910104e96cd0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Mar 2024 04:19:26 -0500 Subject: [PATCH] do not return g2, h2, sw in hybrid descriptors (#3396) g2, h2, and sw are heavily dependent on the neighbor list. We cannot ensure the sub descriptors require the same neighbor list as the parent descriptor. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/hybrid.py | 8 +------- deepmd/pt/model/descriptor/hybrid.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index 46f2616b84..96640d75c8 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -176,7 +176,7 @@ def call( """ out_descriptor = [] out_gr = [] - out_g2 = [] + out_g2 = None out_h2 = None out_sw = None if self.sel_no_mixed_types is not None: @@ -199,15 +199,9 @@ def call( out_descriptor.append(odescriptor) if gr is not None: out_gr.append(gr) - if g2 is not None: - out_g2.append(g2) - if self.get_rcut() == descrpt.get_rcut(): - out_h2 = h2 - out_sw = sw out_descriptor = np.concatenate(out_descriptor, axis=-1) out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None - out_g2 = np.concatenate(out_g2, axis=-1) if out_g2 else None return out_descriptor, out_gr, out_g2, out_h2, out_sw @classmethod diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index b53adca462..204ca7589d 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -200,7 +200,7 @@ def forward( """ out_descriptor = [] out_gr = [] - out_g2 = [] + out_g2: Optional[torch.Tensor] = None out_h2: Optional[torch.Tensor] = None out_sw: Optional[torch.Tensor] = None if self.sel_no_mixed_types is not None: @@ -225,14 +225,8 @@ def forward( out_descriptor.append(odescriptor) if gr is not None: out_gr.append(gr) - if g2 is not None: - out_g2.append(g2) - if self.get_rcut() == descrpt.get_rcut(): - out_h2 = h2 - out_sw = sw out_descriptor = torch.cat(out_descriptor, dim=-1) out_gr = torch.cat(out_gr, dim=-2) if out_gr else None - out_g2 = torch.cat(out_g2, dim=-1) if out_g2 else None return out_descriptor, out_gr, out_g2, out_h2, out_sw @classmethod