Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add native Networks for mutiple Network classes #3117

Merged
merged 9 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
EmbeddingNet,
NativeLayer,
NativeNet,
NetworkCollection,
load_dp_model,
save_dp_model,
traverse_model_dict,
Expand All @@ -24,6 +25,7 @@
"EmbeddingNet",
"NativeLayer",
"NativeNet",
"NetworkCollection",
"load_dp_model",
"save_dp_model",
"traverse_model_dict",
Expand Down
112 changes: 112 additions & 0 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

See issue #2982 for more information.
"""
import itertools
import json
from typing import (
ClassVar,
Dict,
List,
Optional,
Union,
)

import h5py
Expand Down Expand Up @@ -411,3 +415,111 @@
obj = cls(**data)
super(EmbeddingNet, obj).__init__(layers)
return obj


class NetworkCollection:
"""A collection of networks for multiple elements.

The number of dimesions for types might be 0, 1, or 2.
- 0: embedding or fitting with type embedding, in ()
- 1: embedding with type_one_side, or fitting, in (type_i)
- 2: embedding without type_one_side, in (type_i, type_j)

Parameters
----------
ndim : int
The number of dimensions.
network_type : str, optional
The type of the network.
networks : dict, optional
The networks to initialize with.
"""

# subclass may override this
NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = {
"network": NativeNet,
"embedding_network": EmbeddingNet,
}

def __init__(
self,
ndim: int,
ntypes: int,
network_type: str = "network",
networks: List[Union[NativeNet, dict]] = [],
):
self.ndim = ndim
self.ntypes = ntypes
self.network_type = self.NETWORK_TYPE_MAP[network_type]
self._networks = [None for ii in range(ntypes**ndim)]
for ii, network in enumerate(networks):
self[ii] = network
if len(networks):
self.check_completeness()

def check_completeness(self):
"""Check whether the collection is complete.

Raises
------
RuntimeError
If the collection is incomplete.
"""
for tt in itertools.product(range(self.ntypes), repeat=self.ndim):
if self[tuple(tt)] is None:
raise RuntimeError(f"network for {tt} not found")

def _convert_key(self, key):
if isinstance(key, int):
idx = key
else:
if isinstance(key, tuple):
pass
elif isinstance(key, str):
key = tuple([int(tt) for tt in key.split("_")[1:]])

Check warning on line 479 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L478-L479

Added lines #L478 - L479 were not covered by tests
else:
raise TypeError(key)

Check warning on line 481 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L481

Added line #L481 was not covered by tests
assert isinstance(key, tuple)
assert len(key) == self.ndim
idx = sum([tt * self.ntypes**ii for ii, tt in enumerate(key)])
return idx

def __getitem__(self, key):
return self._networks[self._convert_key(key)]

def __setitem__(self, key, value):
if isinstance(value, self.network_type):
pass
elif isinstance(value, dict):
value = self.network_type.deserialize(value)
else:
raise TypeError(value)

Check warning on line 496 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L496

Added line #L496 was not covered by tests
self._networks[self._convert_key(key)] = value

def serialize(self) -> dict:
"""Serialize the networks to a dict.

Returns
-------
dict
The serialized networks.
"""
network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()}
network_type_name = network_type_map_inv[self.network_type]
return {
"ndim": self.ndim,
"ntypes": self.ntypes,
"network_type": network_type_name,
"networks": [nn.serialize() for nn in self._networks],
}

@classmethod
def deserialize(cls, data: dict) -> "NetworkCollection":
"""Deserialize the networks from a dict.

Parameters
----------
data : dict
The dict to deserialize from.
"""
return cls(**data)
27 changes: 15 additions & 12 deletions deepmd_utils/model_format/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from .network import (
EmbeddingNet,
NetworkCollection,
)


Expand Down Expand Up @@ -154,16 +155,18 @@ def __init__(
self.spin = spin

in_dim = 1 # not considiering type embedding
self.embeddings = []
self.embeddings = NetworkCollection(
ntypes=self.ntypes,
ndim=(1 if self.type_one_side else 2),
network_type="embedding_network",
)
for ii in range(self.ntypes):
self.embeddings.append(
EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
self.embeddings[(ii,)] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
self.env_mat = EnvMat(self.rcut, self.rcut_smth)
self.nnei = np.sum(self.sel)
Expand Down Expand Up @@ -196,7 +199,7 @@ def cal_g(
nf, nloc, nnei = ss.shape[0:3]
ss = ss.reshape(nf, nloc, nnei, 1)
# nf x nloc x nnei x ng
gg = self.embeddings[ll].call(ss)
gg = self.embeddings[(ll,)].call(ss)
return gg

def call(
Expand Down Expand Up @@ -258,7 +261,7 @@ def serialize(self) -> dict:
"precision": self.precision,
"spin": self.spin,
"env_mat": self.env_mat.serialize(),
"embeddings": [ii.serialize() for ii in self.embeddings],
"embeddings": self.embeddings.serialize(),
"@variables": {
"davg": self.davg,
"dstd": self.dstd,
Expand All @@ -274,6 +277,6 @@ def deserialize(cls, data: dict) -> "DescrptSeA":

obj["davg"] = variables["davg"]
obj["dstd"] = variables["dstd"]
obj.embeddings = [EmbeddingNet.deserialize(dd) for dd in embeddings]
obj.embeddings = NetworkCollection.deserialize(embeddings)
obj.env_mat = EnvMat.deserialize(env_mat)
return obj
65 changes: 65 additions & 0 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
EnvMat,
NativeLayer,
NativeNet,
NetworkCollection,
load_dp_model,
save_dp_model,
)
Expand Down Expand Up @@ -115,6 +116,70 @@ def test_embedding_net(self):
np.testing.assert_allclose(en0.call(inp), en1.call(inp))


class TestNetworkCollection(unittest.TestCase):
def setUp(self) -> None:
w = np.full((2, 3), 3.0)
b = np.full((3,), 4.0)
self.network = {
"layers": [
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": w, "b": b},
},
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": w, "b": b},
},
],
}

def test_two_dim(self):
networks = NetworkCollection(ndim=2, ntypes=2)
networks[(0, 0)] = self.network
networks[(1, 1)] = self.network
networks[(0, 1)] = self.network
njzjz marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(RuntimeError):
networks.check_completeness()
networks[(1, 0)] = self.network
networks.check_completeness()
np.testing.assert_equal(
networks.serialize(),
NetworkCollection.deserialize(networks.serialize()).serialize(),
)
np.testing.assert_equal(
networks[(0, 0)].serialize(), networks.serialize()["networks"][0]
)

def test_one_dim(self):
networks = NetworkCollection(ndim=1, ntypes=2)
networks[(0,)] = self.network
with self.assertRaises(RuntimeError):
networks.check_completeness()
networks[(1,)] = self.network
networks.check_completeness()
np.testing.assert_equal(
networks.serialize(),
NetworkCollection.deserialize(networks.serialize()).serialize(),
)
np.testing.assert_equal(
networks[(0,)].serialize(), networks.serialize()["networks"][0]
)

def test_zero_dim(self):
networks = NetworkCollection(ndim=0, ntypes=2)
networks[()] = self.network
networks.check_completeness()
np.testing.assert_equal(
networks.serialize(),
NetworkCollection.deserialize(networks.serialize()).serialize(),
)
np.testing.assert_equal(
networks[()].serialize(), networks.serialize()["networks"][0]
)


class TestDPModel(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((3, 2), 3.0)
Expand Down
Loading