Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add DOS net #3452

Merged
merged 32 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
159c4c6
feat: add dos net
anyangml Mar 12, 2024
1f8c74c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
f9b0b06
feat: add dp
anyangml Mar 12, 2024
a990122
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
5ce0e71
fix: serialize
anyangml Mar 13, 2024
9e396e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
62f4150
fix: dim_out serialize
anyangml Mar 13, 2024
40ee0f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
812d563
fix: UTs
anyangml Mar 13, 2024
407eb48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
8ae55d3
fix: UTs
anyangml Mar 13, 2024
ac5f1be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
fc9d01a
feat: add UTs
anyangml Mar 13, 2024
81a16b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
ef008d5
Merge branch 'devel' into feat/dos
anyangml Mar 13, 2024
8a7c250
feat: add training
anyangml Mar 13, 2024
e212776
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
4b985b4
fix: hack consistency UT
anyangml Mar 14, 2024
3825274
Merge branch 'devel' into feat/dos
anyangml Mar 14, 2024
a61f462
fix: remove UT hack
Mar 14, 2024
e5b6cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
06528f4
fix: precommit
Mar 14, 2024
6a1f995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
d7ecbba
fix: UTs
anyangml Mar 14, 2024
3916f5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
9f3b47d
Merge branch 'devel' into feat/dos
anyangml Mar 14, 2024
71f24ff
fix: update tf UTs
anyangml Mar 14, 2024
ad1d5ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
73313f0
fix: deep test
anyangml Mar 15, 2024
5cc34df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
60315ee
Merge branch 'devel' into feat/dos
anyangml Mar 15, 2024
784d7b9
fix: address comments
anyangml Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (

Check warning on line 3 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L2-L3

Added lines #L2 - L3 were not covered by tests
TYPE_CHECKING,
List,
Optional,
)

from deepmd.dpmodel.common import (

Check warning on line 9 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L9

Added line #L9 was not covered by tests
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.invar_fitting import (

Check warning on line 12 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L12

Added line #L12 was not covered by tests
InvarFitting,
)

if TYPE_CHECKING:
from deepmd.dpmodel.fitting.general_fitting import (

Check warning on line 17 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L16-L17

Added lines #L16 - L17 were not covered by tests
GeneralFitting,
)
from deepmd.utils.version import (

Check warning on line 20 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L20

Added line #L20 was not covered by tests
check_version_compatibility,
)


@InvarFitting.register("dos")
class DOSFittingNet(InvarFitting):
def __init__(

Check warning on line 27 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L25-L27

Added lines #L25 - L27 were not covered by tests
self,
ntypes: int,
dim_descrpt: int,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dos: int = 300,
rcond: Optional[float] = None,
trainable: Optional[List[bool]] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = False,
exclude_types: List[int] = [],
# not used
seed: Optional[int] = None,
):
super().__init__(

Check warning on line 45 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L45

Added line #L45 was not covered by tests
var_name="dos",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=numb_dos,
neuron=neuron,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
exclude_types=exclude_types,
)

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
data.pop("tot_ener_zero")
data.pop("layer_name")
data.pop("use_aparam_as_mask")
data.pop("spin")
data.pop("atom_ener")
return super().deserialize(data)

Check warning on line 73 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L62-L73

Added lines #L62 - L73 were not covered by tests

def serialize(self) -> dict:

Check warning on line 75 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L75

Added line #L75 was not covered by tests
"""Serialize the fitting to dict."""
return {

Check warning on line 77 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L77

Added line #L77 was not covered by tests
**super().serialize(),
"type": "dos",
}
80 changes: 80 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L2

Added line #L2 was not covered by tests
Dict,
Optional,
)

import torch

Check warning on line 7 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L7

Added line #L7 was not covered by tests

from .dp_model import (

Check warning on line 9 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L9

Added line #L9 was not covered by tests
DPModel,
)
Comment on lines +9 to +11

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dp_model
begins an import cycle.


class DOSModel(DPModel):
model_type = "dos"

Check warning on line 15 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L14-L15

Added lines #L14 - L15 were not covered by tests

def __init__(

Check warning on line 17 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L17

Added line #L17 was not covered by tests
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

Check warning on line 22 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L22

Added line #L22 was not covered by tests

def forward(

Check warning on line 24 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L24

Added line #L24 was not covered by tests
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(

Check warning on line 33 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L33

Added line #L33 was not covered by tests
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["dos_redu"]

Check warning on line 44 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L41-L44

Added lines #L41 - L44 were not covered by tests

if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]

Check warning on line 47 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L46-L47

Added lines #L46 - L47 were not covered by tests
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
return model_predict

Check warning on line 51 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L49-L51

Added lines #L49 - L51 were not covered by tests

@torch.jit.export
def forward_lower(

Check warning on line 54 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L53-L54

Added lines #L53 - L54 were not covered by tests
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(

Check warning on line 64 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L64

Added line #L64 was not covered by tests
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["energy_redu"]

Check warning on line 76 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L73-L76

Added lines #L73 - L76 were not covered by tests

else:
model_predict = model_ret
return model_predict

Check warning on line 80 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L79-L80

Added lines #L79 - L80 were not covered by tests
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.dos import (

Check warning on line 21 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L21

Added line #L21 was not covered by tests
DOSFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
Expand Down Expand Up @@ -45,6 +48,9 @@
from deepmd.pt.model.model.dipole_model import (
DipoleModel,
)
from deepmd.pt.model.model.dos_model import (

Check warning on line 51 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L51

Added line #L51 was not covered by tests
DOSModel,
)
Comment on lines +51 to +53

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dos_model
begins an import cycle.
from deepmd.pt.model.model.ener_model import (
EnergyModel,
)
Expand All @@ -68,6 +74,8 @@
cls = DipoleModel
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel
elif isinstance(fitting, DOSFittingNet):
cls = DOSModel

Check warning on line 78 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L77-L78

Added lines #L77 - L78 were not covered by tests
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)

Expand Down
90 changes: 90 additions & 0 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (

Check warning on line 4 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L2-L4

Added lines #L2 - L4 were not covered by tests
List,
Optional,
)

import torch

Check warning on line 9 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L9

Added line #L9 was not covered by tests

from deepmd.pt.model.task.ener import (

Check warning on line 11 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L11

Added line #L11 was not covered by tests
InvarFitting,
)
from deepmd.pt.model.task.fitting import (

Check warning on line 14 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L14

Added line #L14 was not covered by tests
Fitting,
)
from deepmd.pt.utils import (

Check warning on line 17 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L17

Added line #L17 was not covered by tests
env,
)
from deepmd.pt.utils.env import (

Check warning on line 20 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L20

Added line #L20 was not covered by tests
DEFAULT_PRECISION,
)
from deepmd.utils.version import (

Check warning on line 23 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L23

Added line #L23 was not covered by tests
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE

Check warning on line 28 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L27-L28

Added lines #L27 - L28 were not covered by tests

log = logging.getLogger(__name__)

Check warning on line 30 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L30

Added line #L30 was not covered by tests


@Fitting.register("dos")
class DOSFittingNet(InvarFitting):
def __init__(

Check warning on line 35 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L33-L35

Added lines #L33 - L35 were not covered by tests
self,
ntypes: int,
dim_descrpt: int,
neuron: List[int] = [128, 128, 128],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dos: int = 300,
rcond: Optional[float] = None,
bias_dos: Optional[torch.Tensor] = None,
trainable: Optional[List[bool]] = None,
seed: Optional[int] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
exclude_types: List[int] = [],
mixed_types: bool = True,
**kwargs,
):
super().__init__(

Check warning on line 54 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L54

Added line #L54 was not covered by tests
var_name="dos",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=numb_dos,
neuron=neuron,
bias_atom_e=bias_dos,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
rcond=rcond,
seed=seed,
exclude_types=exclude_types,
trainable=trainable,
**kwargs,
)

@classmethod
def deserialize(cls, data: dict) -> "InvarFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)

Check warning on line 80 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L74-L80

Added lines #L74 - L80 were not covered by tests

def serialize(self) -> dict:

Check warning on line 82 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L82

Added line #L82 was not covered by tests
"""Serialize the fitting to dict."""
return {

Check warning on line 84 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L84

Added line #L84 was not covered by tests
**super().serialize(),
"type": "dos",
}

# make jit happy with torch 2.0.0
exclude_types: List[int]

Check warning on line 90 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L90

Added line #L90 was not covered by tests
Loading
Loading