diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index bd731004cb..cbcbe119bb 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -509,3 +509,31 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): # call subprocess cls = cls.get_class_by_input(local_jdata) return cls.update_sel(global_jdata, local_jdata) + + @classmethod + def deserialize(cls, data: dict): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is Descriptor: + return Descriptor.get_class_by_input(data).deserialize(data) + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + raise NotImplementedError("Not implemented in class %s" % self.__name__) diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 2de0b63245..72295dd017 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -192,6 +192,7 @@ def __init__( self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron) self.trainable = trainable self.compress_activation_fn = get_activation_func(activation_function) + self.activation_function_name = activation_function self.filter_activation_fn = get_activation_func(activation_function) self.filter_precision = get_precision(precision) self.filter_np_precision = get_np_precision(precision) @@ -1334,3 +1335,57 @@ def explicit_ntypes(self) -> bool: if self.stripped_type_embedding: return True return False + + @classmethod + def deserialize(cls, data: dict): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + descriptor = cls(**data) + descriptor.davg = data["@variables"]["davg"] + descriptor.dstd = data["@variables"]["dstd"] + descriptor.embedding_net_variables = data["@variables"] + descriptor.original_sel = data["@variables"]["original_sel"] + return descriptor + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + return { + "type": "se_e2_a", + "rcut": self.rcut_r, + "rcut_smth": self.rcut_r_smth, + "sel": self.sel_a, + "neuron": self.filter_neuron, + "axis_neuron": self.n_axis_neuron, + "resnet_dt": self.filter_resnet_dt, + "trainable": self.trainable, + "seed": self.seed, + "type_one_side": self.type_one_side, + "exclude_types": list(self.exclude_types), + "set_davg_zero": self.set_davg_zero, + "activation_function": self.activation_function_name, + "precision": self.filter_precision.name, + "uniform_seed": self.uniform_seed, + "stripped_type_embedding": self.stripped_type_embedding, + "@variables": { + **self.embedding_net_variables, + "davg": self.davg, + "dstd": self.dstd, + "original_sel": self.original_sel, + }, + } diff --git a/deepmd/entrypoints/convert.py b/deepmd/entrypoints/convert.py index bea047ba72..e3f262cd84 100644 --- a/deepmd/entrypoints/convert.py +++ b/deepmd/entrypoints/convert.py @@ -9,6 +9,10 @@ convert_pbtxt_to_pb, convert_to_21, ) +from deepmd.utils.convert_dp import ( + convert_dp_to_pb, + convert_pb_to_dp, +) def convert( @@ -23,6 +27,11 @@ def convert( convert_pb_to_pbtxt(input_model, output_model) else: raise RuntimeError("input model is already pbtxt") + elif output_model.endswith(".dp"): + if input_model.endswith(".pb"): + convert_pb_to_dp(input_model, output_model) + else: + raise RuntimeError("Unsupported format") else: if FROM == "auto": convert_to_21(input_model, output_model) @@ -39,5 +48,7 @@ def convert( convert_20_to_21(input_model, output_model) elif FROM == "pbtxt": convert_pbtxt_to_pb(input_model, output_model) + elif FROM == "dp": + convert_dp_to_pb(input_model, output_model) else: raise RuntimeError("unsupported model version " + FROM) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index e74d4a7e6d..9fc80bf2fd 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging from typing import ( + TYPE_CHECKING, List, Optional, ) @@ -53,6 +54,11 @@ Spin, ) +if TYPE_CHECKING: + from deepmd.descriptor import ( + Descriptor, + ) + log = logging.getLogger(__name__) @@ -130,7 +136,7 @@ class EnerFitting(Fitting): def __init__( self, - descrpt: tf.Tensor, + descrpt: "Descriptor", neuron: List[int] = [120, 120, 120], resnet_dt: bool = True, numb_fparam: int = 0, @@ -176,6 +182,7 @@ def __init__( self.ntypes_spin = self.spin.get_ntypes_spin() if self.spin is not None else 0 self.seed_shift = one_layer_rand_seed_shift() self.tot_ener_zero = tot_ener_zero + self.activation_function_name = activation_function self.fitting_activation_fn = get_activation_func(activation_function) self.fitting_precision = get_precision(precision) self.trainable = trainable @@ -916,3 +923,66 @@ def get_loss(self, loss: dict, lr) -> Loss: return EnerSpinLoss(**loss, use_spin=self.spin.use_spin) else: raise RuntimeError("unknown loss type") + + @classmethod + def deserialize(cls, data: dict): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + fitting = cls(**data) + fitting.fitting_net_variables = data["@variables"] + fitting.bias_atom_e = fitting.fitting_net_variables.pop("bias_atom_e") + if fitting.numb_fparam > 0: + fitting.fparam_avg = fitting.fitting_net_variables.pop("fparam_avg") + fitting.fparam_inv_std = fitting.fitting_net_variables.pop("fparam_inv_std") + if fitting.numb_aparam > 0: + fitting.aparam_avg = fitting.fitting_net_variables.pop("aparam_avg") + fitting.aparam_inv_std = fitting.fitting_net_variables.pop("aparam_inv_std") + return fitting + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + data = { + "type": "ener", + "neuron": self.n_neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "rcond": self.rcond, + "tot_ener_zero": self.tot_ener_zero, + "trainable": self.trainable, + "seed": self.seed, + "atom_ener": self.atom_ener, + "activation_function": self.activation_function_name, + "precision": self.fitting_precision.name, + "uniform_seed": self.uniform_seed, + "layer_name": self.layer_name, + "use_aparam_as_mask": self.use_aparam_as_mask, + "@variables": { + **self.fitting_net_variables, + "bias_atom_e": self.bias_atom_e, + }, + } + + if self.numb_fparam > 0: + data["@variables"]["fparam_avg"] = self.fparam_avg + data["@variables"]["fparam_inv_std"] = self.fparam_inv_std + if self.numb_aparam > 0: + data["@variables"]["aparam_avg"] = self.aparam_avg + data["@variables"]["aparam_inv_std"] = self.aparam_inv_std + return data diff --git a/deepmd/fit/fitting.py b/deepmd/fit/fitting.py index a467ec1201..55dcc791cb 100644 --- a/deepmd/fit/fitting.py +++ b/deepmd/fit/fitting.py @@ -43,16 +43,20 @@ class SomeFitting(Fitting): """ return Fitting.__plugins.register(key) + @classmethod + def get_class_by_input(cls, input: dict): + try: + fitting_type = input["type"] + except KeyError: + raise KeyError("the type of fitting should be set by `type`") + if fitting_type in Fitting.__plugins.plugins: + return Fitting.__plugins.plugins[fitting_type] + else: + raise RuntimeError("Unknown descriptor type: " + fitting_type) + def __new__(cls, *args, **kwargs): if cls is Fitting: - try: - fitting_type = kwargs["type"] - except KeyError: - raise KeyError("the type of fitting should be set by `type`") - if fitting_type in Fitting.__plugins.plugins: - cls = Fitting.__plugins.plugins[fitting_type] - else: - raise RuntimeError("Unknown descriptor type: " + fitting_type) + cls = Fitting.get_class_by_input(kwargs) return super().__new__(cls) @property @@ -102,3 +106,31 @@ def get_loss(self, loss: dict, lr) -> Loss: Loss the loss function """ + + @classmethod + def deserialize(cls, data: dict): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is Fitting: + return Fitting.get_class_by_input(data).deserialize(data) + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + raise NotImplementedError("Not implemented in class %s" % self.__name__) diff --git a/deepmd/model/ener.py b/deepmd/model/ener.py index 1976c1ad51..c6f914d004 100644 --- a/deepmd/model/ener.py +++ b/deepmd/model/ener.py @@ -7,12 +7,18 @@ import numpy as np +from deepmd.descriptor.descriptor import ( + Descriptor, +) from deepmd.env import ( MODEL_VERSION, global_cvt_2_ener_float, op_module, tf, ) +from deepmd.fit.fitting import ( + Fitting, +) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -512,3 +518,59 @@ def change_energy_bias( bias_shift, self.data_bias_nsample, ) + + @classmethod + def deserialize(cls, data: dict): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + data = data.copy() + data["descriptor"] = Descriptor.deserialize(data["descriptor"]) + fitting_data = data["fitting_net"].copy() + fitting_data["descrpt"] = data["descriptor"] + data["fitting_net"] = Fitting.deserialize(fitting_data) + if data.get("type_embedding") is not None: + data["type_embedding"] = TypeEmbedNet.deserialize(data["type_embedding"]) + return cls(**data) + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + data = { + "type": "standard", + "descriptor": self.descrpt.serialize(), + "fitting_net": self.fitting.serialize(), + "type_embedding": self.typeebd.serialize() + if self.typeebd is not None + else None, + "type_map": self.get_type_map(), + "data_stat_nbatch": self.data_stat_nbatch, + "data_stat_protect": self.data_stat_protect, + "use_srtab": self.srtab_name, + "spin": self.spin.serialize() if self.spin is not None else None, + "data_bias_nsample": self.data_bias_nsample, + } + if self.srtab_name is not None: + data.merge( + { + "srtab_add_bias": self.srtab_add_bias, + "smin_alpha": self.smin_alpha, + "sw_rmin": self.sw_rmin, + "sw_rmax": self.sw_rmax, + } + ) + return data diff --git a/deepmd/model/model.py b/deepmd/model/model.py index 3f24e42aec..0a3825d47b 100644 --- a/deepmd/model/model.py +++ b/deepmd/model/model.py @@ -32,7 +32,11 @@ from deepmd.utils.data_system import ( DeepmdDataSystem, ) +from deepmd.utils.errors import ( + GraphWithoutTensorError, +) from deepmd.utils.graph import ( + get_tensor_by_name_from_graph, load_graph_def, ) from deepmd.utils.pair_tab import ( @@ -44,6 +48,10 @@ from deepmd.utils.type_embed import ( TypeEmbedNet, ) +from deepmd_utils.model_format import ( + load_dp_model, + save_dp_model, +) class Model(ABC): @@ -506,6 +514,78 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict: cls = cls.get_class_by_input(local_jdata) return cls.update_sel(global_jdata, local_jdata) + @classmethod + def deserialize(cls, data: dict): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is Model: + return Model.get_class_by_input(data).deserialize(data) + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + raise NotImplementedError("Not implemented in class %s" % self.__name__) + + @classmethod + def load_model(cls, filename: str) -> "Model": + """Load the model from file. + + Parameters + ---------- + filename : str + The filename + + Returns + ------- + Model + The loaded model + """ + data = load_dp_model(filename=filename) + return cls.deserialize(data["model"]) + + def save_model(self, filename: str, graph: tf.Graph, graph_def: tf.GraphDef): + """Save the model to file. + + Parameters + ---------- + filename : str + The filename + """ + self.init_variables(graph=graph, graph_def=graph_def) + model_dict = self.serialize() + try: + t_min_nbor_dist = get_tensor_by_name_from_graph( + graph, "train_attr/min_nbor_dist" + ) + except GraphWithoutTensorError as e: + pass + else: + model_dict.setdefault("@variables", {}) + model_dict["@variables"]["min_nbor_dist"] = t_min_nbor_dist + save_dp_model( + filename=filename, + model_dict=model_dict, + extra_info={ + "module": __name__, + }, + ) + class StandardModel(Model): """Standard model, which must contain a descriptor and a fitting. @@ -522,7 +602,8 @@ class StandardModel(Model): The type map """ - def __new__(cls, *args, **kwargs): + @classmethod + def get_class_by_input(cls, input: dict): from .dos import ( DOSModel, ) @@ -534,20 +615,23 @@ def __new__(cls, *args, **kwargs): PolarModel, ) + fitting_type = input["fitting_net"]["type"] + # init model + # infer model type by fitting_type + if fitting_type == "ener": + return EnerModel + elif fitting_type == "dos": + return DOSModel + elif fitting_type == "dipole": + return DipoleModel + elif fitting_type == "polar": + return PolarModel + else: + raise RuntimeError("get unknown fitting type when building model") + + def __new__(cls, *args, **kwargs): if cls is StandardModel: - fitting_type = kwargs["fitting_net"]["type"] - # init model - # infer model type by fitting_type - if fitting_type == "ener": - cls = EnerModel - elif fitting_type == "dos": - cls = DOSModel - elif fitting_type == "dipole": - cls = DipoleModel - elif fitting_type == "polar": - cls = PolarModel - else: - raise RuntimeError("get unknown fitting type when building model") + cls = cls.get_class_by_input(kwargs) return cls.__new__(cls) return super().__new__(cls) @@ -665,3 +749,21 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): global_jdata, local_jdata["descriptor"] ) return local_jdata_cpy + + @classmethod + def deserialize(cls, data: dict): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is StandardModel: + return StandardModel.get_class_by_input(data).deserialize(data) + raise NotImplementedError("Not implemented in class %s" % cls.__name__) diff --git a/deepmd/utils/convert_dp.py b/deepmd/utils/convert_dp.py new file mode 100644 index 0000000000..694135c933 --- /dev/null +++ b/deepmd/utils/convert_dp.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import tempfile + +from deepmd.entrypoints.freeze import ( + freeze, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, + tf, +) +from deepmd.model.model import ( + Model, +) +from deepmd.utils.graph import ( + get_tensor_by_name_from_graph, + load_graph_def, +) +from deepmd.utils.sess import ( + run_sess, +) + + +def convert_pb_to_dp(input_model: str, output_model: str): + """Convert a frozen model to a native DP model. + + Parameters + ---------- + input_model : str + The input frozen model. + output_model : str + The output DP model. + """ + graph, graph_def = load_graph_def(input_model) + t_jdata = get_tensor_by_name_from_graph(graph, "train_attr/training_script") + jdata = json.loads(t_jdata) + model = Model(**jdata["model"]) + model.save_model(output_model, graph, graph_def) + + +def convert_dp_to_pb(input_model: str, output_model: str): + """Convert a native DP model to a frozen model. + + Parameters + ---------- + input_model : str + The input DP model. + output_model : str + The output frozen model. + """ + model = Model.load_model(input_model) + with tf.Session() as sess: + place_holders = {} + for ii in ["coord", "box"]: + place_holders[ii] = tf.placeholder( + GLOBAL_NP_FLOAT_PRECISION, [None, None], name="t_" + ii + ) + place_holders["type"] = tf.placeholder(tf.int32, [None], name="t_type") + place_holders["natoms_vec"] = tf.placeholder( + tf.int32, [model.get_ntypes() + 2], name="t_natoms" + ) + place_holders["default_mesh"] = tf.placeholder(tf.int32, [None], name="t_mesh") + # TODO: fparam, aparam + + model.build( + place_holders["coord"], + place_holders["type"], + place_holders["natoms_vec"], + place_holders["box"], + place_holders["default_mesh"], + place_holders, + reuse=False, + ) + init = tf.global_variables_initializer() + run_sess(sess, init) + saver = tf.train.Saver() + with tempfile.TemporaryDirectory() as nt: + saver.save( + sess, + os.path.join(nt, "model.ckpt"), + global_step=0, + ) + freeze(checkpoint_folder=nt, output=output_model, node_names=None) diff --git a/deepmd_utils/main.py b/deepmd_utils/main.py index 3dc54db052..c36b09e7e5 100644 --- a/deepmd_utils/main.py +++ b/deepmd_utils/main.py @@ -481,7 +481,7 @@ def main_parser() -> argparse.ArgumentParser: nargs="?", default="auto", type=str, - choices=["auto", "0.12", "1.0", "1.1", "1.2", "1.3", "2.0", "pbtxt"], + choices=["auto", "0.12", "1.0", "1.1", "1.2", "1.3", "2.0", "pbtxt", "dp"], help="The original model compatibility", ) parser_transform.add_argument( diff --git a/deepmd_utils/model_format.py b/deepmd_utils/model_format.py new file mode 100644 index 0000000000..11942eab63 --- /dev/null +++ b/deepmd_utils/model_format.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Native DP model format for multiple backends.""" +import json +from typing import ( + Optional, +) + +import h5py + +try: + from deepmd_utils._version import version as __version__ +except ImportError: + __version__ = "unknown" + + +def traverse_model_dict(model_dict: dict, callback: callable): + """Traverse a model dict and call callback on each variable. + + Parameters + ---------- + model_dict : dict + The model dict to traverse. + callback : callable + The callback function to call on each variable. + """ + for kk, vv in model_dict.items(): + if isinstance(vv, dict): + if kk == "@variables": + variable_dict = vv.copy() + for k2, v2 in variable_dict.items(): + variable_dict[k2] = callback(v2) + model_dict[kk] = variable_dict + else: + traverse_model_dict(vv, callback) + + +class Counter: + """A callable counter. + + Examples + -------- + >>> counter = Counter() + >>> counter() + 0 + >>> counter() + 1 + """ + + def __init__(self): + self.count = -1 + + def __call__(self): + self.count += 1 + return self.count + + +def save_dp_model(filename: str, model_dict: dict, extra_info: Optional[dict] = None): + """Save a DP model to a file in the native format. + + Parameters + ---------- + filename : str + The filename to save to. + model_dict : dict + The model dict to save. + extra_info : dict, optional + Extra meta information to save. + """ + model_dict = model_dict.copy() + variable_counter = Counter() + if extra_info is not None: + extra_info = extra_info.copy() + else: + extra_info = {} + with h5py.File(filename, "w") as f: + traverse_model_dict( + model_dict, + lambda x: f.create_dataset( + f"variable_{variable_counter():04d}", data=x + ).name, + ) + save_dict = { + "model": model_dict, + "software": "deepmd-kit", + "version": __version__, + **extra_info, + } + f.attrs["json"] = json.dumps(save_dict, separators=(",", ":")) + + +def load_dp_model(filename: str) -> dict: + """Load a DP model from a file in the native format. + + Parameters + ---------- + filename : str + The filename to load from. + + Returns + ------- + dict + The loaded model dict, including meta information. + """ + with h5py.File(filename, "r") as f: + model_dict = json.loads(f.attrs["json"]) + traverse_model_dict(model_dict, lambda x: f[x][()].copy()) + return model_dict