From e12c10fb4728f3f0d7701a32ff75b4d14910512a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 8 Jan 2024 03:15:00 -0500 Subject: [PATCH] introduce EmbeddingNet Signed-off-by: Jinzhe Zeng --- deepmd/descriptor/se.py | 51 ++++++++++++++++++++++++++++----------- deepmd/descriptor/se_a.py | 18 +++++++++++--- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/deepmd/descriptor/se.py b/deepmd/descriptor/se.py index 1cbdf05531..07f12de34d 100644 --- a/deepmd/descriptor/se.py +++ b/deepmd/descriptor/se.py @@ -4,6 +4,7 @@ defaultdict, ) from typing import ( + List, Tuple, ) @@ -15,8 +16,8 @@ get_embedding_net_variables_from_graph_def, get_tensor_by_name_from_graph, ) -from deepmd_utils.model_format import ( - NativeNet, +from deepmd_utils.model_format.network import ( + EmbeddingNet, ) from .descriptor import ( @@ -169,21 +170,43 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): local_jdata_cpy = local_jdata.copy() return update_one_sel(global_jdata, local_jdata_cpy, False) - def to_dp_variables(self, variables: dict) -> dict: - """Convert the variables to deepmd format. + def serialize_network( + self, + in_dim: int, + neuron: List[int], + activation_function: str, + resnet_dt: bool, + variables: dict, + ) -> dict: + """Serialize network. Parameters ---------- + in_dim : int + The input dimension + neuron : List[int] + The neuron list + activation_function : str + The activation function + resnet_dt : bool + Whether to use resnet variables : dict The input variables Returns ------- dict - The converted variables + The converted network data """ # TODO: unclear how to hand suffix, maybe we need to add a suffix argument? - networks = defaultdict(NativeNet) + networks = defaultdict( + lambda: EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + ) + ) for key, value in variables.items(): m = re.search(EMBEDDING_NET_PATTERN, key) m = [mm for mm in m.groups() if mm is not None] @@ -196,21 +219,21 @@ def to_dp_variables(self, variables: dict) -> dict: return {key: value.serialize() for key, value in networks.items()} @classmethod - def from_dp_variables(cls, variables: dict) -> dict: - """Convert the variables from deepmd format. + def deserialize_network(cls, data: dict) -> Tuple[List[int], str, bool, dict, str]: + """Deserialize network. Parameters ---------- - variables : dict - The input variables + data : dict + The input network data Returns ------- - dict - The converted variables + variables : dict + The input variables """ embedding_net_variables = {} - for key, value in variables.items(): + for key, value in data.items(): keys = key.split("/") key0 = keys[0][5:] key1 = keys[1][5:] @@ -218,7 +241,7 @@ def from_dp_variables(cls, variables: dict) -> dict: key1 = "" else: key1 = "_" + key1 - network = NativeNet.deserialize(value) + network = EmbeddingNet.deserialize(value) for layer_idx, layer in enumerate(network.layers): embedding_net_variables[ f"filter_type_{key0}/matrix_{layer_idx}{key1}" diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 2b5c54fd5a..0573294c35 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -1358,12 +1358,13 @@ def deserialize(cls, data: dict): Model The deserialized model """ + if type(cls) is not DescrptSeA: + raise NotImplementedError("Unsupported") + embedding_net_variables = cls.deserialize_network(data["networks"]) descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables descriptor.davg = data["@variables"]["davg"] descriptor.dstd = data["@variables"]["dstd"] - descriptor.embedding_net_variables = cls.from_dp_variables( - data["@variables"]["networks"] - ) descriptor.original_sel = data["@variables"]["original_sel"] return descriptor @@ -1375,6 +1376,8 @@ def serialize(self) -> dict: dict The serialized data """ + if type(self) is not DescrptSeA: + raise NotImplementedError("Unsupported") return { "type": "se_e2_a", "rcut": self.rcut_r, @@ -1392,8 +1395,15 @@ def serialize(self) -> dict: "precision": self.filter_precision.name, "uniform_seed": self.uniform_seed, "stripped_type_embedding": self.stripped_type_embedding, + "networks": self.serialize_network( + # TODO: how to consider type embedding? + in_dim=1, + neuron=self.filter_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.filter_resnet_dt, + variables=self.embedding_net_variables, + ), "@variables": { - "networks": self.to_dp_variables(self.embedding_net_variables), "davg": self.davg, "dstd": self.dstd, "original_sel": self.original_sel,