Skip to content

Commit

Permalink
Add serialize in se_attn_v2 tf
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed Jun 3, 2024
1 parent 3d66922 commit a14fed2
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions deepmd/tf/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,68 @@ def __init__(
smooth_type_embedding=True,
**kwargs,
)

@classmethod
def deserialize(cls, data: dict, suffix: str = ""):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
if cls is not DescrptSeAttenV2:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
embedding_net_variables = cls.deserialize_network(
data.pop("embeddings"), suffix=suffix
)
attention_layer_variables = cls.deserialize_attention_layers(
data.pop("attention_layers"), suffix=suffix
)
data.pop("env_mat")
variables = data.pop("@variables")
type_one_side = data["type_one_side"]
two_side_embeeding_net_variables = cls.deserialize_network_strip(
data.pop("embeddings_strip"),
suffix=suffix,
type_one_side=type_one_side,
)
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.attention_layer_variables = attention_layer_variables
descriptor.two_side_embeeding_net_variables = two_side_embeeding_net_variables
descriptor.davg = variables["davg"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
descriptor.dstd = variables["dstd"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
return descriptor

def serialize(self, suffix: str = "") -> dict:
"""Serialize the model.
Parameters
----------
suffix : str, optional
The suffix of the scope
Returns
-------
dict
The serialized data
"""
data = super().serialize(suffix)
data.pop("smooth_type_embedding")
data.pop("tebd_input_mode")
data.update(
{
"type": "se_atten_v2"
}
)
return data

0 comments on commit a14fed2

Please sign in to comment.