From 72957bb6f9809633c346177babe50181c7638e8b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 13 Nov 2024 03:46:04 -0500 Subject: [PATCH 1/3] fix(tf): fix model out_bias deserialize Signed-off-by: Jinzhe Zeng --- deepmd/tf/model/model.py | 33 ++++++++++++++++++++-- source/tests/consistent/model/test_ener.py | 4 ++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 5e3f99bc2d..d12a101ea4 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -16,6 +16,9 @@ from deepmd.common import ( j_get_type, ) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) from deepmd.tf.descriptor.descriptor import ( Descriptor, ) @@ -806,6 +809,17 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": data = data.copy() check_version_compatibility(data.pop("@version", 2), 2, 1) descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix) + if data["fitting"].get("@variables", {}).get("bias_atom_e") is not None: + # careful: copy each level and don't modify the input array, + # otherwise it will affect the original data + # deepcopy is not used for performance reasons + data["fitting"] = data["fitting"].copy() + data["fitting"]["@variables"] = data["fitting"]["@variables"].copy() + data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][ + "@variables" + ]["bias_atom_e"] + data["@variables"]["out_bias"].reshape( + data["fitting"]["@variables"]["bias_atom_e"].shape + ) fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) # pass descriptor type embedding to model if descriptor.explicit_ntypes: @@ -814,8 +828,10 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": else: type_embedding = None # BEGINE not supported keys - data.pop("atom_exclude_types") - data.pop("pair_exclude_types") + if len(data.pop("atom_exclude_types")) > 0: + raise NotImplementedError("atom_exclude_types is not supported") + if len(data.pop("pair_exclude_types")) > 0: + raise NotImplementedError("pair_exclude_types is not supported") data.pop("rcond", None) data.pop("preset_out_bias", None) data.pop("@variables", None) @@ -848,6 +864,17 @@ def serialize(self, suffix: str = "") -> dict: ntypes = len(self.get_type_map()) dict_fit = self.fitting.serialize(suffix=suffix) + if dict_fit.get("@variables", {}).get("bias_atom_e") is not None: + out_bias = dict_fit["@variables"]["bias_atom_e"].reshape( + [1, ntypes, dict_fit["dim_out"]] + ) + dict_fit["@variables"]["bias_atom_e"] = np.zeros_like( + dict_fit["@variables"]["bias_atom_e"] + ) + else: + out_bias = np.zeros( + [1, ntypes, dict_fit["dim_out"]], dtype=GLOBAL_NP_FLOAT_PRECISION + ) return { "@class": "Model", "type": "standard", @@ -861,7 +888,7 @@ def serialize(self, suffix: str = "") -> dict: "rcond": None, "preset_out_bias": None, "@variables": { - "out_bias": np.zeros([1, ntypes, dict_fit["dim_out"]]), # pylint: disable=no-explicit-dtype + "out_bias": out_bias, "out_std": np.ones([1, ntypes, dict_fit["dim_out"]]), # pylint: disable=no-explicit-dtype }, } diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 5d0253c5e8..f1518cb6e5 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -141,7 +141,9 @@ def pass_data_to_cls(self, cls, data) -> Any: if cls is EnergyModelDP: return get_model_dp(data) elif cls is EnergyModelPT: - return get_model_pt(data) + model = get_model_pt(data) + model.atomic_model.out_bias += 1.0 + return model elif cls is EnergyModelJAX: return get_model_jax(data) return cls(**data, **self.additional_data) From 4939e1e6ae658aee45ac3a12045da5a06290d06c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 13 Nov 2024 05:09:43 -0500 Subject: [PATCH 2/3] Update source/tests/consistent/model/test_ener.py Signed-off-by: Jinzhe Zeng --- source/tests/consistent/model/test_ener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index f1518cb6e5..350cf0a412 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -142,7 +142,7 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: model = get_model_pt(data) - model.atomic_model.out_bias += 1.0 + model.atomic_model.out_bias.uniform_() return model elif cls is EnergyModelJAX: return get_model_jax(data) From f6538a750aa533f4e612e85857ae0d702e85846a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 13 Nov 2024 15:26:24 -0500 Subject: [PATCH 3/3] throw error if both are provided Signed-off-by: Jinzhe Zeng --- deepmd/tf/model/model.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index d12a101ea4..810db67982 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -805,6 +805,11 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": ------- Descriptor The deserialized descriptor + + Raises + ------ + ValueError + If both fitting/@variables/bias_atom_e and @variables/out_bias are non-zero """ data = data.copy() check_version_compatibility(data.pop("@version", 2), 2, 1) @@ -815,6 +820,14 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": # deepcopy is not used for performance reasons data["fitting"] = data["fitting"].copy() data["fitting"]["@variables"] = data["fitting"]["@variables"].copy() + if ( + int(np.any(data["fitting"]["@variables"]["bias_atom_e"])) + + int(np.any(data["@variables"]["out_bias"])) + > 1 + ): + raise ValueError( + "fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero" + ) data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][ "@variables" ]["bias_atom_e"] + data["@variables"]["out_bias"].reshape(