diff --git a/deepmd_utils/model_format/network.py b/deepmd_utils/model_format/network.py index c73e53f5cb..d9071784ca 100644 --- a/deepmd_utils/model_format/network.py +++ b/deepmd_utils/model_format/network.py @@ -3,6 +3,7 @@ See issue #2982 for more information. """ +import copy import itertools import json from typing import ( @@ -150,23 +151,26 @@ class NativeLayer(NativeOP): def __init__( self, - w: Optional[np.ndarray] = None, - b: Optional[np.ndarray] = None, - idt: Optional[np.ndarray] = None, + num_in, + num_out, + bias: bool = True, + use_timestep: bool = False, activation_function: Optional[str] = None, resnet: bool = False, precision: str = DEFAULT_PRECISION, ) -> None: prec = PRECISION_DICT[precision.lower()] self.precision = precision - self.w = w.astype(prec) if w is not None else None - self.b = b.astype(prec) if b is not None else None - self.idt = idt.astype(prec) if idt is not None else None + rng = np.random.default_rng() + self.w = rng.normal(size=(num_in, num_out)).astype(prec) + self.b = rng.normal(size=(num_out,)).astype(prec) if bias else None + self.idt = rng.normal(size=(num_out,)).astype(prec) if use_timestep else None self.activation_function = ( activation_function if activation_function is not None else "none" ) self.resnet = resnet self.check_type_consistency() + self.check_shape_consistency() def serialize(self) -> dict: """Serialize the layer to a dict. @@ -179,10 +183,11 @@ def serialize(self) -> dict: data = { "w": self.w, "b": self.b, + "idt": self.idt, } - if self.idt is not None: - data["idt"] = self.idt return { + "bias": self.b is not None, + "use_timestep": self.idt is not None, "activation_function": self.activation_function, "resnet": self.resnet, "precision": self.precision, @@ -198,15 +203,34 @@ def deserialize(cls, data: dict) -> "NativeLayer": data : dict The dict to deserialize from. """ - precision = data.get("precision", DEFAULT_PRECISION) - return cls( - w=data["@variables"]["w"], - b=data["@variables"].get("b", None), - idt=data["@variables"].get("idt", None), - activation_function=data["activation_function"], - resnet=data.get("resnet", False), - precision=precision, + data = copy.deepcopy(data) + variables = data.pop("@variables") + assert variables["w"] is not None and len(variables["w"].shape) == 2 + num_in, num_out = variables["w"].shape + obj = cls( + num_in, + num_out, + **data, ) + obj.w, obj.b, obj.idt = ( + variables["w"], + variables.get("b", None), + variables.get("idt", None), + ) + obj.check_shape_consistency() + return obj + + def check_shape_consistency(self): + if self.b is not None and self.w.shape[1] != self.b.shape[0]: + raise ValueError( + f"dim 1 of w {self.w.shape[1]} is not equal to shape " + f"of b {self.b.shape[0]}", + ) + if self.idt is not None and self.w.shape[1] != self.idt.shape[0]: + raise ValueError( + f"dim 1 of w {self.w.shape[1]} is not equal to shape " + f"of idt {self.idt.shape[0]}", + ) def check_type_consistency(self): precision = self.precision @@ -252,6 +276,14 @@ def __getitem__(self, key): else: raise KeyError(key) + @property + def dim_in(self) -> int: + return self.w.shape[0] + + @property + def dim_out(self) -> int: + return self.w.shape[1] + def call(self, x: np.ndarray) -> np.ndarray: """Forward pass. @@ -303,6 +335,7 @@ def __init__(self, layers: Optional[List[dict]] = None) -> None: if layers is None: layers = [] self.layers = [NativeLayer.deserialize(layer) for layer in layers] + self.check_shape_consistency() def serialize(self) -> dict: """Serialize the network to a dict. @@ -327,16 +360,21 @@ def deserialize(cls, data: dict) -> "NativeNet": def __getitem__(self, key): assert isinstance(key, int) - if len(self.layers) <= key: - self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1)) return self.layers[key] def __setitem__(self, key, value): assert isinstance(key, int) - if len(self.layers) <= key: - self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1)) self.layers[key] = value + def check_shape_consistency(self): + for ii in range(len(self.layers) - 1): + if self.layers[ii].dim_out != self.layers[ii + 1].dim_in: + raise ValueError( + f"the dim of layer {ii} output {self.layers[ii].dim_out} ", + f"does not match the dim of layer {ii+1} ", + f"output {self.layers[ii].dim_out}", + ) + def call(self, x: np.ndarray) -> np.ndarray: """Forward pass. @@ -389,9 +427,10 @@ def __init__( i_ot = ii layers.append( NativeLayer( - rng.normal(size=(i_in, i_ot)), - b=rng.normal(size=(i_ot)), - idt=rng.normal(size=(i_ot)) if resnet_dt else None, + i_in, + i_ot, + bias=True, + use_timestep=resnet_dt, activation_function=activation_function, resnet=True, precision=precision, @@ -431,6 +470,7 @@ def deserialize(cls, data: dict) -> "EmbeddingNet": data : dict The dict to deserialize from. """ + data = copy.deepcopy(data) layers = data.pop("layers") obj = cls(**data) super(EmbeddingNet, obj).__init__(layers) @@ -481,9 +521,10 @@ def __init__( i_in, i_ot = neuron[-1], out_dim self.layers.append( NativeLayer( - rng.normal(size=(i_in, i_ot)), - b=rng.normal(size=(i_ot)) if bias_out else None, - idt=None, + i_in, + i_ot, + bias=bias_out, + use_timestep=False, activation_function=None, resnet=False, precision=precision, @@ -520,6 +561,7 @@ def deserialize(cls, data: dict) -> "FittingNet": data : dict The dict to deserialize from. """ + data = copy.deepcopy(data) layers = data.pop("layers") obj = cls(**data) NativeNet.__init__(obj, layers) diff --git a/deepmd_utils/model_format/se_e2_a.py b/deepmd_utils/model_format/se_e2_a.py index a34694a882..b9143ee360 100644 --- a/deepmd_utils/model_format/se_e2_a.py +++ b/deepmd_utils/model_format/se_e2_a.py @@ -6,6 +6,7 @@ except ImportError: __version__ = "unknown" +import copy from typing import ( Any, List, @@ -270,6 +271,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "DescrptSeA": + data = copy.deepcopy(data) variables = data.pop("@variables") embeddings = data.pop("embeddings") env_mat = data.pop("env_mat") diff --git a/source/tests/test_model_format_utils.py b/source/tests/test_model_format_utils.py index b08c3bcf52..f588647096 100644 --- a/source/tests/test_model_format_utils.py +++ b/source/tests/test_model_format_utils.py @@ -35,39 +35,74 @@ def test_serialize_deserize(self): [None, [4], [3, 2]], ["float32", "float64", "single", "double"], ): - ww = np.full((ni, no), 3.0) - bb = np.full((no,), 4.0) if bias else None - idt = np.full((no,), 5.0) if ut else None - nl0 = NativeLayer(ww, bb, idt, activation_function, resnet, prec) + nl0 = NativeLayer( + ni, + no, + bias=bias, + use_timestep=ut, + activation_function=activation_function, + resnet=resnet, + precision=prec, + ) nl1 = NativeLayer.deserialize(nl0.serialize()) - inp_shap = [ww.shape[0]] + inp_shap = [ni] if ashp is not None: inp_shap = ashp + inp_shap inp = np.arange(np.prod(inp_shap)).reshape(inp_shap) np.testing.assert_allclose(nl0.call(inp), nl1.call(inp)) + def test_shape_error(self): + self.w0 = np.full((2, 3), 3.0) + self.b0 = np.full((2,), 4.0) + self.b1 = np.full((3,), 4.0) + self.idt0 = np.full((2,), 4.0) + with self.assertRaises(ValueError) as context: + network = NativeLayer.deserialize( + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": self.w0, "b": self.b0}, + } + ) + assert "not equalt to shape of b" in context.exception + with self.assertRaises(ValueError) as context: + network = NativeLayer.deserialize( + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": self.w0, "b": self.b1, "idt": self.idt0}, + } + ) + assert "not equalt to shape of idt" in context.exception + class TestNativeNet(unittest.TestCase): def setUp(self) -> None: - self.w = np.full((2, 3), 3.0) - self.b = np.full((3,), 4.0) - self.idt = np.full((3,), 5.0) + self.w0 = np.full((2, 3), 3.0) + self.b0 = np.full((3,), 4.0) + self.w1 = np.full((3, 4), 3.0) + self.b1 = np.full((4,), 4.0) def test_serialize(self): - network = NativeNet() - network[1]["w"] = self.w - network[1]["b"] = self.b - network[0]["w"] = self.w - network[0]["b"] = self.b + network = NativeNet( + [ + NativeLayer(2, 3).serialize(), + NativeLayer(3, 4).serialize(), + ] + ) + network[1]["w"] = self.w1 + network[1]["b"] = self.b1 + network[0]["w"] = self.w0 + network[0]["b"] = self.b0 network[1]["activation_function"] = "tanh" network[0]["activation_function"] = "tanh" network[1]["resnet"] = True network[0]["resnet"] = True jdata = network.serialize() - np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["w"], self.w) - np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["b"], self.b) - np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["w"], self.w) - np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["b"], self.b) + np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["w"], self.w0) + np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["b"], self.b0) + np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["w"], self.w1) + np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["b"], self.b1) np.testing.assert_array_equal(jdata["layers"][0]["activation_function"], "tanh") np.testing.assert_array_equal(jdata["layers"][1]["activation_function"], "tanh") np.testing.assert_array_equal(jdata["layers"][0]["resnet"], True) @@ -80,25 +115,45 @@ def test_deserialize(self): { "activation_function": "tanh", "resnet": True, - "@variables": {"w": self.w, "b": self.b}, + "@variables": {"w": self.w0, "b": self.b0}, }, { "activation_function": "tanh", "resnet": True, - "@variables": {"w": self.w, "b": self.b}, + "@variables": {"w": self.w1, "b": self.b1}, }, ], } ) - np.testing.assert_array_equal(network[0]["w"], self.w) - np.testing.assert_array_equal(network[0]["b"], self.b) - np.testing.assert_array_equal(network[1]["w"], self.w) - np.testing.assert_array_equal(network[1]["b"], self.b) + np.testing.assert_array_equal(network[0]["w"], self.w0) + np.testing.assert_array_equal(network[0]["b"], self.b0) + np.testing.assert_array_equal(network[1]["w"], self.w1) + np.testing.assert_array_equal(network[1]["b"], self.b1) np.testing.assert_array_equal(network[0]["activation_function"], "tanh") np.testing.assert_array_equal(network[1]["activation_function"], "tanh") np.testing.assert_array_equal(network[0]["resnet"], True) np.testing.assert_array_equal(network[1]["resnet"], True) + def test_shape_error(self): + with self.assertRaises(ValueError) as context: + network = NativeNet.deserialize( + { + "layers": [ + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": self.w0, "b": self.b0}, + }, + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": self.w0, "b": self.b0}, + }, + ], + } + ) + assert "does not match the dim of layer" in context.exception + class TestEmbeddingNet(unittest.TestCase): def test_embedding_net(self): @@ -146,19 +201,21 @@ def test_fitting_net(self): class TestNetworkCollection(unittest.TestCase): def setUp(self) -> None: - w = np.full((2, 3), 3.0) - b = np.full((3,), 4.0) + w0 = np.full((2, 3), 3.0) + b0 = np.full((3,), 4.0) + w1 = np.full((3, 4), 3.0) + b1 = np.full((4,), 4.0) self.network = { "layers": [ { "activation_function": "tanh", "resnet": True, - "@variables": {"w": w, "b": b}, + "@variables": {"w": w0, "b": b0}, }, { "activation_function": "tanh", "resnet": True, - "@variables": {"w": w, "b": b}, + "@variables": {"w": w1, "b": b1}, }, ], }