From 6be7153e57af2cde553af2374766ffba0376df07 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jan 2024 17:08:38 +0800 Subject: [PATCH 1/2] support fitting net --- deepmd_utils/model_format/__init__.py | 2 + deepmd_utils/model_format/network.py | 116 +++++++++++++++++++++++- source/tests/test_model_format_utils.py | 28 ++++++ 3 files changed, 143 insertions(+), 3 deletions(-) diff --git a/deepmd_utils/model_format/__init__.py b/deepmd_utils/model_format/__init__.py index 356eaaf4fa..9b544aedf8 100644 --- a/deepmd_utils/model_format/__init__.py +++ b/deepmd_utils/model_format/__init__.py @@ -8,6 +8,7 @@ ) from .network import ( EmbeddingNet, + FittingNet, NativeLayer, NativeNet, NetworkCollection, @@ -31,6 +32,7 @@ "DescrptSeA", "EnvMat", "EmbeddingNet", + "FittingNet", "NativeLayer", "NativeNet", "NetworkCollection", diff --git a/deepmd_utils/model_format/network.py b/deepmd_utils/model_format/network.py index c587b08cf6..aec62d74af 100644 --- a/deepmd_utils/model_format/network.py +++ b/deepmd_utils/model_format/network.py @@ -162,7 +162,9 @@ def __init__( self.w = w.astype(prec) if w is not None else None self.b = b.astype(prec) if b is not None else None self.idt = idt.astype(prec) if idt is not None else None - self.activation_function = activation_function + self.activation_function = ( + activation_function if activation_function is not None else "none" + ) self.resnet = resnet self.check_type_consistency() @@ -354,6 +356,24 @@ def call(self, x: np.ndarray) -> np.ndarray: class EmbeddingNet(NativeNet): + """The embedding network. + + Parameters + ---------- + in_dim + Input dimension. + neuron + The number of neurons in each layer. The output dimension + is the same as the dimension of the last layer. + activation_function + The activation function. + resnet_dt + Use time step at the resnet architecture. + precision + Floating point precision for the model paramters. + + """ + def __init__( self, in_dim, @@ -370,8 +390,8 @@ def __init__( layers.append( NativeLayer( rng.normal(size=(i_in, i_ot)), - b=rng.normal(size=(ii)), - idt=rng.normal(size=(ii)) if resnet_dt else None, + b=rng.normal(size=(i_ot)), + idt=rng.normal(size=(i_ot)) if resnet_dt else None, activation_function=activation_function, resnet=True, precision=precision, @@ -417,6 +437,95 @@ def deserialize(cls, data: dict) -> "EmbeddingNet": return obj +class FittingNet(EmbeddingNet): + """The fitting network. It may be implemented as an embedding + net connected with a linear output layer. + + Parameters + ---------- + in_dim + Input dimension. + out_dim + Output dimension + neuron + The number of neurons in each hidden layer. + activation_function + The activation function. + resnet_dt + Use time step at the resnet architecture. + precision + Floating point precision for the model paramters. + bias_out + The last linear layer has bias. + + """ + + def __init__( + self, + in_dim, + out_dim, + neuron: List[int] = [24, 48, 96], + activation_function: str = "tanh", + resnet_dt: bool = False, + precision: str = DEFAULT_PRECISION, + bias_out: bool = True, + ): + super().__init__( + in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + precision=precision, + ) + rng = np.random.default_rng() + i_in, i_ot = neuron[-1], out_dim + self.layers.append( + NativeLayer( + rng.normal(size=(i_in, i_ot)), + b=rng.normal(size=(i_ot)) if bias_out else None, + idt=None, + activation_function=None, + resnet=False, + precision=precision, + ) + ) + self.out_dim = out_dim + self.bias_out = bias_out + + def serialize(self) -> dict: + """Serialize the network to a dict. + + Returns + ------- + dict + The serialized network. + """ + return { + "in_dim": self.in_dim, + "out_dim": self.out_dim, + "neuron": self.neuron.copy(), + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "precision": self.precision, + "bias_out": self.bias_out, + "layers": [layer.serialize() for layer in self.layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "EmbeddingNet": + """Deserialize the network from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + layers = data.pop("layers") + obj = cls(**data) + NativeNet.__init__(obj, layers) + return obj + + class NetworkCollection: """A collection of networks for multiple elements. @@ -439,6 +548,7 @@ class NetworkCollection: NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = { "network": NativeNet, "embedding_network": EmbeddingNet, + "fitting_network": FittingNet, } def __init__( diff --git a/source/tests/test_model_format_utils.py b/source/tests/test_model_format_utils.py index 98138fee17..b08c3bcf52 100644 --- a/source/tests/test_model_format_utils.py +++ b/source/tests/test_model_format_utils.py @@ -12,6 +12,7 @@ DescrptSeA, EmbeddingNet, EnvMat, + FittingNet, NativeLayer, NativeNet, NetworkCollection, @@ -98,6 +99,8 @@ def test_deserialize(self): np.testing.assert_array_equal(network[0]["resnet"], True) np.testing.assert_array_equal(network[1]["resnet"], True) + +class TestEmbeddingNet(unittest.TestCase): def test_embedding_net(self): for ni, act, idt, prec in itertools.product( [1, 10], @@ -116,6 +119,31 @@ def test_embedding_net(self): np.testing.assert_allclose(en0.call(inp), en1.call(inp)) +class TestFittingNet(unittest.TestCase): + def test_fitting_net(self): + for ni, no, act, idt, prec, bo in itertools.product( + [1, 10], + [1, 7], + ["tanh", "none"], + [True, False], + ["double", "single"], + [True, False], + ): + en0 = FittingNet( + ni, + no, + activation_function=act, + precision=prec, + resnet_dt=idt, + bias_out=bo, + ) + en1 = FittingNet.deserialize(en0.serialize()) + inp = np.ones([ni]) + en0.call(inp) + en1.call(inp) + np.testing.assert_allclose(en0.call(inp), en1.call(inp)) + + class TestNetworkCollection(unittest.TestCase): def setUp(self) -> None: w = np.full((2, 3), 3.0) From 9f538e0919139606dcaf9ce340729e755c122575 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jan 2024 18:22:52 +0800 Subject: [PATCH 2/2] fix typo --- deepmd_utils/model_format/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd_utils/model_format/network.py b/deepmd_utils/model_format/network.py index aec62d74af..c73e53f5cb 100644 --- a/deepmd_utils/model_format/network.py +++ b/deepmd_utils/model_format/network.py @@ -512,7 +512,7 @@ def serialize(self) -> dict: } @classmethod - def deserialize(cls, data: dict) -> "EmbeddingNet": + def deserialize(cls, data: dict) -> "FittingNet": """Deserialize the network from a dict. Parameters