Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add DOS net #3452

Merged
merged 32 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
159c4c6
feat: add dos net
anyangml Mar 12, 2024
1f8c74c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
f9b0b06
feat: add dp
anyangml Mar 12, 2024
a990122
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
5ce0e71
fix: serialize
anyangml Mar 13, 2024
9e396e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
62f4150
fix: dim_out serialize
anyangml Mar 13, 2024
40ee0f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
812d563
fix: UTs
anyangml Mar 13, 2024
407eb48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
8ae55d3
fix: UTs
anyangml Mar 13, 2024
ac5f1be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
fc9d01a
feat: add UTs
anyangml Mar 13, 2024
81a16b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
ef008d5
Merge branch 'devel' into feat/dos
anyangml Mar 13, 2024
8a7c250
feat: add training
anyangml Mar 13, 2024
e212776
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
4b985b4
fix: hack consistency UT
anyangml Mar 14, 2024
3825274
Merge branch 'devel' into feat/dos
anyangml Mar 14, 2024
a61f462
fix: remove UT hack
Mar 14, 2024
e5b6cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
06528f4
fix: precommit
Mar 14, 2024
6a1f995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
d7ecbba
fix: UTs
anyangml Mar 14, 2024
3916f5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
9f3b47d
Merge branch 'devel' into feat/dos
anyangml Mar 14, 2024
71f24ff
fix: update tf UTs
anyangml Mar 14, 2024
ad1d5ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
73313f0
fix: deep test
anyangml Mar 15, 2024
5cc34df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
60315ee
Merge branch 'devel' into feat/dos
anyangml Mar 15, 2024
784d7b9
fix: address comments
anyangml Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dipole_fitting import (
DipoleFitting,
)
from .dos_fitting import (
DOSFittingNet,
)
from .ener_fitting import (
EnergyFittingNet,
)
Expand All @@ -21,4 +24,5 @@
"DipoleFitting",
"EnergyFittingNet",
"PolarFitting",
"DOSFittingNet",
]
93 changes: 93 additions & 0 deletions deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
List,
Optional,
Union,
)

import numpy as np

from deepmd.dpmodel.common import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
)

if TYPE_CHECKING:
from deepmd.dpmodel.fitting.general_fitting import (

Check warning on line 20 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L20

Added line #L20 was not covered by tests
GeneralFitting,
)

from deepmd.utils.version import (
check_version_compatibility,
)


@InvarFitting.register("dos")
class DOSFittingNet(InvarFitting):
def __init__(
self,
ntypes: int,
dim_descrpt: int,
numb_dos: int = 300,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_dos: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
trainable: Union[bool, List[bool]] = True,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = False,
exclude_types: List[int] = [],
# not used
seed: Optional[int] = None,
):
if bias_dos is not None:
self.bias_dos = bias_dos

Check warning on line 51 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L51

Added line #L51 was not covered by tests
else:
self.bias_dos = np.zeros((ntypes, numb_dos), dtype=DEFAULT_PRECISION)
super().__init__(
var_name="dos",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=numb_dos,
neuron=neuron,
resnet_dt=resnet_dt,
bias_atom=bias_dos,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
exclude_types=exclude_types,
)

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data["numb_dos"] = data.pop("dim_out")
data.pop("tot_ener_zero", None)
data.pop("var_name", None)
data.pop("layer_name", None)
data.pop("use_aparam_as_mask", None)
data.pop("spin", None)
data.pop("atom_ener", None)
return super().deserialize(data)

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
dd = {
**super().serialize(),
"type": "dos",
}
dd["@variables"]["bias_atom_e"] = self.bias_atom_e

return dd
9 changes: 8 additions & 1 deletion deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
The dimension of the input descriptor.
neuron
Number of neurons :math:`N` in each hidden layer of the fitting net
bias_atom_e
Average enery per atom for each element.
resnet_dt
Time-step `dt` in the resnet construction:
:math:`y = x + dt * \phi (Wx + b)`
Expand Down Expand Up @@ -85,6 +87,7 @@
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_atom_e: Optional[np.ndarray] = None,
anyangml marked this conversation as resolved.
Show resolved Hide resolved
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[List[bool]] = None,
Expand Down Expand Up @@ -125,7 +128,11 @@

net_dim_out = self._net_out_dim()
# init constants
self.bias_atom_e = np.zeros([self.ntypes, net_dim_out])
if bias_atom_e is None:
self.bias_atom_e = np.zeros([self.ntypes, net_dim_out])
else:
assert bias_atom_e.shape == (self.ntypes, net_dim_out)
self.bias_atom_e = bias_atom_e

Check warning on line 135 in deepmd/dpmodel/fitting/general_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/general_fitting.py#L134-L135

Added lines #L134 - L135 were not covered by tests
if self.numb_fparam > 0:
self.fparam_avg = np.zeros(self.numb_fparam)
self.fparam_inv_std = np.ones(self.numb_fparam)
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class InvarFitting(GeneralFitting):
Number of atomic parameter
rcond
The condition number for the regression of atomic energy.
bias_atom
Bias for each element.
tot_ener_zero
Force the total energy to zero. Useful for the charge fitting.
trainable
Expand Down Expand Up @@ -117,10 +119,11 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_atom: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[List[bool]] = None,
atom_ener: Optional[List[float]] = [],
atom_ener: Optional[List[float]] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
layer_name: Optional[List[Optional[str]]] = None,
Expand Down Expand Up @@ -152,6 +155,7 @@ def __init__(
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
bias_atom_e=bias_atom,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
activation_function=activation_function,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/infer/deep_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
)
)

@property
def numb_dos(self) -> int:
"""Get the number of DOS."""
return self.get_numb_dos()

Check warning on line 62 in deepmd/infer/deep_dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_dos.py#L62

Added line #L62 was not covered by tests

def eval(
self,
coords: np.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def get_standard_model(model_params):
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["mixed_types"] = descriptor.mixed_types()
fitting_net["embedding_width"] = descriptor.get_dim_emb()
if fitting_net["type"] != "dos":
fitting_net["embedding_width"] = descriptor.get_dim_emb()
anyangml marked this conversation as resolved.
Show resolved Hide resolved
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
Expand Down
80 changes: 80 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
Optional,
)

import torch

from .dp_model import (
DPModel,
)
Comment on lines +9 to +11

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dp_model
begins an import cycle.


class DOSModel(DPModel):
model_type = "dos"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["dos_redu"]

if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]

Check warning on line 47 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L47

Added line #L47 was not covered by tests
else:
model_predict = model_ret
model_predict["updated_coord"] += coord

Check warning on line 50 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L49-L50

Added lines #L49 - L50 were not covered by tests
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(

Check warning on line 64 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L64

Added line #L64 was not covered by tests
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["dos_redu"]

Check warning on line 76 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L73-L76

Added lines #L73 - L76 were not covered by tests

else:
model_predict = model_ret
return model_predict

Check warning on line 80 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L79-L80

Added lines #L79 - L80 were not covered by tests
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.dos import (
DOSFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
Expand Down Expand Up @@ -45,6 +48,9 @@
from deepmd.pt.model.model.dipole_model import (
DipoleModel,
)
from deepmd.pt.model.model.dos_model import (
DOSModel,
)
Comment on lines +51 to +53

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dos_model
begins an import cycle.
from deepmd.pt.model.model.ener_model import (
EnergyModel,
)
Expand All @@ -68,6 +74,8 @@
cls = DipoleModel
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel
elif isinstance(fitting, DOSFittingNet):
cls = DOSModel
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)

Expand Down
Loading