-
Notifications
You must be signed in to change notification settings - Fork 524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: add DOS net #3452
Merged
Merged
Feat: add DOS net #3452
Changes from 4 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
159c4c6
feat: add dos net
anyangml 1f8c74c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f9b0b06
feat: add dp
anyangml a990122
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5ce0e71
fix: serialize
anyangml 9e396e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 62f4150
fix: dim_out serialize
anyangml 40ee0f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 812d563
fix: UTs
anyangml 407eb48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8ae55d3
fix: UTs
anyangml ac5f1be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fc9d01a
feat: add UTs
anyangml 81a16b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ef008d5
Merge branch 'devel' into feat/dos
anyangml 8a7c250
feat: add training
anyangml e212776
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4b985b4
fix: hack consistency UT
anyangml 3825274
Merge branch 'devel' into feat/dos
anyangml a61f462
fix: remove UT hack
e5b6cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 06528f4
fix: precommit
6a1f995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d7ecbba
fix: UTs
anyangml 3916f5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9f3b47d
Merge branch 'devel' into feat/dos
anyangml 71f24ff
fix: update tf UTs
anyangml ad1d5ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 73313f0
fix: deep test
anyangml 5cc34df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 60315ee
Merge branch 'devel' into feat/dos
anyangml 784d7b9
fix: address comments
anyangml File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import copy | ||
from typing import ( | ||
TYPE_CHECKING, | ||
List, | ||
Optional, | ||
) | ||
|
||
from deepmd.dpmodel.common import ( | ||
DEFAULT_PRECISION, | ||
) | ||
from deepmd.dpmodel.fitting.invar_fitting import ( | ||
InvarFitting, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from deepmd.dpmodel.fitting.general_fitting import ( | ||
GeneralFitting, | ||
) | ||
from deepmd.utils.version import ( | ||
check_version_compatibility, | ||
) | ||
|
||
|
||
@InvarFitting.register("dos") | ||
class DOSFittingNet(InvarFitting): | ||
def __init__( | ||
self, | ||
ntypes: int, | ||
dim_descrpt: int, | ||
neuron: List[int] = [120, 120, 120], | ||
resnet_dt: bool = True, | ||
numb_fparam: int = 0, | ||
numb_aparam: int = 0, | ||
numb_dos: int = 300, | ||
rcond: Optional[float] = None, | ||
trainable: Optional[List[bool]] = None, | ||
activation_function: str = "tanh", | ||
precision: str = DEFAULT_PRECISION, | ||
mixed_types: bool = False, | ||
exclude_types: List[int] = [], | ||
# not used | ||
seed: Optional[int] = None, | ||
): | ||
super().__init__( | ||
var_name="dos", | ||
ntypes=ntypes, | ||
dim_descrpt=dim_descrpt, | ||
dim_out=numb_dos, | ||
neuron=neuron, | ||
resnet_dt=resnet_dt, | ||
numb_fparam=numb_fparam, | ||
numb_aparam=numb_aparam, | ||
rcond=rcond, | ||
trainable=trainable, | ||
activation_function=activation_function, | ||
precision=precision, | ||
mixed_types=mixed_types, | ||
exclude_types=exclude_types, | ||
) | ||
|
||
@classmethod | ||
def deserialize(cls, data: dict) -> "GeneralFitting": | ||
data = copy.deepcopy(data) | ||
check_version_compatibility(data.pop("@version", 1), 1, 1) | ||
data.pop("var_name") | ||
data.pop("dim_out") | ||
data.pop("tot_ener_zero") | ||
data.pop("layer_name") | ||
data.pop("use_aparam_as_mask") | ||
data.pop("spin") | ||
data.pop("atom_ener") | ||
return super().deserialize(data) | ||
|
||
def serialize(self) -> dict: | ||
"""Serialize the fitting to dict.""" | ||
return { | ||
**super().serialize(), | ||
"type": "dos", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Dict, | ||
Optional, | ||
) | ||
|
||
import torch | ||
|
||
from .dp_model import ( | ||
DPModel, | ||
) | ||
|
||
|
||
class DOSModel(DPModel): | ||
model_type = "dos" | ||
|
||
def __init__( | ||
self, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
|
||
def forward( | ||
self, | ||
coord, | ||
atype, | ||
box: Optional[torch.Tensor] = None, | ||
fparam: Optional[torch.Tensor] = None, | ||
aparam: Optional[torch.Tensor] = None, | ||
do_atomic_virial: bool = False, | ||
) -> Dict[str, torch.Tensor]: | ||
model_ret = self.forward_common( | ||
coord, | ||
atype, | ||
box, | ||
fparam=fparam, | ||
aparam=aparam, | ||
do_atomic_virial=do_atomic_virial, | ||
) | ||
if self.get_fitting_net() is not None: | ||
model_predict = {} | ||
model_predict["atom_dos"] = model_ret["dos"] | ||
model_predict["dos"] = model_ret["dos_redu"] | ||
|
||
if "mask" in model_ret: | ||
model_predict["mask"] = model_ret["mask"] | ||
else: | ||
model_predict = model_ret | ||
model_predict["updated_coord"] += coord | ||
return model_predict | ||
|
||
@torch.jit.export | ||
def forward_lower( | ||
self, | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping: Optional[torch.Tensor] = None, | ||
fparam: Optional[torch.Tensor] = None, | ||
aparam: Optional[torch.Tensor] = None, | ||
do_atomic_virial: bool = False, | ||
): | ||
model_ret = self.forward_common_lower( | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping, | ||
fparam=fparam, | ||
aparam=aparam, | ||
do_atomic_virial=do_atomic_virial, | ||
) | ||
if self.get_fitting_net() is not None: | ||
model_predict = {} | ||
model_predict["atom_dos"] = model_ret["dos"] | ||
model_predict["dos"] = model_ret["energy_redu"] | ||
|
||
else: | ||
model_predict = model_ret | ||
return model_predict | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import copy | ||
import logging | ||
from typing import ( | ||
List, | ||
Optional, | ||
) | ||
|
||
import torch | ||
|
||
from deepmd.pt.model.task.ener import ( | ||
InvarFitting, | ||
) | ||
from deepmd.pt.model.task.fitting import ( | ||
Fitting, | ||
) | ||
from deepmd.pt.utils import ( | ||
env, | ||
) | ||
from deepmd.pt.utils.env import ( | ||
DEFAULT_PRECISION, | ||
) | ||
from deepmd.utils.version import ( | ||
check_version_compatibility, | ||
) | ||
|
||
dtype = env.GLOBAL_PT_FLOAT_PRECISION | ||
device = env.DEVICE | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@Fitting.register("dos") | ||
class DOSFittingNet(InvarFitting): | ||
def __init__( | ||
self, | ||
ntypes: int, | ||
dim_descrpt: int, | ||
neuron: List[int] = [128, 128, 128], | ||
resnet_dt: bool = True, | ||
numb_fparam: int = 0, | ||
numb_aparam: int = 0, | ||
numb_dos: int = 300, | ||
rcond: Optional[float] = None, | ||
bias_dos: Optional[torch.Tensor] = None, | ||
trainable: Optional[List[bool]] = None, | ||
seed: Optional[int] = None, | ||
activation_function: str = "tanh", | ||
precision: str = DEFAULT_PRECISION, | ||
exclude_types: List[int] = [], | ||
mixed_types: bool = True, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
var_name="dos", | ||
ntypes=ntypes, | ||
dim_descrpt=dim_descrpt, | ||
dim_out=numb_dos, | ||
neuron=neuron, | ||
bias_atom_e=bias_dos, | ||
resnet_dt=resnet_dt, | ||
numb_fparam=numb_fparam, | ||
numb_aparam=numb_aparam, | ||
activation_function=activation_function, | ||
precision=precision, | ||
mixed_types=mixed_types, | ||
rcond=rcond, | ||
seed=seed, | ||
exclude_types=exclude_types, | ||
trainable=trainable, | ||
**kwargs, | ||
) | ||
|
||
@classmethod | ||
def deserialize(cls, data: dict) -> "InvarFitting": | ||
data = copy.deepcopy(data) | ||
check_version_compatibility(data.pop("@version", 1), 1, 1) | ||
data.pop("var_name") | ||
data.pop("dim_out") | ||
return super().deserialize(data) | ||
|
||
def serialize(self) -> dict: | ||
"""Serialize the fitting to dict.""" | ||
return { | ||
**super().serialize(), | ||
"type": "dos", | ||
} | ||
|
||
# make jit happy with torch 2.0.0 | ||
exclude_types: List[int] | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check notice
Code scanning / CodeQL
Cyclic import Note