Skip to content

Commit

Permalink
fix(pt): fix lammps nlist sort with large sel (deepmodeling#3993)
Browse files Browse the repository at this point in the history
`nlist` was not sorted and passed into `build_multiple_neighbor_list`
when using dpa2 and large sel (larger than lammps max nnei).

Bug fixed on several error systems but a UT needed.

- [x] UT for special sel case in dpa2.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Introduced a new parameter to enhance sorting flexibility in
computation methods.
- Added methods to check if sorted neighbor lists are required,
providing improved control over model behavior.

- **Bug Fixes**
- Enhanced functionality related to neighbor list formatting for
improved performance across models.

- **Documentation**
- Updated method signatures to reflect new parameters and
functionalities, clarifying usage for end-users.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duo <[email protected]>
Co-authored-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
iProzd and njzjz authored Jul 24, 2024
1 parent 7f9300d commit a708c7a
Show file tree
Hide file tree
Showing 37 changed files with 908 additions and 9 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return self.descriptor.need_sorted_nlist_for_lower()

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return True

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def mixed_types(self) -> bool:
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""

@abstractmethod
def fwd(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return False

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def call(
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""


def extend_descrpt_stat(des, type_map, des_with_stat=None):
r"""
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.se_atten.need_sorted_nlist_for_lower()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down Expand Up @@ -956,6 +960,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False


class NeighborGatedAttention(NativeOP):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,10 @@ def has_message_passing(self) -> bool:
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return True

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return True

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def mixed_types(self) -> bool:
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""

@abstractmethod
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False


# translated by GPT and modified
def get_residual(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
32 changes: 27 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ def call_lower(
"""
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.reshape(nframes, -1, 3)
nlist = self.format_nlist(extended_coord, extended_atype, nlist)
nlist = self.format_nlist(
extended_coord,
extended_atype,
nlist,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)
cc_ext, _, fp, ap, input_prec = self.input_type_cast(
extended_coord, fparam=fparam, aparam=aparam
)
Expand Down Expand Up @@ -311,6 +316,7 @@ def format_nlist(
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
extra_nlist_sort: bool = False,
):
"""Format the neighbor list.
Expand All @@ -336,6 +342,8 @@ def format_nlist(
atomic type in extended region. nf x nall
nlist
neighbor list. nf x nloc x nsel
extra_nlist_sort
whether to forcibly sort the nlist.
Returns
-------
Expand All @@ -345,7 +353,12 @@ def format_nlist(
"""
n_nf, n_nloc, n_nnei = nlist.shape
mixed_types = self.mixed_types()
ret = self._format_nlist(extended_coord, nlist, sum(self.get_sel()))
ret = self._format_nlist(
extended_coord,
nlist,
sum(self.get_sel()),
extra_nlist_sort=extra_nlist_sort,
)
if not mixed_types:
ret = nlist_distinguish_types(ret, extended_atype, self.get_sel())
return ret
Expand All @@ -355,6 +368,7 @@ def _format_nlist(
extended_coord: np.ndarray,
nlist: np.ndarray,
nnei: int,
extra_nlist_sort: bool = False,
):
n_nf, n_nloc, n_nnei = nlist.shape
extended_coord = extended_coord.reshape([n_nf, -1, 3])
Expand All @@ -370,7 +384,9 @@ def _format_nlist(
],
axis=-1,
)
elif n_nnei > nnei:

if n_nnei > nnei or extra_nlist_sort:
n_nf, n_nloc, n_nnei = nlist.shape
# make a copy before revise
m_real_nei = nlist >= 0
ret = np.where(m_real_nei, nlist, 0)
Expand All @@ -384,9 +400,11 @@ def _format_nlist(
ret = np.take_along_axis(ret, ret_mapping, axis=2)
ret = np.where(rr > rcut, -1, ret)
ret = ret[..., :nnei]
else: # n_nnei == nnei:
# copy anyway...
# not extra_nlist_sort and n_nnei <= nnei:
elif n_nnei == nnei:
ret = nlist
else:
pass
assert ret.shape[-1] == nnei
return ret

Expand Down Expand Up @@ -483,6 +501,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the model needs sorted nlist when using `forward_lower`."""
return self.atomic_model.need_sorted_nlist_for_lower()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return self.descriptor.need_sorted_nlist_for_lower()

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return True

def get_out_bias(self) -> torch.Tensor:
return self.out_bias

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return False

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def forward(
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""


def make_default_type_embedding(
ntypes,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.se_atten.need_sorted_nlist_for_lower()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def has_message_passing(self) -> bool:
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return True

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
# the env_protection of repinit is the same as that of the repformer
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return True

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,7 @@ def get_stats(self) -> Dict[str, StatItem]:
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.sea.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.sea.need_sorted_nlist_for_lower()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.sea.get_env_protection()
Expand Down Expand Up @@ -712,3 +716,7 @@ def forward(
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False


class NeighborGatedAttention(nn.Module):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.seat.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.seat.need_sorted_nlist_for_lower()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.seat.get_env_protection()
Expand Down Expand Up @@ -727,3 +731,7 @@ def forward(
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False
1 change: 1 addition & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)
if self.get_fitting_net() is not None:
model_predict = {}
Expand Down
Loading

0 comments on commit a708c7a

Please sign in to comment.