Skip to content

Commit

Permalink
do not return g2, h2, sw in hybrid descriptors (#3396)
Browse files Browse the repository at this point in the history
g2, h2, and sw are heavily dependent on the neighbor list. We cannot
ensure the sub descriptors require the same neighbor list as the parent
descriptor.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Mar 3, 2024
1 parent 9c508b7 commit 13a8adf
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 14 deletions.
8 changes: 1 addition & 7 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def call(
"""
out_descriptor = []
out_gr = []
out_g2 = []
out_g2 = None
out_h2 = None
out_sw = None
if self.sel_no_mixed_types is not None:
Expand All @@ -199,15 +199,9 @@ def call(
out_descriptor.append(odescriptor)
if gr is not None:
out_gr.append(gr)
if g2 is not None:
out_g2.append(g2)
if self.get_rcut() == descrpt.get_rcut():
out_h2 = h2
out_sw = sw

out_descriptor = np.concatenate(out_descriptor, axis=-1)
out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None
out_g2 = np.concatenate(out_g2, axis=-1) if out_g2 else None
return out_descriptor, out_gr, out_g2, out_h2, out_sw

@classmethod
Expand Down
8 changes: 1 addition & 7 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def forward(
"""
out_descriptor = []
out_gr = []
out_g2 = []
out_g2: Optional[torch.Tensor] = None
out_h2: Optional[torch.Tensor] = None
out_sw: Optional[torch.Tensor] = None
if self.sel_no_mixed_types is not None:
Expand All @@ -225,14 +225,8 @@ def forward(
out_descriptor.append(odescriptor)
if gr is not None:
out_gr.append(gr)
if g2 is not None:
out_g2.append(g2)
if self.get_rcut() == descrpt.get_rcut():
out_h2 = h2
out_sw = sw
out_descriptor = torch.cat(out_descriptor, dim=-1)
out_gr = torch.cat(out_gr, dim=-2) if out_gr else None
out_g2 = torch.cat(out_g2, dim=-1) if out_g2 else None
return out_descriptor, out_gr, out_g2, out_h2, out_sw

@classmethod
Expand Down

0 comments on commit 13a8adf

Please sign in to comment.