Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add DOS net #3452

Merged
merged 32 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
159c4c6
feat: add dos net
anyangml Mar 12, 2024
1f8c74c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
f9b0b06
feat: add dp
anyangml Mar 12, 2024
a990122
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
5ce0e71
fix: serialize
anyangml Mar 13, 2024
9e396e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
62f4150
fix: dim_out serialize
anyangml Mar 13, 2024
40ee0f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
812d563
fix: UTs
anyangml Mar 13, 2024
407eb48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
8ae55d3
fix: UTs
anyangml Mar 13, 2024
ac5f1be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
fc9d01a
feat: add UTs
anyangml Mar 13, 2024
81a16b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
ef008d5
Merge branch 'devel' into feat/dos
anyangml Mar 13, 2024
8a7c250
feat: add training
anyangml Mar 13, 2024
e212776
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
4b985b4
fix: hack consistency UT
anyangml Mar 14, 2024
3825274
Merge branch 'devel' into feat/dos
anyangml Mar 14, 2024
a61f462
fix: remove UT hack
Mar 14, 2024
e5b6cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
06528f4
fix: precommit
Mar 14, 2024
6a1f995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
d7ecbba
fix: UTs
anyangml Mar 14, 2024
3916f5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
9f3b47d
Merge branch 'devel' into feat/dos
anyangml Mar 14, 2024
71f24ff
fix: update tf UTs
anyangml Mar 14, 2024
ad1d5ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
73313f0
fix: deep test
anyangml Mar 15, 2024
5cc34df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
60315ee
Merge branch 'devel' into feat/dos
anyangml Mar 15, 2024
784d7b9
fix: address comments
anyangml Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.dos import (
DOSFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
Expand Down Expand Up @@ -51,6 +54,7 @@
from deepmd.pt.model.model.polar_model import (
PolarModel,
)
from deepmd.pt.model.model.dos_model
Fixed Show fixed Hide fixed

if atomic_model_ is not None:
fitting = atomic_model_.fitting_net
Expand All @@ -68,6 +72,8 @@
cls = DipoleModel
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel
elif isinstance(fitting, DOSFittingNet):
cls = DOSModel
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)

Expand Down
86 changes: 86 additions & 0 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (

Check warning on line 4 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L2-L4

Added lines #L2 - L4 were not covered by tests
List,
Optional,
)

import torch

Check warning on line 9 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L9

Added line #L9 was not covered by tests

from deepmd.pt.model.task.fitting import (

Check warning on line 11 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L11

Added line #L11 was not covered by tests
InvarFitting,
)
from deepmd.pt.utils import (

Check warning on line 14 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L14

Added line #L14 was not covered by tests
env,
)
from deepmd.pt.utils.env import (

Check warning on line 17 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L17

Added line #L17 was not covered by tests
DEFAULT_PRECISION,
)
from deepmd.utils.version import (

Check warning on line 20 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L20

Added line #L20 was not covered by tests
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE

Check warning on line 25 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L24-L25

Added lines #L24 - L25 were not covered by tests

log = logging.getLogger(__name__)

Check warning on line 27 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L27

Added line #L27 was not covered by tests


@Fitting.register("dos")
class DOSFittingNet(InvarFitting):
def __init__(

Check warning on line 32 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L30-L32

Added lines #L30 - L32 were not covered by tests
self,
ntypes: int,
dim_descrpt: int,
neuron: List[int] = [128, 128, 128],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dos: int = 300,
rcond: Optional[float] = None,
bias_dos: Optional[torch.Tensor] = None,
trainable: Optional[List[bool]] = None,
seed: Optional[int] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
exclude_types: List[int] = [],
mixed_types: bool = True,
**kwargs,
):
super().__init__(

Check warning on line 51 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L51

Added line #L51 was not covered by tests
var_name="dos",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=numb_dos,
neuron=neuron,
bias_atom_e=bias_dos,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
rcond=rcond,
seed=seed,
exclude_types=exclude_types,
**kwargs,
)

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)

Check warning on line 76 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L70-L76

Added lines #L70 - L76 were not covered by tests

def serialize(self) -> dict:

Check warning on line 78 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L78

Added line #L78 was not covered by tests
"""Serialize the fitting to dict."""
return {

Check warning on line 80 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L80

Added line #L80 was not covered by tests
**super().serialize(),
"type": "dos",
}

# make jit happy with torch 2.0.0
exclude_types: List[int]

Check warning on line 86 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L86

Added line #L86 was not covered by tests
Loading