Skip to content

Commit

Permalink
Add exclude_types for data stat
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 5, 2024
1 parent 8a72e01 commit 9772fa4
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 4 deletions.
6 changes: 3 additions & 3 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Callable,
List,
Optional,
Tuple,
Union,
)

Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
type_one_side: bool = True,
precision: str = "default",
trainable: bool = True,
exclude_types: Optional[List[List[int]]] = None,
exclude_types: List[Tuple[int, int]] = [],
stripped_type_embedding: bool = False,
smooth_type_embdding: bool = False,
):
Expand All @@ -73,8 +74,6 @@ def __init__(
raise NotImplementedError("type_one_side is not supported.")
if precision != "default" and precision != "float64":
raise NotImplementedError("precison is not supported.")
if exclude_types is not None and exclude_types != []:
raise NotImplementedError("exclude_types is not supported.")
if stripped_type_embedding:
raise NotImplementedError("stripped_type_embedding is not supported.")
if smooth_type_embdding:
Expand Down Expand Up @@ -103,6 +102,7 @@ def __init__(
normalize=normalize,
temperature=temperature,
return_rot=return_rot,
exclude_types=exclude_types,
env_protection=env_protection,
)
self.type_embedding = TypeEmbedNet(ntypes, tebd_dim)
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Callable,
List,
Optional,
Tuple,
Union,
)

Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
repformer_add_type_ebd_to_seq: bool = False,
env_protection: float = 0.0,
trainable: bool = True,
exclude_types: List[Tuple[int, int]] = [],
type: Optional[
str
] = None, # work around the bad design in get_trainer and DpLoaderSet!
Expand Down Expand Up @@ -176,6 +178,9 @@ def __init__(
repformers block: concatenate the type embedding at the output.
trainable : bool
If the parameters in the descriptor are trainable.
exclude_types : List[Tuple[int, int]] = [],
The excluded pairs of types which have no interaction with each other.
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
Returns
-------
Expand Down Expand Up @@ -206,6 +211,7 @@ def __init__(
tebd_input_mode="concat",
# tebd_input_mode='dot_residual_s',
set_davg_zero=repinit_set_davg_zero,
exclude_types=exclude_types,
env_protection=env_protection,
activation_function=repinit_activation,
)
Expand Down Expand Up @@ -238,6 +244,7 @@ def __init__(
set_davg_zero=repformer_set_davg_zero,
smooth=True,
add_type_ebd_to_seq=repformer_add_type_ebd_to_seq,
exclude_types=exclude_types,
env_protection=env_protection,
)
self.type_embedding = TypeEmbedNet(ntypes, tebd_dim)
Expand Down
14 changes: 14 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Dict,
List,
Optional,
Tuple,
Union,
)

Expand All @@ -24,6 +25,9 @@
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.exclude_mask import (

Check warning on line 28 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L28

Added line #L28 was not covered by tests
PairExcludeMask,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
)
Expand Down Expand Up @@ -83,6 +87,7 @@ def __init__(
set_davg_zero: bool = True, # TODO
smooth: bool = True,
add_type_ebd_to_seq: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
type: Optional[str] = None,
):
Expand Down Expand Up @@ -115,6 +120,8 @@ def __init__(
self.act = get_activation_fn(activation_function)
self.direct_dist = direct_dist
self.add_type_ebd_to_seq = add_type_ebd_to_seq
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
self.env_protection = env_protection

Check warning on line 125 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L124-L125

Added lines #L124 - L125 were not covered by tests

self.g2_embd = mylinear(1, self.g2_dim)
Expand Down Expand Up @@ -213,6 +220,13 @@ def dim_emb(self):
"""Returns the embedding dimension g2."""
return self.get_dim_emb()

def reinit_exclude(

Check warning on line 223 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L223

Added line #L223 was not covered by tests
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

Check warning on line 228 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L227-L228

Added lines #L227 - L228 were not covered by tests

def forward(
self,
nlist: torch.Tensor,
Expand Down
14 changes: 14 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Dict,
List,
Optional,
Tuple,
Union,
)

Expand All @@ -26,6 +27,9 @@
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.exclude_mask import (

Check warning on line 30 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L30

Added line #L30 was not covered by tests
PairExcludeMask,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
Expand Down Expand Up @@ -61,6 +65,7 @@ def __init__(
normalize=True,
temperature=None,
return_rot=False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
type: Optional[str] = None,
):
Expand Down Expand Up @@ -108,6 +113,8 @@ def __init__(
self.split_sel = self.sel
self.nnei = sum(sel)
self.ndescrpt = self.nnei * 4
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)

Check warning on line 117 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L117

Added line #L117 was not covered by tests
self.dpa1_attention = NeighborWiseAttention(
self.attn_layer,
self.nnei,
Expand Down Expand Up @@ -251,6 +258,13 @@ def get_stats(self) -> Dict[str, StatItem]:
)
return self.stats

def reinit_exclude(

Check warning on line 261 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L261

Added line #L261 was not covered by tests
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

Check warning on line 266 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L265-L266

Added lines #L265 - L266 were not covered by tests

def forward(
self,
nlist: torch.Tensor,
Expand Down
10 changes: 9 additions & 1 deletion deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
self.old_impl = False # this does not support old implementation.
self.exclude_types = exclude_types
self.ntypes = len(sel)
self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types)
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
self.env_protection = env_protection

Check warning on line 87 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L86-L87

Added lines #L86 - L87 were not covered by tests

self.sel = sel
Expand Down Expand Up @@ -255,6 +256,13 @@ def __getitem__(self, key):
else:
raise KeyError(key)

def reinit_exclude(

Check warning on line 259 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L259

Added line #L259 was not covered by tests
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

Check warning on line 264 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L263-L264

Added lines #L263 - L264 were not covered by tests

def forward(
self,
coord_ext: torch.Tensor,
Expand Down

0 comments on commit 9772fa4

Please sign in to comment.