Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add polar stat constant matrix calculation to PT #3426

Merged
merged 35 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
97548d3
feat: add constant_matrix calc
anyangml Mar 6, 2024
1ae79a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
2da5a35
feat: add output
anyangml Mar 6, 2024
fa2dc20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
abb3ca4
fix: typo
anyangml Mar 6, 2024
2503147
fix UTs
anyangml Mar 7, 2024
38969ff
fix: reuse out_stat
anyangml Mar 7, 2024
1734875
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
b261abc
fix: atomic stat
anyangml Mar 7, 2024
b81ec36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
ee28334
feat: add UTs
anyangml Mar 7, 2024
a742f49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
590aeba
fix: precommit
anyangml Mar 7, 2024
7748723
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
254c198
fix: precommit
anyangml Mar 7, 2024
8699c28
fix: UTs
anyangml Mar 7, 2024
4d22a5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
62e0386
fix UTs
anyangml Mar 7, 2024
78f1483
fix: CUDA
anyangml Mar 7, 2024
b9090c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
cd3ac47
fix: serialize
anyangml Mar 8, 2024
f416c99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
71ae5e3
Merge branch 'devel' into feat/polar_stat
anyangml Mar 8, 2024
cacbfb6
chore: bump version
anyangml Mar 8, 2024
4e8f8fe
fix: refacotr version check
anyangml Mar 8, 2024
3de4709
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
831158c
chore: refactor
anyangml Mar 8, 2024
543fb6e
Merge branch 'devel' into feat/polar_stat
anyangml Mar 8, 2024
5b3cb4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
2ffca0b
fix: precommit
anyangml Mar 8, 2024
4f63e95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
e4086fb
fix: precommit
anyangml Mar 8, 2024
edee7e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
042c86b
Merge branch 'devel' into feat/polar_stat
anyangml Mar 11, 2024
9ad2c97
fix: typo
anyangml Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -19,6 +20,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -153,6 +157,12 @@
data["c_differentiable"] = self.c_differentiable
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

Check warning on line 164 in deepmd/dpmodel/fitting/dipole_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dipole_fitting.py#L162-L164

Added lines #L162 - L164 were not covered by tests

def output_def(self):
return FittingOutputDef(
[
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)
from deepmd.utils.version import (
check_version_compatibility,
)


@InvarFitting.register("ener")
Expand Down Expand Up @@ -69,6 +72,7 @@
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)

Check warning on line 75 in deepmd/dpmodel/fitting/ener_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/ener_fitting.py#L75

Added line #L75 was not covered by tests
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
4 changes: 0 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
FittingNet,
NetworkCollection,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .base_fitting import (
BaseFitting,
Expand Down Expand Up @@ -256,7 +253,6 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -16,6 +17,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -169,6 +173,12 @@
data["atom_ener"] = self.atom_ener
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

Check warning on line 180 in deepmd/dpmodel/fitting/invar_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/invar_fitting.py#L178-L180

Added lines #L178 - L180 were not covered by tests

def _net_out_dim(self):
"""Set the FittingNet output dim."""
return self.dim_out
Expand Down
35 changes: 35 additions & 0 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -22,6 +23,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -139,6 +143,7 @@
ntypes, 1
)
self.shift_diag = shift_diag
self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION)

Check warning on line 146 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L146

Added line #L146 was not covered by tests
super().__init__(
var_name=var_name,
ntypes=ntypes,
Expand Down Expand Up @@ -168,15 +173,36 @@
else self.embedding_width * self.embedding_width
)

def __setitem__(self, key, value):
if key in ["constant_matrix"]:
self.constant_matrix = value

Check warning on line 178 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L177-L178

Added lines #L177 - L178 were not covered by tests
else:
super().__setitem__(key, value)

Check warning on line 180 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L180

Added line #L180 was not covered by tests

def __getitem__(self, key):
Fixed Show fixed Hide fixed
if key in ["constant_matrix"]:
return self.constant_matrix

Check warning on line 184 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L183-L184

Added lines #L183 - L184 were not covered by tests
else:
super().__getitem__(key)

Check warning on line 186 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L186

Added line #L186 was not covered by tests

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["@version"] = 2

Check warning on line 191 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L191

Added line #L191 was not covered by tests
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag

Check warning on line 195 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L195

Added line #L195 was not covered by tests
data["@variables"]["scale"] = self.scale
data["@variables"]["constant_matrix"] = self.constant_matrix

Check warning on line 197 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L197

Added line #L197 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
return super().deserialize(data)

Check warning on line 204 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L202-L204

Added lines #L202 - L204 were not covered by tests

def output_def(self):
return FittingOutputDef(
[
Expand Down Expand Up @@ -246,4 +272,13 @@
"bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
) # (nframes * nloc, 3, 3)
out = out.reshape(nframes, nloc, 3, 3)
if self.shift_diag:
bias = self.constant_matrix[atype]

Check warning on line 276 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L275-L276

Added lines #L275 - L276 were not covered by tests
# (nframes, nloc, 1)
bias = np.expand_dims(bias, axis=-1) * self.scale[atype]
eye = np.eye(3)
eye = np.tile(eye, (nframes, nloc, 1, 1))

Check warning on line 280 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L278-L280

Added lines #L278 - L280 were not covered by tests
# (nframes, nloc, 3, 3)
bias = np.expand_dims(bias, axis=-1) * eye
out = out + bias

Check warning on line 283 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L282-L283

Added lines #L282 - L283 were not covered by tests
return {self.var_name: out}
10 changes: 10 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy

Check warning on line 2 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L2

Added line #L2 was not covered by tests
import logging
from typing import (
Callable,
Expand All @@ -25,6 +26,9 @@
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (

Check warning on line 29 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L29

Added line #L29 was not covered by tests
check_version_compatibility,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,6 +127,12 @@
data["c_differentiable"] = self.c_differentiable
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

Check warning on line 134 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L130-L134

Added lines #L130 - L134 were not covered by tests

def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down
10 changes: 10 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (

Check warning on line 39 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L39

Added line #L39 was not covered by tests
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -140,6 +143,12 @@
data["atom_ener"] = self.atom_ener
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

Check warning on line 150 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L146-L150

Added lines #L146 - L150 were not covered by tests

def compute_output_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
Expand Down Expand Up @@ -241,6 +250,7 @@
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)

Check warning on line 253 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L253

Added line #L253 was not covered by tests
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@
from deepmd.utils.finetune import (
change_energy_bias_lower,
)
from deepmd.utils.version import (
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -371,7 +368,6 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
variables = data.pop("@variables")
nets = data.pop("nets")
obj = cls(**data)
Expand Down
Loading
Loading