From 4e8f8fe48f5f598b5f179cc8116bfd74680057a2 Mon Sep 17 00:00:00 2001 From: anyangml Date: Fri, 8 Mar 2024 06:56:57 +0000 Subject: [PATCH] fix: refacotr version check --- deepmd/dpmodel/fitting/dipole_fitting.py | 10 ++++++++++ deepmd/dpmodel/fitting/ener_fitting.py | 4 ++++ deepmd/dpmodel/fitting/general_fitting.py | 5 ----- deepmd/dpmodel/fitting/invar_fitting.py | 11 ++++++++++- deepmd/dpmodel/fitting/polarizability_fitting.py | 10 ++++++++++ deepmd/pt/model/task/dipole.py | 11 +++++++++++ deepmd/pt/model/task/ener.py | 11 +++++++++++ deepmd/pt/model/task/fitting.py | 4 ---- deepmd/pt/model/task/polarizability.py | 10 ++++++++++ deepmd/tf/fit/polar.py | 2 +- source/tests/consistent/common.py | 5 +++++ 11 files changed, 72 insertions(+), 11 deletions(-) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index e00f031549..2b19be0b45 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Any, Dict, @@ -23,6 +24,9 @@ from .general_fitting import ( GeneralFitting, ) +from deepmd.utils.version import ( + check_version_compatibility, +) @BaseFitting.register("dipole") @@ -152,6 +156,12 @@ def serialize(self) -> dict: data["r_differentiable"] = self.r_differentiable data["c_differentiable"] = self.c_differentiable return data + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + return super().deserialize(data) def output_def(self): return FittingOutputDef( diff --git a/deepmd/dpmodel/fitting/ener_fitting.py b/deepmd/dpmodel/fitting/ener_fitting.py index 3a0e9909b9..7f83f1e886 100644 --- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -18,6 +18,9 @@ from deepmd.dpmodel.fitting.general_fitting import ( GeneralFitting, ) +from deepmd.utils.version import ( + check_version_compatibility, +) @InvarFitting.register("ener") @@ -69,6 +72,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("var_name") data.pop("dim_out") return super().deserialize(data) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 79927b276f..40283db205 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -21,10 +21,6 @@ FittingNet, NetworkCollection, ) -from deepmd.utils.version import ( - check_version_compatibility, -) - from .base_fitting import ( BaseFitting, ) @@ -260,7 +256,6 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("@class") data.pop("type") variables = data.pop("@variables") diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index fd556ff074..392853c3be 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Any, Dict, @@ -20,7 +21,9 @@ from .general_fitting import ( GeneralFitting, ) - +from deepmd.utils.version import ( + check_version_compatibility, +) @GeneralFitting.register("invar") @fitting_check_output @@ -168,6 +171,12 @@ def serialize(self) -> dict: data["dim_out"] = self.dim_out data["atom_ener"] = self.atom_ener return data + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + return super().deserialize(data) def _net_out_dim(self): """Set the FittingNet output dim.""" diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 4ddc0b387d..58c84bb3cf 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Any, Dict, @@ -27,6 +28,9 @@ GeneralFitting, ) +from deepmd.utils.version import ( + check_version_compatibility, +) @BaseFitting.register("polar") @fitting_check_output @@ -180,6 +184,12 @@ def serialize(self) -> dict: data["@variables"]["scale"] = self.scale data["@variables"]["constant_matrix"] = self.constant_matrix return data + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 2, 1) + return super().deserialize(data) def output_def(self): return FittingOutputDef( diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 21372888d6..1da0573ff7 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging from typing import ( Callable, @@ -26,6 +27,10 @@ DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) + log = logging.getLogger(__name__) @@ -122,6 +127,12 @@ def serialize(self) -> dict: data["r_differentiable"] = self.r_differentiable data["c_differentiable"] = self.c_differentiable return data + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + return super().deserialize(data) def output_def(self) -> FittingOutputDef: return FittingOutputDef( diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index b593ddc3cc..05ded09f50 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -36,6 +36,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -139,6 +142,13 @@ def serialize(self) -> dict: data["dim_out"] = self.dim_out data["atom_ener"] = self.atom_ener return data + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + return super().deserialize(data) + def compute_output_stats( self, @@ -241,6 +251,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("var_name") data.pop("dim_out") return super().deserialize(data) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 2331bb5737..69c21b2020 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -49,9 +49,6 @@ from deepmd.utils.finetune import ( change_energy_bias_lower, ) -from deepmd.utils.version import ( - check_version_compatibility, -) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -371,7 +368,6 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) variables = data.pop("@variables") nets = data.pop("nets") obj = cls(**data) diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 5f2a902d3b..e88725a4e2 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging from typing import ( Callable, @@ -33,6 +34,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) log = logging.getLogger(__name__) @@ -159,6 +163,12 @@ def serialize(self) -> dict: data["@variables"]["scale"] = to_numpy_array(self.scale) data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix) return data + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 2, 1) + return super().deserialize(data) def output_def(self) -> FittingOutputDef: return FittingOutputDef( diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 08b0740459..6c5fa1807f 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -586,7 +586,7 @@ def deserialize(cls, data: dict, suffix: str): The deserialized model """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 1), 2, 1) # to allow PT version. fitting = cls(**data) fitting.fitting_net_variables = cls.deserialize_network( data["nets"], diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 622e2ed3cf..5b11bf3794 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -257,6 +257,11 @@ def test_tf_consistent_with_ref(self): common_keys = set(data1.keys()) & set(data2.keys()) data1 = {k: data1[k] for k in common_keys} data2 = {k: data2[k] for k in common_keys} + + # not comparing version + data1.pop("@version") + data2.pop("@version") + np.testing.assert_equal(data1, data2) for rr1, rr2 in zip(ret1, ret2): np.testing.assert_allclose(