diff --git a/deepmd_utils/model_format.py b/deepmd_utils/model_format.py new file mode 100644 index 0000000000..68a6d4045b --- /dev/null +++ b/deepmd_utils/model_format.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Native DP model format for multiple backends. + +See issue #2982 for more information. +""" +import json +from typing import ( + List, + Optional, +) + +import h5py +import numpy as np + +try: + from deepmd_utils._version import version as __version__ +except ImportError: + __version__ = "unknown" + + +def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False): + """Traverse a model dict and call callback on each variable. + + Parameters + ---------- + model_obj : object + The model object to traverse. + callback : callable + The callback function to call on each variable. + is_variable : bool, optional + Whether the current node is a variable. + + Returns + ------- + object + The model object after traversing. + """ + if isinstance(model_obj, dict): + for kk, vv in model_obj.items(): + model_obj[kk] = traverse_model_dict( + vv, callback, is_variable=is_variable or kk == "@variables" + ) + elif isinstance(model_obj, list): + for ii, vv in enumerate(model_obj): + model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable) + elif is_variable: + model_obj = callback(model_obj) + return model_obj + + +class Counter: + """A callable counter. + + Examples + -------- + >>> counter = Counter() + >>> counter() + 0 + >>> counter() + 1 + """ + + def __init__(self): + self.count = -1 + + def __call__(self): + self.count += 1 + return self.count + + +def save_dp_model(filename: str, model_dict: dict, extra_info: Optional[dict] = None): + """Save a DP model to a file in the native format. + + Parameters + ---------- + filename : str + The filename to save to. + model_dict : dict + The model dict to save. + extra_info : dict, optional + Extra meta information to save. + """ + model_dict = model_dict.copy() + variable_counter = Counter() + if extra_info is not None: + extra_info = extra_info.copy() + else: + extra_info = {} + with h5py.File(filename, "w") as f: + model_dict = traverse_model_dict( + model_dict, + lambda x: f.create_dataset( + f"variable_{variable_counter():04d}", data=x + ).name, + ) + save_dict = { + "model": model_dict, + "software": "deepmd-kit", + "version": __version__, + **extra_info, + } + f.attrs["json"] = json.dumps(save_dict, separators=(",", ":")) + + +def load_dp_model(filename: str) -> dict: + """Load a DP model from a file in the native format. + + Parameters + ---------- + filename : str + The filename to load from. + + Returns + ------- + dict + The loaded model dict, including meta information. + """ + with h5py.File(filename, "r") as f: + model_dict = json.loads(f.attrs["json"]) + model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) + return model_dict + + +class NativeLayer: + """Native representation of a layer. + + Parameters + ---------- + w : np.ndarray, optional + The weights of the layer. + b : np.ndarray, optional + The biases of the layer. + idt : np.ndarray, optional + The identity matrix of the layer. + """ + + def __init__( + self, + w: Optional[np.ndarray] = None, + b: Optional[np.ndarray] = None, + idt: Optional[np.ndarray] = None, + ) -> None: + self.w = w + self.b = b + self.idt = idt + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + data = { + "w": self.w, + "b": self.b, + } + if self.idt is not None: + data["idt"] = self.idt + return data + + @classmethod + def deserialize(cls, data: dict) -> "NativeLayer": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + return cls(data["w"], data["b"], data.get("idt", None)) + + def __setitem__(self, key, value): + if key in ("w", "matrix"): + self.w = value + elif key in ("b", "bias"): + self.b = value + elif key == "idt": + self.idt = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("w", "matrix"): + return self.w + elif key in ("b", "bias"): + return self.b + elif key == "idt": + return self.idt + else: + raise KeyError(key) + + +class NativeNet: + """Native representation of a neural network. + + Parameters + ---------- + layers : list[NativeLayer], optional + The layers of the network. + """ + + def __init__(self, layers: Optional[List[NativeLayer]] = None) -> None: + if layers is None: + layers = [] + self.layers = layers + + def serialize(self) -> dict: + """Serialize the network to a dict. + + Returns + ------- + dict + The serialized network. + """ + return {"layers": [layer.serialize() for layer in self.layers]} + + @classmethod + def deserialize(cls, data: dict) -> "NativeNet": + """Deserialize the network from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + return cls([NativeLayer.deserialize(layer) for layer in data["layers"]]) + + def __getitem__(self, key): + assert isinstance(key, int) + if len(self.layers) <= key: + self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1)) + return self.layers[key] + + def __setitem__(self, key, value): + assert isinstance(key, int) + if len(self.layers) <= key: + self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1)) + self.layers[key] = value diff --git a/source/tests/test_model_format_utils.py b/source/tests/test_model_format_utils.py new file mode 100644 index 0000000000..b959ace3f6 --- /dev/null +++ b/source/tests/test_model_format_utils.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import unittest +from copy import ( + deepcopy, +) + +import numpy as np + +from deepmd_utils.model_format import ( + NativeNet, + load_dp_model, + save_dp_model, +) + + +class TestNativeNet(unittest.TestCase): + def setUp(self) -> None: + self.w = np.full((3, 2), 3.0) + self.b = np.full((3,), 4.0) + + def test_serialize(self): + network = NativeNet() + network[1]["w"] = self.w + network[1]["b"] = self.b + network[0]["w"] = self.w + network[0]["b"] = self.b + jdata = network.serialize() + np.testing.assert_array_equal(jdata["layers"][0]["w"], self.w) + np.testing.assert_array_equal(jdata["layers"][0]["b"], self.b) + np.testing.assert_array_equal(jdata["layers"][1]["w"], self.w) + np.testing.assert_array_equal(jdata["layers"][1]["b"], self.b) + + def test_deserialize(self): + network = NativeNet.deserialize( + { + "layers": [ + {"w": self.w, "b": self.b}, + {"w": self.w, "b": self.b}, + ] + } + ) + np.testing.assert_array_equal(network[0]["w"], self.w) + np.testing.assert_array_equal(network[0]["b"], self.b) + np.testing.assert_array_equal(network[1]["w"], self.w) + np.testing.assert_array_equal(network[1]["b"], self.b) + + +class TestDPModel(unittest.TestCase): + def setUp(self) -> None: + self.w = np.full((3, 2), 3.0) + self.b = np.full((3,), 4.0) + self.model_dict = { + "type": "some_type", + "@variables": { + "layers": [ + {"w": self.w, "b": self.b}, + {"w": self.w, "b": self.b}, + ] + }, + } + self.filename = "test_dp_model_format.dp" + + def test_save_load_model(self): + save_dp_model(self.filename, deepcopy(self.model_dict)) + model = load_dp_model(self.filename) + np.testing.assert_equal(model["model"], self.model_dict) + assert "software" in model + assert "version" in model + + def tearDown(self) -> None: + if os.path.exists(self.filename): + os.remove(self.filename)