From 3de4709f436475826482984ea7ab08ff131253f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Mar 2024 06:57:26 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/fitting/dipole_fitting.py | 8 ++++---- deepmd/dpmodel/fitting/general_fitting.py | 1 + deepmd/dpmodel/fitting/invar_fitting.py | 9 +++++---- deepmd/dpmodel/fitting/polarizability_fitting.py | 8 ++++---- deepmd/pt/model/task/dipole.py | 3 +-- deepmd/pt/model/task/ener.py | 3 +-- deepmd/pt/model/task/polarizability.py | 2 +- deepmd/tf/fit/polar.py | 4 +++- source/tests/consistent/common.py | 2 +- 9 files changed, 21 insertions(+), 19 deletions(-) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 2b19be0b45..6d6324770c 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -20,13 +20,13 @@ OutputVariableDef, fitting_check_output, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .general_fitting import ( GeneralFitting, ) -from deepmd.utils.version import ( - check_version_compatibility, -) @BaseFitting.register("dipole") @@ -156,7 +156,7 @@ def serialize(self) -> dict: data["r_differentiable"] = self.r_differentiable data["c_differentiable"] = self.c_differentiable return data - + @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 40283db205..9865bf6e30 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -21,6 +21,7 @@ FittingNet, NetworkCollection, ) + from .base_fitting import ( BaseFitting, ) diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index 392853c3be..e795953a75 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -17,13 +17,14 @@ OutputVariableDef, fitting_check_output, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .general_fitting import ( GeneralFitting, ) -from deepmd.utils.version import ( - check_version_compatibility, -) + @GeneralFitting.register("invar") @fitting_check_output @@ -171,7 +172,7 @@ def serialize(self) -> dict: data["dim_out"] = self.dim_out data["atom_ener"] = self.atom_ener return data - + @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 58c84bb3cf..cd3f3682fb 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -23,14 +23,14 @@ OutputVariableDef, fitting_check_output, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .general_fitting import ( GeneralFitting, ) -from deepmd.utils.version import ( - check_version_compatibility, -) @BaseFitting.register("polar") @fitting_check_output @@ -184,7 +184,7 @@ def serialize(self) -> dict: data["@variables"]["scale"] = self.scale data["@variables"]["constant_matrix"] = self.constant_matrix return data - + @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 1da0573ff7..b8892c2d95 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -26,7 +26,6 @@ from deepmd.utils.path import ( DPPath, ) - from deepmd.utils.version import ( check_version_compatibility, ) @@ -127,7 +126,7 @@ def serialize(self) -> dict: data["r_differentiable"] = self.r_differentiable data["c_differentiable"] = self.c_differentiable return data - + @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 05ded09f50..55ffd8c650 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -142,14 +142,13 @@ def serialize(self) -> dict: data["dim_out"] = self.dim_out data["atom_ener"] = self.atom_ener return data - + @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) return super().deserialize(data) - def compute_output_stats( self, merged: Union[Callable[[], List[dict]], List[dict]], diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index e88725a4e2..bd19844277 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -163,7 +163,7 @@ def serialize(self) -> dict: data["@variables"]["scale"] = to_numpy_array(self.scale) data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix) return data - + @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 6c5fa1807f..41ea989521 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -586,7 +586,9 @@ def deserialize(cls, data: dict, suffix: str): The deserialized model """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) # to allow PT version. + check_version_compatibility( + data.pop("@version", 1), 2, 1 + ) # to allow PT version. fitting = cls(**data) fitting.fitting_net_variables = cls.deserialize_network( data["nets"], diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 5b11bf3794..5a35ced0a1 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -257,7 +257,7 @@ def test_tf_consistent_with_ref(self): common_keys = set(data1.keys()) & set(data2.keys()) data1 = {k: data1[k] for k in common_keys} data2 = {k: data2[k] for k in common_keys} - + # not comparing version data1.pop("@version") data2.pop("@version")