From a14fed24c16a7f0e21020a57865084bbbb1a7e26 Mon Sep 17 00:00:00 2001 From: Chengqian-Zhang <2000011006@stu.pku.edu.cn> Date: Mon, 3 Jun 2024 09:47:51 +0000 Subject: [PATCH] Add serialize in se_attn_v2 tf --- deepmd/tf/descriptor/se_atten_v2.py | 65 +++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/deepmd/tf/descriptor/se_atten_v2.py b/deepmd/tf/descriptor/se_atten_v2.py index 6204f27855..2329a40a79 100644 --- a/deepmd/tf/descriptor/se_atten_v2.py +++ b/deepmd/tf/descriptor/se_atten_v2.py @@ -109,3 +109,68 @@ def __init__( smooth_type_embedding=True, **kwargs, ) + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + Parameters + ---------- + data : dict + The serialized data + Returns + ------- + Model + The deserialized model + """ + if cls is not DescrptSeAttenV2: + raise NotImplementedError(f"Not implemented in class {cls.__name__}") + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + embedding_net_variables = cls.deserialize_network( + data.pop("embeddings"), suffix=suffix + ) + attention_layer_variables = cls.deserialize_attention_layers( + data.pop("attention_layers"), suffix=suffix + ) + data.pop("env_mat") + variables = data.pop("@variables") + type_one_side = data["type_one_side"] + two_side_embeeding_net_variables = cls.deserialize_network_strip( + data.pop("embeddings_strip"), + suffix=suffix, + type_one_side=type_one_side, + ) + descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables + descriptor.attention_layer_variables = attention_layer_variables + descriptor.two_side_embeeding_net_variables = two_side_embeeding_net_variables + descriptor.davg = variables["davg"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.dstd = variables["dstd"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + return descriptor + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + Parameters + ---------- + suffix : str, optional + The suffix of the scope + Returns + ------- + dict + The serialized data + """ + data = super().serialize(suffix) + data.pop("smooth_type_embedding") + data.pop("tebd_input_mode") + data.update( + { + "type": "se_atten_v2" + } + ) + return data \ No newline at end of file