diff --git a/deepmd_utils/model_format/__init__.py b/deepmd_utils/model_format/__init__.py index 4b33aa0151..bc157f29d3 100644 --- a/deepmd_utils/model_format/__init__.py +++ b/deepmd_utils/model_format/__init__.py @@ -3,6 +3,7 @@ EmbeddingNet, NativeLayer, NativeNet, + Networks, load_dp_model, save_dp_model, traverse_model_dict, @@ -12,6 +13,7 @@ "EmbeddingNet", "NativeLayer", "NativeNet", + "Networks", "load_dp_model", "save_dp_model", "traverse_model_dict", diff --git a/deepmd_utils/model_format/network.py b/deepmd_utils/model_format/network.py index 04aaa75534..433d70247f 100644 --- a/deepmd_utils/model_format/network.py +++ b/deepmd_utils/model_format/network.py @@ -8,8 +8,10 @@ ABC, ) from typing import ( + Dict, List, Optional, + Union, ) import h5py @@ -389,3 +391,98 @@ def deserialize(cls, data: dict) -> "EmbeddingNet": obj = cls(**data) super(EmbeddingNet, obj).__init__(layers) return obj + + +class Networks: + """A collection of networks for multiple elements. + + The number of dimesions for types might be 0, 1, or 2. + - 0: embedding or fitting with type embedding, in () + - 1: embedding with type_one_side, or fitting, in (type_i) + - 2: embedding without type_one_side, in (type_i, type_j) + + Serialized keys are in the form of type, type_i, type_i_j, ... + + Parameters + ---------- + ndim : int + The number of dimensions. + network_type : str, optional + The type of the network. + networks : dict, optional + The networks to initialize with. + """ + + def __init__( + self, + ndim: int, + network_type: str = "network", + networks: Dict[Union[str, tuple], Union[NativeNet, dict]] = {}, + ): + self.ndim = ndim + if network_type == "network": + self.network_type = NativeNet + elif network_type == "embedding_network": + self.network_type = EmbeddingNet + else: + raise NotImplementedError(network_type) + self._networks = {} + for kk, vv in networks.items(): + self[kk] = vv + + def _convert_key(self, key): + if isinstance(key, tuple): + pass + elif isinstance(key, str): + key = tuple([int(tt) for tt in key.split("_")[1:]]) + else: + raise TypeError(key) + assert isinstance(key, tuple) + assert len(key) == self.ndim + return key + + def __getitem__(self, key): + return self._networks[self._convert_key(key)] + + def __setitem__(self, key, value): + if isinstance(value, self.network_type): + pass + elif isinstance(value, dict): + value = self.network_type.deserialize(value) + else: + raise TypeError(value) + self._networks[self._convert_key(key)] = value + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + if self.network_type is NativeNet: + network_type_name = "network" + elif self.network_type is EmbeddingNet: + network_type_name = "embedding_network" + else: + raise NotImplementedError(self.network_type) + return { + "ndim": self.ndim, + "network_type": network_type_name, + "networks": { + ("_".join(["type"] + [str(tt) for tt in key])): value.serialize() + for key, value in self._networks.items() + }, + } + + @classmethod + def deserialize(cls, data: dict) -> "Networks": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + return cls(**data) diff --git a/source/tests/test_model_format_utils.py b/source/tests/test_model_format_utils.py index f26ebbaa8d..55018a3c9e 100644 --- a/source/tests/test_model_format_utils.py +++ b/source/tests/test_model_format_utils.py @@ -12,6 +12,7 @@ EmbeddingNet, NativeLayer, NativeNet, + Networks, load_dp_model, save_dp_model, ) @@ -103,6 +104,61 @@ def test_embedding_net(self): np.testing.assert_allclose(en0.call(inp), en1.call(inp)) +class TestNetworks(unittest.TestCase): + def setUp(self) -> None: + w = np.full((2, 3), 3.0) + b = np.full((3,), 4.0) + self.network = { + "layers": [ + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": w, "b": b}, + }, + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": w, "b": b}, + }, + ], + } + + def test_two_dim(self): + networks = Networks(ndim=2) + networks[(0, 0)] = self.network + networks[(0, 1)] = self.network + self.assertDictEqual( + networks.serialize(), + Networks.deserialize(networks.serialize()).serialize(), + ) + self.assertDictEqual( + networks[(0, 0)].serialize(), networks.serialize()["networks"]["type_0_0"] + ) + + def test_one_dim(self): + networks = Networks(ndim=1) + networks[(0,)] = self.network + networks[(1,)] = self.network + self.assertDictEqual( + networks.serialize(), + Networks.deserialize(networks.serialize()).serialize(), + ) + self.assertDictEqual( + networks[(0,)].serialize(), networks.serialize()["networks"]["type_0"] + ) + + def test_zero_dim(self): + networks = Networks(ndim=0) + networks[()] = self.network + self.assertDictEqual( + networks.serialize(), + Networks.deserialize(networks.serialize()).serialize(), + ) + self.assertDictEqual( + networks[()].serialize(), networks.serialize()["networks"]["type"] + ) + + class TestDPModel(unittest.TestCase): def setUp(self) -> None: self.w = np.full((3, 2), 3.0)