Skip to content

Commit

Permalink
add Networks
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 9, 2024
1 parent f181a30 commit ff76186
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
EmbeddingNet,
NativeLayer,
NativeNet,
Networks,
load_dp_model,
save_dp_model,
traverse_model_dict,
Expand All @@ -12,6 +13,7 @@
"EmbeddingNet",
"NativeLayer",
"NativeNet",
"Networks",
"load_dp_model",
"save_dp_model",
"traverse_model_dict",
Expand Down
97 changes: 97 additions & 0 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
ABC,
)
from typing import (
Dict,
List,
Optional,
Union,
)

import h5py
Expand Down Expand Up @@ -389,3 +391,98 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
obj = cls(**data)
super(EmbeddingNet, obj).__init__(layers)
return obj


class Networks:
"""A collection of networks for multiple elements.
The number of dimesions for types might be 0, 1, or 2.
- 0: embedding or fitting with type embedding, in ()
- 1: embedding with type_one_side, or fitting, in (type_i)
- 2: embedding without type_one_side, in (type_i, type_j)
Serialized keys are in the form of type, type_i, type_i_j, ...
Parameters
----------
ndim : int
The number of dimensions.
network_type : str, optional
The type of the network.
networks : dict, optional
The networks to initialize with.
"""

def __init__(
self,
ndim: int,
network_type: str = "network",
networks: Dict[Union[str, tuple], Union[NativeNet, dict]] = {},
):
self.ndim = ndim
if network_type == "network":
self.network_type = NativeNet
elif network_type == "embedding_network":
self.network_type = EmbeddingNet
else:
raise NotImplementedError(network_type)
self._networks = {}
for kk, vv in networks.items():
self[kk] = vv

def _convert_key(self, key):
if isinstance(key, tuple):
pass
elif isinstance(key, str):
key = tuple([int(tt) for tt in key.split("_")[1:]])
else:
raise TypeError(key)
assert isinstance(key, tuple)
assert len(key) == self.ndim
return key

def __getitem__(self, key):
return self._networks[self._convert_key(key)]

def __setitem__(self, key, value):
if isinstance(value, self.network_type):
pass
elif isinstance(value, dict):
value = self.network_type.deserialize(value)
else:
raise TypeError(value)
self._networks[self._convert_key(key)] = value

def serialize(self) -> dict:
"""Serialize the networks to a dict.
Returns
-------
dict
The serialized networks.
"""
if self.network_type is NativeNet:
network_type_name = "network"
elif self.network_type is EmbeddingNet:
network_type_name = "embedding_network"
else:
raise NotImplementedError(self.network_type)
return {
"ndim": self.ndim,
"network_type": network_type_name,
"networks": {
("_".join(["type"] + [str(tt) for tt in key])): value.serialize()
for key, value in self._networks.items()
},
}

@classmethod
def deserialize(cls, data: dict) -> "Networks":
"""Deserialize the networks from a dict.
Parameters
----------
data : dict
The dict to deserialize from.
"""
return cls(**data)
56 changes: 56 additions & 0 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EmbeddingNet,
NativeLayer,
NativeNet,
Networks,
load_dp_model,
save_dp_model,
)
Expand Down Expand Up @@ -103,6 +104,61 @@ def test_embedding_net(self):
np.testing.assert_allclose(en0.call(inp), en1.call(inp))


class TestNetworks(unittest.TestCase):
def setUp(self) -> None:
w = np.full((2, 3), 3.0)
b = np.full((3,), 4.0)
self.network = {
"layers": [
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": w, "b": b},
},
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": w, "b": b},
},
],
}

def test_two_dim(self):
networks = Networks(ndim=2)
networks[(0, 0)] = self.network
networks[(0, 1)] = self.network
self.assertDictEqual(
networks.serialize(),
Networks.deserialize(networks.serialize()).serialize(),
)
self.assertDictEqual(
networks[(0, 0)].serialize(), networks.serialize()["networks"]["type_0_0"]
)

def test_one_dim(self):
networks = Networks(ndim=1)
networks[(0,)] = self.network
networks[(1,)] = self.network
self.assertDictEqual(
networks.serialize(),
Networks.deserialize(networks.serialize()).serialize(),
)
self.assertDictEqual(
networks[(0,)].serialize(), networks.serialize()["networks"]["type_0"]
)

def test_zero_dim(self):
networks = Networks(ndim=0)
networks[()] = self.network
self.assertDictEqual(
networks.serialize(),
Networks.deserialize(networks.serialize()).serialize(),
)
self.assertDictEqual(
networks[()].serialize(), networks.serialize()["networks"]["type"]
)


class TestDPModel(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((3, 2), 3.0)
Expand Down

0 comments on commit ff76186

Please sign in to comment.