-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compute output stat for atomic model #3642
Merged
wanghan-iapcm
merged 13 commits into
deepmodeling:devel
from
wanghan-iapcm:atom-output-stat-1
Apr 7, 2024
Merged
Changes from 12 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
66add48
stat: add support for std. out_stat: support output var of any shape.…
6dd726d
fix bugs
9e89b25
model atomic model call output stat. fix test_finetune
75591a1
change the atomic model's init interface of the dpmodel
7ab0a4f
fix ut
ce7ec1f
fix ut
48ee272
support preset atom bias. add doc str to base atomic model
4c038d8
solve name conflict
0dae3c9
fix bugs
4714684
Update source/tests/pt/test_multitask.py
wanghan-iapcm 1cbbd9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a3e6f57
add doc str
wanghan-iapcm 2171e19
fix ut
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from deepmd.pt.utils import ( | ||
AtomExcludeMask, | ||
PairExcludeMask, | ||
env, | ||
) | ||
from deepmd.pt.utils.nlist import ( | ||
extend_input_and_build_neighbor_list, | ||
|
@@ -35,19 +36,88 @@ | |
) | ||
|
||
log = logging.getLogger(__name__) | ||
dtype = env.GLOBAL_PT_FLOAT_PRECISION | ||
device = env.DEVICE | ||
|
||
BaseAtomicModel_ = make_base_atomic_model(torch.Tensor) | ||
|
||
|
||
class BaseAtomicModel(BaseAtomicModel_): | ||
class BaseAtomicModel(torch.nn.Module, BaseAtomicModel_): | ||
"""The base of atomic model. | ||
|
||
Parameters | ||
---------- | ||
type_map | ||
Mapping atom type to the name (str) of the type. | ||
For example `type_map[1]` gives the name of the type 1. | ||
atom_exclude_types | ||
Exclude the atomic contribution of the given types | ||
pair_exclude_types | ||
Exclude the pair of atoms of the given types from computing the output | ||
of the atomic model. Implemented by removing the pairs from the nlist. | ||
rcond : float, optional | ||
The condition number for the regression of atomic energy. | ||
preset_out_bias : Dict[str, List[Optional[torch.Tensor]]], optional | ||
Specifying atomic energy contribution in vacuum. Given by key:value pairs. | ||
The value is a list specifying the bias. the elements can be None or np.array of output shape. | ||
For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.] | ||
The `set_davg_zero` key in the descrptor should be set. | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
type_map: List[str], | ||
atom_exclude_types: List[int] = [], | ||
pair_exclude_types: List[Tuple[int, int]] = [], | ||
rcond: Optional[float] = None, | ||
preset_out_bias: Optional[Dict[str, torch.Tensor]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line 74 is different from line 60 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
): | ||
super().__init__() | ||
torch.nn.Module.__init__(self) | ||
BaseAtomicModel_.__init__(self) | ||
self.type_map = type_map | ||
self.reinit_atom_exclude(atom_exclude_types) | ||
self.reinit_pair_exclude(pair_exclude_types) | ||
self.rcond = rcond | ||
self.preset_out_bias = preset_out_bias | ||
|
||
def init_out_stat(self): | ||
"""Initialize the output bias.""" | ||
ntypes = self.get_ntypes() | ||
self.bias_keys: List[str] = list(self.fitting_output_def().keys()) | ||
self.max_out_size = max( | ||
[self.atomic_output_def()[kk].size for kk in self.bias_keys] | ||
) | ||
self.n_out = len(self.bias_keys) | ||
out_bias_data = torch.zeros( | ||
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device | ||
) | ||
out_std_data = torch.ones( | ||
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device | ||
) | ||
self.register_buffer("out_bias", out_bias_data) | ||
self.register_buffer("out_std", out_std_data) | ||
|
||
def __setitem__(self, key, value): | ||
if key in ["out_bias"]: | ||
self.out_bias = value | ||
elif key in ["out_std"]: | ||
self.out_std = value | ||
else: | ||
raise KeyError(key) | ||
|
||
def __getitem__(self, key): | ||
if key in ["out_bias"]: | ||
return self.out_bias | ||
elif key in ["out_std"]: | ||
return self.out_std | ||
else: | ||
raise KeyError(key) | ||
|
||
@torch.jit.export | ||
def get_type_map(self) -> List[str]: | ||
"""Get the type map.""" | ||
return self.type_map | ||
|
||
def reinit_atom_exclude( | ||
self, | ||
|
@@ -165,6 +235,7 @@ | |
fparam=fparam, | ||
aparam=aparam, | ||
) | ||
ret_dict = self.apply_out_stat(ret_dict, atype) | ||
|
||
# nf x nloc | ||
atom_mask = ext_atom_mask[:, :nloc].to(torch.int32) | ||
|
@@ -210,9 +281,60 @@ | |
""" | ||
raise NotImplementedError | ||
|
||
def compute_or_load_out_stat( | ||
self, | ||
merged: Union[Callable[[], List[dict]], List[dict]], | ||
stat_file_path: Optional[DPPath] = None, | ||
): | ||
""" | ||
Compute the output statistics (e.g. energy bias) for the fitting net from packed data. | ||
|
||
Parameters | ||
---------- | ||
merged : Union[Callable[[], List[dict]], List[dict]] | ||
- List[dict]: A list of data samples from various data systems. | ||
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` | ||
originating from the `i`-th data system. | ||
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format | ||
only when needed. Since the sampling process can be slow and memory-intensive, | ||
the lazy function helps by only sampling once. | ||
stat_file_path : Optional[DPPath] | ||
The path to the stat file. | ||
|
||
""" | ||
self.change_out_bias( | ||
merged, | ||
stat_file_path=stat_file_path, | ||
bias_adjust_mode="set-by-statistic", | ||
) | ||
|
||
def apply_out_stat( | ||
self, | ||
ret: Dict[str, torch.Tensor], | ||
atype: torch.Tensor, | ||
): | ||
"""Apply the stat to each atomic output. | ||
The developer may override the method to define how the bias is applied | ||
to the atomic output of the model. | ||
|
||
Parameters | ||
---------- | ||
ret | ||
The returned dict by the forward_atomic method | ||
atype | ||
The atom types. nf x nloc | ||
|
||
""" | ||
out_bias, out_std = self._fetch_out_stat(self.bias_keys) | ||
for kk in self.bias_keys: | ||
# nf x nloc x odims, out_bias: ntypes x odims | ||
ret[kk] = ret[kk] + out_bias[kk][atype] | ||
return ret | ||
|
||
def change_out_bias( | ||
self, | ||
sample_merged, | ||
stat_file_path: Optional[DPPath] = None, | ||
iProzd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
bias_adjust_mode="change-by-statistic", | ||
) -> None: | ||
"""Change the output bias according to the input data and the pretrained model. | ||
|
@@ -231,22 +353,32 @@ | |
'change-by-statistic' : perform predictions on labels of target dataset, | ||
and do least square on the errors to obtain the target shift as bias. | ||
'set-by-statistic' : directly use the statistic output bias in the target dataset. | ||
stat_file_path : Optional[DPPath] | ||
The path to the stat file. | ||
""" | ||
if bias_adjust_mode == "change-by-statistic": | ||
delta_bias = compute_output_stats( | ||
delta_bias, out_std = compute_output_stats( | ||
sample_merged, | ||
self.get_ntypes(), | ||
keys=self.get_output_keys(), | ||
keys=list(self.atomic_output_def().keys()), | ||
stat_file_path=stat_file_path, | ||
model_forward=self._get_forward_wrapper_func(), | ||
)["energy"] | ||
self.set_out_bias(delta_bias, add=True) | ||
rcond=self.rcond, | ||
preset_bias=self.preset_out_bias, | ||
) | ||
# self.set_out_bias(delta_bias, add=True) | ||
self._store_out_stat(delta_bias, out_std, add=True) | ||
elif bias_adjust_mode == "set-by-statistic": | ||
bias_atom = compute_output_stats( | ||
bias_out, std_out = compute_output_stats( | ||
sample_merged, | ||
self.get_ntypes(), | ||
keys=self.get_output_keys(), | ||
)["energy"] | ||
self.set_out_bias(bias_atom) | ||
keys=list(self.atomic_output_def().keys()), | ||
stat_file_path=stat_file_path, | ||
rcond=self.rcond, | ||
preset_bias=self.preset_out_bias, | ||
) | ||
# self.set_out_bias(bias_out) | ||
self._store_out_stat(bias_out, std_out) | ||
else: | ||
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) | ||
|
||
|
@@ -279,3 +411,63 @@ | |
return {kk: vv.detach() for kk, vv in atomic_ret.items()} | ||
|
||
return model_forward | ||
|
||
def _varsize( | ||
self, | ||
shape: List[int], | ||
) -> int: | ||
output_size = 1 | ||
len_shape = len(shape) | ||
for i in range(len_shape): | ||
output_size *= shape[i] | ||
return output_size | ||
|
||
def _get_bias_index( | ||
self, | ||
kk: str, | ||
) -> int: | ||
res: List[int] = [] | ||
for i, e in enumerate(self.bias_keys): | ||
if e == kk: | ||
res.append(i) | ||
assert len(res) == 1 | ||
return res[0] | ||
|
||
def _store_out_stat( | ||
self, | ||
out_bias: Dict[str, torch.Tensor], | ||
out_std: Dict[str, torch.Tensor], | ||
add: bool = False, | ||
): | ||
ntypes = self.get_ntypes() | ||
out_bias_data = torch.clone(self.out_bias) | ||
out_std_data = torch.clone(self.out_std) | ||
for kk in out_bias.keys(): | ||
assert kk in out_std.keys() | ||
idx = self._get_bias_index(kk) | ||
size = self._varsize(self.atomic_output_def()[kk].shape) | ||
if not add: | ||
out_bias_data[idx, :, :size] = out_bias[kk].view(ntypes, size) | ||
else: | ||
out_bias_data[idx, :, :size] += out_bias[kk].view(ntypes, size) | ||
out_std_data[idx, :, :size] = out_std[kk].view(ntypes, size) | ||
self.out_bias.copy_(out_bias_data) | ||
self.out_std.copy_(out_std_data) | ||
|
||
def _fetch_out_stat( | ||
self, | ||
keys: List[str], | ||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: | ||
ret_bias = {} | ||
ret_std = {} | ||
ntypes = self.get_ntypes() | ||
for kk in keys: | ||
idx = self._get_bias_index(kk) | ||
isize = self._varsize(self.atomic_output_def()[kk].shape) | ||
ret_bias[kk] = self.out_bias[idx, :, :isize].view( | ||
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005 | ||
) | ||
ret_std[kk] = self.out_std[idx, :, :isize].view( | ||
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005 | ||
) | ||
return ret_bias, ret_std |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check warning
Code scanning / CodeQL
Overwriting attribute in super-class or sub-class Warning