Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add native Networks for mutiple Network classes #3117

Merged
merged 9 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
EmbeddingNet,
NativeLayer,
NativeNet,
Networks,
load_dp_model,
save_dp_model,
traverse_model_dict,
Expand All @@ -19,6 +20,7 @@
"EmbeddingNet",
"NativeLayer",
"NativeNet",
"Networks",
"load_dp_model",
"save_dp_model",
"traverse_model_dict",
Expand Down
97 changes: 97 additions & 0 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
"""
import json
from typing import (
Dict,
List,
Optional,
Union,
)

import h5py
Expand Down Expand Up @@ -409,3 +411,98 @@
obj = cls(**data)
super(EmbeddingNet, obj).__init__(layers)
return obj


class Networks:
njzjz marked this conversation as resolved.
Show resolved Hide resolved
"""A collection of networks for multiple elements.

The number of dimesions for types might be 0, 1, or 2.
- 0: embedding or fitting with type embedding, in ()
- 1: embedding with type_one_side, or fitting, in (type_i)
- 2: embedding without type_one_side, in (type_i, type_j)

Serialized keys are in the form of type, type_i, type_i_j, ...

Parameters
----------
ndim : int
The number of dimensions.
network_type : str, optional
The type of the network.
networks : dict, optional
The networks to initialize with.
"""

def __init__(
self,
ndim: int,
network_type: str = "network",
networks: Dict[Union[str, tuple], Union[NativeNet, dict]] = {},
):
self.ndim = ndim
if network_type == "network":
self.network_type = NativeNet
elif network_type == "embedding_network":
self.network_type = EmbeddingNet

Check warning on line 446 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L445-L446

Added lines #L445 - L446 were not covered by tests
else:
raise NotImplementedError(network_type)

Check warning on line 448 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L448

Added line #L448 was not covered by tests
self._networks = {}
for kk, vv in networks.items():
self[kk] = vv

def _convert_key(self, key):
if isinstance(key, tuple):
pass
elif isinstance(key, str):
key = tuple([int(tt) for tt in key.split("_")[1:]])
else:
raise TypeError(key)

Check warning on line 459 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L459

Added line #L459 was not covered by tests
assert isinstance(key, tuple)
assert len(key) == self.ndim
return key

def __getitem__(self, key):
return self._networks[self._convert_key(key)]

def __setitem__(self, key, value):
if isinstance(value, self.network_type):
pass

Check warning on line 469 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L469

Added line #L469 was not covered by tests
elif isinstance(value, dict):
value = self.network_type.deserialize(value)
else:
raise TypeError(value)

Check warning on line 473 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L473

Added line #L473 was not covered by tests
self._networks[self._convert_key(key)] = value

def serialize(self) -> dict:
"""Serialize the networks to a dict.

Returns
-------
dict
The serialized networks.
"""
if self.network_type is NativeNet:
network_type_name = "network"
elif self.network_type is EmbeddingNet:
network_type_name = "embedding_network"

Check warning on line 487 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L486-L487

Added lines #L486 - L487 were not covered by tests
else:
raise NotImplementedError(self.network_type)

Check warning on line 489 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L489

Added line #L489 was not covered by tests
return {
"ndim": self.ndim,
"network_type": network_type_name,
"networks": {
("_".join(["type"] + [str(tt) for tt in key])): value.serialize()
for key, value in self._networks.items()
},
}

@classmethod
def deserialize(cls, data: dict) -> "Networks":
"""Deserialize the networks from a dict.

Parameters
----------
data : dict
The dict to deserialize from.
"""
return cls(**data)
56 changes: 56 additions & 0 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
EnvMat,
NativeLayer,
NativeNet,
Networks,
load_dp_model,
save_dp_model,
)
Expand Down Expand Up @@ -108,6 +109,61 @@ def test_embedding_net(self):
np.testing.assert_allclose(en0.call(inp), en1.call(inp))


class TestNetworks(unittest.TestCase):
def setUp(self) -> None:
w = np.full((2, 3), 3.0)
b = np.full((3,), 4.0)
self.network = {
"layers": [
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": w, "b": b},
},
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": w, "b": b},
},
],
}

def test_two_dim(self):
networks = Networks(ndim=2)
networks[(0, 0)] = self.network
networks[(0, 1)] = self.network
njzjz marked this conversation as resolved.
Show resolved Hide resolved
np.testing.assert_equal(
networks.serialize(),
Networks.deserialize(networks.serialize()).serialize(),
)
np.testing.assert_equal(
networks[(0, 0)].serialize(), networks.serialize()["networks"]["type_0_0"]
)

def test_one_dim(self):
networks = Networks(ndim=1)
networks[(0,)] = self.network
networks[(1,)] = self.network
np.testing.assert_equal(
networks.serialize(),
Networks.deserialize(networks.serialize()).serialize(),
)
np.testing.assert_equal(
networks[(0,)].serialize(), networks.serialize()["networks"]["type_0"]
)

def test_zero_dim(self):
networks = Networks(ndim=0)
networks[()] = self.network
np.testing.assert_equal(
networks.serialize(),
Networks.deserialize(networks.serialize()).serialize(),
)
np.testing.assert_equal(
networks[()].serialize(), networks.serialize()["networks"]["type"]
)


class TestDPModel(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((3, 2), 3.0)
Expand Down
Loading