-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat : pt: support property fitting #3488
Closed
Closed
Changes from 33 commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
d1d0a0a
3.16 update:support property fitting(only zero bias and mean pooling)…
Chengqian-Zhang b9a7be4
3.18 update
Chengqian-Zhang 0d58b71
Merge branch 'deepmodeling:devel' into devel
Chengqian-Zhang 3f95a82
Add DeepProperty and UT
Chengqian-Zhang 0dc11d6
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Chengqian-Zhang 786b528
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cea7476
fix pre-commit
Chengqian-Zhang 5cd0d56
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Chengqian-Zhang 4965529
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a05cace
Add example
Chengqian-Zhang d93bfdb
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Chengqian-Zhang 3050172
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 20cee41
delete input
Chengqian-Zhang 89f6f31
resolve cof
Chengqian-Zhang 501b46d
fix push bug
Chengqian-Zhang 8c5645e
Merge branch 'deepmodeling:devel' into devel
Chengqian-Zhang c8dad8d
3.19 update
Chengqian-Zhang d5eaf30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 70674f6
recover
Chengqian-Zhang da7e8c2
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Chengqian-Zhang 55c5f96
recover
Chengqian-Zhang b645176
delete denoise
Chengqian-Zhang 03cbaba
delete denoise
Chengqian-Zhang 196eb0d
delete denoise argcheck
Chengqian-Zhang e137ad5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8afc64f
delete denoise file
Chengqian-Zhang 00c8ef2
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Chengqian-Zhang f334119
delete kwargs in property head
Chengqian-Zhang 479b106
fix pre-commit
Chengqian-Zhang 63649db
Merge branch 'devel' into devel
Chengqian-Zhang 5cffff9
task_num->task_dim
Chengqian-Zhang a364947
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Chengqian-Zhang 5bda82e
delete loss
Chengqian-Zhang b00b678
Merge branch 'deepmodeling:devel' into devel
Chengqian-Zhang 00ec256
Add property loss
Chengqian-Zhang 6396fb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 698ef21
fix eval
Chengqian-Zhang ff61d6b
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Chengqian-Zhang 5fafabb
resolve conversation
Chengqian-Zhang c76e23e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 97bd86f
Add example to tests
Chengqian-Zhang af6a7fe
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Chengqian-Zhang cbb9c4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 09d4de1
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
Chengqian-Zhang 1818271
fix bug of loss_func
Chengqian-Zhang 34ea4d8
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
Chengqian-Zhang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import copy | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Callable, | ||
List, | ||
Optional, | ||
Union, | ||
) | ||
|
||
from deepmd.dpmodel.common import ( | ||
DEFAULT_PRECISION, | ||
) | ||
from deepmd.dpmodel.fitting.invar_fitting import ( | ||
InvarFitting, | ||
) | ||
from deepmd.dpmodel.output_def import ( | ||
FittingOutputDef, | ||
OutputVariableDef, | ||
) | ||
from deepmd.utils.path import ( | ||
DPPath, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from deepmd.dpmodel.fitting.general_fitting import ( | ||
GeneralFitting, | ||
) | ||
|
||
from deepmd.utils.version import ( | ||
check_version_compatibility, | ||
) | ||
|
||
|
||
@InvarFitting.register("property") | ||
class PropertyFittingNet(InvarFitting): | ||
def __init__( | ||
self, | ||
ntypes: int, | ||
dim_descrpt: int, | ||
task_dim: int = 1, | ||
neuron: List[int] = [128, 128, 128], | ||
resnet_dt: bool = True, | ||
numb_fparam: int = 0, | ||
numb_aparam: int = 0, | ||
rcond: Optional[float] = None, | ||
tot_ener_zero: bool = False, | ||
trainable: Optional[List[bool]] = None, | ||
atom_ener: Optional[List[float]] = None, | ||
activation_function: str = "tanh", | ||
precision: str = DEFAULT_PRECISION, | ||
layer_name: Optional[List[Optional[str]]] = None, | ||
use_aparam_as_mask: bool = False, | ||
spin: Any = None, | ||
mixed_types: bool = False, | ||
exclude_types: List[int] = [], | ||
# not used | ||
seed: Optional[int] = None, | ||
): | ||
self.task_dim = task_dim | ||
super().__init__( | ||
var_name="property", | ||
ntypes=ntypes, | ||
dim_descrpt=dim_descrpt, | ||
dim_out=task_dim, | ||
neuron=neuron, | ||
resnet_dt=resnet_dt, | ||
numb_fparam=numb_fparam, | ||
numb_aparam=numb_aparam, | ||
rcond=rcond, | ||
tot_ener_zero=tot_ener_zero, | ||
trainable=trainable, | ||
atom_ener=atom_ener, | ||
activation_function=activation_function, | ||
precision=precision, | ||
layer_name=layer_name, | ||
use_aparam_as_mask=use_aparam_as_mask, | ||
spin=spin, | ||
mixed_types=mixed_types, | ||
exclude_types=exclude_types, | ||
) | ||
|
||
@classmethod | ||
def deserialize(cls, data: dict) -> "GeneralFitting": | ||
data = copy.deepcopy(data) | ||
check_version_compatibility(data.pop("@version", 1), 1, 1) | ||
data.pop("var_name") | ||
data.pop("dim_out") | ||
return super().deserialize(data) | ||
|
||
def serialize(self) -> dict: | ||
"""Serialize the fitting to dict.""" | ||
return {**super().serialize(), "type": "property", "task_dim": self.task_dim} | ||
|
||
def output_def(self) -> FittingOutputDef: | ||
return FittingOutputDef( | ||
[ | ||
OutputVariableDef( | ||
self.var_name, | ||
[self.dim_out], | ||
reduciable=True, | ||
r_differentiable=False, | ||
c_differentiable=False, | ||
), | ||
] | ||
) | ||
|
||
def compute_output_stats( | ||
self, | ||
merged: Union[Callable[[], List[dict]], List[dict]], | ||
stat_file_path: Optional[DPPath] = None, | ||
): | ||
""" | ||
Compute the output statistics (e.g. energy bias) for the fitting net from packed data. | ||
|
||
Parameters | ||
---------- | ||
merged : Union[Callable[[], List[dict]], List[dict]] | ||
- List[dict]: A list of data samples from various data systems. | ||
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` | ||
originating from the `i`-th data system. | ||
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format | ||
only when needed. Since the sampling process can be slow and memory-intensive, | ||
the lazy function helps by only sampling once. | ||
stat_file_path : Optional[DPPath] | ||
The path to the stat file. | ||
|
||
""" | ||
pass | ||
|
||
# make jit happy with torch 2.0.0 | ||
exclude_types: List[int] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Any, | ||
Dict, | ||
List, | ||
Optional, | ||
Tuple, | ||
Union, | ||
) | ||
|
||
import numpy as np | ||
|
||
from deepmd.dpmodel.output_def import ( | ||
FittingOutputDef, | ||
ModelOutputDef, | ||
OutputVariableDef, | ||
) | ||
|
||
from .deep_eval import ( | ||
DeepEval, | ||
) | ||
|
||
|
||
class DeepProperty(DeepEval): | ||
"""Properties of structures. | ||
|
||
Parameters | ||
---------- | ||
model_file : Path | ||
The name of the frozen model file. | ||
*args : list | ||
Positional arguments. | ||
auto_batch_size : bool or int or AutoBatchSize, default: True | ||
If True, automatic batch size will be used. If int, it will be used | ||
as the initial batch size. | ||
neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional | ||
The ASE neighbor list class to produce the neighbor list. If None, the | ||
neighbor list will be built natively in the model. | ||
**kwargs : dict | ||
Keyword arguments. | ||
""" | ||
|
||
@property | ||
def output_def(self) -> ModelOutputDef: | ||
"""Get the output definition of this model.""" | ||
return ModelOutputDef( | ||
FittingOutputDef( | ||
[ | ||
OutputVariableDef( | ||
"property", | ||
shape=[-1], | ||
reduciable=True, | ||
atomic=True, | ||
), | ||
] | ||
) | ||
) | ||
|
||
@property | ||
def numb_task(self) -> int: | ||
"""Get the number of task.""" | ||
return self.get_numb_task() | ||
|
||
def eval( | ||
self, | ||
coords: np.ndarray, | ||
cells: Optional[np.ndarray], | ||
atom_types: Union[List[int], np.ndarray], | ||
atomic: bool = False, | ||
fparam: Optional[np.ndarray] = None, | ||
aparam: Optional[np.ndarray] = None, | ||
mixed_type: bool = False, | ||
**kwargs: Dict[str, Any], | ||
) -> Tuple[np.ndarray, ...]: | ||
"""Evaluate properties. If atomic is True, also return atomic property. | ||
|
||
Parameters | ||
---------- | ||
coords : np.ndarray | ||
The coordinates of the atoms, in shape (nframes, natoms, 3). | ||
cells : np.ndarray | ||
The cell vectors of the system, in shape (nframes, 9). If the system | ||
is not periodic, set it to None. | ||
atom_types : List[int] or np.ndarray | ||
The types of the atoms. If mixed_type is False, the shape is (natoms,); | ||
otherwise, the shape is (nframes, natoms). | ||
atomic : bool, optional | ||
Whether to return atomic property, by default False. | ||
fparam : np.ndarray, optional | ||
The frame parameters, by default None. | ||
aparam : np.ndarray, optional | ||
The atomic parameters, by default None. | ||
mixed_type : bool, optional | ||
Whether the atom_types is mixed type, by default False. | ||
**kwargs : Dict[str, Any] | ||
Keyword arguments. | ||
|
||
Returns | ||
------- | ||
property | ||
The properties of the system, in shape (nframes, num_tasks). | ||
""" | ||
( | ||
coords, | ||
cells, | ||
atom_types, | ||
fparam, | ||
aparam, | ||
nframes, | ||
natoms, | ||
) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type) | ||
results = self.deep_eval.eval( | ||
coords, | ||
cells, | ||
atom_types, | ||
atomic, | ||
fparam=fparam, | ||
aparam=aparam, | ||
**kwargs, | ||
) | ||
atomic_property = results["property"].reshape(nframes, natoms, -1) | ||
property = np.sum(atomic_property, axis=1) | ||
|
||
if atomic: | ||
return ( | ||
property, | ||
atomic_property, | ||
) | ||
else: | ||
return (property,) | ||
|
||
def get_numb_task(self) -> int: | ||
return self.deep_eval.get_numb_task() | ||
|
||
|
||
__all__ = ["DeepProperty"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check notice
Code scanning / CodeQL
Returning tuples with varying lengths Note