From 7f6104801fb920a0c7dbeba2583d5e91fb4dd927 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 26 Jul 2024 04:37:43 -0400 Subject: [PATCH] feat: plain text model format (#4025) Propose a plain text model format based on YAML, which humans can easily read and might be easier to track changes in the git repository (which is good for #2103). Example: [deeppot_dpa_sel.yaml](https://github.com/user-attachments/files/16384230/deeppot_dpa_sel.yaml.txt) ## Summary by CodeRabbit - **New Features** - Added support for additional file formats (.yaml and .yml) for model saving and loading. - Enhanced the ability to serialize and deserialize model data in multiple formats. - **Bug Fixes** - Improved error handling for unsupported file formats during model loading. - **Documentation** - Updated documentation to reflect new supported file formats and clarify backend capabilities. - **Tests** - Introduced new test cases to ensure functionality for saving and loading models in YAML format. Signed-off-by: Jinzhe Zeng --- deepmd/backend/dpmodel.py | 2 +- deepmd/dpmodel/utils/serialization.py | 84 +++++++++++++++++---- doc/backend.md | 8 +- source/tests/common/dpmodel/test_network.py | 10 +++ 4 files changed, 84 insertions(+), 20 deletions(-) diff --git a/deepmd/backend/dpmodel.py b/deepmd/backend/dpmodel.py index 30591fb51a..c51d097d5a 100644 --- a/deepmd/backend/dpmodel.py +++ b/deepmd/backend/dpmodel.py @@ -37,7 +37,7 @@ class DPModelBackend(Backend): Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO ) """The features of the backend.""" - suffixes: ClassVar[List[str]] = [".dp"] + suffixes: ClassVar[List[str]] = [".dp", ".yaml", ".yml"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/dpmodel/utils/serialization.py b/deepmd/dpmodel/utils/serialization.py index a69170e51d..6529bac692 100644 --- a/deepmd/dpmodel/utils/serialization.py +++ b/deepmd/dpmodel/utils/serialization.py @@ -3,11 +3,16 @@ from datetime import ( datetime, ) +from pathlib import ( + Path, +) from typing import ( Callable, ) import h5py +import numpy as np +import yaml try: from deepmd._version import version as __version__ @@ -33,6 +38,8 @@ def traverse_model_dict(model_obj, callback: Callable, is_variable: bool = False The model object after traversing. """ if isinstance(model_obj, dict): + if model_obj.get("@is_variable", False): + return callback(model_obj) for kk, vv in model_obj.items(): model_obj[kk] = traverse_model_dict( vv, callback, is_variable=is_variable or kk == "@variables" @@ -78,22 +85,48 @@ def save_dp_model(filename: str, model_dict: dict) -> None: The model dict to save. """ model_dict = model_dict.copy() - variable_counter = Counter() - with h5py.File(filename, "w") as f: + filename_extension = Path(filename).suffix + extra_dict = { + "software": "deepmd-kit", + "version": __version__, + # use UTC+0 time + "time": str(datetime.utcnow()), + } + if filename_extension == ".dp": + variable_counter = Counter() + with h5py.File(filename, "w") as f: + model_dict = traverse_model_dict( + model_dict, + lambda x: f.create_dataset( + f"variable_{variable_counter():04d}", data=x + ).name, + ) + save_dict = { + **extra_dict, + **model_dict, + } + f.attrs["json"] = json.dumps(save_dict, separators=(",", ":")) + elif filename_extension in {".yaml", ".yml"}: model_dict = traverse_model_dict( model_dict, - lambda x: f.create_dataset( - f"variable_{variable_counter():04d}", data=x - ).name, + lambda x: { + "@class": "np.ndarray", + "@is_variable": True, + "@version": 1, + "dtype": x.dtype.name, + "value": x.tolist(), + }, ) - save_dict = { - "software": "deepmd-kit", - "version": __version__, - # use UTC+0 time - "time": str(datetime.utcnow()), - **model_dict, - } - f.attrs["json"] = json.dumps(save_dict, separators=(",", ":")) + with open(filename, "w") as f: + yaml.safe_dump( + { + **extra_dict, + **model_dict, + }, + f, + ) + else: + raise ValueError(f"Unknown filename extension: {filename_extension}") def load_dp_model(filename: str) -> dict: @@ -109,7 +142,26 @@ def load_dp_model(filename: str) -> dict: dict The loaded model dict, including meta information. """ - with h5py.File(filename, "r") as f: - model_dict = json.loads(f.attrs["json"]) - model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) + filename_extension = Path(filename).suffix + if filename_extension == ".dp": + with h5py.File(filename, "r") as f: + model_dict = json.loads(f.attrs["json"]) + model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) + elif filename_extension in {".yaml", ".yml"}: + + def convert_numpy_ndarray(x): + if isinstance(x, dict) and x.get("@class") == "np.ndarray": + dtype = np.dtype(x["dtype"]) + value = np.asarray(x["value"], dtype=dtype) + return value + return x + + with open(filename) as f: + model_dict = yaml.safe_load(f) + model_dict = traverse_model_dict( + model_dict, + convert_numpy_ndarray, + ) + else: + raise ValueError(f"Unknown filename extension: {filename_extension}") return model_dict diff --git a/doc/backend.md b/doc/backend.md index e164cd8405..8639396941 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -29,13 +29,15 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different This backend is only for development and should not take into production. ::: -- Model filename extension: `.dp` +- Model filename extension: `.dp`, `.yaml`, `.yml` DP is a reference backend for development, which uses pure [NumPy](https://numpy.org/) to implement models without using any heavy deep-learning frameworks. Due to the limitation of NumPy, it doesn't support gradient calculation and thus cannot be used for training. As a reference backend, it is not aimed at the best performance, but only the correct results. -The DP backend uses [HDF5](https://docs.h5py.org/) to store model serialization data, which is backend-independent. -Only Python inference interface can load this format. +The DP backend has two formats, both of which are backend-independent: +The `.dp` format uses [HDF5](https://docs.h5py.org/) to store model serialization data, which has good performance. +The `.yaml` or `.yml` use [YAML](https://yaml.org/) to save the data as plain texts, which is easy to read for human beings. +Only Python inference interface can load these formats. NumPy 1.21 or above is required. diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 047eee501c..381c542272 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -283,6 +283,7 @@ def setUp(self) -> None: ], } self.filename = "test_dp_dpmodel.dp" + self.filename_yaml = "test_dp_dpmodel.yaml" def test_save_load_model(self): save_dp_model(self.filename, {"model": deepcopy(self.model_dict)}) @@ -291,6 +292,15 @@ def test_save_load_model(self): assert "software" in model assert "version" in model + def test_save_load_model_yaml(self): + save_dp_model(self.filename_yaml, {"model": deepcopy(self.model_dict)}) + model = load_dp_model(self.filename_yaml) + np.testing.assert_equal(model["model"], self.model_dict) + assert "software" in model + assert "version" in model + def tearDown(self) -> None: if os.path.exists(self.filename): os.remove(self.filename) + if os.path.exists(self.filename_yaml): + os.remove(self.filename_yaml)