Skip to content

Commit

Permalink
fix(tf): fix model out_bias deserialize
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 13, 2024
1 parent 7416c9f commit 72957bb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
33 changes: 30 additions & 3 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from deepmd.common import (
j_get_type,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.tf.descriptor.descriptor import (
Descriptor,
)
Expand Down Expand Up @@ -806,6 +809,17 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
data = data.copy()
check_version_compatibility(data.pop("@version", 2), 2, 1)
descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix)
if data["fitting"].get("@variables", {}).get("bias_atom_e") is not None:
# careful: copy each level and don't modify the input array,
# otherwise it will affect the original data
# deepcopy is not used for performance reasons
data["fitting"] = data["fitting"].copy()
data["fitting"]["@variables"] = data["fitting"]["@variables"].copy()
data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][
"@variables"
]["bias_atom_e"] + data["@variables"]["out_bias"].reshape(
data["fitting"]["@variables"]["bias_atom_e"].shape
)
fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix)
# pass descriptor type embedding to model
if descriptor.explicit_ntypes:
Expand All @@ -814,8 +828,10 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
else:
type_embedding = None
# BEGINE not supported keys
data.pop("atom_exclude_types")
data.pop("pair_exclude_types")
if len(data.pop("atom_exclude_types")) > 0:
raise NotImplementedError("atom_exclude_types is not supported")
if len(data.pop("pair_exclude_types")) > 0:
raise NotImplementedError("pair_exclude_types is not supported")
data.pop("rcond", None)
data.pop("preset_out_bias", None)
data.pop("@variables", None)
Expand Down Expand Up @@ -848,6 +864,17 @@ def serialize(self, suffix: str = "") -> dict:

ntypes = len(self.get_type_map())
dict_fit = self.fitting.serialize(suffix=suffix)
if dict_fit.get("@variables", {}).get("bias_atom_e") is not None:
out_bias = dict_fit["@variables"]["bias_atom_e"].reshape(
[1, ntypes, dict_fit["dim_out"]]
)
dict_fit["@variables"]["bias_atom_e"] = np.zeros_like(
dict_fit["@variables"]["bias_atom_e"]
)
else:
out_bias = np.zeros(
[1, ntypes, dict_fit["dim_out"]], dtype=GLOBAL_NP_FLOAT_PRECISION
)
return {
"@class": "Model",
"type": "standard",
Expand All @@ -861,7 +888,7 @@ def serialize(self, suffix: str = "") -> dict:
"rcond": None,
"preset_out_bias": None,
"@variables": {
"out_bias": np.zeros([1, ntypes, dict_fit["dim_out"]]), # pylint: disable=no-explicit-dtype
"out_bias": out_bias,
"out_std": np.ones([1, ntypes, dict_fit["dim_out"]]), # pylint: disable=no-explicit-dtype
},
}
Expand Down
4 changes: 3 additions & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def pass_data_to_cls(self, cls, data) -> Any:
if cls is EnergyModelDP:
return get_model_dp(data)
elif cls is EnergyModelPT:
return get_model_pt(data)
model = get_model_pt(data)
model.atomic_model.out_bias += 1.0
return model
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.additional_data)
Expand Down

0 comments on commit 72957bb

Please sign in to comment.