From db22812de47f4300ccb2a3a6e14e40334bad9f63 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 5 Jan 2024 00:16:09 -0500 Subject: [PATCH] add activation_function and resnet arguments and NumPy implementation to NativeLayer (#3109) Signed-off-by: Jinzhe Zeng --- deepmd_utils/model_format.py | 94 +++++++++++++++++++++++-- source/tests/test_model_format_utils.py | 52 ++++++++++---- 2 files changed, 129 insertions(+), 17 deletions(-) diff --git a/deepmd_utils/model_format.py b/deepmd_utils/model_format.py index 68a6d4045b..0b67131c4d 100644 --- a/deepmd_utils/model_format.py +++ b/deepmd_utils/model_format.py @@ -4,6 +4,9 @@ See issue #2982 for more information. """ import json +from abc import ( + ABC, +) from typing import ( List, Optional, @@ -121,7 +124,15 @@ def load_dp_model(filename: str) -> dict: return model_dict -class NativeLayer: +class NativeOP(ABC): + """The unit operation of a native model.""" + + def call(self, *args, **kwargs): + """Forward pass in NumPy implementation.""" + raise NotImplementedError + + +class NativeLayer(NativeOP): """Native representation of a layer. Parameters @@ -132,6 +143,10 @@ class NativeLayer: The biases of the layer. idt : np.ndarray, optional The identity matrix of the layer. + activation_function : str, optional + The activation function of the layer. + resnet : bool, optional + Whether the layer is a residual layer. """ def __init__( @@ -139,10 +154,14 @@ def __init__( w: Optional[np.ndarray] = None, b: Optional[np.ndarray] = None, idt: Optional[np.ndarray] = None, + activation_function: Optional[str] = None, + resnet: bool = False, ) -> None: self.w = w self.b = b self.idt = idt + self.activation_function = activation_function + self.resnet = resnet def serialize(self) -> dict: """Serialize the layer to a dict. @@ -158,7 +177,11 @@ def serialize(self) -> dict: } if self.idt is not None: data["idt"] = self.idt - return data + return { + "activation_function": self.activation_function, + "resnet": self.resnet, + "@variables": data, + } @classmethod def deserialize(cls, data: dict) -> "NativeLayer": @@ -169,7 +192,13 @@ def deserialize(cls, data: dict) -> "NativeLayer": data : dict The dict to deserialize from. """ - return cls(data["w"], data["b"], data.get("idt", None)) + return cls( + w=data["@variables"]["w"], + b=data["@variables"]["b"], + idt=data.get("idt", None), + activation_function=data["activation_function"], + resnet=data.get("resnet", False), + ) def __setitem__(self, key, value): if key in ("w", "matrix"): @@ -178,6 +207,10 @@ def __setitem__(self, key, value): self.b = value elif key == "idt": self.idt = value + elif key == "activation_function": + self.activation_function = value + elif key == "resnet": + self.resnet = value else: raise KeyError(key) @@ -188,11 +221,47 @@ def __getitem__(self, key): return self.b elif key == "idt": return self.idt + elif key == "activation_function": + return self.activation_function + elif key == "resnet": + return self.resnet else: raise KeyError(key) + def call(self, x: np.ndarray) -> np.ndarray: + """Forward pass. + + Parameters + ---------- + x : np.ndarray + The input. + + Returns + ------- + np.ndarray + The output. + """ + if self.w is None or self.b is None or self.activation_function is None: + raise ValueError("w, b, and activation_function must be set") + if self.activation_function == "tanh": + fn = np.tanh + elif self.activation_function.lower() == "none": + + def fn(x): + return x + else: + raise NotImplementedError(self.activation_function) + y = fn(np.matmul(x, self.w) + self.b) + if self.idt is not None: + y *= self.idt + if self.resnet and self.w.shape[1] == self.w.shape[0]: + y += x + elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: + y += np.concatenate([x, x], axis=1) + return y + -class NativeNet: +class NativeNet(NativeOP): """Native representation of a neural network. Parameters @@ -238,3 +307,20 @@ def __setitem__(self, key, value): if len(self.layers) <= key: self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1)) self.layers[key] = value + + def call(self, x: np.ndarray) -> np.ndarray: + """Forward pass. + + Parameters + ---------- + x : np.ndarray + The input. + + Returns + ------- + np.ndarray + The output. + """ + for layer in self.layers: + x = layer.call(x) + return x diff --git a/source/tests/test_model_format_utils.py b/source/tests/test_model_format_utils.py index b959ace3f6..3b2aa5d8d4 100644 --- a/source/tests/test_model_format_utils.py +++ b/source/tests/test_model_format_utils.py @@ -25,25 +25,45 @@ def test_serialize(self): network[1]["b"] = self.b network[0]["w"] = self.w network[0]["b"] = self.b + network[1]["activation_function"] = "tanh" + network[0]["activation_function"] = "tanh" + network[1]["resnet"] = True + network[0]["resnet"] = True jdata = network.serialize() - np.testing.assert_array_equal(jdata["layers"][0]["w"], self.w) - np.testing.assert_array_equal(jdata["layers"][0]["b"], self.b) - np.testing.assert_array_equal(jdata["layers"][1]["w"], self.w) - np.testing.assert_array_equal(jdata["layers"][1]["b"], self.b) + np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["w"], self.w) + np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["b"], self.b) + np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["w"], self.w) + np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["b"], self.b) + np.testing.assert_array_equal(jdata["layers"][0]["activation_function"], "tanh") + np.testing.assert_array_equal(jdata["layers"][1]["activation_function"], "tanh") + np.testing.assert_array_equal(jdata["layers"][0]["resnet"], True) + np.testing.assert_array_equal(jdata["layers"][1]["resnet"], True) def test_deserialize(self): network = NativeNet.deserialize( { "layers": [ - {"w": self.w, "b": self.b}, - {"w": self.w, "b": self.b}, - ] + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": self.w, "b": self.b}, + }, + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": self.w, "b": self.b}, + }, + ], } ) np.testing.assert_array_equal(network[0]["w"], self.w) np.testing.assert_array_equal(network[0]["b"], self.b) np.testing.assert_array_equal(network[1]["w"], self.w) np.testing.assert_array_equal(network[1]["b"], self.b) + np.testing.assert_array_equal(network[0]["activation_function"], "tanh") + np.testing.assert_array_equal(network[1]["activation_function"], "tanh") + np.testing.assert_array_equal(network[0]["resnet"], True) + np.testing.assert_array_equal(network[1]["resnet"], True) class TestDPModel(unittest.TestCase): @@ -52,12 +72,18 @@ def setUp(self) -> None: self.b = np.full((3,), 4.0) self.model_dict = { "type": "some_type", - "@variables": { - "layers": [ - {"w": self.w, "b": self.b}, - {"w": self.w, "b": self.b}, - ] - }, + "layers": [ + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": self.w, "b": self.b}, + }, + { + "activation_function": "tanh", + "resnet": True, + "@variables": {"w": self.w, "b": self.b}, + }, + ], } self.filename = "test_dp_model_format.dp"