Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support virtual atom #3469

Merged
merged 7 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 57 additions & 25 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,19 @@

def atomic_output_def(self) -> FittingOutputDef:
old_def = self.fitting_output_def()
if self.atom_excl is None:
return old_def
else:
old_list = list(old_def.get_data().values())
return FittingOutputDef(
old_list # noqa:RUF005
+ [
OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
]
)
old_list = list(old_def.get_data().values())
return FittingOutputDef(

Check warning on line 60 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L59-L60

Added lines #L59 - L60 were not covered by tests
old_list # noqa:RUF005
+ [
OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
]
)

def forward_common_atomic(
self,
Expand All @@ -82,31 +79,66 @@
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
"""Common interface for atomic inference.

This method accept extended coordinates, extended atom typs, neighbor list,
and predict the atomic contribution of the fit property.

Parameters
----------
extended_coord
extended coodinates, shape: nf x (nall x 3)
extended_atype
extended atom typs, shape: nf x nall
for a type < 0 indicating the atomic is virtual.
nlist
neighbor list, shape: nf x nloc x nsel
mapping
extended to local index mapping, shape: nf x nall
fparam
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam

Returns
-------
ret_dict
dict of output atomic properties.
should implement the definition of `fitting_output_def`.
ret_dict["mask"] of shape nf x nloc will be provided.
ret_dict["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real.
ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.

"""
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = np.where(pair_mask == 1, nlist, -1)

ext_atom_mask = self.make_atom_mask(extended_atype)

Check warning on line 120 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L120

Added line #L120 was not covered by tests
ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
np.where(ext_atom_mask, extended_atype, 0),
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(np.int32)

Check warning on line 131 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L131

Added line #L131 was not covered by tests
if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)

Check warning on line 133 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L133

Added line #L133 was not covered by tests

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
ret_dict[kk] = (

Check warning on line 137 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L135-L137

Added lines #L135 - L137 were not covered by tests
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

Check warning on line 141 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L141

Added line #L141 was not covered by tests

return ret_dict

Expand Down
22 changes: 22 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@
def deserialize(cls, data: dict):
pass

def make_atom_mask(
self,
atype: t_tensor,
) -> t_tensor:
"""The atoms with type < 0 are treated as virutal atoms,
which serves as place-holders for multi-frame calculations
with different number of atoms in different frames.

Parameters
----------
atype
Atom types. >= 0 for real atoms <0 for virtual atoms.

Returns
-------
mask
True for real atoms and False for virutal atoms.

"""
# supposed to be supported by all backends
return atype >= 0

Check warning on line 159 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L159

Added line #L159 was not covered by tests

def do_grad_r(
self,
var_name: Optional[str] = None,
Expand Down
22 changes: 16 additions & 6 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

## translated from torch implemantation by chatgpt
def build_neighbor_list(
coord1: np.ndarray,
coord: np.ndarray,
atype: np.ndarray,
nloc: int,
rcut: float,
Expand All @@ -26,10 +26,11 @@

Parameters
----------
coord1 : np.ndarray
coord : np.ndarray
exptended coordinates of shape [batch_size, nall x 3]
atype : np.ndarray
extended atomic types of shape [batch_size, nall]
type < 0 the atom is treat as virtual atoms.
nloc : int
number of local atoms.
rcut : float
Expand All @@ -54,11 +55,20 @@
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.

"""
batch_size = coord1.shape[0]
coord1 = coord1.reshape(batch_size, -1)
nall = coord1.shape[1] // 3
batch_size = coord.shape[0]
coord = coord.reshape(batch_size, -1)
nall = coord.shape[1] // 3

Check warning on line 63 in deepmd/dpmodel/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/nlist.py#L61-L63

Added lines #L61 - L63 were not covered by tests
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
xmax = np.max(coord) + 2.0 * rcut

Check warning on line 66 in deepmd/dpmodel/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/nlist.py#L66

Added line #L66 was not covered by tests
# nf x nall
is_vir = atype < 0
coord1 = np.where(is_vir[:, :, None], xmax, coord.reshape(-1, nall, 3)).reshape(

Check warning on line 69 in deepmd/dpmodel/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/nlist.py#L68-L69

Added lines #L68 - L69 were not covered by tests
-1, nall * 3
)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
Expand Down Expand Up @@ -88,7 +98,7 @@
axis=-1,
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
nlist = np.where((rr > rcut), -1, nlist)
nlist = np.where(np.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist)

Check warning on line 101 in deepmd/dpmodel/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/nlist.py#L101

Added line #L101 was not covered by tests

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
Expand Down
105 changes: 80 additions & 25 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,44 @@
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

# to make jit happy...
def make_atom_mask(

Check warning on line 62 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L62

Added line #L62 was not covered by tests
self,
atype: torch.Tensor,
) -> torch.Tensor:
"""The atoms with type < 0 are treated as virutal atoms,
which serves as place-holders for multi-frame calculations
with different number of atoms in different frames.

Parameters
----------
atype
Atom types. >= 0 for real atoms <0 for virtual atoms.

Returns
-------
mask
True for real atoms and False for virutal atoms.

"""
# supposed to be supported by all backends
return atype >= 0

Check warning on line 82 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L82

Added line #L82 was not covered by tests

def atomic_output_def(self) -> FittingOutputDef:
old_def = self.fitting_output_def()
if self.atom_excl is None:
return old_def
else:
old_list = list(old_def.get_data().values())
return FittingOutputDef(
old_list # noqa:RUF005
+ [
OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
]
)
old_list = list(old_def.get_data().values())
return FittingOutputDef(

Check warning on line 87 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L86-L87

Added lines #L86 - L87 were not covered by tests
old_list # noqa:RUF005
+ [
OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
]
)

def forward_common_atomic(
self,
Expand All @@ -86,6 +106,37 @@
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Common interface for atomic inference.

This method accept extended coordinates, extended atom typs, neighbor list,
and predict the atomic contribution of the fit property.

Parameters
----------
extended_coord
extended coodinates, shape: nf x (nall x 3)
extended_atype
extended atom typs, shape: nf x nall
for a type < 0 indicating the atomic is virtual.
nlist
neighbor list, shape: nf x nloc x nsel
mapping
extended to local index mapping, shape: nf x nall
fparam
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam

Returns
-------
ret_dict
dict of output atomic properties.
should implement the definition of `fitting_output_def`.
ret_dict["mask"] of shape nf x nloc will be provided.
ret_dict["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real.
ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.

"""
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]

Expand All @@ -94,24 +145,28 @@
# exclude neighbors in the nlist
nlist = torch.where(pair_mask == 1, nlist, -1)

ext_atom_mask = self.make_atom_mask(extended_atype)

Check warning on line 148 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L148

Added line #L148 was not covered by tests
ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
torch.where(ext_atom_mask, extended_atype, 0),
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].to(torch.int32)

Check warning on line 159 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L159

Added line #L159 was not covered by tests
if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
atom_mask *= self.atom_excl(atype)

Check warning on line 161 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L161

Added line #L161 was not covered by tests

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
ret_dict[kk] = (

Check warning on line 165 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L163-L165

Added lines #L163 - L165 were not covered by tests
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).view(out_shape)
ret_dict["mask"] = atom_mask

Check warning on line 169 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L169

Added line #L169 was not covered by tests

return ret_dict

Expand Down
24 changes: 18 additions & 6 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@


def build_neighbor_list(
coord1: torch.Tensor,
coord: torch.Tensor,
atype: torch.Tensor,
nloc: int,
rcut: float,
Expand All @@ -62,10 +62,11 @@

Parameters
----------
coord1 : torch.Tensor
coord : torch.Tensor
exptended coordinates of shape [batch_size, nall x 3]
atype : torch.Tensor
extended atomic types of shape [batch_size, nall]
if type < 0 the atom is treat as virtual atoms.
nloc : int
number of local atoms.
rcut : float
Expand All @@ -90,11 +91,20 @@
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.

"""
batch_size = coord1.shape[0]
coord1 = coord1.view(batch_size, -1)
nall = coord1.shape[1] // 3
batch_size = coord.shape[0]
coord = coord.view(batch_size, -1)
nall = coord.shape[1] // 3

Check warning on line 99 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L97-L99

Added lines #L97 - L99 were not covered by tests
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
xmax = torch.max(coord) + 2.0 * rcut

Check warning on line 102 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L102

Added line #L102 was not covered by tests
# nf x nall
is_vir = atype < 0
coord1 = torch.where(is_vir[:, :, None], xmax, coord.view(-1, nall, 3)).view(

Check warning on line 105 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L104-L105

Added lines #L104 - L105 were not covered by tests
-1, nall * 3
)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
Expand Down Expand Up @@ -133,7 +143,9 @@
dim=-1,
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
nlist = nlist.masked_fill((rr > rcut), -1)
nlist = torch.where(

Check warning on line 146 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L146

Added line #L146 was not covered by tests
torch.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist
)

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
Expand Down
Loading
Loading