Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dp and pt: implement fitting exclude types #3282

Merged
merged 4 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@
EmbeddingNet,
EnvMat,
NetworkCollection,
PairExcludeMask,
)

from .base_descriptor import (
BaseDescriptor,
)
from .exclude_mask import (
ExcludeMask,
)


class DescrptSeA(NativeOP, BaseDescriptor):
Expand Down Expand Up @@ -160,7 +158,7 @@
self.activation_function = activation_function
self.precision = precision
self.spin = spin
self.emask = ExcludeMask(self.ntypes, self.exclude_types)
self.emask = PairExcludeMask(self.ntypes, self.exclude_types)

Check warning on line 161 in deepmd/dpmodel/descriptor/se_e2_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_e2_a.py#L161

Added line #L161 was not covered by tests

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
fitting_check_output,
)
from deepmd.dpmodel.utils import (
AtomExcludeMask,
FittingNet,
NetworkCollection,
)
Expand Down Expand Up @@ -126,6 +127,7 @@
use_aparam_as_mask: bool = False,
spin: Any = None,
distinguish_types: bool = False,
exclude_types: List[int] = [],
):
# seed, uniform_seed are not included
if tot_ener_zero:
Expand Down Expand Up @@ -159,8 +161,10 @@
self.use_aparam_as_mask = use_aparam_as_mask
self.spin = spin
self.distinguish_types = distinguish_types
self.exclude_types = exclude_types

Check warning on line 164 in deepmd/dpmodel/fitting/invar_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/invar_fitting.py#L164

Added line #L164 was not covered by tests
if self.spin is not None:
raise NotImplementedError("spin is not supported")
self.emask = AtomExcludeMask(self.ntypes, exclude_types=self.exclude_types)

Check warning on line 167 in deepmd/dpmodel/fitting/invar_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/invar_fitting.py#L167

Added line #L167 was not covered by tests

# init constants
self.bias_atom_e = np.zeros([self.ntypes, self.dim_out])
Expand Down Expand Up @@ -260,6 +264,7 @@
"precision": self.precision,
"distinguish_types": self.distinguish_types,
"nets": self.nets.serialize(),
"exclude_types": self.exclude_types,
"@variables": {
"bias_atom_e": self.bias_atom_e,
"fparam_avg": self.fparam_avg,
Expand Down Expand Up @@ -370,4 +375,10 @@
outs = outs + atom_energy # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + self.bias_atom_e[atype]

# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)

Check warning on line 380 in deepmd/dpmodel/fitting/invar_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/invar_fitting.py#L380

Added line #L380 was not covered by tests
# nf x nloc x nod
outs = outs * exclude_mask[:, :, None]

Check warning on line 382 in deepmd/dpmodel/fitting/invar_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/invar_fitting.py#L382

Added line #L382 was not covered by tests

return {self.var_name: outs}
6 changes: 6 additions & 0 deletions deepmd/dpmodel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from .env_mat import (
EnvMat,
)
from .exclude_mask import (
AtomExcludeMask,
PairExcludeMask,
)
from .network import (
EmbeddingNet,
FittingNet,
Expand Down Expand Up @@ -53,4 +57,6 @@
"inter2phys",
"phys2inter",
"to_face_distance",
"AtomExcludeMask",
"PairExcludeMask",
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,54 @@
import numpy as np


class ExcludeMask:
"""Computes the atom type exclusion mask."""
class AtomExcludeMask:
"""Computes the type exclusion mask for atoms."""

def __init__(
self,
ntypes: int,
exclude_types: List[int] = [],
):
self.ntypes = ntypes
self.exclude_types = exclude_types
self.type_mask = np.array(

Check warning on line 20 in deepmd/dpmodel/utils/exclude_mask.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/exclude_mask.py#L18-L20

Added lines #L18 - L20 were not covered by tests
[1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)],
dtype=np.int32,
)
# (ntypes)
self.type_mask = self.type_mask.reshape([-1])

Check warning on line 25 in deepmd/dpmodel/utils/exclude_mask.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/exclude_mask.py#L25

Added line #L25 was not covered by tests

def build_type_exclude_mask(
self,
atype: np.ndarray,
):
"""Compute type exclusion mask for atoms.

Parameters
----------
atype
The extended aotm types. shape: nf x natom

Returns
-------
mask
The type exclusion mask for atoms. shape: nf x natom
Element [ff,ii] being 0 if type(ii) is excluded,
otherwise being 1.

"""
nf, natom = atype.shape
return self.type_mask[atype].reshape(nf, natom)

Check warning on line 47 in deepmd/dpmodel/utils/exclude_mask.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/exclude_mask.py#L46-L47

Added lines #L46 - L47 were not covered by tests


class PairExcludeMask:
"""Computes the type exclusion mask for atom pairs."""

def __init__(
self,
ntypes: int,
exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.ntypes = ntypes
self.exclude_types = set()
for tt in exclude_types:
Expand All @@ -41,7 +80,7 @@
nlist: np.ndarray,
atype_ext: np.ndarray,
):
"""Compute type exclusion mask.
"""Compute type exclusion mask for atom pairs.

Parameters
----------
Expand All @@ -53,7 +92,7 @@
Returns
-------
mask
The type exclusion mask of shape: nf x nloc x nnei.
The type exclusion mask for pair atoms of shape: nf x nloc x nnei.
Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded,
otherwise being 1.

Expand Down
78 changes: 0 additions & 78 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
Callable,
List,
Optional,
Set,
Tuple,
Union,
)

Expand All @@ -22,9 +20,6 @@
from deepmd.pt.utils.plugin import (
Plugin,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)

from .base_descriptor import (
BaseDescriptor,
Expand Down Expand Up @@ -211,32 +206,6 @@ class DescriptorBlock(torch.nn.Module, ABC):
__plugins = Plugin()
local_cluster = False

def __init__(
self,
ntypes: int,
exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
_exclude_types: Set[Tuple[int, int]] = set()
for tt in exclude_types:
assert len(tt) == 2
_exclude_types.add((tt[0], tt[1]))
_exclude_types.add((tt[1], tt[0]))
# ntypes + 1 for nlist masks
self.type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in _exclude_types else 0
for tt_i in range(ntypes + 1)
]
for tt_j in range(ntypes + 1)
],
dtype=np.int32,
)
# (ntypes+1 x ntypes+1)
self.type_mask = to_torch_tensor(self.type_mask).view([-1])
self.no_exclusion = len(_exclude_types) == 0

@staticmethod
def register(key: str) -> Callable:
"""Register a DescriptorBlock plugin.
Expand Down Expand Up @@ -365,53 +334,6 @@ def forward(
"""Calculate DescriptorBlock."""
pass

# may have a better place for this method...
def build_type_exclude_mask(
self,
nlist: torch.Tensor,
atype_ext: torch.Tensor,
) -> torch.Tensor:
"""Compute type exclusion mask.

Parameters
----------
nlist
The neighbor list. shape: nf x nloc x nnei
atype_ext
The extended aotm types. shape: nf x nall

Returns
-------
mask
The type exclusion mask of shape: nf x nloc x nnei.
Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded,
otherwise being 1.

"""
if self.no_exclusion:
# safely return 1 if nothing is excluded.
return torch.ones_like(nlist, dtype=torch.int32, device=nlist.device)
nf, nloc, nnei = nlist.shape
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = torch.cat(
[
atype_ext,
self.get_ntypes()
* torch.ones([nf, 1], dtype=atype_ext.dtype, device=atype_ext.device),
],
dim=-1,
)
type_i = atype_ext[:, :nloc].view(nf, nloc) * (self.get_ntypes() + 1)
# nf x nloc x nnei
index = torch.where(nlist == -1, nall, nlist).view(nf, nloc * nnei)
type_j = torch.gather(ae, 1, index).view(nf, nloc, nnei)
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = type_ij.view(nf, nloc * nnei)
mask = self.type_mask[type_ij].view(nf, nloc, nnei)
return mask


def compute_std(sumv2, sumv, sumn, rcut_r):
"""Compute standard deviation."""
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
- descriptor_list: list of descriptors.
- descriptor_param: descriptor configs.
"""
super().__init__(ntypes)
super().__init__()

Check warning on line 35 in deepmd/pt/model/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/hybrid.py#L35

Added line #L35 was not covered by tests
supported_descrpt = ["se_atten", "se_uni"]
descriptor_list = []
for descriptor_param_item in list:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
whether or not add an type embedding to seq_input.
If no seq_input is given, it has no effect.
"""
super().__init__(ntypes)
super().__init__()

Check warning on line 92 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L92

Added line #L92 was not covered by tests
del type
self.epsilon = 1e-4 # protection of 1./nnei
self.rcut = rcut
Expand Down
10 changes: 6 additions & 4 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from deepmd.pt.model.network.network import (
TypeFilter,
)
from deepmd.pt.utils.exclude_mask import (

Check warning on line 41 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L41

Added line #L41 was not covered by tests
PairExcludeMask,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)
Expand Down Expand Up @@ -272,7 +275,7 @@
- filter_neuron: Number of neurons in each hidden layers of the embedding net.
- axis_neuron: Number of columns of the sub-matrix of the embedding matrix.
"""
super().__init__(len(sel), exclude_types=exclude_types)
super().__init__()

Check warning on line 278 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L278

Added line #L278 was not covered by tests
self.rcut = rcut
self.rcut_smth = rcut_smth
self.neuron = neuron
Expand All @@ -286,6 +289,7 @@
self.old_impl = old_impl
self.exclude_types = exclude_types
self.ntypes = len(sel)
self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types)

Check warning on line 292 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L292

Added line #L292 was not covered by tests

self.sel = sel
self.sec = torch.tensor(
Expand Down Expand Up @@ -528,9 +532,7 @@
[nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE
)
# nfnl x nnei
exclude_mask = self.build_type_exclude_mask(nlist, extended_atype).view(
nfnl, -1
)
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)

Check warning on line 535 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L535

Added line #L535 was not covered by tests
for ii, ll in enumerate(self.filter_layers.networks):
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
- filter_neuron: Number of neurons in each hidden layers of the embedding net.
- axis_neuron: Number of columns of the sub-matrix of the embedding matrix.
"""
super().__init__(ntypes)
super().__init__()

Check warning on line 67 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L67

Added line #L67 was not covered by tests
del type
self.rcut = rcut
self.rcut_smth = rcut_smth
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
distinguish_types: bool = False,
rcond: Optional[float] = None,
seed: Optional[int] = None,
exclude_types: List[int] = [],
**kwargs,
):
super().__init__(
Expand All @@ -106,6 +107,7 @@ def __init__(
distinguish_types=distinguish_types,
rcond=rcond,
seed=seed,
exclude_types=exclude_types,
**kwargs,
)

Expand Down
Loading
Loading