Skip to content

Commit

Permalink
add mask in stat data
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 6, 2024
1 parent 7cb8f9c commit 84d1c70
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 41 deletions.
9 changes: 9 additions & 0 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def __init__(
# (ntypes)
self.type_mask = self.type_mask.reshape([-1])

def get_exclude_types(self):
return self.exclude_types

def get_type_mask(self):
return self.type_mask

def build_type_exclude_mask(
self,
atype: np.ndarray,
Expand Down Expand Up @@ -75,6 +81,9 @@ def __init__(
# (ntypes+1 x ntypes+1)
self.type_mask = self.type_mask.reshape([-1])

def get_exclude_types(self):
return self.exclude_types

def build_type_exclude_mask(
self,
nlist: np.ndarray,
Expand Down
19 changes: 17 additions & 2 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import functools
import logging
from typing import (
Dict,
Expand Down Expand Up @@ -204,9 +205,23 @@ def compute_or_load_stat(
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)
self.descriptor.compute_input_stats(sampled_func, stat_file_path)

@functools.lru_cache
def wrapped_sampler():
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(sampled_func, stat_file_path)
self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path)

@torch.jit.export
def get_dim_fparam(self) -> int:
Expand Down
14 changes: 0 additions & 14 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,6 @@ def get_stats(self) -> Dict[str, StatItem]:
"""Get the statistics of the descriptor."""
raise NotImplementedError

def get_emask(self, nlist: torch.Tensor, atype: torch.Tensor) -> torch.Tensor:
"""
Compute the pair-wise type mask for given nlist and atype, for data stat
with shape same as nlist.
1 for include and 0 for exclude.
"""
if hasattr(self, "emask"):
exclude_mask = self.emask(nlist, atype)
else:
exclude_mask = torch.ones_like(
nlist, dtype=torch.int32, device=nlist.device
)
return exclude_mask

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
Expand Down
14 changes: 0 additions & 14 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,6 @@ def mixed_types(self) -> bool:
"""
return False

def get_emask(self, nlist: torch.Tensor, atype: torch.Tensor) -> torch.Tensor:
"""
Compute the pair-wise type mask for given nlist and atype,
with shape same as nlist.
1 for include and 0 for exclude.
"""
if hasattr(self, "emask"):
exclude_mask = self.emask(nlist, atype)
else:
exclude_mask = torch.ones_like(
nlist, dtype=torch.int32, device=nlist.device
)
return exclude_mask

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
Expand Down
13 changes: 9 additions & 4 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
GeneralFitting,
)
from deepmd.pt.utils import (
AtomExcludeMask,
env,
)
from deepmd.pt.utils.env import (
Expand Down Expand Up @@ -176,10 +177,14 @@ def compute_output_stats(
sampled = merged
energy = [item["energy"] for item in sampled]
data_mixed_type = "real_natoms_vec" in sampled[0]
if data_mixed_type:
input_natoms = [item["real_natoms_vec"] for item in sampled]
else:
input_natoms = [item["natoms"] for item in sampled]
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
for system in sampled:
if "atom_exclude_types" in system:
type_mask = AtomExcludeMask(
self.ntypes, system["atom_exclude_types"]
).get_type_mask()
system[natoms_key][:, 2:] *= type_mask.unsqueeze(0)
input_natoms = [item[natoms_key] for item in sampled]
# shape: (nframes, ndim)
merged_energy = to_numpy_array(torch.cat(energy))
# shape: (nframes, ntypes)
Expand Down
21 changes: 14 additions & 7 deletions deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Dict,
Iterator,
List,
Tuple,
Union,
)

import numpy as np
Expand All @@ -18,6 +20,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)
Expand Down Expand Up @@ -73,13 +78,13 @@ def __init__(self, descriptor: "DescriptorBlock"):
) # se_r=1, se_a=4

def iter(
self, data: List[Dict[str, torch.Tensor]]
self, data: List[Dict[str, Union[torch.Tensor, List[Tuple[int, int]]]]]
) -> Iterator[Dict[str, StatItem]]:
"""Get the iterator of the environment matrix.
Parameters
----------
data : List[Dict[str, torch.Tensor]]
data : List[Dict[str, Union[torch.Tensor, List[Tuple[int, int]]]]]
The data.
Yields
Expand Down Expand Up @@ -148,9 +153,6 @@ def iter(
self.descriptor.get_nsel(),
self.last_dim,
)
exclude_mask = self.descriptor.get_emask(nlist, extended_atype).view(
coord.shape[0] * coord.shape[1], -1
)
atype = atype.view(coord.shape[0] * coord.shape[1])
# (1, nloc) eq (ntypes, 1), so broadcast is possible
# shape: (ntypes, nloc)
Expand All @@ -160,8 +162,13 @@ def iter(
self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32
).view(-1, 1),
)
# shape: (ntypes, nloc, nnei)
type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask)
if "pair_exclude_types" in system:
# shape: (1, nloc, nnei)
exclude_mask = PairExcludeMask(
self.descriptor.get_ntypes(), system["pair_exclude_types"]
)(nlist, extended_atype).view(1, coord.shape[0] * coord.shape[1], -1)
# shape: (ntypes, nloc, nnei)
type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask)
for type_i in range(self.descriptor.get_ntypes()):
dd = env_mat[type_idx[type_i]]
dd = dd.reshape([-1, self.last_dim]) # typen_atoms * unmasked_nnei, 4
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def reinit(
)
self.type_mask = to_torch_tensor(self.type_mask).view([-1])

def get_exclude_types(self):
return self.exclude_types

def get_type_mask(self):
return self.type_mask

def forward(
self,
atype: torch.Tensor,
Expand Down Expand Up @@ -97,6 +103,9 @@ def reinit(
self.type_mask = to_torch_tensor(self.type_mask).view([-1])
self.no_exclusion = len(self._exclude_types) == 0

def get_exclude_types(self):
return self._exclude_types

# may have a better place for this method...
def forward(
self,
Expand Down

0 comments on commit 84d1c70

Please sign in to comment.