Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add polar stat constant matrix calculation to PT #3426

Merged
merged 35 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
97548d3
feat: add constant_matrix calc
anyangml Mar 6, 2024
1ae79a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
2da5a35
feat: add output
anyangml Mar 6, 2024
fa2dc20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
abb3ca4
fix: typo
anyangml Mar 6, 2024
2503147
fix UTs
anyangml Mar 7, 2024
38969ff
fix: reuse out_stat
anyangml Mar 7, 2024
1734875
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
b261abc
fix: atomic stat
anyangml Mar 7, 2024
b81ec36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
ee28334
feat: add UTs
anyangml Mar 7, 2024
a742f49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
590aeba
fix: precommit
anyangml Mar 7, 2024
7748723
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
254c198
fix: precommit
anyangml Mar 7, 2024
8699c28
fix: UTs
anyangml Mar 7, 2024
4d22a5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
62e0386
fix UTs
anyangml Mar 7, 2024
78f1483
fix: CUDA
anyangml Mar 7, 2024
b9090c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
cd3ac47
fix: serialize
anyangml Mar 8, 2024
f416c99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
71ae5e3
Merge branch 'devel' into feat/polar_stat
anyangml Mar 8, 2024
cacbfb6
chore: bump version
anyangml Mar 8, 2024
4e8f8fe
fix: refacotr version check
anyangml Mar 8, 2024
3de4709
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
831158c
chore: refactor
anyangml Mar 8, 2024
543fb6e
Merge branch 'devel' into feat/polar_stat
anyangml Mar 8, 2024
5b3cb4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
2ffca0b
fix: precommit
anyangml Mar 8, 2024
4f63e95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
e4086fb
fix: precommit
anyangml Mar 8, 2024
edee7e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
042c86b
Merge branch 'devel' into feat/polar_stat
anyangml Mar 11, 2024
9ad2c97
fix: typo
anyangml Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -19,6 +20,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -153,6 +157,12 @@ def serialize(self) -> dict:
data["c_differentiable"] = self.c_differentiable
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def output_def(self):
return FittingOutputDef(
[
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)
from deepmd.utils.version import (
check_version_compatibility,
)


@InvarFitting.register("ener")
Expand Down Expand Up @@ -69,6 +72,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
4 changes: 0 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
FittingNet,
NetworkCollection,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .base_fitting import (
BaseFitting,
Expand Down Expand Up @@ -256,7 +253,6 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -16,6 +17,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -169,6 +173,12 @@ def serialize(self) -> dict:
data["atom_ener"] = self.atom_ener
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def _net_out_dim(self):
"""Set the FittingNet output dim."""
return self.dim_out
Expand Down
35 changes: 35 additions & 0 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -22,6 +23,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -139,6 +143,7 @@
ntypes, 1
)
self.shift_diag = shift_diag
self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION)
super().__init__(
var_name=var_name,
ntypes=ntypes,
Expand Down Expand Up @@ -168,15 +173,36 @@
else self.embedding_width * self.embedding_width
)

def __setitem__(self, key, value):
if key in ["constant_matrix"]:
self.constant_matrix = value
else:
super().__setitem__(key, value)

def __getitem__(self, key):
Fixed Show fixed Hide fixed
if key in ["constant_matrix"]:
return self.constant_matrix

Check warning on line 184 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L183-L184

Added lines #L183 - L184 were not covered by tests
else:
return super().__getitem__(key)

Check warning on line 186 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L186

Added line #L186 was not covered by tests

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["@version"] = 2
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
data["@variables"]["scale"] = self.scale
data["@variables"]["constant_matrix"] = self.constant_matrix
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
return super().deserialize(data)

def output_def(self):
return FittingOutputDef(
[
Expand Down Expand Up @@ -246,4 +272,13 @@
"bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
) # (nframes * nloc, 3, 3)
out = out.reshape(nframes, nloc, 3, 3)
if self.shift_diag:
bias = self.constant_matrix[atype]
# (nframes, nloc, 1)
bias = np.expand_dims(bias, axis=-1) * self.scale[atype]
eye = np.eye(3)
eye = np.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
bias = np.expand_dims(bias, axis=-1) * eye
out = out + bias
return {self.var_name: out}
10 changes: 10 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (
Callable,
Expand All @@ -25,6 +26,9 @@
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,6 +127,12 @@ def serialize(self) -> dict:
data["c_differentiable"] = self.c_differentiable
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down
10 changes: 10 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -140,6 +143,12 @@ def serialize(self) -> dict:
data["atom_ener"] = self.atom_ener
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def compute_output_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
Expand Down Expand Up @@ -241,6 +250,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@
from deepmd.utils.finetune import (
change_energy_bias_lower,
)
from deepmd.utils.version import (
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -371,7 +368,6 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
variables = data.pop("@variables")
nets = data.pop("nets")
obj = cls(**data)
Expand Down
Loading