From 2ee8a3b1e7bce6a4fcfbccaf3cd16c6e490e7ba0 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 11 Mar 2024 12:00:31 +0800 Subject: [PATCH] Feat: Add polar stat constant matrix calculation to PT (#3426) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/dpmodel/fitting/dipole_fitting.py | 10 ++ deepmd/dpmodel/fitting/ener_fitting.py | 4 + deepmd/dpmodel/fitting/general_fitting.py | 4 - deepmd/dpmodel/fitting/invar_fitting.py | 10 ++ .../dpmodel/fitting/polarizability_fitting.py | 35 ++++++ deepmd/pt/model/task/dipole.py | 10 ++ deepmd/pt/model/task/ener.py | 10 ++ deepmd/pt/model/task/fitting.py | 4 - deepmd/pt/model/task/polarizability.py | 102 +++++++++++++++++- deepmd/tf/fit/polar.py | 18 ++-- source/tests/consistent/common.py | 5 + source/tests/pt/model/test_polar_stat.py | 75 +++++++++++++ source/tests/pt/test_training.py | 3 + 13 files changed, 272 insertions(+), 18 deletions(-) create mode 100644 source/tests/pt/model/test_polar_stat.py diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index e00f031549..6d6324770c 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Any, Dict, @@ -19,6 +20,9 @@ OutputVariableDef, fitting_check_output, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .general_fitting import ( GeneralFitting, @@ -153,6 +157,12 @@ def serialize(self) -> dict: data["c_differentiable"] = self.c_differentiable return data + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + return super().deserialize(data) + def output_def(self): return FittingOutputDef( [ diff --git a/deepmd/dpmodel/fitting/ener_fitting.py b/deepmd/dpmodel/fitting/ener_fitting.py index 3a0e9909b9..7f83f1e886 100644 --- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -18,6 +18,9 @@ from deepmd.dpmodel.fitting.general_fitting import ( GeneralFitting, ) +from deepmd.utils.version import ( + check_version_compatibility, +) @InvarFitting.register("ener") @@ -69,6 +72,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("var_name") data.pop("dim_out") return super().deserialize(data) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 01bf107c63..e9dddae2de 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -21,9 +21,6 @@ FittingNet, NetworkCollection, ) -from deepmd.utils.version import ( - check_version_compatibility, -) from .base_fitting import ( BaseFitting, @@ -256,7 +253,6 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("@class") data.pop("type") variables = data.pop("@variables") diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index fd556ff074..e795953a75 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Any, Dict, @@ -16,6 +17,9 @@ OutputVariableDef, fitting_check_output, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .general_fitting import ( GeneralFitting, @@ -169,6 +173,12 @@ def serialize(self) -> dict: data["atom_ener"] = self.atom_ener return data + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + return super().deserialize(data) + def _net_out_dim(self): """Set the FittingNet output dim.""" return self.dim_out diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 4f7c33b9a8..5d75037137 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Any, Dict, @@ -22,6 +23,9 @@ OutputVariableDef, fitting_check_output, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .general_fitting import ( GeneralFitting, @@ -139,6 +143,7 @@ def __init__( ntypes, 1 ) self.shift_diag = shift_diag + self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION) super().__init__( var_name=var_name, ntypes=ntypes, @@ -168,15 +173,36 @@ def _net_out_dim(self): else self.embedding_width * self.embedding_width ) + def __setitem__(self, key, value): + if key in ["constant_matrix"]: + self.constant_matrix = value + else: + super().__setitem__(key, value) + + def __getitem__(self, key): + if key in ["constant_matrix"]: + return self.constant_matrix + else: + return super().__getitem__(key) + def serialize(self) -> dict: data = super().serialize() data["type"] = "polar" + data["@version"] = 2 data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl data["fit_diag"] = self.fit_diag + data["shift_diag"] = self.shift_diag data["@variables"]["scale"] = self.scale + data["@variables"]["constant_matrix"] = self.constant_matrix return data + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 2, 1) + return super().deserialize(data) + def output_def(self): return FittingOutputDef( [ @@ -246,4 +272,13 @@ def call( "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out ) # (nframes * nloc, 3, 3) out = out.reshape(nframes, nloc, 3, 3) + if self.shift_diag: + bias = self.constant_matrix[atype] + # (nframes, nloc, 1) + bias = np.expand_dims(bias, axis=-1) * self.scale[atype] + eye = np.eye(3) + eye = np.tile(eye, (nframes, nloc, 1, 1)) + # (nframes, nloc, 3, 3) + bias = np.expand_dims(bias, axis=-1) * eye + out = out + bias return {self.var_name: out} diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 21372888d6..b8892c2d95 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging from typing import ( Callable, @@ -25,6 +26,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) log = logging.getLogger(__name__) @@ -123,6 +127,12 @@ def serialize(self) -> dict: data["c_differentiable"] = self.c_differentiable return data + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + return super().deserialize(data) + def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index b593ddc3cc..55ffd8c650 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -36,6 +36,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -140,6 +143,12 @@ def serialize(self) -> dict: data["atom_ener"] = self.atom_ener return data + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + return super().deserialize(data) + def compute_output_stats( self, merged: Union[Callable[[], List[dict]], List[dict]], @@ -241,6 +250,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("var_name") data.pop("dim_out") return super().deserialize(data) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 09f8563bfb..48ffe34084 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -49,9 +49,6 @@ from deepmd.utils.finetune import ( change_energy_bias_lower, ) -from deepmd.utils.version import ( - check_version_compatibility, -) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -371,7 +368,6 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) variables = data.pop("@variables") nets = data.pop("nets") obj = cls(**data) diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index fa4f6d7f37..eb6ccc2b7d 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging from typing import ( Callable, @@ -7,6 +8,7 @@ Union, ) +import numpy as np import torch from deepmd.dpmodel import ( @@ -25,9 +27,16 @@ from deepmd.pt.utils.utils import ( to_numpy_array, ) +from deepmd.utils.out_stat import ( + compute_stats_from_atomic, + compute_stats_from_redu, +) from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) log = logging.getLogger(__name__) @@ -114,6 +123,9 @@ def __init__( self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE ).view(ntypes, 1) self.shift_diag = shift_diag + self.constant_matrix = torch.zeros( + ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) super().__init__( var_name=kwargs.pop("var_name", "polar"), ntypes=ntypes, @@ -140,16 +152,36 @@ def _net_out_dim(self): else self.embedding_width * self.embedding_width ) + def __setitem__(self, key, value): + if key in ["constant_matrix"]: + self.constant_matrix = value + else: + super().__setitem__(key, value) + + def __getitem__(self, key): + if key in ["constant_matrix"]: + return self.constant_matrix + else: + return super().__getitem__(key) + def serialize(self) -> dict: data = super().serialize() data["type"] = "polar" + data["@version"] = 2 data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl data["fit_diag"] = self.fit_diag - data["fit_diag"] = self.fit_diag + data["shift_diag"] = self.shift_diag data["@variables"]["scale"] = to_numpy_array(self.scale) + data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix) return data + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 2, 1) + return super().deserialize(data) + def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ @@ -167,7 +199,7 @@ def compute_output_stats( self, merged: Union[Callable[[], List[dict]], List[dict]], stat_file_path: Optional[DPPath] = None, - ): + ) -> None: """ Compute the output statistics (e.g. energy bias) for the fitting net from packed data. @@ -184,7 +216,60 @@ def compute_output_stats( The path to the stat file. """ - pass + if self.shift_diag: + if stat_file_path is not None: + stat_file_path = stat_file_path / "constant_matrix" + if stat_file_path is not None and stat_file_path.is_file(): + constant_matrix = stat_file_path.load_numpy() + else: + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + + sys_constant_matrix = [] + for sys in range(len(sampled)): + nframs = sampled[sys]["type"].shape[0] + + if sampled[sys]["find_atomic_polarizability"] > 0.0: + sys_atom_polar = compute_stats_from_atomic( + sampled[sys]["atomic_polarizability"].numpy(force=True), + sampled[sys]["type"].numpy(force=True), + )[0] + else: + if not sampled[sys]["find_polarizability"] > 0.0: + continue + sys_type_count = np.zeros( + (nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION + ) + for itype in range(self.ntypes): + type_mask = sampled[sys]["type"] == itype + sys_type_count[:, itype] = type_mask.sum(dim=1).numpy( + force=True + ) + + sys_bias_redu = sampled[sys]["polarizability"].numpy(force=True) + + sys_atom_polar = compute_stats_from_redu( + sys_bias_redu, sys_type_count, rcond=self.rcond + )[0] + cur_constant_matrix = np.zeros( + self.ntypes, dtype=env.GLOBAL_NP_FLOAT_PRECISION + ) + + for itype in range(self.ntypes): + cur_constant_matrix[itype] = np.mean( + np.diagonal(sys_atom_polar[itype].reshape(3, 3)) + ) + sys_constant_matrix.append(cur_constant_matrix) + constant_matrix = np.stack(sys_constant_matrix).mean(axis=0) + + # handle nan values. + constant_matrix = np.nan_to_num(constant_matrix) + if stat_file_path is not None: + stat_file_path.save_numpy(constant_matrix) + self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE) def forward( self, @@ -218,5 +303,16 @@ def forward( "bim,bmj->bij", gr.transpose(1, 2), out ) # (nframes * nloc, 3, 3) out = out.view(nframes, nloc, 3, 3) + if self.shift_diag: + bias = self.constant_matrix[atype] + + # (nframes, nloc, 1) + bias = bias.unsqueeze(-1) * self.scale[atype] + + eye = torch.eye(3, device=env.DEVICE) + eye = eye.repeat(nframes, nloc, 1, 1) + # (nframes, nloc, 3, 3) + bias = bias.unsqueeze(-1) * eye + out = out + bias return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 7ac31809f3..41ea989521 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -183,6 +183,7 @@ def compute_output_stats(self, all_stat): mean_polar = np.zeros([len(self.sel_type), 9]) sys_matrix, polar_bias = [], [] for ss in range(len(all_stat["type"])): + nframes = all_stat["type"][ss].shape[0] atom_has_polar = [ w for w in all_stat["type"][ss][0] if (w in self.sel_type) ] # select atom with polar @@ -193,7 +194,7 @@ def compute_output_stats(self, all_stat): index_lis = [ index for index, w in enumerate(atom_has_polar) - if atom_has_polar[index] == self.sel_type[itype] + if w == self.sel_type[itype] ] # select index in this type sys_matrix.append(np.zeros((1, len(self.sel_type)))) @@ -201,10 +202,9 @@ def compute_output_stats(self, all_stat): polar_bias.append( np.sum( - all_stat["atomic_polarizability"][ss].reshape((-1, 9))[ - index_lis - ], - axis=0, + all_stat["atomic_polarizability"][ss][:, index_lis, :] + / nframes, + axis=(0, 1), ).reshape((1, 9)) ) else: # No atomic polar in this system, so it should have global polar @@ -228,7 +228,9 @@ def compute_output_stats(self, all_stat): sys_matrix[-1][0, itype] = len(index_lis) # add polar_bias - polar_bias.append(all_stat["polarizability"][ss].reshape((1, 9))) + polar_bias.append( + np.mean(all_stat["polarizability"][ss], axis=0).reshape((1, 9)) + ) matrix, bias = ( np.concatenate(sys_matrix, axis=0), @@ -584,7 +586,9 @@ def deserialize(cls, data: dict, suffix: str): The deserialized model """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility( + data.pop("@version", 1), 2, 1 + ) # to allow PT version. fitting = cls(**data) fitting.fitting_net_variables = cls.deserialize_network( data["nets"], diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 622e2ed3cf..5a35ced0a1 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -257,6 +257,11 @@ def test_tf_consistent_with_ref(self): common_keys = set(data1.keys()) & set(data2.keys()) data1 = {k: data1[k] for k in common_keys} data2 = {k: data2[k] for k in common_keys} + + # not comparing version + data1.pop("@version") + data2.pop("@version") + np.testing.assert_equal(data1, data2) for rr1, rr2 in zip(ret1, ret2): np.testing.assert_allclose( diff --git a/source/tests/pt/model/test_polar_stat.py b/source/tests/pt/model/test_polar_stat.py new file mode 100644 index 0000000000..ca3b037011 --- /dev/null +++ b/source/tests/pt/model/test_polar_stat.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.model.task.polarizability import ( + PolarFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.tf.fit.polar import ( + PolarFittingSeA, +) + + +class TestConsistency(unittest.TestCase): + def setUp(self) -> None: + types = torch.randint(0, 4, (1, 5), device=env.DEVICE) + types = torch.cat((types, types, types), dim=0) + types[:, -1] = 3 + ntypes = 4 + atomic_polarizability = torch.rand((3, 5, 9), device=env.DEVICE) + polarizability = torch.rand((3, 9), device=env.DEVICE) + find_polarizability = torch.rand(1, device=env.DEVICE) + find_atomic_polarizability = torch.rand(1, device=env.DEVICE) + self.sampled = [ + { + "type": types, + "find_atomic_polarizability": find_atomic_polarizability, + "atomic_polarizability": atomic_polarizability, + "polarizability": polarizability, + "find_polarizability": find_polarizability, + } + ] + self.all_stat = { + k: [v.numpy(force=True)] for d in self.sampled for k, v in d.items() + } + self.tfpolar = PolarFittingSeA( + ntypes=ntypes, + dim_descrpt=1, + embedding_width=1, + sel_type=list(range(ntypes)), + ) + self.ptpolar = PolarFittingNet( + ntypes=ntypes, + dim_descrpt=1, + embedding_width=1, + ) + + def test_atomic_consistency(self): + self.tfpolar.compute_output_stats(self.all_stat) + tfbias = self.tfpolar.constant_matrix + self.ptpolar.compute_output_stats(self.sampled) + ptbias = self.ptpolar.constant_matrix + np.testing.assert_allclose(tfbias, to_numpy_array(ptbias)) + + def test_global_consistency(self): + self.sampled[0]["find_atomic_polarizability"] = -1 + self.sampled[0]["polarizability"] = self.sampled[0][ + "atomic_polarizability" + ].sum(dim=1) + self.all_stat["find_atomic_polarizability"] = [-1] + self.all_stat["polarizability"] = [ + self.all_stat["atomic_polarizability"][0].sum(axis=1) + ] + self.tfpolar.compute_output_stats(self.all_stat) + tfbias = self.tfpolar.constant_matrix + self.ptpolar.compute_output_stats(self.sampled) + ptbias = self.ptpolar.constant_matrix + np.testing.assert_allclose(tfbias, to_numpy_array(ptbias), rtol=1e-5, atol=1e-5) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index db69a1bcea..d3b6bd67b5 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -293,6 +293,7 @@ def setUp(self): self.config["model"]["atom_exclude_types"] = [1] self.config["model"]["fitting_net"]["type"] = "polar" self.config["model"]["fitting_net"]["fit_diag"] = False + self.config["model"]["fitting_net"]["shift_diag"] = False self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 # can not set requires_grad false for all parameters, @@ -326,6 +327,7 @@ def setUp(self): self.config["model"]["atom_exclude_types"] = [1] self.config["model"]["fitting_net"]["type"] = "polar" self.config["model"]["fitting_net"]["fit_diag"] = False + self.config["model"]["fitting_net"]["shift_diag"] = False self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 # can not set requires_grad false for all parameters, @@ -359,6 +361,7 @@ def setUp(self): self.config["model"]["atom_exclude_types"] = [1] self.config["model"]["fitting_net"]["type"] = "polar" self.config["model"]["fitting_net"]["fit_diag"] = False + self.config["model"]["fitting_net"]["shift_diag"] = False self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 # can not set requires_grad false for all parameters,