Skip to content

Commit

Permalink
introduce EmbeddingNet
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 8, 2024
1 parent 4f19ea3 commit e12c10f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 18 deletions.
51 changes: 37 additions & 14 deletions deepmd/descriptor/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
defaultdict,
)
from typing import (
List,
Tuple,
)

Expand All @@ -15,8 +16,8 @@
get_embedding_net_variables_from_graph_def,
get_tensor_by_name_from_graph,
)
from deepmd_utils.model_format import (
NativeNet,
from deepmd_utils.model_format.network import (
EmbeddingNet,
)

from .descriptor import (
Expand Down Expand Up @@ -169,21 +170,43 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
local_jdata_cpy = local_jdata.copy()
return update_one_sel(global_jdata, local_jdata_cpy, False)

def to_dp_variables(self, variables: dict) -> dict:
"""Convert the variables to deepmd format.
def serialize_network(
self,
in_dim: int,
neuron: List[int],
activation_function: str,
resnet_dt: bool,
variables: dict,
) -> dict:
"""Serialize network.
Parameters
----------
in_dim : int
The input dimension
neuron : List[int]
The neuron list
activation_function : str
The activation function
resnet_dt : bool
Whether to use resnet
variables : dict
The input variables
Returns
-------
dict
The converted variables
The converted network data
"""
# TODO: unclear how to hand suffix, maybe we need to add a suffix argument?
networks = defaultdict(NativeNet)
networks = defaultdict(

Check warning on line 202 in deepmd/descriptor/se.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se.py#L202

Added line #L202 was not covered by tests
lambda: EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
activation_function=activation_function,
resnet_dt=resnet_dt,
)
)
for key, value in variables.items():
m = re.search(EMBEDDING_NET_PATTERN, key)
m = [mm for mm in m.groups() if mm is not None]

Check warning on line 212 in deepmd/descriptor/se.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se.py#L210-L212

Added lines #L210 - L212 were not covered by tests
Expand All @@ -196,29 +219,29 @@ def to_dp_variables(self, variables: dict) -> dict:
return {key: value.serialize() for key, value in networks.items()}

Check warning on line 219 in deepmd/descriptor/se.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se.py#L214-L219

Added lines #L214 - L219 were not covered by tests

@classmethod
def from_dp_variables(cls, variables: dict) -> dict:
"""Convert the variables from deepmd format.
def deserialize_network(cls, data: dict) -> Tuple[List[int], str, bool, dict, str]:
"""Deserialize network.
Parameters
----------
variables : dict
The input variables
data : dict
The input network data
Returns
-------
dict
The converted variables
variables : dict
The input variables
"""
embedding_net_variables = {}
for key, value in variables.items():
for key, value in data.items():
keys = key.split("/")
key0 = keys[0][5:]
key1 = keys[1][5:]
if key1 == "all":
key1 = ""

Check warning on line 241 in deepmd/descriptor/se.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se.py#L235-L241

Added lines #L235 - L241 were not covered by tests
else:
key1 = "_" + key1
network = NativeNet.deserialize(value)
network = EmbeddingNet.deserialize(value)
for layer_idx, layer in enumerate(network.layers):
embedding_net_variables[

Check warning on line 246 in deepmd/descriptor/se.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se.py#L243-L246

Added lines #L243 - L246 were not covered by tests
f"filter_type_{key0}/matrix_{layer_idx}{key1}"
Expand Down
18 changes: 14 additions & 4 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,12 +1358,13 @@ def deserialize(cls, data: dict):
Model
The deserialized model
"""
if type(cls) is not DescrptSeA:
raise NotImplementedError("Unsupported")
embedding_net_variables = cls.deserialize_network(data["networks"])
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.davg = data["@variables"]["davg"]
descriptor.dstd = data["@variables"]["dstd"]
descriptor.embedding_net_variables = cls.from_dp_variables(
data["@variables"]["networks"]
)
descriptor.original_sel = data["@variables"]["original_sel"]
return descriptor

Check warning on line 1369 in deepmd/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_a.py#L1361-L1369

Added lines #L1361 - L1369 were not covered by tests

Expand All @@ -1375,6 +1376,8 @@ def serialize(self) -> dict:
dict
The serialized data
"""
if type(self) is not DescrptSeA:
raise NotImplementedError("Unsupported")
return {

Check warning on line 1381 in deepmd/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_a.py#L1379-L1381

Added lines #L1379 - L1381 were not covered by tests
"type": "se_e2_a",
"rcut": self.rcut_r,
Expand All @@ -1392,8 +1395,15 @@ def serialize(self) -> dict:
"precision": self.filter_precision.name,
"uniform_seed": self.uniform_seed,
"stripped_type_embedding": self.stripped_type_embedding,
"networks": self.serialize_network(
# TODO: how to consider type embedding?
in_dim=1,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
variables=self.embedding_net_variables,
),
"@variables": {
"networks": self.to_dp_variables(self.embedding_net_variables),
"davg": self.davg,
"dstd": self.dstd,
"original_sel": self.original_sel,
Expand Down

0 comments on commit e12c10f

Please sign in to comment.