-
Notifications
You must be signed in to change notification settings - Fork 523
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Refact the DPA2 descriptor in PyTorch with clearer interface - Support residual - Remove bn - Add numpy implement <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added a new descriptor class `DescrptDPA2` implementing DPA-2 functionality for computing descriptors and representations based on input coordinates and atom types. - Expanded supported backends for DPA-2 descriptor to include DP in addition to PyTorch. - **Documentation** - Updated the supported backends information in the documentation for the DPA-2 descriptor to reflect the addition of DP backend support. - Added a reference to the model implementation and a training example link in the DPA-2 descriptor documentation. - **Tests** - Introduced test cases for the `DescrptDPA2` class in different frameworks like TensorFlow, PyTorch, and DeepMD to cover various parameters and configurations. - Validated the functionality of the `DescrptDPA2` descriptor class for deep learning models in the test case class `TestDescrptDPA2`. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
7ab3040
commit 9e8f55c
Showing
35 changed files
with
6,165 additions
and
828 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import logging | ||
from abc import ( | ||
ABC, | ||
abstractmethod, | ||
) | ||
from typing import ( | ||
Callable, | ||
Dict, | ||
List, | ||
Optional, | ||
Union, | ||
) | ||
|
||
import numpy as np | ||
|
||
from deepmd.utils.env_mat_stat import ( | ||
StatItem, | ||
) | ||
from deepmd.utils.path import ( | ||
DPPath, | ||
) | ||
from deepmd.utils.plugin import ( | ||
make_plugin_registry, | ||
) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class DescriptorBlock(ABC, make_plugin_registry("DescriptorBlock")): | ||
"""The building block of descriptor. | ||
Given the input descriptor, provide with the atomic coordinates, | ||
atomic types and neighbor list, calculate the new descriptor. | ||
""" | ||
|
||
local_cluster = False | ||
|
||
def __new__(cls, *args, **kwargs): | ||
if cls is DescriptorBlock: | ||
try: | ||
descrpt_type = kwargs["type"] | ||
except KeyError: | ||
raise KeyError("the type of DescriptorBlock should be set by `type`") | ||
cls = cls.get_class_by_type(descrpt_type) | ||
return super().__new__(cls) | ||
|
||
@abstractmethod | ||
def get_rcut(self) -> float: | ||
"""Returns the cut-off radius.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_nsel(self) -> int: | ||
"""Returns the number of selected atoms in the cut-off radius.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_sel(self) -> List[int]: | ||
"""Returns the number of selected atoms for each type.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_ntypes(self) -> int: | ||
"""Returns the number of element types.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_dim_out(self) -> int: | ||
"""Returns the output dimension.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_dim_in(self) -> int: | ||
"""Returns the input dimension.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_dim_emb(self) -> int: | ||
"""Returns the embedding dimension.""" | ||
pass | ||
|
||
def compute_input_stats( | ||
self, | ||
merged: Union[Callable[[], List[dict]], List[dict]], | ||
path: Optional[DPPath] = None, | ||
): | ||
""" | ||
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. | ||
Parameters | ||
---------- | ||
merged : Union[Callable[[], List[dict]], List[dict]] | ||
- List[dict]: A list of data samples from various data systems. | ||
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` | ||
originating from the `i`-th data system. | ||
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format | ||
only when needed. Since the sampling process can be slow and memory-intensive, | ||
the lazy function helps by only sampling once. | ||
path : Optional[DPPath] | ||
The path to the stat file. | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_stats(self) -> Dict[str, StatItem]: | ||
"""Get the statistics of the descriptor.""" | ||
raise NotImplementedError | ||
|
||
def share_params(self, base_class, shared_level, resume=False): | ||
""" | ||
Share the parameters of self to the base_class with shared_level during multitask training. | ||
If not start from checkpoint (resume is False), | ||
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def call( | ||
self, | ||
nlist: np.ndarray, | ||
extended_coord: np.ndarray, | ||
extended_atype: np.ndarray, | ||
extended_atype_embd: Optional[np.ndarray] = None, | ||
mapping: Optional[np.ndarray] = None, | ||
): | ||
"""Calculate DescriptorBlock.""" | ||
pass |
Oops, something went wrong.