Skip to content

Commit

Permalink
add check_completeness
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 10, 2024
1 parent a209d1a commit 811447d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
16 changes: 16 additions & 0 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
See issue #2982 for more information.
"""
import itertools
import json
from typing import (
Dict,
Expand Down Expand Up @@ -438,10 +439,12 @@ class NetworkCollection:
def __init__(
self,
ndim: int,
ntypes: int,
network_type: str = "network",
networks: Dict[Union[str, tuple], Union[NativeNet, dict]] = {},
):
self.ndim = ndim
self.ntypes = ntypes
if network_type == "network":
self.network_type = NativeNet
elif network_type == "embedding_network":
Expand All @@ -452,6 +455,18 @@ def __init__(
for kk, vv in networks.items():
self[kk] = vv

def check_completeness(self):
"""Check whether the collection is complete.
Raises
------
RuntimeError
If the collection is incomplete.
"""
for tt in itertools.product(range(self.ntypes), repeat=self.ndim):
if tuple(tt) not in self._networks:
raise RuntimeError(f"network for {tt} not found")

def _convert_key(self, key):
if isinstance(key, tuple):
pass
Expand Down Expand Up @@ -491,6 +506,7 @@ def serialize(self) -> dict:
raise NotImplementedError(self.network_type)

Check warning on line 506 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L506

Added line #L506 was not covered by tests
return {
"ndim": self.ndim,
"ntypes": self.ntypes,
"network_type": network_type_name,
"networks": {
("_".join(["type"] + [str(tt) for tt in key])): value.serialize()
Expand Down
15 changes: 12 additions & 3 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,14 @@ def setUp(self) -> None:
}

def test_two_dim(self):
networks = NetworkCollection(ndim=2)
networks = NetworkCollection(ndim=2, ntypes=2)
networks[(0, 0)] = self.network
networks[(1, 1)] = self.network
networks[(0, 1)] = self.network
with self.assertRaises(RuntimeError):
networks.check_completeness()
networks[(1, 0)] = self.network
networks.check_completeness()
np.testing.assert_equal(
networks.serialize(),
NetworkCollection.deserialize(networks.serialize()).serialize(),
Expand All @@ -148,9 +153,12 @@ def test_two_dim(self):
)

def test_one_dim(self):
networks = NetworkCollection(ndim=1)
networks = NetworkCollection(ndim=1, ntypes=2)
networks[(0,)] = self.network
with self.assertRaises(RuntimeError):
networks.check_completeness()
networks[(1,)] = self.network
networks.check_completeness()
np.testing.assert_equal(
networks.serialize(),
NetworkCollection.deserialize(networks.serialize()).serialize(),
Expand All @@ -160,8 +168,9 @@ def test_one_dim(self):
)

def test_zero_dim(self):
networks = NetworkCollection(ndim=0)
networks = NetworkCollection(ndim=0, ntypes=2)
networks[()] = self.network
networks.check_completeness()
np.testing.assert_equal(
networks.serialize(),
NetworkCollection.deserialize(networks.serialize()).serialize(),
Expand Down

0 comments on commit 811447d

Please sign in to comment.