Skip to content

Commit

Permalink
fix bug in test. add doc for forward_common_atomic
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Mar 15, 2024
1 parent a15875d commit 8d05a4e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
31 changes: 31 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ def forward_common_atomic(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
"""Common interface for atomic inference.
This method accept extended coordinates, extended atom typs, neighbor list,
and predict the atomic contribution of the fit property.
Parameters
----------
extended_coord
extended coodinates, shape: nf x (nall x 3)
extended_atype
extended atom typs, shape: nf x nall
for a type < 0 indicating the atomic is virtual.
nlist
neighbor list, shape: nf x nloc x nsel
mapping
extended to local index mapping, shape: nf x nall
fparam
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam
Returns
-------
ret_dict
dict of output atomic properties.
should implement the definition of `fitting_output_def`.
ret_dit["mask"] of shape nf x nloc will be provided.
ret_dit["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real.
ret_dit["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.
"""
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
Expand Down
28 changes: 27 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,33 @@ def forward_common_atomic(
) -> Dict[str, torch.Tensor]:
"""Common interface for atomic inference.
This method accept
This method accept extended coordinates, extended atom typs, neighbor list,
and predict the atomic contribution of the fit property.
Parameters
----------
extended_coord
extended coodinates, shape: nf x (nall x 3)
extended_atype
extended atom typs, shape: nf x nall
for a type < 0 indicating the atomic is virtual.
nlist
neighbor list, shape: nf x nloc x nsel
mapping
extended to local index mapping, shape: nf x nall
fparam
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam
Returns
-------
ret_dict
dict of output atomic properties.
should implement the definition of `fitting_output_def`.
ret_dit["mask"] of shape nf x nloc will be provided.
ret_dit["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real.
ret_dit["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.
"""
_, nloc, _ = nlist.shape
Expand Down
2 changes: 1 addition & 1 deletion source/tests/common/dpmodel/test_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def setUp(self):
self.nsel = [10, 10]
self.ref_nlist = np.array(
[
[-1] * sum(self.sel),
[-1] * sum(self.nsel),
[1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1],
[1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1],
]
Expand Down

0 comments on commit 8d05a4e

Please sign in to comment.